当前位置: 首页 > article >正文

[目标检测] 训练之前要做什么

背景:训练一个Yolo8模型,在训练之前,数据集的处理是影响效果的关键因素。

Step1 定义规则

什么是人/车,比如人的话可能是站着的人,如果是骑电动车/自行车就不算是人。

Step2 收集数据集

1. 自己标注。如果是自己标注,那么根据上述的规则进行清洗。

2. 采集他人的数据集,标注好的。那么最好是能可视化一下标签。下面是可视化代码。

import os
import cv2

def visualize_yolo_boxes(image_folder, label_folder, output_folder):
    # 确保输出文件夹存在
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # 遍历图片文件夹
    for image_name in os.listdir(image_folder):
        # 获取图片路径
        image_path = os.path.join(image_folder, image_name)
        # 获取对应的标签路径
        label_path = os.path.join(label_folder, os.path.splitext(image_name)[0] + '.txt')

        # 如果标签文件不存在,跳过
        if not os.path.exists(label_path):
            print(f"标签文件不存在: {label_path}")
            continue

        # 读取图片
        image = cv2.imread(image_path)
        if image is None:
            print(f"无法读取图片: {image_path}")
            continue

        # 获取图片的宽度和高度
        img_height, img_width = image.shape[:2]

        # 读取标签文件
        with open(label_path, 'r') as f:
            lines = f.readlines()

        # 遍历每个标签
        for line in lines:
            parts = line.strip().split()
            if len(parts) != 5:
                continue

            # 解析YOLO格式的标签
            class_id = int(parts[0])
            x_center = float(parts[1]) * img_width
            y_center = float(parts[2]) * img_height
            width = float(parts[3]) * img_width
            height = float(parts[4]) * img_height

            # 计算边界框的左上角和右下角坐标
            x1 = int(x_center - width / 2)
            y1 = int(y_center - height / 2)
            x2 = int(x_center + width / 2)
            y2 = int(y_center + height / 2)

            # 在图片上绘制边界框
            cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(image, f'Class {class_id}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)

        # 保存可视化结果
        output_path = os.path.join(output_folder, image_name)
        cv2.imwrite(output_path, image)
        print(f"已保存可视化结果: {output_path}")

# 示例用法
image_folder = r'.\train\images'
label_folder = r'.\train\labels'
output_folder = r'.\train\visualize'

visualize_yolo_boxes(image_folder, label_folder, output_folder)

Step3 修改标签

修改标签是指,比如要合并一些标签。已经标注好的数据集,比如把车子分为了truck,car,bus之类的,但是,我们都统称为car。所以要进行一些合并。下面是合并代码。

import os
import glob

def process_line(line):
    """Process a single line of text according to the mapping rules."""
    parts = line.strip().split()
    if not parts:  # Skip empty lines
        return None
        
    try:
        id_num = int(parts[0])
        # Keep only specified IDs
        if id_num not in [0, 1, 2, 6]:
            return None
            
        # Apply mapping: 0,1 -> 0 and 3,4,5,8 -> 1
        if id_num in [0, 1, 6]:
            new_id = 1
        else:  # id in [3,4,5,8]
            new_id = 0
            
        # Replace first number and keep rest of the line the same
        return f"{new_id} {' '.join(parts[1:])}\n"
        
    except ValueError:
        return None

def process_file(input_path, output_path):
    """Process a single text file and save to output directory."""
    try:
        with open(input_path, 'r') as infile:
            lines = infile.readlines()
            
        # Process lines and filter out None results
        processed_lines = [
            processed for line in lines
            if (processed := process_line(line)) is not None
        ]
        
        # Create output directory if it doesn't exist
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        # Write processed lines to output file
        with open(output_path, 'w') as outfile:
            outfile.writelines(processed_lines)
            
        print(f"Processed {input_path} -> {output_path}")
        
    except Exception as e:
        print(f"Error processing {input_path}: {str(e)}")

def process_directory(input_dir, output_dir):
    """Process all .txt files in the input directory."""
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Find all .txt files in input directory
    txt_files = glob.glob(os.path.join(input_dir, "*.txt"))
    
    for input_path in txt_files:
        # Create corresponding output path
        relative_path = os.path.relpath(input_path, input_dir)
        output_path = os.path.join(output_dir, relative_path)
        
        # Process the file
        process_file(input_path, output_path)

# Example usage
if __name__ == "__main__":
    input_directory = r"F:\1-dataset\raw\add_vehicle_person\combine\vehicle_class\valid\labels"  # Replace with your input directory
    output_directory = r"F:\1-dataset\raw\add_vehicle_person\combine\vehicle_class\valid\labels_person_car"  # Replace with your output directory
    
    process_directory(input_directory, output_directory)

 


http://www.kler.cn/a/585461.html

相关文章:

  • 【深度学习|目标检测】YOLO系列anchor-based原理详解
  • SpringBoot旅游管理系统的设计与实现
  • [Java]栈 虚拟机栈 栈帧讲解
  • 蓝桥每日打卡--查找有序数组中的目标值
  • kotlin与MVVM的结合使用总结(二)
  • 工业三防平板AORO-P300 Ultra,开创铁路检修与调度数字化新范式
  • python多种数据类型输出为Excel文件
  • 【模块化编程】数据标签 转 独热编码
  • SSL 和 TLS 认证
  • 汉朔科技业绩高增长:市占率国内外遥遥领先,核心技术创新强劲
  • 六十天前端强化训练之第十七天React Hooks 入门:useState 深度解析
  • 嵌入式硬件--开发工具-AD使用常用操作
  • 今日《AI-人工智能-编程》-3月13日
  • 音视频处理工具 FFmpeg 指令的使用(超级详细!)
  • 电子电子架构 --- 车载ECU信息安全
  • Golang | 每日一练 (5)
  • LabVIEW电池内阻精确测量系统
  • Python基于深度学习的身份证识别考勤系统【附源码、文档说明】
  • 数据炼丹与硬件互动:预测湿度的武学之道
  • 【day13】营销系统:优惠券核销流程