当前位置: 首页 > article >正文

【Block总结】Shuffle Attention,新型的Shuffle注意力|即插即用

一、论文信息

  • 标题: SA-Net: Shuffle Attention for Deep Convolutional Neural Networks

  • 论文链接: arXiv

  • 代码链接: GitHub
    在这里插入图片描述

二、创新点

Shuffle Attention(SA)模块的主要创新在于高效结合了通道注意力和空间注意力,同时通过通道重排技术降低计算复杂度。具体创新点包括:

  • 通道分组: 将输入特征图的通道维度分成多个组,允许并行处理。
  • 通道重排: 通过打乱通道顺序,增强模型对通道特征的表达能力。
  • 融合注意力机制: 同时计算通道和空间注意力,提升特征表示的能力。
    在这里插入图片描述

三、方法

Shuffle Attention模块的实现步骤如下:

  1. 特征分组: 将输入特征图的通道数通过参数G进行分组,得到多个子特征图。

  2. 通道注意力:

    • 对每组特征使用全局平均池化(GAP),生成通道级统计数据。
    • 通过学习的权重和偏置调整每组特征的通道重要性,并使用Sigmoid激活函数应用于特征图。
  3. 空间注意力:

    • 使用组归一化(GroupNorm)计算每组特征的空间注意力。
    • 经过Sigmoid激活后,将空间注意力应用于特征图的空间维度。
  4. 通道重排: 在计算完通道和空间注意力后,使用通道重排操作打乱通道顺序,以增强特征表达能力。

  5. 输出: 返回经过重排后的输出特征图。

Shuffle Attention与传统注意力机制的优势比较

Shuffle Attention(SA)是一种新型的注意力机制,旨在提高深度卷积神经网络的性能。与传统的注意力机制相比,Shuffle Attention在多个方面展现出显著的优势,尤其是在计算效率和模型性能方面。

优势如下:

  1. 高效的计算性能
    Shuffle Attention通过将输入特征图的通道维度分成多个组,并对每个组进行并行处理,从而显著降低了计算复杂度。传统的注意力机制通常需要在全通道上进行计算,导致计算量大且效率低下。SA的设计使得在保持性能的同时,减少了参数量和计算量。例如,在ResNet50上,SA的参数量从300M降至25.56M,计算量从4.12 GFLOPs降至2.76 GFLOPs[2]。

  2. 融合通道和空间注意力
    Shuffle Attention同时结合了通道注意力和空间注意力,能够更全面地捕捉特征之间的依赖关系。传统的注意力机制往往是将这两种注意力机制分开处理,未能充分利用它们之间的相互关系。SA通过“通道重排”操作,促进了不同组之间的信息交流,从而提升了特征的表达能力和模型的整体性能。

  3. 增强特征表达能力
    Shuffle Attention通过通道重排和特征分组的方式,增强了模型对特征的表达能力。传统注意力机制在处理特征时,可能会忽略某些重要的通道或空间信息,而SA通过并行处理和重排,确保了所有特征都能得到充分利用,从而提高了模型的准确性。

  4. 适应性强
    Shuffle Attention的设计使其能够灵活适应不同的网络架构和任务需求。它可以作为一个轻量级的模块,方便地集成到现有的深度学习模型中,而不需要对整个网络结构进行大幅修改。这种灵活性使得SA在各种计算机视觉任务中表现出色,包括图像分类、目标检测和实例分割等。

四、效果

Shuffle Attention模块在多个计算机视觉任务中表现出色,尤其是在图像分类和目标检测任务中。通过有效的特征提取和信息融合,SA模块能够在保持较低计算复杂度的同时,显著提高模型的性能。

五、实验结果

在ImageNet-1k、MS COCO等数据集上的实验结果表明:

  • 准确率提升: SA模块在与ResNet等主干网络结合时,Top-1准确率提升超过1.34%。
  • 计算复杂度降低: 相比于传统的注意力机制,SA模块在参数量和计算量上均显著减少,例如在ResNet50上,参数量从300M降至25.56M,计算量从4.12 GFLOPs降至2.76 GFLOPs。

六、总结

Shuffle Attention模块通过创新的特征分组和通道重排机制,有效地结合了通道和空间注意力,显著提升了深度卷积神经网络的性能。该模块不仅在多个基准测试中表现优异,还展示了在实际应用中的潜力,尤其是在资源受限的环境中。未来的研究可以进一步探索SA模块在更复杂任务中的应用效果。

代码

代码有错误,我做了修改。代码如下:


import torch
import torch.nn.functional
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter


class sa_layer(nn.Module):
    """Constructs a Channel Spatial Group module.

    Args:
        k_size: Adaptive selection of kernel size
    """

    def __init__(self, channel, groups=64):
        super(sa_layer, self).__init__()
        self.groups = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.cweight = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
        self.sweight = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))

        self.sigmoid = nn.Sigmoid()
        self.gn = nn.GroupNorm(channel // (2 * groups), channel // (2 * groups))
    @staticmethod
    def channel_shuffle(x, groups):
        b, c, h, w = x.shape

        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)

        # flatten
        x = x.reshape(b, -1, h, w)

        return x

    def forward(self, x):
        b, c, h, w = x.shape

        x = x.reshape(b * self.groups, -1, h, w)
        x_0, x_1 = x.chunk(2, dim=1)

        # channel attention
        xn = self.avg_pool(x_0)
        xn = self.cweight * xn + self.cbias

        xn = x_0 * self.sigmoid(xn)
        print(xn)
        # spatial attention
        xs = self.gn(x_1)
        xs = self.sweight * xs + self.sbias
        xs = x_1 * self.sigmoid(xs)

        # concatenate along channel axis
        out = torch.cat([xn, xs], dim=1)
        out = out.reshape(b, -1, h, w)

        out = self.channel_shuffle(out, 2)
        return out



if __name__ == "__main__":
    dim=64
    # 如果GPU可用,将模块移动到 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 输入张量 (batch_size, height, width,channels)
    x = torch.randn(2,dim,40,40).to(device)
    # 初始化 sa_layer模块

    block = sa_layer(dim,8)
    print(block)
    block = block.to(device)
    # 前向传播
    output = block(x)
    print("输入:", x.shape)
    print("输出:", output.shape)

输出结果:
在这里插入图片描述


http://www.kler.cn/a/529664.html

相关文章:

  • 低成本、高附加值,具有较强的可扩展性和流通便利性的行业
  • vim的特殊模式-可视化模式
  • MYSQL--一条SQL执行的流程,分析MYSQL的架构
  • nginx目录结构和配置文件
  • Keepalived高可用集群企业应用实例二
  • 索引的底层数据结构、B+树的结构、为什么InnoDB使用B+树而不是B树呢
  • 在C语言中使用条件变量实现线程同步
  • w187社区养老服务平台的设计与实现
  • M|哪吒之魔童闹海
  • 【网络】传输层协议TCP(重点)
  • Python虚拟环境
  • Redis万字面试题汇总
  • 虚幻基础16:locomotion direction
  • 使用ollama在本地部署一个deepseek大模型
  • 面渣逆袭之Java基础篇3
  • LLMs之DeepSeek:Math-To-Manim的简介(包括DeepSeek R1-Zero的详解)、安装和使用方法、案例应用之详细攻略
  • XML DOM 节点
  • 详解Kafka并行计算架构
  • 深入了解 ls 命令及其选项
  • 【AI】探索自然语言处理(NLP):从基础到前沿技术及代码实践
  • unity免费资源2025-2-2
  • 涡旋光特性及多种模型、涡旋光仿真
  • final-关键字
  • 穷举vs暴搜vs深搜vs回溯vs剪枝系列一>单词搜索
  • wax到底是什么意思
  • 【高级篇 / IPv6】(7.6) ❀ 03. 宽带IPv6 - ADSL拨号宽带上网配置 ❀ FortiGate 防火墙