当前位置: 首页 > article >正文

PyTorch模型转换ONNX 入门

目录

前言

什么是ONNX文件

ONNX文件简介

ONNX 文件的主要特点

ONNX 文件的基本结构

pytorch模型转ONNX模型

前期准备

cuda安装

pytorch安装

ONNX模块安装

pytorch模型导出ONNX模型 

简单体验

查看ONNX文件

 torch.onnx.export函数

 加载ONNX模型


前言

什么是ONNX文件

ONNX文件简介

ONNX(Open Neural Network Exchange)是一种开放的文件格式,用于表示机器学习模型,旨在促进不同框架之间的互操作性。

ONNX 文件通常以 .onnx 为扩展名,能够存储神经网络的结构和权重,使得模型可以在不同的深度学习框架(如 TensorFlow、PyTorch、Caffe 等)之间进行转换和部署

ONNX 文件的主要特点

  1. 跨平台兼容性:ONNX 支持在多个框架之间共享模型,用户可以在一个框架中训练模型,然后将其导出为 ONNX 格式,以便在另一个框架中进行推理。

  2. 开放标准:ONNX 是一个开放的标准,由多个行业合作伙伴共同开发和维护。它为机器学习社区提供了一个统一的格式。

  3. 高效性:ONNX 文件能够有效地存储模型的计算图、参数和操作,这样可以更高效地进行推理。

  4. 支持多种操作:ONNX 定义了一组标准操作符,支持多种神经网络架构,包括卷积神经网络(CNN)、循环神经网络(RNN)等。

  5. 工具支持:ONNX 提供了一系列工具和库,支持将模型从不同框架导出到 ONNX 格式,也支持从 ONNX 文件加载模型进行推理。

ONNX 文件的基本结构

一个典型的 ONNX 文件包含以下内容:

  • 计算图:描述了模型的结构,包括各层的连接关系。
  • 参数:模型的权重和偏置值。
  • 元数据:关于模型的一些附加信息,如输入输出的形状、数据类型等。

pytorch模型转ONNX模型

前期准备

用户需事先安装cuda、cudnn(可选)和pytorch

cuda安装

windows下cuda的安装见

windows安装cuda与cudnn-CSDN博客

linux下的cuda安装见

【CUDA】Ubuntu系统如何安装CUDA保姆级教程(2022年最新)_ubuntu安装cuda-CSDN博客

无论在哪个系统上安装cuda,只要输入以下命令时有信息输出即表示安装成功

nvcc -V

pytorch安装

当cuda安装成功后,输入nvcc -V命令查看cuda版本号,然后进入pytorch官网,下载对应cuda版本的pytorch即可

Previous PyTorch Versions | PyTorch

无论什么系统,只要在命令行输出以下结果,即表示pytorch安装成功

ONNX模块安装

pip install onnx
pip install onnxruntime

pytorch模型导出ONNX模型 

简单体验

首先使用pytorch写一个简单的网络模型

import torch
import torchvision
import numpy as np
 
devide=torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义一个简单的PyTorch 模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(64 * 8 * 8, 10)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        return x
 
# 创建模型实例
model = MyModel().to(devide)
 
# 指定模型输入尺寸
dummy_input = torch.randn(1, 3, 32, 32).to(devide)
 
# 将PyTorch模型转为ONNX模型
torch.onnx.export(model, dummy_input, 'mymodel.onnx',  do_constant_folding=False)

如上所示,我们手写了一个简单的网络模型,可以看到,转ONNX模型文件的代码只有最后一行

torch.onnx.export(model, dummy_input, 'mymodel.onnx',  do_constant_folding=False)

事实上,torch.onnx.export函数就是将一个torch模型转换为ONNX文件的函数

查看ONNX文件

上述代码成功运行后,会在本地生成一个mymodel.onnx文件,该文件的打开需要使用netron,有关netron的安装见

netron安装(windows && linux)-CSDN博客

安装成功后,使用netron打开mymodel.onnx,如下所示

 torch.onnx.export函数

上述我们写了一个小demo体验了torch模型转换为ONNX文件,并查看了ONNX文件到底是什么,接下来我们来看torch模型转换ONNX文件的核心函数中参数含义都是什么

torch.onnx.export(
            model, 
            args, 
            f, 
            export_params=True, 
            opset_version=10, 
            do_constant_folding=True, 
            input_names=['input'], 
            output_names=['output'], 
            dynamic_axes=None, 
            verbose=False, 
            example_outputs=None, 
            keep_initializers_as_inputs=None)

参数详解

  1. model:

    • 类型: torch.nn.Module
    • 描述: 被转换的 PyTorch 模型
  2. args:

    • 类型: tuple 或 torch.Tensor
    • 描述: torch模型的输入示例,可以是一个单一的张量或多个张量(以元组的形式)。这些输入数据用于执行模型,确定模型的输入形状。
  3. f:

    • 类型: str 或 Path
    • 描述: 导出模型的目标文件路径或文件名,通常以 .onnx 作为扩展名
  4. export_params:

    • 类型: bool,默认: True
    • 描述: 是否将模型的参数(权重和偏置)也导出到 ONNX 文件中。如果设置为 True,导出的模型会包含所有的参数。
  5. opset_version:

    • 类型: int,默认: 9
    • 描述: 指定要使用的 ONNX 操作集版本。不同版本可能支持不同的操作和功能。设置合适的版本可以确保兼容性。
  6. do_constant_folding:

    • 类型: bool,默认: True
    • 描述: 是否进行常量折叠优化。常量折叠会在导出过程中将一些常量计算提前,从而简化模型的计算图,提升推理效率。
  7. input_names:

    • 类型: list,默认: None
    • 描述: 输入张量的名称列表。可以用来指定导出模型输入的名称,便于后续在其他框架中识别。
  8. output_names:

    • 类型: list,默认: None
    • 描述: 输出张量的名称列表。类似于 input_names,用于指定导出模型输出的名称。
  9. dynamic_axes:

    • 类型: dict 或 None,默认: None
    • 描述: 允许动态维度的输入输出。在导出时,可以指定某些维度是动态的,这样在推理时输入的形状可以变化。例如:
      dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
      这表示 input 和 output 的第一个维度是动态的(例如 batch size)。
  10. verbose:

    • 类型: bool,默认: False
    • 描述: 是否在导出时打印详细信息。如果设置为 True,会显示更多的调试信息,便于跟踪导出过程中的问题。
  11. example_outputs:

    • 类型: tuple 或 torch.Tensor,默认: None
    • 描述: 用于指定模型的示例输出,这有助于 ONNX 在导出时进行类型推断。可以提供一个或多个输出张量,以便更好地推断输出的形状和类型。
  12. keep_initializers_as_inputs:

    • 类型: bool,默认: None
    • 描述: 是否将模型的初始值(权重)作为输入保存。如果设置为 True,则初始值将被视为模型的输入之一,而不是存储在模型的参数中。

下面是一个常用的模板

import torch.onnx 
 
# 转为ONNX
def Convert_ONNX(model): 
 
    # 设置模型为推理模式
    model.eval() 
 
    # 设置模型输入的尺寸
    dummy_input = torch.randn(1, input_size, requires_grad=True)  
 
    # 导出ONNX模型  
    torch.onnx.export(model,         # model being run 
         dummy_input,       # model input (or a tuple for multiple inputs) 
         "xxx.onnx",       # where to save the model  
         export_params=True,  # store the trained parameter weights inside the model file 
         opset_version=10,    # the ONNX version to export the model to 
         do_constant_folding=True,  # whether to execute constant folding for optimization 
         input_names = ['modelInput'],   # the model's input names 
         output_names = ['modelOutput'], # the model's output names 
         dynamic_axes={'modelInput' : {0 : 'batch_size'},    # variable length axes 
                                'modelOutput' : {0 : 'batch_size'}}) 
    print(" ") 
    print('Model has been converted to ONNX')
 
 
if __name__ == "__main__": 
 
    # 构建模型并训练
    # xxxxxxxxxxxx
 
    # 测试模型精度
    #testAccuracy() 
 
    # 加载模型结构与权重
    model = Network() 
    path = "myFirstModel.pth" 
    model.load_state_dict(torch.load(path)) 
 
    # 转换为ONNX 
    Convert_ONNX(model)

 加载ONNX模型

导出ONNX模型后,加载ONNX模型需要用到onnxruntime库,以下是一个导出ONNX模型的示例

import onnxruntime as ort
 
# 加载 ONNX 模型
ort_session = ort.InferenceSession("model.onnx")
 
# 准备输入信息
input_info = ort_session.get_inputs()[0]
input_name = input_info.name
input_shape = input_info.shape
input_type = input_info.type
 
 
# 运行ONNX模型
outputs = ort_session.run(input_name, input_data)
 
# 获取输出信息
output_info = ort_session.get_outputs()[0]
output_name = output_info.name
output_shape = output_info.shape
output_data = outputs[0]
 
print("outputs:", outputs)
print("output_info :", output_info )
print("output_name :", output_name )
print("output_shape :", output_shape )
print("output_data :", output_data )

在以下案例中,我们首先将resnet-18模型导出为ONNX模型,然后再加载导出的ONNX模型,最后对比torch模型和ONNX模型的输出差异

import torch
import torchvision.models as models
import onnx
import onnxruntime
 
# 加载 PyTorch 模型
model = models.resnet18(pretrained=True)
model.eval()
 
# 定义输入和输出张量的名称和形状
input_names = ["input"]
output_names = ["output"]
batch_size = 1
input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)
 
# 将 PyTorch 模型转换为 ONNX 格式
torch.onnx.export(
    model,  # 要转换的 PyTorch 模型
    torch.randn(input_shape),  # 模型输入的随机张量
    "resnet18.onnx",  # 保存的 ONNX 模型的文件名
    input_names=input_names,  # 输入张量的名称
    output_names=output_names,  # 输出张量的名称
    dynamic_axes={input_names[0]: {0: "batch_size"}, output_names[0]: {0: "batch_size"}}  # 动态轴,即输入和输出张量可以具有不同的批次大小
)
 
# 加载 ONNX 模型
onnx_model = onnx.load("resnet18.onnx")
onnx_model_graph = onnx_model.graph
onnx_session = onnxruntime.InferenceSession(onnx_model.SerializeToString())
 
# 使用随机张量测试 ONNX 模型
x = torch.randn(input_shape).numpy()
onnx_output = onnx_session.run(output_names, {input_names[0]: x})[0]
 
print(f"PyTorch output: {model(torch.from_numpy(x)).detach().numpy()[0, :5]}")
print(f"ONNX output: {onnx_output[0, :5]}")

运行结果如下所示

PyTorch output: [0.22972351 2.4930785  2.4462368  2.7443404  4.7080407 ]
ONNX output: [0.22972152 2.4930775  2.4462373  2.7443395  4.708042  ]


http://www.kler.cn/news/360982.html

相关文章:

  • 24下河南秋季教资认定保姆级教程
  • 【YOLO系列】YOLO11原理和深入解析——待完善
  • 《深度学习》Dlib 人脸应用实例 性别年龄预测 案例实现
  • 传输层协议UDP详解
  • 【OpenGauss源码学习 —— (VecSortAgg)】
  • 集合分类及打印的方式
  • SDUT数据结构与算法第四次机测
  • Prometheus 告警
  • MySQL实现主从同步
  • 一个汉字占几个字节、JS中如何获得一个字符串占用多少字节?
  • 前端性能优化之加载篇
  • ubuntu安装boost、x264、FFMPEG
  • 前端项目中遇到的技术问题
  • 字节流写入文件
  • Java基于SSM微信小程序物流仓库管理系统设计与实现(源码+lw+数据库+讲解等)
  • Linux中安装tesserocr遇到的那些坑
  • go-zero系列-限流(并发控制)及hey压测
  • 【JAVA】第三张_Eclipse下载、安装、汉化
  • ruoyi框架配置多数据源
  • C++11——智能指针