当前位置: 首页 > article >正文

(arxiv 2024)即插即用多尺度注意力聚合模块MSAA,即用即起飞

题目:CM-UNet: Hybrid CNN-Mamba UNet for Remote Sensing Image Semantic Segmentation

论文地址:https://arxiv.org/pdf/2405.10530

创新点

  • 提出CM-UNet框架:基于Mamba架构的CM-UNet框架,通过整合CNN和Mamba模块,能够在遥感图像语义分割任务中高效捕捉局部和全局信息。

  • 设计CSMamba块:CSMamba块结合了通道和空间注意力机制,将Mamba模块扩展为能够处理图像长程依赖的组件,提升了特征选择和信息融合的精度。

  • 多尺度注意力聚合模块(MSAA):引入MSAA模块,聚合编码器的多尺度特征,通过空间和通道的双重聚合提高特征表达能力,替代传统的跳跃连接,更好地支持解码器的多层次信息融合。

  • 多输出监督机制:在解码器的不同层次引入多输出监督,确保各层次逐步细化分割图,从而提升最终分割精度。

方法

整体结构

       CM-UNet模型结构由ResNet编码器、多尺度注意力聚合模块(MSAA)和CSMamba解码器组成。编码器负责提取多层次特征,MSAA模块融合多尺度特征以增强表达,解码器则利用CSMamba块通过通道和空间注意力机制高效捕捉长程依赖关系,最终生成精细的分割图,并在各层解码器中加入多输出监督以优化分割结果。

  • CNN编码器:采用ResNet结构作为编码器,用于提取多层次的特征信息。与传统UNet不同,CM-UNet的编码器使用的是多尺度特征提取,以便为后续模块提供更丰富的上下文信息。

  • 多尺度注意力聚合模块(MSAA):在编码器和解码器之间,使用MSAA模块对多尺度特征进行聚合。这个模块通过空间和通道注意力机制对不同尺度的特征进行融合,增强特征表达能力,并取代了UNet中的跳跃连接。

  • CSMamba解码器:解码器采用CSMamba块,该模块结合通道和空间注意力机制,以及Mamba结构的线性时间复杂度特性,能够有效捕获图像的长程依赖关系。CSMamba解码器逐步上采样特征并生成输出分割图。同时引入多输出监督,在各层解码器中加入监督信号,确保分割图在不同层次上得到精细化生成。

即插即用模块作用

MSAA 作为一个即插即用模块,主要适用于:

  • 多尺度图像特征融合:适合需要在不同尺度上捕捉细节和背景特征的任务,例如遥感图像分割。

  • 高分辨率图像分析:在高分辨率图像中识别微小区域和复杂结构时,MSAA能帮助提升模型的特征分辨力。

  • 语义分割和目标检测:适用于需要精确识别物体边界和细节的场景,例如自动驾驶和城市规划中的图像语义分割任务。

消融实验结果

  • 消融实验对比了在不同模块组合下的性能,主要包括是否引入多尺度注意力聚合模块(MSAA)和多输出监督模块。实验结果表明,单独使用MSAA或多输出监督模块都能提高分割性能,而两者结合使用时,模型的mIoU、mF1和整体准确率(OA)达到了最佳。这表明MSAA和多输出监督模块的协同作用显著提升了模型在遥感图像语义分割中的表现。

即插即用模块

import torch
import torch.nn as nn

# 论文:CM-UNet: Hybrid CNN-Mamba UNet for Remote Sensing Image Semantic Segmentation
# 论文地址:https://arxiv.org/pdf/2405.10530


class ChannelAttentionModule(nn.Module):
    def __init__(self, in_channels, reduction=4):
        super(ChannelAttentionModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttentionModule(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttentionModule, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class FusionConv(nn.Module):
    def __init__(self, in_channels, out_channels, factor=4.0):
        super(FusionConv, self).__init__()
        dim = int(out_channels // factor)
        self.down = nn.Conv2d(in_channels, dim, kernel_size=1, stride=1)
        self.conv_3x3 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.conv_5x5 = nn.Conv2d(dim, dim, kernel_size=5, stride=1, padding=2)
        self.conv_7x7 = nn.Conv2d(dim, dim, kernel_size=7, stride=1, padding=3)
        self.spatial_attention = SpatialAttentionModule()
        self.channel_attention = ChannelAttentionModule(dim)
        self.up = nn.Conv2d(dim, out_channels, kernel_size=1, stride=1)
        self.down_2 = nn.Conv2d(in_channels, dim, kernel_size=1, stride=1)

    def forward(self, x1, x2, x4):
        x_fused = torch.cat([x1, x2, x4], dim=1)
        x_fused = self.down(x_fused)
        x_fused_c = x_fused * self.channel_attention(x_fused)
        x_3x3 = self.conv_3x3(x_fused)
        x_5x5 = self.conv_5x5(x_fused)
        x_7x7 = self.conv_7x7(x_fused)
        x_fused_s = x_3x3 + x_5x5 + x_7x7
        x_fused_s = x_fused_s * self.spatial_attention(x_fused_s)

        x_out = self.up(x_fused_s + x_fused_c)

        return x_out

class MSAA(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MSAA, self).__init__()
        self.fusion_conv = FusionConv(in_channels * 3, out_channels)

    def forward(self, x1, x2, x4, last=False):
        # # x2 是从低到高,x4是从高到低的设计,x2传递语义信息,x4传递边缘问题特征补充
        # x_1_2_fusion = self.fusion_1x2(x1, x2)
        # x_1_4_fusion = self.fusion_1x4(x1, x4)
        # x_fused = x_1_2_fusion + x_1_4_fusion
        x_fused = self.fusion_conv(x1, x2, x4)
        return x_fused


if __name__ == '__main__':

    block = MSAA(in_channels=64, out_channels=128)
    x1 = torch.randn(1, 64, 64, 64)
    x2 = torch.randn(1, 64, 64, 64)
    x4 = torch.randn(1, 64, 64, 64)

    output = block(x1, x2, x4)

    # Print the shapes of the inputs and the output
    print(x1.size())
    print(x2.size())
    print(x4.size())    print(output.size())

http://www.kler.cn/a/372864.html

相关文章:

  • 利用摄像机实时接入分析平台LiteAIServer视频智能分析软件进行视频监控:过亮过暗检测算法详解
  • 同声传译器什么好用?哪款是你的会议利器推荐榜?
  • CheckPointUtilsTest
  • 学习threejs,使用粒子实现下雪特效
  • HTML5 应用程序缓存
  • Kafka-代码示例
  • ubuntu双屏只显示一个屏幕另一个黑屏
  • PowerBI 根据条件选择获得不同的表格 因为IF和SWITCH只能返回标量而不能返回表格 Power BI
  • 《Python游戏编程入门》注-第4章3
  • Scala 的trait
  • 钉钉平台开发小程序
  • Linux 常用命令二
  • 空间音频技术
  • 计算机视觉常用数据集Foggy Cityscapes的介绍、下载、转为YOLO格式进行训练
  • WinUI AOT 发布
  • 输电线路云台变焦视频监控装置在智能识别和数据安全方面有哪些具体的优势和措施?
  • 【设计模式系列】代理模式(八)
  • python爬虫抓取豆瓣数据教程
  • redis:基本全局命令-键管理(1)
  • 同WiFi网络情况下,多个手机怎么实现不同城市的IP
  • MATLAB下的四个模型的IMM例程(CV、CT左转、CT右转、CA四个模型),附源代码可复制
  • yocto 下基于SDK的 tcpdump 移植
  • 爬虫利器playwright
  • ts:常见的内置数学方法(Math)
  • Java项目练习——学生管理系统
  • MR30分布式IO:石化行业的智能化革新