【Block总结】SCSA,探索空间与通道注意力之间的协同效应|即插即用
论文信息
该论文于2025年1月27日发布,探讨了空间注意力和通道注意力的协同作用,提出了一种新的空间与通道协同注意力模块(SCSA)。该模块由可共享多语义空间注意力(SMSA)和渐进通道自注意力(PCSA)组成,旨在提升视觉任务中的特征提取能力。
-
论文链接:https://arxiv.org/pdf/2407.05128
-
GitHub链接:https://github.com/HZAI-ZJNU/SCSA
创新点
- 多语义空间注意力(SMSA):整合多种语义信息,通过渐进压缩策略将空间先验信息注入通道自注意力中。
- 渐进通道自注意力(PCSA):基于通道单头自注意力机制,增强特征交互,缓解多语义信息之间的差异。
- 协同机制:通过空间注意力引导通道注意力的学习,提升模型的整体性能。
方法
SCSA的实现方法包括以下几个步骤:
- 特征分解:将输入特征分解为多个独立的子特征,以便高效提取多语义空间信息。
- 轻量级卷积:在每个子特征内应用不同大小的深度一维卷积,捕获不同的语义空间结构。
- 空间注意力图生成:通过组归一化处理不同的子特征,生成空间注意力图。
- 通道自注意力计算:利用渐进式压缩和单头自注意力机制,计算通道间的相似性并缓解语义差异。
SCSA与其他注意力机制的具体改进
SCSA(Spatial and Channel Synergistic Attention)是一种新型的注意力机制,旨在结合空间注意力和通道注意力的优势,以提升深度学习模型在视觉任务中的表现。与传统的注意力机制相比,SCSA在多个方面进行了改进。
具体改进
-
多语义空间信息的利用
- SCSA通过可共享的多语义空间注意力(SMSA)模块,充分利用了输入图像中的多语义空间信息。这一模块采用多尺度深度共享的1D卷积,能够捕捉到丰富的空间特征,从而增强局部和全局特征的表示能力[1][2]。
-
通道特征的精细化处理
- SCSA中的渐进式通道自注意力(PCSA)模块,通过输入感知的自注意力机制,能够有效地精炼通道特征。这一机制不仅减轻了多语义信息之间的语义差异,还确保了通道特征的稳健整合,从而提升了模型的整体性能[1][2]。
-
协同效应的引入
- SCSA通过将空间注意力和通道注意力模块并行组合,利用它们之间的协同效应。空间注意力帮助模型聚焦于重要的空间区域,而通道注意力则强调重要的特征通道。两者的结合使得模型能够同时关注最具信息量的空间位置和特征通道,从而实现更优的决策[1][2]。
-
性能提升
- 在多个基准测试中,SCSA表现出色,超越了现有的最先进注意力机制。例如,在ImageNet-1K分类、MSCOCO目标检测和ADE20K分割任务中,SCSA均展示了显著的性能提升,尤其在低光照和小目标场景下的表现尤为突出[2][1]。
-
处理语义差异的能力
- SCSA有效地处理了由于多语义信息引起的语义差异和交互问题。通过精细化的通道特征处理,SCSA能够更好地整合不同特征通道的信息,提升了模型在复杂场景下的泛化能力[2]。
SCSA通过整合空间和通道注意力的优势,显著提升了特征提取的能力,并在多个视觉任务中取得了优异的表现。其在多语义信息利用、通道特征精细化处理、协同效应引入及性能提升等方面的具体改进,使其在深度学习领域中成为一种具有潜力的注意力机制。
效果
实验结果表明,SCSA在多个视觉任务中表现优异,超越了现有的最先进注意力机制。具体效果包括:
- 图像分类:在ImageNet-1K上,SCSA实现了最高的Top-1准确率。
- 目标检测:在MSCOCO上,SCSA在不同检测器上均表现出色,尤其在小目标和低光照场景中。
- 语义分割:在ADE20K上,SCSA显著提高了mIoU,证明了其在细粒度任务中的有效性。
实验结果
研究团队在七个基准数据集上进行了广泛的实验,包括:
- 分类:ImageNet-1K
- 目标检测:MSCOCO、Pascal VOC、VisDrone、ExDark
- 分割:ADE20K、MSCOCO
实验结果显示,SCSA在各个任务中均优于其他即插即用的注意力机制,展现出强大的泛化能力。
总结
SCSA模块通过有效整合空间和通道注意力的优势,显著提升了特征提取的能力,并在多个视觉任务中取得了优异的表现。该研究为未来的深度学习模型设计提供了新的思路,尤其是在处理复杂视觉任务时,SCSA的引入可能会成为一种重要的工具。
代码
import typing as t
import torch
import torch.nn as nn
from einops import rearrange
__all__ = ['SCSA']
class SCSA(nn.Module):
def __init__(
self,
dim: int,
head_num: int,
window_size: int = 7,
group_kernel_sizes: t.List[int] = [3, 5, 7, 9],
qkv_bias: bool = False,
fuse_bn: bool = False,
down_sample_mode: str = 'avg_pool',
attn_drop_ratio: float = 0.,
gate_layer: str = 'sigmoid',
):
super(SCSA, self).__init__()
self.dim = dim
self.head_num = head_num
self.head_dim = dim // head_num
self.scaler = self.head_dim ** -0.5
self.group_kernel_sizes = group_kernel_sizes
self.window_size = window_size
self.qkv_bias = qkv_bias
self.fuse_bn = fuse_bn
self.down_sample_mode = down_sample_mode
assert self.dim // 4, 'The dimension of input feature should be divisible by 4.'
self.group_chans = group_chans = self.dim // 4
self.local_dwc = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[0],
padding=group_kernel_sizes[0] // 2, groups=group_chans)
self.global_dwc_s = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[1],
padding=group_kernel_sizes[1] // 2, groups=group_chans)
self.global_dwc_m = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[2],
padding=group_kernel_sizes[2] // 2, groups=group_chans)
self.global_dwc_l = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[3],
padding=group_kernel_sizes[3] // 2, groups=group_chans)
self.sa_gate = nn.Softmax(dim=2) if gate_layer == 'softmax' else nn.Sigmoid()
self.norm_h = nn.GroupNorm(4, dim)
self.norm_w = nn.GroupNorm(4, dim)
self.conv_d = nn.Identity()
self.norm = nn.GroupNorm(1, dim)
self.q = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.k = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.v = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.ca_gate = nn.Softmax(dim=1) if gate_layer == 'softmax' else nn.Sigmoid()
if window_size == -1:
self.down_func = nn.AdaptiveAvgPool2d((1, 1))
else:
if down_sample_mode == 'recombination':
self.down_func = self.space_to_chans
# dimensionality reduction
self.conv_d = nn.Conv2d(in_channels=dim * window_size ** 2, out_channels=dim, kernel_size=1, bias=False)
elif down_sample_mode == 'avg_pool':
self.down_func = nn.AvgPool2d(kernel_size=(window_size, window_size), stride=window_size)
elif down_sample_mode == 'max_pool':
self.down_func = nn.MaxPool2d(kernel_size=(window_size, window_size), stride=window_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
The dim of x is (B, C, H, W)
"""
# Spatial attention priority calculation
b, c, h_, w_ = x.size()
# (B, C, H)
x_h = x.mean(dim=3)
l_x_h, g_x_h_s, g_x_h_m, g_x_h_l = torch.split(x_h, self.group_chans, dim=1)
# (B, C, W)
x_w = x.mean(dim=2)
l_x_w, g_x_w_s, g_x_w_m, g_x_w_l = torch.split(x_w, self.group_chans, dim=1)
x_h_attn = self.sa_gate(self.norm_h(torch.cat((
self.local_dwc(l_x_h),
self.global_dwc_s(g_x_h_s),
self.global_dwc_m(g_x_h_m),
self.global_dwc_l(g_x_h_l),
), dim=1)))
x_h_attn = x_h_attn.view(b, c, h_, 1)
x_w_attn = self.sa_gate(self.norm_w(torch.cat((
self.local_dwc(l_x_w),
self.global_dwc_s(g_x_w_s),
self.global_dwc_m(g_x_w_m),
self.global_dwc_l(g_x_w_l)
), dim=1)))
x_w_attn = x_w_attn.view(b, c, 1, w_)
x = x * x_h_attn * x_w_attn
# Channel attention based on self attention
# reduce calculations
y = self.down_func(x)
y = self.conv_d(y)
_, _, h_, w_ = y.size()
# normalization first, then reshape -> (B, H, W, C) -> (B, C, H * W) and generate q, k and v
y = self.norm(y)
q = self.q(y)
k = self.k(y)
v = self.v(y)
# (B, C, H, W) -> (B, head_num, head_dim, N)
q = rearrange(q, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
k = rearrange(k, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
v = rearrange(v, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
# (B, head_num, head_dim, head_dim)
attn = q @ k.transpose(-2, -1) * self.scaler
attn = self.attn_drop(attn.softmax(dim=-1))
# (B, head_num, head_dim, N)
attn = attn @ v
# (B, C, H_, W_)
attn = rearrange(attn, 'b head_num head_dim (h w) -> b (head_num head_dim) h w', h=int(h_), w=int(w_))
# (B, C, 1, 1)
attn = attn.mean((2, 3), keepdim=True)
attn = self.ca_gate(attn)
return attn * x
if __name__ == "__main__":
# 如果GPU可用,将模块移动到 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 输入张量 (batch_size, height, width,channels)
x = torch.randn(1,32,40,40).to(device)
# 初始化 HWD 模块
dim=32
block = SCSA(dim=32, head_num=8, window_size=7)
print(block)
block = block.to(device)
# 前向传播
output = block(x)
print("输入:", x.shape)
print("输出:", output.shape)
输出结果: