当前位置: 首页 > article >正文

torch.nn.Sequential的用法

文章目录

  • 介绍
  • 基本用法
  • 添加命名层
  • 动态添加层
  • 嵌套使用
  • 与自定义前向传播的区别

介绍

torch.nn.Sequential 是 PyTorch 中的一个容器模块,用于将多个神经网络层按顺序组合在一起。它可以让我们以更加简洁的方式定义前向传播的网络结构,适合简单的线性堆叠模型。

基本用法

torch.nn.Sequential 按照定义的顺序将多个层组合在一起,输入数据会依次通过这些层。

import torch.nn as nn  

# 定义一个简单的网络  
model = nn.Sequential(  
    nn.Linear(10, 20),  # 全连接层:输入 10,输出 20  
    nn.ReLU(),          # 激活函数:ReLU  
    nn.Linear(20, 1)    # 全连接层:输入 20,输出 1  
)

当调用 model(input) 时,输入会依次通过 Sequential 中的每一层。

import torch  

input = torch.randn(5, 10)  # 输入:batch_size=5, features=10  
output = model(input)       # 前向传播  
print(output.shape)         # 输出:torch.Size([5, 1])

添加命名层

可以为每一层指定名称,方便后续访问或调试。

model = nn.Sequential(  
    ('fc1', nn.Linear(10, 20)),  # 命名为 'fc1'  
    ('relu1', nn.ReLU()),        # 命名为 'relu1'  
    ('fc2', nn.Linear(20, 1))    # 命名为 'fc2'  
)

通过名称或索引访问某一层:

print(model.fc1)  # 访问名为 'fc1' 的层  
print(model[0])   # 通过索引访问第一层

动态添加层

可以通过 add_module 方法动态添加层。

model = nn.Sequential()  
model.add_module('fc1', nn.Linear(10, 20))  # 添加第一层  
model.add_module('relu1', nn.ReLU())        # 添加激活函数  
model.add_module('fc2', nn.Linear(20, 1))   # 添加第二层

嵌套使用

nn.Sequential 可以嵌套使用,用于构建更复杂的网络。

model = nn.Sequential(  
    nn.Sequential(  
        nn.Linear(10, 20),  
        nn.ReLU()  
    ),  
    nn.Sequential(  
        nn.Linear(20, 10),  
        nn.ReLU()  
    ),  
    nn.Linear(10, 1)  
)

与自定义前向传播的区别

nn.Sequential 适合简单的线性堆叠模型,但如果需要更复杂的前向传播逻辑(如分支、跳跃连接等),需要继承 nn.Module 并自定义 forward 方法。
使用 nn.Sequential

model = nn.Sequential(  
    nn.Linear(10, 20),  
    nn.ReLU(),  
    nn.Linear(20, 1)  
)

自定义 forward

class CustomModel(nn.Module):  
    def __init__(self):  
        super(CustomModel, self).__init__()  
        self.fc1 = nn.Linear(10, 20)  
        self.fc2 = nn.Linear(20, 1)  
        self.relu = nn.ReLU()  

    def forward(self, x):  
        x = self.fc1(x)  
        x = self.relu(x)  
        x = self.fc2(x)  
        return x  

model = CustomModel()

http://www.kler.cn/a/460195.html

相关文章:

  • 实现单例模式的五种方式
  • unity中Timeline动画的播放和播放中如何判断播放结束
  • Flume的安装和使用
  • C++11右值与列表初始化
  • 深入解析 Wireshark 的 TLS 设置:应用场景与实操技巧
  • JS-判断字段值是否为空
  • Markov test笔记
  • 对于爬虫的配置和管理,涉及到的模块和功能主要包括
  • stm32week1+2
  • C++系列之引用
  • SQL 实战:正则表达式匹配 – 高效数据筛选与文本解析
  • 数据库-MySQL-sql有in会走索引吗?(易理解)
  • Java包装类型的缓存
  • solr9.7 单机安装教程
  • Uniapp在浏览器拉起导航
  • 自动驾驶新纪元:城区NOA功能如何成为智能驾驶技术的分水岭
  • (七)- plane/crtc/encoder/connector objects
  • SQL 实战:使用 CTE(公用表达式)优化递归与多层复杂查询
  • Mysql的事务隔离机制
  • 性能与安全测试综合部分
  • 实验八 指针2
  • 常见cms获取Shell漏洞(Wordpress、dedecms、ASPCMS、PhpMyadmin)
  • 深入了解 Zookeeper:原理与应用(选举篇)
  • Supermap iClient Webgl 粒子特效案例-消防场景
  • C++并发:线程管控
  • Android 部分操作(待补充