当前位置: 首页 > article >正文

LSTM神经网络时间序列

LSTM(长短期记忆网络)是一种特殊的循环神经网络(RNN),通过引入记忆单元(Memory Cell)和门控机制,克服了传统RNN在处理长时间依赖问题中的梯度消失和梯度爆炸问题。以下是建模过程和背后的原理。

clc
clear
load('data.mat')
data=RTS'
%% 序列的前485个用于训练,后10个用于验证神经网络,然后往后预测10个数据。
dataTrain = data(1:485);    %定义训练集
dataTest = data(486:495);    %该数据是用来在最后与预测值进行对比的
%% 数据预处理
mu = mean(dataTrain);    %求均值 
sig = std(dataTrain);      %求均差 
dataTrainStandardized = (dataTrain - mu) / sig;    
%% 输入的每个时间步,LSTM网络学习预测下一个时间步,这里交错一个时间步效果最好。
XTrain = dataTrainStandardized(1:end-1);  
YTrain = dataTrainStandardized(2:end);  
%% 一维特征lstm网络训练
numFeatures = 1;   %特征为一维
numResponses = 1;  %输出也是一维
numHiddenUnits = 200;   %创建LSTM回归网络,指定LSTM层的隐含单元个数200。可调
 
layers = [ ...
    sequenceInputLayer(numFeatures)    %输入层
    lstmLayer(numHiddenUnits)  % lstm层,如果是构建多层的LSTM模型,可以修改。
    fullyConnectedLayer(numResponses)    %为全连接层,是输出的维数。
    regressionLayer];      %其计算回归问题的半均方误差模块 。即说明这不是在进行分类问题。
 
%指定训练选项,求解器设置为adam, 1000轮训练。
%梯度阈值设置为 1。指定初始学习率 0.01,在 125 轮训练后通过乘以因子 0.2 来降低学习率。
options = trainingOptions('adam', ...
    'MaxEpochs',1000, ...
    'GradientThreshold',1, ...
    'InitialLearnRate',0.01, ...      
    'LearnRateSchedule','piecewise', ...%每当经过一定数量的时期时,学习率就会乘以一个系数。
    'LearnRateDropPeriod',400, ...      %乘法之间的纪元数由“ LearnRateDropPeriod”控制。可调
    'LearnRateDropFactor',0.15, ...      %乘法因子由参“ LearnRateDropFactor”控制,可调
    'Verbose',0,  ...  %如果将其设置为true,则有关训练进度的信息将被打印到命令窗口中。默认值为true。
    'Plots','training-progress');    %构建曲线图 将'training-progress'替换为none
net = trainNetwork(XTrain,YTrain,layers,options); 
net = predictAndUpdateState(net,XTrain);  %将新的XTrain数据用在网络上进行初始化网络状态
[net,YPred] = predictAndUpdateState(net,YTrain(end));  %用训练的最后一步来进行预测第一个预测值,给定一个初始值。这是用预测值更新网络状态特有的。
%% 进行用于验证神经网络的数据预测(用预测值更新网络状态)
for i = 2:20  %从第二步开始,这里进行20次单步预测(10为用于验证的预测值,10为往后预测的值。一共20个)
    [net,YPred(:,i)] = predictAndUpdateState(net,YPred(:,i-1),'ExecutionEnvironment','cpu');  %predictAndUpdateState函数是一次预测一个值并更新网络状态
end
%% 验证神经网络
YPred = sig*YPred + mu;      %使用先前计算的参数对预测去标准化。
rmse = sqrt(mean((YPred(1:10)-dataTest).^2)) ;     %计算均方根误差 (RMSE)。
subplot(2,1,1)
plot(dataTrain(1:end))   %先画出前面485个数据,是训练数据。
hold on
idx = 486:(485+20);   %为横坐标
plot(idx,YPred(1:20),'.-')  %显示预测值
hold off
xlabel("Time")
ylabel("Case")
title("Forecast")
legend(["Observed" "Forecast"])
subplot(2,1,2)
plot(data)
xlabel("Time")
ylabel("Case")
title("Dataset")
%% 继续往后预测2023年的数据
figure(2)
idx = 486:(485+20);   %为横坐标
plot(idx,YPred(1:20),'.-')  %显示预测值
hold off
net = resetState(net);


http://www.kler.cn/a/418342.html

相关文章:

  • Xilinx FPGA内部资源组成和说明汇总
  • AIS介绍
  • 【QNX+Android虚拟化方案】129 - USB眼图参数配置
  • 【最新鸿蒙开发——应用导航设计】
  • 智慧银行反欺诈大数据管控平台方案(一)
  • 202页MES项目需求方案深入解读,学习MES系统设计规划
  • 使用Docker在Ubuntu 22.04上部署MySQL数据库的完整指南
  • 如何在 VPS 上设置 Apache 并使用免费签名的 SSL 证书
  • Uniapp 使用自定义字体
  • Linux下如何安装JDK
  • 粒子群算法优化RBF网络
  • spark同步mysql数据到sqlserver
  • Latex相关问题
  • 基于yolov8、yolov5的铝材缺陷检测识别系统(含UI界面、训练好的模型、Python代码、数据集)
  • 强国复兴项目携手易方达基金、广发基金 高效推进扶贫金发放与保障房建设
  • windows C#-相等比较
  • 《windows堆内存剖析(一)》
  • ChromeBook11 HP G7EE 刷入Ubuntu的记录
  • 鲲鹏麒麟安装离线版MySQL5.7
  • 吉客云数据集成技巧:智能实现MySQL物料信息查询
  • 栈-数组描述(C++)
  • mysql查询语句执行全流程
  • 10x 性能提升,ProtonBase 为教育行业提供统一的数据库和数仓体验
  • 【C#设计模式(16)——解释器模式(Interpreter Pattern)】
  • 搭建业务的性能优化指南
  • [C/C++]排序算法1、冒泡排序