【Block总结】高效多尺度注意力EMA,超越SE、CBAM、SA、CA等注意力|即插即用
论文信息
标题: Efficient Multi-Scale Attention Module with Cross-Spatial Learning
作者: Daliang Ouyang, Su He, Guozhong Zhang, Mingzhu Luo, Huaiyong Guo, Jian Zhan, Zhijie Huang
论文链接: https://arxiv.org/pdf/2305.13563v2
GitHub链接: https://github.com/YOLOonMe/EMA-attention-module
创新点
该论文提出了一种新颖的高效多尺度注意力模块(EMA),旨在通过跨空间学习来提升特征表示的效果,同时降低计算开销。EMA模块的设计重点在于:
- 信息保留: 在每个通道上保留信息,确保特征的完整性。
- 计算效率: 通过重塑部分通道为批处理维度,减少计算负担。
- 多尺度学习: 结合多尺度特征,增强模型对不同尺度信息的捕捉能力。
方法
EMA模块的核心方法包括:
-
通道重塑: 将部分通道重塑为批处理维度,并将通道维度分组为多个子特征,以实现更高效的信息处理。
-
跨维度交互: 通过跨维度交互,聚合两个并行分支的输出特征,捕获像素级的成对关系。
-
并行子网络: 设计多尺度并行子网络,以建立短期和长期依赖关系,从而增强特征表示能力。
EMA模块的信息保留与计算效率平衡
信息保留机制
EMA(Efficient Multi-Scale Attention)模块通过以下几种方式实现信息的有效保留:
-
通道重塑: EMA模块将部分通道重塑为批处理维度,并将通道维度分组为多个子特征。这种设计确保了每个通道的信息能够被有效保留,同时避免了通道维度的削减,从而增强了特征的表达能力[1][3]。
-
跨维度交互: 在EMA模块中,两个并行分支的输出特征通过跨维度交互进行聚合。这种交互机制能够捕捉到像素级的成对关系,从而进一步提升特征的丰富性和准确性[2][3]。
-
多尺度并行子网络: EMA模块采用了多尺度并行子网络结构,结合了1x1和3x3卷积核的特征处理。这种结构能够有效捕获不同尺度的信息,确保在特征提取过程中不会丢失重要信息[2][3]。
计算效率提升
在计算效率方面,EMA模块通过以下方式优化了计算过程:
-
减少计算开销: 通过将部分通道重塑为批处理维度,EMA模块能够在不显著增加计算成本的情况下,保持高效的信息处理。这种方法使得模型在处理大规模数据时更加高效[1][2]。
-
并行处理: EMA模块的设计允许多个子网络并行处理特征,这不仅提高了计算效率,还减少了模型的顺序处理需求,从而加快了整体计算速度[3]。
-
适度的模型尺寸: EMA模块的设计确保了模型的尺寸适中,适合在移动终端等资源受限的环境中部署。这种设计使得EMA模块在保持性能的同时,能够有效降低计算资源的消耗[3][2]。
EMA模块通过创新的设计实现了信息保留与计算效率的平衡。其通道重塑、跨维度交互和多尺度并行处理的策略,不仅确保了特征信息的完整性,还显著提高了计算效率。这使得EMA模块在计算机视觉任务中表现出色,尤其是在小目标检测和图像分类等应用中,展现了其广泛的应用潜力和实际意义。
效果
实验结果表明,EMA模块在多个计算机视觉任务中表现优异,尤其是在小目标检测和图像分类任务中,相较于传统的注意力机制(如ECA、CBAM、CA),EMA模块显著提高了特征表示的清晰度和准确性。
实验结果
在广泛的消融研究和实验中,EMA模块在以下数据集上进行了评估:
- CIFAR-100
- ImageNet-1k
- MS COCO
- VisDrone2019
实验结果显示,EMA模块在这些基准测试中均取得了优于现有方法的性能,尤其在小目标检测任务中,表现出明显的优势。
总结
Efficient Multi-Scale Attention Module with Cross-Spatial Learning通过创新的设计和有效的实现,成功地提升了计算机视觉任务中的特征表示能力,同时降低了计算复杂度。该模块的提出为未来的研究提供了新的思路,尤其是在需要高效处理大规模数据的应用场景中,EMA模块展现了其广泛的应用潜力。
代码
import torch
from torch import nn
class EMA(nn.Module):
def __init__(self, channels, c2=None, factor=32):
super(EMA, self).__init__()
self.groups = factor
assert channels // self.groups > 0
self.softmax = nn.Softmax(-1)
self.agp = nn.AdaptiveAvgPool2d((1, 1))
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)
def forward(self, x):
b, c, h, w = x.size()
group_x = x.reshape(b * self.groups, -1, h, w) # b*g,c//g,h,w
x_h = self.pool_h(group_x)
x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
x_h, x_w = torch.split(hw, [h, w], dim=2)
x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
x2 = self.conv3x3(group_x)
x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
x12 = x2.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw
x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
x22 = x1.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw
weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
return (group_x * weights.sigmoid()).reshape(b, c, h, w)
if __name__ == "__main__":
# 如果GPU可用,将模块移动到 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 输入张量 (batch_size, channels, height, width)
x = torch.randn(1,32,40,40).to(device)
# 初始化 pconv 模块
dim=32
block = EMA(dim,factor=8)
print(block)
block = block.to(device)
# 前向传播
output = block(x)
print("输入:", x.shape)
print("输出:", output.shape)
输出结果: