当前位置: 首页 > article >正文

【Block总结】高效多尺度注意力EMA,超越SE、CBAM、SA、CA等注意力|即插即用

论文信息

标题: Efficient Multi-Scale Attention Module with Cross-Spatial Learning

作者: Daliang Ouyang, Su He, Guozhong Zhang, Mingzhu Luo, Huaiyong Guo, Jian Zhan, Zhijie Huang

论文链接: https://arxiv.org/pdf/2305.13563v2

GitHub链接: https://github.com/YOLOonMe/EMA-attention-module
在这里插入图片描述

创新点

该论文提出了一种新颖的高效多尺度注意力模块(EMA),旨在通过跨空间学习来提升特征表示的效果,同时降低计算开销。EMA模块的设计重点在于:

  • 信息保留: 在每个通道上保留信息,确保特征的完整性。
  • 计算效率: 通过重塑部分通道为批处理维度,减少计算负担。
  • 多尺度学习: 结合多尺度特征,增强模型对不同尺度信息的捕捉能力。

方法

EMA模块的核心方法包括:

  1. 通道重塑: 将部分通道重塑为批处理维度,并将通道维度分组为多个子特征,以实现更高效的信息处理。

  2. 跨维度交互: 通过跨维度交互,聚合两个并行分支的输出特征,捕获像素级的成对关系。

  3. 并行子网络: 设计多尺度并行子网络,以建立短期和长期依赖关系,从而增强特征表示能力。

在这里插入图片描述

EMA模块的信息保留与计算效率平衡

信息保留机制

EMA(Efficient Multi-Scale Attention)模块通过以下几种方式实现信息的有效保留:

  1. 通道重塑: EMA模块将部分通道重塑为批处理维度,并将通道维度分组为多个子特征。这种设计确保了每个通道的信息能够被有效保留,同时避免了通道维度的削减,从而增强了特征的表达能力[1][3]。

  2. 跨维度交互: 在EMA模块中,两个并行分支的输出特征通过跨维度交互进行聚合。这种交互机制能够捕捉到像素级的成对关系,从而进一步提升特征的丰富性和准确性[2][3]。

  3. 多尺度并行子网络: EMA模块采用了多尺度并行子网络结构,结合了1x1和3x3卷积核的特征处理。这种结构能够有效捕获不同尺度的信息,确保在特征提取过程中不会丢失重要信息[2][3]。

计算效率提升

在计算效率方面,EMA模块通过以下方式优化了计算过程:

  1. 减少计算开销: 通过将部分通道重塑为批处理维度,EMA模块能够在不显著增加计算成本的情况下,保持高效的信息处理。这种方法使得模型在处理大规模数据时更加高效[1][2]。

  2. 并行处理: EMA模块的设计允许多个子网络并行处理特征,这不仅提高了计算效率,还减少了模型的顺序处理需求,从而加快了整体计算速度[3]。

  3. 适度的模型尺寸: EMA模块的设计确保了模型的尺寸适中,适合在移动终端等资源受限的环境中部署。这种设计使得EMA模块在保持性能的同时,能够有效降低计算资源的消耗[3][2]。

EMA模块通过创新的设计实现了信息保留与计算效率的平衡。其通道重塑、跨维度交互和多尺度并行处理的策略,不仅确保了特征信息的完整性,还显著提高了计算效率。这使得EMA模块在计算机视觉任务中表现出色,尤其是在小目标检测和图像分类等应用中,展现了其广泛的应用潜力和实际意义。

效果

实验结果表明,EMA模块在多个计算机视觉任务中表现优异,尤其是在小目标检测和图像分类任务中,相较于传统的注意力机制(如ECA、CBAM、CA),EMA模块显著提高了特征表示的清晰度和准确性。

实验结果

在广泛的消融研究和实验中,EMA模块在以下数据集上进行了评估:

  • CIFAR-100
  • ImageNet-1k
  • MS COCO
  • VisDrone2019

实验结果显示,EMA模块在这些基准测试中均取得了优于现有方法的性能,尤其在小目标检测任务中,表现出明显的优势。

总结

Efficient Multi-Scale Attention Module with Cross-Spatial Learning通过创新的设计和有效的实现,成功地提升了计算机视觉任务中的特征表示能力,同时降低了计算复杂度。该模块的提出为未来的研究提供了新的思路,尤其是在需要高效处理大规模数据的应用场景中,EMA模块展现了其广泛的应用潜力。

代码

import torch
from torch import nn

class EMA(nn.Module):
    def __init__(self, channels, c2=None, factor=32):
        super(EMA, self).__init__()
        self.groups = factor
        assert channels // self.groups > 0
        self.softmax = nn.Softmax(-1)
        self.agp = nn.AdaptiveAvgPool2d((1, 1))
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
        self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
        self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        b, c, h, w = x.size()
        group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,w
        x_h = self.pool_h(group_x)
        x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
        x_h, x_w = torch.split(hw, [h, w], dim=2)
        x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
        x2 = self.conv3x3(group_x)
        x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
        return (group_x * weights.sigmoid()).reshape(b, c, h, w)


if __name__ == "__main__":
    # 如果GPU可用,将模块移动到 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 输入张量 (batch_size, channels, height, width)
    x = torch.randn(1,32,40,40).to(device)
    # 初始化 pconv 模块
    dim=32
    block = EMA(dim,factor=8)
    print(block)
    block = block.to(device)
    # 前向传播
    output = block(x)
    print("输入:", x.shape)
    print("输出:", output.shape)

输出结果:

在这里插入图片描述


http://www.kler.cn/a/526751.html

相关文章:

  • [STM32 - 野火] - - - 固件库学习笔记 - - -十二.基本定时器
  • 爬虫基础(五)爬虫基本原理
  • 基于 NodeJs 一个后端接口的创建过程及其规范 -- 【elpis全栈项目】
  • Linux 6.x版本内核的proc目录组织
  • 网络安全技术简介
  • Origami Agents:AI驱动的销售研究工具,助力B2B销售团队高效增长
  • RK3568 opencv播放视频
  • 第23节课:前端调试技巧—掌握浏览器开发者工具与性能优化
  • 理解PLT表和GOT表
  • 新春登蛇山:告别岁月,启航未来
  • LeetCode 0219.存在重复元素 II:哈希表
  • 【Leetcode刷题记录】166. 分数到小数
  • [EAI-022] FuSe,在VLA模型基础上,融合触觉和语音等异构模态信息
  • 动态规划两个数组dp问题系列一>最长公共子序列
  • 网站快速收录:利用RSS订阅提升效率
  • fpga系列 硬件:FPGA VITIS PS端HELLO WORLD在 ZYNQ EBAZ4203板上实现
  • ADC 精度 第二部分:总的未调整误差解析
  • 33333333333
  • Autogen_core 测试代码:test_cancellation.py
  • Electron工具Electron Fiddle
  • 【TypeScript】TypeScript 运算符
  • AI 的安全性与合规性:实践中的最佳安全策略
  • 【Block总结】PKI 模块,无膨胀多尺度卷积,增强特征提取的能力|即插即用
  • 【华为OD-E卷 - 分积木 100分(python、java、c++、js、c)】
  • Autogen_core: test_code_executor.py
  • 算法---快速排序