计算机视觉之 SE 注意力模块
计算机视觉之 SE 注意力模块
一、简介
SEBlock
是一个自定义的神经网络模块,主要用于实现 Squeeze-and-Excitation(SE)注意力机制。SE 注意力机制通过全局平均池化和全连接层来重新校准通道的权重,从而增强模型的表达能力。
原论文:《Squeeze-and-Excitation Networks》
二、语法和参数
语法
class SEBlock(nn.Module):
def __init__(self, in_channels, reduction=16):
...
def forward(self, x):
...
参数
in_channels
:输入特征的通道数。reduction
:通道缩减比例,默认为 16。
三、实例
3.1 初始化和前向传播
- 代码
import torch
import torch.nn as nn
class SEBlock(nn.Module):
def __init__(self, in_channels, reduction=16):
super(SEBlock, self).__init__()
reduced_channels = max(in_channels // reduction, 1)
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, reduced_channels, bias=False),
nn.ReLU(inplace=True),
nn.Linear(reduced_channels, in_channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
batch_size, channels, _, _ = x.size()
# Squeeze
y = self.global_avg_pool(x).view(batch_size, channels)
# Excitation
y = self.fc(y).view(batch_size, channels, 1, 1)
# Scale
return x * y.expand_as(x)
- 输出
加权图像输出
3.2 应用在示例数据上
- 代码
import torch
# 创建示例输入数据
input_tensor = torch.randn(1, 64, 32, 32) # (batch_size, in_channels, height, width)
# 初始化 SEBlock 模块
se_block = SEBlock(in_channels=64, reduction=16)
# 前向传播
output_tensor = se_block(input_tensor)
print(output_tensor.shape)
- 输出
torch.Size([1, 64, 32, 32])
四、注意事项
SEBlock
模块通过全局平均池化和全连接层来重新校准通道的权重,从而增强模型的表达能力。- 在使用
SEBlock
时,确保输入特征的通道数和缩减比例设置合理,以避免计算开销过大。 - 该模块主要用于图像数据处理,适用于各种计算机视觉任务,如图像分类、目标检测等。