当前位置: 首页 > article >正文

YOLOv8改进 - 注意力篇 - 引入CBAM注意力机制

一、本文介绍

作为入门性第一篇,这里介绍了CBAM注意力在YOLOv8中的使用。包含CBAM原理分析,CBAM的代码、CBAM的使用方法、以及添加以后的yaml文件及运行记录。

二、CBAM原理分析

CBAM官方论文地址:CBAM论文

CBAM的pytorch版代码:CBAM的pytorch版代码

CBAM:卷积块注意力模块,由通道注意力和空间注意力组成。其中通道注意力机制主要检测目标的内容信息,空间注意力主要检测目标位置信息。模块先应用通道注意力,再利用空间注意力;其原理结构如下图所示。

相关代码:

在YOLOv8中,作者已经集成了cbam注意力的代码,仅未应用。

class ChannelAttention(nn.Module):
    """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""

    def __init__(self, channels: int) -> None:
        """Initializes the class and sets the basic configurations and instance variables required."""
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.act = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Applies forward pass using activation on convolutions of the input, optionally using batch normalization."""
        return x * self.act(self.fc(self.pool(x)))


class SpatialAttention(nn.Module):
    """Spatial-attention module."""

    def __init__(self, kernel_size=7):
        """Initialize Spatial-attention module with kernel size argument."""
        super().__init__()
        assert kernel_size in (3, 7), "kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1
        self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.act = nn.Sigmoid()

    def forward(self, x):
        """Apply channel and spatial attention on input for feature recalibration."""
        return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))


class CBAM(nn.Module):
    """Convolutional Block Attention Module."""

    def __init__(self, c1, kernel_size=7):
        """Initialize CBAM with given input channel (c1) and kernel size."""
        super().__init__()
        self.channel_attention = ChannelAttention(c1)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        """Applies the forward pass through C1 module."""
        return self.spatial_attention(self.channel_attention(x))

四、YOLOv8中CBAM使用方法

YOLOv8中CBAM模块,作者存于ultralytics/nn/modules/conv.py中。

我们使用CBAM模块,仅需在ultralytics/nn/tasks.py进行CBAM注意力机制的注册,以及在YOLOv8的yaml配置文件中添加CBAM即可。

首先打开task.py文件,按住Ctrl+F,输入parse_model进行搜索。找到parse_model函数。在其最后一个else前面添加以下注册代码:

        elif m in {CBAM,MHSA,SEAttention,ECA,ShuffleAttention,ECA_SA,SE_SA,SA_ECA,SA_ShuffleAttention,CBAM_base,PECA_SA,CPAM, CPAM_SA, MSCA_SA, SKAttention, DoubleAttention,PCPAM_SA,SA_CPAM,CoordAtt}:#自己加的注意力模块
            c1, c2 = ch[f], args[0]
            if c2 != nc:
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, *args[1:]]

然后,就是新建一个名为YOLOv8_CBAM.yaml的配置文件:(路径:ultralytics/cfg/models/v8/YOLOv8_CBAM.yaml)

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call CPAM-yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, CBAM, [1024,7]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

其中参数中nc,由自己的数据集决定。本文测试,采用的coco8数据集,有80个类别。

在根目录新建一个train.py文件,内容如下

from ultralytics import YOLO


# 加载一个模型
model = YOLO('ultralytics/cfg/models/v8/YOLOv8_CBAM.yaml')  # 从YAML建立一个新模型
# 训练模型
results = model.train(data='ultralytics/cfg/datasets/coco8.yaml', epochs=1,imgsz=640,optimizer="SGD")

训练输出:

五、总结

以上就是CBAM的原理及使用方式,但具体CBAM注意力机制的具体位置放哪里,效果更好。需要根据不同的数据集做相应的实验验证。希望本文能够帮助你入门YOLO中注意力机制的使用。


http://www.kler.cn/news/308735.html

相关文章:

  • TCP.IP四层模型
  • Redis命令:redis-cli
  • 【乐企】基础请求封装
  • 【基于C++的产品入库管理系统】
  • Java项目实战II基于Java+Spring Boot+MySQL的图书管理系统的设计与实现 (源码+数据库+文档)
  • 关于yolov5遇到空标签导致训练暂停的解决
  • C++基于select和epoll的TCP服务器
  • 计算机毕业设计 毕业季一站式旅游服务定制平台的设计与实现 Java实战项目 附源码+文档+视频讲解
  • sshj使用代理连接服务器
  • as 类型断言
  • 动手学深度学习(四)卷积神经网络-下
  • 飞书项目管理使用攻略
  • MySQL基于GTID同步模式搭建主从复制
  • Spring Boot-API版本控制问题
  • 【Linux修行路】信号的产生
  • AI与自然语言处理(NLP):中秋诗词生成
  • ffmpeg硬件解码一般流程
  • 关于RabbitMQ重复消费的解决方案
  • 大数据新视界 --大数据大厂之数据挖掘入门:用 R 语言开启数据宝藏的探索之旅
  • 图数据库的力量:深入理解与应用 Neo4j
  • Vue2知识点
  • makefile 的语法(7):函数 word wordlist words firstword lastword ;
  • SurrealDB:现代应用的端到端云原生数据库解决方案
  • Golang | Leetcode Golang题解之第401题二进制手表
  • 【图像拼接】基于SIFT/SURF特征算法的图像拼接,matlab实现
  • 【重学 MySQL】三十三、流程控制函数
  • 探索未来游戏边界:AI驱动的开放世界RPG引擎与UGC平台
  • 【每日一题】LeetCode 2332.坐上公交的最晚时间(数组、双指针、二分查找、排序)
  • 大数据新视界 --大数据大厂之Kafka消息队列实战:实现高吞吐量数据传输
  • Wophp靶场漏洞挖掘