深度学习之使用yolo网络训练kitti数据集:kitti数据集转换为VOC格式
参考博文:
YOLOv3训练KITTI数据集_kitti训练集-CSDN博客
数据集来源:
darknet yolov3 训练 kitti数据集_kitti数据集下载-CSDN博客
(这里需要下载4个压缩包,少了解压会出现报错)
xml_to_yolo_txt.py文件在将xml文件转换为txt文件时发生不一致报错,这是我修改后的代码:
import glob
import xml.etree.ElementTree as ET
# 这里的类名为我们 xml 里面的类名,顺序现在不需要考虑
class_names = ['Car', 'Cyclist', 'Pedestrian']
# xml 文件路径
path = '/xmls/'
# 转换一个 xml 文件为 txt
def single_xml_to_txt(xml_file):
tree = ET.parse(xml_file)
root = tree.getroot()
# 保存的 txt 文件路径
txt_file = xml_file.split('.')[0] + '.txt'
with open(txt_file, 'w') as txt_file:
for member in root.findall('object'):
picture_width = int(root.find('size')[0].text)
picture_height = int(root.find('size')[1].text)
class_name = member[0].text
class_num = class_names.index(class_name)
# 修改元素访问方式
box_x_min = int(member.find('bndbox')[0].text) # 左上角横坐标
box_y_min = int(member.find('bndbox')[1].text) # 左上角纵坐标
box_x_max = int(member.find('bndbox')[2].text) # 右下角横坐标
box_y_max = int(member.find('bndbox')[3].text) # 右下角纵坐标
x_center = float(box_x_min + box_x_max) / (2 * picture_width)
y_center = float(box_y_min + box_y_max) / (2 * picture_height)
width = float(box_x_max - box_x_min) / picture_width
height = float(box_y_max - box_y_min) / picture_height
print(class_num, x_center, y_center, width, height)
txt_file.write(str(class_num) + ' ' + str(x_center) + ' ' + str(y_center) + ' ' + str(width) + ' ' + str(height) + '\n')
# 转换文件夹下的所有 xml 文件为 txt
def dir_xml_to_txt(path):
for xml_file in glob.glob(path + '*.xml'):
single_xml_to_txt(xml_file)
dir_xml_to_txt(path)
xml文件和txt文件现在存放在了一个目录之下,我习惯将txt文件存放在labels文件下:
import os
import shutil
def move_txt_files(source_dir, destination_dir):
# 确保目标目录存在,如果不存在则创建
if not os.path.exists(destination_dir):
os.makedirs(destination_dir)
# 遍历源目录下的所有文件
for root, dirs, files in os.walk(source_dir):
for file in files:
if file.endswith('.txt'):
source_file = os.path.join(root, file)
destination_file = os.path.join(destination_dir, file)
# 移动文件
shutil.move(source_file, destination_file)
if __name__ == "__main__":
source_dir = '/xmls/'
destination_dir = '/labels/'
move_txt_files(source_dir, destination_dir)
后面生成train和val的代码建议去深度学习环境中进行。