Matlab单输入多输出之同时识别手写数字类别和倾斜角度
本文将介绍如何使用matlab构建单输入多输出网络架构,同时识别手写数字类别和倾斜角度。为了实现多输入和多输出,使用了skip connection 和 addition技巧。skip connection将网络在某一层后分为两个分支,单独设置分支来训练倾斜角度的输出。此外,要训练具有多个输出的深度学习网络,还需要自定义损失函数。
1.加载训练数据
加载数字数据:该数据包含数字图像、数字标签及其与垂直方向的旋转角度。
load DigitsDataTrain
为图像、标签和角度创建一个 arrayDatastore 对象,然后使用 combine 函数创建一个包含所有训练数据的单一数据存储。
dsXTrain = arrayDatastore(XTrain,IterationDimension=4); % IterationDimension 是设置循环读取的维度是哪个?
dsT1Train = arrayDatastore(labelsTrain);
dsT2Train = arrayDatastore(anglesTrain);
dsTrain = combine(dsXTrain,dsT1Train,dsT2Train);
% 读取数据
classNames = categories(labelsTrain); %
numClasses = numel(classNames); % 数字有10个
numObservations = numel(labelsTrain); % 一共有5000个手写数字图片
查看训练数据中的一些图像:
idx = randperm(numObservations,64); % 从所有图片里随机抽取64图片
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I)
2.定义深度学习模型
网络架构:
-
A convolution-batchnorm-ReLU block with 16 5-by-5 filters.
-
Two convolution-batchnorm-ReLU blocks each with 32 3-by-3 filters.
-
A skip connection around the previous two blocks containing a convolution-batchnorm-ReLU block with 32 1-by-1 convolutions.
-
Merge the skip connection using addition.
-
For classification output, a branch with a fully connected operation of size 10 (the number of classes) and a softmax operation.
-
For the regression output, a branch with a fully connected operation of size 1 (the number of responses).
type | fiter num | fiter size | activation | 备注 | |
---|---|---|---|---|---|
1 | conv-batchnorm-ReLU | 16 | 5x5 | ReLU | 用于捕捉图像的基本特征,如边缘和纹理。 |
2 | conv-batchnorm-ReLU | 32 | 3x3 | ReLU | 预测数字标签,提取更复杂和抽象的特征 |
3 | conv-batchnorm-ReLU | 32 | 3x3 | ReLU | 预测数字标签,提取更复杂和抽象的特征 |
4 | conv-batchnorm-ReLU | 32 | 1x1 | ReLU | skip connection,用了1×1卷积 |
5 | addition | - | - | - | 合并数字标签和倾斜角度 |
6a | fully connected | - | - | Softmax | 分类输出,输出大小为10 |
6b | fully connected | - | - | - | 回归输出,输出大小为1 |
注意事项:
-
第2层和第3层是两个相同的卷积-批归一化-ReLU块。
-
第4层的跳跃连接跨越了第2层和第3层。
-
第6层分为两个分支:6a用于分类输出,6b用于回归输出。
这个架构结合了卷积神经网络的特征提取能力和跳跃连接的优势,同时支持多任务学习(分类和回归)。
net = dlnetwork;
layers = [
imageInputLayer([28 28 1],Normalization="none")
convolution2dLayer(5,16,Padding="same")
batchNormalizationLayer
reluLayer(Name="relu_1")
convolution2dLayer(3,32,Padding="same",Stride=2)
batchNormalizationLayer
reluLayer
convolution2dLayer(3,32,Padding="same")
batchNormalizationLayer
reluLayer
additionLayer(2,Name="add")
fullyConnectedLayer(numClasses)
softmaxLayer(Name="softmax")];
net = addLayers(net,layers);
layers = [
convolution2dLayer(1,32,Stride=2,Name="conv_skip")
batchNormalizationLayer
reluLayer(Name="relu_skip")];
net = addLayers(net,layers);
net = connectLayers(net,"relu_1","conv_skip");
net = connectLayers(net,"relu_skip","add/in2");
添加用于回归的全连接层:
layers = fullyConnectedLayer(1,Name="fc_2");
net = addLayers(net,layers);
net = connectLayers(net,"add","fc_2");
查看网络架构:
figure
plot(net)
analyzeNetwork(net)
3.训练神经网络
options = trainingOptions("adam", ...
Plots="training-progress", ...
Verbose=false);
使用 trainnet 函数训练神经网络。对于分类,使用自定义损失函数,即预测和目标标签的交叉熵损失加上预测和目标角度的均方误差损失的0.1倍。
将自定义损失函数定义为函数句柄。定义一个损失,该损失对应于预测和目标标签的交叉熵损失加上预测和目标角度的均方误差,按0.1倍缩放。
lossFcn = @(Y1,Y2,T1,T2) crossentropy(Y1,T1) + 0.1*mse(Y2,T2);
训练神经网络:
net = trainnet(dsTrain,net,lossFcn,options);
4.测试模型
加载用于测试的数字数据,数据包含数字图像、数字标签及其与垂直方向的旋转角度。
load DigitsDataTest
dsXTest = arrayDatastore(XTest,IterationDimension=4);
dsT1Test = arrayDatastore(labelsTest);
dsT2Test = arrayDatastore(anglesTest);
dsTest = combine(dsXTest,dsT1Test,dsT2Test);
使用 minibatchpredict 函数进行预测,我们的模型有两个输出,一个是预测的数字标签,另一个是数字的倾斜角度
[scores,Y2] = minibatchpredict(net,dsTest);
Y1 = scores2label(scores,classNames);
计算标签的预测accuracy:
accuracy = mean(Y1 == labelsTest) % accuracy = 0.9840
计算预测角度和目标角度之间的均方根误差:
err = rmse(Y2,anglesTest) % err = 8.4071
查看一些带有预测的图像:以红色显示预测的倾斜角度,以绿色显示实际倾斜角度。
idx = randperm(size(XTest,4),9);
figure
for i = 1:9
subplot(3,3,i)
I = XTest(:,:,:,idx(i));
imshow(I)
hold on
% 绘制倾斜角度
sz = size(I,1);
offset = sz/2;
% 预测的
degreePred = Y2(idx(i));
plot(offset*[1-tand(degreePred) 1+tand(degreePred)],[sz 0],"r--") % 之所以y是[sz 0],是因为imshow的y轴是越大在越下方)
% 实际的
degreeTest = anglesTest(idx(i));
plot(offset*[1-tand(degreeTest) 1+tand(degreeTest)],[sz 0],"g--")
hold off
% 显示预测label
labelPred = Y1(idx(i));
label = labelsTest(idx(i));
titleStr = sprintf('预测:%s(%.2f)\n实际:%s(%.2f)',string(labelPred),degreePred,string(label),degreeTest);
% 判断预测是否正确,设置标题颜色
if labelPred == label
titleColor = 'g'; % 绿色
else
titleColor = 'r'; % 红色
end
title(titleStr, 'Color', titleColor);
end