基于YOLOv8与SKU110K数据集实现超市货架物品目标检测与计算
概述
本文旨在基于检测到的物品位置信息,分析、计数并提取相关目标。通过对检测结果的坐标数据进行分析,将确定货架的数量以及货架上的物品数量。为此,这里将使用 SKU110K 数据集来训练目标检测模型。该数据集包含商店货架上物品的边界框标注,仅包含一个名为“物品”的类别。
SKU110K数据集
在 SKU-110K 是专注于密集场景下的目标检测问题。此类场景中的图像包含大量外观相似甚至相同的物体,且物体位置紧密相邻。这些场景通常是人造的,例如零售货架展示、交通和城市景观图像。
数据集 SKU-110K 中的一张典型图像,展示了密集排列的物体
算法实现
创建环境与导入相关库,使用预训练模型进行预测,并将结果赋值给一个变量。
conda create -n yolov8 python=3.8
activate ylolv8
conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install ultralytics
测试环境:
import numpy as np
from ultralytics import YOLO
model = YOLO('best.pt')
result = model.predict(
source='test_88.jpg',
conf=0.45,
save=True
)
模型训练
关于模型训练,可以看之前关于YOLOv8如果训练自定义数据
提取边界框坐标并排序
提取边界框的 xyxy 格式坐标,并将其转换为 NumPy 数组。这些数据分别代表边界框的 xmin、ymin、xmax 和 ymax 坐标。以下是前 25 个输出:
arrxy = result[0].boxes.xyxy
coordinates = np.array(arrxy)
coordinates[:25]
array([[ 2082, 1426, 2318, 1635],
[ 2356, 1106, 2678, 1321],
[ 1927, 2442, 2284, 2799],
[ 647, 961, 865, 1149],
[ 2101, 1644, 2323, 1841],
[ 1565, 2472, 1913, 2822],
[ 2334, 1420, 2567, 1640],
[ 1094, 957, 1301, 1138],
[ 967, 3186, 1243, 3369],
[ 873, 956, 1087, 1148],
[ 739, 3466, 993, 3662],
[ 1318, 968, 1512, 1138],
[ 1528, 2948, 1782, 3159],
[ 2109, 1914, 2497, 2128],
[ 1329, 2482, 1559, 2662],
[ 1264, 2947, 1522, 3156],
[ 466, 3655, 738, 3847],
[ 2139, 2135, 2527, 2330],
[ 1526, 1129, 1764, 1330],
[ 468, 3469, 731, 3646],
[ 691, 2976, 963, 3181],
[ 2233, 384, 2458, 582],
[ 1256, 3161, 1509, 3370],
[ 426, 0, 913, 154],
[ 975, 2964, 1252, 3179]], dtype=float32)
由于图像的 (0,0) 坐标代表图像的左上角,因此需要相应地对它们进行排序。为此,计算 x 和 y 坐标的中点,并根据 y 坐标进行排序。
arrxy = result[0].boxes.xyxy
coordinates = np.array(arrxy)
x_coords = (coordinates[:, 0] + coordinates[:, 2]) / 2
y_coords = (coordinates[:, 1] + coordinates[:, 3]) / 2
midpoints = np.column_stack((x_coords, y_coords))
rounded_n_sorted_arr = np.round(midpoints[midpoints[:, 1].argsort()]).astype(int)
print(rounded_n_sorted_arr[:25])
[[2762 63]
[2463 66]
[1998 68]
[ 670 77]
[ 241 80]
[1547 370]
[ 978 378]
[1370 399]
[2088 416]
[2102 476]
[ 916 478]
[2346 483]
[ 363 504]
[2774 514]
[1842 527]
[1392 542]
[2559 544]
[1178 552]
[1599 554]
[ 652 579]
[ 916 662]
[3009 677]
[2778 678]
[2122 684]
[2364 688]]
使用 OpenCV 可视化数据
使用 OpenCV 分析物品之间的坐标关系,并通过将物品中心的坐标用深蓝色圆圈标出,直观地理解这些关系。从排序后的坐标可以看出,y 轴的显著增加表明物品已经跳到了货架上。
例如,y 轴从 80 增加到 370,变化非常大,因此可以断定这是一个货架。
计数物品和货架
从分析中可以看出,如果 y 轴的增加超过某个值,就可以认为是一个货架。在编写代码时,有机会在不同图像上进行测试,可以说如果图像捕捉到了所有货架并且是平行拍摄的,那么 y 轴的跳跃值超过 130 就可以认为是一个货架。
count = 1
objects = 0
group_sizes = []
for i in range(1, len(rounded_n_sorted_arr)):
if rounded_n_sorted_arr[i][1] - rounded_n_sorted_arr[i - 1][1] > 130:
group_sizes.append(objects + 1)
count += 1
objects = 0
else:
objects += 1
group_sizes.append(objects + 1)
for i, size in enumerate(group_sizes):
print(f"第 {i + 1} 层货架上有 {size} 件商品")
转换为参数化程序
将所有代码整合在一起。使用 argparse
库,可以通过以下代码提供输入文件路径。作为输出,将打印出每个货架上有多少件商品的数据。
from typing import List
import numpy as np
from ultralytics import YOLO
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--image_path', type=str, required=True, help='图像路径')
args = parser.parse_args()
image_path = args.image_path
class ShelfDetector:
def __init__(self, model_path: str, confidence: float = 0.45):
self.model = YOLO(model_path)
self.confidence = confidence
def detect_shelves(self, image_path: str) -> List[int]:
result = self.model.predict(source=image_path, conf=self.confidence, save=False)
arrxy = result[0].boxes.xyxy
coordinates = np.array(arrxy)
x_coords = (coordinates[:, 0] + coordinates[:, 2]) / 2
y_coords = (coordinates[:, 1] + coordinates[:, 3]) / 2
midpoints = np.column_stack((x_coords, y_coords))
sorted_midpoints = midpoints[midpoints[:, 1].argsort()]
rounded_n_sorted_arr = np.round(sorted_midpoints).astype(int)
group_sizes = []
objects = 0
for i in range(1, len(rounded_n_sorted_arr)):
if rounded_n_sorted_arr[i][1] - rounded_n_sorted_arr[i - 1][1] > 130:
group_sizes.append(objects + 1)
objects = 0
else:
objects += 1
group_sizes.append(objects + 1)
return group_sizes
detector = ShelfDetector('best.pt')
result = detector.detect_shelves(image_path)
for i, size in enumerate(result):
print(f"第 {i + 1} 层货架上有 {size} 件商品")
在我们的 CMD 提示符中,输出将如下所示。通过这种方式,我们可以提供不同图像的路径来运行代码。