当前位置: 首页 > article >正文

每日Attention学习18——Grouped Attention Gate

模块出处

[ICLR 25 Submission] [link] UltraLightUNet: Rethinking U-shaped Network with Multi-kernel Lightweight Convolutions for Medical Image Segmentation


模块名称

Grouped Attention Gate (GAG)


模块作用

轻量特征融合


模块结构

在这里插入图片描述


模块特点
  • 特征融合前使用Group Conv进行处理,比标准卷积更加轻量
  • 将融合得到的粗特征视为Spatial Attention Map, 并与Encoder特征相乘,从而实现名字中"Gate"的效果
  • 相较于特征融合模块,也可以视为一种利用辅助信息(Decoder)特征以增强Encoder特征的增强模块

模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class GAG(nn.Module):
    def __init__(self, F_g, F_l, F_int, kernel_size=1, groups=1):
        super(GAG,self).__init__()
        if kernel_size == 1:
            groups = 1
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=kernel_size,stride=1,padding=kernel_size//2,groups=groups, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=kernel_size,stride=1,padding=kernel_size//2,groups=groups, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.activation = nn.ReLU(inplace=True)

        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.activation(g1+x1)
        psi = self.psi(psi)
        return x*psi
    

if __name__ == '__main__':
    x1 = torch.randn([1, 64, 44, 44])
    x2 = torch.randn([1, 64, 44, 44])
    gag = GAG(F_g=64, F_l=64, F_int=64//2, kernel_size=3, groups=64//2)
    out = gag(x1, x2)
    print(out.shape)  # [1, 64, 44, 44]


http://www.kler.cn/a/533344.html

相关文章:

  • Github 2025-01-31Java开源项目日报 Top10
  • Verilog基础(三):过程
  • SQL进阶实战技巧:如何构建用户行为转移概率矩阵,深入洞察会话内活动流转?
  • RK3566-移植5.10内核Ubuntu22.04
  • Linux:文件系统(软硬链接)
  • 物业管理系统源码提升社区智能化管理效率与用户体验
  • 探索巨控GRM240系列远程模块的强大功能:物联应用新选择
  • deepseek、qwen等多种模型本地化部署
  • RabbitMQ 深度解析与最佳实践
  • 【LeetCode 刷题】贪心算法(1)-基础
  • React开发中箭头函数返回值陷阱的深度解析
  • 利用TensorFlow.js实现浏览器端机器学习:一个全面指南
  • 机器学习专业毕设选题推荐合集 人工智能
  • 4 HBase 的高级 shell 管理命令
  • [基础]端口隔离实验
  • Elasticsearch 就业形势
  • C++STL(二)——vector
  • 基于springboot河南省旅游管理系统
  • Java高频面试之SE-17
  • 糖果(安师大)
  • vscode技巧总结
  • go语言中的Stringer的使用
  • 【工具变量】中国省级八批自由贸易试验区设立及自贸区设立数据(2024-2009年)
  • JSON常用的工具方法
  • 家政预约小程序12服务详情
  • 如何自定义软件安装路径及Scoop包管理器使用全攻略