(arxiv 2024)即插即用多尺度注意力聚合模块MSAA,即用即起飞
题目:CM-UNet: Hybrid CNN-Mamba UNet for Remote Sensing Image Semantic Segmentation
论文地址:https://arxiv.org/pdf/2405.10530
创新点
-
提出CM-UNet框架:基于Mamba架构的CM-UNet框架,通过整合CNN和Mamba模块,能够在遥感图像语义分割任务中高效捕捉局部和全局信息。
-
设计CSMamba块:CSMamba块结合了通道和空间注意力机制,将Mamba模块扩展为能够处理图像长程依赖的组件,提升了特征选择和信息融合的精度。
-
多尺度注意力聚合模块(MSAA):引入MSAA模块,聚合编码器的多尺度特征,通过空间和通道的双重聚合提高特征表达能力,替代传统的跳跃连接,更好地支持解码器的多层次信息融合。
-
多输出监督机制:在解码器的不同层次引入多输出监督,确保各层次逐步细化分割图,从而提升最终分割精度。
方法
整体结构
CM-UNet模型结构由ResNet编码器、多尺度注意力聚合模块(MSAA)和CSMamba解码器组成。编码器负责提取多层次特征,MSAA模块融合多尺度特征以增强表达,解码器则利用CSMamba块通过通道和空间注意力机制高效捕捉长程依赖关系,最终生成精细的分割图,并在各层解码器中加入多输出监督以优化分割结果。
-
CNN编码器:采用ResNet结构作为编码器,用于提取多层次的特征信息。与传统UNet不同,CM-UNet的编码器使用的是多尺度特征提取,以便为后续模块提供更丰富的上下文信息。
-
多尺度注意力聚合模块(MSAA):在编码器和解码器之间,使用MSAA模块对多尺度特征进行聚合。这个模块通过空间和通道注意力机制对不同尺度的特征进行融合,增强特征表达能力,并取代了UNet中的跳跃连接。
-
CSMamba解码器:解码器采用CSMamba块,该模块结合通道和空间注意力机制,以及Mamba结构的线性时间复杂度特性,能够有效捕获图像的长程依赖关系。CSMamba解码器逐步上采样特征并生成输出分割图。同时引入多输出监督,在各层解码器中加入监督信号,确保分割图在不同层次上得到精细化生成。
即插即用模块作用
MSAA 作为一个即插即用模块,主要适用于:
-
多尺度图像特征融合:适合需要在不同尺度上捕捉细节和背景特征的任务,例如遥感图像分割。
-
高分辨率图像分析:在高分辨率图像中识别微小区域和复杂结构时,MSAA能帮助提升模型的特征分辨力。
-
语义分割和目标检测:适用于需要精确识别物体边界和细节的场景,例如自动驾驶和城市规划中的图像语义分割任务。
消融实验结果
-
消融实验对比了在不同模块组合下的性能,主要包括是否引入多尺度注意力聚合模块(MSAA)和多输出监督模块。实验结果表明,单独使用MSAA或多输出监督模块都能提高分割性能,而两者结合使用时,模型的mIoU、mF1和整体准确率(OA)达到了最佳。这表明MSAA和多输出监督模块的协同作用显著提升了模型在遥感图像语义分割中的表现。
即插即用模块
import torch
import torch.nn as nn
# 论文:CM-UNet: Hybrid CNN-Mamba UNet for Remote Sensing Image Semantic Segmentation
# 论文地址:https://arxiv.org/pdf/2405.10530
class ChannelAttentionModule(nn.Module):
def __init__(self, in_channels, reduction=4):
super(ChannelAttentionModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttentionModule(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttentionModule, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class FusionConv(nn.Module):
def __init__(self, in_channels, out_channels, factor=4.0):
super(FusionConv, self).__init__()
dim = int(out_channels // factor)
self.down = nn.Conv2d(in_channels, dim, kernel_size=1, stride=1)
self.conv_3x3 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
self.conv_5x5 = nn.Conv2d(dim, dim, kernel_size=5, stride=1, padding=2)
self.conv_7x7 = nn.Conv2d(dim, dim, kernel_size=7, stride=1, padding=3)
self.spatial_attention = SpatialAttentionModule()
self.channel_attention = ChannelAttentionModule(dim)
self.up = nn.Conv2d(dim, out_channels, kernel_size=1, stride=1)
self.down_2 = nn.Conv2d(in_channels, dim, kernel_size=1, stride=1)
def forward(self, x1, x2, x4):
x_fused = torch.cat([x1, x2, x4], dim=1)
x_fused = self.down(x_fused)
x_fused_c = x_fused * self.channel_attention(x_fused)
x_3x3 = self.conv_3x3(x_fused)
x_5x5 = self.conv_5x5(x_fused)
x_7x7 = self.conv_7x7(x_fused)
x_fused_s = x_3x3 + x_5x5 + x_7x7
x_fused_s = x_fused_s * self.spatial_attention(x_fused_s)
x_out = self.up(x_fused_s + x_fused_c)
return x_out
class MSAA(nn.Module):
def __init__(self, in_channels, out_channels):
super(MSAA, self).__init__()
self.fusion_conv = FusionConv(in_channels * 3, out_channels)
def forward(self, x1, x2, x4, last=False):
# # x2 是从低到高,x4是从高到低的设计,x2传递语义信息,x4传递边缘问题特征补充
# x_1_2_fusion = self.fusion_1x2(x1, x2)
# x_1_4_fusion = self.fusion_1x4(x1, x4)
# x_fused = x_1_2_fusion + x_1_4_fusion
x_fused = self.fusion_conv(x1, x2, x4)
return x_fused
if __name__ == '__main__':
block = MSAA(in_channels=64, out_channels=128)
x1 = torch.randn(1, 64, 64, 64)
x2 = torch.randn(1, 64, 64, 64)
x4 = torch.randn(1, 64, 64, 64)
output = block(x1, x2, x4)
# Print the shapes of the inputs and the output
print(x1.size())
print(x2.size())
print(x4.size()) print(output.size())