当前位置: 首页 > article >正文

每日Attention学习20——Group Shuffle Attention

模块出处

[MICCAI 24] [link] LB-UNet: A Lightweight Boundary-Assisted UNet for Skin Lesion Segmentation


模块名称

Group Shuffle Attention (GSA)


模块作用

轻量特征学习


模块结构

在这里插入图片描述


模块特点
  • 使用分组(Group)卷积降低计算量
  • 引入External Attention机制更好的学习特征
  • Shuffle操作促进不同Group之间信息的交互

模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x
        

class GSA(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        c_dim = dim_in // 4
        self.share_space1 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
        nn.init.ones_(self.share_space1)
        self.conv1 = nn.Sequential(
            nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
            nn.GELU(),
            nn.Conv2d(c_dim, c_dim, 1)
        )
        self.share_space2 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
        nn.init.ones_(self.share_space2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
            nn.GELU(),
            nn.Conv2d(c_dim, c_dim, 1)
        )
        self.share_space3 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
        nn.init.ones_(self.share_space3)
        self.conv3 = nn.Sequential(
            nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
            nn.GELU(),
            nn.Conv2d(c_dim, c_dim, 1)
        )
        self.share_space4 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
        nn.init.ones_(self.share_space4)
        self.conv4 = nn.Sequential(
            nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
            nn.GELU(),
            nn.Conv2d(c_dim, c_dim, 1)
        )
        self.norm1 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')
        self.norm2 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')
        self.ldw = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size=3, padding=1, groups=dim_in),
            nn.GELU(),
            nn.Conv2d(dim_in, dim_out, 1),
        )

    def forward(self, x):
        x = self.norm1(x)
        x1, x2, x3, x4 = torch.chunk(x, 4, dim=1)
        B, C, H, W = x1.size()
        x1 = x1 * self.conv1(F.interpolate(self.share_space1, size=x1.shape[2:4],mode='bilinear', align_corners=True))
        x2 = x2 * self.conv2(F.interpolate(self.share_space2, size=x1.shape[2:4],mode='bilinear', align_corners=True))
        x3 = x3 * self.conv3(F.interpolate(self.share_space3, size=x1.shape[2:4],mode='bilinear', align_corners=True))
        x4 = x4 * self.conv4(F.interpolate(self.share_space4, size=x1.shape[2:4],mode='bilinear', align_corners=True))
        x = torch.cat([x2,x4,x1,x3], dim=1)
        x = self.norm2(x)
        x = self.ldw(x)
        return x
    

if __name__ == '__main__':
    x = torch.randn([1, 64, 44, 44])
    gsa = GSA(dim_in=64, dim_out=64)
    out = gsa(x)
    print(out.shape)  # [1, 64, 44, 44]


http://www.kler.cn/a/533789.html

相关文章:

  • Qt跨屏窗口的一个Bug及解决方案
  • VMware下Linux和macOS遇到的一些问题总结
  • 修剪二叉搜索树(力扣669)
  • 备战蓝桥杯-并查集
  • [ Spring ] Spring Boot Mybatis++ 2025
  • oracle 基础语法复习记录
  • DeepSeek-V3 大模型哪些地方超越了其他主流大模型
  • 中国通信企业协会 通信网络安全服务能力评定 风险评估二级要求准则
  • 保姆级教程Docker部署Zookeeper官方镜像
  • FPGA学习篇——Verilog学习1
  • Shell条件变量替换
  • PySpark学习笔记5-SparkSQL
  • 在游戏本(6G显存)上本地部署Deepseek,运行一个14B大语言模型,并使用API访问
  • 记录debian12运行时出现卡死的问题
  • http状态码:请说说 503 Service Unavailable(服务不可用)的原因以及排查问题的思路
  • Windows Docker笔记-简介摘录
  • Java synchronized锁升级
  • 算法与数据结构(括号匹配问题)
  • w192中国陕西民俗网的设计与实现
  • 从BIO到NIO:Java IO的进化之路
  • deepseekLLM发展历程
  • ElasticSearch学习笔记-解析JSON格式的内容
  • 硬件工程师笔试基础题目
  • 数字化转型:概念性名词浅谈(第四讲)
  • DS图(下)(19)
  • 【算法】经典博弈论问题③——斐波那契博弈 + Zeckendorf 定理 python