当前位置: 首页 > article >正文

每日Attention学习21——Cascade Multi-Receptive Fields

模块出处

[MICCAI 24] [link] TinyU-Net: Lighter Yet Better U-Net with Cascaded Multi-receptive Fields


模块名称

Cascade Multi-Receptive Fields (CMRF)


模块作用

轻量感受野块


模块结构

在这里插入图片描述


模块特点
  • 起点使用PWConv(PointWise Convolution, 1×1卷积)压缩通道,终点使用PWConv恢复通道,构成bottle neck结构
  • 中间使用级联的DWConv(Depthwise Convolution, 深度卷积)提取特征

模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


def autopad(k, p=None, d=1):  
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
    return p


class Conv(nn.Module):
    default_act = nn.GELU()
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        super().__init__()
        self.conv   = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn     = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        self.act    = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        return self.act(self.conv(x))
    

class DWConv(Conv):
    def __init__(self, c1, c2, k=1, s=1, d=1, act=True):
        super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)


class CMRF(nn.Module):
    def __init__(self, c1, c2, N=8, shortcut=True, g=1, e=0.5):
        super().__init__()
        self.N         = N
        self.c         = int(c2 * e / self.N)
        self.add       = shortcut and c1 == c2
        self.pwconv1   = Conv(c1, c2//self.N, 1, 1)
        self.pwconv2   = Conv(c2//2, c2, 1, 1)
        self.m         = nn.ModuleList(DWConv(self.c, self.c, k=3, act=False) for _ in range(N-1))

    def forward(self, x):
        x_residual = x
        x          = self.pwconv1(x)
        x          = [x[:, 0::2, :, :], x[:, 1::2, :, :]]
        x.extend(m(x[-1]) for m in self.m)
        x[0]       = x[0] +  x[1] 
        x.pop(1)
        y          = torch.cat(x, dim=1) 
        y          = self.pwconv2(y)
        return x_residual + y if self.add else y
    

if __name__ == '__main__':
    x = torch.randn([1, 64, 44, 44])
    cmrf = CMRF(c1=64, c2=64)
    out = cmrf(x)
    print(out.shape)  # [1, 64, 44, 44]


http://www.kler.cn/a/535829.html

相关文章:

  • 简单说一下CAP理论和Base理论
  • 【正点原子K210连载】第六十七章 音频FFT实验 摘自【正点原子】DNK210使用指南-CanMV版指南
  • 变压器-000000
  • 低代码提升交付效率的公式计算
  • openssl 中 EVP_aes_256_gcm() 函数展开
  • 114,【6】攻防世界 web wzsc_文件上传
  • 华为od 勾股数元组
  • 如何在 FastAPI 中使用本地资源自定义 Swagger UI
  • ElasticSearch 学习课程入门(二)
  • 【2024华为OD-E卷-100分-箱子之字形摆放】((题目+思路+JavaC++Python解析)
  • maxun爬虫机器人介绍与部署
  • vue文档01
  • Linux系统安装Nginx详解(适用于CentOS 7)
  • C#元组和Unity Vector3
  • vue3-响应式 toRefs
  • 旅行社项目展示微信小程序功能模块和开发流程
  • STM32G4系列微控制器深度解析
  • qt使用MQTT协议连接阿里云demo
  • 学习TCL脚本的几个步骤?
  • java开发 网络安全 java开发转网络安全
  • Deepseek 接入Word处理对话框(隐藏密钥)
  • Servlet笔记(上)
  • 深入解析二分查找算法:原理、实现与变种
  • 深度学习篇---深度学习相关知识点关键名词含义
  • MySQL 缓存机制与架构解析
  • react的antd表单校验,禁止输入空格并触发校验提示