大模型使用vLLM推理加速
关于vLLM推理加速,可以参考之前的帖子:vLLM加速组件XFormers与FlashAttention的区别
在使用 vLLM
进行模型推理时,即使你不显式调用 tokenizer
,vLLM
也会自动处理 tokenization。vLLM
内部会使用模型对应的 tokenizer
来对输入文本进行 tokenization。以下是一些关键点和示例代码,帮助你理解这一过程。
关键点
- 自动 Tokenization:
vLLM
会在内部自动调用tokenizer
对输入文本进行 tokenization。 - 输入格式:你可以直接传递字符串作为输入,
vLLM
会处理其余部分。 - 生成参数:你可以设置生成参数来控制生成过程。
示例代码
以下是一个完整的示例,展示了如何使用 vLLM
进行模型推理,而不显式调用 tokenizer
:
from vllm import LLM, SamplingParams
# 模型名称或路径
model_name_or_path = 'model_name'
# 设置采样参数
sampling_params = SamplingParams(
temperature=0.7,
top_k=50,
top_p=0.95,
max_tokens=50
)
# 加载预训练的模型
model = LLM(model=model_name_or_path)
# 输入文本
input_text = "Hello, how are you?"
# 将输入文本传递给模型
outputs = model.generate([input_text], sampling_params=sampling_params)
# 解码生成的输出
for output in outputs:
generated_text = output.outputs[0].text
print("Generated Text:", generated_text)
详细解释
-
加载模型:
model = LLM(model=model_name_or_path)
这一步加载了预训练的模型,并且
vLLM
会自动加载对应的tokenizer
。 -
设置采样参数:
sampling_params = SamplingParams( temperature=0.7, top_k=50, top_p=0.95, max_tokens=50 )
这些参数控制生成过程,例如温度、top-k 和 top-p 等。
-
输入文本:
input_text = "Hello, how are you?"
这是你想要生成文本的输入。
-
生成文本:
outputs = model.generate([input_text], sampling_params=sampling_params)
这一步将输入文本传递给模型,
vLLM
会自动进行 tokenization 并生成文本。 -
解码生成的输出:
for output in outputs: generated_text = output.outputs[0].text print("Generated Text:", generated_text)
这一步解码生成的
token_ids
并打印生成的文本。
调试建议
-
打印生成的原始输出:
for output in outputs: print("Raw Output:", output)
这可以帮助你检查生成的原始输出,确保每一步都正确。
-
检查生成的
token_ids
:for output in outputs: token_ids = output.outputs[0].token_ids print("Token IDs:", token_ids)
这可以帮助你确认生成的
token_ids
是否合理。 -
确保模型和 tokenizer 匹配:
确保你使用的模型和 tokenizer 是同一个预训练模型的。
示例调试
from vllm import LLM, SamplingParams
# 模型名称或路径
model_name_or_path = 'model_name'
# 设置采样参数
sampling_params = SamplingParams(
temperature=0.7,
top_k=50,
top_p=0.95,
max_tokens=50
)
# 加载预训练的模型
model = LLM(model=model_name_or_path)
# 输入文本
input_text = "Hello, how are you?"
# 将输入文本传递给模型
outputs = model.generate([input_text], sampling_params=sampling_params)
# 打印生成的原始输出
for output in outputs:
print("Raw Output:", output)
# 解码生成的输出
for output in outputs:
generated_text = output.outputs[0].text
print("Generated Text:", generated_text)
# 检查 token_ids
token_ids = output.outputs[0].token_ids
print("Token IDs:", token_ids)