【Block总结】ELGCA模块,池化-转置(PT)注意力和深度卷积有效聚合局部和全局上下文信息
ELGCA结构
论文题目:ELGC-Net: Efficient Local-Global Context Aggregation for Remote Sensing Change Detection
论文链接:https://arxiv.org/pdf/2403.17909
官方github:https://github.com/techmn/elgcnet
高效局部-全局上下文聚合器(ELGCA)是ELGC-Net框架中的核心模块,旨在有效捕获遥感图像中的局部和全局上下文信息。其结构主要包括以下几个方面:
-
Siamese编码器:ELGC-Net采用Siamese架构,通过两个相同的编码器分别处理输入的时间序列图像。这种设计使得模型能够有效地提取和比较两个图像之间的特征。
-
池化-转置(PT)注意力机制:ELGCA引入了一种新颖的PT注意力机制,通过池化操作来增强全局上下文的捕获,同时利用转置卷积来恢复特征图的空间分辨率。这种机制能够在保持计算效率的同时,捕获更丰富的上下文信息。
-
深度卷积:通过深度卷积,ELGCA能够在不同的空间尺度上提取特征,从而增强模型对细微变化的敏感性。这种方法有效地减少了模型的参数数量,同时提高了特征提取的效率。
-
融合模块:在特征提取后,ELGCA通过融合模块将来自不同图像的特征进行整合,生成最终的变化检测输出。这一过程确保了模型能够综合考虑局部和全局信息,从而提高检测精度。
优点
ELGCA的设计带来了多个显著的优点:
-
提高检测精度:通过有效聚合局部和全局上下文信息,ELGCA显著提高了遥感变化检测的准确性。实验结果表明,ELGC-Net在多个挑战性数据集上均超越了现有的最先进方法。
-
减少计算复杂性:ELGCA通过引入高效的注意力机制和深度卷积,减少了模型的计算复杂性和参数数量,使得ELGC-Net在资源受限的环境中也能高效运行。
-
适应性强:ELGCA能够自适应地处理不同尺度和不同类型的遥感图像,增强了模型在多种应用场景下的适用性,包括城市监测、环境变化分析等。
-
轻量化设计:ELGC-Net还提供了一个轻量级变体(ELGC-Net-LW),在保持性能的同时,进一步减少了计算资源的需求,适合在边缘设备上进行实时变化检测。
综上所述,高效局部-全局上下文聚合器(ELGCA)通过其创新的设计和实现,不仅提升了遥感变化检测的性能,还为实际应用提供了更高的灵活性和效率。
代码结构
import torch
import torch.nn as nn
class ELGCA(nn.Module):
"""
Efficient local global context aggregation module
dim: number of channels of input
heads: number of heads utilized in computing attention
"""
def __init__(self, dim, heads=4):
super().__init__()
self.heads = heads
self.dwconv = nn.Conv2d(dim // 2, dim // 2, 3, padding=1, groups=dim // 2)
self.qkvl = nn.Conv2d(dim // 2, (dim // 4) * self.heads, 1, padding=0)
self.pool_q = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
self.pool_k = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.act = nn.GELU()
def forward(self, x):
B, C, H, W = x.shape
x1, x2 = torch.split(x, [C // 2, C // 2], dim=1)
# apply depth-wise convolution on half channels
x1 = self.act(self.dwconv(x1))
# linear projection of other half before computing attention
x2 = self.act(self.qkvl(x2))
x2 = x2.reshape(B, self.heads, C // 4, H, W)
q = torch.sum(x2[:, :-3, :, :, :], dim=1)
k = x2[:, -3, :, :, :]
q = self.pool_q(q)
k = self.pool_k(k)
v = x2[:, -2, :, :, :].flatten(2)
lfeat = x2[:, -1, :, :, :]
qk = torch.matmul(q.flatten(2), k.flatten(2).transpose(1, 2))
qk = torch.softmax(qk, dim=1).transpose(1, 2)
x2 = torch.matmul(qk, v).reshape(B, C // 4, H, W)
x = torch.cat([x1, lfeat, x2], dim=1)
return x
if __name__ == '__main__':
# 创建一个随机输入张量,形状为 (batch_size,height×width,channels)
input = torch.rand(1, 64,40, 40)
# 实例化ELGCA模块
block = ELGCA(64,4)
# 前向传播
output = block(input)
# 打印输入和输出的形状
print(input.size())
print(output.size())
输出结果
torch.Size([1, 64, 40, 40])
torch.Size([1, 64, 40, 40])
主要组件
-
初始化方法 (
__init__
):dim
: 输入通道数。heads
: 注意力头的数量,默认为4。dwconv
: 深度卷积层,用于对输入的一半通道进行卷积操作,增强局部特征提取。qkvl
: 线性投影层,将另一半通道的特征映射到多个注意力头。pool_q
和pool_k
: 池化层,用于对查询(Q)和键(K)进行下采样,分别使用平均池化和最大池化。act
: 激活函数,使用GELU(高斯误差线性单元)。
-
前向传播方法 (
forward
):- 输入
x
的形状为(B, C, H, W)
,其中B
是批量大小,C
是通道数,H
和W
是特征图的高度和宽度。 - 使用
torch.split
将输入x
分为两部分x1
和x2
,每部分的通道数为C // 2
。 - 对
x1
应用深度卷积,提取局部特征。 - 对
x2
进行线性投影,准备计算注意力。 - 将
x2
重塑为(B, heads, C // 4, H, W)
,以便于后续的注意力计算。 - 计算查询(Q)、键(K)和值(V):
q
是x2
的前heads-3
个通道的和。k
是x2
的倒数第三个通道。v
是x2
的倒数第二个通道,展平为二维。lfeat
是x2
的最后一个通道,保留用于后续拼接。
- 对
q
和k
进行池化,减少计算量。 - 计算注意力权重
qk
,并通过softmax归一化。 - 使用注意力权重对值(V)进行加权求和,得到新的特征图
x2
。 - 最后,将
x1
、lfeat
和x2
在通道维度上拼接,形成最终输出。
- 输入
总结
ELGCA模块通过深度卷积和注意力机制的结合,有效地聚合了局部和全局上下文信息。这种设计不仅提高了特征提取的效率,还降低了计算复杂性,使得ELGC-Net在遥感变化检测任务中表现出色。