当前位置: 首页 > article >正文

MATLAB深度学习(六)——LSTM长短期神经网络原理与应用

        LSTM的应用可以参见一个相当好的视频:小车倒立摆最优控制教程 - Part1 Simulink Simscape Multibody仿真建模_哔哩哔哩_bilibili

6.1 序列建模——循环神经网络

        循环神经网络RNN是一类专门用于处理序列性数据x,···,xn的神经网络结构,其结构主要包含三个关键点:本质是一种神经网络;具有循环体结构;包括隐含状态。

6.1.1 RNN

        

        输入为x,经过Wx+b和激活函数f,得到输出y,但是实际应用中会有很多序列形式的数据

        RNN引入了隐状态h,可以对序列形式的数据提取特征,再转换输出。举个例子,一张图片那是同时出现,那就不叫作序列,但如果是视频中的平衡车,每个时间都有输出,且时间不同,输出会有变化,这种就叫做序列。

        

        h1基于上一个隐藏层的状态h0和当前的输入x1计算而来,这一点所有的序列数据都是一样的

        f一般数tanh,sigmoid,ReLU等非线性激活函数,他和CNN一样,在计算时,使用的U W b都是一样的,所以参数是共享的,你把他理解成一个向量,相当于用同一个3*1的向量去乘以每个时刻。因此我们可以得出,t时刻的RNN网络具有两个输入,分别为t时刻的输入向量x与t-1时刻的隐含状态。其中h在RNN网络的循环路上。

6.1.2 BPTT

    随时间反向传播算法(BPTT)笔记-CSDN博客

6.1.3 基于门与记忆细胞的长期依赖关系 LSTM

        

        RNN循环体内的信息渠道单一,只是通过tanh神经网络层,直接输出。从训练RNN的网络角度来看,训练过程一般使用随着时间的反向传播BPTT,沿着RNN时域展开方向的反方向对每个参数依次求偏导,由此可以得到损失函数关于各个权重矩阵的BPTT梯度表达式,由于训练过程中有些值会非常接近0或者1,因此序列长了之后会导致梯度爆炸或消失。

        改进思想是增加RNN体内循环的信息流通路径,设置相应记忆单元,其基本单元称为LSTM单元(cell)。LSTM中的重复模块则包含四个交互的层,三个Sigmoid和一个tanh层,并以一个非常特殊的方式进行交互。

上图中,σ表示的Sigmoid 激活函数与 tanh 函数类似,不同之处在于 sigmoid 是把值压缩到0~1 之间而不是 -1~1 之间。这样的设置有助于更新或忘记信息,

  • 因为任何数乘以 0 都得 0,这部分信息就会剔除掉;
  • 同样的,任何数乘以 1 都得到它本身,这部分信息就会完美地保存下来

6.1.4 LSTM的门原理

(1)忘记门

        在我们LSTM中的第一步是决定我们会从细胞状态中丢弃什么信息。这个决定通过一个称为“忘记门”的结构完成。该忘记门会读取上一个输出和当前输入,做一个Sigmoid 的非线性映射,然后输出一个向量(该向量每一个维度的值都在0到1之间,1表示完全保留,0表示完全舍弃,相当于记住了重要的,忘记了无关紧要的),最后与细胞状态相乘。

        (2)输入门

        第一,sigmoid层称“输入门层”决定什么值我们将要更新;第二,一个tanh层创建一个新的候选值向量\tilde{C}_{t},会被加入到状态中。下一步,我们会讲这两个信息来产生对状态的更新。

     (3)细胞状态

          现在是更新旧细胞状态的时间了,C_{t-1}更新为C_{t}。前面的步骤已经决定了将会做什么,我们现在就是实际去完成。我们把旧状态与f_{t}相乘,丢弃掉我们确定需要丢弃的信息,接着加上i_{t} * \tilde{C}_{t}。这就是新的候选值,根据我们决定更新每个状态的程度进行变化。

      (4)输出门

        最终,我们需要确定输出什么值。这个输出将会基于我们的细胞状态,但是也是一个过滤后的版本。

        首先,我们运行一个sigmoid层来确定细胞状态的哪个部分将输出出去。接着,我们把细胞状态通过tanh进行处理(得到一个在-1到1之间的值)并将它和sigmoid门的输出相乘,最终我们仅仅会输出我们确定输出的那部分

        

        从上面的分析我们就可以看到,一贯而终的记忆细胞 C,信息流通的路径,使得训练过程中的梯度信息长距离具有可行性,遗忘门f,输入门i,输出门o,来通过有用信息,删除以往信息。

6.2 训练效果

给出训练代码:

%% Data: 数据来源
fileName = {'data_45deg.mat', 'data_30deg.mat', ...
    'data_15deg.mat', 'data_5deg.mat', ...
    'data_-45deg.mat', 'data_-30deg.mat', ...
    'data_-15deg.mat', 'data_-5deg.mat' };

% training:
trainingData = [];
for i = 1:length(fileName)
    currentData = [];
    load(fileName{i});
    currentData = [out.state_data, ...
        out.control_input_data]; % 5列
    trainingData = [trainingData; currentData];
end


%% Setup: (sequence to one)
numInput = 1; % sequence 4x1 
numResponses = 1; % one 1
 
% LSTM:
numHiddenUnits = 100;
layers = [ ...
    sequenceInputLayer(numInput, Normalization="zscore")
    lstmLayer(numHiddenUnits, OutputMode="last")
    fullyConnectedLayer(numResponses)
    regressionLayer];

options = trainingOptions("adam", ...
    MaxEpochs=2000, ...
    ValidationData= ...
    {num2cell(trainingData(:,1:4),2) ...
    trainingData(:,5)}, ...
    OutputNetwork="best-validation-loss", ...
    InitialLearnRate=0.001, ...
    SequenceLength="shortest", ...
    Plots="training-progress", ...
    Verbose= false);

policy = trainNetwork(num2cell(trainingData(:,1:4),2),...
    trainingData(:,5), layers, options);

        

LSTM与RNN参考 如何从RNN起步,一步一步通俗理解LSTM_rnn lstm-CSDN博客

1111


http://www.kler.cn/a/411785.html

相关文章:

  • C++ 优先算法 —— 无重复字符的最长子串(滑动窗口)
  • 基于springboot的县市级土地使用监控系统的设计与实现
  • Vue实训---5-路由搭建
  • 使用 前端技术 创建 QR 码生成器 API1
  • django authentication 登录注册
  • KMeans聚类实验(基础入门)
  • 华为ENSP--BGP路由协议实验详解
  • 网络安全期末复习
  • docker启动kafka、zookeeper、kafdrop
  • Oracle impdp-ORA-39083,ORA-00942
  • GitLab使用操作v1.0
  • 【设计模式】【行为型模式(Behavioral Patterns)】之策略模式(Strategy Pattern)
  • 【微服务架构】Kubernetes与Docker在微服务架构中的最佳实践(详尽教程)
  • 《免费学习网站推荐1》
  • 【JAVA】Java高级:Java网络编程——TCP/IP与UDP协议基础
  • 鸿蒙中拍照上传与本地图片上传
  • JavaWeb--JDBC
  • 如何搭建一个小程序:从零开始的详细指南
  • 过滤条件包含 OR 谓词,如何进行查询优化——OceanBase SQL 优化实践
  • C++设计模式-中介者模式
  • 【31-40期】从Java反射到SSO:深度解析面试高频问题
  • 17. 【.NET 8 实战--孢子记账--从单体到微服务】--记账模块--主币种设置
  • qt 读写文本、xml文件
  • 0 基础 入门简单 linux操作 上篇 利用apt命令装13 linux搭建自己的服务器
  • 【WEB开发.js】getElementById :通过元素id属性获取HTML元素
  • SpringMVC框架---SpringMVC概述、入门案例、常用注解