当前位置: 首页 > article >正文

YOLOv8改进 - 注意力篇 - 引入SK网络注意力机制

一、本文介绍

作为入门性篇章,这里介绍了SK网络注意力在YOLOv8中的使用。包含SK原理分析,SK的代码、SK的使用方法、以及添加以后的yaml文件及运行记录。

二、SK原理分析

SK官方论文地址:SK注意力文章

SK注意力机制:SK网络中的神经元可以捕获具有不同比例的目标对象,实验验证了神经元根据输入自适应地调整其感受野大小的能力。其SK模块的原理结构如下图所示。该方法由三个部分组成:Split,Fuse,Select。

Split就是一个multi-branch的操作,用不同的卷积核进行卷积得到不同的特征;

Fuse部分就是用SE的结构获取通道注意力的矩阵(N个卷积核就可以得到N个注意力矩阵,这步操作对所有的特征参数共享),这样就可以得到不同kernel经过SE之后的特征;

Select操作就是将这几个特征进行相加。

相关代码:

SK注意力的代码,如下。

from torch.nn import init
from collections import OrderedDict

class SKAttention(nn.Module):

    def __init__(self, channel=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32):
        super().__init__()
        self.d = max(L, channel // reduction)
        self.convs = nn.ModuleList([])
        for k in kernels:
            self.convs.append(
                nn.Sequential(OrderedDict([
                    ('conv', nn.Conv2d(channel, channel, kernel_size=k, padding=k // 2, groups=group)),
                    ('bn', nn.BatchNorm2d(channel)),
                    ('relu', nn.ReLU())
                ]))
            )
        self.fc = nn.Linear(channel, self.d)
        self.fcs = nn.ModuleList([])
        for i in range(len(kernels)):
            self.fcs.append(nn.Linear(self.d, channel))
        self.softmax = nn.Softmax(dim=0)

    def forward(self, x):
        bs, c, _, _ = x.size()
        conv_outs = []
        ### split
        for conv in self.convs:
            conv_outs.append(conv(x))
        feats = torch.stack(conv_outs, 0)  # k,bs,channel,h,w

        ### fuse
        U = sum(conv_outs)  # bs,c,h,w

        ### reduction channel
        S = U.mean(-1).mean(-1)  # bs,c
        Z = self.fc(S)  # bs,d

        ### calculate attention weight
        weights = []
        for fc in self.fcs:
            weight = fc(Z)
            weights.append(weight.view(bs, c, 1, 1))  # bs,channel
        attention_weughts = torch.stack(weights, 0)  # k,bs,channel,1,1
        attention_weughts = self.softmax(attention_weughts)  # k,bs,channel,1,1

        ### fuse
        V = (attention_weughts * feats).sum(0)
        return V

四、YOLOv8中SK使用方法

1.YOLOv8中添加SK模块,首先在ultralytics/nn/modules/conv.py最后添加SK模块的代码。

2.在conv.py的开头__all__ = 内添加SK模块的类别名(SK的类别名在本文中为SKAttention)

3.在同级文件夹下的__init__.py内添加SKAttention的相关内容:(分别是from .conv import SKAttention ;以及在__all__内添加SKAttention)

4.在ultralytics/nn/tasks.py进行SK注意力机制的注册,以及在YOLOv8的yaml配置文件中添加SK即可。

首先打开task.py文件,按住Ctrl+F,输入parse_model进行搜索。找到parse_model函数。在其最后一个else前面添加以下注册代码:(本文续接上篇文章,加在了CBAM、ECA的位置)

        elif m in {CBAM,ECA,SKAttention}:#添加注意力模块,没有CBAM、ECA的,将CBAM、ECA删除即可
            c1, c2 = ch[f], args[0]
            if c2 != nc:
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, *args[1:]]

然后,就是新建一个名为YOLOv8_SK.yaml的配置文件:(路径:ultralytics/cfg/models/v8/YOLOv8_SK.yaml)

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call CPAM-yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SKAttention, [1024]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

其中参数中nc,由自己的数据集决定。本文测试,采用的coco8数据集,有80个类别。

在根目录新建一个train.py文件,内容如下

from ultralytics import YOLO


# 加载一个模型
model = YOLO('ultralytics/cfg/models/v8/YOLOv8_SK.yaml')  # 从YAML建立一个新模型
# 训练模型
results = model.train(data='ultralytics/cfg/datasets/coco8.yaml', epochs=1,imgsz=640,optimizer="SGD")

训练输出:

五、总结

以上就是SK的原理及使用方式,但具体SK注意力机制的具体位置放哪里,效果更好。需要根据不同的数据集做相应的实验验证。希望本文能够帮助你入门YOLO中注意力机制的使用。


http://www.kler.cn/a/321532.html

相关文章:

  • 对javascript语言标准函数与箭头函数中this的理解(补充)
  • 解释下什么是面向对象?面向对象和面向过程的区别?
  • java开发入门学习五-流程控制
  • 【机器学习】探索机器学习与人工智能:驱动未来创新的关键技术
  • NACA四位数字翼型
  • Debian11 安装MYSQL8 签名错误
  • 计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-09-26
  • 了解网络的相关信息
  • 从0开始linux(5)——vim
  • 微信小程序-canvas
  • go语言网络编程
  • 【Linux 从基础到进阶】Kafka消息队列配置与管理
  • C/C++中的内存管理
  • c语言200例 063 信息查询
  • 数据结构 ——— 移除元素(快慢指针)
  • io流(学习笔记03)字符集
  • 大数据时代的PDF解析:技术与挑战
  • Python:百度贴吧实现自动化签到
  • Spring是什么
  • 有源蜂鸣器(5V STM32)
  • 无人机之虚拟云台技术篇
  • LeetCode 137. 只出现一次的数字 II
  • Linux安装vim超详细教程
  • MySQL重点,面试题
  • 深入Android UI开发:从自定义View到高级布局技巧的全面学习资料
  • RestSharp简介