当前位置: 首页 > article >正文

【即插即用涨点模块】CAA上下文锚点注意力机制:有效捕捉全局信息,助力高效涨点【附源码+注释】

《------往期经典推荐------》

一、AI应用软件开发实战专栏【链接】

项目名称项目名称
1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】
3.【手势识别系统开发】4.【人脸面部活体检测系统开发】
5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】
7.【YOLOv8多目标识别与自动标注软件开发】8.【基于深度学习的行人跌倒检测系统】
9.【基于深度学习的PCB板缺陷检测系统】10.【基于深度学习的生活垃圾分类目标检测系统】
11.【基于深度学习的安全帽目标检测系统】12.【基于深度学习的120种犬类检测与识别系统】
13.【基于深度学习的路面坑洞检测系统】14.【基于深度学习的火焰烟雾检测系统】
15.【基于深度学习的钢材表面缺陷检测系统】16.【基于深度学习的舰船目标分类检测系统】
17.【基于深度学习的西红柿成熟度检测系统】18.【基于深度学习的血细胞检测与计数系统】
19.【基于深度学习的吸烟/抽烟行为检测系统】20.【基于深度学习的水稻害虫检测与识别系统】
21.【基于深度学习的高精度车辆行人检测与计数系统】22.【基于深度学习的路面标志线检测与识别系统】
23.【基于深度学习的智能小麦害虫检测识别系统】24.【基于深度学习的智能玉米害虫检测识别系统】
25.【基于深度学习的200种鸟类智能检测与识别系统】26.【基于深度学习的45种交通标志智能检测与识别系统】
27.【基于深度学习的人脸面部表情识别系统】28.【基于深度学习的苹果叶片病害智能诊断系统】
29.【基于深度学习的智能肺炎诊断系统】30.【基于深度学习的葡萄簇目标检测系统】
31.【基于深度学习的100种中草药智能识别系统】32.【基于深度学习的102种花卉智能识别系统】
33.【基于深度学习的100种蝴蝶智能识别系统】34.【基于深度学习的水稻叶片病害智能诊断系统】
35.【基于与ByteTrack的车辆行人多目标检测与追踪系统】36.【基于深度学习的智能草莓病害检测与分割系统】
37.【基于深度学习的复杂场景下船舶目标检测系统】38.【基于深度学习的农作物幼苗与杂草检测系统】
39.【基于深度学习的智能道路裂缝检测与分析系统】40.【基于深度学习的葡萄病害智能诊断与防治系统】
41.【基于深度学习的遥感地理空间物体检测系统】42.【基于深度学习的无人机视角地面物体检测系统】
43.【基于深度学习的木薯病害智能诊断与防治系统】44.【基于深度学习的野外火焰烟雾检测系统】
45.【基于深度学习的脑肿瘤智能检测系统】46.【基于深度学习的玉米叶片病害智能诊断与防治系统】
47.【基于深度学习的橙子病害智能诊断与防治系统】48.【基于深度学习的车辆检测追踪与流量计数系统】
49.【基于深度学习的行人检测追踪与双向流量计数系统】50.【基于深度学习的反光衣检测与预警系统】
51.【基于深度学习的危险区域人员闯入检测与报警系统】52.【基于深度学习的高密度人脸智能检测与统计系统】
53.【基于深度学习的CT扫描图像肾结石智能检测系统】54.【基于深度学习的水果智能检测系统】
55.【基于深度学习的水果质量好坏智能检测系统】56.【基于深度学习的蔬菜目标检测与识别系统】
57.【基于深度学习的非机动车驾驶员头盔检测系统】58.【太基于深度学习的阳能电池板检测与分析系统】
59.【基于深度学习的工业螺栓螺母检测】60.【基于深度学习的金属焊缝缺陷检测系统】
61.【基于深度学习的链条缺陷检测与识别系统】62.【基于深度学习的交通信号灯检测识别】
63.【基于深度学习的草莓成熟度检测与识别系统】64.【基于深度学习的水下海生物检测识别系统】
65.【基于深度学习的道路交通事故检测识别系统】66.【基于深度学习的安检X光危险品检测与识别系统】
67.【基于深度学习的农作物类别检测与识别系统】68.【基于深度学习的危险驾驶行为检测识别系统】
69.【基于深度学习的维修工具检测识别系统】70.【基于深度学习的维修工具检测识别系统】
71.【基于深度学习的建筑墙面损伤检测系统】72.【基于深度学习的煤矿传送带异物检测系统】
73.【基于深度学习的老鼠智能检测系统】74.【基于深度学习的水面垃圾智能检测识别系统】
75.【基于深度学习的遥感视角船只智能检测系统】76.【基于深度学习的胃肠道息肉智能检测分割与诊断系统】
77.【基于深度学习的心脏超声图像间隔壁检测分割与分析系统】78.【基于深度学习的心脏超声图像间隔壁检测分割与分析系统】
79.【基于深度学习的果园苹果检测与计数系统】80.【基于深度学习的半导体芯片缺陷检测系统】
81.【基于深度学习的糖尿病视网膜病变检测与诊断系统】82.【基于深度学习的运动鞋品牌检测与识别系统】
83.【基于深度学习的苹果叶片病害检测识别系统】84.【基于深度学习的医学X光骨折检测与语音提示系统】
85.【基于深度学习的遥感视角农田检测与分割系统】86.【基于深度学习的运动品牌LOGO检测与识别系统】
87.【基于深度学习的电瓶车进电梯检测与语音提示系统】88.【基于深度学习的遥感视角地面房屋建筑检测分割与分析系统】
89.【基于深度学习的医学CT图像肺结节智能检测与语音提示系统】

二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】
五、YOLOv8改进专栏【链接】持续更新中~~
六、YOLO性能对比专栏【链接】,持续更新中~

《------正文------》

目录

  • 摘要
  • 创新点
  • 方法总结
  • CAA模块的作用
  • CAA源码与注释

在这里插入图片描述

论文地址:https://arxiv.org/pdf/2403.06258
代码地址:https://github.com/NUST-Machine-Intelligence-Laboratory/PKINet

摘要

本文提出了一种名为**Poly Kernel Inception Network (PKINet)的轻量级特征提取网络,用于解决遥感图像中目标检测面临的挑战,特别是目标尺度变化大和上下文信息多样的问题。PKINet通过并行使用不同尺度的卷积核来提取多尺度的目标特征,并结合Context Anchor Attention (CAA)**模块来捕获长距离的上下文信息。实验表明,PKINet在四个遥感目标检测基准数据集(DOTA-v1.0、DOTA-v1.5、HRSC2016和DIOR-R)上均取得了优异的性能。

创新点

  1. 多尺度卷积核设计:PKINet采用了Inception风格的卷积模块,并行使用不同尺度的卷积核(无空洞卷积)来提取多尺度的纹理特征,避免了传统大卷积核引入的背景噪声和空洞卷积导致的特征稀疏问题。
    在这里插入图片描述

  2. Context Anchor Attention (CAA)模块:引入CAA模块,通过全局平均池化和1D条状卷积来捕获长距离的上下文信息,增强中心区域的特征表示。

  3. 轻量级设计:通过使用深度可分离卷积和1D卷积,PKINet在保持高性能的同时,显著减少了模型的计算量和参数量。

方法总结

PKINet由四个阶段组成,每个阶段采用Cross-Stage Partial (CSP)**结构,输入特征被分为两部分,分别通过一个简单的**Feed-Forward Network (FFN)**和一系列**PKI Block进行处理。每个PKI Block包含一个PKI Module和一个CAA Module

  1. PKI Module:通过并行的小卷积核和深度可分离卷积提取多尺度的局部特征,并通过1x1卷积进行通道融合,生成具有丰富上下文信息的特征。
  2. CAA Module:通过全局平均池化和1D条状卷积捕获长距离的上下文信息,生成注意力权重,用于增强PKI Module的输出特征。

CAA模块的作用

CAA模块的主要作用是捕获长距离的上下文信息,增强中心区域的特征表示。具体来说:

  1. 全局平均池化:首先对输入特征进行全局平均池化,提取全局上下文信息。
  2. 1D条状卷积:通过水平和垂直方向的1D卷积,近似模拟大卷积核的效果,捕获长距离的像素关系。
  3. 注意力机制:通过Sigmoid函数生成注意力权重,用于增强PKI Module的输出特征,从而提升模型对长距离上下文信息的感知能力。
    在这里插入图片描述

通过PKI Module和CAA Module的协同工作,PKINet能够有效地提取具有局部和全局上下文信息的自适应特征,显著提升了遥感目标检测的性能。

CAA源码与注释

# 论文:Poly Kernel Inception Network for Remote Sensing Detection(CVPR 2024)
# 论文地址:https://arxiv.org/pdf/2403.06258
# 代码地址:https://github.com/NUST-Machine-Intelligence-Laboratory/PKINet
# Context Anchor Attention (CAA) module

from typing import Optional
import torch.nn as nn
import torch

class ConvModule(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            stride: int = 1,
            padding: int = 0,
            groups: int = 1,
            norm_cfg: Optional[dict] = None,
            act_cfg: Optional[dict] = None):
        super().__init__()
        layers = []
        # 添加卷积层
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=(norm_cfg is None)))
        # 添加归一化层(如果配置了)
        if norm_cfg:
            norm_layer = self._get_norm_layer(out_channels, norm_cfg)
            layers.append(norm_layer)
        # 添加激活层(如果配置了)
        if act_cfg:
            act_layer = self._get_act_layer(act_cfg)
            layers.append(act_layer)
        # 将所有层组合成一个顺序容器
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        # 前向传播
        return self.block(x)

    def _get_norm_layer(self, num_features, norm_cfg):
        # 根据配置获取归一化层
        if norm_cfg['type'] == 'BN':
            return nn.BatchNorm2d(num_features, momentum=norm_cfg.get('momentum', 0.1), eps=norm_cfg.get('eps', 1e-5))
        # 如果需要,可以添加更多归一化类型
        raise NotImplementedError(f"Normalization layer '{norm_cfg['type']}' is not implemented.")

    def _get_act_layer(self, act_cfg):
        # 根据配置获取激活层
        if act_cfg['type'] == 'ReLU':
            return nn.ReLU(inplace=True)
        if act_cfg['type'] == 'SiLU':
            return nn.SiLU(inplace=True)
        # 如果需要,可以添加更多激活类型
        raise NotImplementedError(f"Activation layer '{act_cfg['type']}' is not implemented.")

class CAA(nn.Module):
    """Context Anchor Attention"""
    def __init__(
            self,
            channels: int,
            h_kernel_size: int = 11,
            v_kernel_size: int = 11,
            norm_cfg: Optional[dict] = dict(type='BN', momentum=0.03, eps=0.001),
            act_cfg: Optional[dict] = dict(type='SiLU')):
        super().__init__()
        # 平均池化层,用于提取上下文信息
        self.avg_pool = nn.AvgPool2d(7, 1, 3)
        # 第一个卷积模块
        self.conv1 = ConvModule(channels, channels, 1, 1, 0, norm_cfg=norm_cfg, act_cfg=act_cfg)
        # 水平方向的深度可分离卷积模块
        self.h_conv = ConvModule(channels, channels, (1, h_kernel_size), 1, (0, h_kernel_size // 2), groups=channels, norm_cfg=None, act_cfg=None)
        # 垂直方向的深度可分离卷积模块
        self.v_conv = ConvModule(channels, channels, (v_kernel_size, 1), 1, (v_kernel_size // 2, 0), groups=channels, norm_cfg=None, act_cfg=None)
        # 第二个卷积模块
        self.conv2 = ConvModule(channels, channels, 1, 1, 0, norm_cfg=norm_cfg, act_cfg=act_cfg)
        # Sigmoid激活函数,用于生成注意力权重
        self.act = nn.Sigmoid()

    def forward(self, x):
        # 前向传播过程
        # 1. 通过平均池化层提取上下文信息
        # 2. 通过conv1卷积模块
        # 3. 通过h_conv进行水平方向的深度可分离卷积
        # 4. 通过v_conv进行垂直方向的深度可分离卷积
        # 5. 通过conv2卷积模块
        # 6. 通过Sigmoid激活函数生成注意力权重
        attn_factor = self.act(self.conv2(self.v_conv(self.h_conv(self.conv1(self.avg_pool(x))))))
        return attn_factor

# 示例用法,打印输入和输出的形状
if __name__ == "__main__":
    input = torch.randn(1, 64, 128, 128) # 输入 B C H W
    block = CAA(64)
    output = block(input)
    print(input.size())
    print(output.size())


在这里插入图片描述

好了,这篇文章就介绍到这里,喜欢的小伙伴感谢给点个赞和关注,更多精彩内容持续更新~~
关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!


http://www.kler.cn/a/585196.html

相关文章:

  • 21天 - 说说 TCP 的四次挥手?TCP 的粘包和拆包能说说吗?说说 TCP 拥塞控制的步骤?
  • 谷歌 Gemini 2.0 Flash实测:1条指令自动出图+配故事!
  • el-table 插槽踩过的坑 :slot-scope 和#default的区别
  • 代码随想录-回溯
  • 如何优雅地将Collection转为Map?
  • 平安养老险广西分公司2025年“3∙15”金融消费者权益教育宣传活动暨南湖公园健步行活动
  • 【C语言】编译和链接详解
  • Redis的缓存雪崩、缓存击穿、缓存穿透与缓存预热、缓存降级
  • 2025-03-15 学习记录--C/C++-PTA 练习3-4 统计字符
  • 【3D视觉学习笔记2】摄像机的标定、畸变的建模、2D/3D变换
  • python如何获取三个小时之前的时间并输出
  • MATLAB 控制系统设计与仿真 - 26
  • python画图文字显示不全+VScode新建jupyter文件
  • 构建分类树(ElementPlus的二级数据模型)
  • [S32K]SPI
  • Python 语言因其广泛的库与框架资源,诸如 `requests`、`BeautifulSoup
  • 证券交易系统的流程
  • pytorch lightning ddp 逆天分配显存方式
  • 关于重构分析查询界面的思考(未完)
  • 基于Hadoop的城市道路交通数据的可视化分析-Flask