当前位置: 首页 > article >正文

torch.nn.Sequential介绍

torch.nn.Sequential 是 PyTorch 中一个模块容器,用于将一系列层或模块按顺序连接在一起,简化前向传播过程。在 Sequential 中,所有的子模块会按照添加的顺序被执行,适合那些有明确顺序的神经网络结构,比如卷积神经网络、全连接网络等。

主要特点

  • 按顺序执行: 将多个子模块按顺序组合,前向传播时依次调用。
  • 简洁代码: 减少显式定义 forward 方法的需求,对于简单的网络结构,使用 Sequential 可以大大简化代码。
  • 嵌套支持Sequential 容器可以嵌套,允许将多个 Sequential 容器嵌套在一起。

使用方式

  1. 直接传入模块: 可以通过将模块按顺序传入 Sequential
  2. 有序字典: 可以使用 OrderedDict 来为每个模块指定名字。

基本用法

1. 直接传入模块

import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 5)
)

input_tensor = torch.randn(1, 10)
output = model(input_tensor)
print(output)

在这个例子中,Sequential 中包含了两个 Linear 层和一个 ReLU 激活函数,前向传播时,输入会依次通过这些层。

2. 使用 OrderedDict

from collections import OrderedDict
import torch
import torch.nn as nn

model = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(10, 20)),
    ('relu', nn.ReLU()),
    ('fc2', nn.Linear(20, 5))
]))

input_tensor = torch.randn(1, 10)
output = model(input_tensor)
print(output)

在这里,使用 OrderedDict 给每个层指定了名字,方便在访问时更具可读性。

访问子模块

可以通过索引或名称访问 Sequential 中的子模块。

# 按索引访问
print(model[0])  # Linear(10, 20)

# 按名称访问(如果使用了 OrderedDict)
print(model.fc1)  # Linear(10, 20)

嵌套使用

Sequential 容器可以嵌套其他 Sequential 容器或其他模块,形成更复杂的模型结构。

model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Sequential(
        nn.Linear(20, 30),
        nn.ReLU()
    ),
    nn.Linear(30, 5)
)

input_tensor = torch.randn(1, 10)
output = model(input_tensor)
print(output)

优势

  • 简洁明了Sequential 适合那些模型结构比较清晰、需要按顺序堆叠层的神经网络模型,避免了手动写 forward 函数。
  • 易于嵌套: 可以将复杂的子结构封装为单独的 Sequential,方便组合复杂的网络。

注意事项

  • 不能处理复杂的前向传播逻辑:Sequential 仅适用于简单的前向传播操作,如果有条件判断或多个输入/输出的情况,仍然需要手动定义 forward 函数。

http://www.kler.cn/news/356404.html

相关文章:

  • 线性可分支持向量机的原理推导 最大化几何间隔d 公式解析
  • D36【python 接口自动化学习】- python基础之函数
  • VUE 开发——Vue学习(四)—— 智慧商城项目
  • Javascript中的堆内存和栈内存
  • mysql--数据类型
  • 前端vue项目使用Decimal.js做加减乘除求余运算
  • C++20中头文件source_location的使用
  • 大数据学习-Clickhouse
  • 数据结构——链表,哈希表
  • makefile和make
  • JavaWeb学习(3)
  • [项目详解][boost搜索引擎#1] 概述 | 去标签 | 数据清洗 | scp
  • 024 elasticsearch集群
  • 生财合伙人推荐 - 鞠海深-群控
  • 霍夫圆型硬币检测Matlab程序
  • GitHub与GitCode
  • vuefor循环动态展示图片不显示
  • ARM指令集和汇编语言的关联学习
  • 设计模式——代理模式(6)
  • 408算法题leetcode--第33天