当前位置: 首页 > article >正文

使用ultralytics库微调 YOLO World 保持 Zero-Shot 能力

在训练 YOLO World 模型时,如果希望在特定数据集(如火灾数据集)上进行微调,同时保留模型的 Zero-Shot 能力,可以参考以下几点方法。Zero-Shot 能力指的是模型在未见过的类别上仍具备一定的推理能力,但在特定数据集上的微调有时会导致模型过度专注于新任务,从而丧失这种能力。

如何微调 YOLO World 保持 Zero-Shot 能力

  1. 保持数据集平衡(Balanced Dataset):

问题:当你只用特定的定制数据集(如火灾数据集)进行训练时,模型可能会逐渐丧失其泛化能力,变得只擅长特定任务。
解决方法:在你的定制数据集中,增加与原来 Zero-Shot 类别相关的图像。这有助于模型保持对广泛类别的识别能力。例如,可以将 YOLO World 训练过的 GQA 数据集与火灾数据集合并,使模型在学习特定任务时仍能保持泛化能力。

  1. 限制训练周期(Limited Epochs):

问题:过多的训练周期会导致模型过拟合在特定的数据集上,导致泛化能力下降。
解决方法:减少训练周期,以避免模型过度拟合。比如,10 个 epoch 是一个不错的起点。长时间训练可能会让模型专注于新任务,从而削弱 Zero-Shot 的能力。

  1. 调整学习率(Learning Rate):

问题:如果学习率太大,模型权重调整幅度过大,容易导致模型丧失之前学到的泛化能力。
解决方法:使用更小的学习率,例如 0.001 或 0.0005,以细微调整模型权重而非大幅修改。你可以尝试在 100 个 epoch 内使用较小的初始学习率,同时通过学习率调度器逐渐减小学习率。

  1. 添加自定义头(Custom Head):

问题:在某些特定任务中,完全微调整个模型可能导致模型的 Zero-Shot 能力下降。
解决方法:可以考虑只为你的定制数据集添加一个自定义头(Custom Head),而保持模型的主干不变。这意味着模型的底层特征提取能力依然保持其原有的 Zero-Shot 能力,而新增的任务会通过自定义头进行学习。

  1. 示例代码

下面是微调 YOLO World 模型的示例代码,包含了较少的训练 epoch、较小的学习率

from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_dataset
from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.world import WorldTrainer
from ultralytics.utils import DEFAULT_CFG
from ultralytics.utils.torch_utils import de_parallel

# 配置训练参数
yaml_path = '/path/to/your/dataset.yaml'
args = dict(
    model='yolov8x-worldv2.pt', 
    data=yaml_path, 
    epochs=10,           # 从 10 个 epoch 开始,可以根据需求调整
    batch=4,             # 批次大小
    imgsz=640,           # 输入图像尺寸
    lr0=0.001,           # 固定较小的初始学习率
    optimizer='SGD',     # 使用 SGD 优化器
    weight_decay=0.0005, # 权重衰减
    momentum=0.932,      # 动量参数
    hsv_h=0.015,         # 颜色抖动范围
    hsv_s=0.7, 
    hsv_v=0.4, 
    mosaic=1.0,          # 启用 mosaic 数据增强
    augment=True,        # 启用数据增强
    save_period=1,       # 每个 epoch 保存一次模型
    patience=5,          # 提前终止策略
    device=0,            # 指定训练设备
    val=True,            # 启用验证
    plots=True,          # 绘制图像
    workers=0            # 数据加载进程
)

trainer = WorldTrainer(overrides=args)

# 开始训练
results = trainer.train()
  1. 总结

微调 YOLO World 模型以保留 Zero-Shot 能力需要从数据集平衡、训练周期、学习率以及模型架构等多个角度入手。通过合并数据集、减少训练周期以及使用较小的学习率,你可以在特定任务上获得更好的性能,同时保持模型在未见类别上的推理能力。

如果你有进一步的问题或在实验过程中遇到困难,欢迎继续讨论!


http://www.kler.cn/news/314340.html

相关文章:

  • Go小专栏 第一期
  • 【前端】ES6:Promise对象和Generator函数
  • 【MySQL 01】数据库基础
  • 配置docker的proxy指向
  • 【Proteus仿真】基于51单片机的L298N电机电速调节
  • 记录动态库项目仅生成了dll,未生成lib文件的问题
  • 深度学习02-pytorch-07-张量的拼接操作
  • 剖析Spark Shuffle原理(图文详解)
  • go 以太坊代币查余额
  • Python | Leetcode Python题解之第424题替换后的最长重复字符
  • 是德科技Keysight N4433D ECal模块 26.5GHz 4端口3.5毫米
  • 在python爬虫中xpath方式提取lxml.etree._ElementUnicodeResult转化为字符串str类型
  • RAG+Agent人工智能平台:RAGflow实现GraphRA知识库问答,打造极致多模态问答与AI编排流体验
  • 演示jvm锁存在的问题
  • Java集合(三)
  • Centos7安装chrome的问题
  • WebApi开发中依赖注入和RESTful 详解
  • OceanBase 的并发简述笔记
  • Navicate 链接Oracle 提示 Oracle Library is not loaded ,账号密码都正确地址端口也对
  • 【变化检测】基于ChangeStar建筑物(LEVIR-CD)变化检测实战及ONNX推理
  • php变量赋值javascipt变量
  • 13.面试算法-字符串常见算法题(二)
  • 【论文阅读】3D Diffuser Actor: Policy Diffusion with 3D Scene Representations
  • 人工智能与机器学习原理精解【25】
  • 【电路笔记】-运算放大器积分器
  • 数模方法论-整数规划
  • Python类及元类的创建流程
  • C#进阶-基于雪花算法的订单号设计与实现
  • [Python数据可视化] Plotly:交互式数据可视化的强大工具
  • 15.9 grafana-deployment-yaml讲解