当前位置: 首页 > article >正文

小白学Pytorch系列--Torch.nn API Recurrent Layers(8)

小白学Pytorch系列–Torch.nn API Recurrent Layers(8)

方法注释
nn.RNNBase
nn.RNN将具有tanh tanh或ReLU ReLU非线性的多层Elman RNN应用于输入序列。
nn.LSTM将多层长短期记忆(LSTM) RNN应用于输入序列。
nn.GRU将多层门控循环单元(GRU) RNN应用于输入序列。
nn.RNNCell具有tanh或ReLU非线性的Elman RNN单元。
nn.LSTMCell长短期记忆(LSTM)细胞。
nn.GRUCell门控循环单元(GRU)细胞

nn.RNNBase

重置参数数据指针,以便它们可以使用更快的代码路径。

目前,只有当模块在GPU上并且启用了cuDNN时,这才有效。否则,这就是拒绝。

nn.RNNCell

h ′ = tanh ⁡ ( W i h x + b i h + W h h h + b h h ) h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh}) h=tanh(Wihx+bih+Whhh+bhh)

rnn = nn.RNNCell(10, 20)
input = torch.randn(6, 3, 10)
hx = torch.randn(3, 20)
output = []
for i in range(6):
    hx = rnn(input[i], hx)
    output.append(hx)

nn.RNN

将具有tanh或ReLU非线性的多层Elman RNN应用于输入序列。对于输入序列中的每个元素,每个层计算以下函数:

import torch.nn as nn
import torch
model = nn.RNN(input_size=10, hidden_size=100, batch_first=True, num_layers=2, bidirectional =True)

input_tensor = torch.randn(2, 5, 10 )
output, hidden = model(input_tensor)
print(output.shape) # [bz, seq_len, hz]
print(hidden.shape) #  [num_layer*D, bz, hz]

nn.LSTMCell

长短期记忆(LSTM) Cell

>>> rnn = nn.LSTMCell(10, 20)  # (input_size, hidden_size)
>>> input = torch.randn(2, 3, 10)  # (time_steps, batch, input_size)
>>> hx = torch.randn(3, 20)  # (batch, hidden_size)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(input.size()[0]):
...     hx, cx = rnn(input[i], (hx, cx))
...     output.append(hx)
>>> output = torch.stack(output, dim=0)

nn.LSTM

将多层长短期记忆(LSTM) RNN应用于输入序列。
对于输入序列中的每个元素,每一层都计算以下函数




>>> rnn = nn.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0)) # [seq_len, bz, bi*hz]


nn.GRUCell



>>> rnn = nn.GRUCell(10, 20)
>>> input = torch.randn(6, 3, 10)
>>> hx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
...     hx = rnn(input[i], hx)
...     output.append(hx)

nn.GRU




>>> rnn = nn.GRU(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> output, hn = rnn(input, h0)

http://www.kler.cn/a/4263.html

相关文章:

  • python之二维几何学习笔记
  • 关于高级工程师的想法
  • 【Rust自学】13.2. 闭包 Pt.2:闭包的类型推断和标注
  • 《探索烟雾目标检测开源项目:技术与应用的深度剖析》
  • 《Keras 3 在 TPU 上的肺炎分类》
  • 45_Lua模块与包
  • 渗透测试之冰蝎实战
  • 使用 Alluxio 优化 EMR 上 Flink Join
  • 有钱还系统开发|有钱还系统顾头不顾尾?最后的人会受伤害?
  • Thinkphp 6.0路由的域名和跨域请求
  • TS常用数据类型(TypeScript常用数据类型,ts常用数据类型和js常用数据类型的区别)
  • 前端面试题之html css篇
  • Spring MVC 启动之 Handler 揭秘
  • C#学习 Day2
  • gunicorn启动flask输出调试信息
  • CocosCreator实战篇 | 实现刮刮卡和橡皮擦 | 擦除效果
  • Mysql语句复习
  • 安装及使用本地Maven仓库
  • 面经-javascript基础
  • get和post的区别
  • python使用正则表达式re
  • Linux 多线程:多线程和多进程的对比
  • ChatGPT再掀AI资本狂潮,30位科技创新VC投资者齐聚“实在智能”
  • 2023年科睿唯安官方剔除的35本SCI清单
  • 【CSS】清除浮动 ③ ( 清除浮动 - 使用 after 伪元素 | 语法简介 | 兼容低版本浏览器 | 原理分析 )
  • 深度好文,无代码平台如何解决软件开发成本居高不下?