当前位置: 首页 > article >正文

使用onnxruntime c++ API实现yolov5m视频检测

@[使用onnxruntime c++ API实现yolov5m视频检测]

本文演示了yolov5m从模型导出到onnxruntime推理的过程

一.创建容器

docker run --shm-size=32g -ti  --privileged --net=host \
    --rm \
    -v $PWD:/home -w /home ghcr.io/intel/llvm/ubuntu2204_base /bin/bash

二.安装依赖

apt install libopencv-dev -y
wget https://github.com/microsoft/onnxruntime/releases/download/v1.19.2/onnxruntime-linux-x64-1.19.2.tgz
tar -xf onnxruntime-linux-x64-1.19.2.tgz

三.生成onnx模型

rm yolov5 -rf
git clone https://github.com/ultralytics/yolov5.git
cd yolov5
pip install -r requirements.txt
wget https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5m.pt
python export.py --weights yolov5m.pt --include onnx --img 640
mv yolov5m.onnx ../
cd ..

四.生成类别名

tee gen_names.py<<-'EOF'
import yaml
data=yaml.safe_load(open("yolov5/data/coco.yaml","r"))
with open("coco.names","w") as f:
    for name in list(data['names'].values()):
        f.write(f"{name}\n")
EOF
python gen_names.py

五.运行测试程序

tee yolov5_onnxruntime.cpp<<-'EOF'
#include <iostream>
#include <fstream>
#include <vector>
#include <algorithm>
#include <opencv2/opencv.hpp>
#include <onnxruntime_cxx_api.h>

using namespace std;
using namespace cv;

// NMS参数
float confThreshold = 0.25; // 置信度阈值
float nmsThreshold = 0.45;  // NMS 阈值
int inpWidth = 640;         // 网络输入宽度
int inpHeight = 640;        // 网络输入高度

// COCO 数据集类别名称
vector<string> classes;

// 加载类别名称
void loadClasses(const string& classesFile) {
    ifstream ifs(classesFile.c_str());
    string line;
    while (getline(ifs, line)) {
        classes.push_back(line);
    }
}

// 后处理,解析模型输出并进行NMS
void postprocess(const Mat& frame, const vector<vector<Mat>>& outputs);

// 自定义 NMSBoxes 函数
void NMSBoxesCustom(const vector<Rect>& boxes, const vector<float>& scores, float scoreThreshold, float nmsThreshold, vector<int>& indices);

int main(int argc, char** argv) {
    // 检查参数
    if (argc != 4) {
        cout << "用法: ./yolov5_onnxruntime <yolov5.onnx> <classes.names> <input.mp4>" << endl;
        return -1;
    }

    string model_path = argv[1];
    string classesFile = argv[2];
    string video_path = argv[3];

    // 加载类别名称
    loadClasses(classesFile);

    // 初始化 ONNX Runtime 环境
    Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "YoloV5");
    Ort::SessionOptions session_options;
    session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);

    // 如果需要使用GPU,需要启用CUDA
    // OrtCUDAProviderOptions cuda_options;
    // session_options.AppendExecutionProvider_CUDA(cuda_options);

    // 创建会话
    printf("%s\n",model_path.c_str());
    Ort::Session session(env, model_path.c_str(), session_options);

    // 获取输入输出节点信息
    Ort::AllocatorWithDefaultOptions allocator;

    // 输入节点
    size_t num_input_nodes = session.GetInputCount();
    printf("num_input_nodes:%d\n",num_input_nodes);
    vector<const char*> input_node_names(num_input_nodes);
    vector<int64_t> input_node_dims;
    for (int i = 0; i < num_input_nodes; i++) {
        // 获取输入节点名
        Ort::AllocatedStringPtr input_name = session.GetInputNameAllocated(i, allocator);
        printf("input_name:%s\n",input_name.get());
        input_node_names[i] = strdup(input_name.get());

        // 获取输入节点维度
        Ort::TypeInfo type_info = session.GetInputTypeInfo(i);
        auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
        input_node_dims = tensor_info.GetShape();
    }

    // 输出节点
    size_t num_output_nodes = session.GetOutputCount();
    printf("num_output_nodes:%d\n",num_output_nodes);
    vector<const char*> output_node_names(num_output_nodes);
    for (int i = 0; i < num_output_nodes; i++) {
        Ort::AllocatedStringPtr output_name = session.GetOutputNameAllocated(i, allocator);
        printf("output_name:%s\n",output_name.get());
        output_node_names[i] = strdup(output_name.get());
    }
    
    // 打开视频文件
    VideoCapture cap(video_path);
    if (!cap.isOpened()) {
        cerr << "无法打开视频文件!" << endl;
        return -1;
    }

    Mat frame;
    while (cap.read(frame)) {
        // 图像预处理
        Mat img;
        resize(frame, img, Size(inpWidth, inpHeight));
        cvtColor(img, img, COLOR_BGR2RGB);
        img.convertTo(img, CV_32F, 1.0 / 255.0);

        // 转换为CHW格式
        vector<float> img_data;
        int channels = img.channels();
        int img_h = img.rows;
        int img_w = img.cols;
        img_data.resize(channels * img_h * img_w);

        vector<Mat> chw;
        for (int i = 0; i < channels; ++i) {
            Mat channel(img.rows, img.cols, CV_32FC1, img_data.data() + i * img_h * img_w);
            chw.push_back(channel);
        }
        split(img, chw);

        // 创建输入张量
        array<int64_t, 4> input_shape{1, channels, img_h, img_w};
        size_t input_tensor_size = img_data.size();
        Ort::Value input_tensor = Ort::Value::CreateTensor<float>(allocator.GetInfo(), img_data.data(), input_tensor_size, input_shape.data(), input_shape.size());

        // 进行推理
        auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_node_names.data(), &input_tensor, 1, output_node_names.data(), num_output_nodes);

        // 解析输出
        vector<vector<Mat>> outputs;
        for (auto& tensor : output_tensors) {
            float* output_data = tensor.GetTensorMutableData<float>();
            auto type_info = tensor.GetTensorTypeAndShapeInfo();
            vector<int64_t> output_shape = type_info.GetShape();

            // 将输出数据转换为 Mat
            int rows = output_shape[1];
            int dimensions = output_shape[2];
            Mat output = Mat(rows, dimensions, CV_32F, output_data);

            // 将输出添加到列表
            outputs.push_back({output});
        }

        // 后处理
        postprocess(frame, outputs);

        // 显示结果
        imwrite("Detection.jpg", frame);
        break;
    }

    cap.release();
    destroyAllWindows();
    return 0;
}

void postprocess(const Mat& frame, const vector<vector<Mat>>& outputs) {
    // 存储检测结果
    vector<int> classIds;
    vector<float> confidences;
    vector<Rect> boxes;

    int img_w = frame.cols;
    int img_h = frame.rows;

    float x_factor = img_w / (float)inpWidth;
    float y_factor = img_h / (float)inpHeight;

    // 遍历检测结果
    for (size_t i = 0; i < outputs.size(); ++i) {
        Mat detections = outputs[i][0];
        int rows = detections.rows;

        for (int r = 0; r < rows; ++r) {
            float confidence = detections.at<float>(r, 4);

            if (confidence >= confThreshold) {
                Mat scores = detections.row(r).colRange(5, detections.cols);
                Point classIdPoint;
                double maxClassScore;
                minMaxLoc(scores, 0, &maxClassScore, 0, &classIdPoint);

                if (maxClassScore >= confThreshold) {
                    // 解析坐标
                    float cx = detections.at<float>(r, 0);
                    float cy = detections.at<float>(r, 1);
                    float w = detections.at<float>(r, 2);
                    float h = detections.at<float>(r, 3);

                    int left = int((cx - 0.5 * w) * x_factor);
                    int top = int((cy - 0.5 * h) * y_factor);
                    int width = int(w * x_factor);
                    int height = int(h * y_factor);

                    classIds.push_back(classIdPoint.x);
                    confidences.push_back(confidence);
                    boxes.push_back(Rect(left, top, width, height));
                }
            }
        }
    }

    // 执行自定义非极大值抑制
    vector<int> indices;
    NMSBoxesCustom(boxes, confidences, confThreshold, nmsThreshold, indices);

    // 绘制检测框
    for (size_t i = 0; i < indices.size(); ++i) {
        int idx = indices[i];
        Rect box = boxes[idx];

        // 绘制边界框
        rectangle(frame, box, Scalar(0, 255, 0), 2);

        // 显示类别名称和置信度
        string label = format("%.2f", confidences[idx]);
        if (!classes.empty()) {
            CV_Assert(classIds[idx] < (int)classes.size());
            label = classes[classIds[idx]] + ":" + label;
        }
        printf("%02d %04d,%04d,%04d,%04d\n",classIds[idx],box.x,box.y,box.width,box.height);
        int baseLine;
        Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
        int top = max(box.y, labelSize.height);
        rectangle(frame, Point(box.x, top - labelSize.height),
                  Point(box.x + labelSize.width, top + baseLine),
                  Scalar::all(255), FILLED);
        putText(frame, label, Point(box.x, top),
                FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0,0,0), 1);
    }
}

// 自定义 NMSBoxes 函数实现
void NMSBoxesCustom(const vector<Rect>& boxes, const vector<float>& scores, float scoreThreshold, float nmsThreshold, vector<int>& indices) {
    // 创建一个向量,包含每个框的索引
    vector<int> idxs;
    for (size_t i = 0; i < scores.size(); ++i) {
        if (scores[i] >= scoreThreshold) {
            idxs.push_back(i);
        }
    }

    // 如果没有满足条件的框,返回空的索引
    if (idxs.empty()) {
        return;
    }

    // 根据置信度分数对索引进行排序(从高到低)
    sort(idxs.begin(), idxs.end(), [&scores](int i1, int i2) {
        return scores[i1] > scores[i2];
    });

    vector<bool> suppressed(idxs.size(), false);

    // 进行 NMS 处理
    for (size_t i = 0; i < idxs.size(); ++i) {
        if (suppressed[i]) {
            continue;
        }
        int idx_i = idxs[i];
        indices.push_back(idx_i);
        Rect box_i = boxes[idx_i];

        for (size_t j = i + 1; j < idxs.size(); ++j) {
            if (suppressed[j]) {
                continue;
            }
            int idx_j = idxs[j];
            Rect box_j = boxes[idx_j];

            // 计算 IoU(交并比)
            float iou = (box_i & box_j).area() / float((box_i | box_j).area());

            // 如果 IoU 大于阈值,抑制当前框
            if (iou > nmsThreshold) {
                suppressed[j] = true;
            }
        }
    }
}
EOF
g++ yolov5_onnxruntime.cpp -o yolov5_onnxruntime  `pkg-config --cflags --libs opencv4` \
    -I onnxruntime-linux-x64-1.19.2/include -L onnxruntime-linux-x64-1.19.2/lib -lonnxruntime \
    -Wl,-rpath onnxruntime-linux-x64-1.19.2/lib
./yolov5_onnxruntime yolov5m.onnx coco.names input.mp4    

http://www.kler.cn/a/384244.html

相关文章:

  • 道可云人工智能元宇宙每日资讯|2024国际虚拟现实创新大会将在青岛举办
  • 信号-2-信号捕捉
  • 使用Python简单实现客户端界面
  • 汽修行业员工培训SOP的智能化搭建
  • xrc的比赛
  • SDL基本使用
  • 进入半导体行业需要具备哪些能力?
  • Scala的List
  • 计算机体系结构知识(一)
  • 前端零基础学习Day-Six
  • MySQL 导出数据
  • 鸿蒙多线程开发——并发模型对比(Actor与内存共享)
  • qt QTextDocument详解
  • 56合并区间 go解题
  • 【经验分享】六西格玛管理培训适合哪些人参加?
  • docker 拉取MySQL8.0镜像以及安装
  • C#笔记(4)
  • 带点符号的 TypeScript 实用程序类型 NestedKeyOf 在严格模式下失败
  • 卷积神经网络——paddle部分
  • 初阶数据结构【单链表及其接口的实现】
  • 分数阶傅里叶变换与信息熵怎么用于信号处理?
  • 基于SpringBoot+Vue+HTML的美食食谱系统的设计与实现
  • Spark程序的监控
  • Python配合Flask搭建简单的个人博客案例demo
  • 【react】Redux基础用法
  • 【Linux】进程控制——创建,终止,等待回收