【Block总结】FreqFusion特征融合模块,适用于分割、检测任务|即插即用
论文信息
标题: Frequency-aware Feature Fusion for Dense Image Prediction
作者: Linwei Chen, Ying Fu, Lin Gu, Chenggang Yan, Tatsuya Harada, Gao Huang
论文链接:https://arxiv.org/pdf/2408.12879
GitHub链接:https://github.com/Linwei-Chen/FreqFusion
创新点
本论文提出了一种新的特征融合方法,称为频率感知特征融合(FreqFusion),旨在解决现有密集图像预测模型中存在的两个主要问题:
- 类别内不一致性:由于高频特征的干扰,导致同一类别内的特征值快速变化。
- 边界模糊:融合特征的边界缺乏准确的高频信息,造成边界位移。
方法
FreqFusion方法集成了三个关键组件:
-
自适应低通滤波器(ALPF)生成器:用于预测空间变化的低通滤波器,旨在减少上采样过程中对象内部的高频成分,从而降低类别内的不一致性。
-
偏移生成器:通过重采样,将不一致的特征替换为更一致的特征,以提高同一类别目标特征的一致性。
-
自适应高通滤波器(AHPF)生成器:用于增强在下采样过程中丢失的高频细节边界信息,从而提高边界的清晰度。
这三个组件协同工作,旨在恢复具有一致类别信息和明确边界的融合特征。
效果
通过综合可视化和定量分析,FreqFusion显著提高了特征的一致性,并锐化了对象的边界。该方法在多个密集预测任务中表现出色,能够有效改善现有模型的性能。
实验结果
在各种密集图像预测任务中,FreqFusion的实验结果如下:
-
语义分割:在轻量化语义分割模型SegNeXt上,提升了2.4 mIoU;在强大的Mask2Former上,提升了1.4 mIoU。
-
目标检测:在Faster RCNN上,提升了1.9 AP。
-
实例分割:在Mask R-CNN上,提升了1.7 box AP和1.3 mask AP。
-
全景分割:在PanopticFCN上,提升了2.5 PQ。
这些结果表明,FreqFusion在提高特征质量和模型性能方面具有显著的优势,尤其是在处理复杂的图像密集预测任务时。
代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import warnings
import numpy as np
def xavier_init(module: nn.Module,
gain: float = 1,
bias: float = 0,
distribution: str = 'normal') -> None:
assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform':
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def carafe(x, normed_mask, kernel_size, group=1, up=1):
b, c, h, w = x.shape
_, m_c, m_h, m_w = normed_mask.shape
print('x', x.shape)
print('normed_mask', normed_mask.shape)
# assert m_c == kernel_size ** 2 * up ** 2
assert m_h == up * h
assert m_w == up * w
pad = kernel_size // 2
# print(pad)
pad_x = F.pad(x, pad=[pad] * 4, mode='reflect')
# print(pad_x.shape)
unfold_x = F.unfold(pad_x, kernel_size=(kernel_size, kernel_size), stride=1, padding=0)
# unfold_x = unfold_x.reshape(b, c, 1, kernel_size, kernel_size, h, w).repeat(1, 1, up ** 2, 1, 1, 1, 1)
unfold_x = unfold_x.reshape(b, c * kernel_size * kernel_size, h, w)
unfold_x = F.interpolate(unfold_x, scale_factor=up, mode='nearest')
# normed_mask = normed_mask.reshape(b, 1, up ** 2, kernel_size, kernel_size, h, w)
unfold_x = unfold_x.reshape(b, c, kernel_size * kernel_size, m_h, m_w)
normed_mask = normed_mask.reshape(b, 1, kernel_size * kernel_size, m_h, m_w)
res = unfold_x * normed_mask
# test
# res[:, :, 0] = 1
# res[:, :, 1] = 2
# res[:, :, 2] = 3
# res[:, :, 3] = 4
res = res.sum(dim=2).reshape(b, c, m_h, m_w)
# res = F.pixel_shuffle(res, up)
# print(res.shape)
# print(res)
return res
def normal_init(module, mean=0, std=1, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def resize(input,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None,
warning=True):
if warning:
if size is not None and align_corners:
input_h, input_w = tuple(int(x) for x in input.shape[2:])
output_h, output_w = tuple(int(x) for x in size)
if output_h > input_h or output_w > input_w:
if ((output_h > 1 and output_w > 1 and input_h > 1
and input_w > 1) and (output_h - 1) % (input_h - 1)
and (output_w - 1) % (input_w - 1)):
warnings.warn(
f'When align_corners={align_corners}, '
'the output would more aligned if '
f'input size {(input_h, input_w)} is `x+1` and '
f'out size {(output_h, output_w)} is `nx+1`')
return F.interpolate(input, size, scale_factor, mode, align_corners)
def hamming2D(M, N):
"""
生成二维Hamming窗
参数:
- M:窗口的行数
- N:窗口的列数
返回:
- 二维Hamming窗
"""
hamming_x = np.hamming(M)
hamming_y = np.hamming(N)
# 通过外积生成二维Hamming窗
hamming_2d = np.outer(hamming_x, hamming_y)
return hamming_2d
class FreqFusion(nn.Module):
def __init__(self,
hr_channels,
lr_channels,
scale_factor=1,
lowpass_kernel=5,
highpass_kernel=3,
up_group=1,
encoder_kernel=3,
encoder_dilation=1,
compressed_channels=64,
align_corners=False,
upsample_mode='nearest',
feature_resample=False, # use offset generator or not
feature_resample_group=4,
comp_feat_upsample=True, # use ALPF & AHPF for init upsampling
use_high_pass=True,
use_low_pass=True,
hr_residual=True,
semi_conv=True,
hamming_window=True, # for regularization, do not matter really
feature_resample_norm=True,
**kwargs):
super().__init__()
self.scale_factor = scale_factor
self.lowpass_kernel = lowpass_kernel
self.highpass_kernel = highpass_kernel
self.up_group = up_group
self.encoder_kernel = encoder_kernel
self.encoder_dilation = encoder_dilation
self.compressed_channels = compressed_channels
self.hr_channel_compressor = nn.Conv2d(hr_channels, self.compressed_channels, 1)
self.lr_channel_compressor = nn.Conv2d(lr_channels, self.compressed_channels, 1)
self.content_encoder = nn.Conv2d( # ALPF generator
self.compressed_channels,
lowpass_kernel ** 2 * self.up_group * self.scale_factor * self.scale_factor,
self.encoder_kernel,
padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
dilation=self.encoder_dilation,
groups=1)
self.align_corners = align_corners
self.upsample_mode = upsample_mode
self.hr_residual = hr_residual
self.use_high_pass = use_high_pass
self.use_low_pass = use_low_pass
self.semi_conv = semi_conv
self.feature_resample = feature_resample
self.comp_feat_upsample = comp_feat_upsample
if self.feature_resample:
self.dysampler = LocalSimGuidedSampler(in_channels=compressed_channels, scale=2, style='lp',
groups=feature_resample_group, use_direct_scale=True,
kernel_size=encoder_kernel, norm=feature_resample_norm)
if self.use_high_pass:
self.content_encoder2 = nn.Conv2d( # AHPF generator
self.compressed_channels,
highpass_kernel ** 2 * self.up_group * self.scale_factor * self.scale_factor,
self.encoder_kernel,
padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
dilation=self.encoder_dilation,
groups=1)
self.hamming_window = hamming_window
lowpass_pad = 0
highpass_pad = 0
if self.hamming_window:
self.register_buffer('hamming_lowpass', torch.FloatTensor(
hamming2D(lowpass_kernel + 2 * lowpass_pad, lowpass_kernel + 2 * lowpass_pad))[None, None,])
self.register_buffer('hamming_highpass', torch.FloatTensor(
hamming2D(highpass_kernel + 2 * highpass_pad, highpass_kernel + 2 * highpass_pad))[None, None,])
else:
self.register_buffer('hamming_lowpass', torch.FloatTensor([1.0]))
self.register_buffer('hamming_highpass', torch.FloatTensor([1.0]))
self.init_weights()
def init_weights(self):
for m in self.modules():
# print(m)
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
normal_init(self.content_encoder, std=0.001)
if self.use_high_pass:
normal_init(self.content_encoder2, std=0.001)
def kernel_normalizer(self, mask, kernel, scale_factor=None, hamming=1):
if scale_factor is not None:
mask = F.pixel_shuffle(mask, self.scale_factor)
n, mask_c, h, w = mask.size()
mask_channel = int(mask_c / float(kernel ** 2)) # group
# mask = mask.view(n, mask_channel, -1, h, w)
# mask = F.softmax(mask, dim=2, dtype=mask.dtype)
# mask = mask.view(n, mask_c, h, w).contiguous()
mask = mask.view(n, mask_channel, -1, h, w)
mask = F.softmax(mask, dim=2, dtype=mask.dtype)
mask = mask.view(n, mask_channel, kernel, kernel, h, w)
mask = mask.permute(0, 1, 4, 5, 2, 3).view(n, -1, kernel, kernel)
# mask = F.pad(mask, pad=[padding] * 4, mode=self.padding_mode) # kernel + 2 * padding
mask = mask * hamming
mask /= mask.sum(dim=(-1, -2), keepdims=True)
# print(hamming)
# print(mask.shape)
mask = mask.view(n, mask_channel, h, w, -1)
mask = mask.permute(0, 1, 4, 2, 3).view(n, -1, h, w).contiguous()
return mask
def forward(self, hr_feat, lr_feat, use_checkpoint=False): # use check_point to save GPU memory
if use_checkpoint:
return checkpoint(self._forward, hr_feat, lr_feat)
else:
return self._forward(hr_feat, lr_feat)
def _forward(self, hr_feat, lr_feat):
compressed_hr_feat = self.hr_channel_compressor(hr_feat)
compressed_lr_feat = self.lr_channel_compressor(lr_feat)
if self.semi_conv:
if self.comp_feat_upsample:
if self.use_high_pass:
mask_hr_hr_feat = self.content_encoder2(compressed_hr_feat) # 从hr_feat得到初始高通滤波特征
mask_hr_init = self.kernel_normalizer(mask_hr_hr_feat, self.highpass_kernel,
hamming=self.hamming_highpass) # kernel归一化得到初始高通滤波
compressed_hr_feat = compressed_hr_feat + compressed_hr_feat - carafe(compressed_hr_feat,
mask_hr_init,
self.highpass_kernel,
self.up_group,
1) # 利用初始高通滤波对压缩hr_feat的高频增强 (x-x的低通结果=x的高通结果)
mask_lr_hr_feat = self.content_encoder(compressed_hr_feat) # 从hr_feat得到初始低通滤波特征
mask_lr_init = self.kernel_normalizer(mask_lr_hr_feat, self.lowpass_kernel,
hamming=self.hamming_lowpass) # kernel归一化得到初始低通滤波
mask_lr_lr_feat_lr = self.content_encoder(compressed_lr_feat) # 从hr_feat得到另一部分初始低通滤波特征
mask_lr_lr_feat = F.interpolate( # 利用初始低通滤波对另一部分初始低通滤波特征上采样
carafe(mask_lr_lr_feat_lr, mask_lr_init, self.lowpass_kernel, self.up_group, 2),
size=compressed_hr_feat.shape[-2:], mode='nearest')
mask_lr = mask_lr_hr_feat + mask_lr_lr_feat # 将两部分初始低通滤波特征合在一起
mask_lr_init = self.kernel_normalizer(mask_lr, self.lowpass_kernel,
hamming=self.hamming_lowpass) # 得到初步融合的初始低通滤波
mask_hr_lr_feat = F.interpolate( # 使用初始低通滤波对lr_feat处理,分辨率得到提高
carafe(self.content_encoder2(compressed_lr_feat), mask_lr_init, self.lowpass_kernel,
self.up_group, 2), size=compressed_hr_feat.shape[-2:], mode='nearest')
mask_hr = mask_hr_hr_feat + mask_hr_lr_feat # 最终高通滤波特征
else:
raise NotImplementedError
else:
mask_lr = self.content_encoder(compressed_hr_feat) + F.interpolate(
self.content_encoder(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest')
if self.use_high_pass:
mask_hr = self.content_encoder2(compressed_hr_feat) + F.interpolate(
self.content_encoder2(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest')
else:
compressed_x = F.interpolate(compressed_lr_feat, size=compressed_hr_feat.shape[-2:],
mode='nearest') + compressed_hr_feat
mask_lr = self.content_encoder(compressed_x)
if self.use_high_pass:
mask_hr = self.content_encoder2(compressed_x)
mask_lr = self.kernel_normalizer(mask_lr, self.lowpass_kernel, hamming=self.hamming_lowpass)
if self.semi_conv:
lr_feat = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 2)
else:
lr_feat = resize(
input=lr_feat,
size=hr_feat.shape[2:],
mode=self.upsample_mode,
align_corners=None if self.upsample_mode == 'nearest' else self.align_corners)
lr_feat = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 1)
if self.use_high_pass:
mask_hr = self.kernel_normalizer(mask_hr, self.highpass_kernel, hamming=self.hamming_highpass)
hr_feat_hf = hr_feat - carafe(hr_feat, mask_hr, self.highpass_kernel, self.up_group, 1)
if self.hr_residual:
# print('using hr_residual')
hr_feat = hr_feat_hf + hr_feat
else:
hr_feat = hr_feat_hf
if self.feature_resample:
# print(lr_feat.shape)
lr_feat = self.dysampler(hr_x=compressed_hr_feat,
lr_x=compressed_lr_feat, feat2sample=lr_feat)
return mask_lr, hr_feat, lr_feat
class LocalSimGuidedSampler(nn.Module):
"""
offset generator in FreqFusion
"""
def __init__(self, in_channels, scale=2, style='lp', groups=4, use_direct_scale=True, kernel_size=1, local_window=3,
sim_type='cos', norm=True, direction_feat='sim_concat'):
super().__init__()
assert scale == 2
assert style == 'lp'
self.scale = scale
self.style = style
self.groups = groups
self.local_window = local_window
self.sim_type = sim_type
self.direction_feat = direction_feat
if style == 'pl':
assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0
assert in_channels >= groups and in_channels % groups == 0
if style == 'pl':
in_channels = in_channels // scale ** 2
out_channels = 2 * groups
else:
out_channels = 2 * groups * scale ** 2
if self.direction_feat == 'sim':
self.offset = nn.Conv2d(local_window ** 2 - 1, out_channels, kernel_size=kernel_size,
padding=kernel_size // 2)
elif self.direction_feat == 'sim_concat':
self.offset = nn.Conv2d(in_channels + local_window ** 2 - 1, out_channels, kernel_size=kernel_size,
padding=kernel_size // 2)
else:
raise NotImplementedError
normal_init(self.offset, std=0.001)
if use_direct_scale:
if self.direction_feat == 'sim':
self.direct_scale = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
padding=kernel_size // 2)
elif self.direction_feat == 'sim_concat':
self.direct_scale = nn.Conv2d(in_channels + local_window ** 2 - 1, out_channels,
kernel_size=kernel_size, padding=kernel_size // 2)
else:
raise NotImplementedError
constant_init(self.direct_scale, val=0.)
out_channels = 2 * groups
if self.direction_feat == 'sim':
self.hr_offset = nn.Conv2d(local_window ** 2 - 1, out_channels, kernel_size=kernel_size,
padding=kernel_size // 2)
elif self.direction_feat == 'sim_concat':
self.hr_offset = nn.Conv2d(in_channels + local_window ** 2 - 1, out_channels, kernel_size=kernel_size,
padding=kernel_size // 2)
else:
raise NotImplementedError
normal_init(self.hr_offset, std=0.001)
if use_direct_scale:
if self.direction_feat == 'sim':
self.hr_direct_scale = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
padding=kernel_size // 2)
elif self.direction_feat == 'sim_concat':
self.hr_direct_scale = nn.Conv2d(in_channels + local_window ** 2 - 1, out_channels,
kernel_size=kernel_size, padding=kernel_size // 2)
else:
raise NotImplementedError
constant_init(self.hr_direct_scale, val=0.)
self.norm = norm
if self.norm:
self.norm_hr = nn.GroupNorm(in_channels // 8, in_channels)
self.norm_lr = nn.GroupNorm(in_channels // 8, in_channels)
else:
self.norm_hr = nn.Identity()
self.norm_lr = nn.Identity()
self.register_buffer('init_pos', self._init_pos())
def _init_pos(self):
h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)
def sample(self, x, offset, scale=None):
if scale is None: scale = self.scale
B, _, H, W = offset.shape
offset = offset.view(B, 2, -1, H, W)
coords_h = torch.arange(H) + 0.5
coords_w = torch.arange(W) + 0.5
coords = torch.stack(torch.meshgrid([coords_w, coords_h])
).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)
normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)
coords = 2 * (coords + offset) / normalizer - 1
coords = F.pixel_shuffle(coords.view(B, -1, H, W), scale).view(
B, 2, -1, scale * H, scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)
return F.grid_sample(x.reshape(B * self.groups, -1, x.size(-2), x.size(-1)), coords, mode='bilinear',
align_corners=False, padding_mode="border").view(B, -1, scale * H, scale * W)
def forward(self, hr_x, lr_x, feat2sample):
hr_x = self.norm_hr(hr_x)
lr_x = self.norm_lr(lr_x)
if self.direction_feat == 'sim':
hr_sim = compute_similarity(hr_x, self.local_window, dilation=2, sim='cos')
lr_sim = compute_similarity(lr_x, self.local_window, dilation=2, sim='cos')
elif self.direction_feat == 'sim_concat':
hr_sim = torch.cat([hr_x, compute_similarity(hr_x, self.local_window, dilation=2, sim='cos')], dim=1)
lr_sim = torch.cat([lr_x, compute_similarity(lr_x, self.local_window, dilation=2, sim='cos')], dim=1)
hr_x, lr_x = hr_sim, lr_sim
# offset = self.get_offset(hr_x, lr_x)
offset = self.get_offset_lp(hr_x, lr_x, hr_sim, lr_sim)
return self.sample(feat2sample, offset)
# def get_offset_lp(self, hr_x, lr_x):
def get_offset_lp(self, hr_x, lr_x, hr_sim, lr_sim):
if hasattr(self, 'direct_scale'):
# offset = (self.offset(lr_x) + F.pixel_unshuffle(self.hr_offset(hr_x), self.scale)) * (self.direct_scale(lr_x) + F.pixel_unshuffle(self.hr_direct_scale(hr_x), self.scale)).sigmoid() + self.init_pos
offset = (self.offset(lr_sim) + F.pixel_unshuffle(self.hr_offset(hr_sim), self.scale)) * (
self.direct_scale(lr_x) + F.pixel_unshuffle(self.hr_direct_scale(hr_x),
self.scale)).sigmoid() + self.init_pos
# offset = (self.offset(lr_sim) + F.pixel_unshuffle(self.hr_offset(hr_sim), self.scale)) * (self.direct_scale(lr_sim) + F.pixel_unshuffle(self.hr_direct_scale(hr_sim), self.scale)).sigmoid() + self.init_pos
else:
offset = (self.offset(lr_x) + F.pixel_unshuffle(self.hr_offset(hr_x), self.scale)) * 0.25 + self.init_pos
return offset
def get_offset(self, hr_x, lr_x):
if self.style == 'pl':
raise NotImplementedError
return self.get_offset_lp(hr_x, lr_x)
def compute_similarity(input_tensor, k=3, dilation=1, sim='cos'):
"""
计算输入张量中每一点与周围KxK范围内的点的余弦相似度。
参数:
- input_tensor: 输入张量,形状为[B, C, H, W]
- k: 范围大小,表示周围KxK范围内的点
返回:
- 输出张量,形状为[B, KxK-1, H, W]
"""
B, C, H, W = input_tensor.shape
# 使用零填充来处理边界情况
# padded_input = F.pad(input_tensor, (k // 2, k // 2, k // 2, k // 2), mode='constant', value=0)
# 展平输入张量中每个点及其周围KxK范围内的点
unfold_tensor = F.unfold(input_tensor, k, padding=(k // 2) * dilation, dilation=dilation) # B, CxKxK, HW
# print(unfold_tensor.shape)
unfold_tensor = unfold_tensor.reshape(B, C, k ** 2, H, W)
# 计算余弦相似度
if sim == 'cos':
similarity = F.cosine_similarity(unfold_tensor[:, :, k * k // 2:k * k // 2 + 1], unfold_tensor[:, :, :], dim=1)
elif sim == 'dot':
similarity = unfold_tensor[:, :, k * k // 2:k * k // 2 + 1] * unfold_tensor[:, :, :]
similarity = similarity.sum(dim=1)
else:
raise NotImplementedError
# 移除中心点的余弦相似度,得到[KxK-1]的结果
similarity = torch.cat((similarity[:, :k * k // 2], similarity[:, k * k // 2 + 1:]), dim=1)
# 将结果重塑回[B, KxK-1, H, W]的形状
similarity = similarity.view(B, k * k - 1, H, W)
return similarity
if __name__ == '__main__':
hr_feat = torch.rand(1, 128, 512, 512)
lr_feat = torch.rand(1, 128, 256, 256)
model = FreqFusion(hr_channels=128, lr_channels=128)
mask_lr, hr_feat, lr_feat = model(hr_feat=hr_feat, lr_feat=lr_feat)
print(mask_lr.shape)