当前位置: 首页 > article >正文

计算机视觉之 SE 注意力模块

计算机视觉之 SE 注意力模块

一、简介

SEBlock 是一个自定义的神经网络模块,主要用于实现 Squeeze-and-Excitation(SE)注意力机制。SE 注意力机制通过全局平均池化和全连接层来重新校准通道的权重,从而增强模型的表达能力。

原论文:《Squeeze-and-Excitation Networks》

二、语法和参数

语法
class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        ...
    def forward(self, x):
        ...
参数
  • in_channels:输入特征的通道数。
  • reduction:通道缩减比例,默认为 16。

三、实例

3.1 初始化和前向传播
  • 代码
import torch
import torch.nn as nn

class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SEBlock, self).__init__()
        reduced_channels = max(in_channels // reduction, 1)
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, reduced_channels, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(reduced_channels, in_channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        batch_size, channels, _, _ = x.size()
        # Squeeze
        y = self.global_avg_pool(x).view(batch_size, channels)
        # Excitation
        y = self.fc(y).view(batch_size, channels, 1, 1)
        # Scale
        return x * y.expand_as(x)
  • 输出
加权图像输出
3.2 应用在示例数据上
  • 代码
import torch

# 创建示例输入数据
input_tensor = torch.randn(1, 64, 32, 32)  # (batch_size, in_channels, height, width)

# 初始化 SEBlock 模块
se_block = SEBlock(in_channels=64, reduction=16)

# 前向传播
output_tensor = se_block(input_tensor)
print(output_tensor.shape)
  • 输出
torch.Size([1, 64, 32, 32])

四、注意事项

  1. SEBlock 模块通过全局平均池化和全连接层来重新校准通道的权重,从而增强模型的表达能力。
  2. 在使用 SEBlock 时,确保输入特征的通道数和缩减比例设置合理,以避免计算开销过大。
  3. 该模块主要用于图像数据处理,适用于各种计算机视觉任务,如图像分类、目标检测等。


http://www.kler.cn/a/288377.html

相关文章:

  • 左值引用(Lvalue Reference)和右值引用(Rvalue Reference)详解
  • 设计模式 行为型 责任链模式(Chain of Responsibility Pattern)与 常见技术框架应用 解析
  • 【SpringAOP】Spring AOP 底层逻辑:切点表达式与原理简明阐述
  • JavaScript系列(16)--原型继承
  • 04、Redis深入数据结构
  • GO随记:不使用主键id 如何分表与mysql大表
  • 微信小程序接入客服功能
  • 逆向工程核心原理 Chapter23 | DLL注入
  • 【舍入,取整,取小数,取余数丨Excel 函数】
  • 探索四川财谷通信息技术有限公司抖音小店的独特魅力
  • 收银系统源码-收银台UI自定义
  • 51单片机-第九节-AT24C02存储器(I2C总线)
  • 代码随想录算法训练营第35天 | 416.分割等和子集
  • PLUTO: 推动基于模仿学习的自动驾驶规划的极限
  • AI智能电销机器人的优势是什么,有什么特点?
  • Python群发邮件:如何实现Python邮件群发?
  • 浅谈sizeof() 函数在Arduino中的使用
  • 代码随想录算法训练营_day35
  • ARM 异常处理(21)
  • dfs算法复习
  • Express与SQLite集成教程:轻松实现数据库操作
  • 【概率与统计 动态规划】 808. 分汤
  • Unity3D DOTS系列之BlobAsset核心机制详解
  • UFUG2601-OJ palindrome
  • idea便捷操作
  • Kubernetes 1.20 上将容器从 Docker Engine 改为 Containerd