当前位置: 首页 > article >正文

基于 LlamaFactory 的 LoRA 微调模型支持 vllm 批量推理的实现

背景

LlamaFactory 的 LoRA 微调功能非常便捷,微调后的模型,没有直接支持 vllm 推理,故导致推理速度不够快。

LlamaFactory 目前支持通过 VLLM API 进行部署,调用 API 时的响应速度,仍然没有vllm批量推理的速度快。

如果模型是通过 LlamaFactory 微调的,为了确保数据集的一致性,建议在推理时也使用 LlamaFactory 提供的封装数据集。

简介

在上述的背景下,我们使用 LlamaFactory 原生数据集,支持 lora的 vllm 批量推理。
完整代码如下:

import json
import os
from typing import List

from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer

def vllm_infer():
    model_args, data_args, training_args, finetuning_args, generating_args = (
        get_train_args()
    )
    tokenizer = load_tokenizer(model_args)["tokenizer"]
    template = get_template_and_fix_tokenizer(tokenizer, data_args)

    eval_dataset = get_dataset(
        template, model_args, data_args, training_args, finetuning_args.stage, tokenizer
    )["eval_dataset"]

    prompts = [item["input_ids"] for item in eval_dataset]
    prompts = tokenizer.batch_decode(prompts, skip_special_tokens=False)

    labels = [
        list(filter(lambda x: x != IGNORE_INDEX, item["labels"]))
        for item in eval_dataset
    ]
    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    sampling_params = SamplingParams(
        temperature=generating_args.temperature,
        top_k=generating_args.top_k,
        top_p=generating_args.top_p,
        max_tokens=2048,
    )

    if model_args.adapter_name_or_path:
        if isinstance(model_args.adapter_name_or_path, list):
            lora_requests = []
            for i, _lora_path in enumerate(model_args.adapter_name_or_path):
                lora_requests.append(
                    LoRARequest(f"lora_adapter_{i}", i, lora_path=_lora_path)
                )
        else:
            lora_requests = LoRARequest(
                "lora_adapter_0", 0, lora_path=model_args.adapter_name_or_path
            )

        enable_lora = True
    else:
        lora_requests = None
        enable_lora = False

    llm = LLM(
        model=model_args.model_name_or_path,
        trust_remote_code=True,
        tokenizer=model_args.model_name_or_path,
        enable_lora=enable_lora,
    )

    outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)

    if not os.path.exists(training_args.output_dir):
        os.makedirs(training_args.output_dir, exist_ok=True)

    output_prediction_file = os.path.join(
        training_args.output_dir, "generated_predictions.jsonl"
    )

    with open(output_prediction_file, "w", encoding="utf-8") as writer:
        res: List[str] = []
        for text, pred, label in zip(prompts, outputs, labels):
            res.append(
                json.dumps(
                    {"prompt": text, "predict": pred.outputs[0].text, "label": label},
                    ensure_ascii=False,
                )
            )
        writer.write("\n".join(res))

vllm.yaml 示例:

## model
model_name_or_path: qwen/Qwen2.5-7B-Instruct
# adapter_name_or_path: lora模型

### method
stage: sft
do_predict: true
finetuning_type: lora

### dataset
dataset_dir: 数据集路径
eval_dataset: 数据集
template: qwen
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: output/
overwrite_output_dir: true

### eval
predict_with_generate: true

程序调用:

python vllm_infer.py vllm.yaml

程序运行速度:

Processed prompts: 100%|| 1000/1000 [01:56<00:00,  8.60it/s, est. speed input: 5169.35 toks/s, output: 811.57

总结

本方案在原生 LlamaFactory 数据集的基础上,支持 LoRA 的 vllm 批量推理,能提升了推理效率。


http://www.kler.cn/a/416572.html

相关文章:

  • 刷LeetCode hot100--1.哈希表
  • QuantLib-python使用心得(持续更新)
  • 学习笔记044——HashMap源码学习2
  • 【WPS】【EXCEL】将单元格中字符按照分隔符拆分按行填充到其他单元格
  • LangChain——HTML文本分割 多种文本分割
  • 解决 java -jar 报错:xxx.jar 中没有主清单属性
  • Go语言技巧:快速统一字符串中的换行符,解决跨平台问题
  • T507 buildroot linux4.9之RTC8563开发调试
  • SQLModel与FastAPI结合:构建用户增删改查接口
  • 海盗王用golang重写的AccountServer功能
  • Facebook Audience Network优化指南
  • 学习笔记042——如何通过IDEA中自带的数据库组件导出MySQL数据
  • Jmeter测试工具的安装和使用,mac版本,jmeter版本5.2.1
  • 《向量数据库指南》——稀疏激活:解锁大数据处理新纪元
  • 【游戏引擎之路】登神长阶(十五)——DirectX12龙书:行百里者半九十(学习阶段完结)
  • 介绍一下atoi(arr);(c基础)
  • 汽车驾校寒冬,新增无人机飞手培训技术详解
  • GPT打字机效果—— fetchEventSouce进行sse流式请求
  • Oracle LinuxR7安装Oracle 12.2 RAC集群实施(DNS解析)
  • 【大数据学习 | Spark-SQL】定义UDF和DUAF,UDTF函数
  • 使用Java来构筑一个基础的项目完全梳理(二):前端vue搭建
  • SpringBoot小知识(3):热部署知识
  • LLM - 使用 LLaMA-Factory 微调 Qwen2-VL DPO(LoRA) 图像数据集 教程 (3)
  • 力扣 最长回文字串-5
  • EXCEL截取某一列从第一个字符开始到特定字符结束的字符串到新的一列
  • Websocket——化神篇