YOLOv8 训练15种动物目标检测模型
1. 下载数据集
https://hyper.ai/datasets/31084
2. 将图片分类文件夹转换为yolo格式标注文件
(这里相当于将整张图片作为一个标注框,检测结果不太精确)
代码:
import json
import os
import shutil
import cv2
import matplotlib.pyplot as plt
"""
将文件夹分类转yolo格式
"""
target = "./animal_data/"
def convert(size, box):
dw = size[1]
dh = size[0]
# box x1 y1 x2 y2
x = (box[0] + box[2]) / 2.0
y = (box[1] + box[3]) / 2.0
w = box[2] - box[0]
h = box[3] - box[1]
x = x / dw
w = w / dw
y = y / dh
h = h / dh
if w >= 1:
w = 0.99
if h >= 1:
h = 0.99
return (x, y, w, h)
classify = os.listdir(target)
classify.remove("train")
train_dir = target + "train/"
train_images = train_dir + "images/"
train_labels = train_dir + "labels/"
if not os.path.exists(train_labels):
os.mkdir(train_labels)
os.mkdir(train_images)
for idx in range(len(classify)):
item = classify[idx]
clazz = target + item
# 遍历分类下面的所有图片
imgs = os.listdir(clazz)
for i in imgs:
shutil.copyfile(os.path.join(clazz, i), os.path.join(train_images, i))
img = cv2.imread(os.path.join(train_images, i))
filename, _ = os.path.splitext(i)
with open(train_labels + filename + ".txt", "w") as f:
box = convert(img.shape, (0, 0, img.shape[0], img.shape[1]))
f.write(str(idx)+" " + " ".join([str(a) for a in box]))
这里没对数据集进行train、test拆分
3. 进行训练
yaml:
path: C:\Users\lhq\Desktop\15-animals-data\animal_data
train: "C://Users//lhq//Desktop//15-animals-data//animal_data/train/"
val: "C://Users//lhq//Desktop//15-animals-data//animal_data/train/"
test: "C://Users//lhq//Desktop//15-animals-data//animal_data/train/"
nc: 15
names:
0: 熊
1: 鸟
2: 猫
3: 奶牛
4: 鹿
5: 狗
6: 海豚
7: 大象
8: 长颈鹿
9: 马
10: 袋鼠
11: 狮子
12: 熊猫
13: 老虎
14: 斑马
train.py:
from ultralytics import YOLO
from ultralytics.utils import DEFAULT_CFG
from datetime import datetime
current_time = datetime.now()
time_str = current_time.strftime("%Y-%m-%d_%H-%M-%S")
DEFAULT_CFG.save_dir = f"./models/{time_str}"
if __name__ == "__main__":
model = YOLO("yolov8n.pt")
# Train the model
results = model.train(data="animal.yaml", epochs=200, imgsz=224, device=0, save=True, save_period=1 ,batch=16)
4.检测
代码:
from ultralytics import YOLO
# Load a model
model = YOLO('best.pt')
# 可以是文件夹或图片
model.predict("u950468431,735343220fm253fmtauto.jpg", imgsz=224, save=True, device=0,plots=True)
检测结果:
部分数据推理不准确,比如:白色的猫可能检测出熊猫