当前位置: 首页 > article >正文

计算机视觉之 GSoP 注意力模块

计算机视觉之 GSoP 注意力模块

一、简介

GSopBlock 是一个自定义的神经网络模块,主要用于实现 GSoP(Global Second-order Pooling)注意力机制。GSoP 注意力机制通过计算输入特征的协方差矩阵,捕捉全局二阶统计信息,从而增强模型的表达能力。

原论文:《Global Second-order Pooling Convolutional Networks (arxiv.org)》

二、语法和参数

语法
class GSopBlock(nn.Module):
    def __init__(self, in_channels, mid_channels):
        ...
    def forward(self, x):
        ...
参数
  • in_channels:输入特征的通道数。
  • mid_channels:中间层的通道数,用于调整特征维度。

三、实例

3.1 初始化和前向传播
  • 代码
import torch
import torch.nn as nn

class GSopBlock(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(GSopBlock, self).__init__()
        self.conv2d1 = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True)
        )
        self.row_wise_conv = nn.Sequential(
            nn.Conv2d(
                mid_channels, 4*mid_channels,
                kernel_size=(mid_channels, 1),
                groups=mid_channels, bias=False),
            nn.BatchNorm2d(4*mid_channels),
        )
        self.conv2d2 = nn.Sequential(
            nn.Conv2d(4*mid_channels, in_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Step 1: 调整通道数
        x = self.conv2d1(x)
        batch_size, channels, height, width = x.size()

        # Step 2: 展平输入
        x_flat = x.view(batch_size, channels, -1)

        # Step 3: 计算协方差矩阵
        x_mean = x_flat.mean(dim=-1, keepdim=True)
        x_centered = x_flat - x_mean
        cov_matrix = torch.bmm(x_centered, x_centered.transpose(1, 2)) / (height * width)
        cov_matrix = cov_matrix.unsqueeze(-1)

        # Step 4: 行方向卷积
        cov_features = self.row_wise_conv(cov_matrix)

        # Step 5: 生成权重向量
        weight_vector = self.conv2d2(cov_features)

        # Step 6: 计算最终输出
        x_out = x * weight_vector

        return x_out
  • 输出
经过加权后的图像
3.2 应用在示例数据上
  • 代码
import torch

# 创建示例输入数据
input_tensor = torch.randn(1, 64, 32, 32)  # (batch_size, in_channels, height, width)

# 初始化 GSopBlock 模块
gsop_block = GSopBlock(in_channels=64, mid_channels=16)

# 前向传播
output_tensor = gsop_block(input_tensor)
print(output_tensor.shape)
  • 输出
torch.Size([1, 64, 32, 32])

四、注意事项

  1. GSopBlock 模块适用于捕捉输入特征之间的全局二阶统计信息,增强模型的表达能力。
  2. 在使用 GSopBlock 时,确保输入特征的通道数和中间层的通道数设置合理,以避免计算开销过大。
  3. 该模块主要用于图像数据处理,适用于各种计算机视觉任务,如图像分类、目标检测等。


http://www.kler.cn/a/288906.html

相关文章:

  • 计算机毕业设计Python+卷积神经网络租房推荐系统 租房大屏可视化 租房爬虫 hadoop spark 58同城租房爬虫 房源推荐系统
  • 《Mycat核心技术》第06章:Mycat问题处理总结
  • SpringCloud 系列教程:微服务的未来(二)Mybatis-Plus的条件构造器、自定义SQL、Service接口基本用法
  • langchain使用FewShotPromptTemplate出现KeyError的解决方案
  • 【C++基础】09、结构体
  • vue3封装而成的APP ,在版本更新后,页面显示空白
  • 《第二十六章 IO 流 - 字节流》
  • 在项目中使用 redis存储 数据,提高 项目运行速度
  • 【Linux】 理解 Linux 中的 `dup2` 函数
  • Spring框架中的@EventListener注解浅谈
  • 【C++ Primer Plus习题】8.2
  • 直播路由器的原理是什么
  • Linux CentOS 7.39 安装mysql8
  • rabbitmq发送的消息接收不到
  • 告别文档处理烦恼,PDF Guru Anki一键搞定所有
  • 多目标应用:基于双存档模型的多模态多目标进化算法(MMOHEA)的移动机器人路径规划研究(提供MATLAB代码)
  • C语言之猜数字小游戏
  • 【苍穹外卖】Day3 菜品接口
  • dinput8.dll错误应该如何修复呢?五种快速修复dinput8.dll错误的问题
  • SpringBoot开发——初步了解SpringBoot
  • CephX 认证机制及用户管理
  • 功能测试常用的测试用例大全
  • 大模型入门 ch01:大模型概述
  • 强化学习,第 5 部分:时间差异学习
  • 数据结构——单链表相关操作
  • C# 开发环境搭建(Avalonia UI、Blazor Web UI、Web API 应用示例)