当前位置: 首页 > article >正文

YOLOv9模型重新参数化,将yolo.pt转为yolo-converted.pt

1、yolo-m转为yolo-m-converted

在我们使用train_dual.py训练好自己的模型后,我们会得到yolo系列的模型,比如我使用yolov9-m在数据集CityScapes上训练后,得到cityscapes-yolov9-m.pt,然后运行yolov9-m_to_converted.py代码得到cityscapes-yolov9-m-converted.pt,代码如下:记得修改nc为自己数据集的分类数,然后修改自己的加载模型路径和模型保存路径。

#!/usr/bin/env python
# coding: utf-8



import torch
from models.yolo import Model

# ## Convert YOLOv9-M

# In[ ]:


device = torch.device("cpu")
cfg = "../models/detect/gelan-m.yaml"
model = Model(cfg, ch=3, nc=8, anchors=3)
#model = model.half()
model = model.to(device)
_ = model.eval()
ckpt = torch.load('../runs/train/YOLOv9-m训练Cityscapes/weights/cityscapes-yolov9-m.pt', map_location='cpu')
model.names = ckpt['model'].names
model.nc = ckpt['model'].nc


# In[ ]:


idx = 0
for k, v in model.state_dict().items():
    if "model.{}.".format(idx) in k:
        if idx < 22:
            kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx+1))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.cv2.".format(idx) in k:
            kr = k.replace("model.{}.cv2.".format(idx), "model.{}.cv4.".format(idx+16))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.cv3.".format(idx) in k:
            kr = k.replace("model.{}.cv3.".format(idx), "model.{}.cv5.".format(idx+16))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.dfl.".format(idx) in k:
            kr = k.replace("model.{}.dfl.".format(idx), "model.{}.dfl2.".format(idx+16))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
    else:
        while True:
            idx += 1
            if "model.{}.".format(idx) in k:
                break
        if idx < 22:
            kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx+1))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.cv2.".format(idx) in k:
            kr = k.replace("model.{}.cv2.".format(idx), "model.{}.cv4.".format(idx+16))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.cv3.".format(idx) in k:
            kr = k.replace("model.{}.cv3.".format(idx), "model.{}.cv5.".format(idx+16))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.dfl.".format(idx) in k:
            kr = k.replace("model.{}.dfl.".format(idx), "model.{}.dfl2.".format(idx+16))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
_ = model.eval()


# In[ ]:


m_ckpt = {'model': model.half(),
          'optimizer': None,
          'best_fitness': None,
          'ema': None,
          'updates': None,
          'opt': None,
          'git': None,
          'date': None,
          'epoch': -1}
torch.save(m_ckpt, "./cityscapes-yolov9-m-converted.pt")


运行效果如下:

2、 使用重参后的模型进行推理

得到cityscapes-yolov9-m-converted.pt后,权重大小为38.7MB,使用detect.py进行检测

可以看到模型信息为:377 layers, 19926616 parameters, 9048 gradients, 76.0 GFLOPs

此时在使用原来的cityscapes-yolov9-m.pt,权重大小为63.1MB,进行检测,效果如下:

可以看到模型信息为:588 layers, 32563288 parameters, 0 gradients, 130.7 GFLOPs

对于示例图片,两个模型都做到了一致的检测结果:7 cars, 2 persons, 1 rider, 1 bicycle,这检测效果还是很好的,很多远一点、小一点的目标我自己肉眼都没看到。

无论是推理速度还是模型参数,经过重参后都有所提升,并且官方称两者的精度是一致的,做到了既保留了较高的精度,又缩小了模型复杂度。

但是!虽然经过重参后的模型能够使用detect.py进行检测推理,目前还无法使用val或者val_dual.py对重新参数化的模型进行验证,比如cityscapes-yolov9-m.pt经过参数重新化得到cityscapes-yolov9-m-converted.pt,将converted模型用于验证会报错,只能等后续官方维护了。

3、 yolo-s转为yolo-s-converted

运行yolov9-s_to_converted.py代码得到yolov9-s-converted.pt,代码如下:记得修改nc为自己数据集的分类数,然后修改自己的加载模型路径和模型保存路径。

#!/usr/bin/env python
# coding: utf-8



import torch
from models.yolo import Model

# ## Convert YOLOv9-S

# In[ ]:


device = torch.device("cpu")
cfg = "../models/detect/gelan-s.yaml"
model = Model(cfg, ch=3, nc=80, anchors=3)
#model = model.half()
model = model.to(device)
_ = model.eval()
ckpt = torch.load('../models/pretrained-weights/yolov9-s.pt', map_location='cpu')
model.names = ckpt['model'].names
model.nc = ckpt['model'].nc


# In[ ]:


idx = 0
for k, v in model.state_dict().items():
    if "model.{}.".format(idx) in k:
        if idx < 22:
            kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.cv2.".format(idx) in k:
            kr = k.replace("model.{}.cv2.".format(idx), "model.{}.cv4.".format(idx+7))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.cv3.".format(idx) in k:
            kr = k.replace("model.{}.cv3.".format(idx), "model.{}.cv5.".format(idx+7))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.dfl.".format(idx) in k:
            kr = k.replace("model.{}.dfl.".format(idx), "model.{}.dfl2.".format(idx+7))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
    else:
        while True:
            idx += 1
            if "model.{}.".format(idx) in k:
                break
        if idx < 22:
            kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.cv2.".format(idx) in k:
            kr = k.replace("model.{}.cv2.".format(idx), "model.{}.cv4.".format(idx+7))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.cv3.".format(idx) in k:
            kr = k.replace("model.{}.cv3.".format(idx), "model.{}.cv5.".format(idx+7))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.dfl.".format(idx) in k:
            kr = k.replace("model.{}.dfl.".format(idx), "model.{}.dfl2.".format(idx+7))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
_ = model.eval()


# In[ ]:


m_ckpt = {'model': model.half(),
          'optimizer': None,
          'best_fitness': None,
          'ema': None,
          'updates': None,
          'opt': None,
          'git': None,
          'date': None,
          'epoch': -1}
torch.save(m_ckpt, "./test_yolov9-s-converted.pt")


4、 yolo-c转为yolo-c-converted

运行yolov9-c_to_converted.py代码得到yolov9-c-converted.pt,代码如下:记得修改nc为自己数据集的分类数,然后修改自己的加载模型路径和模型保存路径。

#!/usr/bin/env python
# coding: utf-8



import torch
from models.yolo import Model

# ## Convert YOLOv9-C

# In[ ]:


device = torch.device("cpu")
cfg = "../models/detect/gelan-c.yaml"
model = Model(cfg, ch=3, nc=8, anchors=3)
#model = model.half()
model = model.to(device)
_ = model.eval()
ckpt = torch.load('../runs/train/YOLOv9-m训练Cityscapes/weights/best.pt', map_location='cpu')
model.names = ckpt['model'].names
model.nc = ckpt['model'].nc


# In[ ]:



idx = 0
for k, v in model.state_dict().items():
    if "model.{}.".format(idx) in k:
        if idx < 22:
            kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx+1))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
        elif "model.{}.cv2.".format(idx) in k:
            kr = k.replace("model.{}.cv2.".format(idx), "model.{}.cv4.".format(idx+16))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
        elif "model.{}.cv3.".format(idx) in k:
            kr = k.replace("model.{}.cv3.".format(idx), "model.{}.cv5.".format(idx+16))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
        elif "model.{}.dfl.".format(idx) in k:
            kr = k.replace("model.{}.dfl.".format(idx), "model.{}.dfl2.".format(idx+16))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
    else:
        while True:
            idx += 1
            if "model.{}.".format(idx) in k:
                break
        if idx < 22:
            kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx+1))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
        elif "model.{}.cv2.".format(idx) in k:
            kr = k.replace("model.{}.cv2.".format(idx), "model.{}.cv4.".format(idx+16))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
        elif "model.{}.cv3.".format(idx) in k:
            kr = k.replace("model.{}.cv3.".format(idx), "model.{}.cv5.".format(idx+16))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
        elif "model.{}.dfl.".format(idx) in k:
            kr = k.replace("model.{}.dfl.".format(idx), "model.{}.dfl2.".format(idx+16))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
_ = model.eval()


# In[ ]:


m_ckpt = {'model': model.half(),
          'optimizer': None,
          'best_fitness': None,
          'ema': None,
          'updates': None,
          'opt': None,
          'git': None,
          'date': None,
          'epoch': -1}
torch.save(m_ckpt, "./cityscapes-yolov9-c-converted.pt")


5、 yolo-e转为yolo-e-converted

运行yolov9-e_to_converted.py代码得到yolov9-e-converted.pt,代码如下:记得修改nc为自己数据集的分类数,然后修改自己的加载模型路径和模型保存路径。

#!/usr/bin/env python
# coding: utf-8



import torch
from models.yolo import Model

# ## Convert YOLOv9-E

# In[ ]:


device = torch.device("cpu")
cfg = "../models/detect/gelan-e.yaml"
model = Model(cfg, ch=3, nc=8, anchors=3)
#model = model.half()
model = model.to(device)
_ = model.eval()
ckpt = torch.load('../runs/train/YOLOv9-m训练Cityscapes/weights/best.pt', map_location='cpu')
model.names = ckpt['model'].names
model.nc = ckpt['model'].nc


# In[ ]:


idx = 0
for k, v in model.state_dict().items():
    if "model.{}.".format(idx) in k:
        if idx < 29:
            kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif idx < 42:
            kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx+7))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.cv2.".format(idx) in k:
            kr = k.replace("model.{}.cv2.".format(idx), "model.{}.cv4.".format(idx+7))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.cv3.".format(idx) in k:
            kr = k.replace("model.{}.cv3.".format(idx), "model.{}.cv5.".format(idx+7))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.dfl.".format(idx) in k:
            kr = k.replace("model.{}.dfl.".format(idx), "model.{}.dfl2.".format(idx+7))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
    else:
        while True:
            idx += 1
            if "model.{}.".format(idx) in k:
                break
        if idx < 29:
            kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif idx < 42:
            kr = k.replace("model.{}.".format(idx), "model.{}.".format(idx+7))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.cv2.".format(idx) in k:
            kr = k.replace("model.{}.cv2.".format(idx), "model.{}.cv4.".format(idx+7))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.cv3.".format(idx) in k:
            kr = k.replace("model.{}.cv3.".format(idx), "model.{}.cv5.".format(idx+7))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
        elif "model.{}.dfl.".format(idx) in k:
            kr = k.replace("model.{}.dfl.".format(idx), "model.{}.dfl2.".format(idx+7))
            model.state_dict()[k] -= model.state_dict()[k]
            model.state_dict()[k] += ckpt['model'].state_dict()[kr]
            print(k, "perfectly matched!!")
_ = model.eval()


# In[ ]:


m_ckpt = {'model': model.half(),
          'optimizer': None,
          'best_fitness': None,
          'ema': None,
          'updates': None,
          'opt': None,
          'git': None,
          'date': None,
          'epoch': -1}
torch.save(m_ckpt, "./cityscapes-yolov9-e-converted.pt")



http://www.kler.cn/a/372754.html

相关文章:

  • 基于OSS搭建在线教育视频课程分享网站
  • 一七一、React性能优化方式
  • 3.2 大数据概念、特征与价值
  • 使用AIM对SAP PO核心指标的自动化巡检监控
  • 大模型面试题63题(1-11)
  • ClickHouse 神助攻:纽约城市公共交通管理(MTA)数据应用挑战赛
  • GetX在使用过程中一些问题
  • 计算机科学与技术-毕业设计选题推荐
  • Python OpenCV精讲系列 - 车牌识别的全方位指南(二十四)
  • 论文 | Ignore Previous Prompt: Attack Techniques For Language Models
  • 第二十三章 Vue组件通信之非父子组件通信
  • 【Linux】网络编程:初识协议,序列化与反序列化——基于json串实现,网络通信计算器中简单协议的实现、手写序列化与反序列化
  • 【Web前端】JavaScript 对象原型与继承机制
  • 「C/C++」C++ 三大特性 之 类和对象
  • 版本管理工具切换 | svn切换到gitlab | gitblit 迁移到 gitlab
  • STL——list的介绍和使用
  • 微信小程序-全局数据共享/页面间通信
  • unity :Error building Player: Incompatible color space with graphics API
  • k8s Ingress 七层负载
  • 迪杰斯特拉算法(Dijkstra‘s Algorithm
  • 路由参数与请求方式
  • 理解环境变量与Shell编程:Linux开发的基础
  • 将你的 Kibana Dev Console 请求导出到 Python 和 JavaScript 代码
  • GB/T 28046.2-2019 道路车辆 电气及电子设备的环境条件和试验 第2部分:电气负荷(4)
  • 如何写好prompt以及评测prompt的好坏
  • 14.社团管理系统(基于springboot和vue)