【Block总结】Shuffle Attention,新型的Shuffle注意力|即插即用
一、论文信息
-
标题: SA-Net: Shuffle Attention for Deep Convolutional Neural Networks
-
论文链接: arXiv
-
代码链接: GitHub
二、创新点
Shuffle Attention(SA)模块的主要创新在于高效结合了通道注意力和空间注意力,同时通过通道重排技术降低计算复杂度。具体创新点包括:
- 通道分组: 将输入特征图的通道维度分成多个组,允许并行处理。
- 通道重排: 通过打乱通道顺序,增强模型对通道特征的表达能力。
- 融合注意力机制: 同时计算通道和空间注意力,提升特征表示的能力。
三、方法
Shuffle Attention模块的实现步骤如下:
-
特征分组: 将输入特征图的通道数通过参数G进行分组,得到多个子特征图。
-
通道注意力:
- 对每组特征使用全局平均池化(GAP),生成通道级统计数据。
- 通过学习的权重和偏置调整每组特征的通道重要性,并使用Sigmoid激活函数应用于特征图。
-
空间注意力:
- 使用组归一化(GroupNorm)计算每组特征的空间注意力。
- 经过Sigmoid激活后,将空间注意力应用于特征图的空间维度。
-
通道重排: 在计算完通道和空间注意力后,使用通道重排操作打乱通道顺序,以增强特征表达能力。
-
输出: 返回经过重排后的输出特征图。
Shuffle Attention与传统注意力机制的优势比较
Shuffle Attention(SA)是一种新型的注意力机制,旨在提高深度卷积神经网络的性能。与传统的注意力机制相比,Shuffle Attention在多个方面展现出显著的优势,尤其是在计算效率和模型性能方面。
优势如下:
-
高效的计算性能
Shuffle Attention通过将输入特征图的通道维度分成多个组,并对每个组进行并行处理,从而显著降低了计算复杂度。传统的注意力机制通常需要在全通道上进行计算,导致计算量大且效率低下。SA的设计使得在保持性能的同时,减少了参数量和计算量。例如,在ResNet50上,SA的参数量从300M降至25.56M,计算量从4.12 GFLOPs降至2.76 GFLOPs[2]。 -
融合通道和空间注意力
Shuffle Attention同时结合了通道注意力和空间注意力,能够更全面地捕捉特征之间的依赖关系。传统的注意力机制往往是将这两种注意力机制分开处理,未能充分利用它们之间的相互关系。SA通过“通道重排”操作,促进了不同组之间的信息交流,从而提升了特征的表达能力和模型的整体性能。 -
增强特征表达能力
Shuffle Attention通过通道重排和特征分组的方式,增强了模型对特征的表达能力。传统注意力机制在处理特征时,可能会忽略某些重要的通道或空间信息,而SA通过并行处理和重排,确保了所有特征都能得到充分利用,从而提高了模型的准确性。 -
适应性强
Shuffle Attention的设计使其能够灵活适应不同的网络架构和任务需求。它可以作为一个轻量级的模块,方便地集成到现有的深度学习模型中,而不需要对整个网络结构进行大幅修改。这种灵活性使得SA在各种计算机视觉任务中表现出色,包括图像分类、目标检测和实例分割等。
四、效果
Shuffle Attention模块在多个计算机视觉任务中表现出色,尤其是在图像分类和目标检测任务中。通过有效的特征提取和信息融合,SA模块能够在保持较低计算复杂度的同时,显著提高模型的性能。
五、实验结果
在ImageNet-1k、MS COCO等数据集上的实验结果表明:
- 准确率提升: SA模块在与ResNet等主干网络结合时,Top-1准确率提升超过1.34%。
- 计算复杂度降低: 相比于传统的注意力机制,SA模块在参数量和计算量上均显著减少,例如在ResNet50上,参数量从300M降至25.56M,计算量从4.12 GFLOPs降至2.76 GFLOPs。
六、总结
Shuffle Attention模块通过创新的特征分组和通道重排机制,有效地结合了通道和空间注意力,显著提升了深度卷积神经网络的性能。该模块不仅在多个基准测试中表现优异,还展示了在实际应用中的潜力,尤其是在资源受限的环境中。未来的研究可以进一步探索SA模块在更复杂任务中的应用效果。
代码
代码有错误,我做了修改。代码如下:
import torch
import torch.nn.functional
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter
class sa_layer(nn.Module):
"""Constructs a Channel Spatial Group module.
Args:
k_size: Adaptive selection of kernel size
"""
def __init__(self, channel, groups=64):
super(sa_layer, self).__init__()
self.groups = groups
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.cweight = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
self.cbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
self.sweight = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
self.sbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
self.sigmoid = nn.Sigmoid()
self.gn = nn.GroupNorm(channel // (2 * groups), channel // (2 * groups))
@staticmethod
def channel_shuffle(x, groups):
b, c, h, w = x.shape
x = x.reshape(b, groups, -1, h, w)
x = x.permute(0, 2, 1, 3, 4)
# flatten
x = x.reshape(b, -1, h, w)
return x
def forward(self, x):
b, c, h, w = x.shape
x = x.reshape(b * self.groups, -1, h, w)
x_0, x_1 = x.chunk(2, dim=1)
# channel attention
xn = self.avg_pool(x_0)
xn = self.cweight * xn + self.cbias
xn = x_0 * self.sigmoid(xn)
print(xn)
# spatial attention
xs = self.gn(x_1)
xs = self.sweight * xs + self.sbias
xs = x_1 * self.sigmoid(xs)
# concatenate along channel axis
out = torch.cat([xn, xs], dim=1)
out = out.reshape(b, -1, h, w)
out = self.channel_shuffle(out, 2)
return out
if __name__ == "__main__":
dim=64
# 如果GPU可用,将模块移动到 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 输入张量 (batch_size, height, width,channels)
x = torch.randn(2,dim,40,40).to(device)
# 初始化 sa_layer模块
block = sa_layer(dim,8)
print(block)
block = block.to(device)
# 前向传播
output = block(x)
print("输入:", x.shape)
print("输出:", output.shape)
输出结果: