当前位置: 首页 > article >正文

第十五站:循环神经网络(RNN)与长短期记忆网络(LSTM)

1. 循环神经网络(RNN)概述

RNN 是一种非常适合处理序列数据的神经网络。与传统的前馈神经网络不同,RNN 具有一个 循环连接,它可以 记住 前一个时刻的信息,并将其传递到当前时刻。

RNN 的工作原理

  • 输入序列:RNN 接收一个序列的输入,比如时间序列数据、文本数据等。
  • 隐藏状态:RNN 的核心是其 隐藏状态,它存储了对输入序列历史的记忆。
  • 递归计算:在每一步,RNN 会计算当前时刻的隐藏状态,并将其传递到下一时刻的计算中。

RNN 的数学表示如下:
h t = σ ( W h h h t − 1 + W x h x t + b ) h_t = \sigma(W_{hh}h_{t-1} + W_{xh}x_t + b) ht=σ(Whhht1+Wxhxt+b)

  • h t h_t ht 是当前时刻的隐藏状态。
  • W h h W_{hh} Whh是隐藏状态到隐藏状态的权重。
  • W x h W_{xh} Wxh是输入到隐藏状态的权重。
  • x t x_t xt是当前时刻的输入。
  • b b b是偏置项。

2. RNN 的局限性

尽管 RNN 能够处理序列数据,但它存在 梯度消失和梯度爆炸问题。特别是在长序列上,RNN 很难保持长时间的依赖关系。

  • 梯度消失:在训练过程中,当梯度经过多次反向传播时,可能会变得非常小,导致网络无法有效学习长期依赖关系。
  • 梯度爆炸:相反,梯度也可能变得非常大,导致训练不稳定。

为了解决这些问题,我们引入了 长短期记忆网络(LSTM)

3. 长短期记忆网络(LSTM)

LSTM 是一种特殊的 RNN,它引入了 记忆单元门控机制,使得网络能够更好地学习和保持长期依赖。

LSTM 的工作原理

LSTM 通过 遗忘门、输入门和输出门 来控制信息的流动。

  • 遗忘门(Forget Gate):决定哪些信息需要丢弃。
  • 输入门(Input Gate):决定哪些信息需要存储到记忆单元中。
  • 输出门(Output Gate):决定从记忆单元中输出哪些信息。

LSTM 中的每个步骤计算如下:

  1. 遗忘门:决定当前隐藏状态中有多少信息需要被丢弃。
    f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)
  2. 输入门:决定当前输入中有多少信息需要被保存。
    i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)
  3. 记忆单元更新:更新当前的记忆单元。
    C t = f t ⋅ C t − 1 + i t ⋅ C ~ t C_t = f_t \cdot C_{t-1} + i_t \cdot \tilde{C}_t Ct=ftCt1+itC~t
  4. 输出门:决定从记忆单元中输出哪些信息。
    o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)
  5. 隐藏状态更新:最终的隐藏状态。
    h t = o t ⋅ tanh ⁡ ( C t ) h_t = o_t \cdot \tanh(C_t) ht=ottanh(Ct)

4. LSTM 的优势

  • 长期依赖:LSTM 能够更好地捕捉长期依赖关系,解决了传统 RNN 的梯度消失问题。
  • 门控机制:通过遗忘门、输入门和输出门,LSTM 控制了信息的流动,避免了无用信息的积累。

5. LSTM 示例代码:

下面是一个使用 LSTM 进行时间序列预测的简单示例代码:

import torch
import torch.nn as nn
import torch.optim as optim


# 定义 LSTM 网络结构
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size)  # LSTM 层
        self.fc = nn.Linear(hidden_size, output_size)  # 全连接层

    def forward(self, x):
        # 初始化 LSTM 的隐藏状态和细胞状态
        h0 = torch.zeros(1, x.size(1), hidden_size).to(x.device)  # 隐藏状态 h0
        c0 = torch.zeros(1, x.size(1), hidden_size).to(x.device)  # 细胞状态 c0

        # LSTM 前向传播
        lstm_out, (hn, cn) = self.lstm(x, (h0, c0))

        # 使用最后一个时间步的输出进行预测
        out = self.fc(lstm_out[-1])  # lstm_out[-1] 形状为 (batch_size, hidden_size)
        return out


# 输入参数
input_size = 1  # 输入特征维度
hidden_size = 64  # LSTM 隐藏层维度
output_size = 1  # 输出维度

# 创建 LSTM 模型实例
model = LSTMModel(input_size, hidden_size, output_size)

# 定义损失函数和优化器
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam 优化器

# 假设我们有一个简单的时间序列数据
data = torch.randn(10, 100, 1)  # 形状为 (sequence_length, batch_size, input_size)
labels = torch.randn(100, 1)  # 目标值,形状为 (batch_size, output_size)

# 训练循环
for epoch in range(100):
    model.train()  # 设置模型为训练模式
    optimizer.zero_grad()  # 清空梯度

    # 预测
    output = model(data)  # 前向传播

    # 计算损失
    loss = criterion(output, labels)  # 计算损失

    # 反向传播
    loss.backward()  # 计算梯度

    # 更新参数
    optimizer.step()  # 更新模型参数

    # 输出损失值
    if epoch % 10 == 0:
        print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

关键点说明:

  1. LSTM 层:

    • self.lstm = nn.LSTM(input_size, hidden_size):定义了一个 LSTM 层,input_size 表示每个时间步的输入维度,hidden_size 是 LSTM 层的隐藏单元数量。
    • LSTM 网络有一个非常重要的特点,即它能够通过递归传递信息(记忆)来处理时间序列数据。
  2. 全连接层:

    • self.fc = nn.Linear(hidden_size, output_size):全连接层用于将 LSTM 的输出映射到最终的预测结果。
    • 在这里,我们将 hidden_size 的输出映射到 output_size,适用于回归任务。
  3. 前向传播:

    • lstm_out, (hn, cn) = self.lstm(x, (h0, c0)):将输入数据 x 传入 LSTM 层,并得到 LSTM 的输出和最后的隐藏状态 hn、细胞状态 cn
    • out = self.fc(lstm_out[-1]):选择 LSTM 输出序列的最后一个时间步的输出,传递给全连接层进行预测。
  4. 训练过程:

    • 清空梯度:每个 epoch 之前使用 optimizer.zero_grad() 清空之前计算的梯度。
    • 损失计算和反向传播:通过 criterion(output, labels) 计算损失,并通过 loss.backward() 进行反向传播来计算梯度。
    • 优化器更新optimizer.step() 用来更新模型的参数。

6. LSTM 在实际的生活中也有很多应用地方:

LSTM 广泛应用于 时间序列分析自然语言处理(NLP) 中:

  1. 时间序列预测:LSTM 可以用来预测股票价格、天气变化等序列数据。
  2. 文本生成和语言建模:LSTM 可用于生成文本或建模语言的上下文。
  3. 机器翻译:LSTM 用于翻译不同语言之间的句子。
  4. 语音识别:LSTM 可用于处理语音信号,并将其转换为文本。

结语:因为博主的一些原因,机器学习系列就更到这里,学到这里各位也应该对机器学习的基础有一定的了解,并能搭建属于自己的一个神经网络,并去进行调优,改进,并部署到实际的现实需求当中


http://www.kler.cn/a/568157.html

相关文章:

  • redis的启动方式
  • Linux——计算机网络
  • 【SDR课堂第12讲】AD9361毛刺问题总结
  • 手写RPC框架-V1版本
  • 一周学会Flask3 Python Web开发-Jinja2模版中加载静态文件
  • 2.9作业
  • 大模型最新面试题系列:训练篇之数据处理与增强
  • Python可视化大框架的研究与应用
  • 聊聊大数据测试开展方向有哪些?
  • Protobuf原理与序列化
  • Android中的四大组件及其生命周期
  • 学习笔记-单片机蓝桥杯大模板更新-米醋
  • uniapp h5页面获取跳转传参的简单方法
  • 设置电脑一接通电源就主动开机
  • OpenEuler学习笔记(三十五):搭建代码托管服务器
  • IP-----动态路由OSPF(2)
  • Docker数据卷容器实战
  • CSS 中最常用的三种选择器的详细讲解(配合实例)
  • (视频教程)Compass代谢分析详细流程及python版-R语言版下游分析和可视化
  • 从零基础到通过考试