24/11/4 算法笔记 蛇形卷积
蛇形卷积(Snake Convolution)是一种新型的卷积操作,它旨在提高对细长和弯曲的管状结构的特征提取能力。这种卷积操作的设计灵感来源于蛇形曲线,能够在不同尺度上捕捉到管状结构的细节信息,从而提高准确性。以下是蛇形卷积的一些核心特点和机制:
-
动态蛇形卷积核:蛇形卷积核的形状不是固定的矩形或方形,而是类似于蛇形路径,这样的设计使得卷积核能够更灵活地捕捉图像中的曲线和非直线结构,更好地适应图像中的复杂边缘和纹理。
-
自适应调整权重:蛇形卷积通过自适应地调整卷积核的权重,使得网络能够更加关注管状结构的局部特征,如血管的分叉和弯曲部分。
-
提高目标检测准确性和鲁棒性:这种模块的应用通常能够增强目标检测模型对不同尺度、形状和姿态的目标的感知能力,从而提高目标检测的准确性和鲁棒性。
-
捕捉复杂结构:蛇形动态卷积(Snake-like Dynamic Convolution)是一种旨在提升卷积操作的灵活性和适应性,以便更好地捕捉和表征图像中的复杂结构的技术。
-
动态调整卷积核参数:动态卷积的核心思想是根据输入数据动态调整卷积核的参数,使其更加适应局部特征。蛇形动态卷积是动态卷积的一种特殊形式,其名称源于卷积核的形状和应用方式。
-
连续性拓扑约束损失:在一些实现中,蛇形卷积还会结合连续性拓扑约束损失,以增强对管状结构连续性的特征提取。
三维心脏血管数据集和二维遥感道路数据集,这两个数据集都旨在提取管状结构,但由于脆弱的局部结构和复杂的全球形态,这项任务存在很多挑战。
标准卷积核旨在提取局部特征。在此基础上,可变形卷积核丰富其应用,并适应不同目标的几何形态,但针对细长的管状结构容易丢失关注。
在实现上,蛇形卷积包含一个偏移卷积层来预测偏移量,然后使用这个偏移量来调整卷积核的形状。这个过程涉及到计算偏移量、生成坐标映射以及对输入特征进行插值,以实现可变形的卷积操作。
偏移量依据什么设置:
-
偏移量预测: 偏移量是通过一个专门的卷积层(
offset_conv
)来预测的。这个卷积层通常是一个较小的卷积核(例如3x3),用于从输入特征图中学习偏移量。这个预测过程是基于输入数据本身的特征,通过训练过程中的反向传播来优化偏移量的预测。 -
批量归一化和激活函数: 预测出的偏移量会经过批量归一化(
bn
)和激活函数(通常是tanh
)处理。批量归一化有助于稳定训练过程,而tanh
函数则将偏移量限制在[-1, 1]的范围内,这有助于控制偏移量的大小,使其不会过大或过小。 -
偏移量的动态调整: 偏移量会根据蛇形卷积的设计动态调整。在蛇形卷积中,偏移量不仅仅是固定的数值,而是可以根据目标的形状和位置动态变化的。这种动态调整是通过将偏移量与扩展范围(
extend_scope
)相乘来实现的,这样可以控制偏移量的影响范围。 -
坐标映射: 偏移量用于生成坐标映射(
y_coordinate_map
和x_coordinate_map
),这些映射定义了输入特征图上每个点的新位置。坐标映射是根据偏移量和卷积核的中心位置计算得出的,用于指导后续的特征插值过程。 -
特征插值: 使用坐标映射,蛇形卷积会对输入特征图进行插值,以获得调整后的特征图。这个插值过程可以是双线性插值(
bilinear
)或双三次插值(bicubic
),取决于具体的实现和需求。 -
偏移量的归一化: 在某些实现中,预测出的偏移量会通过除以特征图的原始大小(
offset_normalizer
)来进行归一化,这样偏移量就是相对于特征图大小的比例,而不是绝对的像素值。这种归一化有助于使模型对不同尺寸的输入具有更好的适应性。
蛇形卷积自适应调节过程:
1. 偏移量计算
我们用一个卷积层来计算偏移量
1.1 定义偏移卷积层
首先,定义一个卷积层,它的输入通道数与主卷积层的输入通道数相同,输出通道数是卷积核尺寸的两倍,因为需要为每个方向(通常是x和y方向)预测一个偏移量。
self.offset_conv = nn.Conv2d(in_channels, 2 * kernel_size, kernel_size, padding=1)
1.2 正向传播计算偏移量
offset = self.offset_conv(x)
1.3 应用激活函数
为了避免偏移量过大,通常会对偏移量应用一个激活函数,如tanh
,将偏移量限制在一个合理的范围内,通常是[-1, 1]。
offset = self.tanh(offset)
1.4计算新的采样位置
接下来,使用原始的坐标网格和预测的偏移量来计算新的采样位置。这通常涉及到对每个位置的坐标进行调整。
# 假设grid是原始的坐标网格,offset是预测的偏移量
new_x = grid_x + offset[:, 0, :, :]
new_y = grid_y + offset[:, 1, :, :]
1.5特征插值
根据新的采样位置,使用插值方法(如双线性插值)从输入特征图中采样特征,以获得调整后的特征图。
# 假设input_feature是输入特征图,new_x和new_y是新的采样位置
output_feature = F.grid_sample(input_feature, torch.stack((new_x, new_y), dim=-1))
特征插值函数F.grid_sample
1.5.1实现代码
import torch
import torch.nn.functional as F
def bilinear_interpolate(img, x, y):
x = x.floor().int()
y = y.floor().int()
x0 = torch.clamp(x, 0, img.shape[3]-1)
x1 = x0 + 1
y0 = torch.clamp(y, 0, img.shape[2]-1)
y1 = y0 + 1
Ia = img[:, :, y0, x0]
Ib = img[:, :, y1, x0]
Ic = img[:, :, y0, x1]
Id = img[:, :, y1, x1]
wa = (x1.type_as(x) - x) * (y1.type_as(y) - y)
wb = (x1.type_as(x) - x) * (y - y0.type_as(y))
wc = (x - x0.type_as(x)) * (y1.type_as(y) - y)
wd = (x - x0.type_as(x)) * (y - y0.type_as(y))
return wa * Ia + wb * Ib + wc * Ic + wd * Id
def grid_sample(img, grid):
B, C, H, W = img.shape
grid_x = grid[:, 0]
grid_y = grid[:, 1]
x = (grid_x + 1) / 2 * (W - 1)
y = (grid_y + 1) / 2 * (H - 1)
x = x.view(B, 1, 1, 1)
y = y.view(B, 1, 1, 1)
img = img.view(B, C, H * W)
# Get base and linear weights for interpolation
x0_f = torch.floor(x).int()
x1_f = x0_f + 1
y0_f = torch.floor(y).int()
y1_f = y0_f + 1
x0 = torch.clamp(x0_f, 0, W-1)
x1 = torch.clamp(x1_f, 0, W-1)
y0 = torch.clamp(y0_f, 0, H-1)
y1 = torch.clamp(y1_f, 0, H-1)
wa = ((x1.float() - x) * (y1.float() - y)).unsqueeze(1)
wb = ((x1.float() - x) * (y - y0.float())).unsqueeze(1)
wc = ((x - x0.float()) * (y1.float() - y)).unsqueeze(1)
wd = ((x - x0.float()) * (y - y0.float())).unsqueeze(1)
# Interpolate
Ia = torch.gather(img, 2, y0.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long() + x0.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long())
Ib = torch.gather(img, 2, y1.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long() + x0.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long())
Ic = torch.gather(img, 2, y0.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long() + x1.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long())
Id = torch.gather(img, 2, y1.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long() + x1.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long())
return wa * Ia + wb * Ib + wc * Ic + wd * Id
# Example usage
img = torch.randn(1, 3, 4, 4)
grid = torch.tensor([[[0.5, 0.5], [-0.5, 1.5]]])
output = grid_sample(img, grid.view(1, 2, 2))
print(output)
1.5.2定义双线性插值函数
def bilinear_interpolate(img, x, y):
# ...
1.5.3计算插值权重
x = x.floor().int() #floor() 函数是向下取整,int() 函数是将浮点数转换为整数。
y = y.floor().int()
x0 = torch.clamp(x, 0, img.shape[3]-1) #torch.clamp 函数将 x 变量中的每个元素限制在一个范围内。
x1 = x0 + 1
y0 = torch.clamp(y, 0, img.shape[2]-1)
y1 = y0 + 1
1.5.4获取周围的像素值
Ia = img[:, :, y0, x0]
Ib = img[:, :, y1, x0]
Ic = img[:, :, y0, x1]
Id = img[:, :, y1, x1]
1.5.5计算插值
wa = (x1.type_as(x) - x) * (y1.type_as(y) - y)
wb = (x1.type_as(x) - x) * (y - y0.type_as(y))
wc = (x - x0.type_as(x)) * (y1.type_as(y) - y)
wd = (x - x0.type_as(x)) * (y - y0.type_as(y))
return wa * Ia + wb * Ib + wc * Ic + wd * Id
我们计算插值的权重wa
、wb
、wc
和wd
,并将它们与周围的像素值相乘,最后将结果相加得到插值结果。
1.5.6定义 grid_sample 函数
def grid_sample(img, grid):
B, C, H, W = img.shape
grid_x = grid[:, 0]
grid_y = grid[:, 1]
1.5.7归一化网格坐标
x = (grid_x + 1) / 2 * (W - 1)
y = (grid_y + 1) / 2 * (H - 1)
1.5.8调整网格坐标形状
x = x.view(B, 1, 1, 1) #用于改变张量(Tensor)的形状。
y = y.view(B, 1, 1, 1)
1.5.9计算插值坐标和权重
# Get base and linear weights for interpolation
x0_f = torch.floor(x).int()
x1_f = x0_f + 1
y0_f = torch.floor(y).int()
y1_f = y0_f + 1
x0 = torch.clamp(x0_f, 0, W-1)
x1 = torch.clamp(x1_f, 0, W-1)
y0 = torch.clamp(y0_f, 0, H-1)
y1 = torch.clamp(y1_f, 0, H-1)
wa = ((x1.float() - x) * (y1.float() - y)).unsqueeze(1) #unsqueeze(1) 将结果张量在第二个维度上增加一个维度,这通常是为了在后续的张量运算中保持维度一致性。
wb = ((x1.float() - x) * (y - y0.float())).unsqueeze(1)
wc = ((x - x0.float()) * (y1.float() - y)).unsqueeze(1)
wd = ((x - x0.float()) * (y - y0.float())).unsqueeze(1)
1.5.10插值采样
# Interpolate
Ia = torch.gather(img, 2, y0.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long() + x0.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long())
Ib = torch.gather(img, 2, y1.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long() + x0.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long())
Ic = torch.gather(img, 2, y0.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long() + x1.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long())
Id = torch.gather(img, 2, y1.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long() + x1.view(B, 1, 1, H*W).expand(B, C, 1, H*W).long())
return wa * Ia + wb * Ib + wc * Ic + wd * Id #确定每个邻近像素值对目标像素值的贡献度
2. 偏移量应用
接下来,我们需要将计算出的偏移量应用到卷积核上,以调整卷积核的位置。这通常涉及到生成新的坐标映射,并使用这些映射来调整输入特征图。
def forward(self, f):
offset = self.offset_conv(f)
offset = self.bn(offset)
offset = torch.tanh(offset)
input_shape = f.shape
num_batch, num_channels, height, width = input_shape
num_points = self.kernel_size
# 生成网格
y = torch.linspace(-int(num_points // 2), int(num_points // 2), int(num_points))
x = torch.linspace(0, 0, 1)
y, x = torch.meshgrid(y, x)
y_spread = y.reshape(-1, 1)
x_spread = x.reshape(-1, 1)
y_grid = y_spread.repeat([1, width * height])
y_grid = y_grid.reshape([num_points, width, height])
y_grid = y_grid.unsqueeze(0)
x_grid = x_spread.repeat([1, width * height])
x_grid = x_grid.reshape([num_points, width, height])
x_grid = x_grid.unsqueeze(0)
# 计算新的坐标
y_new = y_grid + offset[:, 0, :, :]
x_new = x_grid + offset[:, 1, :, :]
# 应用偏移量
if self.if_offset:
y_new = y_new.add(y_grid.mul(self.extend_scope))
x_new = x_new.add(x_grid.mul(self.extend_scope))
3. 特征插值
最后,我们使用新的坐标来对输入特征图进行插值,以获得调整后的特征图。
def _bilinear_interpolate_3D(self, input, y, x):
# 双线性插值
B, C, H, W = input.size()
x = x.contiguous()
y = y.contiguous()
x0 = torch.floor(x).long()
y0 = torch.floor(y).long()
x1 = x0 + 1
y1 = y0 + 1
x0 = torch.clamp(x0, 0, W - 1)
x1 = torch.clamp(x1, 0, W - 1)
y0 = torch.clamp(y0, 0, H - 1)
y1 = torch.clamp(y1, 0, H - 1)
Ia = input[:, :, y0, x0]
Ib = input[:, :, y1, x0]
Ic = input[:, :, y0, x1]
Id = input[:, :, y1, x1]
wa = (x1.type_as(x) - x) * (y1.type_as(y) - y)
wb = (x1.type_as(x) - x) * (y - y0.type_as(y))
wc = (x - x0.type_as(x)) * (y1.type_as(y) - y)
wd = (x - x0.type_as(x)) * (y - y0.type_as(y))
return wa * Ia + wb * Ib + wc * Ic + wd * Id
蛇形卷积的代码
import torch
import torch.nn as nn
__all__ = ['C3k2_DSConv']
def autopad(k, p=None, d=1):
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
return p
class Conv(nn.Module):
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
default_act = nn.SiLU()
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv layer with given arguments including activation."""
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
"""Perform transposed convolution of 2D data."""
return self.act(self.conv(x))
class DySnakeConv(nn.Module):
def __init__(self, inc, ouc, k=3) -> None:
super().__init__()
self.conv_0 = Conv(inc, ouc, k)
self.conv_x = DSConv(inc, ouc, 0, k)
self.conv_y = DSConv(inc, ouc, 1, k)
def forward(self, x):
return torch.cat([self.conv_0(x), self.conv_x(x), self.conv_y(x)], dim=1)
class DSConv(nn.Module):
def __init__(self, in_ch, out_ch, morph, kernel_size=3, if_offset=True, extend_scope=1):
"""
The Dynamic Snake Convolution
:param in_ch: input channel
:param out_ch: output channel
:param kernel_size: the size of kernel
:param extend_scope: the range to expand (default 1 for this method)
:param morph: the morphology of the convolution kernel is mainly divided into two types
along the x-axis (0) and the y-axis (1) (see the paper for details)
:param if_offset: whether deformation is required, if it is False, it is the standard convolution kernel
"""
super(DSConv, self).__init__()
self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1)
self.bn = nn.BatchNorm2d(2 * kernel_size)
self.kernel_size = kernel_size
self.dsc_conv_x = nn.Conv2d(
in_ch,
out_ch,
kernel_size=(kernel_size, 1),
stride=(kernel_size, 1),
padding=0,
)
self.dsc_conv_y = nn.Conv2d(
in_ch,
out_ch,
kernel_size=(1, kernel_size),
stride=(1, kernel_size),
padding=0,
)
self.gn = nn.GroupNorm(out_ch // 4, out_ch)
self.act = Conv.default_act
self.extend_scope = extend_scope
self.morph = morph
self.if_offset = if_offset
def forward(self, f):
offset = self.offset_conv(f)
offset = self.bn(offset)
offset = torch.tanh(offset)
input_shape = f.shape
# DSC function is not provided in the snippet, you would need to implement it or find the complete implementation
# dsc = DSC(input_shape, self.kernel_size, self.extend_scope, self.morph)
# return dsc(f, offset)
# For now, this is a placeholder to show where the DSC function would be called
raise NotImplementedError("DSC function is not implemented in this snippet.")
# Usage
# Create an instance of DySnakeConv
dy_snake_conv = DySnakeConv(inc=3, ouc=64, k=3)
接下来每段代码的解释:
1.导入必要的库:
import torch
import torch.nn as nn
2.定义模块:
这行代码定义了一个列表,其中只有一个元素 'C3k2_DSConv'
。这意味着如果有人使用 from your_module import *
来导入你的模块(假设你的模块名为 your_module
),那么只有 C3k2_DSConv
这个名称会被导入。
__all__ = ['C3k2_DSConv']
3.自定义辅助函数:
def autopad(k, p=None, d=1):
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
return p
这个函数用于自动计算填充(padding),以确保卷积操作后输出的维度与输入相同。
4.定义卷积层:
class Conv(nn.Module):
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
default_act = nn.SiLU()
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv layer with given arguments including activation."""
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
"""Perform transposed convolution of 2D data."""
return self.act(self.conv(x))
5.定义蛇形卷积模块:
class DySnakeConv(nn.Module):
def __init__(self, inc, ouc, k=3) -> None:
super().__init__()
self.conv_0 = Conv(inc, ouc, k)
self.conv_x = DSConv(inc, ouc, 0, k)
self.conv_y = DSConv(inc, ouc, 1, k)
def forward(self, x):
return torch.cat([self.conv_0(x), self.conv_x(x), self.conv_y(x)], dim=1)
这个类定义了一个动态蛇形卷积模块,它包含三个卷积层:一个标准卷积层和两个蛇形卷积层(分别处理x轴和y轴)。forward
方法将这三个卷积层的输出在通道维度上进行拼接。
6.定义蛇形卷积层:
class DSConv(nn.Module):
def __init__(self, in_ch, out_ch, morph, kernel_size=3, if_offset=True, extend_scope=1):
super(DSConv, self).__init__()
self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1)
self.bn = nn.BatchNorm2d(2 * kernel_size)
self.kernel_size = kernel_size
self.dsc_conv_x = nn.Conv2d(in_ch, out_ch, kernel_size=(kernel_size, 1), stride=(kernel_size, 1), padding=0)
self.dsc_conv_y = nn.Conv2d(in_ch, out_ch, kernel_size=(1, kernel_size), stride=(1, kernel_size), padding=0)
self.gn = nn.GroupNorm(out_ch // 4, out_ch)
self.act = nn.SiLU() # Assuming Conv.default_act is SiLU
self.extend_scope = extend_scope
self.morph = morph
self.if_offset = if_offset
def forward(self, f):
offset = self.offset_conv(f)
offset = self.bn(offset)
offset = torch.tanh(offset) # Limit the offset to range (-1, 1)
# Generate grid for sampling
B, C, H, W = f.shape
grid_x, grid_y = torch.meshgrid(torch.arange(W, device=f.device), torch.arange(H, device=f.device))
grid_x = grid_x.float() / (W - 1) * 2 - 1
grid_y = grid_y.float() / (H - 1) * 2 - 1
grid = torch.stack((grid_x, grid_y), dim=-1)[None].repeat(B, 1, 1, 1).to(f.device)
# Apply offset to the grid
if self.if_offset:
offset_x = offset[:, 0, :, :] * self.extend_scope
offset_y = offset[:, 1, :, :] * self.extend_scope
new_grid = grid + torch.stack((offset_x, offset_y), dim=-1)
else:
new_grid = grid
# Sample the input feature map using the new grid
sampled_features = F.grid_sample(f, new_grid, align_corners=True)
# Apply convolution along x and y directions
if self.morph == 0: # Along x-axis
x_conv = self.dsc_conv_x(sampled_features)
y_conv = self.dsc_conv_y(sampled_features)
output = x_conv + y_conv
elif self.morph == 1: # Along y-axis
x_conv = self.dsc_conv_x(sampled_features)
y_conv = self.dsc_conv_y(sampled_features)
output = x_conv + y_conv
else:
raise ValueError("Morph value must be 0 or 1")
# Apply group normalization and activation
output = self.gn(output)
output = self.act(output)
return output
我们来讲解一下里面的前向传播函数 forward
:
计算偏移量:
offset = self.offset_conv(f)
offset = self.bn(offset)
offset = torch.tanh(offset) # Limit the offset to range (-1, 1)
生成采样网格:
B, C, H, W = f.shape
grid_x, grid_y = torch.meshgrid(torch.arange(W, device=f.device), torch.arange(H, device=f.device))
grid_x = grid_x.float() / (W - 1) * 2 - 1
grid_y = grid_y.float() / (H - 1) * 2 - 1
grid = torch.stack((grid_x, grid_y), dim=-1)[None].repeat(B, 1, 1, 1).to(f.device)
使用 torch.meshgrid
生成两个二维网格,grid_x
和 grid_y
,分别表示每个像素的 x 和 y 坐标。torch.arange
生成从 0 到 W-1
(宽度减一)和 0 到 H-1
(高度减一)的序列,这些序列在指定设备上创建。
应用偏移量:
if self.if_offset:
offset_x = offset[:, 0, :, :] * self.extend_scope
offset_y = offset[:, 1, :, :] * self.extend_scope
new_grid = grid + torch.stack((offset_x, offset_y), dim=-1)
else:
new_grid = grid
如果需要偏移,则将偏移量应用到网格上,生成新的网格 new_grid
。
采样输入特征图:
sampled_features = F.grid_sample(f, new_grid, align_corners=True)
使用 F.grid_sample
函数根据新的网格 new_grid
对输入特征图 f
进行采样。
沿 x 轴和 y 轴的卷积:
if self.morph == 0: # Along x-axis
x_conv = self.dsc_conv_x(sampled_features)
y_conv = self.dsc_conv_y(sampled_features)
output = x_conv + y_conv
elif self.morph == 1: # Along y-axis
x_conv = self.dsc_conv_x(sampled_features)
y_conv = self.dsc_conv_y(sampled_features)
output = x_conv + y_conv
else:
raise ValueError("Morph value must be 0 or 1")
应用分组归一化和激活函数:
output = self.gn(output)
output = self.act(output)
返回输出结果:
return output