每日Attention学习21——Cascade Multi-Receptive Fields
模块出处
[MICCAI 24] [link] TinyU-Net: Lighter Yet Better U-Net with Cascaded Multi-receptive Fields
模块名称
Cascade Multi-Receptive Fields (CMRF)
模块作用
轻量感受野块
模块结构
模块特点
- 起点使用PWConv(PointWise Convolution, 1×1卷积)压缩通道,终点使用PWConv恢复通道,构成bottle neck结构
- 中间使用级联的DWConv(Depthwise Convolution, 深度卷积)提取特征
模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def autopad(k, p=None, d=1):
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Conv(nn.Module):
default_act = nn.GELU()
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
return self.act(self.conv(x))
class DWConv(Conv):
def __init__(self, c1, c2, k=1, s=1, d=1, act=True):
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
class CMRF(nn.Module):
def __init__(self, c1, c2, N=8, shortcut=True, g=1, e=0.5):
super().__init__()
self.N = N
self.c = int(c2 * e / self.N)
self.add = shortcut and c1 == c2
self.pwconv1 = Conv(c1, c2//self.N, 1, 1)
self.pwconv2 = Conv(c2//2, c2, 1, 1)
self.m = nn.ModuleList(DWConv(self.c, self.c, k=3, act=False) for _ in range(N-1))
def forward(self, x):
x_residual = x
x = self.pwconv1(x)
x = [x[:, 0::2, :, :], x[:, 1::2, :, :]]
x.extend(m(x[-1]) for m in self.m)
x[0] = x[0] + x[1]
x.pop(1)
y = torch.cat(x, dim=1)
y = self.pwconv2(y)
return x_residual + y if self.add else y
if __name__ == '__main__':
x = torch.randn([1, 64, 44, 44])
cmrf = CMRF(c1=64, c2=64)
out = cmrf(x)
print(out.shape) # [1, 64, 44, 44]