每日Attention学习26——Dynamic Weighted Feature Fusion
模块出处
[ACM MM 23] [link] [code] Efficient Parallel Multi-Scale Detail and Semantic Encoding
Network for Lightweight Semantic Segmentation
模块名称
Dynamic Weighted Feature Fusion (DWFF)
模块作用
双级特征融合
模块结构
模块思想
我们提出了 DWFF 策略,选择性地关注特征图中信息量最大的部分,以有效地结合浅层和深层特征,提高分割精度。DWFF 可用于在具有细粒度细节的区域中更重地加权浅层特征,在具有较高语义信息的区域中更重地加权深层特征,从而实现更好的特征组合和准确的分割。
模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class DWFF(nn.Module):
def __init__(self,
in_channels: int,
height: int = 2,
reduction: int = 8,
bias: bool = False) -> None:
super(DWFF, self).__init__()
self.height = height
d = max(int(in_channels / reduction), 4)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_du = nn.Sequential(
nn.Conv2d(in_channels, d, 1, padding=0, bias=bias),
nn.BatchNorm2d(d),
nn.LeakyReLU(0.2)
)
self.fcs = nn.ModuleList([])
for i in range(self.height):
self.fcs.append(nn.Conv2d(d, in_channels, kernel_size=1, stride=1, bias=bias))
self.softmax = nn.Softmax(dim=1)
def forward(self, inp_feats):
batch_size = inp_feats[0].shape[0]
n_feats = inp_feats[0].shape[1]
inp_feats = torch.cat(inp_feats, dim=1)
inp_feats = inp_feats.view(batch_size, self.height, n_feats, inp_feats.shape[2], inp_feats.shape[3])
feats_U = torch.sum(inp_feats, dim=1)
feats_S = self.avg_pool(feats_U)
feats_Z = self.conv_du(feats_S)
attention_vectors = [fc(feats_Z) for fc in self.fcs]
attention_vectors = torch.cat(attention_vectors, dim=1)
attention_vectors = attention_vectors.view(batch_size, self.height, n_feats, 1, 1)
attention_vectors = self.softmax(attention_vectors)
feats_V = torch.sum(inp_feats * attention_vectors, dim=1)
return feats_V
if __name__ == '__main__':
dwff = DWFF(in_channels=64)
x1 = torch.randn([2, 64, 16, 16])
x2 = torch.randn([2, 64, 16, 16])
out = dwff([x1, x2])
print(out.shape) # 2, 64, 16, 16