当前位置: 首页 > article >正文

论文阅读(二十四):SA-Net: Shuffle Attention for Deep Convolutional Neural Networks

文章目录

  • Abstract
  • 1.Introduction
  • 2.Shuffle Attention
  • 3.Code


  论文:SA-Net:Shuffle Attention for Deep Convolutional Neural Networks(SA-Net:置换注意力机制)
  论文链接:SA-Net:Shuffle Attention for Deep Convolutional Neural Networks
  代码链接:Github

Abstract

  计算机视觉的注意力机制主要有空间注意力机制和通道注意力机制两种,分别旨在捕获像素(空间域)依赖性和通道依赖性,尽管将它们融合在一起可能会获得更好的性能,但也会增加计算开销。本文提出一种高效的置换注意力机制 S h u f f l e    A t t e n t i o n ( S A ) Shuffle\;Attention(SA) ShuffleAttention(SA),其将通道维度分组到多个子特征中,然后再并行处理。对于每个子特征,SA 利用 S h u f f l e    U n i t Shuffle\;Unit ShuffleUnit来描述空间和通道维度的特征依赖关系。之后将所有子特征聚合,并采用 c h a n n e l s h u f f l e channel shuffle channelshuffle运算来实现不同子特征之间的信息通信。

1.Introduction

  常见的注意力机制,如GCNet(Gcnet: Non-local networks meet squeeze-excitation networks and beyond)、CBAM(CBAM: convolutional block attention module),将空间注意力和通道注意力整合到一个模块中,但也带来较大的计算负担。受ShuffleNet v2(Shufflenet V2: practical guidelines for efficient CNN architecture design)的启发,本文针对深度卷积神经网络提出了置换注意力机制SA(Shuffle Attention)。它将通道维度分为多个子特征,然后利用Shuffle Unit为每个子特征集成互补的通道和空间注意力模块。

2.Shuffle Attention

在这里插入图片描述
   S h u f f l e    A t t e n t i o n Shuffle\;Attention ShuffleAttention机制包含两种运算:
【1.特征分组】
  设有特征图 X ∈ R C × H × W X∈R^{C×H×W} XRC×H×W,SA沿通道维度将 X X X分为 G G G组,即, X = X 1 , X 2 , . . . X g , X k ∈ R C g × H × W X={X_1,X_2,...X_g},X_k∈R^{\frac{C}{g}×H×W} X=X1,X2,...Xg,XkRgC×H×W。通过attention模块为每个子特征生成相应的重要性系数。具体来说,在每个注意力单元的开始,将输入 X k X_k Xk沿通道维度拆分为两个分支 X k 1 、 X k 2 ∈ R C 2 g × H × W X_{k1}、X_{k2}∈R^{\frac{C}{2g}×H×W} Xk1Xk2R2gC×H×W。一个分支利用通道的相互关系生成通道注意力图,另一个分支利用特征的空间关系生成空间注意力图。
【2.通道注意力图】
  完全捕获通道之间的依赖关系的常见模块,如SE(Squeeze-and-Excitation Networks)模块,其会带来太多的参数。本文提出了一种替代方案,与SE模块的思想一样,先通过全局平均池化(GAP)操作来收集空间域的所有信息,将 X k 1 X_{k1} Xk1转换为向量 1 × 1 × C 2 g 1×1×\frac{C}{2g} 1×1×2gC。计算公式:
在这里插入图片描述
之后通过简单的门控机制( F c F_c Fc)与 s i g m o i d sigmoid sigmoid函数( σ σ σ)生成通道注意力图,将其与 X k 1 X_{k1} Xk1相乘,即可完全捕获通道之间的依赖关系。计算公式:
在这里插入图片描述
【3.空间注意力图】
  空间注意力图用于捕获位置信息(语义信息),其一般是通道注意力的补充。具体来说,对 X k 2 X_{k2} Xk2使用组归一化来捕获空间域的统计信息,与生成通道注意力图的方式相同,使用简单的门控机制( F c F_c Fc)与 s i g m o i d sigmoid sigmoid函数( σ σ σ)生成空间注意力图,将其与 X k 2 X_{k2} Xk2相乘,即可完全捕获空间域信息。计算公式:
在这里插入图片描述
【4.特征融合】
  先通过 C o n c a t Concat Concat操作将特征图融合得到 X k ′ = [ X k 1 ′ , X k 2 ′ ] ∈ R C 2 G × H × W X'_k=[X'_{k1},X'_{k2}]∈R^{\frac{C}{2G}×H×W} Xk=[Xk1,Xk2]R2GC×H×W。最后采用与ShuffleNetV2相同的思想,采用通道置换操作(channel shuffle)。进行组间通信。SA的最终输出具有与输入相同的尺寸。

3.Code

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
class sa_layer(nn.Module):
    """Constructs a Channel Spatial Group module.
    Args:
        k_size: Adaptive selection of kernel size
    """
    def __init__(self, channel, groups=64):
        super(sa_layer, self).__init__()
        self.groups = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.cweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
        self.sweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
        self.sigmoid = nn.Sigmoid()
        self.gn = nn.GroupNorm(channel // (2 * groups), channel // (2 * groups))
    @staticmethod
    def channel_shuffle(x, groups):
        b, c, h, w = x.shape
        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)
        # flatten
        x = x.reshape(b, -1, h, w)
        return x
    def forward(self, x):
    	#1.特征分组
        b, c, h, w = x.shape
        x = x.reshape(b * self.groups, -1, h, w)
        x_0, x_1 = x.chunk(2, dim=1)
        #2.通道注意力图
        xn = self.avg_pool(x_0)
        xn = self.cweight * xn + self.cbias
        xn = x_0 * self.sigmoid(xn)
        #3.空间注意力图
        xs = self.gn(x_1)
        xs = self.sweight * xs + self.sbias
        xs = x_1 * self.sigmoid(xs)
        #特征融合
        out = torch.cat([xn, xs], dim=1)
        out = out.reshape(b, -1, h, w)
        out = self.channel_shuffle(out, 2)
        return out

http://www.kler.cn/news/362213.html

相关文章:

  • CMOS 图像传感器:像素寻址与信号处理
  • SpringCloud学习记录|day6
  • 【从零开始的LeetCode-算法】3075. 幸福值最大化的选择方案
  • 认识一下:__asm { int 80h; LINUX - sys_fork }
  • 2024软考网络工程师笔记 - 第10章.组网技术
  • 【Python】爬虫
  • linux系统下载安装nvidia显卡驱动
  • Qml的使用
  • Unity加载界面制作
  • Docker:安装 Syslog-ng 的技术指南
  • Build an Android project and get a `.apk` file on a Debian 11 command line
  • Java项目-基于Springboot的车辆充电桩项目(源码+说明).zip
  • c++基础算法练习(1)
  • Android SELinux——调试工具audio2allow介绍(十三)
  • Linux系列-Linux的常见指令(三)
  • 错误0x80070522:客户端没有所需的特权
  • C++ set和map的模拟实现
  • 在Debian上安装向日葵
  • 深度学习-卷积神经网络-基于VGG16模型, 实现猫狗二分类(文末附带数据集下载链接, 长期有效)
  • oracle_查询建表语句
  • 【毕业设计】基于SpringBoot的网上商城系统
  • 【C语言刷力扣】2006.差的绝对值为K的数对数目
  • CTFHUB技能树之SQL——布尔盲注
  • 前端模块化技术 IIFE、CMD、UMD
  • 智能去毛刺:2D视觉引导机器人如何重塑制造业未来
  • MySQL 指定字段排序