当前位置: 首页 > article >正文

每日Attention学习26——Dynamic Weighted Feature Fusion

模块出处

[ACM MM 23] [link] [code] Efficient Parallel Multi-Scale Detail and Semantic Encoding
Network for Lightweight Semantic Segmentation


模块名称

Dynamic Weighted Feature Fusion (DWFF)


模块作用

双级特征融合


模块结构

在这里插入图片描述


模块思想

我们提出了 DWFF 策略,选择性地关注特征图中信息量最大的部分,以有效地结合浅层和深层特征,提高分割精度。DWFF 可用于在具有细粒度细节的区域中更重地加权浅层特征,在具有较高语义信息的区域中更重地加权深层特征,从而实现更好的特征组合和准确的分割。


模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class DWFF(nn.Module):
    def __init__(self,
                 in_channels: int,
                 height: int = 2,
                 reduction: int = 8,
                 bias: bool = False) -> None:
        super(DWFF, self).__init__()

        self.height = height
        d = max(int(in_channels / reduction), 4)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_du = nn.Sequential(
            nn.Conv2d(in_channels, d, 1, padding=0, bias=bias),
            nn.BatchNorm2d(d),
            nn.LeakyReLU(0.2)
        )
        self.fcs = nn.ModuleList([])
        for i in range(self.height):
            self.fcs.append(nn.Conv2d(d, in_channels, kernel_size=1, stride=1, bias=bias))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, inp_feats):
        batch_size = inp_feats[0].shape[0]
        n_feats = inp_feats[0].shape[1]
        inp_feats = torch.cat(inp_feats, dim=1)
        inp_feats = inp_feats.view(batch_size, self.height, n_feats, inp_feats.shape[2], inp_feats.shape[3])
        feats_U = torch.sum(inp_feats, dim=1)
        feats_S = self.avg_pool(feats_U)
        feats_Z = self.conv_du(feats_S)
        attention_vectors = [fc(feats_Z) for fc in self.fcs]
        attention_vectors = torch.cat(attention_vectors, dim=1)
        attention_vectors = attention_vectors.view(batch_size, self.height, n_feats, 1, 1)
        attention_vectors = self.softmax(attention_vectors)
        feats_V = torch.sum(inp_feats * attention_vectors, dim=1)
        return feats_V
    

if __name__ == '__main__':
    dwff = DWFF(in_channels=64)
    x1 = torch.randn([2, 64, 16, 16])
    x2 = torch.randn([2, 64, 16, 16])
    out = dwff([x1, x2])
    print(out.shape)  # 2, 64, 16, 16


http://www.kler.cn/a/588786.html

相关文章:

  • 泽众TestOne推出快速测试用例设计,让自动化更快捷
  • Keytool常见问题全解析:从环境配置到公钥提取
  • Advanced Intelligent Systems 软体机器手助力截肢者玩转鼠标
  • DeepSeek-R1思路训练多模态大模型-Vision-R1开源及实现方法思路
  • JavaScript相关面试题
  • 前端面试题---vue项目打包时, 内存不足了怎么办 为什么会出现这样的情况
  • Web开发-PHP应用文件操作安全上传下载任意读取删除目录遍历文件包含
  • 深入解析工厂模式及其C#实现
  • 【k8s002】k8s健康检查与故障诊断
  • Ubuntu下安装后anaconda出现conda:command not found
  • 使用 WebP 优化 GPU 纹理占用
  • 初阶数据结构--复杂度
  • Flutter桌面开发(三、widget布局与表单)
  • Python手写机器学习的“线性回归”算法
  • 深度学习CNN特征提取与匹配
  • 【AWS入门】AWS云计算简介
  • 机器学习 [白板推导](三)[线性分类]
  • 【C++】一文吃透STL容器——list
  • deepseek GRPO算法保姆级讲解(数学原理+源码解析+案例实战)
  • 详解Flutter单线程模型,以及Flutter是如何通过单线程实现异步的