百度飞浆目标检测PPYOLOE模型在PC端、Jetson上的部署(python)
部署目标检测模型前,需要配置好paddlepaddle的环境:
开始使用_飞桨-源于产业实践的开源深度学习平台 (paddlepaddle.org.cn)
PC端和Jetson板卡端的部署方法相同,如下(直接放置部署和测试代码):
import paddle.inference
import cv2
import numpy as np
import time
from paddle.inference import Config, PrecisionType, create_predictor
import os
import json
import argparse
from PIL import Image
import yaml
import glob
def config_init(model_dir):
model_file = glob.glob(model_dir + "/*.pdmodel")[0]
params_file = glob.glob(model_dir + "/*.pdiparams")[0]
config = Config()
config.set_prog_file(model_file)
config.set_params_file(params_file)
config.switch_ir_optim() # 开启IR优化
config.enable_memory_optim() # 启用内存优化
config.enable_use_gpu(500, 0) # 500是内存大小,0是GPU编号
return create_predictor(config) # 初始化一个预测器对象
class PPYoloeModel:
def __init__(self, model_dir):
self.yaml_file = glob.glob(model_dir + "/*.yml")[0]
self.predictor = config_init(model_dir) # 初始化一个预测器对象
self.mean = [0.0, 0.0, 0.0]
self.std = [1.0, 1.0, 1.0]
self.img_size = 640
self.threshold = 0.6 # 置信度阈值
def read_yaml(self):
with open(self.yaml_file, 'r', encoding='utf-8') as infer_cfg: # 使用with语句确保文件正确关闭
yaml_reader = yaml.safe_load(infer_cfg) # 直接从文件对象中读取
label_list = yaml_reader['label_list'] # 获取分类
# print(label_list)
return label_list
def resize(self, img):
if not isinstance(img, np.ndarray):
raise TypeError('image type is not numpy.')
im_shape = img.shape
im_scale_x = float(self.img_size) / float(im_shape[1]) # 可以保持宽高比
im_scale_y = float(self.img_size) / float(im_shape[0])
img = cv2.resize(img, None, None, fx=im_scale_x, fy=im_scale_y) # fx和fy是缩放因子(倍数)
return img
def normalize(self, img):
img = img / 255.0 # 归一化到[0,1]
mean = np.array(self.mean)[np.newaxis, np.newaxis, :] # np.newaxis增加新的轴(维度)
std = np.array(self.std)[np.newaxis, np.newaxis, :]
img -= mean
img /= std
return img
def pre_process(self, img):
img = self.resize(img)
img = img[:, :, ::-1].astype('float32') # bgr -> rgb
img = self.normalize(img)
img = img.transpose((2, 0, 1)) # hwc -> chw
return img[np.newaxis, :]
def run_predict(self, img):
input_names = self.predictor.get_input_names()
for i, name in enumerate(input_names):
input_tensor = self.predictor.get_input_handle(name) # 获取张量句柄
input_tensor.reshape(img[i].shape) # 重塑输入张量
input_tensor.copy_from_cpu(img[i].copy()) # 从CPU复制图像数据到输入张量
self.predictor.run()
results = []
output_names = self.predictor.get_output_names()
for i, name in enumerate(output_names):
output_tensor = self.predictor.get_output_handle(name)
output_data = output_tensor.copy_to_cpu() # 获取检测结果
results.append(output_data)
return results
def get_result(self, img):
result = []
scale_factor = (
np.array([self.img_size * 1.0 / img.shape[0], self.img_size * 1.0 / img.shape[1]])
.reshape((1, 2))
.astype(np.float32)
)
img = self.pre_process(img)
results = self.run_predict([img, scale_factor])
for res in results[0]:
score = res[1]
if score > self.threshold:
result.append(res) # 只保留置信度大于阈值的目标(同时可能有多个目标)
return result
# result是一个列表,里面可能存有一个或多个目标的信息
# 对于单个目标,第一个值是类别的id,第二个值是置信度,后四个值分别是xmin, ymin, xmax, ymax
def draw_bbox_cv(self, img): # 对于实时摄像头图片的后处理
label_list = self.read_yaml()
result = self.get_result(img)
for res in result:
cat_id, score, bbox = res[0], res[1], res[2:] # cat_id表示类别的id;score是得分;bbox是列表,含有xmin, ymin, xmax, ymax四个元素
bbox = [int(i) for i in bbox] # 对坐标取整
xmin, ymin, xmax, ymax = bbox
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (255, 0, 255), 2) # 2表示线宽
print('category id is {}, bbox is {}'.format(cat_id, bbox))
print(f"中心点:{(xmin + xmax) / 2},{(ymin + ymax) / 2}")
try:
label_id = label_list[int(cat_id)] # 通过类别id查找对应类别名称
# 在图像上打印类别 FONT_HERSHEY_SIMPLEX为字体类型
cv2.putText(img, label_id, (int(xmin), int(ymin + 15)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0),
2)
# 在图像上打印得分
cv2.putText(img, str(round(score, 2)), (int(xmin - 35), int(ymin + 15)), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(0, 255, 0), 2)
print(round(score, 2))
except KeyError:
pass
# 调用示例
if __name__ == "__main__":
ppyoloe_model = PPYoloeModel("这里输入模型所在文件夹的路径") # 初始化一个ppyoloe目标检测模型对象
cap = cv2.VideoCapture(0) # 初始化摄像头
while True:
ret, image = cap.read() # 获取图像
if not ret:
print("无法正确读取图像!")
break
ppyoloe_model.draw_bbox_cv(image)
cv2.imshow("frame", image)
if cv2.waitKey(1) & 0xFF == ord('q'):
break