llama3.1 8b instruct的function calling的template解析
最近在研究function calling,但是没有llama3.1 8b instruction调用function calling的教程,于是我顺手写了一个调用的例子,仅供大家学习参考,安装依赖:
pip install transformers
pip install langchain
调用代码示例:
import json
import torch
from langchain_core.prompts import PromptTemplate
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "<your path to >/llama3.1-8b-instruct"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
sys_message="""Cutting Knowledge Date: December 2023
Today Date: 23 July 2024
When you receive a tool call response, use the output to format an answer to the orginal user question.
You are a helpful assistant with tool calling capabilities.
"""
user_message="""
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
{{functions}}
Question: {{query}}
"""
prompt_template= PromptTemplate(template=user_message, input_variables=["functions", "query"], template_format="jinja2")
query = "what is the weather like in San Fransisco?"
functions = [{
"type": "function",
"function": {
"name": "get_current_conditions",
"description": "Get the current weather conditions for a specific location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g., San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["Celsius", "Fahrenheit"],
"description": "The temperature unit to use. Infer this from the user's location."
}
},
"required": ["location", "unit"]
}
}
}]
messages = [
{"role": "system", "content": sys_message},
{"role":"user", "content": prompt_template.format(query=query, functions=functions)}]
# inputs = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer.apply_chat_template(messages, return_dict=True, return_tensors="pt")
print(inputs)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map="auto")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
out = model.generate(**inputs, max_new_tokens=128)
print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):]))
参考文献
[1]. JSON based tool calling
[2].开源模型 Function Call 方案梳理