当前位置: 首页 > article >正文

YOLOv8改进 - 注意力篇 - 引入SCAM注意力机制

一、本文介绍

作为入门性篇章,这里介绍了SCAM注意力在YOLOv8中的使用。包含SCAM原理分析,SCAM的代码、SCAM的使用方法、以及添加以后的yaml文件及运行记录。

二、SCAM原理分析

SCAM官方论文地址:SCAM文章

SCAM官方代码地址:SCAM代码

SCAM注意力机制(空间上下文感知模块):

空间上下文感知模块(SCAM)在FEM和FFM之后,特征映射已经考虑了局部上下文信息,并且能够很好地表示小对象特征。在此阶段对小目标和背景之间的全局关系进行建模比在主干阶段更有效。利用全局上下文信息来表示像素之间的跨空间关系,可以抑制无用背景,增强目标和背景的区分能力。受GCNet和SCP的启发,SCAM由三个分支组成。第一个部分使用GAP和GMP整合全球信息。第二个分支使用1 × 1卷积生成特征映射的线性变换结果,该特征映射在图4中称为value。第三个分支使用1 × 1卷积来简化查询和键的倍数。这个卷积在图4中称为QK。随后,将第一分支和第三分支分别与第二分支矩阵相乘。得到的两个分支分别表示跨通道和空间的上下文信息。最后,利用广播Hadamard积在这两个分支上得到了SCAM的输出。

相关代码:

SCAM注意力的代码,如下。

class SCAM(nn.Module):
    def __init__(self, in_channels, reduction=1):
        super(SCAM, self).__init__()
        self.in_channels = in_channels
        self.inter_channels = in_channels

        self.k = Conv(in_channels, 1, 1, 1)
        self.v = Conv(in_channels, self.inter_channels, 1, 1)
        self.m = Conv_withoutBN(self.inter_channels, in_channels, 1, 1)
        self.m2 = Conv(2, 1, 1, 1)

        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # GAP
        self.max_pool = nn.AdaptiveMaxPool2d(1)  # GMP

    def forward(self, x):
        n, c, h, w = x.size(0), x.size(1), x.size(2), x.size(3)

        # avg max: [N, C, 1, 1]
        avg = self.avg_pool(x).softmax(1).view(n, 1, 1, c)
        max = self.max_pool(x).softmax(1).view(n, 1, 1, c)

        # k: [N, 1, HW, 1]
        k = self.k(x).view(n, 1, -1, 1).softmax(2)

        # v: [N, 1, C, HW]
        v = self.v(x).view(n, 1, c, -1)

        # y: [N, C, 1, 1]
        y = torch.matmul(v, k).view(n, c, 1, 1)

        # y2:[N, 1, H, W]
        y_avg = torch.matmul(avg, v).view(n, 1, h, w)
        y_max = torch.matmul(max, v).view(n, 1, h, w)

        # y_cat:[N, 2, H, W]
        y_cat = torch.cat((y_avg, y_max), 1)

        y = self.m(y) * self.m2(y_cat).sigmoid()

        return x + y

四、YOLOv8中SCAM使用方法

1.YOLOv8中添加SCAM模块:

首先在ultralytics/nn/modules/conv.py最后添加SCAM模块的代码。

2.在conv.py的开头__all__ = 内添加SCAM模块的类别名:

3.在同级文件夹下的__init__.py内添加SCAM的相关内容:(分别是from .conv import SCAM ;以及在__all__内添加SCAM)

4.在ultralytics/nn/tasks.py进行LSKA注意力机制的注册,以及在YOLOv8的yaml配置文件中添加SCAM即可。

首先打开task.py文件,按住Ctrl+F,输入parse_model进行搜索。找到parse_model函数。在其最后一个else前面添加以下注册代码:

        elif m is SCAM:
            c2 = ch[f]
            args = [c2]

然后,就是新建一个名为YOLOv8_SCAM.yaml的配置文件:(路径:ultralytics/cfg/models/v8/YOLOv8_SCAM.yaml)

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call CPAM-yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SCAM, [1024]]#11代表卷积核大小,可以填写7、11、23、35、41、53
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

其中参数中nc,由自己的数据集决定。本文测试,采用的coco8数据集,有80个类别。

在根目录新建一个train.py文件,内容如下

from ultralytics import YOLO

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
# 加载一个模型
    model = YOLO('ultralytics/cfg/models/v8/YOLOv8_SCAM.yaml')  # 从YAML建立一个新模型
# 训练模型
    results = model.train(data='ultralytics/cfg/datasets/coco8.yaml', epochs=1,imgsz=640,optimizer="SGD")

训练输出:​

​​

五、总结

以上就是SCAM的原理及使用方式,但具体SCAM注意力机制的具体位置放哪里,效果更好。需要根据不同的数据集做相应的实验验证。希望本文能够帮助你入门YOLO中注意力机制的使用。


http://www.kler.cn/news/327222.html

相关文章:

  • 【2025】基于Spring Boot的智慧农业小程序(源码+文档+调试+答疑)
  • plt绘画三维曲面
  • Android OTA升级
  • excel快速入门(二)
  • Redis缓存技术 基础第二篇(Redis的Java客户端)
  • Ingress Gateway 它负责处理进入集群的 HTTP 和 TCP 流量
  • 七星创客:重塑商业模式认知
  • 在 Linux 中,要让某一个线程或进程排他性地独占一个 CPU
  • AI芯片WT2605C赋能厨房家电,在线对话操控,引领智能烹饪新体验:尽享高效便捷生活
  • Linux:文件描述符介绍
  • 【SpringBoot详细教程】-08-MybatisPlus详细教程以及SpringBoot整合Mybatis-plus【持续更新】
  • 端点安全服务:全面的端点安全解决方案
  • 初识CyberBattleSim
  • sql语法学习 sql各种语法 sql增删改查 数据库各种操作 数据库指令
  • 自动化测试中如何精确模拟富文本编辑器中的输入与提交?
  • Pytorch-LSTM轴承故障一维信号分类(一)
  • 如何在 Amazon EMR 中运行 Flink CDC Pipeline Connector
  • 【笔记】如何将本地的.md变成不影响阅读的类pdf模式
  • COMP 6714-Info Retrieval and Web Search笔记week2
  • 解决 Android WebView 无法加载 H5 页面常见问题的实用指南
  • Another redis desktop manager使用说明
  • 在IntelliJ IDEA中设置文件自动定位
  • 劳易测ODT3CL1-2M漫反射传感器荣获 “2024 MM《现代制造》创新产品奖”
  • AWS Network Firewall - IGW方式配置只应许白名单域名出入站
  • SQL进阶技巧:影院2人相邻的座位如何预定?
  • QT将QBytearray的data()指针赋值给结构体指针变量后数据不正确的问题
  • Brave编译指南2024 MacOS篇-构建与运行(六)
  • 正则表达式的使用规则
  • Linux —— Socket编程(三)
  • 深入理解 C++11 Lambda 表达式及其捕获列表