当前位置: 首页 > article >正文

第R3周:LSTM-火灾温度预测:3. nn.LSTM() 函数详解

nn.LSTM 是 PyTorch 中用于创建长短期记忆(Long Short-Term Memory,LSTM)模型的类。LSTM 是一种循环神经网络(Recurrent Neural Network,RNN)的变体,用于处理序列数据,能够有效地捕捉长期依赖关系。

语法

torch.nn.LSTM(input_size, hidden_size, num_layers=1, 
              bias=True, batch_first=False, 
              dropout=0, bidirectional=False)

● input_size: 输入特征的维度。
● hidden_size: 隐藏状态的维度,也是输出特征的维度。
● num_layers(可选参数): LSTM 层的数量,默认为 1。
● bias(可选参数): 是否使用偏置,默认为 True。
● batch_first(可选参数): 如果为 True,则输入和输出张量的形状为 (batch_size, seq_len, feature_size),默认为 False,张量的形状为(seq_len, batch_size, feature_dim)。
● dropout(可选参数): 如果非零,将在 LSTM 层的输出上应用 dropout,防止过拟合。默认为 0。
● bidirectional(可选参数): 如果为 True,则使用双向 LSTM,输出维度将翻倍。默认为 False。

示例

import torch
import torch.nn as nn

# 定义一个单向 LSTM 模型
input_size  = 10
hidden_size = 20
num_layers  = 2
batch_size  = 3
seq_len     = 5

lstm = nn.LSTM(input_size, hidden_size, num_layers)

# 构造一个输入张量
input_tensor = torch.randn(seq_len, batch_size, input_size)

# 初始化隐藏状态和细胞状态
h0 = torch.randn(num_layers, batch_size, hidden_size)
c0 = torch.randn(num_layers, batch_size, hidden_size)

# 将输入传递给 LSTM 模型
output, (hn, cn) = lstm(input_tensor, (h0, c0))

print("Output shape:", output.shape)    # 输出特征的形状
print("Hidden state shape:", hn.shape)  # 最后一个时间步的隐藏状态的形状
print("Cell state shape:", cn.shape)    # 最后一个时间步的细胞状态的形状

代码输出

Output shape: torch.Size([5, 3, 20])
Hidden state shape: torch.Size([2, 3, 20])
Cell state shape: torch.Size([2, 3, 20])

注意事项

● input_size 指定了输入数据的特征维度,hidden_size 指定了 LSTM 层的隐藏状态维度,num_layers 指定了 LSTM 的层数。
● LSTM 的输入张量的形状通常是 (seq_len, batch_size, input_size),但如果设置了 batch_first=True,则形状为 (batch_size, seq_len, input_size)。
● LSTM 的输出包括输出张量和最后一个时间步的隐藏状态和细胞状态。
● 可以通过 bidirectional=True 参数创建双向 LSTM,它会将输入序列分别从前向和后向传播,并将两个方向的隐藏状态拼接在一起作为输出。
● 在使用 LSTM 时,通常需要注意输入数据的序列长度,以及是否需要对输入数据进行填充或截断,以保证输入数据的长度是一致的。


http://www.kler.cn/a/310416.html

相关文章:

  • 【mysql】使用宝塔面板在云服务器上安装MySQL数据库并实现远程连接
  • 树形dp总结
  • HTTP协议基础
  • Unity3D实现视频和模型融合效果
  • 区块链技术在慈善捐赠中的应用
  • 算法——移除链表元素(leetcode203)
  • Matlab 的.m 文件批量转成py文件
  • HTML讲解(一)body部分
  • IDEA去除掉虚线,波浪线,和下划线实线的方法
  • 【微服务-注册中心】
  • 初识Maven:Java项目管理工具
  • 鸿蒙Harmony应用开发,数据驾驶舱 项目结构搭建
  • Android使用LiquidFun物理引擎实现果冻碰撞效果
  • JAVA网络编程【基于TCP和UDP协议】超详细!!!
  • C编程演奏中文版“生日快乐歌”
  • Java html生成pdf和图片
  • 【kafka】基本概念
  • SAP学习笔记 - 开发06 - CDSView + Fiori Element 之 List Report
  • Vue3 项目引入阿里 iconfont 图标和字体的多种方式
  • 2024 VMpro 虚拟机中如何给Ubuntu Linux操作系统配置联网
  • 总结拓展十:SAP开发计划(上)
  • 新峰商城之分类三级联动实现
  • PyQt / PySide + Pywin32 + ctypes 自定义标题栏窗口 + 完全还原 Windows 原生窗口边框特效项目
  • html+css+js网页设计 旅游 龙门石窟8个页面
  • HarmonyOS ArkTS 用户首选项的开发及测试
  • AI大模型与产品经理:替代与合作的深度剖析