当前位置: 首页 > article >正文

C++:Opencv读取ONNX模型,通俗易懂

1. 准备 ONNX 模型

假设你已经有一个训练好的 ONNX 模型文件。可以从各类深度学习框架(如 PyTorch、TensorFlow)中导出 ONNX 模型。例如,下面是一个简单的 PyTorch 模型导出为 ONNX 文件的示例:

import torch
import torchvision.models as models

# Load a pre-trained model (e.g., ResNet18)
model = models.resnet18(pretrained=True)
model.eval()

# Dummy input for tracing
dummy_input = torch.randn(1, 3, 224, 224)

# Export the model to ONNX format
torch.onnx.export(model, dummy_input, "resnet18.onnx")

2. 读取 ONNX 模型

在 OpenCV 中,你可以使用 cv::dnn::readNetFromONNX 函数加载 ONNX 模型。

#include <opencv2/opencv.hpp>
#include <opencv2/dnn.hpp>

int main() {
    // 模型文件路径
    std::string modelFile = "resnet18.onnx";

    // 从 ONNX 文件中读取模型
    cv::dnn::Net net = cv::dnn::readNetFromONNX(modelFile);

    // 检查模型是否成功加载
    if (net.empty()) {
        std::cerr << "Failed to load network!" << std::endl;
        return -1;
    }

    return 0;
}

3. 预处理输入图像

在进行推理之前,需要将输入图像预处理成模型所需的格式。通常,这包括调整图像大小、归一化等。

#include <opencv2/opencv.hpp>
#include <opencv2/dnn.hpp>

int main() {
    // 模型文件路径
    std::string modelFile = "resnet18.onnx";
    cv::dnn::Net net = cv::dnn::readNetFromONNX(modelFile);

    if (net.empty()) {
        std::cerr << "Failed to load network!" << std::endl;
        return -1;
    }

    // 读取输入图像
    cv::Mat img = cv::imread("image.jpg");

    if (img.empty()) {
        std::cerr << "Failed to read image!" << std::endl;
        return -1;
    }

    // 将图像调整为模型所需的大小和格式
    cv::Mat blob = cv::dnn::blobFromImage(img, 1.0, cv::Size(224, 224), cv::Scalar(104.0, 117.0, 123.0), true, false);

    // 设置网络的输入
    net.setInput(blob);

    // 执行前向传播以获得输出
    cv::Mat output = net.forward();

    // 输出处理
    std::cout << "Output size: " << output.size << std::endl;

    return 0;
}

4. 进行推理

在前面的代码中,已经包含了执行推理的步骤。net.forward() 函数会返回模型的输出结果。

5. 处理和显示结果

通常,推理结果需要根据模型的输出进行处理。例如,如果是图像分类模型,你可能需要将输出的向量映射到类别标签

#include <opencv2/opencv.hpp>
#include <opencv2/dnn.hpp>
#include <iostream>

int main() {
    // 模型文件路径
    std::string modelFile = "resnet18.onnx";
    cv::dnn::Net net = cv::dnn::readNetFromONNX(modelFile);

    if (net.empty()) {
        std::cerr << "Failed to load network!" << std::endl;
        return -1;
    }

    // 读取输入图像
    cv::Mat img = cv::imread("image.jpg");

    if (img.empty()) {
        std::cerr << "Failed to read image!" << std::endl;
        return -1;
    }

    // 将图像调整为模型所需的大小和格式
    cv::Mat blob = cv::dnn::blobFromImage(img, 1.0, cv::Size(224, 224), cv::Scalar(104.0, 117.0, 123.0), true, false);

    // 设置网络的输入
    net.setInput(blob);

    // 执行前向传播以获得输出
    cv::Mat output = net.forward();

    // 处理输出
    cv::Point classId;
    double confidence;
    cv::minMaxLoc(output, 0, &confidence, 0, &classId);
    
    std::cout << "Predicted class ID: " << classId.x << std::endl;
    std::cout << "Confidence: " << confidence << std::endl;

    return 0;
}


http://www.kler.cn/a/290511.html

相关文章:

  • linux centos挂载未分配的磁盘空间
  • 【生物信息】如何使用 h5py 读取 HDF5 格式文件中的数据并将其转换为 NumPy 数组
  • Ubuntu中使用miniconda安装R和R包devtools
  • [python3]Excel解析库-xlwt
  • 线性代数考研笔记
  • 【Web】0基础学Web—节点操作、发表神评妙论、事件添加和移除、事件冒泡和事件捕获
  • jmeter响应断言、json断言、断言持续时间、大小断言操作
  • 暴力破解和撞库攻击有什么区别,怎么防御暴力破解和撞库攻击
  • FPGA进阶教程16 同一块FPGA的两个网口实现arp自通信
  • Opencv中的直方图(4)局部直方图均衡技术函数createCLAHE()的使用
  • windows修改升级时间
  • 九、安装artifactory并配置PostgreSQL--失败了
  • 如何通过本地服务器来测试环信的回调功能
  • powershell自动提交git脚本
  • Android 13 aosp 恢复出厂设置流程
  • 快消品渠道开发方案,让你拥有源源不断的批发客户!
  • 2.门锁_STM32_舵机设备实现
  • oracle 定时任务dbms_job 增删改查
  • slice
  • 一篇文章深入了解Oracle常用命令
  • xxe漏洞
  • 小型集群分析
  • 【IPV6从入门到起飞】3-域名解析动态IPV6(阿里云)
  • 学习大数据DAY49 考后练习题
  • python学习8:dict字典的定义,操作和方法,跟json有什么区别?
  • 通过查找真实IP bypass WAF