当前位置: 首页 > article >正文

深度学习模型格式转换:pytorch2onnx(包含自定义操作符)

       将PyTorch模型转换为ONNX(Open Neural Network Exchange)格式是实现模型跨平台部署和优化推理性能的一种常见方法。PyTorch 提供了多种方式来完成这一转换,以下是几种主要的方法: 

一、静态模型转换

使用 torch.onnx.export()

   torch.onnx.export() 是 PyTorch 官方推荐的最常用方法,适用于大多数情况。它允许你将一个 PyTorch 模型及其输入数据一起导出为 ONNX 格式。

基本用法
import torch
import torch.onnx

# 假设你有一个训练好的模型 `model` 和一个示例输入 `dummy_input`
model = ...  # 你的 PyTorch 模型
dummy_input = torch.randn(1, 3, 224, 224)  # 示例输入,形状取决于模型的输入要求

# 设置模型为评估模式
model.eval()

# 导出为 ONNX 文件
torch.onnx.export(
    model,                    # 要导出的模型
    dummy_input,              # 模型的输入张量
    "model.onnx",             # 输出文件名
    export_params=True,       # 是否导出模型参数
    opset_version=11,         # ONNX 操作集版本
    do_constant_folding=True, # 是否执行常量折叠优化
    input_names=['input'],    # 输入节点名称
    output_names=['output'],  # 输出节点名称
    dynamic_axes={'input': {0: 'batch_size'},  # 动态轴,支持可变批次大小
                  'output': {0: 'batch_size'}}
)
关键参数说明
  • model: 要导出的 PyTorch 模型。
  • dummy_input: 一个与模型输入形状匹配的张量,用于模拟实际输入。
  • export_params: 是否导出模型的参数(权重和偏置)。通常设置为 True
  • opset_version: 指定要使用的 ONNX 操作集版本。不同的版本可能支持不同的操作符。建议使用较新的版本(如 11 或 13)。
  • do_constant_folding: 是否执行常量折叠优化,可以减少模型的计算量。
  • input_names 和 output_names: 指定 ONNX 模型的输入和输出节点的名称,方便后续加载和调用。
  • dynamic_axes: 指定哪些维度是动态的(即可以在推理时变化),例如批次大小或序列长度。

二、复杂模型转换

       对于一些复杂的模型,特别是包含控制流(如条件语句、循环等)的模型,torch.onnx.export() 可能无法直接处理。这时可以先使用 torch.jit.trace() 将模型转换为 TorchScript 格式,然后再导出为 ONNX。

基本用法

 

import torch
import torch.onnx

# 假设你有一个训练好的模型 `model` 和一个示例输入 `dummy_input`
model = ...  # 你的 PyTorch 模型
dummy_input = torch.randn(1, 3, 224, 224)  # 示例输入

# 设置模型为评估模式
model.eval()

# 使用 torch.jit.trace 将模型转换为 TorchScript
traced_model = torch.jit.trace(model, dummy_input)

# 导出为 ONNX 文件
torch.onnx.export(
    traced_model,            # 已经转换为 TorchScript 的模型
    dummy_input,             # 模型的输入张量
    "traced_model.onnx",     # 输出文件名
    export_params=True,      # 是否导出模型参数
    opset_version=11,        # ONNX 操作集版本
    do_constant_folding=True,# 是否执行常量折叠优化
    input_names=['input'],   # 输入节点名称
    output_names=['output'], # 输出节点名称
    dynamic_axes={'input': {0: 'batch_size'},  # 动态轴
                  'output': {0: 'batch_size'}}
)

三、动态模型转换

使用 torch.onnx.dynamo_export()

   torch.onnx.dynamo_export() 是 PyTorch 2.0 引入的新功能,基于 PyTorch 的 Dynamo 编译器。它旨在提供更好的性能和更广泛的模型支持,尤其是对于那些包含动态控制流的模型。

基本用法
import torch

# 假设你有一个训练好的模型 `model` 和一个示例输入 `dummy_input`
model = ...  # 你的 PyTorch 模型
dummy_input = torch.randn(1, 3, 224, 224)  # 示例输入

# 设置模型为评估模式
model.eval()

# 使用 dynamo_export 导出为 ONNX 文件
torch.onnx.dynamo_export(
    model,                    # 要导出的模型
    dummy_input,              # 模型的输入张量
    "dynamo_model.onnx"       # 输出文件名
)

        注意torch.onnx.dynamo_export() 是 PyTorch 2.0 中引入的功能,确保你使用的是最新版本的 PyTorch。

四、自定义操作符模型转换

       自定义操作符(Custom Operator)是指那些不在标准 PyTorch 或 ONNX 操作集中的操作符。当你需要实现某些特定的功能或优化时,可能需要编写自定义的操作符,并将其注册到 ONNX 中以便在导出和推理时使用。

例子:实现一个自定义的 ReLU6 操作符

假设我们想要实现一个自定义的 ReLU6 操作符。ReLU6 是一种常用的激活函数,它与标准的 ReLU 类似,但有一个上限值 6。其数学表达式为:

1. 实现自定义操作符

       首先,我们需要在 C++ 中实现这个自定义操作符,并编译成一个共享库。PyTorch 提供了 torch::jit::custom_ops 接口来注册自定义操作符,而 ONNX 则提供了 onnxruntime 来注册自定义操作符。

1.1 在 PyTorch 中实现自定义操作符

       我们可以在 C++ 中实现 ReLU6 操作符,并通过 PyTorch 的 torch::jit::custom_ops 接口将其注册到 PyTorch 中:

// custom_relu6.cpp
#include <torch/script.h>
#include <torch/custom_class.h>

// 定义自定义的 ReLU6 操作符
torch::Tensor custom_relu6(const torch::Tensor& input) {
    return torch::clamp(input, 0, 6);
}

// 注册自定义操作符
static auto registry = torch::RegisterOperators("custom_ops::relu6", &custom_relu6);
1.2 编译自定义操作符

       接下来,我们需要将这个 C++ 文件编译成一个共享库(例如 .so 文件),以便在 Python 中加载:

# 使用 PyTorch 提供的工具进行编译
python -m pip install torch torchvision torchaudio
python -m torch.utils.cpp_extension.build_ext --inplace custom_relu6.cpp

这会生成一个名为 custom_relu6.so 的共享库文件;

2. 在 PyTorch 中使用自定义操作符

       现在我们可以在 Python 中加载并使用这个自定义操作符;

import torch
import torch.nn as nn
import custom_relu6  # 加载编译后的共享库

# 定义一个使用自定义 ReLU6 操作符的模型
class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv(x)
        # 调用自定义的 ReLU6 操作符
        x = torch.ops.custom_ops.relu6(x)
        return x

# 创建模型实例
model = CustomModel()
model.eval()

# 准备示例输入
dummy_input = torch.randn(1, 3, 224, 224)

# 运行模型
output = model(dummy_input)
print(output.shape)  # 输出形状应为 (1, 16, 224, 224)
3. 将自定义操作符导出为 ONNX

       为了将包含自定义操作符的模型导出为 ONNX 格式,我们需要告诉 ONNX 如何处理这个自定义操作符。我们可以使用 torch.onnx.register_custom_op_symbolic 来定义 ONNX 符号函数,从而在导出时正确处理自定义操作符。

3.1 定义 ONNX 符号函数

       我们需要定义一个符号函数,告诉 ONNX 如何表示 custom_ops::relu6 操作符。这个符号函数会生成相应的 ONNX 操作符节点。

import torch.onnx
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args

# 定义 ONNX 符号函数
@parse_args('v')
def symbolic_custom_relu6(g, input):
    # 使用 ONNX 的 Clip 操作符来实现 ReLU6
    return g.op("Clip", input, min_f=0.0, max_f=6.0)

# 注册自定义操作符的符号函数
register_custom_op_symbolic('custom_ops::relu6', symbolic_custom_relu6, 9)  # 9 表示 ONNX 操作集版本
3.2 导出为 ONNX

       现在我们可以将模型导出为 ONNX 格式,并确保自定义操作符被正确处理。

# 导出为 ONNX 文件
torch.onnx.export(
    model,                    # 要导出的模型
    dummy_input,              # 模型的输入张量
    "custom_model.onnx",      # 输出文件名
    export_params=True,       # 是否导出模型参数
    opset_version=9,          # ONNX 操作集版本
    do_constant_folding=True, # 是否执行常量折叠优化
    input_names=['input'],    # 输入节点名称
    output_names=['output'],  # 输出节点名称
    dynamic_axes={'input': {0: 'batch_size'},  # 动态轴
                  'output': {0: 'batch_size'}}
)
4. 在 ONNX Runtime 中使用自定义操作符

       为了在 ONNX Runtime 中使用自定义操作符,我们需要将自定义操作符的实现编译成一个 ONNX Runtime 扩展库,并在推理时加载该扩展库。

4.1 实现 ONNX Runtime 自定义操作符

       我们需要在 C++ 中实现 ReLU6 操作符,并将其注册到 ONNX Runtime 中。

// custom_relu6_onnxruntime.cpp
#include "onnxruntime/core/providers/cpu/cpu_provider_factory.h"
#include "onnxruntime/core/framework/op_kernel.h"

namespace onnxruntime {

class CustomRelu6 : public OpKernel {
public:
  explicit CustomRelu6(const OpKernelInfo& info) : OpKernel(info) {}

  Status Compute(OpKernelContext* context) const override {
    // 获取输入张量
    const Tensor* input_tensor = context->Input<Tensor>(0);
    if (!input_tensor) return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is null");

    // 获取输出张量
    Tensor* output_tensor = context->Output(0, input_tensor->Shape());
    if (!output_tensor) return Status(common::ONNXRUNTIME, common::FAIL, "Output tensor is null");

    // 获取输入和输出的数据指针
    float* input_data = input_tensor->template Data<float>();
    float* output_data = output_tensor->template Data<float>();

    // 计算 ReLU6
    size_t size = input_tensor->Shape().Size();
    for (size_t i = 0; i < size; ++i) {
      output_data[i] = std::min(std::max(input_data[i], 0.0f), 6.0f);
    }

    return Status::OK();
  }
};

ONNX_OPERATOR_KERNEL(
    Relu6,  // 操作符名称
    kOnnxDomain,  // 命名空间
    9,  // 操作集版本
    KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),  // 数据类型约束
    CustomRelu6);  // 自定义操作符类
}
4.2 编译 ONNX Runtime 自定义操作符

       我们将上述代码编译成一个动态链接库(例如 .so 文件),以便在 ONNX Runtime 中加载。

# 使用 ONNX Runtime 提供的工具进行编译
g++ -shared -fPIC -o custom_relu6_onnxruntime.so custom_relu6_onnxruntime.cpp -lonnxruntime
4.3 在 ONNX Runtime 中加载自定义操作符

       最后,我们在 Python 中使用 onnxruntime 加载自定义操作符,并运行推理。

import onnxruntime as ort
import numpy as np

# 加载 ONNX 模型
ort_session = ort.InferenceSession("custom_model.onnx", providers=['CPUExecutionProvider'])

# 加载自定义操作符的扩展库
ort_session.load_custom_ops_library("custom_relu6_onnxruntime.so")

# 准备输入数据
ort_inputs = {'input': dummy_input.numpy()}  # 将 PyTorch 张量转换为 NumPy 数组

# 运行推理
ort_outs = ort_session.run(None, ort_inputs)

# 获取 PyTorch 模型的输出
with torch.no_grad():
    torch_out = model(dummy_input)

# 比较 ONNX 和 PyTorch 的输出
np.testing.assert_allclose(torch_out.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05)

print("ONNX 模型验证通过!")

总结

  • 自定义操作符:当你的模型中包含不在标准 PyTorch 或 ONNX 操作集中的操作符时,你可以通过编写自定义操作符来实现这些功能。
  • PyTorch 中的自定义操作符:可以使用 torch::jit::custom_ops 接口在 C++ 中实现自定义操作符,并通过共享库加载到 PyTorch 中。
  • ONNX 中的自定义操作符:可以通过 torch.onnx.register_custom_op_symbolic 定义符号函数,告诉 ONNX 如何处理自定义操作符。然后,在 ONNX Runtime 中,可以通过编译自定义操作符的实现并加载扩展库来支持推理。
  • 复杂性:实现自定义操作符通常比较复杂,因为它涉及到跨语言编程(C++ 和 Python)、编译和链接等多个步骤。然而,这对于实现特定功能或优化模型是非常有用的。

       通过这个例子,你可以看到如何从头实现一个自定义操作符,并将其集成到 PyTorch 和 ONNX 中。


http://www.kler.cn/a/458845.html

相关文章:

  • GRAPE——RLAIF微调VLA模型:通过偏好对齐提升机器人策略的泛化能力(含24年具身模型汇总)
  • 【服务器】上传文件到服务器并训练深度学习模型下载服务器文件到本地
  • 【笔记】在虚拟机中通过apache2给一个主机上配置多个web服务器
  • android.enableJetifier=true的作用:V4包的类自动编程成了androidx包的类,实现androidx的向下兼容
  • 深度学习论文: RemDet: Rethinking Efficient Model Design for UAV Object Detection
  • ListenAI 1.0.6 | 解锁订阅的文本转语音工具,支持朗读文档和网页
  • 当现代教育技术遇上仓颉---探秘华为仓颉编程语言与未来教育技术的接轨
  • 电子电器架构 ---什么是智能电动汽车上的BMS?
  • VScode怎么重启
  • C# init 关键字的使用
  • 【ArcGIS Pro/GeoScene Pro】可视化时态数据
  • javaweb线上问题排查(若依定时任务)
  • 分布式版本管理工具——git 中忽略文件的版本跟踪(初级方法及高级方法)
  • 进程、线程和协程是什么,以及他们之间的区别
  • K-means 聚类:Python 和 Scikit-learn实现
  • uniapp 微信小程序开发使用高德地图定位SDK
  • ZYQN MPSoc系列芯片综述
  • MOS管驱动方案汇总
  • WeNet:面向生产的流式和非流式端到端语音识别工具包
  • 下载mysql免安装版和配置
  • 计算机网络-L2TP VPN基础实验配置
  • LeetCode-正则表达式匹配(010)
  • 为什么C++支持函数重载而C语言不支持?
  • “技术学习”(Technical Learning)在英文中的多种表达方式
  • 第十六届蓝桥杯模拟赛(第一期)(C语言)
  • HarmonyOS NEXT 实战之元服务:静态案例效果---本地生活服务