pytorch nn.Unflatten 和 nn.Flatten模块介绍
nn.Flatten
和 nn.Unflatten
是 PyTorch 中用于调整张量形状的模块。它们提供了对多维张量的简单变换,常用于神经网络模型的层之间的数据调整。
1. nn.Flatten
功能:
- 将输入张量展平为二维张量,通常用于将卷积层的输出展平成全连接层的输入。
- 它会将张量的指定维度范围压缩为单个维度。
构造参数:
start_dim
: 展平的起始维度(默认值为1
)。end_dim
: 展平的结束维度(默认值为-1
)。
用法:
import torch
from torch import nn
# 输入张量: [batch_size, channels, height, width]
x = torch.randn(4, 3, 32, 32)
# 展平操作
flatten = nn.Flatten(start_dim=1) # 从维度1到最后展平
y = flatten(x)
print(y.shape) # 输出: [4, 3072] (3*32*32 被展平)
适用场景:
- 通常用于从卷积层(或其他多维特征)到全连接层的过渡。
- 例如:
[batch_size, channels, height, width] -> [batch_size, features]
。
2. nn.Unflatten
功能:
- 将展平的张量还原为多维张量。
- 它通过指定目标维度和形状信息,反向操作
nn.Flatten
。
构造参数:
dim
: 需要展开的维度。unflattened_size
: 展开的形状(tuple 类型)。
用法:
import torch
from torch import nn
# 输入张量: [batch_size, features]
x = torch.randn(4, 3072)
# 还原操作
unflatten = nn.Unflatten(dim=1, unflattened_size=(3, 32, 32))
y = unflatten(x)
print(y.shape) # 输出: [4, 3, 32, 32]
适用场景:
- 通常用于从全连接层(或展平特征)还原到卷积层或其他多维表示。
- 例如:
[batch_size, features] -> [batch_size, channels, height, width]
。
对比
特性 | nn.Flatten | nn.Unflatten |
---|---|---|
主要操作 | 将多个维度压缩为一个维度 | 将一个维度展开为多个维度 |
输入 | 多维张量 | 展平的张量 |
输出 | 二维张量 | 恢复为多维张量 |
常用场景 | 用于连接卷积层和全连接层 | 用于从展平的特征恢复到多维结构 |
参数控制 | 指定展平的起始和结束维度范围 | 指定需要展开的维度和目标形状 |
实际应用示例
结合使用 Flatten 和 Unflatten:
import torch
from torch import nn
# 初始化 Flatten 和 Unflatten
flatten = nn.Flatten(start_dim=1)
unflatten = nn.Unflatten(dim=1, unflattened_size=(3, 32, 32))
# 模拟数据
x = torch.randn(4, 3, 32, 32) # [batch_size, channels, height, width]
# 展平
flat_x = flatten(x)
print(flat_x.shape) # 输出: [4, 3072]
# 恢复
unflat_x = unflatten(flat_x)
print(unflat_x.shape) # 输出: [4, 3, 32, 32]
这两个模块通过简单的接口提供了灵活的形状调整功能,是构建神经网络过程中不可或缺的工具。