当前位置: 首页 > article >正文

【Block总结】SGE注意力机制

一、论文介绍

论文链接:https://arxiv.org/pdf/1905.09646

  • 研究背景:论文首先提及了在计算机视觉领域,特征分组的思想由来已久,并介绍了相关背景。
  • 研究目的:旨在通过引入SGE模块,改善特征图的空间分布,提升模型对特定语义特征的表示能力。
  • 实验平台:实验代码和预训练模型可在https://github.com/implus/PytorchInsight 获取。
    在这里插入图片描述

二、创新点

  • SGE模块:提出了一种新的空间分组增强模块,该模块具有轻量级、几乎不需要额外参数和计算量的特点。
  • 特征表示增强:SGE模块能够显著改善特征图的空间分布,提升模型对特定语义特征的捕捉能力。
  • 对比实验:在ImageNet和COCO 2017数据集上与多种先进的注意力模块进行了对比实验,验证了SGE模块的有效性。
    在这里插入图片描述

三、方法

  • SGE模块设计:介绍了SGE模块的具体设计,包括如何通过分组和增强操作来改善特征图的空间分布。
  • 特征分组:将特征图按空间位置进行分组,每组包含特定区域的特征向量。
  • 增强操作:对每个组的特征向量进行增强处理,以提高其表示能力。

四、模块作用

  • 改善空间分布:SGE模块通过分组和增强操作,改善了特征图的空间分布,使得语义相关区域的激活值更加突出。
  • 抑制噪声:在增强语义相关区域的同时,SGE模块还能有效抑制大量噪声,提高特征表示的清晰度。
  • 提升准确性:实验结果表明,引入SGE模块后,模型的准确性得到了显著提升。

五、改进的效果

  • ImageNet数据集:在ImageNet数据集上,SGE-ResNet50和SGE-ResNet101的Top-1和Top-5准确率均超过了多种先进的注意力模块。
  • COCO 2017数据集:在COCO 2017数据集上,SGE模块也表现出色,提升了对象检测的平均精度(AP)。
  • 可视化分析:通过可视化分析,进一步验证了SGE模块在改善特征图空间分布和提升语义特征表示能力方面的有效性。

代码:

import torch
from torch import nn
from torch.nn.parameter import Parameter

__all__ = ['SGELayer']


class SGELayer(nn.Module):
    def __init__(self, groups=64):
        super(SGELayer, self).__init__()
        self.groups = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.weight = Parameter(torch.zeros(1, groups, 1, 1))
        self.bias = Parameter(torch.ones(1, groups, 1, 1))
        self.sig = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()
        x = x.view(b * self.groups, -1, h, w)
        xn = x * self.avg_pool(x)
        xn = xn.sum(dim=1, keepdim=True)

        t = xn.view(b * self.groups, -1)
        t = t - t.mean(dim=1, keepdim=True)
        std = t.std(dim=1, keepdim=True) + 1e-5
        t = t / std

        t = t.view(b, self.groups, h, w)
        t = t * self.weight + self.bias
        t = t.view(b * self.groups, 1, h, w)

        x = x * self.sig(t)
        x = x.view(b, c, h, w)

        return x
if __name__ == '__main__':
    # 生成随机输入数据
    input_data = torch.randn(1, 32, 640, 480)
    mca = SGELayer(16)
    output = mca(input_data)
    # 打印输入和输出形状
    print("Input size:", input_data.size())
    print("Output size:", output.size())

http://www.kler.cn/a/465993.html

相关文章:

  • [Linux]进程间通信-共享内存与消息队列
  • 小程序学习06——uniapp组件常规引入和easycom引入语法
  • connect to host github.com port 22: Connection timed out 的解决方法
  • 【GUI-pyqt5】QWidget类
  • 框架Tensorflow2
  • MATLAB程序转C# WPF,dll集成,混合编程
  • linux内核PWM子系统笔记
  • 论文精读:Root Cause Analysis in Microservice Using Neural Granger Causal Discovery
  • 用python重写了座位表生成器
  • 仓库叉车高科技安全辅助设备——AI防碰撞系统N2024G-2
  • 【74HC192减法24/20/72进制】2022-5-17
  • 在 pandas.Grouper() 中,freq 参数用于指定时间频率,它定义了如何对时间序列数据进行分组。freq 的值可以是多种时间单位
  • 发现一个可用的免费docker镜像源
  • AI智能生成PPT,告别手工操作的新选择
  • 安卓11 SysteUI添加按钮以及下拉状态栏的色温调节按钮
  • MATLAB画柱状图
  • 【Spring学习】为什么Spring中的IOC(控制反转)能够降低耦合性(解耦)?
  • springboot和vue项目前后端交互
  • 竞品分析对于ASO优化的重要性
  • MySql---进阶篇(六)---SQL优化
  • 在 SQL 中获取第m个开始的n条记录方法汇总
  • 亚远景-ASPICE与ISO 26262:汽车软件开发与功能安全的协同作用
  • GitHub Actions 工作流编写指南
  • Mysql8主从复制(兼容低高版本)
  • 【AI部署】腾讯云每月1w小时免费GPU获取
  • DBSCAN 聚类 和 gmm 聚类 测试