DeepLearing4j深度学习之Yolo Tiny实现目标检测

Yolo Tiny是 Yolo2的简化版,虽然有点过时但对于很多物体检测的应用场景还是很管用,本示例利用DeepLearing4j构建Yolo算法实现目标检测,下图是本示例的网络结构:

 


// parameters matching the pretrained TinyYOLO model
int width = 416;
int height = 416;
int nChannels = 3;
int gridWidth = 13;
int gridHeight = 13;
// number classes (digits) for the SVHN datasets
int nClasses = 5;
// parameters for the Yolo2OutputLayer
double[][] priorBoxes = { { 1.5, 2.2 }, { 1.4, 1.95 }, { 1.8, 3.3 }, { 2.4, 2.9 }, { 1.7, 2.2 } };
double detectionThreshold = 0.8;
// parameters for the training phase
int batchSize = 10;
int nEpochs = 20;
int seed = 123;
Random rng = new Random(seed);
File imageDir = new File("D:\train");
InputSplit[] data = new FileSplit(imageDir, NativeImageLoader.ALLOWED_FORMATS, rng).sample(null, 0.9, 0.1);
InputSplit trainData = data[0];
InputSplit testData = data[1];
ObjectDetectionRecordReader recordReaderTrain = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth,
		new YoloLabelProvider(imageDir.getAbsolutePath()));
recordReaderTrain.initialize(trainData);

ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth,
		new YoloLabelProvider(imageDir.getAbsolutePath()));
recordReaderTest.initialize(testData);
// ObjectDetectionRecordReader performs regression, so we need to specify it here
RecordReaderDataSetIterator train = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 1, 1, true);
train.setPreProcessor(new ImagePreProcessingScaler(0, 1));

RecordReaderDataSetIterator test = new RecordReaderDataSetIterator(recordReaderTest, 1, 1, 1, true);
test.setPreProcessor(new ImagePreProcessingScaler(0, 1));

ComputationGraph model;
String modelFilename = "D:\model.zip";

if (new File(modelFilename).exists()) {
	this.output("Load model...");
	model = ComputationGraph.load(new File(modelFilename), true);
} else {
	this.output("Build model...");
	model = TinyYOLO.builder().numClasses(nClasses).priorBoxes(priorBoxes).build().init();
	System.out.println(model.summary(InputType.convolutional(height, width, nChannels)));
	this.output("Train model...");
	model.setListeners(new ScoreIterationListener(1));
	model.fit(train, nEpochs);
	ModelSerializer.writeModel(model, modelFilename, true);
}
// visualize results on the test set
NativeImageLoader imageLoader = new NativeImageLoader();
CanvasFrame frame = new CanvasFrame("WatermelonDetection");
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model
		.getOutputLayer(0);
List<String> labels = train.getLabels();
test.setCollectMetaData(true);
Scalar[] colormap = { RED, BLUE, GREEN, CYAN, YELLOW, MAGENTA, ORANGE, PINK, LIGHTBLUE, VIOLET };
while (test.hasNext() && frame.isVisible()) {
	org.nd4j.linalg.dataset.DataSet ds = test.next();
	RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) ds.getExampleMetaData().get(0);
	INDArray features = ds.getFeatures();
	INDArray results = model.outputSingle(features);
	List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold);
	File file = new File(metadata.getURI());

	Mat mat = imageLoader.asMat(features);
	Mat convertedMat = new Mat();
	mat.convertTo(convertedMat, CV_8U, 255, 0);
	int w = metadata.getOrigW();
	int h = metadata.getOrigH();
	Mat image = new Mat();
	resize(convertedMat, image, new Size(w, h));
	for (DetectedObject obj : objs) {
		double[] xy1 = obj.getTopLeftXY();
		double[] xy2 = obj.getBottomRightXY();
		String label = labels.get(obj.getPredictedClass());
		int x1 = (int) Math.round(w * xy1[0] / gridWidth);
		int y1 = (int) Math.round(h * xy1[1] / gridHeight);
		int x2 = (int) Math.round(w * xy2[0] / gridWidth);
		int y2 = (int) Math.round(h * xy2[1] / gridHeight);
		rectangle(image, new Point(x1, y1), new Point(x2, y2), colormap[obj.getPredictedClass()]);
		putText(image, label, new Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, colormap[obj.getPredictedClass()]);

	}
	frame.setTitle(new File(metadata.getURI()).getName() + " - WatermelonDetection");
	frame.setCanvasSize(w, h);
	frame.showImage(converter.convert(image));
	frame.waitKey();
}
frame.dispose();
	

参数讲解
图片的宽高 :int width = 416;   int height = 416;是固定的 

图片的通道数彩色 是int nChannels = 3;灰图则是nChannels=1,默认为3

算法的特征提取框的个数,yolo tiny 默认个数为13 不能改变 int gridWidth = 13;  int gridHeight = 13;

待检测的类别个数,我这示例是5 个 int nClasses = 5

特征提取先验框的横高比 double[][] priorBoxes = { { 1.5, 2.2 }, { 1.4, 1.95 }, { 1.8, 3.3 }, { 2.4, 2.9 }, { 1.7, 2.2 } };Yolo2中提取先验框需通过Kmeans函数,代码如下


		YoloLabelProvider svhnLabelProvider = new YoloLabelProvider(trainDir.getAbsolutePath());
		DistanceMeasure distanceMeasure = new YoloIOUDistanceMeasure();
		KMeansPlusPlusClusterer<ImageObjectWrapper> clusterer = new KMeansPlusPlusClusterer<>(5, 15, distanceMeasure);
		File[] pngFiles = trainDir.listFiles(new FilenameFilter() {
			private final static String FILENAME_SUFFIX = ".png";

			@Override
			public boolean accept(File dir, String name) {
				return name.endsWith(FILENAME_SUFFIX);
			}
		});
		List<ImageObjectWrapper> clusterInput = Stream.of(pngFiles).flatMap(png -> svhnLabelProvider.getImageObjectsForPath(png.getName()).stream())
				.map(imageObject -> new ImageObjectWrapper(imageObject)).filter(imageObjectWraper -> {
					double[] point = imageObjectWraper.getPoint();
					if (point[0] <= 32d && point[1] <= 32) {//少于一个单元格的不计
						return false;
					}
					return true;
				}).collect(Collectors.toList());
		List<CentroidCluster<ImageObjectWrapper>> clusterResults = clusterer.cluster(clusterInput);
		for (int i = 0; i < clusterResults.size(); i++) {
			CentroidCluster<ImageObjectWrapper> centroidCluster = clusterResults.get(i);
			double[] point = centroidCluster.getCenter().getPoint();
			System.out.println(
					"width:" + point[0] + "  height:" + point[1] + " ratio:" + point[1] / point[0] + " size:" + centroidCluster.getPoints().size());
			System.out.println("bbox amount:" + point[0] / 32 + "," + point[1] / 32);
			ImageObjectWrapper maxWidthImage = centroidCluster.getPoints().stream()
					.collect(Collectors.maxBy(Comparator.comparingDouble(ImageObjectWrapper::getWidth))).get();
			ImageObjectWrapper maxHeightImage = centroidCluster.getPoints().stream()
					.collect(Collectors.maxBy(Comparator.comparingDouble(ImageObjectWrapper::getHeight))).get();
			System.out.println(" width:" + maxWidthImage.getWidth() + " height:" + maxHeightImage.getHeight());
			System.out.println("-----------");
		}
	

上述主要通过Kmeas方法获取训练样本中有代表性的宽高比,需要重新Kmeas的距离测算的方法,改成IOU的形式具体可参照YOLO v2目标检测详解二 计算iou - 灰信网(软件开发博客聚合)

detectionThreshold  是物体检测的置信度阀值,值越高检测出来的物体个数越小,准确率越高

我的训练集是通过LabelImg制作且格式为Yolo,训练样本如下,注意图片的大小要与参数416x416的大小一致

 

 标签类别文件为classes.txt ,包括五个类别xi ,cake ,dan,ss,bi

标签解释提供类YoloLabelProvider代码如下,主要作用是把LabelImg制作出来的txt的数据转化成算法可以识别的

public class YoloLabelProvider implements ImageObjectLabelProvider {
	private String baseDirectory;
	private List<String> labels;

	public YoloLabelProvider(String baseDirectory) {
		this.baseDirectory = baseDirectory;
		Assert.notNull(baseDirectory, "标签目录不能为空");
		if (!new File(baseDirectory).exists()) {
			throw new IllegalStateException(
					"baseDirectory directory does not exist. txt files should be " + "present at  Expected location: " + baseDirectory);
		}
		String classTxtPath = FilenameUtils.concat(this.baseDirectory, "classes.txt");
		File classFile = new File(classTxtPath);
		Assert.isTrue(classFile.exists(), "classTxtPath does not exist");
		try {
			labels = Files.readAllLines(classFile.toPath());
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public List<ImageObject> getImageObjectsForPath(String path) {
		int idx = path.lastIndexOf('/');
		idx = Math.max(idx, path.lastIndexOf('\'));
		String filename = path.substring(idx + 1, path.length() - 4); //-4: ".png"
		String txtPath = FilenameUtils.concat(this.baseDirectory, filename + ".txt");
		String pngPath = FilenameUtils.concat(this.baseDirectory, filename + ".png");
		File txtFile = new File(txtPath);
		if (!txtFile.exists()) {
			throw new IllegalStateException("Could not find TXT file for image " + path + "; expected at " + txtPath);
		}
		List<String> readAllLines = null;
		BufferedImage image = null;
		try {
			image = ImageIO.read(Paths.get(pngPath).toFile());
			readAllLines = Files.readAllLines(txtFile.toPath());
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
		int width = image.getWidth();
		int height = image.getHeight();
		List<ImageObject> imageObjects = readAllLines.stream().map(line -> {
			String[] data = StringUtils.split(line, " ");
			int centerX = Math.round(Float.valueOf(data[1]) * width);
			int centerY = Math.round(Float.valueOf(data[2]) * height);
			int bboxWidth = Math.round(Float.valueOf(data[3]) * width);
			int bboxHeight = Math.round(Float.valueOf(data[4]) * height);
			int xmin = centerX - (bboxWidth / 2);
			int ymin = centerY - (bboxHeight / 2);
			int xmax = centerX + (bboxWidth / 2);
			int ymax = centerY + (bboxHeight / 2);
			ImageObject imageObject = new ImageObject(xmin, ymin, xmax, ymax, this.labels.get(Integer.valueOf(data[0])));
			return imageObject;
		}).collect(Collectors.toList());
		return imageObjects;
	}

	@Override
	public List<ImageObject> getImageObjectsForPath(URI uri) {
		return getImageObjectsForPath(uri.toString());
	}

}

先是训练大概用4个小时训练300多张图片,结果如下

 

 

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇

)">
下一篇>>