当前位置: 首页 > article >正文

融合注意力机制的卷积神经网络-双向长短期记忆网络(CNN-BiLSTM-Attention)的多变量/时间序列预测/matlab代码

CNN-BiLSTM-Attention模型是一种在自然语言处理(NLP)任务中常用的强大架构,如文本分类、情感分析等。它结合了卷积神经网络(CNN)、双向长短期记忆网络(BiLSTM)和注意力机制的优势,能够捕捉局部特征和序列数据中的长程依赖关系。点击下方链接可以查看或者获取原文章对应的代码
[(1)融合注意力机制的卷积神经网络-双向长短期记忆网络(CNN-BiLSTM-Attention)的多变量/时间序列预测/matlab代码]
在这里插入图片描述

NO.1|CNN简介

CNN 主要由卷积层和池化层构成,其中卷积层利用卷积核进行电力负荷数据的有效非线性局部特征提取,池化层用于压缩提取的特征并生成更重要的特征信息,提高泛化能力。 卷积神经网络作为一种深度学习模型,广泛应用于图像识别、目标检测、图像分割和自然语言处理等领域。CNN的设计灵感来源于生物视觉系统,通过模拟人类视觉处理的方式来实现对图像等数据的高效识别和处理。CNN的核心是卷积层,它通过卷积操作自动提取输入数据的空间层次特征,并通过池化层降低特征的空间尺寸,同时保留最重要的特征信息。在经过多层卷积和池化层之后,通常会有全连接层来产生最终的输出结果,如图像分类的类别标签 。 CNN的工作原理主要包括以下几个关键步骤:

局部感受野: 每个神经元只响应图像的一个小区域,这与人类视觉系统的工作原理相似 。

	卷积操作:	通过滑动窗口(卷积核)在输入图像上进行逐元素相乘然后求和的操作,以此提取图像的局部特征 。

	激活函数:	如ReLU,用于引入非线性,帮助网络学习更复杂的特征 。

	池化层:	通过最大池化或平均池化减少特征图的空间尺寸,降低计算量并提取重要特征 。

	全连接层:	在网络的末端,将提取的特征映射转化为最终输出,如分类结果 。

CNN的训练通常采用监督学习的方法,通过调整网络中的权重和偏置来最小化预测输出和真实标签之间的差异。在训练过程中,网络学习从简单到复杂的分层特征,低层可能学习到边缘等基础特征,而高层可能学习到更复杂的形状或纹理特征

NO.2|双向LSTM(BiLSTM)

LSTM是一种特殊的递归神经网络(RNN),擅长捕捉时间序列中的长期依赖关系, 通过引入遗忘门、输入门和输出门三种门的逻辑控制单元保持和更新细胞状态,加强了长期记忆能力,可以很好地解决RNN梯度消失与梯度爆炸的问题。LSTM 通过学习时序数据的长 时间相关性,使网络可以更好、 更快地收敛,有助于提高预测精度。

	(1)遗忘门的作用是从时刻t-1	的细胞状态  中选择需要丢弃的信息,具体如公所示。	遗忘门通过读取前一时刻的隐藏层状态  和当前时刻的输入序列  ,产生一个介于0到1之	间的输出值,其中1代表完全保留信息,0代表完全舍弃信息。 

 	在公式中,  表示当前时刻t的遗忘门输出;	Wf和bf分别表示遗忘门的权重矩阵和偏置项;	σ	代表 sigmoid 激活函数。

	(2) 输入门的主要功能是根据当前时刻t的输入  决定需要保存在神经元中的信息。首先,通过 tanh 层生成时刻t的候选记忆单元状态  ,然后结合遗忘门的输出对细胞状态进行更新,得到新的细胞状态  。输入门的具体运算过程如下所示。     

其中,  是输入门在时刻t的输出状态,决定了输入  对更新细胞状态  的影响程度;	Wi和bi分别是输入门的权重和偏置项;	Wc和bc是生成候选记忆单元  的权重矩阵和偏置;	tanh为双曲正切激活函数;  	表示元素逐位相乘。	

	(3) 输出门从当前细胞状态中提取并输出关键信息。首先通过	σ层决定哪些神经元状态需要被输出,然后将这些神经元状态经过 tanh 层处理,再与	σ输出相乘,生成最终的输出值  ,该值也作为下一个时间步的隐藏层输入。输出门的计算过程如公式所示。 

	其中, 是

输出门在时刻t的输出状态; Wo和 bo分别为输出门的权重矩阵和偏置项。

	由此可见,LSTM网络在训练过程中通常采用单向的数据流,即信息只从序列的开始流向结束。这种单向训练方式可能会导致数据的时间序列特征没有得到充分利用,因为它忽略了从序列末尾到开始的潜在联系。	为了克服这一限制,BiLSTM(双向长短期记忆)网络被提出。	BiLSTM 融合了两个LSTM层:	一个处理正向序列(从前到后),另一个处理反向序列(从后到前)。	这种结构的优势在于:

	(1)正向LSTM	:捕捉时间序列中从过去到当前的信息流动,提取序列的前向依赖特征。

	(2)反向LSTM	:捕捉从未来到当前的逆向信息流动,提取序列的后向依赖特征。


	前向 LSTM、后向 LSTM 的隐藏层更新状态以	及 BiLSTM 最终输出过程如式所示。 

分别为不同层之间的激活函数。 通过结合这两个方向的信息,BiLSTM 能够 更全面地理解数据在时间序列上的特征联系,从而提高模型对数据的利用率和预测精度。 这种双向处理方式使得BiLSTM在处理诸如自然语言处理、语音识别和其他需要理解时间序列前后文信息的任务中表现出色。

NO.3|注意力机制

	注意力机制(Attention Mechanism)是深度学习中的一种关键技术,它模仿了人类在处理大量信息时选择性集中注意力的能力。	注意力模型在自然语言处理、图像识别、语音识别等多个领域都有广泛应用。



	注意力机制的核心思想是通过权重(或分数)来表示模型对输入数据中不同部分的关注程度。这些权重通常通过一个可学习的方式计算得到,以反映不同输入部分对于当前任务的重要性。注意力机制的数学模型通常包括以下几个步骤:

	(1)计算权重:对于给定的查询(Query)和一组键(Key),计算它们之间的相似度或相关性得分。 

 	(2)归一化:	使用softmax函数对得分进行归一化处理,得到每个键对应的权重。 

	(3)加权求和:	根据权重对相应的值(Value)进行加权求和,得到最终的注意力输出。	

NO.4|	CNN-BiLSTM-Attention
%% 建立模型
lgraph = layerGraph();                                                 % 建立空白网络结构
​
tempLayers = [
    sequenceInputLayer([f_, 1, 1], "Name", "sequence")                 % 建立输入层,输入数据结构为[f_, 1, 1]
    sequenceFoldingLayer("Name", "seqfold")];                          % 建立序列折叠层
lgraph = addLayers(lgraph, tempLayers);                                % 将上述网络结构加入空白结构中
​
tempLayers = convolution2dLayer([1, 1], 32, "Name", "conv_1");         % 卷积层 卷积核[1, 1] 通道数 32
lgraph = addLayers(lgraph,tempLayers);                                 % 将上述网络结构加入空白结构中
 
tempLayers = [
    reluLayer("Name", "relu_1")                                        % 激活层
    convolution2dLayer([1, 1], 64, "Name", "conv_2")                   % 卷积层 卷积核[1, 1] 通道数 64
    reluLayer("Name", "relu_2")];                                      % 激活层
lgraph = addLayers(lgraph, tempLayers);                                % 将上述网络结构加入空白结构中
​
tempLayers = [
    globalAveragePooling2dLayer("Name", "gapool")                      % 全局平均池化层
    fullyConnectedLayer(16, "Name", "fc_2")                            % SE注意力机制,通道数的1 / 4
    reluLayer("Name", "relu_3")                                        % 激活层
    fullyConnectedLayer(64, "Name", "fc_3")                            % SE注意力机制,数目和通道数相同
    sigmoidLayer("Name", "sigmoid")];                                  % 激活层
lgraph = addLayers(lgraph, tempLayers);                                % 将上述网络结构加入空白结构中
​
tempLayers = multiplicationLayer(2, "Name", "multiplication");         % 点乘的注意力
lgraph = addLayers(lgraph, tempLayers);                                % 将上述网络结构加入空白结构中
​
tempLayers = [
    sequenceUnfoldingLayer("Name", "sequnfold")                        % 建立序列反折叠层
    flattenLayer("Name", "flatten")                                    % 网络铺平层
    bilstmLayer(6, "Name", "bilstm", "OutputMode", "last")             % BiLSTM层,6个隐藏单元,输出最后一个隐藏状态
    fullyConnectedLayer(1, "Name", "fc")                               % 全连接层
    regressionLayer("Name", "regressionoutput")];                      % 回归层
lgraph = addLayers(lgraph, tempLayers);                                % 将上述网络结构加入空白结构中
​
lgraph = connectLayers(lgraph, "seqfold/out", "conv_1");               % 折叠层输出 连接 卷积层输入;
lgraph = connectLayers(lgraph, "seqfold/miniBatchSize", "sequnfold/miniBatchSize"); 
                                                                       % 折叠层输出 连接 反折叠层输入  
lgraph = connectLayers(lgraph, "conv_1", "relu_1");                    % 卷积层输出 链接 激活层
lgraph = connectLayers(lgraph, "conv_1", "gapool");                    % 卷积层输出 链接 全局平均池化
lgraph = connectLayers(lgraph, "relu_2", "multiplication/in2");        % 激活层输出 链接 相乘层
lgraph = connectLayers(lgraph, "sigmoid", "multiplication/in1");       % 全连接输出 链接 相乘层
lgraph = connectLayers(lgraph, "multiplication", "sequnfold/in");      % 点乘输出 链接 反折叠层输入

http://www.kler.cn/a/447100.html

相关文章:

  • clickhouse优化记录
  • 基于 uniapp 开发 android 播放 webrtc 流
  • Win10将WindowsTerminal设置默认终端并添加到右键(无法使用微软商店)
  • List深拷贝后,数据还是被串改
  • WeakAuras NES Script(lua)
  • 音频接口:PDM TDM128 TDM256
  • C:\Windows 文件夹
  • 大模型微调---Lora微调实战
  • jsp中的四个域对象(Spring MVC)
  • 浅谈目前我开发的前端项目用到的设计模式
  • 爬取Q房二手房房源信息
  • Partition Strategies kafka分区策略
  • <项目代码>YOLO Visdrone航拍目标识别<目标检测>
  • GESP CCF python二级编程等级考试认证真题 2024年12月
  • 基于微信小程序的绘画学习平台
  • 攻防世界easyphp
  • Leetcode中最常用的Java API——util包
  • LeetCode hot100-90
  • 纯css 实现呼吸灯效果
  • TCP/IP协议:网际层相关知识梳理
  • metasploit之ms17_010_psexec模块
  • MapBox实现深蓝色科技风格底图方案
  • android studio更改应用图片,和应用名字。
  • D101【python 接口自动化学习】- pytest进阶之fixture用法
  • Unity3D仿星露谷物语开发6之角色添加动画
  • 麒麟操作系统服务架构保姆级教程(二)ssh远程连接