当前位置: 首页 > article >正文

Matlab单输入多输出之同时识别手写数字类别和倾斜角度

本文将介绍如何使用matlab构建单输入多输出网络架构,同时识别手写数字类别和倾斜角度。为了实现多输入和多输出,使用了skip connection 和 addition技巧。skip connection将网络在某一层后分为两个分支,单独设置分支来训练倾斜角度的输出。此外,要训练具有多个输出的深度学习网络,还需要自定义损失函数。

1.加载训练数据

加载数字数据:该数据包含数字图像、数字标签及其与垂直方向的旋转角度。

load DigitsDataTrain

为图像、标签和角度创建一个 arrayDatastore 对象,然后使用 combine 函数创建一个包含所有训练数据的单一数据存储。

dsXTrain = arrayDatastore(XTrain,IterationDimension=4); % IterationDimension 是设置循环读取的维度是哪个?
dsT1Train = arrayDatastore(labelsTrain);
dsT2Train = arrayDatastore(anglesTrain);

dsTrain = combine(dsXTrain,dsT1Train,dsT2Train);
% 读取数据
classNames = categories(labelsTrain); % 
numClasses = numel(classNames); % 数字有10个
numObservations = numel(labelsTrain); % 一共有5000个手写数字图片

查看训练数据中的一些图像:

idx = randperm(numObservations,64); % 从所有图片里随机抽取64图片
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I)

2.定义深度学习模型

网络架构:

  • A convolution-batchnorm-ReLU block with 16 5-by-5 filters.

  • Two convolution-batchnorm-ReLU blocks each with 32 3-by-3 filters.

  • A skip connection around the previous two blocks containing a convolution-batchnorm-ReLU block with 32 1-by-1 convolutions.

  • Merge the skip connection using addition.

  • For classification output, a branch with a fully connected operation of size 10 (the number of classes) and a softmax operation.

  • For the regression output, a branch with a fully connected operation of size 1 (the number of responses).

typefiter numfiter sizeactivation备注
1conv-batchnorm-ReLU165x5ReLU用于捕捉图像的基本特征,如边缘和纹理。
2conv-batchnorm-ReLU323x3ReLU预测数字标签,提取更复杂和抽象的特征
3conv-batchnorm-ReLU323x3ReLU预测数字标签,提取更复杂和抽象的特征
4conv-batchnorm-ReLU321x1ReLUskip connection,用了1×1卷积
5addition---合并数字标签和倾斜角度
6afully connected--Softmax分类输出,输出大小为10
6bfully connected---回归输出,输出大小为1

注意事项:

  1. 第2层和第3层是两个相同的卷积-批归一化-ReLU块。

  2. 第4层的跳跃连接跨越了第2层和第3层。

  3. 第6层分为两个分支:6a用于分类输出,6b用于回归输出。

这个架构结合了卷积神经网络的特征提取能力和跳跃连接的优势,同时支持多任务学习(分类和回归)。

net = dlnetwork;

layers = [
    imageInputLayer([28 28 1],Normalization="none")

    convolution2dLayer(5,16,Padding="same")
    batchNormalizationLayer
    reluLayer(Name="relu_1")

    convolution2dLayer(3,32,Padding="same",Stride=2)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,32,Padding="same")
    batchNormalizationLayer
    reluLayer

    additionLayer(2,Name="add")

    fullyConnectedLayer(numClasses)
    softmaxLayer(Name="softmax")];
  
net = addLayers(net,layers);
layers = [
    convolution2dLayer(1,32,Stride=2,Name="conv_skip")
    batchNormalizationLayer
    reluLayer(Name="relu_skip")];

net = addLayers(net,layers);
net = connectLayers(net,"relu_1","conv_skip");
net = connectLayers(net,"relu_skip","add/in2");

添加用于回归的全连接层:

layers = fullyConnectedLayer(1,Name="fc_2");
net = addLayers(net,layers);
net = connectLayers(net,"add","fc_2");

查看网络架构:

figure
plot(net)
analyzeNetwork(net)

图片

3.训练神经网络

options = trainingOptions("adam", ...
    Plots="training-progress", ...
    Verbose=false);

使用 trainnet 函数训练神经网络。对于分类,使用自定义损失函数,即预测和目标标签的交叉熵损失加上预测和目标角度的均方误差损失的0.1倍。

将自定义损失函数定义为函数句柄。定义一个损失,该损失对应于预测和目标标签的交叉熵损失加上预测和目标角度的均方误差,按0.1倍缩放。

lossFcn = @(Y1,Y2,T1,T2) crossentropy(Y1,T1) + 0.1*mse(Y2,T2);

训练神经网络:

net = trainnet(dsTrain,net,lossFcn,options);

图片

4.测试模型

加载用于测试的数字数据,数据包含数字图像、数字标签及其与垂直方向的旋转角度。

load DigitsDataTest
dsXTest = arrayDatastore(XTest,IterationDimension=4);
dsT1Test = arrayDatastore(labelsTest);
dsT2Test = arrayDatastore(anglesTest);

dsTest = combine(dsXTest,dsT1Test,dsT2Test);

使用 minibatchpredict 函数进行预测,我们的模型有两个输出,一个是预测的数字标签,另一个是数字的倾斜角度

[scores,Y2] = minibatchpredict(net,dsTest);
Y1 = scores2label(scores,classNames);

计算标签的预测accuracy:

accuracy = mean(Y1 == labelsTest) % accuracy = 0.9840

计算预测角度和目标角度之间的均方根误差:

err = rmse(Y2,anglesTest) % err = 8.4071

查看一些带有预测的图像:以红色显示预测的倾斜角度,以绿色显示实际倾斜角度。

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)
    hold on

    % 绘制倾斜角度
    sz = size(I,1);
    offset = sz/2;
    % 预测的
    degreePred = Y2(idx(i));
    plot(offset*[1-tand(degreePred) 1+tand(degreePred)],[sz 0],"r--") % 之所以y是[sz 0],是因为imshow的y轴是越大在越下方)
    % 实际的
    degreeTest = anglesTest(idx(i));
    plot(offset*[1-tand(degreeTest) 1+tand(degreeTest)],[sz 0],"g--")

    hold off
    % 显示预测label
    labelPred = Y1(idx(i));
    label = labelsTest(idx(i));
    titleStr = sprintf('预测:%s(%.2f)\n实际:%s(%.2f)',string(labelPred),degreePred,string(label),degreeTest);
    % 判断预测是否正确,设置标题颜色
    if labelPred == label
        titleColor = 'g';  % 绿色
    else
        titleColor = 'r';  % 红色
    end

    title(titleStr, 'Color', titleColor);

end


http://www.kler.cn/a/400465.html

相关文章:

  • 每日一练 | 包过滤防火墙的工作原理
  • python+Django+MySQL+echarts+bootstrap制作的教学质量评价系统,包括学生、老师、管理员三种角色
  • 【含开题报告+文档+PPT+源码】基于springboot的教师评价系统的设计与实现
  • 游戏引擎学习第12天
  • Sql进阶:字段中包含CSV,如何通过Sql解析CSV成多行多列?
  • 如何在 Ubuntu 上安装 Emby 媒体服务器
  • 用 Android Studio 从零开发一个多功能计算器应用
  • 集群聊天服务器(9)一对一聊天功能
  • 数据科学与SQL:如何计算排列熵?| 基于SQL实现
  • 10月回顾 | Apache SeaTunnel社区动态与进展一览
  • 【jvm】方法区的理解
  • 讨论大语言模型在学术文献应用中的未来与所带来的可能性和担忧
  • C++笔试面试题
  • leetcode 扫描线专题 06-leetcode.836 rectangle-overlap 力扣.836 矩形重叠
  • 无人机动力系统节能技术的未来发展趋势——CKESC电调小课堂12.1
  • Python 神经网络项目常用语法
  • C++---智能指针和内存泄露
  • 【网络安全 | 漏洞挖掘】邮件HTML注入
  • 群控系统服务端开发模式-应用开发-前端部门功能开发
  • 传奇996_25——ctrl+f11,ui标签,绘制自定义面板的参数
  • 学习大数据DAY61 宽表加工
  • 【惠州大亚湾】之维修戴尔服务器DELLR730XD
  • vue下载后端提供的文件/播放音频文件
  • Redis的缓存穿透、缓存雪崩、缓存击穿问题及有效解决方案
  • 初始Python篇(3)—— 列表
  • 大数据新视界 -- 大数据大厂之 Impala 性能优化:集群资源动态分配的智慧(上)(23 / 30)