当前位置: 首页 > article >正文

基于CNN-LSTM的时序预测MATLAB实战

卷积神经网络(CNN)用于提取时间序列数据中的局部空间特征,通过卷积层和池化层的堆叠,CNN能够有效捕获数据中的短期模式和局部依赖关系。长短时记忆网络(LSTM)用于处理时间序列数据,特别擅长捕捉数据中的长期依赖关系,LSTM通过引入门控机制和记忆单元来解决长期依赖问题。基于CNN-LSTM的时序预测结合上述两种网络的优点,同时考虑了时空特征,模型通常能够获得更高的预测精度。

一、算法原理

1.1 CNN原理 

卷积神经网络具有局部连接、权值共享和空间相关等特性。卷积神经网络结构包含卷积层、激活层和池化层。

(a)二维卷积层将滑动卷积滤波器应用于输入。该层通过沿输入垂直和水平方向 移动滤波器对输入进行卷积,并计算权重与输入的点积,然后加入一个偏置项。具体表达式为:

图片

卷积层的功能是对输入数据进行特征提取,其内部包含多个卷积核,也称为感受野。将输入图像和卷积核做卷积运算,可以增强原始信号特征的同时降低噪声,卷积运算的具体过程如图所示。

图片

卷积运算的具体过程

(2)激活函数

在卷积神经网络中,常用的激活函数包括 Sigmoid 函数、Tanh 函数、Swish 函数和 Relu 函数。Relu 函数解决了 Sigmoid 函数和 Tanh 函数梯度消失的问题, 提高了模型收敛的速度,受到的广泛学者的欢迎。

(3)池化层

池化层又称为下采样层,池化层分为平均池化层和最大池化层。其中,最大池化层通过将输入分为矩形池化区域,并计算每个区域的最大值来执行下采样, 而平均池化层则是计算池化区域的平均值来执行下采样。池化层的池化过程如图所示。

图片

池化过程示意图

1.2 LSTM原理 

LSTM采用循环神经网络( Recurrent Neural Network,RNN )架构[8],它是专门为从序列中学习长期依赖关系而设计的。LSTM可以使用4个组件:输入门、输出门、遗忘门和具有自循环连接的单元来移除或添加块状态的信息。

图片

LSTM网络结构

设输入序列共有 k 个时间步,LSTM 门控机制 结构为遗忘门、输入门和输出门,xt携带网络输入值 作为向量引入系统,ht 通过隐含层对 LSTM 细胞进 行输出,ct携带着 LSTM 细胞状态进行运算。LSTM 运算规则如下:

图片

计算后保留 ct与 ht,用于下一时间步的计算;最后一步计算完成后,将隐藏层向量 hk作为输出与本组序列对应的预测值对比,得出损失函数值,依据梯度下降算法,优化权重和偏置参数,以此训练出迭代次数范围内最精确的网络参数。

1.3 CNN-LSTM框架

以时序预测为例,本次使用的CNN-LSTM框架如图所示。

图片

CNN-LSTM框架

二、具体模型框架

通过layerGraph来组合CNN和LSTM层,并使用trainNetwork函数进行模型训练,使用Dropout技术来减少模型的过拟合现象。

对LSTM参数进行设置,确定数据输入的特征维度,即时间步长的特征数量。numhidden_units1numhidden_units2numhidden_units3分别代表不同LSTM层中的隐含单元数。例如,可以设置为50、20、100,表示不同层的神经元数量。Dropout层用于减少过拟合,例如设置为0.3,表示在训练过程中随机丢弃30%的神经元。

使用Adam优化算法进行训练,这是一种改进的梯度优化算法,能够解决传统SGD算法的一些问题。训练完成后使用测试数据集进行预测,并采用如均方误差(MSE)、均方根误差(RMSE)、平均绝对误差(MAE)等评价指标来定量评估模型性能。

clc
clear
load('Train.mat')
load('Test.mat')

% LSTM 层设置,参数设置
inputSize = size(Train_xNorm{1},1);   %数据输入x的特征维度
outputSize = 1;  %数据输出y的维度  
numhidden_units1=50;
numhidden_units2= 20;
numhidden_units3=100;
%
opts = trainingOptions('adam', ...
    'MaxEpochs',10, ...
    'GradientThreshold',1,...
    'ExecutionEnvironment','cpu',...
    'InitialLearnRate',0.001, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropPeriod',2, ...   %2个epoch后学习率更新
    'LearnRateDropFactor',0.5, ...
    'Shuffle','once',...  % 时间序列长度
    'SequenceLength',k,...
    'MiniBatchSize',24,...
    'Verbose',0);
%% lstm

layers = [ ...
    
    sequenceInputLayer([inputSize,1,1],'name','input')   %输入层设置
    sequenceFoldingLayer('name','fold')
    convolution2dLayer([2,1],10,'Stride',[1,1],'name','conv1')
    batchNormalizationLayer('name','batchnorm1')
    reluLayer('name','relu1')
    maxPooling2dLayer([1,3],'Stride',1,'Padding','same','name','maxpool')
    sequenceUnfoldingLayer('name','unfold')
    flattenLayer('name','flatten')
    lstmLayer(numhidden_units1,'Outputmode','sequence','name','hidden1') 
    dropoutLayer(0.3,'name','dropout_1')
    lstmLayer(numhidden_units2,'Outputmode','last','name','hidden2') 
    dropoutLayer(0.3,'name','drdiopout_2')
    fullyConnectedLayer(outputSize,'name','fullconnect')   % 全连接层设置(影响输出维度)(cell层出来的输出层) %
    tanhLayer('name','softmax')
    regressionLayer('name','output')];

lgraph = layerGraph(layers)
lgraph = connectLayers(lgraph,'fold/miniBatchSize','unfold/miniBatchSize');
plot(lgraph)
%
% 网络训练
tic

net = trainNetwork(Train_xNorm,Train_yNorm,lgraph,opts);

%% 测试
figure
Predict_Ynorm = net.predict(Test_xNorm);
Predict_Y  = mapminmax('reverse',Predict_Ynorm',yopt);
Predict_Y = Predict_Y';

plot(Predict_Y,'g-')
hold on 

plot(Test_y);    
legend('预测值','实际值')

通过比较预测值和实际值可以评估模型的性能,通常使用图表来展示预测结果与实际结果的对比:

图片

基于CNN-LSTM的时序预测模型结合了CNN的特征提取能力和LSTM的时序建模能力,使其在处理具有复杂空间和时间依赖性的时间序列数据方面表现出色。在MATLAB中实现这一模型,可以利用其强大的内置函数和深度学习工具箱,进行有效的模型构建、训练和评估。


http://www.kler.cn/a/410172.html

相关文章:

  • 【Linux学习】【Ubuntu入门】2-3 make工具和makefile引入
  • 即时通讯服务器被ddos攻击了怎么办?
  • css效果
  • 在英文科技论文中分号后面的单词首字母需不需要大写
  • 大数据实验4-HBase
  • 神经网络12-Time-Series Transformer (TST)模型
  • 【C++篇】深度解析 C++ List 容器:底层设计与实现揭秘
  • 类和对象plus版
  • 量化交易系统开发-实时行情自动化交易-4.4.做市策略
  • 蓝桥杯每日真题 - 第22天
  • Wireshark抓取HTTPS流量技巧
  • 量子生成对抗网络
  • docker镜像、容器、仓库介绍
  • Flutter:启动屏逻辑处理01:修改默认APP启动图标
  • 大数据实战——MapReduce案例实践
  • Node.js的http模块:创建HTTP服务器、客户端示例
  • uniapp+vue2重新进入小程序就清除缓存,设备需要重新扫码
  • 实现飞2米高的框架(c语言版)
  • 网络渗透测试工具推荐与简介
  • MySQL获取数据库内所有表格数据总数
  • 解决getSubject is supported only if a security manager is allowed
  • HarmonyOS NEXT应用元服务开发Intents Kit(意图框架服务)习惯推荐方案开发者测试
  • 计算机网络——第3章 数据链路层(自学笔记)
  • 机器学习中数据集Upsampling和Downsampling是什么意思?中英文介绍
  • Spring Boot 实战:基于 Validation 注解实现分层数据校验与校验异常拦截器统一返回处理
  • IDEA 2024安装指南(含安装包以及使用说明 cannot collect jvm options 问题一)