transformers蒸馏版本对话小模型
Facebook/blenderbot-400M-distill
- 链接: https://huggingface.co/facebook/blenderbot-400M-distill
- 参数规模: 虽然名称中有 400M,但这是一个蒸馏版本,其原始模型参数量接近 2B,蒸馏后的模型更适合在资源受限的设备上运行,保持了对话能力。
- 特点:
- BlenderBot 系列模型是 Meta (Facebook) 开源的对话模型,在多轮对话方面表现出色。
- 这是一个蒸馏版本,模型体积小,推理速度快。
- 专注于对话任务,具有较强的对话生成能力。
- 支持多种对话技能,如情感表达、知识检索等。
- 支持英语。
- 适用场景: 适合对话系统、聊天机器人、客服机器人等。显存要求不高,基本上有显卡都能跑。
如何使用 (基于 transformers
库):
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
import torch
# 模型名称
model_name = "blenderbot-400M-distill"
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 加载模型
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# 移动模型到设备
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# 输入对话
conversation = [
"Who are you?",
"How are you?",
]
# 处理对话输入
inputs = tokenizer(conversation, return_tensors="pt", padding=True).to(device)
# 生成回复
outputs = model.generate(**inputs, max_new_tokens=50)
reply0 = tokenizer.decode(outputs[0], skip_special_tokens=True)
reply1 = tokenizer.decode(outputs[1], skip_special_tokens=True)
print("Reply0:", reply0)
print("Reply1:", reply1)
使用结果: