构造一个工具(TravelSQLAgentTool),利用大语言模型(例如 Llama 模型)来完成 SQL 查询代理工具
完整代码:
from langchain_core.tools import tool
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
import os
from dotenv import load_dotenv
load_dotenv()
class TravelSQLAgentTool:
"""
A tool for interacting with a travel-related SQL database using an LLM (Language Model) to generate and execute SQL queries.
This tool enables users to ask travel-related questions, which are transformed into SQL queries by a language model.
The SQL queries are executed on the provided SQLite database, and the results are processed by the language model to
generate a final answer for the user.
Attributes:
sql_agent_llm (LLAMA): An instance of a LLAMA language model used to generate and process SQL queries.
system_role (str): A system prompt template that guides the language model in answering user questions based on SQL query results.
db (SQLDatabase): An instance of the SQL database used to execute queries.
chain (RunnablePassthrough): A chain of operations that creates SQL queries, executes them, and generates a response.
Methods:
__init__: Initializes the TravelSQLAgentTool by setting up the language model, SQL database, and query-answering pipeline.
"""
def __init__(self, llm: str, sqldb_directory: str, llm_temerature: float) -> None:
"""
Initializes the TravelSQLAgentTool with the necessary configurations.
Args:
llm (str): The name of the language model to be used for generating and interpreting SQL queries.
sqldb_directory (str): The directory path where the SQLite database is stored.
llm_temerature (float): The temperature setting for the language model, controlling response randomness.
"""
# 初始化 Llama 模型,使用 Groq 后端
# "llama-3.3-70b-specdec"
self.sql_agent_llm = init_chat_model(llm, model_provider="groq", temperature=llm_temerature)
self.db = SQLDatabase.from_uri(
f"sqlite:///{sqldb_directory}")
# print(self.db.get_usable_table_names())
# 定义自定义提示模板,用于生成 SQL 查询
custom_prompt = PromptTemplate(
input_variables=["dialect", "input", "table_info", "top_k"],
template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Do not Limit {top_k} the results.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)
# write_query
write_query = create_sql_query_chain(self.sql_agent_llm, self.db,prompt=custom_prompt)
execute_query = QuerySQLDataBaseTool(db=self.db)
# answer
self.system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
Question: {question}\n
SQL Query: {query}\n
SQL Result: {result}\n
Answer:
"""
answer_prompt = PromptTemplate.from_template(
self.system_role)
answer = answer_prompt | self.sql_agent_llm | StrOutputParser()
# 8. 定义一个调试链 debug_chain,用于打印 write_query 生成的 SQL 查询。
# 这里使用 RunnablePassthrough 执行一个 lambda 函数:
# lambda data: (print("write_query execution result:", data["query"]), data)[1]
# 解释:先打印 data 字典中 "query" 对应的 SQL 语句,然后将原始 data 返回,以便后续链继续处理。
debug_chain = RunnablePassthrough(lambda data: (print("write_query execution result:", data["query"]), data)[1])
# 9. 构造完整的处理链 chain_ex:
# - 首先调用 write_query 生成 SQL 查询,并将结果存储到字典的 "query" 字段中;
# - 接着通过 debug_chain 打印出生成的 SQL 查询;
# - 然后使用 execute_query 执行 SQL 查询,结果存入 "result" 字段(此处利用 itemgetter 提取 "query" 字段后传递给执行工具);
# - 最后将问题、SQL 查询以及查询结果传递给 answer 链,生成最终回答。
# chain
chain_ex = (
RunnablePassthrough.assign(query=write_query)
| debug_chain
| RunnablePassthrough.assign(result=itemgetter("query") | execute_query)
| answer
)
# 利用 bind 绑定固定参数到链中
bound_chain = chain_ex.bind(
dialect=self.db.dialect,
table_info=self.db.get_table_info(),
top_k=55
)
self.chain = bound_chain
sqldb_directory = here("data/Chinook.db")
query = "and calculate the number of all Playlist"
@tool
def query_travel_sqldb(query: str) -> str:
"""Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
agent = TravelSQLAgentTool(
llm="llama3-70b-8192", #TOOLS_CFG.travel_sqlagent_llm
sqldb_directory= sqldb_directory,#TOOLS_CFG.travel_sqldb_directory
llm_temerature=0 #TOOLS_CFG.travel_sqlagent_llm_temperature
)
response = agent.chain.invoke({"question": query})
return response
print(query_travel_sqldb(query))
下面我将分步骤、用通俗易懂的语言详细解释这段代码的含义和作用,并举例说明各部分的工作流程。
1. 整体概述
这段代码的主要目的是构造一个工具(TravelSQLAgentTool),它能利用大语言模型(LLM,例如 Llama 模型)来完成以下任务:
- 根据用户提出的“旅行”相关问题生成一条 SQL 查询语句;
- 将这条 SQL 查询语句发送给一个 SQLite 数据库执行,获得查询结果;
- 最后,再利用语言模型将查询结果转换成用户易懂的答案。
同时,为了调试方便,在生成 SQL 查询后会打印出这条查询语句,这样你就能看到 LLM 生成的 SQL 语句内容。
2. 模块导入和环境设置
from langchain_core.tools import tool
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
import os
from dotenv import load_dotenv
load_dotenv()
-
导入各个模块:
这些模块主要用于创建提示模板、构造数据处理链、连接 SQL 数据库以及注册工具函数。 -
加载环境变量:
load_dotenv()
用于加载.env
文件中的环境变量,方便管理敏感信息(如 API 密钥)。
3. TravelSQLAgentTool 类的定义
这个类封装了整个查询流程。我们逐行看它的初始化方法 __init__
。
3.1 初始化语言模型和数据库
self.sql_agent_llm = init_chat_model(llm, model_provider="groq", temperature=llm_temerature)
- 作用:调用
init_chat_model
初始化一个 LLM 模型,这里传入的llm
参数(例如 “llama3-70b-8192”)指定使用哪个模型。 - 参数解释:
model_provider="groq"
:表示使用 Groq 后端;temperature=llm_temerature
:温度参数决定了模型回答的随机性(0 表示确定性很高)。
self.db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
- 作用:通过给定的数据库路径(如
"data/Chinook.db"
)构造一个 SQLite 数据库连接实例,后续执行 SQL 查询时会使用它。
3.2 定义生成 SQL 查询的提示模板
custom_prompt = PromptTemplate(
input_variables=["dialect", "input", "table_info", "top_k"],
template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Do not Limit {top_k} the results.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)
- 作用:构造一个模板,指导语言模型如何生成 SQL 查询。
- 模板内容说明:
{dialect}
:SQL 的方言(例如 SQLite)。{table_info}
:数据库中各表的结构信息;{input}
:用户的问题;{top_k}
:限制查询返回记录条数的参数(不过这里实际上是说明“不要限制”)。
- 举例:
如果用户的问题是“计算所有 Playlist 的数量”,模板就会要求模型生成类似:
的 SQL 查询。SELECT COUNT(*) FROM Playlist;
3.3 创建 SQL 查询和执行链
write_query = create_sql_query_chain(self.sql_agent_llm, self.db, prompt=custom_prompt)
execute_query = QuerySQLDataBaseTool(db=self.db)
-
write_query:
- 作用:利用前面定义的提示模板,让语言模型根据问题生成 SQL 查询。
- 举例:针对问题“calculate the number of all Playlist”,write_query 可能生成:
SELECT COUNT(*) FROM Playlist;
-
execute_query:
- 作用:定义了一个工具,将生成的 SQL 查询发送到数据库执行,并返回查询结果。
3.4 定义生成最终答案的系统提示和链
self.system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
Question: {question}\n
SQL Query: {query}\n
SQL Result: {result}\n
Answer:
"""
answer_prompt = PromptTemplate.from_template(self.system_role)
answer = answer_prompt | self.sql_agent_llm | StrOutputParser()
-
system_role:
- 作用:定义了一个系统提示,告诉模型如何根据用户问题、生成的 SQL 查询和 SQL 查询结果生成最终的答案。
- 举例:如果 SQL 查询返回的结果为 25,系统提示就会告诉模型“用户问:计算所有 Playlist 数量,SQL 查询是……,查询结果是 25,请生成一个最终的自然语言答案”。
-
answer 链:
- 通过将提示、语言模型和一个输出解析器(将模型的文本输出转换成纯文本)串联起来,构成生成最终答案的链。
3.5 定义调试链(debug_chain)
debug_chain = RunnablePassthrough(lambda data: (print("write_query execution result:", data["query"]), data)[1])
- 作用:
这个调试链在数据传递过程中打印出data["query"]
的内容,即打印出由 write_query 生成的 SQL 查询语句,然后把原始数据继续返回给后续链的步骤。 - 举例说明:
假设上一步生成的data
为:
那么这个 lambda 函数会先打印:{"query": "SELECT COUNT(*) FROM Playlist;"}
然后返回原始数据write_query execution result: SELECT COUNT(*) FROM Playlist;
{"query": "SELECT COUNT(*) FROM Playlist;"}
,不对数据做任何修改。
3.6 构造完整的链(chain_ex)
chain_ex = (
RunnablePassthrough.assign(query=write_query)
| debug_chain
| RunnablePassthrough.assign(result=itemgetter("query") | execute_query)
| answer
)
这段代码构造了一个数据处理流水线,每一步的含义如下:
-
RunnablePassthrough.assign(query=write_query)
- 作用:调用
write_query
生成 SQL 查询,并将生成的查询结果存储到数据字典的"query"
键中。 - 举例:生成的数据可能为:
{"query": "SELECT COUNT(*) FROM Playlist;"}
- 作用:调用
-
| debug_chain
- 作用:将上一步生成的数据传入
debug_chain
,打印出 SQL 查询,同时不改变数据。 - 举例:会打印出上述 SQL 查询语句。
- 作用:将上一步生成的数据传入
-
| RunnablePassthrough.assign(result=itemgetter(“query”) | execute_query)
- 作用:利用
itemgetter("query")
从数据字典中提取 SQL 查询语句,然后将其传递给execute_query
工具,执行 SQL 查询,并将执行结果存储到数据字典的"result"
键中。 - 举例:假如数据库中 Playlist 表有 25 条记录,则查询结果可能为:
{"query": "SELECT COUNT(*) FROM Playlist;", "result": 25}
- 作用:利用
-
| answer
- 作用:将包含问题、SQL 查询和查询结果的数据传入答案生成链,得到最终的自然语言回答。
3.7 绑定固定参数
bound_chain = chain_ex.bind(
dialect=self.db.dialect,
table_info=self.db.get_table_info(),
top_k=55
)
self.chain = bound_chain
- 作用:
bind
方法将一些固定参数(如 SQL 的方言、数据库表结构信息、以及 top_k 参数)绑定到流水线中,确保每次调用链时这些参数都自动传递进去。- 绑定后,将整个链赋值给
self.chain
,这样后续调用就会按照这个步骤顺序执行。
4. 全局部分与工具函数定义
4.1 指定数据库路径和示例查询
sqldb_directory = here("data/Chinook.db")
query = "and calculate the number of all Playlist"
- 作用:
- 使用
here
函数定位到数据库文件所在的路径(相对路径 “data/Chinook.db”); - 定义一个示例查询字符串,意思是“计算所有 Playlist 的数量”。
- 使用
4.2 定义工具函数 query_travel_sqldb
@tool
def query_travel_sqldb(query: str) -> str:
"""Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
agent = TravelSQLAgentTool(
llm="llama3-70b-8192", # 指定使用的语言模型
sqldb_directory= sqldb_directory, # 数据库文件路径
llm_temerature=0 # LLM 温度设为 0,表示回答比较确定,不引入随机性
)
response = agent.chain.invoke({"question": query})
return response
- 作用:
- 这个函数使用
@tool
装饰器注册成工具函数,方便外部调用。 - 内部实例化一个 TravelSQLAgentTool 对象,并传入模型名称、数据库路径及温度参数;
- 然后调用构造好的链(
agent.chain
)的invoke
方法,把用户的问题(键名为"question"
)传入整个链进行处理,最终得到回答。
- 这个函数使用
- 举例说明:
当你调用query_travel_sqldb("and calculate the number of all Playlist")
时:- 首先,write_query 生成 SQL 查询
SELECT COUNT(*) FROM Playlist;
; - 调试链打印出这个 SQL 查询;
- 执行查询后,假设结果为 25;
- 最后,语言模型根据系统提示生成答案,比如“总共有 25 个 Playlist。”。
- 首先,write_query 生成 SQL 查询
4.3 最后执行并打印结果
print(query_travel_sqldb(query))
- 作用:
调用上面定义的工具函数,并将最终生成的答案打印出来。
5. 总结
这段代码构建了一个基于 LLM 的 SQL 查询代理工具,其工作流程为:
- 用户输入一个查询问题;
- 工具通过自定义提示模板调用语言模型生成 SQL 查询;
- 生成的 SQL 查询在调试链中打印出来以便检查;
- SQL 查询被发送到 SQLite 数据库执行,获得结果;
- 语言模型根据问题、SQL 查询和结果生成最终的自然语言答案;
- 最终答案返回并打印出来。
这种设计使得查询流程高度自动化,同时又便于调试和检查中间步骤,帮助理解每一步是如何工作的。