探索 PyTorch 中的 ConvTranspose2d 及其转置卷积家族
探索 PyTorch 中的 ConvTranspose2d 及其转置卷积家族
在深度学习领域,尤其是图像处理任务中,卷积神经网络(CNN)扮演着重要角色。而当我们需要在网络中进行上采样(Upsampling)时,转置卷积(Transpose Convolution)就成为了不可或缺的工具。今天,我们以 PyTorch 中的 ConvTranspose2d
为核心,深入探讨它的功能、使用方式,并介绍它的“家族成员”——其他转置卷积相关函数。
什么是 ConvTranspose2d?
ConvTranspose2d
是 PyTorch 中 torch.nn
模块提供的一个二维转置卷积层,也常被称为“反卷积”(Deconvolution),尽管这个名称在学术上并不完全准确。它的本质是通过卷积操作将输入特征图的空间尺寸(宽和高)放大,通常用于上采样任务。
与普通的卷积层(Conv2d
)将输入特征图尺寸缩小的功能相反,ConvTranspose2d
的主要作用是:
- 上采样:增大特征图的空间分辨率。
- 特征恢复:在解码器(如 U-Net 的扩展路径)中恢复细节信息。
- 生成任务:在生成对抗网络(GAN)等模型中生成高分辨率输出。
在 U-Net 等分割网络中,ConvTranspose2d
常用于“扩展路径”,通过放大特征图并结合跳跃连接(Skip Connection)逐步重建输入图像的细节。
定义与参数
ConvTranspose2d
的基本定义如下:
torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)
in_channels
:输入特征图的通道数。out_channels
:输出特征图的通道数。kernel_size
:卷积核的大小,例如2
或(2, 2)
。stride
:卷积核滑动的步幅,默认值为 1,增大步幅会显著放大输出尺寸。padding
:输入边缘填充的像素数,默认值为 0。output_padding
:调整输出尺寸的额外填充,用于精确控制输出大小。groups
:分组卷积的组数,默认值为 1。bias
:是否添加偏置项,默认值为True
。dilation
:卷积核元素之间的间距,默认值为 1。
工作原理
转置卷积的核心思想是将普通卷积的“前向过程”反转。普通卷积通过卷积核滑动和加权求和缩小特征图,而转置卷积则通过在输入特征之间插入零(即“稀疏化”),再应用卷积核,生成更大的输出特征图。这种操作可以看作是对输入特征的“放大重建”。
例如,输入一个 2x2 的特征图,使用 stride=2
的转置卷积,输出尺寸会变为 4x4(具体尺寸还与 kernel_size
和 padding
有关)。
使用示例
以下是一个简单的例子:
import torch
import torch.nn as nn
# 定义一个转置卷积层
upconv = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=2)
# 输入张量:1个样本,1个通道,2x2 的特征图
x = torch.tensor([[[[1., 2.],
[3., 4.]]]])
# 应用转置卷积
y = upconv(x)
print(y.shape) # 输出:torch.Size([1, 1, 4, 4])
print(y)
在这个例子中,输入 2x2 的特征图被放大为 4x4,具体输出值取决于卷积核的权重。
ConvTranspose2d 的家族成员
转置卷积并非孤立存在,PyTorch 提供了一系列相关函数,统称为“转置卷积家族”。它们针对不同维度和需求设计,以下是几个常见成员:
1. ConvTranspose1d - 一维转置卷积
- 功能:对一维序列数据进行上采样。
- 使用场景:适用于时间序列、音频信号等任务。
- 示例:
upconv1d = nn.ConvTranspose1d(1, 1, kernel_size=2, stride=2)
x = torch.tensor([[[1., 2., 3.]]]) # 1x3 输入
y = upconv1d(x)
print(y.shape) # 输出:torch.Size([1, 1, 6])
2. ConvTranspose3d - 三维转置卷积
- 功能:对三维数据(如体视数据)进行上采样。
- 使用场景:常用于医学影像(如 CT/MRI)或视频处理。
- 示例:
upconv3d = nn.ConvTranspose3d(1, 1, kernel_size=2, stride=2)
x = torch.randn(1, 1, 4, 4, 4) # 4x4x4 输入
y = upconv3d(x)
print(y.shape) # 输出:torch.Size([1, 1, 8, 8, 8])
3. 与普通卷积的关系
虽然 ConvTranspose2d
和 Conv2d
是“互逆”的概念,但它们并非严格的数学逆操作。转置卷积的权重是可学习的,因此它更像是一种参数化的上采样方法,而非简单的“反卷积”。
在 U-Net 中的应用
在U-Net 代码中(详情见笔者的另一篇博客:深入了解 PyTorch 中的 MaxPool2d 及其池化家族函数),ConvTranspose2d
是扩展路径的核心组件:
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
d3 = self.up3(b) # 上采样
d3 = torch.cat([d3, e3], dim=1) # 跳跃连接
补充全部代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def __init__(self, in_channels=1, out_channels=2):
super(UNet, self).__init__()
# 收缩路径
self.enc1 = self.conv_block(in_channels, 64)
self.enc2 = self.conv_block(64, 128)
self.enc3 = self.conv_block(128, 256)
self.pool = nn.MaxPool2d(2)
# 底部
self.bottom = self.conv_block(256, 512)
# 扩展路径
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.dec3 = self.conv_block(512, 256) # 拼接后通道数为 256+256=512
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec2 = self.conv_block(256, 128)
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec1 = self.conv_block(128, 64)
# 输出层
self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),
nn.ReLU(inplace=True)
)
def forward(self, x):
# 收缩路径
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
b = self.bottom(self.pool(e3))
# 扩展路径
d3 = self.up3(b)
d3 = torch.cat([d3, e3], dim=1) # 跳跃连接
d3 = self.dec3(d3)
d2 = self.up2(d3)
d2 = torch.cat([d2, e2], dim=1)
d2 = self.dec2(d2)
d1 = self.up1(d2)
d1 = torch.cat([d1, e1], dim=1)
d1 = self.dec1(d1)
# 输出
out = self.out_conv(d1)
return out
# 测试代码
if __name__ == "__main__":
model = UNet(in_channels=1, out_channels=2)
x = torch.randn(1, 1, 572, 572) # 输入示例:单通道 572x572 图像
y = model(x)
print(y.shape) # 输出:torch.Size([1, 2, 388, 388])
在这里,up3
将底部特征图从 512 通道上采样到 256 通道,同时将空间尺寸放大一倍(例如从 56x56 到 112x112)。随后通过跳跃连接与编码路径的特征图 e3
拼接,进一步恢复细节。
U-Net 的这种设计充分利用了转置卷积的上采样能力,结合跳跃连接保留了低层次特征,使模型在图像分割任务中表现出色。
与其他上采样方法的对比
除了转置卷积,PyTorch 还提供了其他上采样方法,如:
nn.Upsample
:基于插值(如双线性插值)的上采样,计算简单但缺乏学习能力。nn.MaxUnpool2d
:基于池化索引的上采样,需与MaxPool2d
配合使用。
相比之下,ConvTranspose2d
的优势在于其卷积核是可训练的,可以根据任务需求学习最佳的上采样方式。
总结
ConvTranspose2d
是深度学习中实现上采样的强大工具,广泛应用于图像分割、生成模型等任务。它通过转置卷积操作放大特征图,并结合可学习的参数提供灵活性。它的家族成员(如 ConvTranspose1d
和 ConvTranspose3d
)进一步扩展了应用场景,覆盖一维到三维数据。
在实际使用中,选择 ConvTranspose2d
还是其他上采样方法,取决于任务需求:如果需要可学习的特征重建,转置卷积是首选;如果只追求简单放大,则插值方法可能更高效。希望这篇博客能让你对 ConvTranspose2d
及其家族有更清晰的认识!
分析ConvTranspose2d
和 torch.cat
的作用
代码背景
这是 U-Net 模型中“扩展路径”的一部分,具体代码如下:
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
d3 = self.up3(b) # 上采样
d3 = torch.cat([d3, e3], dim=1) # 跳跃连接
我们需要搞清楚:
ConvTranspose2d(512, 256, kernel_size=2, stride=2)
如何影响通道数和分辨率。torch.cat([d3, e3], dim=1)
后的通道数和分辨率变化。
输入假设
假设输入 b
是底部特征图(即 self.bottom
的输出),其形状为 [batch_size, 512, H, W]
,其中:
batch_size
是批量大小(通常为 1 或更多)。512
是通道数。H
和W
是特征图的空间分辨率(宽和高,例如 56x56)。
在 U-Net 中,底部特征图通常是经过多次池化(MaxPool2d(2)
)后的结果,因此分辨率较小。
第一步:ConvTranspose2d
的作用
定义
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
d3 = self.up3(b)
in_channels=512
:输入通道数为 512。out_channels=256
:输出通道数为 256。kernel_size=2
:转置卷积核大小为 2x2。stride=2
:步幅为 2,表示输出分辨率会放大两倍。
通道数变化
- 输入
b
的通道数是 512。 ConvTranspose2d
通过卷积操作将通道数从 512 减少到 256。- 因此,
d3
的通道数为 256。
分辨率变化
转置卷积的输出尺寸可以通过以下公式计算:
H_out = (H_in - 1) * stride - 2 * padding + kernel_size + output_padding
W_out = (W_in - 1) * stride - 2 * padding + kernel_size + output_padding
默认情况下,padding=0
和 output_padding=0
,代入参数:
H_in = H
,stride = 2
,kernel_size = 2
,padding = 0
,output_padding = 0
。H_out = (H - 1) * 2 - 2 * 0 + 2 + 0 = 2H - 2 + 2 = 2H
。- 同理,
W_out = 2W
。
所以:
- 如果输入
b
的分辨率是H x W
(例如 56x56),则d3
的分辨率变为 2H x 2W(例如 112x112)。
中间结果
经过 d3 = self.up3(b)
后:
- 通道数:256。
- 分辨率:2H x 2W(例如 112x112)。
d3
的形状为[batch_size, 256, 2H, 2W]
。
第二步:torch.cat
的作用
定义
d3 = torch.cat([d3, e3], dim=1) # 跳跃连接
torch.cat
是沿着指定维度(这里是dim=1
,即通道维度)拼接两个张量。d3
是上采样后的特征图,形状为[batch_size, 256, 2H, 2W]
。e3
是收缩路径中的特征图(来自self.enc3
),形状需要与d3
的空间分辨率匹配。
e3 的形状
在 U-Net 中,e3
是编码路径中第三层的输出,经过 self.enc3 = self.conv_block(128, 256)
处理:
- 通道数为 256。
- 分辨率取决于前面的池化操作。假设输入图像是 572x572:
e1
:经过卷积后为 568x568(因无填充,572 - 3 + 1 = 568)。e2
:池化后为 284x284(568 / 2)。e3
:池化后为 142x142(284 / 2)。b
:池化后为 71x71(142 / 2,假设向下取整)。d3
:上采样后为 142x142(71 * 2)。
- 所以,
e3
的分辨率是 142x142,形状为[batch_size, 256, 142, 142]
。
由于 d3
的分辨率(2H x 2W,例如 142x142)与 e3
的分辨率匹配,它们可以在通道维度上拼接。
通道数变化
d3
的通道数为 256。e3
的通道数为 256。torch.cat([d3, e3], dim=1)
将两个张量的通道数相加:256 + 256 = 512。
分辨率变化
torch.cat
只在通道维度上操作,不改变空间分辨率。- 因此,分辨率保持为 2H x 2W(例如 142x142)。
最终结果
经过 d3 = torch.cat([d3, e3], dim=1)
后:
- 通道数:512。
- 分辨率:2H x 2W(例如 142x142)。
d3
的形状为[batch_size, 512, 2H, 2W]
。
回答疑问
-
“将通道变小,但是分辨率加倍”:
- 是的,
ConvTranspose2d(512, 256, kernel_size=2, stride=2)
将通道数从 512 减小到 256,同时分辨率从 H x W 加倍到 2H x 2W。这是转置卷积的典型行为:通过减少通道数换取更大的空间尺寸。
- 是的,
-
“然后 cat 一下呢?通道数不变?分辨率怎么变化”:
- 错了,
torch.cat([d3, e3], dim=1)
后通道数会变化,从 256 增加到 512(因为拼接了e3
的 256 个通道)。 - 分辨率不变,仍然是 2H x 2W,因为
cat
只影响通道维度,不改变宽和高。
- 错了,
总结
- 初始 b:
[batch_size, 512, H, W]
(例如 [1, 512, 71, 71])。 - 经过 up3:
d3
变为[batch_size, 256, 2H, 2W]
(例如 [1, 256, 142, 142])。 - 经过 cat:
d3
变为[batch_size, 512, 2H, 2W]
(例如 [1, 512, 142, 142])。
通道数先减半(512 -> 256),分辨率加倍(H x W -> 2H x 2W),然后通过跳跃连接拼接后通道数又加倍(256 -> 512),分辨率保持不变。这正是 U-Net 的设计精髓:通过上采样和跳跃连接逐步恢复空间信息,同时融合多尺度特征。
后记
2025年3月13日15点45分于上海,在Grok 3大模型辅助下完成。