当前位置: 首页 > article >正文

分类预测 | MATLAB实现CNN-BiLSTM-Attention多输入分类预测

分类预测 | MATLAB实现CNN-BiLSTM-Attention多输入分类预测

目录

    • 分类预测 | MATLAB实现CNN-BiLSTM-Attention多输入分类预测
      • 分类效果
      • 基本介绍
      • 模型描述
      • 程序设计
      • 参考资料

分类效果

1
2
3
4
5
6

基本介绍

MATLAB实现CNN-BiLSTM-Attention多输入分类预测,CNN-BiLSTM结合注意力机制多输入分类预测。

模型描述

Matlab实现CNN-BiLSTM-Attention多变量分类预测
1.data为数据集,格式为excel,12个输入特征,输出四个类别;
2.MainCNN_BiLSTM_AttentionNC.m为主程序文件,运行即可;

注意程序和数据放在一个文件夹,运行环境为Matlab200b及以上。
4.注意力机制模块:
SEBlock(Squeeze-and-Excitation Block)是一种聚焦于通道维度而提出一种新的结构单元,为模型添加了通道注意力机制,该机制通过添加各个特征通道的重要程度的权重,针对不同的任务增强或者抑制对应的通道,以此来提取有用的特征。该模块的内部操作流程如图,总体分为三步:首先是Squeeze 压缩操作,对空间维度的特征进行压缩,保持特征通道数量不变。融合全局信息即全局池化,并将每个二维特征通道转换为实数。实数计算公式如公式所示。该实数由k个通道得到的特征之和除以空间维度的值而得,空间维数为H*W。其次是Excitation激励操作,它由两层全连接层和Sigmoid函数组成。如公式所示,s为激励操作的输出,σ为激活函数sigmoid,W2和W1分别是两个完全连接层的相应参数,δ是激活函数ReLU,对特征先降维再升维。最后是Reweight操作,对之前的输入特征进行逐通道加权,完成原始特征在各通道上的重新分配。

1
2

程序设计

  • 完整程序和数据获取方式1:同等价值程序兑换;
  • 完整程序和数据获取方式2:私信博主获取。
%% CNN模型建立

layers = [
    imageInputLayer([size(input,1) 1 1])     %输入层参数设置
    convolution2dLayer(3,16,'Padding','same')%卷积层的核大小、数量,填充方式
    reluLayer                                %relu激活函数
    fullyConnectedLayer(384) % 384 全连接层神经元
    fullyConnectedLayer(384) % 384 全连接层神经元
    fullyConnectedLayer(1)   % 输出层神经元
    regressionLayer];        % 添加回归层,用于计算损失值
%% 模型训练与测试

options = trainingOptions('adam', ...
    'MaxEpochs',20, ...
    'MiniBatchSize',16, ...
    'InitialLearnRate',0.005, ...
    'GradientThreshold',1, ...
    'Verbose',false,...
    'Plots','training-progress',...
    'ValidationData',{testD,targetD_test'});
% 训练
net = trainNetwork(trainD,targetD',layers,options);
————————————————
版权声明:本文为CSDN博主「机器学习之心」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
tempLayers = multiplicationLayer(2, "Name", "multiplication");         % 点乘的注意力
lgraph = addLayers(lgraph, tempLayers);                                % 将上述网络结构加入空白结构中

tempLayers = [
    sequenceUnfoldingLayer("Name", "sequnfold")                        % 建立序列反折叠层
    flattenLayer("Name", "flatten")                                    % 网络铺平层
    bilstmLayer(6, "Name", "bilstm", "OutputMode", "last")             % BiLSTM层
    fullyConnectedLayer(num_class)                                     % 全连接层
    softmaxLayer                                                       % 损失函数层
    classificationLayer];                                              % 分类层
lgraph = addLayers(lgraph, tempLayers);                                % 将上述网络结构加入空白结构中

lgraph = connectLayers(lgraph, "seqfold/out", "conv_1");               % 折叠层输出 连接 卷积层输入;
lgraph = connectLayers(lgraph, "seqfold/miniBatchSize", "sequnfold/miniBatchSize"); 

参考资料

[1] https://blog.csdn.net/kjm13182345320/article/details/129943065?spm=1001.2014.3001.5501
[2] https://blog.csdn.net/kjm13182345320/article/details/129919734?spm=1001.2014.3001.5501


http://www.kler.cn/a/6540.html

相关文章:

  • nodejs利用子进程child_process执行命令及child.stdout输出数据
  • 代码加入SFTP JAVA ---(小白篇3)
  • MySQL数据库下载及安装教程
  • kafka常用命令(持续更新)
  • 智慧商城:购物车模块基本静态结构 + 构建vuex cart模块,获取数据存储(异步actions)
  • 如何缩放组件
  • 回归预测 | MATLAB实现GA-BiLSTM遗传算法优化双向长短期记忆网络的数据多输入单输出回归预测
  • 技术动态 | 基于GPT-4的知识图谱构建能力评测
  • 【C++】开散列哈希表封装实现unordered_map和unordered_set
  • HTML - Javascript - JS可变参数函数
  • Stable Diffusion 安装教程
  • opencv_c++学习(二)
  • 使用JSR303对数据进行校验【JAVA】
  • Linux reset子系统和驱动实例
  • GEE:栅格转矢量
  • Jackson之ObjectMapper常用用法
  • 【异常解决】java: 无法访问org.springframework.boot.SpringApplication的解决方案
  • 中级软考有没有必要考?
  • php+mysql仓储进销存仓库管理系统
  • 【C++】多态(二)
  • 从零开始学习InfluxDB:安装和使用入门教程
  • C++ using:软件设计中的面向对象编程技巧
  • 分库分表笔记
  • MVVM理解、object.defineProperty、数据代理
  • Android动态换肤框架实现小结
  • 射频接收机概述