当前位置: 首页 > article >正文

使⽤MATLAB进⾏⽬标检测

目录

  • 数据准备
  • 定义模型并训练
  • 用测试集评估性能
  • 推理过程
  • ⼀⾏代码查看⽹络结构
  • ⼀⾏代码转onnx
  • 结语

⼈⽣苦短,我⽤MATLAB。
Pytorch在深度学习领域占据了半壁江⼭,最主要的原因是⽣态完善,⽽且api直观易⽤。但谁能想到现在MATLAB⽤起来⽐Pytorch还好⽤。从数据集划分到训练,再到性能验证和画图,仅仅使⽤了⼏⼗⾏代码。炼丹师们终于可以解放编码时间,把⾃⼰的精⼒放在摸⻥(划掉)算法本身上了。
下⾯⽤⾃⼰的数据集训练⼀个YOLOv4,看看MATLAB到底怎么个事。

数据准备

我的数据内容是⾦属锻件磁粉探伤所显示的零件缺陷。如下例所示,轴孔上有⼀处明显缺陷:
在这里插入图片描述
数据集格式为Pascal VOC格式(详⻅http://host.robots.ox.ac.uk/pascal/VOC/)。⾸先要将格式的标签转换为MATLAB能够使⽤的格式。使⽤以下函数:

function convertXMLtoMAT(xmlFolder, outputMATFile)
	% xmlFolder: 存放所有XML⽂件的⽂件夹
	% outputMATFile: 输出的MAT⽂件名
	xmlFiles = dir(fullfile(xmlFolder, '*.xml')); imageFilenames = {};
	boundingBoxes = {};

	for i = 1:length(xmlFiles) 
		try
			xmlFilePath	fullfile(xmlFiles(i).folder, xmlFiles(i).name); xmlDoc = xmlread(xmlFilePath);
			filenameNode = xmlDoc.getElementsByTagName('filename').item(0);
			filename = char(filenameNode.getFirstChild.getData); 
			pathNode = xmlDoc.getElementsByTagName('path').item(0); 
			newPath = fullfile(pwd, 'img', filename);
			xminNode = xmlDoc.getElementsByTagName('xmin').item(0); 
			yminNode = xmlDoc.getElementsByTagName('ymin').item(0); 
			xmaxNode = xmlDoc.getElementsByTagName('xmax').item(0); 
			ymaxNode = xmlDoc.getElementsByTagName('ymax').item(0); 
			xmin = str2double(xminNode.getFirstChild.getData);
			ymin = str2double(yminNode.getFirstChild.getData); 
			xmax = str2double(xmaxNode.getFirstChild.getData); 
			ymax = str2double(ymaxNode.getFirstChild.getData);
			
			width = xmax - xmin; height = ymax - ymin;
			bbox = [xmin, ymin, width, height];
			
			imageFilenames{end+1, 1} = newPath; boundingBoxes{end+1, 1} = bbox;
		catch
			% pass
		end
	end
			
	% 创建table并保存为MAT⽂件
	myDataset = table(imageFilenames, boundingBoxes, ... 
		'VariableNames', {'imageFilename', 'defect'});
	save(outputMATFile, 'myDataset');
	% 输出MAT⽂件的前⼏⾏
	disp('数据集的前⼏⾏:');
	disp(myDataset(1:4, :));

end

转换后的标签⽂件就变成⼀个mat⽂件,加载后就是以下的样⼦:

data = load("my_dataset.mat"); 
defectDataset = data.myDataset;
% 显示数据
disp(defectDataset(1:5, :));
table sample_num * 2 :
			imageFilename					defect
		---------------------		---------------------
		{'img/00001.jpg'}			{[2635 435 133 682]}
		{'img/00001_z0.jpg'}		{[2574 1086 125 561]}
		{'img/00001_z1.jpg'}		{[2569 720 386 596]}
		{'img/00001_z2.jpg'}		{[2608 951 303 647]}
		{'img/00001_z3.jpg'}		{[1748 947 303 606]}

第⼆步,划分数据集。⼀般来讲,按照7:2:1的⽐例拆分训练集、测试集和验证集。

trainingRatio = 0.7;
validationRatio = 0.1;

rng("default");
shuffledIndices = randperm(height(defectDataset));
idx = floor(trainingRatio * length(shuffledIndices) );

trainingIdx = 1:idx;
trainingDataTbl = defectDataset(shuffledIndices(trainingIdx),:);

validationIdx = idx+1 : idx + 1 + floor(validationRatio * length(shuffledIndices) );
validationDataTbl = defectDataset(shuffledIndices(validationIdx),:);

testIdx = validationIdx(end)+1 : length(shuffledIndices); 
testDataTbl = defectDataset(shuffledIndices(testIdx),:);

第三步,创建 Datastore Datastore 是MATLAB特有的⼀个概念,作⽤类似于Pytorch中的DataLoader 。它集成了⼀些常⽤的操作,如计数,按序读取,打乱顺序,合并数据集等⽅法。MATLAB的Datastore ⽐torch的DataLoader 更⽅便,⼀般⽆需⾃⼰定义,拿来就⽤。

imdsTrain = imageDatastore(trainingDataTbl{:,"imageFilename"}); 
bldsTrain = boxLabelDatastore(trainingDataTbl(:,"defect"));

imdsValidation = imageDatastore(validationDataTbl{:,"imageFilename"}); 
bldsValidation = boxLabelDatastore(validationDataTbl(:,"defect"));

imdsTest = imageDatastore(testDataTbl{:,"imageFilename"}); 
bldsTest = boxLabelDatastore(testDataTbl(:,"defect"));
% 合并
trainingData = combine(imdsTrain,bldsTrain); 
validationData = combine(imdsValidation,bldsValidation); 
testData = combine(imdsTest,bldsTest);

最后,为了验证数据是否正确,可以画⼀个样本检查⼀下:

data = read(trainingData); 
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,"Rectangle",bbox,LineWidth=10); 
annotatedImage = imresize(annotatedImage,2);

figure 
imshow(annotatedImage); 
reset(trainingData);

在这里插入图片描述

定义模型并训练

这部分是本⽂的重点,但是却很短,因为MATLAB的流程太简洁了,狠狠爱了。

⾸先定义模型的主要参数:

optimizer = "adam"; % 优化器
gradientDecayFactor = 0.9; % 梯度衰减因⼦ 
squaredGradientDecayFactor = 0.999; % 平⽅梯度衰减因⼦ 
initialLearnRate = 0.001; % 初始学习率 
learnRateSchedule = 'none'; % 学习率衰减策略 
miniBatchSize = 4; % 批⼤⼩
L2Regularization = 0.0005; % L2正则化 
MaxEpochs = 100; % 迭代轮数
inputSize = [416 416 3]; 
className = "defect";

第⼆步,设置yolov4的Anchors

rng("default") 
trainingDataForEstimation =transform(trainingData,@(data)preprocessData(data,inputSize)); 
numAnchors = 6;
[anchors,meanIoU] = estimateAnchorBoxes(trainingDataForEstimation,numAnchors);

area = anchors(:, 1).*anchors(:,2);
[~,idx] = sort(area,"descend");

anchors = anchors(idx,:); 
anchorBoxes = {anchors(1:3,:)
anchors(4:6,:)};

最后开始训练模型就ok了:

detector = yolov4ObjectDetector("tiny-yolov4-coco", ... 
className,anchorBoxes,InputSize=inputSize);
%  如果你选择了从头开始训练,那么就得定义⼀⼤堆参数
options = trainingOptions(optimizer, ...
	GradientDecayFactor=gradientDecayFactor, ... 
	SquaredGradientDecayFactor=squaredGradientDecayFactor, ... 
	InitialLearnRate=initialLearnRate, ...
	LearnRateSchedule=learnRateSchedule, ... 
	MiniBatchSize=miniBatchSize, ...
	L2Regularization=L2Regularization, ... 
	MaxEpochs=MaxEpochs, ...
	DispatchInBackground=true, ... 
	ResetInputNormalization=true, ... 
	Shuffle="every-epoch", ...
	VerboseFrequency=20, ... 
	ValidationFrequency=1000, ... 
	CheckpointPath=tempdir, ...
	ValidationData=validationData, ... 
	OutputNetwork="best-validation-loss");

% 开始训练
[detector,info] = trainYOLOv4ObjectDetector(augmentedTrainingData,detector,options);
% 保存
save('yolov4_detector.mat',"detector"); % 模型
save('yolov4_detector_info.mat',"info"); % 训练过程

训练时会有如下输出:

Computing Input Normalization Statistics.
*************************************************************************  
Training a YOLO v4 Object Detector for the following object classes:

* defect

正在使⽤  'Processes' 配置⽂件启动并⾏池(parpool)...
已连接到具有 8 个⼯作进程的并⾏池。

Epoch	Iteration	TimeElapsed	LearnRate TrainingLoss	ValidationLoss
								
1	1		00:00:10		0.001		1174.1	1124.1
1	20		00:00:36		0.001		80.843	
1	40		00:00:52		0.001		30.598	
1	60		00:01:08		0.001		20.106	
1	80		00:01:21		0.001		86.183	
1	100		00:01:34		0.001		14.755	
2	120		00:01:53		0.001		13.844	
2	140		00:02:06		0.001		13.443	
2	160		00:02:19		0.001		12.067	
2	180		00:02:32		0.001		10.023	
2	200		00:02:45		0.001		11.157	
2	220		00:02:59		0.001		11.613	

*************************************************************************  
Detector training complete.
*************************************************************************

训练完成后,模型和相关的记录会分别保存在yolov4_detector.matyolov4_detector_info.mat中。

用测试集评估性能

⽤来评估性能的

detectionResults = detect(detector,testData,Threshold=0.01); 
metrics = evaluateObjectDetection(detectionResults,testData); 
AP = averagePrecision(metrics);
[precision,recall] = precisionRecall(metrics,ClassName=className);

% 画pr图
figure 1 
plot(recall{:},precision{:}) 
xlabel("Recall") 
ylabel("Precision")
grid on
title(sprintf("Average Precision = %.2f",AP)) 
imshow(I);

% 再画个loss曲线
figure 2 
plot(info.TrainingLoss)
xlabel("迭代次数")
ylabel("loss") title('Loss曲线')

在这里插入图片描述
在这里插入图片描述

你可以说我练的不怎么样,但不能说MATLAB不⾏,因为在torch上也是⼀样的效果(哭…

推理过程

I = imread("img/00029.jpg"); 
[bboxes,scores,labels] = detect(detector,I);
I = insertObjectAnnotation(I,"rectangle",bboxes,scores, ...
	LineWidth=10,FontSize=72); figure
imshow(I)

在这里插入图片描述

⼀⾏代码查看⽹络结构

MATLAB⾥⾯⼀⾏代码可以做到图形化的查看⽹络结构,更专注于分析算法。

net = detector.Network 
analyzeNetwork(net)

在这里插入图片描述

⼀⾏代码转onnx

exportONNXNetwork(net,'yolov4.onnx')

导出⽹络后,可以将该⽹络导⼊到其他深度学习框架中进⾏推理。妈妈再也不⽤担⼼我的部署了。

结语

MATLAB完全解放了炼丹师的编码时间,让炼丹师能够专注于算法本身,我愿称之为最好的深度学习框架(如果你有licence的话)


http://www.kler.cn/a/400649.html

相关文章:

  • AWD脚本编写_1
  • ROS Action
  • 在MATLAB中导入TXT文件的若干方法
  • 大数据算法考试习题
  • STM32串口——5个串口的使用方法
  • Linux:调试器-gdb/cgdb
  • 数字化转型的三个阶段:信息化、数字化、数智化
  • 软考-信息安全-网络安全体系与网络安全模型
  • 高级java面试---spring.factories文件的解析源码API机制
  • Vue基础(2)_el和data的两种写法
  • uni-ui自动化导入
  • element ui 走马灯一页展示多个数据实现
  • MATLAB绘制克莱因瓶
  • QT6学习第三天
  • 驰骋资讯高速:Spring Boot汽车新闻网站
  • Idea中创建和联系MySQL等数据库
  • C#中的方法
  • 【数据中台资料大合集】大数据平台、数据湖、指标池建设,数据中台底层数据采集管理,ETL数据清洗(Word原件,PPT原件)
  • lua脚本语言基本原理
  • mysql的mvcc机制中,read view是什么时候生成的?
  • 游戏引擎学习第13天
  • 使用 JavaScript 制作 To-Do List
  • 06 - Clickhouse的表引擎
  • 【3D Slicer】的小白入门使用指南十
  • React(一)
  • 【Golang】——Gin 框架中的路由与请求处理