当前位置: 首页 > article >正文

从零开始使用YOLOv11——Yolo检测detect数据集自建格式转换为模型训练格式:20w+图片1w+类别代码测试成功

        在之前的文章中记录了YOLO环境的配置安装和基本命令的一些使用,上一篇博文的地址快速链接:从零开始使用YOLOv8——环境配置与极简指令(CLI)操作:1篇文章解决—直接使用:模型部署 and 自建数据集:训练微调-CSDN博客

        使用YOLO作为目标检测任务的平台一个好处是,其搭建了非常简洁明了的训练命令行模式,可以便捷的对自建数据集进行微调。

        在对自己数据集进行模型训练前,非常重要费时的就是对数据的预处理,包括数据清洗、统计信息分析、数据格式转换。本文专注于将自己数据 json 格式转为YOLO训练支持的 txt 数据格式,并给出可以复用的数据集构建代码,代码已上传至Gitee平台。

        Gitee链接:https://gitee.com/machine-bai-xue/yolo-source-code-analysis

        如果链接失效,访问404拒绝,可以直接在Gitee码云主页搜索——“机器白学”,所有项目中的YOLO源码实验就是本系列所有实验代码。

        

目录

一、初始自建数据集Json格式

        1.文件存放格式

        2.标签JSON格式

二、YOLO训练数据集创建类

        1.直接使用

        2.可视化检查

        3.完整代码与扩展


一、初始自建数据集Json格式

        1.文件存放格式

        首先约定一下数据的初始格式,本文选择最简单的 JSON 列表数据集格式作标签保存。所有图片在一个文件夹(img)下,所有便签在另一个文件夹(json)下。

        2.标签JSON格式

        对标签数据具体来说,坐标数据(下图红框)和类别数据(下图蓝框)存放在同一个列表里,前四个为左上右下两点的xy绝对坐标值,类别字符信息放在后面。

        (如果还有除了类别外的其余信息可以直接在列表后添加,最后构建数据集只需取出对应索引即可。例如,对于同一批图片数据和坐标数据可能存在多种分类任务,就官方coco例子来说,其给的类别是详细分类后的物体名称——长颈鹿、花瓶、杯子......如果想构建一个大致分类的检测模型,如将细化类改为抽象类名称——动物、装饰品、日用品......只需在列表后面继续添加即可,如下第一个图第一个框可以改为——【385, 60, 600, 357, “giraffe”, “动物”】

二、YOLO训练数据集创建类

        将任意训练数据集搭建成初始的Json格式并存放按文件夹存放后,即可使用下面的数据转换类生成多个符合训练标准的数据集格式——包括训练集train和验证集val、yaml配置文件、txt标签文件等部分。

        1.直接使用

        首先总览整个类的使用。首先确保 opencv-python(cv2)和 pillow(PIL)库正确安装在环境里了。

        导入定义的转化类(可在文章最后直接复制,或者在Gitee地址下载对应py文件),实例初始化。初始化中三个关于文件地址的基本参数是必须存在的。

        初始化基本参数按输入顺序含义归纳在下面表格。

img_path所有图片存放文件地址(str)
label_path所有Json格式标签数据文件地址(str)
save_pathYolo数据集结果保存地址(str)

        另外初始化中还有几个可以调整的附加参数。其影响数据集搭建的某些细节部分。

train_ratio训练集占总数据量的比重,小数数据格式(float)
cls_id训练数据集中类别标签在原始数据中的索引位置(int,>=4)
seed设置打乱数据集文件名的随机种子——随机分配训练和验证数据(int)

         配置完参数后,直接使用类下的 dataset_main() 方法就可以自动生成训练验证数据集和yaml配置文件了。

        所有数据按照设定的划分比例随机采样分开。

        其中 cls_freq.json 是统计所有类别出现的频率字典,字典键对应类别名,值对应出现的次数。可以根据其频率查看哪些类训练样本偏少,决定是否要进行数据增强操作。

        .yaml 文件是YOLO训练的配置文件。其中names是按出现频率排序的类别和标签索引对。 

        2.可视化检查

        转换类中还定义了一些可视化函数,可以检查数据是否正确。其中对于初始数据只需直接使用visual_json_main() 方法即可。

        还可以直接使用类中的可视化函数,进行自定义的检查。

        3.完整代码与扩展

        下面将完整类代码放在下面,可以对其中相应函数方法进行修改实现扩展任务。欢迎批评指正。

import os
import json
import random
import cv2
import yaml
import numpy as np
from PIL import Image, ImageDraw, ImageFont

class YOLO_Dataset_Creator:
    def __init__(self, img_path, label_path, save_path, train_ratio=0.85, cls_id=4, seed=42):
        self.img_path, self.labels, self.data = img_path, label_path, save_path
        self.train = train_ratio
        self.cls = None
        self.cls_id = cls_id
        self.seed = seed

    def dataset_main(self):
        # 读取图片信息划分train和val集
        self.tr_name, self.val_name = self.divide_dataset(self.labels)
        # 根据json中类别信息生成配置yaml文件
        self.cls = self.generate_yaml(self.labels)
        # 生成yolo的txt训练数据集
        self.dataset_create(self.data)

    def divide_dataset(self, json_path):
        # 根据图片获取所有文件名信息
        total_file_list = []
        for file in os.listdir(json_path):
            if file.lower().split('.')[-1] in ['json']:
                base = file.split('.')[0]
                total_file_list.append(base)
        # 随机打乱后按比例生成训练train和验证val
        random.seed(self.seed)
        random.shuffle(total_file_list)
        length = len(total_file_list)
        tr_file_list = [tr for tr in total_file_list[:int(self.train * length)]]
        val_file_list = [te for te in total_file_list[int(self.train * length):]]
        return tr_file_list, val_file_list

    def statis_info(self, json_path):
        cls_dict = dict()
        for file in os.listdir(json_path):
            if file.lower().split('.')[-1] in ['json']:
                jsondir = os.path.join(json_path, file)
                with open(jsondir, 'r', encoding='utf-8') as f:
                    box_cls_list = json.load(f)
                for box_cls in box_cls_list:
                    cls = box_cls[self.cls_id]
                    if str(cls) not in cls_dict.keys():
                        cls_dict[str(cls)] = 1
                    else:
                        cls_dict[str(cls)] +=1
        cls_dictdir = os.path.join(self.data, 'cls_freq.json')
        with open(cls_dictdir, 'w') as f:
            json.dump(cls_dict, f)
        return cls_dict

    def generate_yaml(self, json_path):
        # 获得类别频率字典
        cls_dict = self.statis_info(json_path)
        # 生成yaml配置文件
        sorted_cls = sorted(cls_dict, key=cls_dict.get, reverse=True)
        names_dict = {}
        clses_dict = {}
        for id, c in enumerate(sorted_cls):
            names_dict[id] = c
            clses_dict[c] = id
        yaml_dict = {"path":self.data,
                     "train":"images/train",
                     "val":"images/val",
                     "names":names_dict}
        yaml_savedir = os.path.join(self.data, 'HP_Data.yaml')
        with open(yaml_savedir, "w") as f:
            yaml.dump(yaml_dict,f)
        print('yaml success')
        return clses_dict

    def dataset_create(self,data_path):
        # 必要子文件生成
        tr_img = os.path.join(data_path, 'images/train')
        va_img = os.path.join(data_path, 'images/val')
        tr_lab = os.path.join(data_path, 'labels/train')
        va_lab = os.path.join(data_path, 'labels/val')
        # 创建文件夹
        for file in [tr_img, va_img, tr_lab, va_lab]:
            os.makedirs(file, exist_ok=True)
        # train和val
        self._train(tr_img, tr_lab)
        self._val(va_img, va_lab)

    def _train(self, tr_img, tr_lab):
        # 生产训练集train
        for name in self.tr_name:
            print(name, 'start')
            # 图片复制保存移动
            imgsave = os.path.join(tr_img, name+'.jpg')
            jpgdir = os.path.join(self.img_path, name+'.jpg')
            # 标签
            txtsave = os.path.join(tr_lab, name+'.txt')
            jsondir = os.path.join(self.labels, name+'.json')
            self.label_create(jsondir, txtsave, name, jpgdir, imgsave)

    def _val(self, va_img, va_lab):
        # 生产验证集val
        for name in self.val_name:
            # 图片复制保存移动
            imgsave = os.path.join(va_img, name + '.jpg')
            jpgdir = os.path.join(self.img_path, name + '.jpg')
            # 标签
            txtsave = os.path.join(va_lab, name + '.txt')
            jsondir = os.path.join(self.labels, name + '.json')
            self.label_create(jsondir, txtsave, name, jpgdir, imgsave)

    def label_create(self, jsondir, txtsave, name, jpgdir, imgsave):
        # 图片信息
        img = cv2.imread(jpgdir)
        if img is None:
            return print(name, 'jpg is empty')
        height, width, _ = img.shape
        # 框信息
        with open(jsondir, 'r', encoding='utf-8') as f:
            box_list = json.load(f)

        box_str_list = []
        for box_cls in box_list:
            box = box_cls[:4]
            conv_box = self.normalize((width, height), box)
            text = box_cls[self.cls_id]
            cls = self.cls[str(text)]
            conv_box.insert(0, int(cls))
            box_str = " ".join(str(item) for item in conv_box)+'\n'
            box_str_list.append(box_str)
        if box_str_list!=[]:
            with open(txtsave, "w", encoding="utf-8") as f:
                f.writelines(box_str_list)

        if os.path.exists(txtsave):
            cv2.imwrite(imgsave, img)

    def normalize(self, size, box):  # size:(原图w,原图h) , box:(xmin,xmax,ymin,ymax)
        # 锚框归一化
        dw = 1. / size[0]  # 1/w
        dh = 1. / size[1]  # 1/h
        x = (box[0] + box[2]) / 2.0  # 物体在图中的中心点x坐标
        y = (box[1] + box[3]) / 2.0  # 物体在图中的中心点y坐标
        w = box[2] - box[0]  # 物体实际像素宽度
        h = box[3] - box[1]  # 物体实际像素高度
        x = x * dw  # 物体中心点x的坐标比(相当于 x/原图w)
        w = w * dw  # 物体宽度的宽度比(相当于 w/原图w)
        y = y * dh  # 物体中心点y的坐标比(相当于 y/原图h)
        h = h * dh  # 物体宽度的宽度比(相当于 h/原图h)
        return [x, y, w, h]  # 返回 相对于原图的物体中心点的x坐标比,y坐标比,宽度比,高度比,取值范围[0-1]

    def denormalize(self, size, normalized_box):  # size: (原图w, 原图h), normalized_box: [x, y, w, h]
        # 提取原图的宽度和高度
        w, h = size
        # 将中心点坐标和宽高还原为原图的像素坐标
        x_center = normalized_box[0] * w
        y_center = normalized_box[1] * h
        box_width = normalized_box[2] * w
        box_height = normalized_box[3] * h
        # 计算还原后的边界框坐标
        xmin = x_center - box_width / 2.0
        xmax = x_center + box_width / 2.0
        ymin = y_center - box_height / 2.0
        ymax = y_center + box_height / 2.0
        return [int(xmin), int(ymin), int(xmax), int(ymax)]  # 返回还原后的边界框坐标

    def visual_json_main(self, visfile):
        for file in os.listdir(self.img_path):
            base = file.split('.')[0]
            jpgdir = os.path.join(self.img_path, file)
            img = cv2.imread(jpgdir)
            if img is None:
                return print('img is empty')
            jsondir = os.path.join(self.labels, base+'.json')
            with open(jsondir, 'r', encoding='utf-8') as f:
                boxes_list = json.load(f)
            box_list = []
            for box_ in boxes_list:
                box = box_[:4]
                text = box_[self.cls_id]
                box.append(text)
                box_list.append(box)
            visdir = os.path.join(visfile, file)
            vis = self.visual_word(box_list, img ,(0,255,0))
            cv2.imwrite(visdir, vis)
            print(file,' success')

    def visual_box(self, box_list, img, color):
        for idx, box in enumerate(box_list):
            if len(box) == 4:
                l, t, r, b = box
                # 图片画框
                cv2.rectangle(img, (int(l), int(t)), (int(r), int(b)), color, thickness=2, lineType=cv2.LINE_AA)
            elif len(box) == 8:
                pts = np.array(box, np.int32)
                pts = pts.reshape((-1, 1, 2))
                cv2.polylines(img, [pts], isClosed=True, color=color, thickness=2)
        return img

    def visual_word(self, word_list, img, color):
        font_path = "C:\Windows\Fonts\SimHei.ttf"
        font_size = 45
        for idx, box in enumerate(word_list):
            l, t, r, b, word = box
            # 图片画框
            cv2.rectangle(img, (l, t), (r, b), color, thickness=2, lineType=cv2.LINE_AA)
            caption = f"{word}"
            pil_img = Image.new('RGB', (0, 0))
            draw = ImageDraw.Draw(pil_img)
            text_size = draw.textbbox((0, 0), caption, font=ImageFont.truetype(font_path, font_size))
            text_width = text_size[2] - text_size[0]
            text_height = text_size[3] - text_size[1]
            # 使用 OpenCV 画文本背景框
            cv2.rectangle(img, (r, t), (r + text_width, t + text_height), color, -1)
            # 使用 PIL 绘制中文标签
            img = self.put_chinese_text(img, caption, (r, t), font_size, font_path, (0, 0, 0))
        return img

    def put_chinese_text(self, img, text, position, font_size, font_path, text_color):
        # 创建一个 PIL 图像
        pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        draw = ImageDraw.Draw(pil_img)
        # 加载字体
        font = ImageFont.truetype(font_path, font_size)
        # 绘制文本
        draw.text(position, text, font=font, fill=text_color)
        # 将 PIL 图像转换回 OpenCV 格式
        img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
        return img

http://www.kler.cn/a/392065.html

相关文章:

  • C#发票识别、发票查验接口集成、电子发票(航空运输电子行程单)
  • GxtWaitCursor:Qt下基于RAII的鼠标等待光标类
  • 软件工程概论项目(二),node.js的配置,npm的使用与vue的安装
  • ubuntu20.04安装FLIR灰点相机BFS-PGE-16S2C-CS的ROS驱动
  • python装饰器的使用以及私有化
  • 软件测试面试2024最新热点问题
  • PointMamba: A Simple State Space Model for Point Cloud Analysis——点云论文阅读(10)
  • 边缘计算与推理算力:智能时代的加速引擎
  • 开源大模型推理引擎现状及常见推理优化方法总结
  • ubontu安装anaconda
  • 简单理解回调函数
  • Jenkins配置步骤
  • Spring学习笔记_30——事务接口PlatformTransactionManager
  • 汽车牌照识别系统的设计与仿真(论文+源码)
  • Vue 组件间传值指南:Vue 组件通信的七种方法
  • 软考系统架构设计师论文:论多源数据集成及应用
  • 企业“3D官网”主要有哪些功能?
  • labview实现定时器的功能
  • ❤React-React 组件基础(类组件)
  • Redhat7.9 安装 KingbaseES 金仓数据库 V9单机版(命令行安装)
  • 【设计模式】单例设计模式
  • openresty入门教程:ngx.print ngx.say ngx.log
  • Java LeetCode练习
  • Unity3D
  • 八、Spring Boot集成Spring Security之前后分离认证最佳实现测试
  • 多个摄像机画面融合:找到同一个目标在多个画面中的伪三维坐标,找出这几个摄像头间的转换矩阵