YOLOv11改进策略【注意力机制篇】| 添加SE、CBAM、ECA、CA、Swin Transformer等注意力和多头注意力机制
前言
这篇文章带来一个经典注意力模块的汇总,虽然有些模块已经发布很久了,但后续的注意力模块也都是在此基础之上进行改进的,对于初学者来说还是有必要去学习了解一下,以加深对模块,模型的理解。
文章目录
- 前言
- 一、为什么要引入注意力机制?
- 二、SE
- 2.1 SE的原理
- 2.2 SE的实现代码
- 三、CBAM
- 3.1 CBAM的原理
- 3.2 CBAM的实现代码
- 四、ECA
- 4.1 ECA的原理
- 4.2 ECA的实现代码
- 五、CA
- 5.1 CA的原理
- 5.2 CA的实现代码
- 六、Swin Transformer
- 6.1 Swin Transformer的原理
- 6.2 Swin Transformer的实现代码
- 七、添加步骤
- 1. 修改ultralytics/nn/modules/block.py
- 2. 修改ultralytics/nn/modules/__init__.py
- 3. 修改ultralytics/nn/modules/tasks.py
- 八、yaml模型文件
- 九、成功运行结果
一、为什么要引入注意力机制?
来源:注意力机制的设计灵感来源于人类视觉系统。当我们在观察外界事物时,会自动将注意力集中在重要或感兴趣的区域,而忽略无关信息。计算机视觉中的注意力机制就是在试图模拟这一过程,以提高模型的感知和理解能力。
问题:随着图像数据量的增加,模型需要处理的信息量也随之增大。传统的卷积神经网络在处理大量数据时可能会遇到信息过载的问题,导致性能下降。注意力机制通过有选择地关注重要信息,帮助模型在海量数据中筛选出关键内容,从而提高检测精度。
好处:
- 注意力机制能够赋予输入数据的不同部分以不同的权重,使模型更加关注重要的特征信息。
- 通过生成热力图,显示模型在做出决策时关注的具体区域,有助于更好地理解模型的决策过程,增强模型的可解释性。
- 注意力机制使模型在处理不同数据集和任务时能够更灵活地调整其关注点。有助于提升模型的泛化能力,使其在面对新数据集或新任务时仍能保持较高的性能水平。
除了能够提升性能外,其最主要的还是其即插即用的特性,无论模块放在什么地方,都可以运行查看训练效果,更方便炼丹成功~
二、SE
2.1 SE的原理
通道注意力模块关注于网络中每个通道的重要性,通过为每个通道分配不同的权重,使得网络能够更加关注那些对任务更为关键的通道特征,从而提高模型的性能。其中主要涉及Squeeze
和Excitation
两个操作。
- Squeeze操作:通过
全局平均池化
将每个通道的特征图压缩为一个实数。 - Excitation操作:利用两个
全连接层
(先降维后升维)和一个ReLU激活函数
来学习通道间的依赖关系,并通过sigmoid函数
生成权重向量。 - Scale操作:将学习到的通道权重与原始特征图进行逐通道相乘,实现特征的重标定。
论文:https://arxiv.org/abs/1709.01507
源码:https://github.com/hujie-frank/SENet
2.2 SE的实现代码
import torch.nn as nn
# SE
class SE(nn.Module):
def __init__(self, c1, ratio=16):
super(SE, self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.l1 = nn.Linear(c1, c1 // ratio, bias=False)
self.relu = nn.ReLU(inplace=True)
self.l2 = nn.Linear(c1 // ratio, c1, bias=False)
self.sig = nn.Sigmoid()
def forward(self, x):
b, c, _, _ = x.size()
y = self.avgpool(x).view(b, c)
y = self.l1(y)
y = self.relu(y)
y = self.l2(y)
y = self.sig(y)
y = y.view(b, c, 1, 1)
return x * y.expand_as(x)
注意❗:在7.2、7.3小节
中的__init__.py
和tasks.py
文件中需要声明的模块名称为:SE
。
三、CBAM
3.1 CBAM的原理
CBAM注意力模块由通道注意力模块
和空间注意力模块
两部分组成。它通过顺序地应用通道注意力和空间注意力,使得网络能够自适应地关注到输入特征图中最重要的通道和空间位置,从而提高模型的表征能力。
-
通道注意力模块(CAM):
此部分的操作步骤与SE通道注意力
模块的步骤一致。 -
空间注意力模块(SAM):
- 特征提取:在通道注意力模块处理后的特征图上,分别进行基于通道维度的
最大池化
和平均池化
操作,以生成两个新的特征图。 - 特征融合:将两个池化后的特征图在通道维度上进行
拼接
(concatenate),然后通过一个卷积层进行特征融合,生成空间注意力权重图。 - 特征增强:将空间注意力权重图与原始特征图进行
逐元素相乘
,实现特征的增强,使得模型能够关注到更重要的空间位置信息。
- 特征提取:在通道注意力模块处理后的特征图上,分别进行基于通道维度的
论文:https://arxiv.org/abs/1807.06521
源码:https://github.com/Jongchan/attention-module
3.2 CBAM的实现代码
import torch
import torch.nn as nn
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
self.relu = nn.ReLU()
self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
out = self.sigmoid(avg_out + max_out)
return out
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), "kernel size must be 3 or 7"
padding = 3 if kernel_size == 7 else 1
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# 1*h*w
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
# 2*h*w
x = self.conv(x)
# 1*h*w
return self.sigmoid(x)
class CBAM(nn.Module):
def __init__(self, c1, c2, ratio=16, kernel_size=7):
super(CBAM, self).__init__()
self.channel_attention = ChannelAttention(c1, ratio)
self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x):
out = self.channel_attention(x) * x
# c*h*w
# c*h*w * 1*h*w
out = self.spatial_attention(out) * out
return out
注意❗:在7.2、7.3小节
中的__init__.py
和tasks.py
文件中需要声明的模块名称为:CBAM
。
四、ECA
4.1 ECA的原理
ECA注意力模块的核心思想是在不增加过多计算成本和参数的情况下,通过引入一种有效的通道注意力机制,来增强网络对关键特征的关注能力。它避免了通道注意力机制中可能存在的降维操作带来的性能损失,通过一种自适应的跨通道交互策略来实现通道权重的生成。步骤如下:
- 特征压缩:这里和
SE注意力
中的Squeeze操作
一致,省略啦。 - 特征学习:
ECA
使用一维卷积来代替SE注意力
中的全连接层,来学习通道间的依赖关系。这里的一维卷积核大小是自适应的,与通道维度成正比,以确保不同通道数的特征图都能有效地进行跨通道交互。通过一维卷积,ECA
能够直接捕获局部跨通道交互信息,而无需进行复杂的降维和升维操作。 - 特征重标定:这里也和
SE注意力
中的Scale操作
一致。
论文:https://arxiv.org/abs/1910.03151
源码:https://github.com/BangguWu/ECANet
4.2 ECA的实现代码
import torch
import torch.nn as nn
class ECA(nn.Module):
def __init__(self, c1, c2, k_size=3):
super(ECA, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
y = self.avg_pool(x)
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
y = self.sigmoid(y)
return x * y.expand_as(x)
注意❗:在7.2、7.3小节
中的__init__.py
和tasks.py
文件中需要声明的模块名称为:ECA
。
五、CA
5.1 CA的原理
CA注意力模块的核心思想是将位置信息嵌入到通道注意力中,以更精确地捕捉到图像中的空间分布特征。与普通的通道注意力机制不同的是,CA不仅关注通道间的依赖关系,还通过引入坐标信息来增强模型对空间细节的敏感度。步骤如下:
- 特征分解:输入特征图通常具有C(通道数)、H(高度)、W(宽度)三个维度。CA首先通过两个并行的
全局平均池化
操作,分别沿垂直(高度)和水平(宽度)方向聚合输入特征,生成两个包含方向特定信息的特征图。这两个特征图分别捕捉了高度和宽度方向上的空间信息。 - 特征编码:将两个池化后的特征图在通道维度上
拼接
,并通过一个1x1的二维卷积层
来融合和转换特征。然后,对卷积后的特征图进行批量归一化
并非线性激活
。 - 注意力图生成:将批量归一化和激活后的特征图分裂为两个特征图,分别对应于高度和宽度方向。接着,通过另外两个
1x1的二维卷积层
分别处理这两个特征图,并应用Sigmoid激活函数
,生成两个注意力图。这两个注意力图分别沿宽度和高度方向对输入特征图进行重标定。 - 特征重加权:将通过Sigmoid激活的注意力图与原始的输入特征图相乘,以重新加权原始特征。这样,重要的特征会被放大,而不重要的特征则会减弱,从而增强模型对关键信息的关注能力。
论文:https://arxiv.org/abs/2103.02907v1
源码:https://github.com/houqb/CoordAttention
5.2 CA的实现代码
import torch
import torch.nn as nn
class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)
def forward(self, x):
return self.relu(x + 3) / 6
class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.sigmoid = h_sigmoid(inplace=inplace)
def forward(self, x):
return x * self.sigmoid(x)
class CoordAtt(nn.Module):
def __init__(self, inp, oup, reduction=32):
super(CoordAtt, self).__init__()
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
mip = max(8, inp // reduction)
self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(mip)
self.act = h_swish()
self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
def forward(self, x):
identity = x
n, c, h, w = x.size()
# c*1*W
x_h = self.pool_h(x)
# c*H*1
# C*1*h
x_w = self.pool_w(x).permute(0, 1, 3, 2)
y = torch.cat([x_h, x_w], dim=2)
# C*1*(h+w)
y = self.conv1(y)
y = self.bn1(y)
y = self.act(y)
x_h, x_w = torch.split(y, [h, w], dim=2)
x_w = x_w.permute(0, 1, 3, 2)
a_h = self.conv_h(x_h).sigmoid()
a_w = self.conv_w(x_w).sigmoid()
out = identity * a_w * a_h
return out
注意❗:在7.2、7.3小节
中的__init__.py
和tasks.py
文件中需要声明的模块名称为:CoordAtt
。
六、Swin Transformer
6.1 Swin Transformer的原理
Swin Transformer通过分层设计结合多个等级的窗口划分来降低计算复杂度,并提出位移窗口使相邻的窗口之间进行交互,从而达到全局建模的能力。在Swin Transformer
模型中最重要的是模块是窗口多头自注意力(W-MSA)
和移动窗口多头自注意力(SW-MSA)
,用于自注意力的计算。
- 窗口多头自注意力(W-MSA):
- 划分窗口:将特征图划分为多个固定大小的窗口。
- 自注意力计算:在每个窗口内独立计算多头自注意力,此时计算复杂度与窗口内的小块数量成线性关系,从而降低了整体计算复杂度。
- 输出:得到每个窗口内的自注意力特征图。
- 移动窗口多头自注意力(SW-MSA):
- 窗口移动:在W-MSA之后,通过移动窗口的方式改变窗口的划分,使得相邻的窗口之间能够产生交互。
- 自注意力计算:在新的窗口划分下再次计算多头自注意力。
- 输出:得到移动窗口后的自注意力特征图。
论文:https://arxiv.org/abs/2103.14030
源码:https://github.com/microsoft/Swin-Transformer
6.2 Swin Transformer的实现代码
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
nn.init.normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
# print(attn.dtype, v.dtype)
try:
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
except:
#print(attn.dtype, v.dtype)
x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
B, H, W, C = x.shape
assert H % window_size == 0, 'feature map h and w can not divide by window size'
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class SwinTransformerLayer(nn.Module):
def __init__(self, dim, num_heads, window_size=8, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.SiLU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
# if min(self.input_resolution) <= self.window_size:
# # if window size is larger than input resolution, we don't partition windows
# self.shift_size = 0
# self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def create_mask(self, H, W):
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x):
# reshape x[b c h w] to x[b l c]
_, _, H_, W_ = x.shape
Padding = False
if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
Padding = True
# print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
pad_r = (self.window_size - W_ % self.window_size) % self.window_size
pad_b = (self.window_size - H_ % self.window_size) % self.window_size
x = F.pad(x, (0, pad_r, 0, pad_b))
# print('2', x.shape)
B, C, H, W = x.shape
L = H * W
x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) # b, L, c
# create mask from init to forward
if self.shift_size > 0:
attn_mask = self.create_mask(H, W).to(x.device)
else:
attn_mask = None
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W) # b c h w
if Padding:
x = x[:, :, :H_, :W_] # reverse padding
return x
class SwinTransformerBlock(nn.Module):
def __init__(self, c1, c2, num_heads, num_layers, window_size=8):
super().__init__()
self.conv = None
if c1 != c2:
self.conv = Conv(c1, c2)
# remove input_resolution
self.blocks = nn.Sequential(*[SwinTransformerLayer(dim=c2, num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])
def forward(self, x):
if self.conv is not None:
x = self.conv(x)
x = self.blocks(x)
return x
class STCSPA(nn.Module):
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super(STCSPA, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1, 1)
num_heads = c_ // 32
self.m = SwinTransformerBlock(c_, c_, num_heads, n)
#self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
def forward(self, x):
y1 = self.m(self.cv1(x))
y2 = self.cv2(x)
return self.cv3(torch.cat((y1, y2), dim=1))
class STCSPB(nn.Module):
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super(STCSPB, self).__init__()
c_ = int(c2) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1, 1)
num_heads = c_ // 32
self.m = SwinTransformerBlock(c_, c_, num_heads, n)
#self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
def forward(self, x):
x1 = self.cv1(x)
y1 = self.m(x1)
y2 = self.cv2(x1)
return self.cv3(torch.cat((y1, y2), dim=1))
class STCSPC(nn.Module):
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super(STCSPC, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(c_, c_, 1, 1)
self.cv4 = Conv(2 * c_, c2, 1, 1)
num_heads = c_ // 32
self.m = SwinTransformerBlock(c_, c_, num_heads, n)
#self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
def forward(self, x):
y1 = self.cv3(self.m(self.cv1(x)))
y2 = self.cv2(x)
return self.cv4(torch.cat((y1, y2), dim=1))
注意❗:在7.2、7.3小节
中的__init__.py
和tasks.py
文件中需要声明的模块名称为:STCSPA, STCSPB, STCSPC
,在模型中使用哪个选哪个就行。
七、添加步骤
此处在模型配置中以SE通道注意力
为例,列举的其他注意力模块添加步骤与此完全一致。
1. 修改ultralytics/nn/modules/block.py
此处需要修改的文件是ultralytics/nn/modules/block.py
common.py中定义了网络结构的通用模块
,我们想要加入新的模块就只需要将模块代码放到这个文件内即可。
SE
添加后如下:
2. 修改ultralytics/nn/modules/init.py
此处需要修改的文件是ultralytics/nn/modules/__init__.py
__init__.py
文件中定义了所有模块的初始化,我们只需要将block.py
中的新的模块命添加到对应的函数即可。
SE
在block.py
中实现,所有要添加在:
from .block import (
C1,
C2,
...
SE
)
3. 修改ultralytics/nn/modules/tasks.py
在tasks.py
文件中,需要在两处位置添加各模块类名称。
首先:在函数声明中引入SE
其次:在parse_model函数
中注册SE模块
八、yaml模型文件
在代码配置完成后,配置模型的YAML文件。
此处以models/detect/yolov10m.yaml
为例,在同目录下创建一个用于自己数据集训练的模型文件yolov10m-SE.yaml
。
将yolov10m.yaml
中的内容复制到yolov10m-SE.yaml
文件下,修改nc
数量等于自己数据中目标的数量。
在骨干网络的最后一层添加SE模块
,只需要填入一个参数,通道数,和前一层通道数一致。还需要注意的是,由于PAN+FPN的颈部模型结构存在,层之间的匹配也要记得修改,维度要匹配上。
📌 放在此处的目的是让网络能够学习到更深层的语义信息,因为此时特征图尺寸小,包含全局信息。若是希望网络能够更加关注局部信息,可尝试将注意力模块添加到网络的浅层。
📌 当然由于其即插即用的特性,加在哪里都是可以的,但是想要真的有效,还需要根据模型结构,数据集特性等多方面因素,多做实验进行验证。
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# yolo task=detect mode=train model=yolov11m.yaml data=data.yaml device=0 epochs=300 batch=16 imgsz=640 workers=10
# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
# YOLO11n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 2, C3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 2, C3k2, [512, False, 0.25]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 2, C3k2, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 2, C3k2, [1024, True]]
- [-1, 1, SE, [1024]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 2, C2PSA, [1024]] # 10
# YOLO11n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 2, C3k2, [512, False]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 14], 1, Concat, [1]] # cat head P4
- [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 11], 1, Concat, [1]] # cat head P5
- [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
- [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)
九、成功运行结果
打印网络模型可以看到SE模块
已经加入到模型中,并可以进行训练了。
其他模块如:CBAM、ECA、CA、Swin Transformer
这些模块和SE
的添加步骤完全一致。
并且对于这里未提到的注意力模块的添加步骤也是一样的,只要加入模块代码,并将其添加到模型中即可。
from n params module arguments
0 -1 1 1856 ultralytics.nn.modules.conv.Conv [3, 64, 3, 2]
1 -1 1 73984 ultralytics.nn.modules.conv.Conv [64, 128, 3, 2]
2 -1 1 111872 ultralytics.nn.modules.block.C3k2 [128, 256, 1, True, 0.25]
3 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2]
4 -1 1 444928 ultralytics.nn.modules.block.C3k2 [256, 512, 1, True, 0.25]
5 -1 1 2360320 ultralytics.nn.modules.conv.Conv [512, 512, 3, 2]
6 -1 1 1380352 ultralytics.nn.modules.block.C3k2 [512, 512, 1, True]
7 -1 1 2360320 ultralytics.nn.modules.conv.Conv [512, 512, 3, 2]
8 -1 1 1380352 ultralytics.nn.modules.block.C3k2 [512, 512, 1, True]
9 -1 1 1024 ultralytics.nn.modules.block.SE [512, 512]
10 -1 1 656896 ultralytics.nn.modules.block.SPPF [512, 512, 5]
11 -1 1 990976 ultralytics.nn.modules.block.C2PSA [512, 512, 1]
12 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
13 [-1, 6] 1 0 ultralytics.nn.modules.conv.Concat [1]
14 -1 1 1642496 ultralytics.nn.modules.block.C3k2 [1024, 512, 1, True]
15 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
16 [-1, 4] 1 0 ultralytics.nn.modules.conv.Concat [1]
17 -1 1 542720 ultralytics.nn.modules.block.C3k2 [1024, 256, 1, True]
18 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2]
19 [-1, 14] 1 0 ultralytics.nn.modules.conv.Concat [1]
20 -1 1 1511424 ultralytics.nn.modules.block.C3k2 [768, 512, 1, True]
21 -1 1 2360320 ultralytics.nn.modules.conv.Conv [512, 512, 3, 2]
22 [-1, 11] 1 0 ultralytics.nn.modules.conv.Concat [1]
23 -1 1 1642496 ultralytics.nn.modules.block.C3k2 [1024, 512, 1, True]
24 [17, 20, 23] 1 1411795 ultralytics.nn.modules.head.Detect [1, [256, 512, 512]]
YOLOv11m-SE summary: 415 layers, 20,054,803 parameters, 20,054,787 gradients, 68.2 GFLOPs