当前位置: 首页 > article >正文

【Pytorch】nn.RNN、nn.LSTM 和 nn.GRU的输入和输出形状

nn.RNN、nn.LSTM 和 nn.GRU的输入和输出形状

      • 输入形状
        • 通用输入参数
        • 特殊情况(LSTM)
      • 输出形状
        • nn.RNN 和 nn.GRU
        • nn.LSTM
      • 代码示例

输入形状

通用输入参数

这三个模块通常接收以下两种形式的输入:

  • 输入序列:形状为 (seq_len, batch_size, input_size)
    • seq_len:表示序列的长度,即时间步的数量。例如在处理文本时,它可以是句子的单词数量;在处理时间序列数据时,它可以是时间点的数量。
    • batch_size:表示每次输入的样本数量。在训练模型时,通常会将多个样本组成一个批次进行处理,以提高计算效率。
    • input_size:表示每个时间步输入的特征维度。例如,在处理图像序列时,它可以是图像的特征向量维度;在处理文本时,它可以是词向量的维度。
  • 初始隐藏状态:形状为 (num_layers * num_directions, batch_size, hidden_size)
    • num_layers:表示 RNN 层数。如果设置为多层 RNN,信息会在不同层之间依次传递。
    • num_directions:表示 RNN 的方向数,取值为 1(单向 RNN)或 2(双向 RNN)。双向 RNN 会同时考虑序列的正向和反向信息。
    • hidden_size:表示隐藏层的维度,即每个时间步输出的隐藏状态的特征数量。
特殊情况(LSTM)

对于 nn.LSTM,除了初始隐藏状态外,还需要一个初始细胞状态,其形状与初始隐藏状态相同,即 (num_layers * num_directions, batch_size, hidden_size)

输出形状

nn.RNN 和 nn.GRU
  • 输出序列:形状为 (seq_len, batch_size, num_directions * hidden_size)。它包含了每个时间步的隐藏状态输出,其中 num_directions 取决于 RNN 是否为双向。如果是单向 RNN,num_directions 为 1;如果是双向 RNN,num_directions 为 2,输出的特征维度会翻倍。
  • 最终隐藏状态:形状为 (num_layers * num_directions, batch_size, hidden_size)。它表示最后一个时间步的隐藏状态,用于后续的任务,如分类或预测。
nn.LSTM
  • 输出序列:形状同样为 (seq_len, batch_size, num_directions * hidden_size),含义与 nn.RNNnn.GRU 的输出序列类似。
  • 最终隐藏状态和细胞状态:最终隐藏状态和细胞状态的形状均为 (num_layers * num_directions, batch_size, hidden_size)。最终隐藏状态和细胞状态一起保存了 LSTM 在最后一个时间步的信息。

代码示例

import torch
import torch.nn as nn

# 定义参数
input_size = 10
hidden_size = 20
num_layers = 2
batch_size = 3
seq_len = 5
num_directions = 1  # 单向 RNN

# 创建 RNN 模型
rnn = nn.RNN(input_size, hidden_size, num_layers)
# 创建 LSTM 模型
lstm = nn.LSTM(input_size, hidden_size, num_layers)
# 创建 GRU 模型
gru = nn.GRU(input_size, hidden_size, num_layers)

# 生成随机输入序列
input_seq = torch.randn(seq_len, batch_size, input_size)
# 初始化隐藏状态
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)

# 运行 RNN
rnn_output, hn_rnn = rnn(input_seq, h0)
print("RNN 输出序列形状:", rnn_output.shape)
print("RNN 最终隐藏状态形状:", hn_rnn.shape)

# 初始化 LSTM 的细胞状态
c0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
# 运行 LSTM
lstm_output, (hn_lstm, cn_lstm) = lstm(input_seq, (h0, c0))
print("LSTM 输出序列形状:", lstm_output.shape)
print("LSTM 最终隐藏状态形状:", hn_lstm.shape)
print("LSTM 最终细胞状态形状:", cn_lstm.shape)

# 运行 GRU
gru_output, hn_gru = gru(input_seq, h0)
print("GRU 输出序列形状:", gru_output.shape)
print("GRU 最终隐藏状态形状:", hn_gru.shape)

在上述代码中,我们定义了输入序列和初始隐藏状态,并分别使用 nn.RNNnn.LSTMnn.GRU 对输入序列进行处理,最后打印出它们的输出形状,帮助你更好地理解输入输出形状的特点。


http://www.kler.cn/a/537057.html

相关文章:

  • 信息学奥赛一本通 2101:【23CSPJ普及组】旅游巴士(bus) | 洛谷 P9751 [CSP-J 2023] 旅游巴士
  • python编程-内置函数reversed(),repr(),chr()详解
  • Windows Docker笔记-制作、加载镜像
  • 深度学习 Pytorch 神经网络的学习
  • 【FPGA】 MIPS 12条整数指令 【3】
  • 【回溯+剪枝】单词搜索,你能用递归解决吗?
  • 荣耀内置的远程控制怎样用?荣耀如何远程控制其他品牌的手机?
  • 【GitHub】GitHub 2FA 双因素认证 ( 使用 Microsoft Authenticator 应用进行二次验证 )
  • 121,【5】 buuctf web [RoarCTF 2019] Easy Calc
  • 树莓集团双流布局,元宇宙产业园点亮科技之光
  • 如何确保爬虫不会违反平台规则?
  • 为什么关系模型不叫表模型
  • Redis基础--常用数据结构的命令及底层编码
  • DeepSeek Window本地私有化部署
  • Ubuntu Crontab 日志在什么位置 ?
  • 京东java面试流程_java京东社招面试经历
  • ES6 迭代器 (`Iterator`)使用总结
  • flutter Selector 使用
  • StarSpider 星蛛 爬虫 Java框架 可以实现 lazy爬取 实现 HTML 文件的编译,子标签缓存等操作
  • 前端导出pdf,所见即所得
  • 芯科科技的BG22L和BG24L带来应用优化的超低功耗蓝牙®连接
  • Spring Boot 有哪些优点
  • 【Redis】事务因WATCH的键被修改而失败 事务队列中的操作被自动丢弃 UNWATCH的应用场景
  • 视频编辑质量评价的开源项目 VE-Bench 介绍
  • 使用deepseek快速创作ppt
  • 基于物联网技术的智能寻车引导系统方案:工作原理、核心功能及系统架构