【Block总结】TFF和SFF模块,时间和尺度的特征融合|即插即用
论文信息
标题: STNet: Spatial and Temporal feature fusion network for change detection in remote sensing images
作者: Xiaowen Ma, Jiawei Yang, Tingfeng Hong, Mengting Ma, Ziyan Zhao, Tian Feng, Wei Zhang
发表日期: 2023年4月22日
链接:https://arxiv.org/pdf/2304.11422
摘要: 本文提出了一种新的遥感变化检测(RSCD)网络STNet,旨在通过空间和时间特征融合来识别遥感图像中的变化。该方法通过设计时态特征融合模块和空间特征融合模块,强调感兴趣的变化并恢复变化表征的空间细节。
创新点
-
跨时间特征融合模块(TFF): 提出了一种基于跨时间门控机制的特征融合模块,用于双时相特征的融合。该模块通过选择性地增强目标变化信息并抑制非目标变化,提升了变化检测的准确性。
-
跨尺度特征融合模块(SFF): 首次采用跨尺度注意力机制,利用高层次特征引导低层次特征的建模,能够捕捉变化目标的细粒度空间信息,恢复变化表示的空间细节。
-
高效的多尺度特征交互设计: 设计了一个轻量化的深度神经网络框架,使用ResNet-18作为特征提取的骨干网络,结合TFF和SFF模块,有效减少了参数量和计算成本。
-
创新的损失函数: 采用结合Focal Loss和Dice Loss的混合损失函数,解决变化检测中正负样本不平衡的问题。
方法
-
整体架构: STNet的结构包括输入双时相遥感影像,使用共享权重的ResNet-18提取多尺度特征,随后通过TFF和SFF模块进行特征融合,最后通过轻量化解码器生成高精度变化检测图。
-
时间特征融合模块(TFF):
- 目标是通过跨时间门控机制融合双时相特征,强调目标变化并抑制非目标变化。
- 工作流程包括对双时相特征进行逐元素相减,生成粗粒度变化表示,并通过深度可分离卷积提取特征。
-
空间特征融合模块(SFF):
- 目标是通过跨尺度注意力机制融合多尺度特征,恢复变化表示的空间细节。
- 工作流程涉及高层次特征与低层次特征的交互,使用注意力机制计算像素间的关系。
-
解码器与变化检测图生成: 轻量化解码器将各尺度的变化表示上采样到统一尺寸,并结合通道注意力模块生成最终的变化检测图。
跨时间特征融合模块(TFF)
跨时间特征融合模块(Temporal Feature Fusion Module, TFF)是STNet网络中的一个关键组成部分,旨在有效融合双时相特征,以提高遥感图像变化检测的准确性。该模块通过跨时间门控机制,选择性地增强目标变化信息,同时抑制非目标变化,从而提升变化检测的性能。
主要功能
-
特征融合: TFF模块通过对双时相特征进行逐元素相减,生成初步的粗粒度变化表示。这一过程帮助识别出在两个时间点之间发生的变化。
-
门控机制: 该模块使用门控机制对特征进行加权融合。具体而言,粗粒度变化表示与原始双时相特征进行拼接,并通过深度可分离卷积提取特征,生成权重。这些权重用于调整融合过程,确保重要的变化信息得到强调,而非目标变化则被抑制。
-
增强变化信息: 通过选择性地增强目标变化信息,TFF模块能够有效提高变化检测的准确性,尤其是在复杂场景中。
原理
-
输入特征: TFF模块接收来自两个时间点的特征表示(例如, R 1 R_1 R1和 R 2 R_2 R2)。
-
粗粒度变化表示: 通过逐元素相减,计算出粗粒度变化表示 R c = R 1 − R 2 R_c = R_1 - R_2 Rc=R1−R2。
-
特征拼接与卷积: 将 R c R_c Rc与 R 1 R_1 R1和 R 2 R_2 R2分别进行拼接,并通过深度可分离卷积提取特征,生成权重 W 1 W_1 W1和 W 2 W_2 W2。
-
融合输出: 使用门控机制,根据生成的权重调整 R 1 R_1 R1和 R 2 R_2 R2的融合,最终输出融合后的时间特征 R t R_t Rt。
提升效果
通过引入TFF模块,STNet在多个遥感变化检测基准数据集(如WHU、LEVIR-CD和CLCD)上表现出显著的性能提升。具体而言,TFF模块的有效性在于:
-
提高准确性: 通过增强目标变化信息,TFF模块显著提高了变化检测的准确性和召回率。
-
减少计算成本: 由于采用了深度可分离卷积,TFF模块在保持性能的同时有效减少了计算量和参数量,使得STNet在实际应用中更加高效。
-
性能提升: STNet在三个遥感变化检测的基准数据集(WHU、LEVIR-CD和CLCD)上,在F1分数、IoU、整体准确率等指标上取得了领先的表现。
-
参数与计算量: STNet的参数量为14.6M,计算量为9.61G FLOPs,显著低于许多现有方法,证明了其轻量化设计的有效性。
-
消融实验: 通过对比基础模型与添加TFF和SFF模块的性能差异,验证了这两个模块的有效性,表明它们的联合使用具有协同增益。
空间特征融合模块(SFF)
空间特征融合模块(Spatial-aware Feature Fusion Module, SFF)是STNet网络中的一个重要组成部分,旨在通过集成多尺度特征来提高遥感图像变化检测的精度。SFF模块的设计使得模型能够同时处理小范围的纹理异常和大范围的结构缺陷,确保在检测过程中保留正常样本的细节信息并精确重建异常区域。
主要功能
-
多尺度特征集成: SFF模块能够有效整合来自不同尺度的特征,增强模型对各种尺度变化的适应能力。这一特性使得模型在处理复杂场景时,能够更好地捕捉到细微的变化。
-
细节保留: 通过融合高层次和低层次特征,SFF模块确保了在重建异常区域时,正常样本的细节信息不会丢失,从而提高了变化检测的准确性。
-
适应性强: SFF模块的设计使其能够适应不同类型的变化,无论是小范围的纹理变化还是大范围的结构变化,都能有效处理。
原理
-
特征提取: SFF模块首先从高层次特征 H j H_j Hj中提取重要信息。这些高层次特征通常包含丰富的语义信息,但在空间定位上相对粗糙。
-
特征映射: 通过卷积操作 F ( H j ) F(H_j) F(Hj)对高层次特征进行映射,以提取更具代表性的特征信息。
-
特征融合: 将映射后的高层次特征与低层次输出 P i P_i Pi进行融合,生成最终输出 Q i Q_i Qi。这一过程确保了不同尺度特征的有效结合,从而增强了模型的特征表达能力。
公式表示
SFF模块的输出可以通过以下公式表示:
Q i = P i + ∑ j = 1 J F ( H j ) Q_i = P_i + \sum_{j=1}^{J} F(H_j) Qi=Pi+j=1∑JF(Hj)
其中:
- Q i Q_i Qi:第 i i i层的最终低尺度输出。
- P i P_i Pi:第 i i i层的低尺度输出特征。
- H j H_j Hj:第 j j j层的高尺度输出特征。
- F ( ⋅ ) F(\cdot) F(⋅):表示卷积操作,通常包括3x3卷积、归一化和激活函数。
提升效果
SFF模块的引入显著提升了STNet在遥感图像变化检测中的性能,具体体现在以下几个方面:
-
提高检测精度: 通过有效融合多尺度特征,SFF模块增强了模型对变化的敏感性,提升了检测的准确性和召回率。
-
减少计算复杂度: SFF模块的设计考虑到了计算效率,使得在保持高性能的同时,模型的计算量和参数量得以控制。
-
增强模型鲁棒性: 通过多尺度特征的融合,SFF模块提高了模型在复杂场景下的鲁棒性,使其能够更好地应对各种变化情况。
空间特征融合模块(SFF)是STNet中不可或缺的部分,通过创新的特征融合策略,为遥感图像变化检测提供了强大的支持。其在多尺度特征的有效集成和细节信息的保留方面表现出色,显著提升了模型的整体性能和实用性。
总结
STNet通过创新的跨时间和跨尺度特征融合方法,为遥感图像变化检测提供了一种高效的解决方案。其在多个基准数据集上的优异表现证明了该方法的有效性和实用性,标志着遥感变化检测领域的一次重要进展。未来的研究可以进一步探索该网络在其他应用场景中的潜力。代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
def conv_3x3(in_channel, out_channel):
return nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True)
)
def dsconv_3x3(in_channel, out_channel):
return nn.Sequential(
nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, groups=in_channel),
nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, groups=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True)
)
def conv_1x1(in_channel, out_channel):
return nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True)
)
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class SelfAttentionBlock(nn.Module):
"""
query_feats: (B, C, h, w)
key_feats: (B, C, h, w)
value_feats: (B, C, h, w)
output: (B, C, h, w)
"""
def __init__(self, key_in_channels, query_in_channels, transform_channels, out_channels,
key_query_num_convs, value_out_num_convs):
super(SelfAttentionBlock, self).__init__()
self.key_project = self.buildproject(
in_channels=key_in_channels,
out_channels=transform_channels,
num_convs=key_query_num_convs,
)
self.query_project = self.buildproject(
in_channels=query_in_channels,
out_channels=transform_channels,
num_convs=key_query_num_convs
)
self.value_project = self.buildproject(
in_channels=key_in_channels,
out_channels=transform_channels,
num_convs=value_out_num_convs
)
self.out_project = self.buildproject(
in_channels=transform_channels,
out_channels=out_channels,
num_convs=value_out_num_convs
)
self.transform_channels = transform_channels
def forward(self, query_feats, key_feats, value_feats):
batch_size = query_feats.size(0)
query = self.query_project(query_feats)
query = query.reshape(*query.shape[:2], -1)
query = query.permute(0, 2, 1).contiguous() # (B, h*w, C)
key = self.key_project(key_feats)
key = key.reshape(*key.shape[:2], -1) # (B, C, h*w)
value = self.value_project(value_feats)
value = value.reshape(*value.shape[:2], -1)
value = value.permute(0, 2, 1).contiguous() # (B, h*w, C)
sim_map = torch.matmul(query, key)
sim_map = (self.transform_channels ** -0.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1) # (B, h*w, K)
context = torch.matmul(sim_map, value) # (B, h*w, C)
context = context.permute(0, 2, 1).contiguous()
context = context.reshape(batch_size, -1, *query_feats.shape[2:]) # (B, C, h, w)
context = self.out_project(context) # (B, C, h, w)
return context
def buildproject(self, in_channels, out_channels, num_convs):
convs = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
for _ in range(num_convs - 1):
convs.append(
nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
)
if len(convs) > 1:
return nn.Sequential(*convs)
return convs[0]
class TFF(nn.Module):
def __init__(self, in_channel, out_channel):
super(TFF, self).__init__()
self.catconvA = dsconv_3x3(in_channel * 2, in_channel)
self.catconvB = dsconv_3x3(in_channel * 2, in_channel)
self.catconv = dsconv_3x3(in_channel * 2, out_channel)
self.convA = nn.Conv2d(in_channel, 1, 1)
self.convB = nn.Conv2d(in_channel, 1, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, xA, xB):
x_diff = xA - xB
x_diffA = self.catconvA(torch.cat([x_diff, xA], dim=1))
x_diffB = self.catconvB(torch.cat([x_diff, xB], dim=1))
A_weight = self.sigmoid(self.convA(x_diffA))
B_weight = self.sigmoid(self.convB(x_diffB))
xA = A_weight * xA
xB = B_weight * xB
x = self.catconv(torch.cat([xA, xB], dim=1))
return x
class SFF(nn.Module):
def __init__(self, in_channel):
super(SFF, self).__init__()
self.conv_small = conv_1x1(in_channel, in_channel)
self.conv_big = conv_1x1(in_channel, in_channel)
self.catconv = conv_3x3(in_channel * 2, in_channel)
self.attention = SelfAttentionBlock(
key_in_channels=in_channel,
query_in_channels=in_channel,
transform_channels=in_channel // 2,
out_channels=in_channel,
key_query_num_convs=2,
value_out_num_convs=1
)
def forward(self, x_small, x_big):
img_size = x_big.size(2), x_big.size(3)
x_small = F.interpolate(x_small, img_size, mode="bilinear", align_corners=False)
x = self.conv_small(x_small) + self.conv_big(x_big)
new_x = self.attention(x, x, x_big)
out = self.catconv(torch.cat([new_x, x_big], dim=1))
return out
if __name__ == '__main__':
block = SFF(64)
x_small = torch.rand(1, 64, 20, 20)
x_big = torch.rand(1, 64, 40, 40)
output = block(x_small, x_big)
# 打印输入和输出的形状
print(f"Input Small: {x_small.shape}")
print(f"Input Big: {x_big.shape}")
print(f"Output: {output.shape}")
block = TFF(64,32)
x_A = torch.rand(1, 64, 40, 40)
x_B = torch.rand(1, 64, 40, 40)
output = block(x_A, x_B)
# 打印输入和输出的形状
print(f"Input A: {x_A.shape}")
print(f"Input B: {x_B.shape}")
print(f"Output: {output.shape}")
输出结果: