当前位置: 首页 > article >正文

【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!

【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!

【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!


文章目录

  • 【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!
  • 前言
    • 1. 多尺度多头自注意力(Multi-Head Self-Attention,MHSA)模块
    • `forward` 前向传播函数:
    • 2. 块级模块 Block
    • `forward` 前向传播函数:
    • 3. 融合模块 Fusion
    • `forward` 前向传播函数:


欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

论文地址:https://ieeexplore.ieee.org/document/10247595

前言

在这里插入图片描述
该代码实现了一个多尺度多头自注意力(Multi-Head Self-Attention,MHSA)模块 Mutilscal_MHSA、一个块级模块 Block 以及一个融合模块 Fusion。此代码用于遥感图像语义分割模型 CMTFNet 中,主要通过多尺度卷积、MHSA 和融合机制增强图像特征提取。以下是逐行代码解析:

在这里插入图片描述

1. 多尺度多头自注意力(Multi-Head Self-Attention,MHSA)模块

class Mutilscal_MHSA(nn.Module):
    def __init__(self, dim, num_heads, atten_drop = 0., proj_drop = 0., dilation = [3, 5, 7], fc_ratio=4, pool_ratio=16):
        super(Mutilscal_MHSA, self).__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.atten_drop = nn.Dropout(atten_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        self.MSC = MutilScal(dim=dim, fc_ratio=fc_ratio, dilation=dilation, pool_ratio=pool_ratio)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels=dim, out_channels=dim//fc_ratio, kernel_size=1),
            nn.ReLU6(),
            nn.Conv2d(in_channels=dim//fc_ratio, out_channels=dim, kernel_size=1),
            nn.Sigmoid()
        )
        self.kv = Conv(dim, 2 * dim, 1)
  • __init__ 构造函数:
  • super(Mutilscal_MHSA, self).__init__(): 初始化父类 nn.Module
  • assert dim % num_heads == 0: 确保特征维度 dim 可被头数 num_heads 整除。
  • self.dim、self.num_heads: 初始化维度和多头数量。
  • head_dim = dim // num_heads: 每个头的维度大小。
  • self.scale = head_dim ** -0.5: 计算缩放因子,用于稳定点积结果。
  • self.atten_drop、self.proj_drop: 设置注意力和投影的 dropout 层。
  • self.MSC = MutilScal(...): 多尺度卷积模块,用于提取多尺度特征。
  • self.avgpool = nn.AdaptiveAvgPool2d(1): 全局平均池化,将特征图缩小至 (1,1)。
  • self.fc = nn.Sequential(...): 两层全连接网络,用于生成通道注意力权重。
  • self.kv = Conv(dim, 2 * dim, 1): 卷积层,将输入特征转换为键值对。

forward 前向传播函数:

    def forward(self, x):
        u = x.clone()
        B, C, H, W = x.shape
        kv = self.MSC(x)
        kv = self.kv(kv)

        B1, C1, H1, W1 = kv.shape

        q = rearrange(x, 'b (h d) (hh) (ww) -> (b) h (hh ww) d', h=self.num_heads,
                      d=C // self.num_heads, hh=H, ww=W)
        k, v = rearrange(kv, 'b (kv h d) (hh) (ww) -> kv (b) h (hh ww) d', h=self.num_heads,
                         d=C // self.num_heads, hh=H1, ww=W1, kv=2)

        dots = (q @ k.transpose(-2, -1)) * self.scale
        attn = dots.softmax(dim=-1)
        attn = self.atten_drop(attn)
        attn = attn @ v

        attn = rearrange(attn, '(b) h (hh ww) d -> b (h d) (hh) (ww)', h=self.num_heads,
                         d=C // self.num_heads, hh=H, ww=W)
        c_attn = self.avgpool(x)
        c_attn = self.fc(c_attn)
        c_attn = c_attn * u
        return attn + c_attn
  • u = x.clone(): 复制输入 x,用于残差连接。
  • B, C, H, W = x.shape: 获取输入张量的维度信息。
  • kv = self.MSC(x): 将输入 x 传入多尺度卷积模块以提取键值特征。
  • kv = self.kv(kv): 使用 kv 卷积层进一步处理特征。
  • B1, C1, H1, W1 = kv.shape: 获取键值特征的维度信息。
  • q = rearrange(...): 重排 xquery 形式,适用于多头自注意力。
  • k, v = rearrange(...): 重排 kv 为键和值形式,适用于多头自注意力。
  • dots = (q @ k.transpose(-2, -1)) * self.scale: 计算缩放的查询键点积。
  • attn = dots.softmax(dim=-1): 计算点积的 softmax,生成注意力权重。
  • attn = self.atten_drop(attn): 应用注意力 dropout。
  • attn = attn @ v: 将注意力权重和值相乘,得到新的特征表示。
  • attn = rearrange(...): 重排 attn 为原始特征形状。
  • c_attn = self.avgpool(x): 对 x 进行全局平均池化。
  • c_attn = self.fc(c_attn): 通过全连接层生成通道注意力权重。
  • c_attn = c_attn * u: 将通道注意力权重与输入 u 相乘。
  • return attn + c_attn: 返回多头自注意力特征和通道注意力特征的和。

2. 块级模块 Block

class Block(nn.Module):
    def __init__(self, dim=512, num_heads=16,  mlp_ratio=4, pool_ratio=16, drop=0., dilation=[3, 5, 7],
                 drop_path=0., act_layer=nn.ReLU6, norm_layer=nn.BatchNorm2d):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Mutilscal_MHSA(dim, num_heads=num_heads, atten_drop=drop, proj_drop=drop, dilation=dilation,
                                   pool_ratio=pool_ratio, fc_ratio=mlp_ratio)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        mlp_hidden_dim = int(dim // mlp_ratio)

        self.mlp = E_FFN(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer,
                         drop=drop)
  • super().__init__(): 初始化父类。
  • self.norm1 = norm_layer(dim): 归一化层。
  • self.attn = Mutilscal_MHSA(...): 多尺度多头自注意力模块。
  • self.drop_path = DropPath(...): 随机丢弃路径,用于防止过拟合。
  • mlp_hidden_dim = int(dim // mlp_ratio): 计算多层感知机的隐藏层维度。
  • self.mlp = E_FFN(...): 全连接前馈网络。

forward 前向传播函数:

    def forward(self, x):

        x = x + self.drop_path(self.norm1(self.attn(x)))
        x = x + self.drop_path(self.mlp(x))

        return x
  • x = x + self.drop_path(self.norm1(self.attn(x))): 对注意力模块进行归一化、添加残差连接。
  • x = x + self.drop_path(self.mlp(x)): 对全连接层输出添加残差连接。
  • return x: 返回块的输出。

3. 融合模块 Fusion

class Fusion(nn.Module):
    def __init__(self, dim, eps=1e-8):
        super(Fusion, self).__init__()

        self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
        self.eps = eps
        self.post_conv = SeparableConvBNReLU(dim, dim, 5)
  • super(Fusion, self).__init__(): 初始化父类。
  • self.weights = nn.Parameter(...): 创建两个可训练的权重参数。
  • self.eps = eps: 用于避免除零的 epsilon。
  • self.post_conv = SeparableConvBNReLU(...): 可分离卷积层,融合后的卷积处理。

forward 前向传播函数:

    def forward(self, x, res):
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        weights = nn.ReLU6()(self.weights)
        fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps)
        x = fuse_weights[0] * res + fuse_weights[1] * x
        x = self.post_conv(x)
        return x
  • x = F.interpolate(...): 上采样 x
  • weights = nn.ReLU6()(self.weights): 对权重参数应用 ReLU6 激活。
  • fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps): 归一化权重。
  • x = fuse_weights[0] * res + fuse_weights[1] * x: 加权融合 xres
  • x = self.post_conv(x): 通过可分离卷积进一步处理。
  • return x: 返回融合后的特征。

这些模块配合在一起实现了多尺度、多头自注意力机制以及融合处理,有效提升遥感图像语义分割性能。

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz


http://www.kler.cn/a/384567.html

相关文章:

  • 【wxWidgets GUI设计教程 - 高级布局与窗口管理】
  • JAVA:生成唯一的ID
  • 【每日刷题】Day151
  • 无人车之路径规划篇
  • 微信小程序开发,诗词鉴赏app,诗词推荐实现(二)
  • 少儿编程教育的多维度对比:软件类、硬件类与软硬件结合课程的选择
  • 丹摩征文活动|详解 DAMODEL(丹摩智算)平台:为 AI 开发者量身打造的智算云服务
  • 三周精通FastAPI:30 API、标签元数据和文档 URL
  • 大语言模型训练的全过程:预训练、微调、RLHF
  • Axure设计之左右滚动组件教程(动态面板)
  • ArcGIS Pro SDK (二十三)实时要素类
  • windows、linux安装jmeter及设置中文显示
  • Oracle 23AI创建示例库
  • idea | 搭建 SpringBoot 项目之配置 Maven
  • 第十五届蓝桥杯C/C++B组题解——数字接龙
  • 线性表之链表详解
  • Chrome与火狐哪个浏览器的隐私追踪功能更好
  • 实用篇:简单RTC时钟使用手册!
  • 跨境独立站新手,如何用DuoPlus云手机破局海外社媒引流?
  • C语言 | Leetcode C语言题解之第542题01矩阵
  • 正则表达式在Kotlin中的应用:提取图片链接
  • Istio Gateway发布服务
  • 一文了解Android的Doze模式
  • 前端开发设计模式——原型模式
  • Linux文件系统详解
  • 【Axure高保真原型】视频列表播放器