【深度学习】Pytorch:在 ResNet 中加入注意力机制
在这篇教程中,我们将介绍如何在 ResNet 网络中加入注意力机制模块。我们将通过对标准 ResNet50 进行改进,向网络中添加两个自定义的注意力模块,并展示如何实现这一过程。
为什么要加入注意力机制
注意力机制可以帮助神经网络专注于图像中重要的特征区域,从而提高模型的性能。在卷积神经网络中,加入注意力机制能够有效增强特征提取能力,减少冗余信息的干扰,尤其在处理复杂图像时,能够提升网络的表现。
在本教程中,我们将使用一种通用的注意力模块,您可以根据需求自行替换或改进该模块。
代码实现
导入依赖
我们需要以下 PyTorch 库来构建网络:
import torch
import torch.nn as nn
from torchvision import models
定义注意力模块
首先,我们需要定义一个注意力模块。这里我们使用了一个简单的通道注意力机制(如 SE 模块、CBAM 模块等),你可以根据需求选择不同类型的注意力模块。
假设我们已经有一个注意力模块类(AttentionModule
),它的结构可以像下面这样:
class AttentionModule(nn.Module):
def __init__(self, in_channels):
super(AttentionModule, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels // 16, kernel_size=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(in_channels // 16, in_channels, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
attention = self.conv1(x)
attention = self.relu(attention)
attention = self.conv2(attention)
attention = self.sigmoid(attention)
return x * attention
这段代码定义了一个简单的注意力模块。它通过两个卷积层和一个 Sigmoid 函数来生成一个通道注意力映射,并通过该映射加权输入特征图。
构建 ResNet 与注意力机制集成的模型
现在我们将创建一个新的模型类 ResNetWithAttention
,该模型继承自 nn.Module
,并将注意力模块插入到 ResNet 的关键位置。在这个示例中,我们将注意力模块插入到网络的卷积层输出之后,并在最后一层卷积层后再次插入。
class ResNetWithAttention(nn.Module):
def __init__(self, attention_cls, pretrained=True):
super(ResNetWithAttention, self).__init__()
# 使用预训练的 ResNet50
self.base_model = models.resnet50(pretrained=pretrained)
# 创建注意力模块
self.attention_layer1 = attention_cls(64) # 第一层卷积后
self.attention_layer2 = attention_cls(2048) # 最后一层卷积后
def forward(self, x):
# ResNet50的前向传播过程
x = self.base_model.conv1(x) # 初始卷积层
x = self.base_model.bn1(x) # 批归一化
x = self.base_model.relu(x) # 激活函数
# 第一个注意力模块:第一层卷积后
x = self.attention_layer1(x)
# 最大池化层
x = self.base_model.maxpool(x)
# ResNet的残差层
x = self.base_model.layer1(x)
x = self.base_model.layer2(x)
x = self.base_model.layer3(x)
x = self.base_model.layer4(x)
# 第二个注意力模块:最后一层卷积后
x = self.attention_layer2(x)
# 平均池化
x = self.base_model.avgpool(x)
# 展平并通过全连接层
x = torch.flatten(x, 1)
x = self.base_model.fc(x)
return x
在这个模型中,我们通过 attention_cls
参数动态地将任何类型的注意力模块传入模型。模型首先使用基础的 ResNet50 结构,之后我们将自定义的注意力模块应用到两个关键位置:一个是在第一层卷积之后,另一个是在最后的卷积层之后。
训练模型
使用该模型的训练过程与标准的 ResNet 模型相同。你可以像使用普通的 ResNet 模型一样训练和评估 ResNetWithAttention
。下面是训练的一般流程:
# 示例:初始化模型并进行训练
attention_cls = AttentionModule # 可以替换为其他类型的注意力模块
model = ResNetWithAttention(attention_cls)
# 选择优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 假设 train_loader 是数据加载器
for epoch in range(num_epochs):
model.train()
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')
可加入的注意力模块
通道注意力模块
-
SE (Squeeze-and-Excitation) 模块:最经典的通道注意力模块,使用全局平均池化后生成通道级注意力,通过全连接层建模通道之间的关系。
class SEBlock(nn.Module): def __init__(self, in_channels, reduction=16): super(SEBlock, self).__init__() # 定义第一个全连接层,将输入通道数压缩为 in_channels // reduction self.fc1 = nn.Linear(in_channels, in_channels // reduction) # 定义第二个全连接层,将通道数恢复为原始输入通道数 self.fc2 = nn.Linear(in_channels // reduction, in_channels) # 定义Sigmoid激活函数,用于生成注意力权重 self.sigmoid = nn.Sigmoid() def forward(self, x): # 获取输入张量的 batch_size 和 channels batch_size, channels, _, _ = x.size() # 对输入张量的空间维度(高度和宽度)进行全局平均池化 squeeze = torch.mean(x, dim=(2, 3)) # 通过第一个全连接层进行通道压缩 squeeze = self.fc1(squeeze) # 通过ReLU激活函数和第二个全连接层进行通道扩展 squeeze = self.fc2(F.relu(squeeze)) # 使用Sigmoid生成注意力权重,并调整形状以匹配输入张量的维度 attention = self.sigmoid(squeeze).view(batch_size, channels, 1, 1) # 将注意力权重应用到输入张量上,进行通道加权 return x * attention
-
ECA (Efficient Channel Attention) 模块:通过 1D 卷积建模通道间的依赖关系,减少了计算量,提升了效率
class ECABlock(nn.Module): def __init__(self, channels, kernel_size=3): super(ECABlock, self).__init__() # 定义1D卷积层,用于学习通道间的注意力权重 self.conv = nn.Conv1d(1, 1, kernel_size, padding=kernel_size // 2, bias=False) def forward(self, x): # 获取输入张量的 batch_size 和 channels batch_size, channels, _, _ = x.size() # 对输入张量的空间维度(高度和宽度)进行全局平均池化,并调整形状 y = F.adaptive_avg_pool2d(x, 1).view(batch_size, channels, 1) # 调整形状以适配1D卷积的输入格式 y = y.view(batch_size, 1, channels) # 通过1D卷积层学习通道间的注意力权重 y = self.conv(y) # 使用Sigmoid激活函数生成注意力权重 y = torch.sigmoid(y) # 将注意力权重应用到输入张量上,进行通道加权 return x * y.view(batch_size, channels, 1, 1).expand_as(x)
空间注意力模块
-
CBAM (Convolutional Block Attention Module) 模块:结合了通道注意力和空间注意力,首先进行通道注意力加权,然后通过空间卷积生成空间注意力。
class CBAM(nn.Module): def __init__(self, in_channels, reduction=16): super(CBAM, self).__init__() # 通道注意力模块(SEBlock),用于学习通道间的注意力权重 self.channel_attention = SEBlock(in_channels, reduction) # 空间注意力模块,使用1x1卷积核学习空间注意力权重 self.spatial_attention = nn.Conv2d(2, 1, kernel_size=7, padding=3) def forward(self, x): # 应用通道注意力模块,对输入特征进行通道加权 x = self.channel_attention(x) # 计算输入特征在通道维度上的平均值 avg_out = torch.mean(x, dim=1, keepdim=True) # 计算输入特征在通道维度上的最大值 max_out, _ = torch.max(x, dim=1, keepdim=True) # 将平均值和最大值拼接在一起 spatial_out = torch.cat([avg_out, max_out], dim=1) # 通过空间注意力模块学习空间注意力权重 spatial_out = self.spatial_attention(spatial_out) # 使用Sigmoid激活函数生成空间注意力权重 spatial_attention = torch.sigmoid(spatial_out) # 将空间注意力权重应用到输入特征上,进行空间加权 return x * spatial_attention
-
Coordinate Attention 模块:通过空间坐标信息提升特征建模能力,增强空间特征表达。
class CoordinateAttention(nn.Module): def __init__(self, in_channels, reduction=16): super(CoordinateAttention, self).__init__() # 定义1x1卷积层,用于压缩通道数 self.fc = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1) # 定义1x1卷积层,用于恢复通道数 self.fc_out = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1) def forward(self, x): # 获取输入张量的 batch_size、channels、height 和 width batch_size, channels, height, width = x.size() # 对输入张量的空间维度(高度和宽度)进行全局平均池化 avg_out = torch.mean(x, dim=[2, 3], keepdim=True) # 对输入张量的空间维度(高度和宽度)进行全局最大池化 max_out = torch.amax(x, dim=[2, 3], keepdim=True) # 通过1x1卷积层压缩通道数 avg_out = self.fc(avg_out) max_out = self.fc(max_out) # 通过1x1卷积层恢复通道数 avg_out = self.fc_out(avg_out) max_out = self.fc_out(max_out) # 将平均池化和最大池化的结果相加,并应用到输入张量上 out = x * (avg_out + max_out) return out
双重注意力模块
-
Dual Attention 模块:结合了通道和空间的双重注意力机制,增强了特征的表征能力。
class DualAttentionBlock(nn.Module): def __init__(self, in_channels, reduction=16): super(DualAttentionBlock, self).__init__() # 通道注意力模块(SEBlock),用于学习通道间的注意力权重 self.channel_attention = SEBlock(in_channels, reduction) # 空间注意力模块(CBAM),用于学习空间上的注意力权重 self.spatial_attention = CBAM(in_channels, reduction) def forward(self, x): # 应用通道注意力模块,对输入特征进行通道加权 x = self.channel_attention(x) # 应用空间注意力模块,对输入特征进行空间加权 x = self.spatial_attention(x) return x
全局依赖建模模块
-
Non-local 模块:通过自注意力机制建模全局依赖关系,提升对长距离特征的建模能力。
class NonLocalBlock(nn.Module): def __init__(self, in_channels): super(NonLocalBlock, self).__init__() # 输入通道数 self.in_channels = in_channels # 中间通道数,通常为输入通道数的一半 self.inter_channels = in_channels // 2 # 定义1x1卷积层,用于生成查询(query)特征 self.query_conv = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1) # 定义1x1卷积层,用于生成键(key)特征 self.key_conv = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1) # 定义1x1卷积层,用于生成值(value)特征 self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1) # 定义Softmax函数,用于计算注意力权重 self.softmax = nn.Softmax(dim=-1) def forward(self, x): # 获取输入张量的 batch_size、通道数、高度和宽度 batch_size, C, H, W = x.size() # 通过查询卷积层生成查询特征,并调整形状 query = self.query_conv(x).view(batch_size, self.inter_channels, -1) # 通过键卷积层生成键特征,并调整形状 key = self.key_conv(x).view(batch_size, self.inter_channels, -1) # 通过值卷积层生成值特征,并调整形状 value = self.value_conv(x).view(batch_size, C, -1) # 计算查询特征和键特征的相似度(亲和矩阵) affinity = torch.bmm(query.transpose(1, 2), key) # 使用Softmax计算注意力权重 attention = self.softmax(affinity) # 将注意力权重应用到值特征上,得到加权输出 out = torch.bmm(value, attention.transpose(1, 2)) # 调整输出形状以匹配输入张量的维度 out = out.view(batch_size, C, H, W) # 将加权输出与输入张量相加,实现残差连接 return out + x
-
Attention U-Net 模块:在 U-Net 结构中引入注意力模块,适用于图像分割任务,能够自适应地选择重要区域进行特征增强。
class AttentionGate(nn.Module): def __init__(self, in_channels): super(AttentionGate, self).__init__() # 定义门控通道数和中间通道数 gating_channels = in_channels inter_channels = in_channels // 2 # 定义1x1卷积层,用于处理输入特征 self.conv1 = nn.Conv2d(in_channels, inter_channels, kernel_size=1) # 定义1x1卷积层,用于处理门控特征 self.conv2 = nn.Conv2d(gating_channels, inter_channels, kernel_size=1) # 定义1x1卷积层,用于生成注意力权重 self.psi = nn.Conv2d(inter_channels, 1, kernel_size=1) # 定义Sigmoid激活函数,用于生成注意力权重 self.sigmoid = nn.Sigmoid() def forward(self, x): # 门控信号与输入特征相同 gating = x # 对输入特征进行1x1卷积 x1 = self.conv1(x) # 对门控信号进行1x1卷积 x2 = self.conv2(gating) # 将两个卷积结果相加,并通过ReLU激活函数 attention = self.sigmoid(self.psi(F.relu(x1 + x2))) # 将注意力权重应用到输入特征上,进行加权 return x * attention
总结
通过将注意力模块集成到 ResNet 中,我们能够增强模型对重要特征的关注,从而提高性能。你可以根据需要选择不同的注意力机制,并在模型中任意位置插入这些模块。