MATLAB深度学习(六)——LSTM长短期神经网络原理与应用
LSTM的应用可以参见一个相当好的视频:小车倒立摆最优控制教程 - Part1 Simulink Simscape Multibody仿真建模_哔哩哔哩_bilibili
6.1 序列建模——循环神经网络
循环神经网络RNN是一类专门用于处理序列性数据x,···,xn的神经网络结构,其结构主要包含三个关键点:本质是一种神经网络;具有循环体结构;包括隐含状态。
6.1.1 RNN
输入为x,经过Wx+b和激活函数f,得到输出y,但是实际应用中会有很多序列形式的数据
RNN引入了隐状态h,可以对序列形式的数据提取特征,再转换输出。举个例子,一张图片那是同时出现,那就不叫作序列,但如果是视频中的平衡车,每个时间都有输出,且时间不同,输出会有变化,这种就叫做序列。
h1基于上一个隐藏层的状态h0和当前的输入x1计算而来,这一点所有的序列数据都是一样的
f一般数tanh,sigmoid,ReLU等非线性激活函数,他和CNN一样,在计算时,使用的U W b都是一样的,所以参数是共享的,你把他理解成一个向量,相当于用同一个3*1的向量去乘以每个时刻。因此我们可以得出,t时刻的RNN网络具有两个输入,分别为t时刻的输入向量x与t-1时刻的隐含状态。其中h在RNN网络的循环路上。
6.1.2 BPTT
随时间反向传播算法(BPTT)笔记-CSDN博客
6.1.3 基于门与记忆细胞的长期依赖关系 LSTM
RNN循环体内的信息渠道单一,只是通过tanh神经网络层,直接输出。从训练RNN的网络角度来看,训练过程一般使用随着时间的反向传播BPTT,沿着RNN时域展开方向的反方向对每个参数依次求偏导,由此可以得到损失函数关于各个权重矩阵的BPTT梯度表达式,由于训练过程中有些值会非常接近0或者1,因此序列长了之后会导致梯度爆炸或消失。
改进思想是增加RNN体内循环的信息流通路径,设置相应记忆单元,其基本单元称为LSTM单元(cell)。LSTM中的重复模块则包含四个交互的层,三个Sigmoid和一个tanh层,并以一个非常特殊的方式进行交互。
上图中,σ表示的Sigmoid 激活函数与 tanh 函数类似,不同之处在于 sigmoid 是把值压缩到0~1 之间而不是 -1~1 之间。这样的设置有助于更新或忘记信息,
- 因为任何数乘以 0 都得 0,这部分信息就会剔除掉;
- 同样的,任何数乘以 1 都得到它本身,这部分信息就会完美地保存下来
6.1.4 LSTM的门原理
(1)忘记门
在我们LSTM中的第一步是决定我们会从细胞状态中丢弃什么信息。这个决定通过一个称为“忘记门”的结构完成。该忘记门会读取上一个输出和当前输入,做一个Sigmoid 的非线性映射,然后输出一个向量(该向量每一个维度的值都在0到1之间,1表示完全保留,0表示完全舍弃,相当于记住了重要的,忘记了无关紧要的),最后与细胞状态相乘。
(2)输入门
第一,sigmoid层称“输入门层”决定什么值我们将要更新;第二,一个tanh层创建一个新的候选值向量,会被加入到状态中。下一步,我们会讲这两个信息来产生对状态的更新。
(3)细胞状态
现在是更新旧细胞状态的时间了,更新为。前面的步骤已经决定了将会做什么,我们现在就是实际去完成。我们把旧状态与相乘,丢弃掉我们确定需要丢弃的信息,接着加上。这就是新的候选值,根据我们决定更新每个状态的程度进行变化。
(4)输出门
最终,我们需要确定输出什么值。这个输出将会基于我们的细胞状态,但是也是一个过滤后的版本。
首先,我们运行一个sigmoid层来确定细胞状态的哪个部分将输出出去。接着,我们把细胞状态通过tanh进行处理(得到一个在-1到1之间的值)并将它和sigmoid门的输出相乘,最终我们仅仅会输出我们确定输出的那部分
从上面的分析我们就可以看到,一贯而终的记忆细胞 C,信息流通的路径,使得训练过程中的梯度信息长距离具有可行性,遗忘门f,输入门i,输出门o,来通过有用信息,删除以往信息。
6.2 训练效果
给出训练代码:
%% Data: 数据来源
fileName = {'data_45deg.mat', 'data_30deg.mat', ...
'data_15deg.mat', 'data_5deg.mat', ...
'data_-45deg.mat', 'data_-30deg.mat', ...
'data_-15deg.mat', 'data_-5deg.mat' };
% training:
trainingData = [];
for i = 1:length(fileName)
currentData = [];
load(fileName{i});
currentData = [out.state_data, ...
out.control_input_data]; % 5列
trainingData = [trainingData; currentData];
end
%% Setup: (sequence to one)
numInput = 1; % sequence 4x1
numResponses = 1; % one 1
% LSTM:
numHiddenUnits = 100;
layers = [ ...
sequenceInputLayer(numInput, Normalization="zscore")
lstmLayer(numHiddenUnits, OutputMode="last")
fullyConnectedLayer(numResponses)
regressionLayer];
options = trainingOptions("adam", ...
MaxEpochs=2000, ...
ValidationData= ...
{num2cell(trainingData(:,1:4),2) ...
trainingData(:,5)}, ...
OutputNetwork="best-validation-loss", ...
InitialLearnRate=0.001, ...
SequenceLength="shortest", ...
Plots="training-progress", ...
Verbose= false);
policy = trainNetwork(num2cell(trainingData(:,1:4),2),...
trainingData(:,5), layers, options);
LSTM与RNN参考 如何从RNN起步,一步一步通俗理解LSTM_rnn lstm-CSDN博客
1111