利用 LangChain 和一个大语言模型(LLM)构建一个链条,自动从用户输入的问题中提取相关的 SQL 表信息,再生成对应的 SQL 查询
示例代码:
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_sql_query_chain
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic
# 系统消息,要求 LLM 返回与问题相关的 SQL 表类别
system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:
Music
Business"""
# 初始化 LLM 模型
table_extractor_llm = init_chat_model("llama3-70b-8192", model_provider="groq", temperature=0)
# 创建提取链:将用户问题转换为 Table 模型的实例
category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)
# 定义一个函数,根据 Table 对象映射到具体的 SQL 表名
def get_tables(categories: List[Table]) -> List[str]:
"""将类别名称映射到对应的 SQL 表名列表."""
tables = []
for category in categories:
if category.name == "Music":
tables.extend(
[
"Album",
"Artist",
"Genre",
"MediaType",
"Playlist",
"PlaylistTrack",
"Track",
]
)
elif category.name == "Business":
tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
return tables
# 将类别提取链与映射函数组合,得到一个返回 SQL 表名列表的链
table_chain = category_chain | get_tables
# 定义自定义 SQL 提示模板,用于生成 SQL 查询
custom_prompt = PromptTemplate(
input_variables=["dialect", "input", "table_info", "top_k"],
template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Don't limit the results to {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)
# 创建 SQL 查询链
query_chain = create_sql_query_chain(table_extractor_llm, db, prompt=custom_prompt)
# 利用 bind 将固定参数绑定到 SQL 查询链中
bound_chain = query_chain.bind(
dialect=db.dialect,
table_info=db.get_table_info(),
top_k=55
)
# 将输入中的 "question" 键复制到 "input" 键,同时保留原始数据
table_chain = (lambda x: {**x, "input": x["question"]}) | table_chain
# 使用 RunnablePassthrough.assign 将提取到的表名添加到上下文中,然后与 SQL 查询链组合
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | bound_chain
# 调用整个链,生成 SQL 查询
query = full_chain.invoke(
{"question": "What are all the genres of Alanis Morisette songs? Do not repeat!"}
)
print(query)
这段代码主要展示如何利用 LangChain 和一个大语言模型(LLM)构建一个链条,自动从用户输入的问题中提取相关的 SQL 表信息,再生成对应的 SQL 查询。下面我将分步详细解释每个部分的作用,并通过举例说明每段代码的输入和输出。
1. 定义系统消息和初始化 LLM 模型
system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:
Music
Business"""
-
作用:
这段系统消息告诉 LLM:请根据用户的问题返回与问题相关的 SQL 表类别,这里限定了两类——“Music”和“Business”。 -
举例:
如果用户的问题涉及音乐信息(例如歌曲、专辑等),那么 LLM 会返回 “Music”;如果涉及客户、发票等信息,则返回 “Business”。
table_extractor_llm = init_chat_model("llama3-70b-8192", model_provider="groq", temperature=0)
-
作用:
初始化一个 LLM 模型(此处使用 llama3-70b-8192,由 groq 提供,温度设为 0 以保证回答确定性),后续会用这个模型进行类别提取和 SQL 查询生成。 -
输出:
返回一个 LLM 实例table_extractor_llm
。
2. 创建提取链:从问题中抽取相关表类别
category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)
-
作用:
这里利用create_extraction_chain_pydantic
创建了一个链,该链的任务是将用户输入的问题转换为一个或多个符合 Pydantic 模型Table
的实例。也就是说,LLM 会分析问题并输出如Table(name="Music")
或Table(name="Business")
的结果。 -
输入:
用户问题(例如 “What are all the genres of Alanis Morisette songs? Do not repeat!”)。 -
输出:
一个或多个Table
对象,指明问题相关的表类别。例如,对于这个问题,可能返回[Table(name="Music")]
。
3. 定义映射函数,将类别映射到具体的 SQL 表名
def get_tables(categories: List[Table]) -> List[str]:
"""将类别名称映射到对应的 SQL 表名列表."""
tables = []
for category in categories:
if category.name == "Music":
tables.extend(
[
"Album",
"Artist",
"Genre",
"MediaType",
"Playlist",
"PlaylistTrack",
"Track",
]
)
elif category.name == "Business":
tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
return tables
-
作用:
此函数接收前面提取链返回的Table
对象列表,根据类别名称映射到具体的 SQL 表名列表:- 如果类别是 “Music”,则映射为音乐相关的多个表(如 Album、Artist、Genre 等)。
- 如果类别是 “Business”,则映射为商业相关的表(如 Customer、Invoice 等)。
-
举例:
- 输入:
[Table(name="Music")]
- 输出:
["Album", "Artist", "Genre", "MediaType", "Playlist", "PlaylistTrack", "Track"]
- 输入:
4. 组合提取链和映射函数
table_chain = category_chain | get_tables
-
作用:
利用管道操作符(|
)将category_chain
和get_tables
组合起来。整个链条(table_chain
)的作用就是:接收用户问题 → 利用 LLM 提取相关类别 → 将类别映射为具体的 SQL 表名列表。 -
输入:
一个包含用户问题的字典(例如{"question": "..."}
)。 -
输出:
一个 SQL 表名列表,如上例中的音乐相关表名列表。
5. 定义自定义 SQL 提示模板
custom_prompt = PromptTemplate(
input_variables=["dialect", "input", "table_info", "top_k"],
template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Don't limit the results to {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)
-
作用:
该模板为生成 SQL 查询提供指令:- 指定 SQL 方言(如 MySQL、PostgreSQL 等)。
- 提供数据库的表结构信息。
- 告诉 LLM 根据问题(
{input}
)生成正确的 SQL 查询。 - 不要限制返回行数,并且只返回 SQL 语句本身,无额外说明。
-
举例:
如果传入:dialect
: “SQLite”input
: “What are all the genres of Alanis Morisette songs? Do not repeat!”table_info
: 数据库所有表的结构信息top_k
: 55
那么模板会指导 LLM 输出类似下面的 SQL 查询(实际内容由 LLM 根据 schema 生成):
SELECT DISTINCT Genre.Name FROM Track JOIN Genre ON Track.GenreId = Genre.GenreId JOIN Artist ON Track.ArtistId = Artist.ArtistId WHERE Artist.Name = 'Alanis Morisette';
6. 创建 SQL 查询链并绑定固定参数
query_chain = create_sql_query_chain(table_extractor_llm, db, prompt=custom_prompt)
- 作用:
利用同一个 LLM 实例和预定义的 SQL 提示模板,创建一个 SQL 查询链。该链将根据数据库表结构(db
)和用户问题生成 SQL 查询。
bound_chain = query_chain.bind(
dialect=db.dialect,
table_info=db.get_table_info(),
top_k=55
)
-
作用:
通过bind
方法将一些固定的参数绑定到 SQL 查询链上:dialect
:数据库使用的 SQL 方言。table_info
:数据库中所有表的结构信息。top_k
:限制返回的行数,这里设定为 55 行,但指令中说明不要限制,所以其实这个参数仅作为提示的一部分。
-
输出:
得到一个参数已经固定的 SQL 查询链bound_chain
,后续调用时只需要传入用户问题(以及其他动态数据)。
7. 调整输入数据格式
table_chain = (lambda x: {**x, "input": x["question"]}) | table_chain
-
作用:
这行代码先用一个 lambda 函数将输入字典中的"question"
键复制一份到"input"
键,目的是统一变量名称(因为上面的 SQL 提示模板要求有input
变量)。然后再将结果传递给table_chain
。 -
举例:
- 输入:
{"question": "What are all the genres of Alanis Morisette songs? Do not repeat!"}
- lambda 输出:
{"question": "What are all the genres of Alanis Morisette songs? Do not repeat!", "input": "What are all the genres of Alanis Morisette songs? Do not repeat!"}
- 最终经过 table_chain 输出: 列表形式的 SQL 表名,如
["Album", "Artist", "Genre", "MediaType", "Playlist", "PlaylistTrack", "Track"]
- 输入:
8. 组合整个链条,生成最终 SQL 查询
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | bound_chain
-
作用:
这里使用RunnablePassthrough.assign
将从table_chain
得到的 SQL 表名列表赋值到上下文中的table_names_to_use
键,然后通过管道传递给已经绑定参数的 SQL 查询链bound_chain
。这一步确保了在生成 SQL 查询时,上下文中不仅包含用户的原始问题,还包含了与之相关的 SQL 表名信息。 -
输入:
包含用户问题的字典(经过前面的处理已包含"input"
键)。 -
输出:
经过整个链条处理后,输出最终生成的 SQL 查询语句。
9. 调用链条并生成 SQL 查询
query = full_chain.invoke(
{"question": "What are all the genres of Alanis Morisette songs? Do not repeat!"}
)
print(query)
-
作用:
这里将包含用户问题的字典传递给full_chain
。整个流程如下:- 提取表类别:首先通过
table_chain
将"question"
转换为"input"
,然后利用 LLM 提取出与问题相关的类别(预期为 “Music”)。 - 映射表名称:根据类别映射出所有与音乐相关的 SQL 表名。
- 生成 SQL 查询:利用绑定好的
bound_chain
(包含 SQL 模板、数据库 schema 信息等),结合用户的问题和上下文信息,生成一个正确的 SQL 查询。
- 提取表类别:首先通过
-
输出举例:
假设 LLM 理解问题并生成的 SQL 查询可能为:SELECT DISTINCT Genre.Name FROM Track JOIN Genre ON Track.GenreId = Genre.GenreId JOIN Artist ON Track.ArtistId = Artist.ArtistId WHERE Artist.Name = 'Alanis Morisette';
(实际生成的 SQL 语句会依赖于 LLM 的理解和数据库的 schema 信息。)
最后运行这个SQL语句
db.run(query)
输出:
总结
这段代码整体实现了一个智能化的数据查询过程:
- 输入: 用户问题(如关于 Alanis Morisette 歌曲的查询)。
- 内部处理:
- 利用 LLM 提取相关 SQL 表类别。
- 根据类别映射出具体的 SQL 表名称。
- 结合数据库的表结构和预定义的 SQL 提示模板,生成正确的 SQL 查询语句。
- 输出: 一条 SQL 查询语句,用来从数据库中获取答案。
这种链式结构使得整个流程模块化、可扩展:可以分别替换提取逻辑、映射逻辑和 SQL 查询生成逻辑,非常适合在实际应用中自动生成数据库查询。