当前位置: 首页 > article >正文

【论文笔记】独属于CV的注意力机制CBAM-Convolutional Block Attention Module

目录

写在前面

一、基数和宽度

二、通道注意力模块(Channel Attention Module)

三、空间注意力模块(Spatial Attention Module)

四、CBAM(Convolutional Block Attention Module)

五、总结


写在前面

        CBAM论文地址:https://arxiv.org/abs/1807.06521

        CBAM(Convolutional Block Attention Module)是2018年被提出的,不同于ViT的Attention,CBAM是为CNN量身定做的Attention模块,实现简单、效果好,你值得拥有。

        为了提高CNN的性能,我们可以从深度(depth)、宽度(width)和基数(cardinality)三个方面入手。深度很好理解,就是模型的层数,层数越多模型越深。下面说一说基数和宽度的区别。

一、基数和宽度

基数(cardinality):指的是并行分支的数量。

宽度(width):指每一层卷积的卷积核数量(即输出特征图的通道数)。

举个GoogLeNet的例子:

        GoogLeNet 的设计中,Inception模块通过组合多个不同大小的卷积核(例如 1x1、3x3、5x5)和池化操作来提取不同尺度的特征。

        增加“宽度”的效果: 我们有一个 Inception 模块,包含 1x1、3x3 和 5x5 的卷积层,以及一个 3x3 的最大池化层。如果我们在该 Inception 模块中增加每个卷积操作的通道数,例如将 1x1 卷积层的输出通道数从 32 增加到 64,将 3x3 卷积层的输出通道数从 64 增加到 128,这种操作就增加了网络的“宽度”。增加“宽度”意味着每个 Inception 模块可以提取更多的特征信息,但同时也增加了计算成本。

        增加基数: 如果我们将每个卷积和池化操作进一步拆分为多个组,例如在每个卷积操作中使用组卷积(group convolution),那么这些并行组卷积操作的数量就类似于增加了“基数”。每个组卷积操作都是一个独立的路径,这些路径的数量增加就代表了基数的增加。基数不仅节省了参数的总数,而且比深度和宽度这两个因素具有更强的表示能力。

        可以看下图,蓝色的线表示模型的基数,红色的数字表示宽度。

        CBAM由两个顺序的子模块组成:通道注意力模块(Channel Attention Module)和空间注意力模块(Spatial Attention Module)。CAM解决的问题是重要的信息是什么(‘what’ is meaningful given an input image)、SAM解决重要的信息在哪里(‘where’ is an informative part)。这两个模块都使用了增加基数的方式,提升模型的表达能力。

二、通道注意力模块(Channel Attention Module

        我们知道CNN的每个通道可以提取不同的特征(也就是Feature Map),通道注意力模块(Channel Attention Module)的主要作用是自适应地调整和增强输入Feature Map中每个通道的重要性。它通过学习每个通道对于当前任务的重要性权重,从而对不同通道进行加权,增强关键信息的表达,同时抑制不相关或冗余的特征。这种机制能够使神经网络更高效地利用信息,提高模型的性能和表达能力。

        通俗的说,CAM就是判断每个Feature Map的重要程度。即下图中相同颜色的部分会有一个标记重要程度的权重。

        CAM使用并行的平均池化和最大池化,每个池化分别经过两个卷积模块,然后相加经过sigmoid得到注意力概率图,公式如下:

        其中σ为Sigmoid型函数,MLP这里使用的是两个卷积层,权重W_0W_1是共享的。

        公式不直观,示意图如下,假设输入是一个32通道244x244的Feature Map,输出是32x1x1,表示32个Feature Map的重要程度。

        使用pooling可以捕获每个通道特征的强度。平均池化(Average Pooling) 能够捕获每个通道的整体激活分布信息。它反映了一个通道中所有特征点的平均响应,能够平滑地代表特征的整体强度。最大池化(Max Pooling) 捕获的是特征图中的最强激活信号。它能够突出特征图中最显著的特征,强调特征中的极端值。

        平均池化提供了特征分布的全局性信息,而最大池化提供了最显著特征的信息。通过结合这两种池化方法,通道注意力模块能够更好地理解哪些特征通道在给定输入图像中最重要。

        两种池化之后的卷积层共享参数,而且两个卷积层的中间维度使用in_planes//16,最大限度的减少参数,因为注意力模块只提供一个注意力概率图,提取特征并不是它的首要任务,所以不需要太多参数。

        代码示例:

class ChannelAttention(nn.Module):
    def __init__(self, in_planes):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

        

三、空间注意力模块(Spatial Attention Module

        SAM首先沿着通道轴分别使用平均池化和最大池化操作,并将它们连接起来,然后经过一个7x7的卷积层将通道数变成1,最后是Sigmoid得到结果,是一个宽高和输入一致通道数为1的概率图。

        SAM就是判断Feature Map中每个部分(像素)的重要程度,需要综合每个部分所有通道的特征。即下图中相同颜色的部分会有一个标记重要程度的权重。

公式如下:

       其中,σ表示Sigmoid函数,f^{7*7}表示大小为7×7的卷积运算。

        补充一下,这里的卷积核大小是7x7,而上面CAM的是1x1,这是因为CAM关注的是整个特征图中的全局通道信息1x1 卷积核适合这种只需在通道维度上操作的情况。SAM关注的是特征图中的局部和全局空间信息7x7 卷积核适合捕捉空间维度上的局部和全局特征。

        示意图如下,仍然假设输入是一个32通道244x244的Feature Map,输出是1x244x244。

        同时使用平均池化和最大池化的原因和CAM类似,平均池化提供了特征分布的全局性信息,而最大池化提供了最显著特征的信息。通过结合这两种池化方法,SAM能够更好地理解Feature Map中哪些区域最重要。

四、CBAM(Convolutional Block Attention Module)

        有了CAM和SAM,就剩最后一个问题,这两个模块怎么摆放,是并行放置还是顺行方式呢?作者发现顺序排列比平行排列的效果更好。下面是CBAM完整的结构图,我们随意在CBAM之前放几个卷积层:

CBAM完整的代码:

import torch
import torch.nn as nn


class ChannelAttention(nn.Module):
    def __init__(self, in_planes):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)


class DemoNet(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(DemoNet, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)

        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention()

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        out = self.relu(out)

        return out


if __name__ == '__main__':
    input = torch.randn(1, 3, 224, 224)
    demo_net = DemoNet(inplanes=3, planes=32)
    output = demo_net(input)
    print(output)

        这里还有一个CBAM应用到ResNet的完整例子:https://github.com/luuuyi/CBAM.PyTorch

五、总结

1.CBAM是一个轻量级和通用的模块,它可以无缝地集成到任何CNN架构中,而开销可以忽略不计,并且可以与基础CNN一起进行端到端训练;

2.通道注意力模块(Channel Attention Module)关注每个通道的Feature Map的重要程度;

3.空间注意力模块(Spatial Attention Module)关注Feature Map每个部分(像素)的重要程度;

4.CBAM由通道注意力模块和空间注意力模块两个模块组成,同时兼顾了通道与空间特征的表达,相比传统的卷积层参数更少、效果更好。

        CBAM就介绍到这里,关注不迷路(*^▽^*)


http://www.kler.cn/a/284789.html

相关文章:

  • 15 个改变世界的开源项目:塑造现代技术的先锋力量
  • https网站 请求http图片报错:net::ERR_SSL_PROTOCOL_ERROR
  • Vue中优雅的使用Echarts的三种方式
  • MySQL如何利用索引优化ORDER BY排序语句
  • 【C/C++】CreateThread 与 _beginthreadex, 应该使用哪一个?为什么?
  • AI 大模型如何赋能电商行业,引领变革
  • Ubuntu上安装配置(jdk/tomcat/ufw防火墙/mysql)+mysql卸载
  • ssm面向企事业单位的项目申报小程序论文源码调试讲解
  • 大数据处理从零开始————1.Hadoop介绍
  • 50ETF期权合约要素有哪些?50ETF期权合约组成构成分享
  • MFC工控项目实例之九选择下拉菜单主界面文本框显示菜单名
  • Python算法工程师面试整理-Python 在算法中的应用
  • Java基础——方法引用、单元测试、XML、注解
  • mysql集群
  • es重启后调大恢复并发参数,加速分片分配
  • 美团8/31—24年秋招【技术】第四场
  • 算法的空间复杂度
  • 【Redis】持久化——rdb机制
  • 零基础国产GD32单片机编程入门(九)低功耗模式实战含源码
  • 掌握CHECK约束:确保数据准确性的关键技巧
  • 【网络】HTTPS——HTTP的安全版本
  • GalaChain 全面剖析:为 Web3 游戏和娱乐而生的创新区块链
  • 速盾:Nginx使用CDN之后获取真实的用户IP
  • 机器学习--核心要点总结
  • k8s 存储
  • Spark自定义函数例子