计算机视觉之 GSoP 注意力模块
计算机视觉之 GSoP 注意力模块
一、简介
GSopBlock
是一个自定义的神经网络模块,主要用于实现 GSoP(Global Second-order Pooling)注意力机制。GSoP 注意力机制通过计算输入特征的协方差矩阵,捕捉全局二阶统计信息,从而增强模型的表达能力。
原论文:《Global Second-order Pooling Convolutional Networks (arxiv.org)》
二、语法和参数
语法
class GSopBlock(nn.Module):
def __init__(self, in_channels, mid_channels):
...
def forward(self, x):
...
参数
in_channels
:输入特征的通道数。mid_channels
:中间层的通道数,用于调整特征维度。
三、实例
3.1 初始化和前向传播
- 代码
import torch
import torch.nn as nn
class GSopBlock(nn.Module):
def __init__(self, in_channels, mid_channels):
super(GSopBlock, self).__init__()
self.conv2d1 = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True)
)
self.row_wise_conv = nn.Sequential(
nn.Conv2d(
mid_channels, 4*mid_channels,
kernel_size=(mid_channels, 1),
groups=mid_channels, bias=False),
nn.BatchNorm2d(4*mid_channels),
)
self.conv2d2 = nn.Sequential(
nn.Conv2d(4*mid_channels, in_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
# Step 1: 调整通道数
x = self.conv2d1(x)
batch_size, channels, height, width = x.size()
# Step 2: 展平输入
x_flat = x.view(batch_size, channels, -1)
# Step 3: 计算协方差矩阵
x_mean = x_flat.mean(dim=-1, keepdim=True)
x_centered = x_flat - x_mean
cov_matrix = torch.bmm(x_centered, x_centered.transpose(1, 2)) / (height * width)
cov_matrix = cov_matrix.unsqueeze(-1)
# Step 4: 行方向卷积
cov_features = self.row_wise_conv(cov_matrix)
# Step 5: 生成权重向量
weight_vector = self.conv2d2(cov_features)
# Step 6: 计算最终输出
x_out = x * weight_vector
return x_out
- 输出
经过加权后的图像
3.2 应用在示例数据上
- 代码
import torch
# 创建示例输入数据
input_tensor = torch.randn(1, 64, 32, 32) # (batch_size, in_channels, height, width)
# 初始化 GSopBlock 模块
gsop_block = GSopBlock(in_channels=64, mid_channels=16)
# 前向传播
output_tensor = gsop_block(input_tensor)
print(output_tensor.shape)
- 输出
torch.Size([1, 64, 32, 32])
四、注意事项
GSopBlock
模块适用于捕捉输入特征之间的全局二阶统计信息,增强模型的表达能力。- 在使用
GSopBlock
时,确保输入特征的通道数和中间层的通道数设置合理,以避免计算开销过大。 - 该模块主要用于图像数据处理,适用于各种计算机视觉任务,如图像分类、目标检测等。