当前位置: 首页 > article >正文

【深度学习】SAB:空间注意力

@[toc]SAB:空间注意力

SAB:空间注意力

SA(Spatial Attention)空间注意力模块是一种用于计算机视觉任务的注意力机制,它的核心思想是通过学习特征图在空间维度上的重要性,自适应地增强或抑制不同空间位置的特征响应。与通道注意力(如 SENet 或 CBAM 中的通道注意力模块)不同,空间注意力关注的是特征图的空间维度(高度和宽度),而不是通道维度

SA 的核心思想

1.空间注意力机制:

通过对特征图在通道维度上进行聚合(如最大池化或平均池化),生成空间注意力图。

使用卷积层学习空间位置之间的关系。

通过 Sigmoid 函数生成空间注意力权重,对原始特征图进行加权。

残差连接:

为了保留原始特征信息,通常会在空间注意力模块的输出上添加残差连接。

SA 的结构

通道聚合:

对输入特征图在通道维度上进行聚合(如最大池化或平均池化),生成一个空间描述图。

卷积层:

使用卷积层学习空间位置之间的关系。

通常使用一个 7x7 或 3x3 的卷积核来捕获局部空间信息。

Sigmoid 激活函数:

对卷积层的输出进行归一化,生成空间注意力权重。

特征加权:

将生成的空间注意力权重与输入特征图相乘,得到加权后的特征。

残差连接:

将加权后的特征与输入特征相加,保留原始信息。

代码实现

import torch
import torch.nn as nn

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        """
        初始化 SA 模块
        :param kernel_size: 卷积核大小,默认为 7
        """
        super(SpatialAttention, self).__init__()
        self.kernel_size = kernel_size

        # 通道聚合(最大池化和平均池化)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        # 卷积层
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        """
        前向传播
        :param x: 输入特征图,形状为 [batch_size, channel, height, width]
        :return: 加权后的特征图
        """
        # 通道聚合
        max_out = self.max_pool(x)  # [batch_size, channel, 1, 1]
        avg_out = self.avg_pool(x)  # [batch_size, channel, 1, 1]

        # 拼接
        out = torch.cat([max_out, avg_out], dim=1)  # [batch_size, 2, height, width]

        # 卷积层
        out = self.conv(out)  # [batch_size, 1, height, width]

        # 生成空间注意力权重
        spatial_weights = self.sigmoid(out)  # [batch_size, 1, height, width]

        # 特征加权
        return x * spatial_weights

总结

SA 空间注意力模块通过学习特征图在空间维度上的重要性,能够自适应地增强重要区域的特征并抑制不重要区域的特征。


http://www.kler.cn/a/502856.html

相关文章:

  • 解决Qt打印中文字符出现乱码
  • lerna使用指南
  • Cython全教程2 多种定义方式
  • 【大数据】机器学习 -----关于data.csv数据集分析案例
  • 解读若依微服务架构图:架构总览、核心模块解析、消息与任务处理、数据存储与缓存、监控与日志
  • Compose 的集成与导航
  • 【深度学习】数据操作入门
  • web-app uniapp监测屏幕大小的变化对数组一行展示数据作相应处理
  • vue3+ts的<img :src=““ >写法
  • Unity搭配VS Code使用
  • 基于“大型园区”网络设计
  • LeetCode 3270.求出数字答案:每位分别计算 或 for循环
  • 重回C语言之老兵重装上阵(三)C语言储存类
  • 【Uniapp-Vue3】@import导入css样式及scss变量用法与static目录
  • 数据结构:栈(Stack)和队列(Queue)—面试题(一)
  • 2、第一个GO 程序
  • Win32汇编学习笔记09.SEH和反调试
  • 数据结构(Java版)第七期:LinkedList与链表(二)
  • 3 生成器(Builder)模式
  • Java算法 数据结构 栈 队列 优先队列 比较器
  • C#中前台线程与后台线程的区别及设置方法
  • 《自动驾驶与机器人中的SLAM技术》ch8:基于 IESKF 的紧耦合 LIO 系统
  • 灌区闸门自动化控制系统-精准渠道量测水-灌区现代化建设
  • 相加交互效应函数发布—适用于逻辑回归、cox回归、glmm模型、gee模型
  • RabbitMQ 在 Spring Boot 项目中的深度应用与实战解析
  • Java异步任务