当前位置: 首页 > article >正文

每日Attention学习19——Convolutional Multi-Focal Attention

每日Attention学习19——Convolutional Multi-Focal Attention

模块出处

[ICLR 25 Submission] [link] UltraLightUNet: Rethinking U-shaped Network with Multi-kernel Lightweight Convolutions for Medical Image Segmentation


模块名称

Convolutional Multi-Focal Attention (CMFA)


模块作用

轻量解码器


模块结构

在这里插入图片描述


模块特点
  • 使用最大池化与平均池化构建通道注意力
  • 使用Channel Max与Channel Average构建空间注意力
  • 核心思想与CBAM较为类似,串联通道注意力与空间注意力

模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7, 11), 'kernel size must be 3 or 7 or 11'
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, out_planes=None, ratio=16):
        super(ChannelAttention, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        if self.in_planes < ratio:
            ratio = self.in_planes
        self.reduced_channels = self.in_planes // ratio
        if self.out_planes == None:
            self.out_planes = in_planes
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.activation = nn.ReLU(inplace=True)
        self.fc1 = nn.Conv2d(in_planes, self.reduced_channels, 1, bias=False)
        self.fc2 = nn.Conv2d(self.reduced_channels, self.out_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_pool_out = self.avg_pool(x) 
        avg_out = self.fc2(self.activation(self.fc1(avg_pool_out)))
        max_pool_out= self.max_pool(x)
        max_out = self.fc2(self.activation(self.fc1(max_pool_out)))
        out = avg_out + max_out
        return self.sigmoid(out) 
    

class CMFA(nn.Module):
    def __init__(self, in_planes, out_planes=None,):
        super(CMFA, self).__init__()
        self.ca = ChannelAttention(in_planes=64, out_planes=64)
        self.sa = SpatialAttention()

    def forward(self, x):
        x = x*self.ca(x)
        x = x*self.sa(x)
        return x
    

if __name__ == '__main__':
    x = torch.randn([1, 64, 44, 44])
    cmfa = CMFA(in_planes=64, out_planes=64)
    out = cmfa(x)
    print(out.shape)  # [1, 64, 44, 44]


http://www.kler.cn/a/533652.html

相关文章:

  • 除了网页,还有哪些方式可以访问deepseek r1
  • JPA使用@EntityGraph立即加载关联实体
  • 深度剖析八大排序算法
  • Verilog语言学习总结
  • pytorch使用SVM实现文本分类
  • C++:结构体和类
  • Java学习进阶路线
  • 标准库发送数据深入理解USART
  • Windows下安装mkcert
  • 9. k8s二进制集群之kube-controller-manager部署
  • TensorFlow深度学习实战(6)——回归分析详解
  • Deepseek技术浅析(四):专家选择与推理机制
  • AI开发模式:ideal或vscode + 插件continue+DeepSeek R1
  • 0205算法:最长连续序列、三数之和、排序链表
  • 2024年12月 Scratch 图形化(四级)真题解析 中国电子学会全国青少年软件编程等级考试
  • 工作总结:上线篇
  • 你也在这里
  • MYSQL简单查询
  • 【JavaScript】《JavaScript高级程序设计 (第4版) 》笔记-Chapter3-语言基础
  • 力扣-哈希表-1 两数之和
  • Baklib如何实现内容管理平台的智能化升级与数据整合
  • Docker深度解析:安装各大环境
  • [加餐]指针和动态内存管理
  • 网络安全——Span 安全监控
  • 请求响应(接上篇)
  • 【字节青训营-9】:初探字节微服务框架 Hertz 基础使用及进阶(下)