当前位置: 首页 > article >正文

CBAM-2018学习笔记

名称:

Convolutional Block Attention Module (CBAM)

来源:

CBAM: Convolutional Block Attention Module

相关工作:

#ResNet #GoogleNet #ResNeXt #Network-engineering #Attention-mechanism

创新点:

fpg0umoj.4ze.png

贡献:

  • 提出CBAM
  • 验证了其有效性
  • 改善提高了以往模型的性能

代码:

  
import torch  
from torch import nn  
  
  
class ChannelAttention(nn.Module):  
    def __init__(self, in_planes, ratio=16):  
        super(ChannelAttention, self).__init__()  
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  
        self.max_pool = nn.AdaptiveMaxPool2d(1)  
  
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)  
        self.relu1 = nn.ReLU()  
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))  
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))  
        out = avg_out + max_out  
        return self.sigmoid(out)  
  
  
class SpatialAttention(nn.Module):  
    def __init__(self, kernel_size=7):  
        super(SpatialAttention, self).__init__()  
  
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'  
        padding = 3 if kernel_size == 7 else 1  
  
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)  # 7,3     3,1  
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        avg_out = torch.mean(x, dim=1, keepdim=True)  
        max_out, _ = torch.max(x, dim=1, keepdim=True)  
        x = torch.cat([avg_out, max_out], dim=1)  
        x = self.conv1(x)  
        return self.sigmoid(x)  
  
  
class CBAM(nn.Module):  
    def __init__(self, in_planes, ratio=16, kernel_size=7):  
        super(CBAM, self).__init__()  
        self.ca = ChannelAttention(in_planes, ratio)  
        self.sa = SpatialAttention(kernel_size)  
  
    def forward(self, x):  
        out = x * self.ca(x)  
        result = out * self.sa(out)  
        return result  
  
  
# 输入 N C H W,  输出 N C H Wif __name__ == '__main__':  
    block = CBAM(64)  
    input = torch.rand(3, 64, 32, 32)  
    output = block(input)  
    print(input.size(), output.size())

http://www.kler.cn/a/516690.html

相关文章:

  • 深度学习系列76:流式tts的一个简单实现
  • 小游戏源码开发搭建技术栈和服务器配置流程
  • 什么是COLLATE排序规则?
  • 编写子程序
  • MyBatis-Plus的插件
  • 探索WPF中的RelativeSource:灵活的资源绑定利器
  • 如何使 LLaMA-Factory 支持 google/gemma-2-2b-jpn-it 的微调
  • 网络(二)协议
  • GIT的常规使用
  • 【MySQL — 数据库增删改查操作】深入解析MySQL的create insert 操作
  • docker 启动镜像命令集合
  • Java 大视界 -- Java 大数据中的异常检测技术与应用(61)
  • ESP8266 OTA固件启动日志里分区解析【2M flash】
  • 【Java实现 通过Easy Excel完成对excel文本数据的读写】
  • 递归的本质
  • Rman还原
  • Yii框架中的Cart组件:实现购物车功能
  • GC(垃圾回收)的分类
  • 使用 Elasticsearch 导航检索增强生成图表
  • linux-centosubuntu本地源配置
  • 蓝桥杯练习日常|c/c++竞赛常用库函数
  • 使用Python爬虫获取1688店铺所有商品信息的完整指南
  • C#高级:常用的扩展方法大全
  • ubuntu系统docker环境搭建
  • STM32调试手段:重定向printf串口
  • 重载C++运算符