当前位置: 首页 > article >正文

25/1/12 算法笔记 剖析Yolov8底层逻辑

YOLOv8 是一种基于深度学习的目标检测和图像分割模型,属于 YOLO(You Only Look Once)系列的最新版本。YOLO 系列模型以其高效的实时目标检测能力而闻名,YOLOv8 在此基础上进行了一些优化和改进。

Yolov8的主要特点:

1.实时性,在速度和准确性之间得到了良好的平衡,适合实时应用。

2.多任务学习,支持多任务,包括目标检测,示例分割,语义分割

3.改进的网络结构,引入了新的网络结构和层,可能包括更深的卷积层,更高效的特征提取模块等,以提高模型的性能。

4.增强的训练策略,采用一系列新的训练策略,如数据增强,混合精度训练。

5.可拓展性,允许用户根据具体任务的需求进行自定义和扩展。

Yolov8的网络结构:

Backbone:网络的特征提取部分,负责从输入图像中提取高层次的特征。

import torch
import torch.nn as nn

class CSPNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CSPNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels // 2, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels, out_channels // 2, kernel_size=1)
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        return self.conv3(torch.cat((x1, x2), dim=1))

通过两个1*1卷积将输入特征分成两块,然后将它们拼接在一起,最后通过3*3卷积处理。

Backbone会根据任务的复杂度和目标的多样性来调整层数。例如:

class YOLOv8Backbone(nn.Module):
    def __init__(self):
        super(YOLOv8Backbone, self).__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # 继续增加通道数...
        )

    def forward(self, x):
        x = self.backbone(x)
        return x

通道数的逐步增加确保了模型在不同层次上捕捉到丰富的特征,同时避免了过高的计算成本。这样的设计在实际应用中能够有效提升模型的性能和效率。

Neck:用于连接Backbone和Head,通常负责特征融合和多尺度特征的生成。

FPN 类实现了一个简单的特征金字塔网络。它通过 1x1 卷积生成横向连接的特征,并通过上采样将特征提升到更高的分辨率。

class FPN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FPN, self).__init__()
        self.lateral_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x):
        lateral = self.lateral_conv(x)
        upsampled = self.upsample(lateral)
        return upsampled

Upsample:是一种将低分辨率数据转换为高分辨率数据的操作。它在信号处理、图像处理、计算机视觉和深度学习等领域中广泛应用。上采样的目的是增加数据的尺寸或分辨率,同时尽可能地保留原始数据的特征。

Head:是模型的输出部分,负责生成最终的检测预测,包括边界框的位置,类别概率和分割掩码。YOLOV8在此部分可能会采用新的损失函数和预测方式。

class YOLOHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(YOLOHead, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels, num_classes + 5, kernel_size=1)  # 5 for bbox (x, y, w, h, conf)

    def forward(self, x):
        x = self.conv1(x)
        return self.conv2(x)

这个示例中,类实现了Yolov的输出层,首先通过一个3*3卷积提取特征,然后通过一个1*1卷积生成边界框和类别概率的预测。

整体的网络结构

class YOLOv8(nn.Module):
    def __init__(self, num_classes):
        super(YOLOv8, self).__init__()
        self.backbone = CSPNet(in_channels=3, out_channels=64)
        self.neck = FPN(in_channels=64, out_channels=128)
        self.head = YOLOHead(in_channels=128, num_classes=num_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = self.neck(x)
        return self.head(x)

损失函数

Yolov8使用多重损失函数来优化模型,包括定位损失,置信度损失和类别损失。

class YOLOLoss(nn.Module):
    def __init__(self):
        super(YOLOLoss, self).__init__()

    def forward(self, predictions, targets):
        # 计算定位损失、置信度损失和类别损失
        loc_loss = self.compute_location_loss(predictions, targets)
        conf_loss = self.compute_confidence_loss(predictions, targets)
        class_loss = self.compute_class_loss(predictions, targets)
        return loc_loss + conf_loss + class_loss

预测机制

Yolov8通过将图像划分位网络来进行目标检测,每个网络负责预测其中心点落在其内部的目标。每个网络预测以下信息:

        边界框坐标:通常以相对于网络单元的偏移量和比例进行预测

        置信度分数:表示该网络内是否有目标的概率

        类别概率:表示目标属于各个类别的概率分布

数据增强和训练策略

YOLOv8 采用多种数据增强技术,以提高模型的泛化能力。常见的数据增强方法包括:

        随机裁剪,旋转,翻转

import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
])

推理和后处理

在推理阶段,YOLOv8 会对输入图像进行处理,生成预测结果。

包括:

非极大值抑制NM5:用于消除重叠的边界框,只保留置信度最高的框。

def non_max_suppression(predictions, conf_threshold=0.5, iou_threshold=0.4):
    # 过滤低置信度的预测
    predictions = [p for p in predictions if p[4] >= conf_threshold]
    # 按照置信度排序
    predictions.sort(key=lambda x: x[4], reverse=True)
    
    keep_boxes = []
    while predictions:
        best_box = predictions.pop(0)
        keep_boxes.append(best_box)
        predictions = [p for p in predictions if compute_iou(best_box, p) < iou_threshold]
    
    return keep_boxes

阈值过滤:根据设定的阈值过滤低置信度的预测。

def filter_predictions(predictions, conf_threshold):
    return [p for p in predictions if p[4] >= conf_threshold]

总的来看其实Yolov8模型并没有这么复杂,其实是它里面的结构网络设计的非常具有合理性,使得它简单高效。

Backbone:

  • 特征提取效率:选择高效的网络结构,减少计算量,同时保留足够的特征信息。
  • 深度和宽度的平衡:合理的层数和通道数设计,使得模型在提取低级和高级特征时具有良好的表现。
  • 预训练模型:通常使用在大规模数据集(如 ImageNet)上预训练的模型,帮助加速收敛并提高准确性。

Neck:

  • 特征金字塔结构:通过特征金字塔网络(FPN)或其他融合方法,能够有效地结合来自不同层的特征,增强模型对多尺度目标的检测能力。
  • 减少信息损失:在特征融合过程中,合理的设计可以最大限度地保留重要信息,避免特征的丢失。

Head:

  • 多任务学习:通过同时预测多个输出(边界框、置信度、类别),模型能够更好地学习到目标的特征,提高检测的准确性。
  • 损失函数设计:合理的损失函数组合(如定位损失、置信度损失和类别损失)能够使模型在训练过程中更有效地优化各个任务,避免单一任务的过拟合。

网络结构也能基于任务复杂性自主调节:

 # 基于任务复杂性调整通道数
        if task_complexity == 'simple':
            self.channels = [32, 64, 128]
        elif task_complexity == 'moderate':
            self.channels = [64, 128, 256]
        else:  # complex
            self.channels = [128, 256, 512]

        # 根据类别数量调整最后一层的通道数
        self.final_channels = self.channels[-1] + num_classes

ok!明天见!


http://www.kler.cn/a/500789.html

相关文章:

  • vue3+vite图片动态地址问题 + nginx配置
  • 【2024年华为OD机试】(C卷,100分)- 单词加密(Java JS PythonC/C++)
  • 【学习笔记】理解深度学习的基础:机器学习
  • webpack打包要义
  • 什么是MVCC
  • 【ASP.NET学习】Web Pages 最简单的网页编程开发模型
  • 深入浅出Java Web开放平台:从API设计到安全保障的全方位探索
  • --- 多线程编程 基本用法 java ---
  • 从零开始开发纯血鸿蒙应用之多签名证书管理
  • A3. Springboot3.x集成LLama3.2实战
  • B+ 树的实现原理与应用场景
  • 20250112面试鸭特训营第20天
  • 移动端屏幕分辨率rem,less
  • 前端开发:HTML常见标签
  • 慧集通(DataLinkX)iPaaS集成平台-业务建模之业务对象(二)
  • Linux权限管理(用户和权限之间的关系)
  • MATLAB语言的文件操作
  • 《分布式光纤测温:解锁楼宇安全的 “高精度密码”》
  • 如何在本地部署大模型并实现接口访问( Llama3、Qwen、DeepSeek等)
  • mark 一下conductor github
  • 【前端动效】原生js实现拖拽排课效果