当前位置: 首页 > article >正文

百度飞浆目标检测PPYOLOE模型在PC端、Jetson上的部署(python)

部署目标检测模型前,需要配置好paddlepaddle的环境:
开始使用_飞桨-源于产业实践的开源深度学习平台 (paddlepaddle.org.cn)

PC端和Jetson板卡端的部署方法相同,如下(直接放置部署和测试代码):

import paddle.inference
import cv2
import numpy as np
import time
from paddle.inference import Config, PrecisionType, create_predictor
import os
import json
import argparse
from PIL import Image
import yaml
import glob

def config_init(model_dir):
    model_file = glob.glob(model_dir + "/*.pdmodel")[0]
    params_file = glob.glob(model_dir + "/*.pdiparams")[0]

    config = Config()
    config.set_prog_file(model_file)
    config.set_params_file(params_file)
    config.switch_ir_optim()  # 开启IR优化
    config.enable_memory_optim()  # 启用内存优化
    config.enable_use_gpu(500, 0)  # 500是内存大小,0是GPU编号
    return create_predictor(config)  # 初始化一个预测器对象

class PPYoloeModel:
    def __init__(self, model_dir):
        self.yaml_file = glob.glob(model_dir + "/*.yml")[0]
        self.predictor = config_init(model_dir)  # 初始化一个预测器对象

        self.mean = [0.0, 0.0, 0.0]
        self.std = [1.0, 1.0, 1.0]

        self.img_size = 640
        self.threshold = 0.6  # 置信度阈值

    def read_yaml(self):
        with open(self.yaml_file, 'r', encoding='utf-8') as infer_cfg:  # 使用with语句确保文件正确关闭
            yaml_reader = yaml.safe_load(infer_cfg)  # 直接从文件对象中读取
            label_list = yaml_reader['label_list']  # 获取分类
        # print(label_list)
        return label_list

    def resize(self, img):
        if not isinstance(img, np.ndarray):
            raise TypeError('image type is not numpy.')
        im_shape = img.shape 
        im_scale_x = float(self.img_size) / float(im_shape[1])  # 可以保持宽高比
        im_scale_y = float(self.img_size) / float(im_shape[0])
        img = cv2.resize(img, None, None, fx=im_scale_x, fy=im_scale_y)  # fx和fy是缩放因子(倍数)
        return img

    def normalize(self, img):
        img = img / 255.0  # 归一化到[0,1]
        mean = np.array(self.mean)[np.newaxis, np.newaxis, :]  # np.newaxis增加新的轴(维度)
        std = np.array(self.std)[np.newaxis, np.newaxis, :]
        img -= mean  
        img /= std 
        return img

    def pre_process(self, img):
        img = self.resize(img)
        img = img[:, :, ::-1].astype('float32')  # bgr -> rgb
        img = self.normalize(img)
        img = img.transpose((2, 0, 1))  # hwc -> chw
        return img[np.newaxis, :]

    def run_predict(self, img):
        input_names = self.predictor.get_input_names()
        for i, name in enumerate(input_names):
            input_tensor = self.predictor.get_input_handle(name)  # 获取张量句柄
            input_tensor.reshape(img[i].shape)  # 重塑输入张量
            input_tensor.copy_from_cpu(img[i].copy())  # 从CPU复制图像数据到输入张量
        self.predictor.run()
        results = []
        output_names = self.predictor.get_output_names()
        for i, name in enumerate(output_names):
            output_tensor = self.predictor.get_output_handle(name)
            output_data = output_tensor.copy_to_cpu()  # 获取检测结果
            results.append(output_data)
        return results

    def get_result(self, img):
        result = []
        scale_factor = (
            np.array([self.img_size * 1.0 / img.shape[0], self.img_size * 1.0 / img.shape[1]])
            .reshape((1, 2))
            .astype(np.float32)
        )
        img = self.pre_process(img)
        results = self.run_predict([img, scale_factor])
        for res in results[0]:
            score = res[1]
            if score > self.threshold:
                result.append(res)  # 只保留置信度大于阈值的目标(同时可能有多个目标)
        return result
        # result是一个列表,里面可能存有一个或多个目标的信息
        # 对于单个目标,第一个值是类别的id,第二个值是置信度,后四个值分别是xmin, ymin, xmax, ymax

    def draw_bbox_cv(self, img):  # 对于实时摄像头图片的后处理
        label_list = self.read_yaml()
        result = self.get_result(img)

        for res in result:
            cat_id, score, bbox = res[0], res[1], res[2:]  # cat_id表示类别的id;score是得分;bbox是列表,含有xmin, ymin, xmax, ymax四个元素
            bbox = [int(i) for i in bbox]  # 对坐标取整
            xmin, ymin, xmax, ymax = bbox
            cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (255, 0, 255), 2)  # 2表示线宽
            print('category id is {}, bbox is {}'.format(cat_id, bbox))
            print(f"中心点:{(xmin + xmax) / 2},{(ymin + ymax) / 2}")
            try:
                label_id = label_list[int(cat_id)]  # 通过类别id查找对应类别名称
                # 在图像上打印类别   FONT_HERSHEY_SIMPLEX为字体类型
                cv2.putText(img, label_id, (int(xmin), int(ymin + 15)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0),
                            2)
                # 在图像上打印得分
                cv2.putText(img, str(round(score, 2)), (int(xmin - 35), int(ymin + 15)), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                            (0, 255, 0), 2)
                print(round(score, 2))
            except KeyError:
                pass

# 调用示例
if __name__ == "__main__":
    ppyoloe_model = PPYoloeModel("这里输入模型所在文件夹的路径")  # 初始化一个ppyoloe目标检测模型对象
    cap = cv2.VideoCapture(0)  # 初始化摄像头

    while True:
        ret, image = cap.read()  # 获取图像
        if not ret:
            print("无法正确读取图像!")
            break
        ppyoloe_model.draw_bbox_cv(image)
        cv2.imshow("frame", image)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

http://www.kler.cn/news/290678.html

相关文章:

  • React 创建和嵌套组件
  • 策略规划:在MySQL中实现数据恢复的全面指南
  • [Python图论]在用图nx.shortest_path求解最短路径时,节点之间有多条边edge,会如何处理?
  • 【MySQL】索引使用规则——(覆盖索引,单列索引,联合索引,前缀索引,SQL提示,数据分布影响,查询失效情况)
  • Proteus 仿真设计:开启电子工程创新之门
  • Unity3D中控制3D场景中游戏对象显示层级的详解
  • 构建数据恢复的硬件基础:MySQL中的硬件要求详解
  • draw.io图片保存路径如何设置
  • linux(ubuntu)安装QT-ros插件
  • Ferrari求解四次方程
  • VTK随笔十三:QT与VTK的交互
  • jupyter 笔记本中如何判定bash块是否执行完毕
  • CentOS7 yum 报错解决方案
  • FFmpeg源码:get_audio_frame_duration、av_get_audio_frame_duration2函数分析
  • Splasthop 安全远程访问帮助企业对抗 Cobalt Strike 载荷网络攻击
  • 鸿蒙(API 12 Beta6版)图形【NativeImage开发指导 (C/C++)】方舟2D图形服务
  • git---gitignore--忽略文件
  • 【C++】对比讲解构造函数和析构函数
  • 智能优化特征选择|基于鲸鱼WOA优化算法实现的特征选择研究Matlab程序(KNN分类器)
  • idea对项目中的文件操作没有权限
  • 海外合规|新加坡网络安全认证计划简介(三)-Cyber Trust
  • SpringBoot+Redis极简整合
  • springboot集成七牛云上传文件
  • Python画笔案例-030 实现打点之斜正方
  • MATLAB 中的对数计算
  • torch、torchvision、torchtext版本兼容问题
  • ubuntu 22.04安装NVIDIA驱动和CUDA
  • 传统CV算法——基于 SIFT 特征点检测与匹配实现全景图像拼接
  • Java实现根据某个字段对集合进行去重并手动选择被保留的对象
  • vuex 基础使用