当前位置: 首页 > article >正文

Matlab多输入单输出之倾斜手写数字识别

本文主要介绍使用matlab构建多输入单输出的网络架构,来实现倾斜的手写数字识别,使用concatenationLayer来拼接特征,实现网络输入多个特征。

1.加载训练数据

加载数据:手写数字的图像、真实数字标签和数字顺时针旋转的角度。

load DigitsDataTrain

网络的输入数据类型需要是datastore,使用 arrayDatastore 将三个普通矩阵变为datastore,最后再使用combine合并。

dsX1Train = arrayDatastore(XTrain,IterationDimension=4);
dsX2Train = arrayDatastore(anglesTrain);
dsTTrain = arrayDatastore(labelsTrain);
dsTrain = combine(dsX1Train,dsX2Train,dsTTrain);

显示20个随机训练图像:

numObservationsTrain = numel(labelsTrain);
idx = randperm(numObservationsTrain,20);

figure
tiledlayout("flow");
for i = 1:numel(idx)
    nexttile
    imshow(XTrain(:,:,:,idx(i)))
    title("Angle: " + anglesTrain(idx(i)))
end

图片

2.设计网络架构

设计如下的网络结构:

图片

  • 对于图像输入,指定一个大小与输入数据匹配的图像输入层。

  • 对于特征输入,指定一个大小与输入特征数量匹配的特征输入层。

  • 对于图像输入分支,进行卷积、批归一化和ReLU层块,其中卷积层有16个5×5滤波器。

  • 为了将批归一化层的输出转换为特征向量,需要用一个大小为50的全连接层。

  • 要将第一个全连接层的输出与特征输入连接起来,使用flatten layer将全连接层中的 "SSCB" (空间、空间、通道、批处理)输出展平,使其具有 "CB" 格式。

  • 沿第一维度(channel维度)将平坦层的输出与特征输入连接起来

  • 对于分类输出,包括一个输出大小与类数匹配的全连接层,然后是softmax层。

创建一个空的神经网络:

net = dlnetwork;

创建一个网络主分支,并将其添加到网络中:

[h,w,numChannels,numObservations] = size(XTrain);
numFeatures = 1;
classNames = categories(labelsTrain);
numClasses = numel(classNames);

imageInputSize = [h w numChannels];
filterSize = 5;
numFilters = 16;

layers = [
    imageInputLayer(imageInputSize,Normalization="none")
    convolution2dLayer(filterSize,numFilters)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(50)
    flattenLayer
    concatenationLayer(1,2,Name="cat")
    fullyConnectedLayer(numClasses)
    softmaxLayer];

net = addLayers(net,layers);

将feature input layer添加到网络中,并将其连接到 concatenation layer的第二个输入:

featInput = featureInputLayer(numFeatures,Name="features");
net = addLayers(net,featInput);
net = connectLayers(net,"features","cat/in2");

在绘图中可视化网络:

figure
plot(net)

3.训练网络

使用SGDM优化器进行训练,训练15个epochs,以0.01的学习率进行训练,在图表中显示训练进度并监控accuracy指标,不显示详细输出。

options = trainingOptions("sgdm", ...
    MaxEpochs=15, ...
    InitialLearnRate=0.01, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=0);

使用 trainnet 函数训练神经网络,对于分类使用交叉熵损失。

net = trainnet(dsTrain,net,"crossentropy",options);

图片

4.测试网络

通过将测试集上的预测与真实标签进行比较来测试网络的分类准确性,加载测试数据:

load DigitsDataTest

使用 minibatchpredict 函数进行预测,并使用 scores2label 函数将分数转换为标签。

scores = minibatchpredict(net,XTest,anglesTest);
YTest = scores2label(scores,classNames);

在混淆图中可视化预测:

figure
confusionchart(labelsTest,YTest)

图片

评估分类准确性:

accuracy = mean(YTest == labelsTest)

accuracy = 0.9878

查看一些预测的图像:

idx = randperm(size(XTest,4),9);
figure
tiledlayout(3,3)
for i = 1:9
    nexttile
    I = XTest(:,:,:,idx(i));
    imshow(I)

    label = string(YTest(idx(i)));
    title("Predicted Label: " + label)
end

图片

5.不用角度特征训练和测试网络

% 网络设计
net_without_feature = dlnetwork;
layers = [
    imageInputLayer(imageInputSize,Normalization="none")
    convolution2dLayer(filterSize,numFilters)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

net_without_feature = addLayers(net_without_feature,layers);
% 网络训练
options = trainingOptions("sgdm", ...
    MaxEpochs=15, ...
    InitialLearnRate=0.01, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=0);

dsTrain_without_feature = combine(dsX1Train,dsTTrain);

net_without_feature = trainnet(dsTrain_without_feature,net_without_feature,"crossentropy",options);

图片

% 在混淆矩阵中可视化预测。
scores = minibatchpredict(net_without_feature,XTest);
YTest = scores2label(scores,classNames);
figure
confusionchart(labelsTest,YTest)

图片

% 评估分类准确性。
accuracy = mean(YTest == labelsTest)

accuracy = 0.9858


http://www.kler.cn/a/403030.html

相关文章:

  • c++中操作数据库的常用函数
  • Vue3-后台管理系统
  • 力扣 LeetCode 235. 二叉搜索树的最近公共祖先(Day10:二叉树)
  • 第 24 章 -Golang 性能优化
  • 二叉树路径相关算法题|带权路径长度WPL|最长路径长度|直径长度|到叶节点路径|深度|到某节点的路径非递归(C)
  • gitlab:使用脚本批量下载项目,实现全项目检索
  • os库的常见使用
  • 星融元与焱融科技AI分布式存储软硬件完成兼容性互认证
  • 13.C++内存管理2(C++ new和delete的使用和原理详解,内存泄漏问题)
  • 数据结构(双向链表——c语言实现)
  • Restful API 规范详解
  • 单片机学习笔记 2. LED灯闪烁
  • c++--------《set 和 map》
  • C++手写PCD文件
  • 使用Kotlin写一个将字符串加密成short数组,然后可以解密还原成原始的字符串的功能
  • 前端页面自适应等比例缩放 Flexible+rem方案
  • 小程序-基于java+SpringBoot+Vue的超市购物系统设计与实现
  • 【React 进阶】掌握 React18 全部 Hooks
  • 鸿蒙原生应用开发元服务 元服务是什么?和App的关系?(保姆级步骤)
  • 详解八大排序(一)------(插入排序,选择排序,冒泡排序,希尔排序)
  • Linux驱动开发第2步_“物理内存”和“虚拟内存”的映射
  • EDA实验设计-led灯管动态显示;VHDL;Quartus编程
  • Ubuntu24.04LTS设置root用户可远程登录
  • Flutter踩坑记录(一)debug运行生成的项目,不能手动点击运行
  • Qt5-雷达项目
  • C++零基础入门:趣味学信息学奥赛从“Hello World”开始