当前位置: 首页 > article >正文

探索 PyTorch 中的 ConvTranspose2d 及其转置卷积家族

探索 PyTorch 中的 ConvTranspose2d 及其转置卷积家族

在深度学习领域,尤其是图像处理任务中,卷积神经网络(CNN)扮演着重要角色。而当我们需要在网络中进行上采样(Upsampling)时,转置卷积(Transpose Convolution)就成为了不可或缺的工具。今天,我们以 PyTorch 中的 ConvTranspose2d 为核心,深入探讨它的功能、使用方式,并介绍它的“家族成员”——其他转置卷积相关函数。

什么是 ConvTranspose2d?

ConvTranspose2d 是 PyTorch 中 torch.nn 模块提供的一个二维转置卷积层,也常被称为“反卷积”(Deconvolution),尽管这个名称在学术上并不完全准确。它的本质是通过卷积操作将输入特征图的空间尺寸(宽和高)放大,通常用于上采样任务。

与普通的卷积层(Conv2d)将输入特征图尺寸缩小的功能相反,ConvTranspose2d 的主要作用是:

  1. 上采样:增大特征图的空间分辨率。
  2. 特征恢复:在解码器(如 U-Net 的扩展路径)中恢复细节信息。
  3. 生成任务:在生成对抗网络(GAN)等模型中生成高分辨率输出。

在 U-Net 等分割网络中,ConvTranspose2d 常用于“扩展路径”,通过放大特征图并结合跳跃连接(Skip Connection)逐步重建输入图像的细节。

定义与参数

ConvTranspose2d 的基本定义如下:

torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)
  • in_channels:输入特征图的通道数。
  • out_channels:输出特征图的通道数。
  • kernel_size:卷积核的大小,例如 2(2, 2)
  • stride:卷积核滑动的步幅,默认值为 1,增大步幅会显著放大输出尺寸。
  • padding:输入边缘填充的像素数,默认值为 0。
  • output_padding:调整输出尺寸的额外填充,用于精确控制输出大小。
  • groups:分组卷积的组数,默认值为 1。
  • bias:是否添加偏置项,默认值为 True
  • dilation:卷积核元素之间的间距,默认值为 1。

工作原理

转置卷积的核心思想是将普通卷积的“前向过程”反转。普通卷积通过卷积核滑动和加权求和缩小特征图,而转置卷积则通过在输入特征之间插入零(即“稀疏化”),再应用卷积核,生成更大的输出特征图。这种操作可以看作是对输入特征的“放大重建”。

例如,输入一个 2x2 的特征图,使用 stride=2 的转置卷积,输出尺寸会变为 4x4(具体尺寸还与 kernel_sizepadding 有关)。

使用示例

以下是一个简单的例子:

import torch
import torch.nn as nn

# 定义一个转置卷积层
upconv = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=2)

# 输入张量:1个样本,1个通道,2x2 的特征图
x = torch.tensor([[[[1., 2.],
                    [3., 4.]]]])

# 应用转置卷积
y = upconv(x)
print(y.shape)  # 输出:torch.Size([1, 1, 4, 4])
print(y)

在这个例子中,输入 2x2 的特征图被放大为 4x4,具体输出值取决于卷积核的权重。

ConvTranspose2d 的家族成员

转置卷积并非孤立存在,PyTorch 提供了一系列相关函数,统称为“转置卷积家族”。它们针对不同维度和需求设计,以下是几个常见成员:

1. ConvTranspose1d - 一维转置卷积

  • 功能:对一维序列数据进行上采样。
  • 使用场景:适用于时间序列、音频信号等任务。
  • 示例
upconv1d = nn.ConvTranspose1d(1, 1, kernel_size=2, stride=2)
x = torch.tensor([[[1., 2., 3.]]])  # 1x3 输入
y = upconv1d(x)
print(y.shape)  # 输出:torch.Size([1, 1, 6])

2. ConvTranspose3d - 三维转置卷积

  • 功能:对三维数据(如体视数据)进行上采样。
  • 使用场景:常用于医学影像(如 CT/MRI)或视频处理。
  • 示例
upconv3d = nn.ConvTranspose3d(1, 1, kernel_size=2, stride=2)
x = torch.randn(1, 1, 4, 4, 4)  # 4x4x4 输入
y = upconv3d(x)
print(y.shape)  # 输出:torch.Size([1, 1, 8, 8, 8])

3. 与普通卷积的关系

虽然 ConvTranspose2dConv2d 是“互逆”的概念,但它们并非严格的数学逆操作。转置卷积的权重是可学习的,因此它更像是一种参数化的上采样方法,而非简单的“反卷积”。

在 U-Net 中的应用

在U-Net 代码中(详情见笔者的另一篇博客:深入了解 PyTorch 中的 MaxPool2d 及其池化家族函数),ConvTranspose2d 是扩展路径的核心组件:

self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
d3 = self.up3(b)  # 上采样
d3 = torch.cat([d3, e3], dim=1)  # 跳跃连接

补充全部代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=2):
        super(UNet, self).__init__()

        # 收缩路径
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.pool = nn.MaxPool2d(2)

        # 底部
        self.bottom = self.conv_block(256, 512)

        # 扩展路径
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(512, 256)  # 拼接后通道数为 256+256=512
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(128, 64)

        # 输出层
        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # 收缩路径
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        b = self.bottom(self.pool(e3))

        # 扩展路径
        d3 = self.up3(b)
        d3 = torch.cat([d3, e3], dim=1)  # 跳跃连接
        d3 = self.dec3(d3)
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        # 输出
        out = self.out_conv(d1)
        return out

# 测试代码
if __name__ == "__main__":
    model = UNet(in_channels=1, out_channels=2)
    x = torch.randn(1, 1, 572, 572)  # 输入示例:单通道 572x572 图像
    y = model(x)
    print(y.shape)  # 输出:torch.Size([1, 2, 388, 388])

在这里,up3 将底部特征图从 512 通道上采样到 256 通道,同时将空间尺寸放大一倍(例如从 56x56 到 112x112)。随后通过跳跃连接与编码路径的特征图 e3 拼接,进一步恢复细节。

U-Net 的这种设计充分利用了转置卷积的上采样能力,结合跳跃连接保留了低层次特征,使模型在图像分割任务中表现出色。

与其他上采样方法的对比

除了转置卷积,PyTorch 还提供了其他上采样方法,如:

  • nn.Upsample:基于插值(如双线性插值)的上采样,计算简单但缺乏学习能力。
  • nn.MaxUnpool2d:基于池化索引的上采样,需与 MaxPool2d 配合使用。
    相比之下,ConvTranspose2d 的优势在于其卷积核是可训练的,可以根据任务需求学习最佳的上采样方式。

总结

ConvTranspose2d 是深度学习中实现上采样的强大工具,广泛应用于图像分割、生成模型等任务。它通过转置卷积操作放大特征图,并结合可学习的参数提供灵活性。它的家族成员(如 ConvTranspose1dConvTranspose3d)进一步扩展了应用场景,覆盖一维到三维数据。

在实际使用中,选择 ConvTranspose2d 还是其他上采样方法,取决于任务需求:如果需要可学习的特征重建,转置卷积是首选;如果只追求简单放大,则插值方法可能更高效。希望这篇博客能让你对 ConvTranspose2d 及其家族有更清晰的认识!

分析ConvTranspose2dtorch.cat 的作用

代码背景

这是 U-Net 模型中“扩展路径”的一部分,具体代码如下:

self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
d3 = self.up3(b)  # 上采样
d3 = torch.cat([d3, e3], dim=1)  # 跳跃连接

我们需要搞清楚:

  1. ConvTranspose2d(512, 256, kernel_size=2, stride=2) 如何影响通道数和分辨率。
  2. torch.cat([d3, e3], dim=1) 后的通道数和分辨率变化。

输入假设

假设输入 b 是底部特征图(即 self.bottom 的输出),其形状为 [batch_size, 512, H, W],其中:

  • batch_size 是批量大小(通常为 1 或更多)。
  • 512 是通道数。
  • HW 是特征图的空间分辨率(宽和高,例如 56x56)。

在 U-Net 中,底部特征图通常是经过多次池化(MaxPool2d(2))后的结果,因此分辨率较小。


第一步:ConvTranspose2d 的作用

定义
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
d3 = self.up3(b)
  • in_channels=512:输入通道数为 512。
  • out_channels=256:输出通道数为 256。
  • kernel_size=2:转置卷积核大小为 2x2。
  • stride=2:步幅为 2,表示输出分辨率会放大两倍。
通道数变化
  • 输入 b 的通道数是 512。
  • ConvTranspose2d 通过卷积操作将通道数从 512 减少到 256。
  • 因此,d3 的通道数为 256
分辨率变化

转置卷积的输出尺寸可以通过以下公式计算:

H_out = (H_in - 1) * stride - 2 * padding + kernel_size + output_padding
W_out = (W_in - 1) * stride - 2 * padding + kernel_size + output_padding

默认情况下,padding=0output_padding=0,代入参数:

  • H_in = Hstride = 2kernel_size = 2padding = 0output_padding = 0
  • H_out = (H - 1) * 2 - 2 * 0 + 2 + 0 = 2H - 2 + 2 = 2H
  • 同理,W_out = 2W

所以:

  • 如果输入 b 的分辨率是 H x W(例如 56x56),则 d3 的分辨率变为 2H x 2W(例如 112x112)。
中间结果

经过 d3 = self.up3(b) 后:

  • 通道数:256。
  • 分辨率:2H x 2W(例如 112x112)。
  • d3 的形状为 [batch_size, 256, 2H, 2W]

第二步:torch.cat 的作用

定义
d3 = torch.cat([d3, e3], dim=1)  # 跳跃连接
  • torch.cat 是沿着指定维度(这里是 dim=1,即通道维度)拼接两个张量。
  • d3 是上采样后的特征图,形状为 [batch_size, 256, 2H, 2W]
  • e3 是收缩路径中的特征图(来自 self.enc3),形状需要与 d3 的空间分辨率匹配。
e3 的形状

在 U-Net 中,e3 是编码路径中第三层的输出,经过 self.enc3 = self.conv_block(128, 256) 处理:

  • 通道数为 256
  • 分辨率取决于前面的池化操作。假设输入图像是 572x572:
    • e1:经过卷积后为 568x568(因无填充,572 - 3 + 1 = 568)。
    • e2:池化后为 284x284(568 / 2)。
    • e3:池化后为 142x142(284 / 2)。
    • b:池化后为 71x71(142 / 2,假设向下取整)。
    • d3:上采样后为 142x142(71 * 2)。
  • 所以,e3 的分辨率是 142x142,形状为 [batch_size, 256, 142, 142]

由于 d3 的分辨率(2H x 2W,例如 142x142)与 e3 的分辨率匹配,它们可以在通道维度上拼接。

通道数变化
  • d3 的通道数为 256。
  • e3 的通道数为 256。
  • torch.cat([d3, e3], dim=1) 将两个张量的通道数相加:256 + 256 = 512
分辨率变化
  • torch.cat 只在通道维度上操作,不改变空间分辨率。
  • 因此,分辨率保持为 2H x 2W(例如 142x142)。
最终结果

经过 d3 = torch.cat([d3, e3], dim=1) 后:

  • 通道数:512。
  • 分辨率:2H x 2W(例如 142x142)。
  • d3 的形状为 [batch_size, 512, 2H, 2W]

回答疑问

  1. “将通道变小,但是分辨率加倍”

    • 是的,ConvTranspose2d(512, 256, kernel_size=2, stride=2) 将通道数从 512 减小到 256,同时分辨率从 H x W 加倍到 2H x 2W。这是转置卷积的典型行为:通过减少通道数换取更大的空间尺寸。
  2. “然后 cat 一下呢?通道数不变?分辨率怎么变化”

    • 错了,torch.cat([d3, e3], dim=1) 后通道数会变化,从 256 增加到 512(因为拼接了 e3 的 256 个通道)。
    • 分辨率不变,仍然是 2H x 2W,因为 cat 只影响通道维度,不改变宽和高。

总结

  • 初始 b[batch_size, 512, H, W](例如 [1, 512, 71, 71])。
  • 经过 up3d3 变为 [batch_size, 256, 2H, 2W](例如 [1, 256, 142, 142])。
  • 经过 catd3 变为 [batch_size, 512, 2H, 2W](例如 [1, 512, 142, 142])。

通道数先减半(512 -> 256),分辨率加倍(H x W -> 2H x 2W),然后通过跳跃连接拼接后通道数又加倍(256 -> 512),分辨率保持不变。这正是 U-Net 的设计精髓:通过上采样和跳跃连接逐步恢复空间信息,同时融合多尺度特征。

后记

2025年3月13日15点45分于上海,在Grok 3大模型辅助下完成。


http://www.kler.cn/a/584690.html

相关文章:

  • SolidWorks中文完整版+教程百度云资源分享
  • 【JavaScript 】1. 什么是 Node.js?(JavaScript 服务器环境)
  • 【Flutter】第一次textEditingController.text获取到空字符串
  • 医院本地化DeepSeek R1对接混合数据库技术实战方案研讨
  • 性能优化:服务器性能影响网站加载速度分析
  • 如何从零编写自己的.NET IoT设备驱动
  • 第54天:Web攻防-SQL注入数据类型参数格式JSONXML编码加密符号闭合复盘报告
  • JVM 详解:Java 虚拟机的核心机制
  • k8s中的控制器的使用
  • DeepSeek 助力 Vue3 开发:打造丝滑的表格(Table)之添加列宽调整功能,示例Table14_06带搜索功能的固定表头表格
  • Linux C++ 编程死锁详解
  • MyBatis一对多查询方式
  • uniapp实现 uview1 u-button的水波纹效果
  • Jump Desktop for Mac v9.0.94 优秀的远程桌面连接工具 支持M、Intel芯片
  • 数据结构——顺序表seqlist
  • PostgreSQL16 的双向逻辑复制
  • Netty基础—4.NIO的使用简介一
  • 【贪心算法5】
  • 使用DeepSeek完成一个简单嵌入式开发
  • 如何优化AI模型的Prompt:深度指南