当前位置: 首页 > article >正文

导出BERT句子模型为ONNX并推理

在深度学习中,将模型导出为ONNX(Open Neural Network Exchange)格式并利用ONNX进行推理是提高推理速度和模型兼容性的一种常见做法。本文将介绍如何将BERT句子模型导出为ONNX格式,并使用ONNX Runtime进行推理,具体以中文文本处理为例。

1. 什么是ONNX?

ONNX 是一种开放的神经网络交换格式,旨在促进深度学习模型在不同平台和工具之间的共享和移植。它支持包括PyTorch、TensorFlow等多种主流框架,可以通过ONNX Runtime库高效推理。通过将模型转换为ONNX格式,我们可以获得跨平台部署的优势,并利用ONNX Runtime加速推理过程。

2. 准备工作

在导出和推理之前,需要安装以下库:

pip install torch transformers onnx onnxruntime

3. 导出BERT句子模型为ONNX

首先,我们将使用HuggingFace的transformers库加载一个预训练的BERT句子模型(text2vec-base-chinese),然后将其导出为ONNX格式。以下是导出模型的步骤和代码:

3.1 导出模型的代码

import torch
from transformers import BertTokenizer, BertModel

# 加载预训练的BERT模型和分词器
tokenizer = BertTokenizer.from_pretrained('shibing624/text2vec-base-chinese')
model = BertModel.from_pretrained('shibing624/text2vec-base-chinese')

# 读取要处理的句子
with open("corpus/words_nlu.txt", 'rt', encoding='utf-8') as f:
    nlu_words = [line.strip() for line in f.readlines()]
nlu_words.insert(0, "摄像头打开一下")  # 插入要比较的句子

# 对句子进行编码
encoded_input = tokenizer(nlu_words, padding=True, truncation=True, return_tensors='pt')

# 设置ONNX模型的保存路径
onnx_model_path = "text2vec-base-chinese.onnx"
model.eval()

# 导出模型为ONNX格式
with torch.no_grad():
    torch.onnx.export(
        model,
        (encoded_input['input_ids'], encoded_input['attention_mask']),
        onnx_model_path,
        input_names=['input_ids', 'attention_mask'],
        output_names=['last_hidden_state'],
        opset_version=14,
        dynamic_axes={
            'input_ids': {0: 'batch_size', 1: 'sequence_length'},
            'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
            'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}
        }
    )
print(f"ONNX模型已导出到 {onnx_model_path}")

在这段代码中,我们将text2vec-base-chinese模型导出为ONNX格式,指定了输入和输出的名称,并使用了动态轴设置(如批大小和序列长度),这样可以处理不同长度的句子。

4. 使用ONNX进行推理

导出模型后,我们可以使用ONNX Runtime进行推理。以下是基于ONNX的推理代码。该代码实现了对输入文本进行预处理、调用ONNX模型进行推理、以及对模型输出进行均值池化处理。

4.1 ONNX推理代码

import numpy as np
from onnxruntime import InferenceSession

class PIPE_NLU:
    def __init__(self, model_path="text2vec-base-chinese.onnx", vocab_path="vocab.txt") -> None:
        self.model_path = model_path
        self.vocab_path = vocab_path
        self.vocab = self.load_vocab(vocab_path)
        self.onnx_session = InferenceSession(model_path)
        print("成功加载NLU解码器")

    def load_vocab(self, vocab_path):
        """加载BERT词汇表"""
        vocab = {}
        with open(vocab_path, 'r', encoding='utf-8') as f:
            for idx, line in enumerate(f):
                token = line.strip()
                vocab[token] = idx
        return vocab

    def tokenize(self, text):
        """将文本分词为BERT的input_ids"""
        tokens = ['[CLS]']
        for char in text:
            if char in self.vocab:
                tokens.append(char)
            else:
                tokens.append('[UNK]')
        tokens.append('[SEP]')
        input_ids = [self.vocab[token] if token in self.vocab else self.vocab['[UNK]'] for token in tokens]
        return input_ids

    def preprocess(self, texts, max_length=128):
        """对输入文本进行预处理"""
        input_ids_list = []
        attention_mask_list = []
        
        for text in texts:
            input_ids = self.tokenize(text)
            if len(input_ids) > max_length:
                input_ids = input_ids[:max_length]
            else:
                input_ids += [0] * (max_length - len(input_ids))

            attention_mask = [1 if idx != 0 else 0 for idx in input_ids]
            
            input_ids_list.append(input_ids)
            attention_mask_list.append(attention_mask)

        inputs = {
            'input_ids': np.array(input_ids_list, dtype=np.int64),
            'attention_mask': np.array(attention_mask_list, dtype=np.int64)
        }
        return inputs

    def mean_pooling_numpy(self, model_output, attention_mask):
        """对模型输出进行均值池化"""
        token_embeddings = model_output
        input_mask_expanded = np.expand_dims(attention_mask, -1).astype(float)
        return np.sum(token_embeddings * input_mask_expanded, axis=1) / np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)

    def compute_embeddings(self, texts):
        """计算输入文本的句子嵌入"""
        onnx_inputs = self.preprocess(texts)
        onnx_outputs = self.onnx_session.run(None, onnx_inputs)
        last_hidden_state = onnx_outputs[0]
        sentence_embeddings = self.mean_pooling_numpy(last_hidden_state, onnx_inputs['attention_mask'])
        sentence_embeddings = sentence_embeddings / np.linalg.norm(sentence_embeddings, axis=1, keepdims=True)
        return sentence_embeddings

4.2 推理流程

  1. 加载ONNX模型:通过InferenceSession加载ONNX模型。
  2. 加载词汇表:读取BERT的词汇表,用于将输入文本转化为模型可接受的input_ids格式。
  3. 文本预处理:将输入的文本进行分词、截断或填充为固定长度,并生成相应的注意力掩码attention_mask
  4. 模型推理:通过ONNX Runtime调用模型,获取句子的最后隐藏状态输出。
  5. 均值池化:对最后的隐藏状态进行均值池化,计算出句子的嵌入向量。
  6. 归一化嵌入:将句子嵌入向量进行归一化,使得向量长度为1。

5. 总结

通过将BERT模型导出为ONNX并使用ONNX Runtime进行推理,我们可以大幅度提升推理速度,同时保持了高精度的句子嵌入计算。在实际应用中,ONNX Runtime的跨平台特性和高性能表现使其成为模型部署和推理的理想选择。

使用上述步骤,您可以轻松将BERT句子模型应用到各种自然语言处理任务中,如语义相似度计算、文本分类和句子嵌入等。


http://www.kler.cn/news/366646.html

相关文章:

  • 【mysql进阶】4-7. 通用表空间
  • 如何提高游戏的游戏性
  • 研发运营一体化(DevOps)能力成熟度模型
  • 2024软考网络工程师笔记 - 第8章.网络安全
  • linux shell 脚本语言教程(超详细!)
  • 有关spring,springboot项目的知识点
  • axios直接上传binary
  • PHP 正则表达式 修正符【m s x e ? (?i)】内部修正符 贪婪模式 后向引用 断言【总结篇】
  • 【C++初阶】一文讲通C++内存管理
  • 力扣第 420 场周赛 3324. 出现在屏幕上的字符串序列
  • Chromium127编译指南 Windows篇 - 使用 GN 工具生成构建文件(六)
  • 【二轮征稿启动】第三届环境工程与可持续能源国际会议持续收录优质稿件
  • 代码随想录day11 栈与队列
  • Android静态变量中的字段被置空了
  • 关键词搜索的“魔法咒语”:用API接口召唤商品数据
  • Ubuntu服务器搭建Tailscale Derp节点
  • 掌握ElasticSearch(四):数据类型、回复体
  • arm架构 ubuntu 部署docker
  • 校园表白墙源码修复版
  • 基于python智能推荐的丢失物品招领网站的制作,前端vue+django框架,协同过滤算法实现推荐功能
  • 【MySQL 保姆级教学】表的约束--详细(6)
  • #渗透测试#SRC漏洞挖掘# 信息收集-Shodan批量扫描
  • 新王Claude 3.5的6大应用场景
  • android 文字绘制
  • 常见的租用服务器类型和费用
  • Vue学习笔记(三、v-cloak、v-text、v-html指令)