当前位置: 首页 > article >正文

利用 LangChain 和一个大语言模型(LLM)构建一个链条,自动从用户输入的问题中提取相关的 SQL 表信息,再生成对应的 SQL 查询

示例代码:

from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_sql_query_chain
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic

# 系统消息,要求 LLM 返回与问题相关的 SQL 表类别
system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:

Music
Business"""

# 初始化 LLM 模型
table_extractor_llm = init_chat_model("llama3-70b-8192", model_provider="groq", temperature=0)

# 创建提取链:将用户问题转换为 Table 模型的实例
category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)

# 定义一个函数,根据 Table 对象映射到具体的 SQL 表名
def get_tables(categories: List[Table]) -> List[str]:
    """将类别名称映射到对应的 SQL 表名列表."""
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables

# 将类别提取链与映射函数组合,得到一个返回 SQL 表名列表的链
table_chain = category_chain | get_tables 

# 定义自定义 SQL 提示模板,用于生成 SQL 查询
custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k"],
    template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Don't limit the results to {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)

# 创建 SQL 查询链
query_chain = create_sql_query_chain(table_extractor_llm, db, prompt=custom_prompt)

# 利用 bind 将固定参数绑定到 SQL 查询链中
bound_chain = query_chain.bind(
    dialect=db.dialect,
    table_info=db.get_table_info(),
    top_k=55
)

# 将输入中的 "question" 键复制到 "input" 键,同时保留原始数据
table_chain = (lambda x: {**x, "input": x["question"]}) | table_chain

# 使用 RunnablePassthrough.assign 将提取到的表名添加到上下文中,然后与 SQL 查询链组合
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | bound_chain

# 调用整个链,生成 SQL 查询
query = full_chain.invoke(
    {"question": "What are all the genres of Alanis Morisette songs? Do not repeat!"}
)
print(query)

这段代码主要展示如何利用 LangChain 和一个大语言模型(LLM)构建一个链条,自动从用户输入的问题中提取相关的 SQL 表信息,再生成对应的 SQL 查询。下面我将分步详细解释每个部分的作用,并通过举例说明每段代码的输入和输出。


1. 定义系统消息和初始化 LLM 模型

system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:

Music
Business"""
  • 作用:
    这段系统消息告诉 LLM:请根据用户的问题返回与问题相关的 SQL 表类别,这里限定了两类——“Music”和“Business”。

  • 举例:
    如果用户的问题涉及音乐信息(例如歌曲、专辑等),那么 LLM 会返回 “Music”;如果涉及客户、发票等信息,则返回 “Business”。

table_extractor_llm = init_chat_model("llama3-70b-8192", model_provider="groq", temperature=0)
  • 作用:
    初始化一个 LLM 模型(此处使用 llama3-70b-8192,由 groq 提供,温度设为 0 以保证回答确定性),后续会用这个模型进行类别提取和 SQL 查询生成。

  • 输出:
    返回一个 LLM 实例 table_extractor_llm


2. 创建提取链:从问题中抽取相关表类别

category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)
  • 作用:
    这里利用 create_extraction_chain_pydantic 创建了一个链,该链的任务是将用户输入的问题转换为一个或多个符合 Pydantic 模型 Table 的实例。也就是说,LLM 会分析问题并输出如 Table(name="Music")Table(name="Business") 的结果。

  • 输入:
    用户问题(例如 “What are all the genres of Alanis Morisette songs? Do not repeat!”)。

  • 输出:
    一个或多个 Table 对象,指明问题相关的表类别。例如,对于这个问题,可能返回 [Table(name="Music")]
    在这里插入图片描述


3. 定义映射函数,将类别映射到具体的 SQL 表名

def get_tables(categories: List[Table]) -> List[str]:
    """将类别名称映射到对应的 SQL 表名列表."""
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables
  • 作用:
    此函数接收前面提取链返回的 Table 对象列表,根据类别名称映射到具体的 SQL 表名列表:

    • 如果类别是 “Music”,则映射为音乐相关的多个表(如 Album、Artist、Genre 等)。
    • 如果类别是 “Business”,则映射为商业相关的表(如 Customer、Invoice 等)。
  • 举例:

    • 输入: [Table(name="Music")]
    • 输出: ["Album", "Artist", "Genre", "MediaType", "Playlist", "PlaylistTrack", "Track"]

4. 组合提取链和映射函数

table_chain = category_chain | get_tables
  • 作用:
    利用管道操作符(|)将 category_chainget_tables 组合起来。整个链条(table_chain)的作用就是:接收用户问题 → 利用 LLM 提取相关类别 → 将类别映射为具体的 SQL 表名列表。

  • 输入:
    一个包含用户问题的字典(例如 {"question": "..."})。

  • 输出:
    一个 SQL 表名列表,如上例中的音乐相关表名列表。
    在这里插入图片描述


5. 定义自定义 SQL 提示模板

custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k"],
    template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Don't limit the results to {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)
  • 作用:
    该模板为生成 SQL 查询提供指令:

    • 指定 SQL 方言(如 MySQL、PostgreSQL 等)。
    • 提供数据库的表结构信息。
    • 告诉 LLM 根据问题({input})生成正确的 SQL 查询。
    • 不要限制返回行数,并且只返回 SQL 语句本身,无额外说明。
  • 举例:
    如果传入:

    • dialect: “SQLite”
    • input: “What are all the genres of Alanis Morisette songs? Do not repeat!”
    • table_info: 数据库所有表的结构信息
    • top_k: 55
      那么模板会指导 LLM 输出类似下面的 SQL 查询(实际内容由 LLM 根据 schema 生成):
    SELECT DISTINCT Genre.Name FROM Track
    JOIN Genre ON Track.GenreId = Genre.GenreId
    JOIN Artist ON Track.ArtistId = Artist.ArtistId
    WHERE Artist.Name = 'Alanis Morisette';
    

6. 创建 SQL 查询链并绑定固定参数

query_chain = create_sql_query_chain(table_extractor_llm, db, prompt=custom_prompt)
  • 作用:
    利用同一个 LLM 实例和预定义的 SQL 提示模板,创建一个 SQL 查询链。该链将根据数据库表结构(db)和用户问题生成 SQL 查询。
bound_chain = query_chain.bind(
    dialect=db.dialect,
    table_info=db.get_table_info(),
    top_k=55
)
  • 作用:
    通过 bind 方法将一些固定的参数绑定到 SQL 查询链上:

    • dialect:数据库使用的 SQL 方言。
    • table_info:数据库中所有表的结构信息。
    • top_k:限制返回的行数,这里设定为 55 行,但指令中说明不要限制,所以其实这个参数仅作为提示的一部分。
  • 输出:
    得到一个参数已经固定的 SQL 查询链 bound_chain,后续调用时只需要传入用户问题(以及其他动态数据)。


7. 调整输入数据格式

table_chain = (lambda x: {**x, "input": x["question"]}) | table_chain
  • 作用:
    这行代码先用一个 lambda 函数将输入字典中的 "question" 键复制一份到 "input" 键,目的是统一变量名称(因为上面的 SQL 提示模板要求有 input 变量)。然后再将结果传递给 table_chain

  • 举例:

    • 输入: {"question": "What are all the genres of Alanis Morisette songs? Do not repeat!"}
    • lambda 输出: {"question": "What are all the genres of Alanis Morisette songs? Do not repeat!", "input": "What are all the genres of Alanis Morisette songs? Do not repeat!"}
    • 最终经过 table_chain 输出: 列表形式的 SQL 表名,如 ["Album", "Artist", "Genre", "MediaType", "Playlist", "PlaylistTrack", "Track"]

8. 组合整个链条,生成最终 SQL 查询

full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | bound_chain
  • 作用:
    这里使用 RunnablePassthrough.assign 将从 table_chain 得到的 SQL 表名列表赋值到上下文中的 table_names_to_use 键,然后通过管道传递给已经绑定参数的 SQL 查询链 bound_chain。这一步确保了在生成 SQL 查询时,上下文中不仅包含用户的原始问题,还包含了与之相关的 SQL 表名信息。

  • 输入:
    包含用户问题的字典(经过前面的处理已包含 "input" 键)。

  • 输出:
    经过整个链条处理后,输出最终生成的 SQL 查询语句。


9. 调用链条并生成 SQL 查询

query = full_chain.invoke(
    {"question": "What are all the genres of Alanis Morisette songs? Do not repeat!"}
)
print(query)
  • 作用:
    这里将包含用户问题的字典传递给 full_chain。整个流程如下:

    1. 提取表类别:首先通过 table_chain"question" 转换为 "input",然后利用 LLM 提取出与问题相关的类别(预期为 “Music”)。
    2. 映射表名称:根据类别映射出所有与音乐相关的 SQL 表名。
    3. 生成 SQL 查询:利用绑定好的 bound_chain(包含 SQL 模板、数据库 schema 信息等),结合用户的问题和上下文信息,生成一个正确的 SQL 查询。
  • 输出举例:
    假设 LLM 理解问题并生成的 SQL 查询可能为:

    SELECT DISTINCT Genre.Name
    FROM Track
    JOIN Genre ON Track.GenreId = Genre.GenreId
    JOIN Artist ON Track.ArtistId = Artist.ArtistId
    WHERE Artist.Name = 'Alanis Morisette';
    

    (实际生成的 SQL 语句会依赖于 LLM 的理解和数据库的 schema 信息。)


最后运行这个SQL语句

db.run(query)

输出:
在这里插入图片描述

总结

这段代码整体实现了一个智能化的数据查询过程:

  • 输入: 用户问题(如关于 Alanis Morisette 歌曲的查询)。
  • 内部处理:
    1. 利用 LLM 提取相关 SQL 表类别。
    2. 根据类别映射出具体的 SQL 表名称。
    3. 结合数据库的表结构和预定义的 SQL 提示模板,生成正确的 SQL 查询语句。
  • 输出: 一条 SQL 查询语句,用来从数据库中获取答案。

这种链式结构使得整个流程模块化、可扩展:可以分别替换提取逻辑、映射逻辑和 SQL 查询生成逻辑,非常适合在实际应用中自动生成数据库查询。


http://www.kler.cn/a/568554.html

相关文章:

  • 基于MATLAB 的GUI设计
  • 【2025-03-02】基础算法:二叉树 相同 对称 平衡 右视图
  • Pytorch实现之结合mobilenetV2和FPN的GAN去雾算法
  • Windows搭建jenkins服务
  • 【Linux】【网络】不同子网下的客户端和服务器通信其它方式
  • DeepSeek-R1 大模型实战:腾讯云 HAI 平台 3 分钟极速部署指南
  • .net开源商城_C#开源商城源码_.netcore开源商城多少钱
  • 机器学习:线性回归,梯度下降,多元线性回归
  • Django数据迁移
  • 从零开始用react + tailwindcss + express + mongodb实现一个聊天程序(八) 聊天框用户列表
  • Java 网络八股(2) TCP6大核心机制/异常处理
  • 基于单片机的智能宿舍管理系统(论文+源码)
  • 【3天快速入门WPF】11-附加属性
  • 【MongoDB】在Windows11下安装与使用
  • 蓝桥杯web第三天
  • h5 IOS端渐变的兼容问题 渐变实现弧形效果
  • Ubuntu 下 nginx-1.24.0 源码分析 - ngx_init_cycle 函数 - 详解(9)
  • LeetCode 2353. 设计食物评分系统题解
  • Qt 的 Lambda 捕获局部变量导致 UI 更新异常的分析与解决
  • Solar2月应急响应公益月赛