论文阅读(二十四):SA-Net: Shuffle Attention for Deep Convolutional Neural Networks
文章目录
- Abstract
- 1.Introduction
- 2.Shuffle Attention
- 3.Code
论文:SA-Net:Shuffle Attention for Deep Convolutional Neural Networks(SA-Net:置换注意力机制)
论文链接:SA-Net:Shuffle Attention for Deep Convolutional Neural Networks
代码链接:Github
Abstract
计算机视觉的注意力机制主要有空间注意力机制和通道注意力机制两种,分别旨在捕获像素(空间域)依赖性和通道依赖性,尽管将它们融合在一起可能会获得更好的性能,但也会增加计算开销。本文提出一种高效的置换注意力机制 S h u f f l e A t t e n t i o n ( S A ) Shuffle\;Attention(SA) ShuffleAttention(SA),其将通道维度分组到多个子特征中,然后再并行处理。对于每个子特征,SA 利用 S h u f f l e U n i t Shuffle\;Unit ShuffleUnit来描述空间和通道维度的特征依赖关系。之后将所有子特征聚合,并采用 c h a n n e l s h u f f l e channel shuffle channelshuffle运算来实现不同子特征之间的信息通信。
1.Introduction
常见的注意力机制,如GCNet(Gcnet: Non-local networks meet squeeze-excitation networks and beyond)、CBAM(CBAM: convolutional block attention module),将空间注意力和通道注意力整合到一个模块中,但也带来较大的计算负担。受ShuffleNet v2(Shufflenet V2: practical guidelines for efficient CNN architecture design)的启发,本文针对深度卷积神经网络提出了置换注意力机制SA(Shuffle Attention)。它将通道维度分为多个子特征,然后利用Shuffle Unit为每个子特征集成互补的通道和空间注意力模块。
2.Shuffle Attention
S
h
u
f
f
l
e
A
t
t
e
n
t
i
o
n
Shuffle\;Attention
ShuffleAttention机制包含两种运算:
【1.特征分组】
设有特征图
X
∈
R
C
×
H
×
W
X∈R^{C×H×W}
X∈RC×H×W,SA沿通道维度将
X
X
X分为
G
G
G组,即,
X
=
X
1
,
X
2
,
.
.
.
X
g
,
X
k
∈
R
C
g
×
H
×
W
X={X_1,X_2,...X_g},X_k∈R^{\frac{C}{g}×H×W}
X=X1,X2,...Xg,Xk∈RgC×H×W。通过attention模块为每个子特征生成相应的重要性系数。具体来说,在每个注意力单元的开始,将输入
X
k
X_k
Xk沿通道维度拆分为两个分支
X
k
1
、
X
k
2
∈
R
C
2
g
×
H
×
W
X_{k1}、X_{k2}∈R^{\frac{C}{2g}×H×W}
Xk1、Xk2∈R2gC×H×W。一个分支利用通道的相互关系生成通道注意力图,另一个分支利用特征的空间关系生成空间注意力图。
【2.通道注意力图】
完全捕获通道之间的依赖关系的常见模块,如SE(Squeeze-and-Excitation Networks)模块,其会带来太多的参数。本文提出了一种替代方案,与SE模块的思想一样,先通过全局平均池化(GAP)操作来收集空间域的所有信息,将
X
k
1
X_{k1}
Xk1转换为向量
1
×
1
×
C
2
g
1×1×\frac{C}{2g}
1×1×2gC。计算公式:
之后通过简单的门控机制(
F
c
F_c
Fc)与
s
i
g
m
o
i
d
sigmoid
sigmoid函数(
σ
σ
σ)生成通道注意力图,将其与
X
k
1
X_{k1}
Xk1相乘,即可完全捕获通道之间的依赖关系。计算公式:
【3.空间注意力图】
空间注意力图用于捕获位置信息(语义信息),其一般是通道注意力的补充。具体来说,对
X
k
2
X_{k2}
Xk2使用组归一化来捕获空间域的统计信息,与生成通道注意力图的方式相同,使用简单的门控机制(
F
c
F_c
Fc)与
s
i
g
m
o
i
d
sigmoid
sigmoid函数(
σ
σ
σ)生成空间注意力图,将其与
X
k
2
X_{k2}
Xk2相乘,即可完全捕获空间域信息。计算公式:
【4.特征融合】
先通过
C
o
n
c
a
t
Concat
Concat操作将特征图融合得到
X
k
′
=
[
X
k
1
′
,
X
k
2
′
]
∈
R
C
2
G
×
H
×
W
X'_k=[X'_{k1},X'_{k2}]∈R^{\frac{C}{2G}×H×W}
Xk′=[Xk1′,Xk2′]∈R2GC×H×W。最后采用与ShuffleNetV2相同的思想,采用通道置换操作(channel shuffle)。进行组间通信。SA的最终输出具有与输入相同的尺寸。
3.Code
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
class sa_layer(nn.Module):
"""Constructs a Channel Spatial Group module.
Args:
k_size: Adaptive selection of kernel size
"""
def __init__(self, channel, groups=64):
super(sa_layer, self).__init__()
self.groups = groups
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.cweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
self.cbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
self.sweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
self.sbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
self.sigmoid = nn.Sigmoid()
self.gn = nn.GroupNorm(channel // (2 * groups), channel // (2 * groups))
@staticmethod
def channel_shuffle(x, groups):
b, c, h, w = x.shape
x = x.reshape(b, groups, -1, h, w)
x = x.permute(0, 2, 1, 3, 4)
# flatten
x = x.reshape(b, -1, h, w)
return x
def forward(self, x):
#1.特征分组
b, c, h, w = x.shape
x = x.reshape(b * self.groups, -1, h, w)
x_0, x_1 = x.chunk(2, dim=1)
#2.通道注意力图
xn = self.avg_pool(x_0)
xn = self.cweight * xn + self.cbias
xn = x_0 * self.sigmoid(xn)
#3.空间注意力图
xs = self.gn(x_1)
xs = self.sweight * xs + self.sbias
xs = x_1 * self.sigmoid(xs)
#特征融合
out = torch.cat([xn, xs], dim=1)
out = out.reshape(b, -1, h, w)
out = self.channel_shuffle(out, 2)
return out