当前位置: 首页 > article >正文

YOLO模型缝合实战指南:ECA注意力模块的实现与集成

YOLO模型缝合实战指南:ECA注意力模块的实现与集成

1. 引言

在目标检测领域,YOLO (You Only Look Once) 系列模型因其出色的性能和实时性而广受欢迎。本文将详细介绍如何将ECA(Efficient Channel Attention)注意力模块缝合到YOLO模型中,以提升检测效果。

2. ECA模块简介

ECA(Efficient Channel Attention)是一种轻量级的通道注意力机制,它通过自适应地捕获通道间的依赖关系来提升模型性能。相比于SENet,ECA模块具有更少的参数量和更高的效率。

3. 实现步骤

3.1 创建ECA模块

首先,在ultralytics/nn/modules/目录下创建ECA.py文件:

import torch.nn as nn
import torch

class ECA(nn.Module):
    """Efficient Channel Attention Module.
    
    Args:
        c1 (int): Input channels
        c2 (int): Output channels (usually same as c1)
        k_size (int, optional): Kernel size for the 1D convolution. Defaults to 3.
    """
    def __init__(self, c1, c2, k_size=3):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)

3.2 更新模块导入

ultralytics/nn/modules/__init__.py中添加ECA模块的导入:

from .transformer import (
    AIFI,
    MLP,
    DeformableTransformerDecoder,
    DeformableTransformerDecoderLayer,
    LayerNorm2d,
    MLPBlock,
    MSDeformAttn,
    TransformerBlock,
    TransformerEncoderLayer,
    TransformerLayer,
)
from .ECA import ECA as ECAModule

3.3 修改任务文件

ultralytics/nn/tasks.py中进行以下修改:

  1. 添加ECA模块的导入:
from ultralytics.nn.modules import (
    DPF,
    CGAFusion,
    DetailAttentionModule,
    ECAModule,  # 新增ECA模块
    AIFI,
    C1,
    C2,
    C2PSA,
)
  1. parse_model函数中添加ECA模块支持:
if m in {
    Classify,
    Conv,
    ConvTranspose,
    GhostConv,
    Bottleneck,
    GhostBottleneck,
    SPP,
    SPPF,
    C2fPSA,
    C2PSA,
    DWConv,
    Focus,
    BottleneckCSP,
    C1,
    C2,
    C2f,
    C3k2,
    RepNCSPELAN4,
    ELAN1,
    ADown,
    AConv,
    SPPELAN,
    C2fAttn,
    C3,
    C3TR,
    C3Ghost,
    nn.ConvTranspose2d,
    DWConvTranspose2d,
    C3x,
    RepC3,
    PSA,
    SCDown,
    C2fCIB,
    ECAModule,  # 新增ECA模块支持
    DetailAttentionModule,
}:
    c1, c2 = ch[f], args[0]
    if c2 != no:  # if not output
        c2 = make_divisible(c2 * gw, 8)

    args = [c1, c2, *args[1:]]
    if m in {C2fPSA, C2PSA}:
        args.insert(2, n)  # number of repeats
        n = 1
  1. 修改_predict_once函数以添加调试信息:
def _predict_once(self, x, profile=False, visualize=False, embed=None):
    """
    Perform a forward pass through the network.

    Args:
        x (torch.Tensor): The input tensor to the model.
        profile (bool):  Print the computation time of each layer if True, defaults to False.
        visualize (bool): Save the feature maps of the model if True, defaults to False.
        embed (list, optional): A list of feature vectors/embeddings to return.

    Returns:
        (torch.Tensor): The last output of the model.
    """
    y, dt, embeddings = [], [], []  # outputs
    for m in self.model:
        if m.f != -1:  # if not from previous layer
            x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
        
        # 添加调试信息
        print(f"\n当前层: {m.i}, 类型: {m.type}")
        if isinstance(x, (list, tuple)):
            print(f"输入形状: {[t.shape for t in x]}")
        else:
            print(f"输入形状: {x.shape}")
        
        if profile:
            self._profile_one_layer(m, x, dt)
        x = m(x)  # run
        
        # 添加调试信息
        if isinstance(x, (list, tuple)):
            print(f"输出形状: {[t.shape for t in x]}")
        else:
            print(f"输出形状: {x.shape}")
        
        y.append(x if m.i in self.save else None)  # save output
        if visualize:
            feature_visualization(x, m.type, m.i, save_dir=visualize)
        if embed and m.i in embed:
            embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flatten
            if m.i == max(embed):
                return torch.unbind(torch.cat(embeddings, 1), dim=0)
    return x

4. 在YOLO中使用ECA

4.1 创建模型配置文件

ultralytics/cfg/models/v9/目录下创建yolov9t-ECA.yaml文件:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv9t object detection model. For Usage examples see https://docs.ultralytics.com/models/yolov9
# 917 layers, 2128720 parameters, 8.5 GFLOPs
#YOLOv9t-ECA summary: 925 layers, 2,807,131 parameters, 2,807,115 gradients, 14.3 GFLOPs

# Parameters
nc: 80 # number of classes

# GELAN backbone
backbone:
  - [-1, 1, Conv, [16, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [32, 3, 2]] # 1-P2/4
  - [-1, 1, ELAN1, [32, 32, 16]] # 2
  - [-1, 1, AConv, [64]] # 3-P3/8
  - [-1, 1, RepNCSPELAN4, [64, 64, 32, 3]] # 4
  - [-1, 1, AConv, [96]] # 5-P4/16
  - [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]] # 6
  - [-1, 1, AConv, [128]] # 7-P5/32
  - [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]] # 8
  - [-1, 1, SPPELAN, [128, 64]] # 9

head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]] # 12
  - [-1, 1, ECAModule, [96]]  # 添加ECA模块
  
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 1, RepNCSPELAN4, [64, 64, 32, 3]] # 15
  - [-1, 1, ECAModule, [64]]  # 添加ECA模块

  - [-1, 1, AConv, [48]]
  - [[-1, 12], 1, Concat, [1]] # cat head P4
  - [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]] # 18 (P4/16-medium)

  - [-1, 1, AConv, [64]]
  - [[-1, 9], 1, Concat, [1]] # cat head P5
  - [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]] # 21 (P5/32-large)

  - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)

4.2 创建训练脚本

创建train.py文件:

from ultralytics import YOLO

def main():
    # 加载轻量化模型配置
    model = YOLO('yolov9t-ECA.yaml')

    # 训练模型
    model.train(
        # 基础配置
        data='D:/archive/data.yaml',  # 改为你的数据路径
        epochs=100,
        imgsz=640,
        batch=32,      
        device='0',    
        workers=8,       
        
        # 早停和学习率
        patience=20,                    
        lr0=0.01,      
        
        # 数据增强参数
        mosaic=1.0,                    
        mixup=0.5,                     
        hsv_h=0.015,                   
        hsv_s=0.7,                     
        hsv_v=0.4,                     
        degrees=45,                    
        translate=0.1,                 
        scale=0.5,                     
        
        # 输出配置
        project='runs/detect_test',
        name='fruit_exp_test',
        
        # 保存和可视化配置
        save_period=10,               
        exist_ok=True,
        plots=True,             
    )

if __name__ == '__main__':
    main()

predict_once输出的调试信息
请添加图片描述
成功训练:
请添加图片描述

参考文献

  1. ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks
  2. YOLOv8官方文档
  3. Ultralytics YOLO文档

http://www.kler.cn/a/541571.html

相关文章:

  • Django学习笔记(第一天:Django基本知识简介与启动)
  • 尝试一下,交互式的三维计算python库,py3d
  • 手写一个C++ Android Binder服务及源码分析
  • JavaScript:还在用if判断属性是否存在?哒咩(?.)用起来
  • 【系统架构设计师】体系结构文档化
  • Rust语言进阶之标准输入: stdin用法实例(一百零五)
  • Tria Technologies RFSoC 平台 - 入门指南
  • 2025 年前端开发现状分析:卷疯了还是卷麻了?
  • 【FPGA】模型机下载FPGA设计
  • UE5 如何通过命令行启动游戏工程
  • 【错题本】js事件循环机制下,记录一个意外错误
  • Ada语言的区块链
  • 地平线 3D 目标检测 Bevformer 参考算法-V2.0
  • 怎么查看电脑显存大小(查看电脑配置)
  • OpenGL学习笔记(十二):初级光照:投光物/多光源(平行光、点光源、聚光)
  • JEECGBOOT前端VUE3版本浏览器兼容支持chrome>=76版本方法
  • 虚拟机+Docker配置主机代理和常见配置
  • 如何在 Java 应用中实现数据库的主从复制(读写分离)?请简要描述架构和关键代码实现?
  • Docker 和 Containerd 目录结构及存储机制
  • 基于Springboot和vue的流浪猫狗救助救援系统设计与实现(源码+数据库+文档)
  • kafka消费端之分区分配策略
  • MATLAB中extract 函数用法
  • Ubuntu 如何安装Snipaste切图软件
  • redis高级数据结构HyperLogLog
  • TCP/IP 协议图解 | TCP 协议详解 | IP 协议详解
  • 香橙派AI Pro算子开发(二)kernel直调Add算子