每日Attention学习20——Group Shuffle Attention
模块出处
[MICCAI 24] [link] LB-UNet: A Lightweight Boundary-Assisted UNet for Skin Lesion Segmentation
模块名称
Group Shuffle Attention (GSA)
模块作用
轻量特征学习
模块结构
模块特点
- 使用分组(Group)卷积降低计算量
- 引入External Attention机制更好的学习特征
- Shuffle操作促进不同Group之间信息的交互
模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class GSA(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
c_dim = dim_in // 4
self.share_space1 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
nn.init.ones_(self.share_space1)
self.conv1 = nn.Sequential(
nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
nn.GELU(),
nn.Conv2d(c_dim, c_dim, 1)
)
self.share_space2 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
nn.init.ones_(self.share_space2)
self.conv2 = nn.Sequential(
nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
nn.GELU(),
nn.Conv2d(c_dim, c_dim, 1)
)
self.share_space3 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
nn.init.ones_(self.share_space3)
self.conv3 = nn.Sequential(
nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
nn.GELU(),
nn.Conv2d(c_dim, c_dim, 1)
)
self.share_space4 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
nn.init.ones_(self.share_space4)
self.conv4 = nn.Sequential(
nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
nn.GELU(),
nn.Conv2d(c_dim, c_dim, 1)
)
self.norm1 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')
self.norm2 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')
self.ldw = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size=3, padding=1, groups=dim_in),
nn.GELU(),
nn.Conv2d(dim_in, dim_out, 1),
)
def forward(self, x):
x = self.norm1(x)
x1, x2, x3, x4 = torch.chunk(x, 4, dim=1)
B, C, H, W = x1.size()
x1 = x1 * self.conv1(F.interpolate(self.share_space1, size=x1.shape[2:4],mode='bilinear', align_corners=True))
x2 = x2 * self.conv2(F.interpolate(self.share_space2, size=x1.shape[2:4],mode='bilinear', align_corners=True))
x3 = x3 * self.conv3(F.interpolate(self.share_space3, size=x1.shape[2:4],mode='bilinear', align_corners=True))
x4 = x4 * self.conv4(F.interpolate(self.share_space4, size=x1.shape[2:4],mode='bilinear', align_corners=True))
x = torch.cat([x2,x4,x1,x3], dim=1)
x = self.norm2(x)
x = self.ldw(x)
return x
if __name__ == '__main__':
x = torch.randn([1, 64, 44, 44])
gsa = GSA(dim_in=64, dim_out=64)
out = gsa(x)
print(out.shape) # [1, 64, 44, 44]