torch.nn.Sequential介绍
torch.nn.Sequential
是 PyTorch 中一个模块容器,用于将一系列层或模块按顺序连接在一起,简化前向传播过程。在 Sequential
中,所有的子模块会按照添加的顺序被执行,适合那些有明确顺序的神经网络结构,比如卷积神经网络、全连接网络等。
主要特点
- 按顺序执行: 将多个子模块按顺序组合,前向传播时依次调用。
- 简洁代码: 减少显式定义
forward
方法的需求,对于简单的网络结构,使用Sequential
可以大大简化代码。 - 嵌套支持:
Sequential
容器可以嵌套,允许将多个Sequential
容器嵌套在一起。
使用方式
- 直接传入模块: 可以通过将模块按顺序传入
Sequential
。 - 有序字典: 可以使用
OrderedDict
来为每个模块指定名字。
基本用法
1. 直接传入模块
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
input_tensor = torch.randn(1, 10)
output = model(input_tensor)
print(output)
在这个例子中,Sequential
中包含了两个 Linear
层和一个 ReLU
激活函数,前向传播时,输入会依次通过这些层。
2. 使用 OrderedDict
from collections import OrderedDict
import torch
import torch.nn as nn
model = nn.Sequential(OrderedDict([
('fc1', nn.Linear(10, 20)),
('relu', nn.ReLU()),
('fc2', nn.Linear(20, 5))
]))
input_tensor = torch.randn(1, 10)
output = model(input_tensor)
print(output)
在这里,使用 OrderedDict
给每个层指定了名字,方便在访问时更具可读性。
访问子模块
可以通过索引或名称访问 Sequential
中的子模块。
# 按索引访问
print(model[0]) # Linear(10, 20)
# 按名称访问(如果使用了 OrderedDict)
print(model.fc1) # Linear(10, 20)
嵌套使用
Sequential
容器可以嵌套其他 Sequential
容器或其他模块,形成更复杂的模型结构。
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Sequential(
nn.Linear(20, 30),
nn.ReLU()
),
nn.Linear(30, 5)
)
input_tensor = torch.randn(1, 10)
output = model(input_tensor)
print(output)
优势
- 简洁明了:
Sequential
适合那些模型结构比较清晰、需要按顺序堆叠层的神经网络模型,避免了手动写forward
函数。 - 易于嵌套: 可以将复杂的子结构封装为单独的
Sequential
,方便组合复杂的网络。
注意事项
- 不能处理复杂的前向传播逻辑:
Sequential
仅适用于简单的前向传播操作,如果有条件判断或多个输入/输出的情况,仍然需要手动定义forward
函数。