当前位置: 首页 > article >正文

【深度学习】【图像分类】【OnnxRuntime】【Python】VggNet模型部署

【深度学习】【图像分类】【OnnxRuntime】【Python】VggNet模型部署

提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论

文章目录

  • 【深度学习】【图像分类】【OnnxRuntime】【Python】VggNet模型部署
  • 前言
  • Windows平台搭建依赖环境
  • 模型转换--pytorch转onnx
  • ONNXRuntime推理代码
  • 总结


前言

本期将讲解深度学习图像分类网络VggNet模型的部署,对于该算法的基础知识,可以参考博主【VggNet模型算法Pytorch版本详解】博文。
读者可以通过学习 【onnx部署】部署系列学习文章目录的onnxruntime系统学习–Python篇 的内容,系统的学习OnnxRuntime部署不同任务的onnx模型。


Windows平台搭建依赖环境

在【入门基础篇】中详细的介绍了onnxruntime环境的搭建以及ONNXRuntime推理核心流程代码,不再重复赘述。


模型转换–pytorch转onnx

import torch
import torchvision as tv
def resnet2onnx():
    # 使用torch提供的预训练权重 1000分类
    model = tv.models.vgg16(pretrained=True)
    model.eval()
    model.cpu()
    dummy_input1 = torch.randn(1, 3, 224, 224)
    torch.onnx.export(model, (dummy_input1), "vgg16.onnx", verbose=True, opset_version=11)
if __name__ == "__main__":
    resnet2onnx()


如下图,torchvision本身提供了不少经典的网络,为了减少教学复杂度,这里博主直接使用了torchvision提供的ResNet网络,并下载和加载了它提供的训练权重。这里可以替换成自己的搭建的ResNet网络以及自己训练的训练权重。


ONNXRuntime推理代码

需要配置imagenet_classes.txt【百度云下载,提取码:rkz7 】文件存储1000类分类标签,假设是用户自定的分类任务,需要根据实际情况作出修改,并将其放置到工程目录下(推荐)。

这里需要将vgg16.onnx放置到工程目录下(推荐),并且将以下推理代码拷贝到新建的py文件中,并执行查看结果。

import onnxruntime as ort
import cv2
import numpy as np

# 加载标签文件获得分类标签
def read_class_names(file_path="./imagenet_classes.txt"):
    class_names = []
    try:
        with open(file_path, 'r') as fp:
            for line in fp:
                name = line.strip()
                if name:
                    class_names.append(name)
    except IOError:
        print("could not open file...")
        import sys
        sys.exit(-1)
    return class_names

# 主函数
def main():
    # 预测的目标标签数
    labels = read_class_names()

    # 测试图片
    image_path = "./lion.jpg"
    image = cv2.imread(image_path)
    # cv2.imshow("输入图", image)
    # cv2.waitKey(0)

    # 设置会话选项
    sess_options = ort.SessionOptions()
    # 0=VERBOSE, 1=INFO, 2=WARN, 3=ERROR, 4=FATAL
    sess_options.log_severity_level = 3
    # 优化器级别:基本的图优化级别
    sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
    # 线程数:4
    sess_options.intra_op_num_threads = 4
    # 设备使用优先使用GPU而是才是CPU,列表中的顺序决定了执行提供者的优先级
    providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']

    # onnx训练模型文件
    onnxpath = "./vgg16.onnx"

    # 加载模型并创建会话
    session = ort.InferenceSession(onnxpath, sess_options=sess_options, providers=providers)

    input_nodes_num = len(session.get_inputs())     # 输入节点输
    output_nodes_num = len(session.get_outputs())   # 输出节点数
    input_node_names = []                           # 输入节点名称
    output_node_names = []                          # 输出节点名称

    # 获取模型输入信息
    for i in range(input_nodes_num):
        # 获得输入节点的名称并存储
        input_name = session.get_inputs()[i].name
        input_node_names.append(input_name)
        # 显示输入图像的形状
        input_shape = session.get_inputs()[i].shape
        ch, input_h, input_w = input_shape[1], input_shape[2], input_shape[3]
        print(f"input format: {ch}x{input_h}x{input_w}")

    # 获取模型输出信息
    for i in range(output_nodes_num):
        # 获得输出节点的名称并存储
        output_name = session.get_outputs()[i].name
        output_node_names.append(output_name)
        # 显示输出结果的形状
        output_shape = session.get_outputs()[i].shape
        num, nc = output_shape[0], output_shape[1]
        print(f"output format: {num}x{nc}")

    input_shape = session.get_inputs()[0].shape
    input_h, input_w = input_shape[2], input_shape[3]
    print(f"input format: {input_shape[1]}x{input_h}x{input_w}")

    # 预处理输入数据
    # 默认是BGR需要转化成RGB
    rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # 对图像尺寸进行缩放
    blob = cv2.resize(rgb, (input_w, input_h))
    blob = blob.astype(np.float32)
    # 对图像进行标准化处理
    blob /= 255.0   # 归一化
    blob -= np.array([0.485, 0.456, 0.406])  # 减去均值
    blob /= np.array([0.229, 0.224, 0.225])  # 除以方差
    #CHW-->NCHW 维度扩展
    timg = cv2.dnn.blobFromImage(blob)
    # ---blobFromImage 可以用以下替换---
    # blob = blob.transpose(2, 0, 1)
    # blob = np.expand_dims(blob, axis=0)
    # -------------------------------

    # 模型推理
    try:
        ort_outputs = session.run(output_names=output_node_names, input_feed={input_node_names[0]: timg})
    except Exception as e:
        print(e)
        ort_outputs = None

    # 后处理推理结果
    prob = ort_outputs[0]
    max_index = np.argmax(prob)     # 获得最大值的索引
    print(f"label id: {max_index}")
    # 在测试图像上加上预测的分类标签
    label_text = labels[max_index]
    cv2.putText(image, label_text, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2, 8)
    cv2.imshow("输入图像", image)
    cv2.waitKey(0)

if __name__ == '__main__':
    main()

图片预测为猎豹(cheetah),没有准确预测出狮子(lion),但是这个图片难度很大,在1000分类中预测的比较接近的。

其实图像分类网络的部署代码基本是一致的,几乎不需要修改,只需要修改传入的图片数据已经训练模型权重即可。


总结

尽可能简单、详细的讲解了Python下onnxruntime环境部署VggNet模型的过程。


http://www.kler.cn/a/303595.html

相关文章:

  • SQLAlchemy: python类的属性值为None,数据为JSON类型,插入数据库为‘ NULL‘字符串,而不是真正的NULL
  • H2数据库在单元测试中的应用
  • 矩阵求逆的几种方式
  • 3BB学习transformer日记,attention原理
  • 研究线段树的最大子段和
  • Angular 最新版本和 Vue 对比完整指南
  • 项目进度一
  • 数据库常规操作
  • vue引入三维模型
  • 【绿盟科技盟管家-注册/登录安全分析报告】
  • 2024CCPC网络预选赛
  • raksmart大带宽服务器租用
  • mycat双主高可用架构部署-MySQL5.7环境部署第一台
  • 「漏洞复现」紫光电子档案管理系统 selectFileRemote SQL注入漏洞
  • TestCraft - GPT支持的测试想法生成器和自动化测试生成器
  • 前端使用COS上传文件
  • 为什么要进行MySQL增量备份?
  • 【数据结构和算法实践-树-LeetCode112-路径总和】
  • 力扣: 四数相加II
  • Linux安装管理多版本JDK
  • CSS Clip-Path:重塑元素边界的艺术
  • mysql慢sql问题修复
  • 计算机毕业设计 自习室座位预约系统的设计与实现 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试
  • qt操作excel(QAxObject详细介绍)
  • 论文解读《LaMP: When Large Language Models Meet Personalization》
  • 浏览器插件利器--allWebPluginV2.0.0.20-alpha版发布