Matlab多输入单输出之倾斜手写数字识别
本文主要介绍使用matlab构建多输入单输出的网络架构,来实现倾斜的手写数字识别,使用concatenationLayer来拼接特征,实现网络输入多个特征。
1.加载训练数据
加载数据:手写数字的图像、真实数字标签和数字顺时针旋转的角度。
load DigitsDataTrain
网络的输入数据类型需要是datastore,使用 arrayDatastore 将三个普通矩阵变为datastore,最后再使用combine合并。
dsX1Train = arrayDatastore(XTrain,IterationDimension=4);
dsX2Train = arrayDatastore(anglesTrain);
dsTTrain = arrayDatastore(labelsTrain);
dsTrain = combine(dsX1Train,dsX2Train,dsTTrain);
显示20个随机训练图像:
numObservationsTrain = numel(labelsTrain);
idx = randperm(numObservationsTrain,20);
figure
tiledlayout("flow");
for i = 1:numel(idx)
nexttile
imshow(XTrain(:,:,:,idx(i)))
title("Angle: " + anglesTrain(idx(i)))
end
2.设计网络架构
设计如下的网络结构:
-
对于图像输入,指定一个大小与输入数据匹配的图像输入层。
-
对于特征输入,指定一个大小与输入特征数量匹配的特征输入层。
-
对于图像输入分支,进行卷积、批归一化和ReLU层块,其中卷积层有16个5×5滤波器。
-
为了将批归一化层的输出转换为特征向量,需要用一个大小为50的全连接层。
-
要将第一个全连接层的输出与特征输入连接起来,使用flatten layer将全连接层中的 "SSCB" (空间、空间、通道、批处理)输出展平,使其具有 "CB" 格式。
-
沿第一维度(channel维度)将平坦层的输出与特征输入连接起来
-
对于分类输出,包括一个输出大小与类数匹配的全连接层,然后是softmax层。
创建一个空的神经网络:
net = dlnetwork;
创建一个网络主分支,并将其添加到网络中:
[h,w,numChannels,numObservations] = size(XTrain);
numFeatures = 1;
classNames = categories(labelsTrain);
numClasses = numel(classNames);
imageInputSize = [h w numChannels];
filterSize = 5;
numFilters = 16;
layers = [
imageInputLayer(imageInputSize,Normalization="none")
convolution2dLayer(filterSize,numFilters)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(50)
flattenLayer
concatenationLayer(1,2,Name="cat")
fullyConnectedLayer(numClasses)
softmaxLayer];
net = addLayers(net,layers);
将feature input layer添加到网络中,并将其连接到 concatenation layer的第二个输入:
featInput = featureInputLayer(numFeatures,Name="features");
net = addLayers(net,featInput);
net = connectLayers(net,"features","cat/in2");
在绘图中可视化网络:
figure
plot(net)
3.训练网络
使用SGDM优化器进行训练,训练15个epochs,以0.01的学习率进行训练,在图表中显示训练进度并监控accuracy指标,不显示详细输出。
options = trainingOptions("sgdm", ...
MaxEpochs=15, ...
InitialLearnRate=0.01, ...
Plots="training-progress", ...
Metrics="accuracy", ...
Verbose=0);
使用 trainnet 函数训练神经网络,对于分类使用交叉熵损失。
net = trainnet(dsTrain,net,"crossentropy",options);
4.测试网络
通过将测试集上的预测与真实标签进行比较来测试网络的分类准确性,加载测试数据:
load DigitsDataTest
使用 minibatchpredict 函数进行预测,并使用 scores2label 函数将分数转换为标签。
scores = minibatchpredict(net,XTest,anglesTest);
YTest = scores2label(scores,classNames);
在混淆图中可视化预测:
figure
confusionchart(labelsTest,YTest)
评估分类准确性:
accuracy = mean(YTest == labelsTest)
accuracy = 0.9878
查看一些预测的图像:
idx = randperm(size(XTest,4),9);
figure
tiledlayout(3,3)
for i = 1:9
nexttile
I = XTest(:,:,:,idx(i));
imshow(I)
label = string(YTest(idx(i)));
title("Predicted Label: " + label)
end
5.不用角度特征训练和测试网络
% 网络设计
net_without_feature = dlnetwork;
layers = [
imageInputLayer(imageInputSize,Normalization="none")
convolution2dLayer(filterSize,numFilters)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer];
net_without_feature = addLayers(net_without_feature,layers);
% 网络训练
options = trainingOptions("sgdm", ...
MaxEpochs=15, ...
InitialLearnRate=0.01, ...
Plots="training-progress", ...
Metrics="accuracy", ...
Verbose=0);
dsTrain_without_feature = combine(dsX1Train,dsTTrain);
net_without_feature = trainnet(dsTrain_without_feature,net_without_feature,"crossentropy",options);
% 在混淆矩阵中可视化预测。
scores = minibatchpredict(net_without_feature,XTest);
YTest = scores2label(scores,classNames);
figure
confusionchart(labelsTest,YTest)
% 评估分类准确性。
accuracy = mean(YTest == labelsTest)
accuracy = 0.9858