当前位置: 首页 > article >正文

onnx报错解决-bert

 

一、定义

  1. UserWarning: Provided key output for dynamic axes is not a valid input/output name warnings.warn(

  2. 案例

  3. 实体识别bert 案例

  4. 转transformers 模型到onnx 接口解读

二、实现

https://huggingface.co/docs/transformers/main_classes/onnx#transformers.onnx.FeaturesManager

  1. UserWarning: Provided key output for dynamic axes is not a valid input/output name warnings.warn(

代码:

with torch.no_grad():
    symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
    torch.onnx.export(model,
                      (inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]),
                      "./saves/bertclassify.onnx",
                      opset_version=14,
                      input_names=["input_ids", "token_type_ids", "attention_mask"],        
                       output_names=["logits"],                                             
                      dynamic_axes =    {'input_ids': symbolic_names,
                                        'attention_mask': symbolic_names,
                                        'token_type_ids': symbolic_names,
                                        'logits': symbolic_names
                                         }
                      )

改正后:原因: input_names 名字顺序与模型定义不一致导致。为了避免错误产生,应该标准化。如下2所示。

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path)
model.eval()
onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification']("./saves")
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')
from itertools import chain
with torch.no_grad():
    symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
    torch.onnx.export(model,
                      (inputs["input_ids"],inputs["attention_mask"], inputs["token_type_ids"]),
                      "./saves/bertclassify.onnx",
                      opset_version=14,
                      input_names=["input_ids", "attention_mask", "token_type_ids"],      
                       output_names=["logits"],                                            
                      dynamic_axes =    {
        name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())
    }
)
# #验证是否成功
import onnx
onnx_model=onnx.load("./saves/bertclassify.onnx")
onnx.checker.check_model(onnx_model)
print("无报错,转换成功")

# #推理
import onnxruntime
ort_session=onnxruntime.InferenceSession("./saves/bertclassify.onnx", providers=['CPUExecutionProvider'])    #加载模型
ort_input={"input_ids":inputs["input_ids"].cpu().numpy(),"token_type_ids":inputs["token_type_ids"].cpu().numpy(),
           "attention_mask":inputs["attention_mask"].cpu().numpy()}
output_on = ort_session.run(["logits"], ort_input)[0]   #推理


print(output_org.detach().numpy())
print(output_on)
assert np.allclose(output_org.detach().numpy(), output_on, 10-5)  #无报错

标准化:

output_onnx_path = "./saves/bertclassify.onnx"
from itertools import chain
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')

torch.onnx.export(
    model,
    (dummy_inputs,),
    f=output_onnx_path,
    input_names=list(onnx_config.inputs.keys()),
    output_names=list(onnx_config.outputs.keys()),
    dynamic_axes={
        name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())
    },
    do_constant_folding=True,
    opset_version=14,
)

全部:

import torch
devices=torch.device("cpu")
from transformers.onnx.features import FeaturesManager
import torch
from transformers import AutoTokenizer, BertForSequenceClassification
import numpy as np
model_path = "./saves"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path)

words = ["你叫什么名字"]
inputs = tokenizer(words, return_tensors='pt', padding=True)
model.eval()

onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification']("./saves")
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')
from itertools import chain

output_org = model(**inputs).logits

torch.onnx.export(
    model,
    (dummy_inputs,),
    f=output_onnx_path,
    input_names=list(onnx_config.inputs.keys()),
    output_names=list(onnx_config.outputs.keys()),
    dynamic_axes={
        name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())
    },
    do_constant_folding=True,
    opset_version=14,
)

# #验证是否成功
import onnx
onnx_model=onnx.load("./saves/bertclassify.onnx")
onnx.checker.check_model(onnx_model)
print("无报错,转换成功")

# #推理
import onnxruntime
ort_session=onnxruntime.InferenceSession("./saves/bertclassify.onnx", providers=['CPUExecutionProvider'])    #加载模型
ort_input={"input_ids":inputs["input_ids"].cpu().numpy(),"token_type_ids":inputs["token_type_ids"].cpu().numpy(),
           "attention_mask":inputs["attention_mask"].cpu().numpy()}
output_on = ort_session.run(["logits"], ort_input)[0]   #推理


print(output_org.detach().numpy())
print(output_on)
assert np.allclose(output_org.detach().numpy(), output_on, 10-5)  #无报错

无任何警告产生

                 

  1. 实体识别案例

import onnxruntime
from itertools import chain
from transformers.onnx.features import FeaturesManager

config = ner_config
tokenizer = ner_tokenizer
model = ner_model
output_onnx_path = "bert-ner.onnx"

onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification'](config)
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')

torch.onnx.export(
    model,
    (dummy_inputs,),
    f=output_onnx_path,
    input_names=list(onnx_config.inputs.keys()),
    output_names=list(onnx_config.outputs.keys()),
    dynamic_axes={
        name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())
    },
    do_constant_folding=True,
    opset_version=onnx_config.default_onnx_opset,       #默认,报错改为14
)
  1. 转transformers 模型到onnx 接口解读

Huggingface:导出transformers模型到onnx_ONNX_程序员架构进阶_InfoQ写作社区

https://zhuanlan.zhihu.com/p/684444410


http://www.kler.cn/a/414343.html

相关文章:

  • 第六届国际科技创新学术交流大会暨信息技术与计算机应用学术会议(ITCA 2024)
  • 在 Django 中创建和使用正整数、负数、小数等数值字段
  • shell脚本基础学习_总结篇(完结)
  • OpenCV截取指定图片区域
  • 嵌入式linux C++通用makefile模板
  • transformer学习笔记-神经网络原理
  • Leetcode 面试150题 189. 轮转数组 中等
  • React UI设计黑色蒙层#000000 80%,首次打开弹出,点击图片可以关闭
  • Figma入门-铅笔钢笔工具
  • 大数据笔记
  • Mybatis:Mybatis快速入门
  • 如何将MinIO数据迁移到阿里云OSS
  • LLMs之ell:ell(轻量级函数式提示工程框架)的简介、安装和使用方法、案例应用之详细攻略
  • python+django自动化平台(一键执行sql) 前端vue-element展示
  • 应急响应靶机——easy溯源
  • 算法的NPU终端移植:深入探讨与实践指南
  • 豆包MarsCode算法题:三数之和问题
  • 论 AI(人工智能)的现状
  • 商汤绝影打造A New Member For U,让汽车拥有“有趣灵魂”
  • 力扣 搜索旋转排序数组-33
  • Qt UI设计 菜单栏无法输入名字
  • faiss库中ivf-sq(ScalarQuantizer,标量量化)代码解读-3
  • 自动驾驶科研资料整理
  • 【再谈设计模式】装配器模式 ~复杂结构构建的巧匠
  • 注意http-proxy-middleware要解决跨域问题,想修改origin请求头不要设置changeOrigin=true
  • DeSTSeg: Segmentation Guided Denoising Student-Teacher for Anomaly Detection