当前位置: 首页 > article >正文

vLLM结构化输出(Guided Decoding)

简介

vLLM 的结构化输出特性是通过“引导式解码”(Guided Decoding)实现的,这一功能允许模型在生成文本时遵循特定的格式约束,例如 JSON 模式或正则表达式,从而确保生成的内容符合预期的结构化要求。

后端引擎

启动vLLM时,可以指定--guided-decoding-backend参数来设置引导式编码的具体实现引擎,最新版本默认使用的是xgrammar。可以有以下三种选择:

  • outlines-dev/outlines
  • mlc-ai/xgrammar
  • noamgat/lm-format-enforcer

优势

  1. 输出结果符合预期,不需要额外的兼容逻辑。不使用引导式编码时,模型的输出通常无法控制,导致生成的内容通常需要额外的处理逻辑去兼容,且无法兼容所有情况。
  2. 性能更好。不使用引导式编码时,模型有可能生成与你预期格式无关的token,导致整理耗时较大。

如何使用

chat completion接口的extra_body可以指定输出的格式。

分类任务

示例代码使用guided_choice引导模型生成指定的分类

from openai import OpenAI
client = OpenAI(
    base_url="http://localhost:8000/v1",
    api_key="-",
)

completion = client.chat.completions.create(
    model="Qwen/Qwen2.5-3B-Instruct",
    messages=[
        {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
    ],
    extra_body={"guided_choice": ["positive", "negative"]},
)
print(completion.choices[0].message.content)

正则格式

示例代码使用guided_regex引导模型生成邮箱格式的输出

completion = client.chat.completions.create(
    model="Qwen/Qwen2.5-3B-Instruct",
    messages=[
        {
            "role": "user",
            "content": "Generate an example email address for Alan Turing, who works in Enigma. End in .com and new line. Example result: alan.turing@enigma.com\n",
        }
    ],
    extra_body={"guided_regex": "\w+@\w+\.com\n", "stop": ["\n"]},
)
print(completion.choices[0].message.content)

JSON格式

示例代码使用guided_json引导模型生成json格式输出(使用pydantic只是为了得到json schema,你也可以手动提供json schema)

from pydantic import BaseModel
from enum import Enum

class CarType(str, Enum):
    sedan = "sedan"
    suv = "SUV"
    truck = "Truck"
    coupe = "Coupe"


class CarDescription(BaseModel):
    brand: str
    model: str
    car_type: CarType


json_schema = CarDescription.model_json_schema()

completion = client.chat.completions.create(
    model="Qwen/Qwen2.5-3B-Instruct",
    messages=[
        {
            "role": "user",
            "content": "Generate a JSON with the brand, model and car_type of the most iconic car from the 90's",
        }
    ],
    extra_body={"guided_json": json_schema},
)
print(completion.choices[0].message.content)

EBNF语法格式

EBNF 是 Extended Backus-Naur Form(扩展巴科斯-诺尔范式) 的缩写,它是一种用于描述上下文无关语法的标准化表示法。EBNF 是 BNF 的扩展版本,比 BNF 更加简洁和易读,广泛用于定义编程语言、协议以及其他形式化语言的语法规则。

示例代码使用guided_grammar用于指导生成符合特定规则(SQL 查询格式)的语言结构。

simplified_sql_grammar = """
    ?start: select_statement

    ?select_statement: "SELECT " column_list " FROM " table_name

    ?column_list: column_name ("," column_name)*

    ?table_name: identifier

    ?column_name: identifier

    ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
"""

completion = client.chat.completions.create(
    model="Qwen/Qwen2.5-3B-Instruct",
    messages=[
        {
            "role": "user",
            "content": "Generate an SQL query to show the 'username' and 'email' from the 'users' table.",
        }
    ],
    extra_body={"guided_grammar": simplified_sql_grammar},
)
print(completion.choices[0].message.content)

实现原理

结构化输出流程图

构建logits_processor

这里以xgrammar作为示例,使用transformers进行模型推理时,只需要在generate方法的入参,指定logits_processor就行。

import xgrammar as xgr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

device = "cuda"  # Or "cpu", etc.
model_name = "meta-llama/Llama-3.2-1B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float32, device_map=device
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)

# 1. 组装inputs
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Introduce yourself in JSON briefly."},
]
texts = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer(texts, return_tensors="pt").to(model.device)

# 2. 获取compiled grammar
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=config.vocab_size)
grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
compiled_grammar = grammar_compiler.compile_builtin_json_grammar()
# Other ways: provide a json schema string
# compiled_grammar = grammar_compiler.compile_json_schema(json_schema_string)
# Or provide an EBNF string
# compiled_grammar = grammar_compiler.compile_grammar(ebnf_string)

# 3. generate时,指定compiled_grammar作为logits_processor
xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar)
generated_ids = model.generate(
    **model_inputs, max_new_tokens=512, logits_processor=[xgr_logits_processor]
)
generated_ids = generated_ids[0][len(model_inputs.input_ids[0]) :]
print(tokenizer.decode(generated_ids, skip_special_tokens=True))

屏蔽

logits_processor的内部处理逻辑大致如下:

# 1. 初始化grammar matcher,并实例化一个bitmask
matcher = xgr.GrammarMatcher(compiled_grammar)
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)

# 模型 LLM 推理过程,logits 为模型的推理结果
for logits in LLM.inference(**model_inputs)
	# 2. 使用 GrammarMatcher 计算 bitmask 并应用到 logits 上
    matcher.fill_next_token_bitmask(token_bitmask)
    xgr.apply_token_bitmask_inplace(logits, token_bitmask.to(logits.device))
    
    # 3. 获取下一个token id (使用softmax得到概率值,再根据具体的取样算法获取最终生成的token_id)
    probs = torch.softmax(logits, dim=-1).cpu().numpy()
    next_token_id = np.random.choice(list(range(full_vocab_size)), p=probs)
    
    # 4. 更新 GrammarMatcher 的状态
    assert matcher.accept_token(next_token_id), f"Invalid token: {next_token_id}"
    
    # 5. 检查终止条件
    if next_token_id == tokenizer.eos_token_id:
        break
        
    # 6. 将 token_id 转换为 token 并返回
    next_token = tokenizer.decode([next_token_id])[0]  # 根据实际的 tokenizer 转换 token_id 为 token
    yield next_token  # 逐步返回每个生成的 token
    

总结

语言模型的输出由模型的推理结果(logits)通过采样或其他策略生成。在结构化生成场景中,GrammarMatcher 实例化一个 bitmask ,对模型的生成过程进行约束,确保输出符合预定义的语法规则(如 JSON 格式或特定语言的 EBNF 语法)。


http://www.kler.cn/a/464449.html

相关文章:

  • 渗透测试-非寻常漏洞案例
  • 【开源社区openEuler实践】hpcrunner
  • AfuseKt1.4.4 | 刮削视频播放器,支持阿里云盘和自动海报墙
  • golang 编程规范 - 项目目录结构
  • SASS 简化代码开发的基本方法
  • 今日复盘103周五(189)
  • C语言中的va_list
  • 云架构Web端的工业MES系统设计之区分工业过程
  • 工业路由器是什么?ER5000为何是领先5G路由器行业
  • 鸿蒙HarmonyOS开发:系统服务(拨打电话、网络搜索、联系人、位置服务、拉起弹框请求用户授权)
  • OpenCV报错:应用程序无法正常启动0xc000007b
  • Hack The Box-Starting Point系列Responder
  • CSS列表、表格、鼠标、滤镜样式设置
  • 深入理解 C 语言预处理:从源文件到可执行程序的关键步骤
  • Vue3实战教程》24:Vue3自定义指令
  • linux下安装达梦数据库v8详解
  • 通过Dockerfile来实现项目可以指定读取不同环境的yml包
  • 24.Java 新特性扩展(重复注解、类型注解)
  • Docker隔离及资源限制原理
  • 参观华为-拓宽全球视野
  • ip属地是看运营商吗还是手机
  • 【C语言 采集数据 精简排序】
  • 数字化转型 · OCR 技术如何打破效率瓶颈?
  • SpringMVC(六)拦截器
  • 栈及栈的操作
  • 【three.js】材质(Material)