当前位置: 首页 > article >正文

构造一个工具(TravelSQLAgentTool),利用大语言模型(例如 Llama 模型)来完成 SQL 查询代理工具

完整代码:

from langchain_core.tools import tool
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
import os
from dotenv import load_dotenv
load_dotenv()


class TravelSQLAgentTool:
    """
    A tool for interacting with a travel-related SQL database using an LLM (Language Model) to generate and execute SQL queries.

    This tool enables users to ask travel-related questions, which are transformed into SQL queries by a language model.
    The SQL queries are executed on the provided SQLite database, and the results are processed by the language model to
    generate a final answer for the user.

    Attributes:
        sql_agent_llm (LLAMA): An instance of a LLAMA language model used to generate and process SQL queries.
        system_role (str): A system prompt template that guides the language model in answering user questions based on SQL query results.
        db (SQLDatabase): An instance of the SQL database used to execute queries.
        chain (RunnablePassthrough): A chain of operations that creates SQL queries, executes them, and generates a response.

    Methods:
        __init__: Initializes the TravelSQLAgentTool by setting up the language model, SQL database, and query-answering pipeline.
    """

    def __init__(self, llm: str, sqldb_directory: str, llm_temerature: float) -> None:
        """
        Initializes the TravelSQLAgentTool with the necessary configurations.

        Args:
            llm (str): The name of the language model to be used for generating and interpreting SQL queries.
            sqldb_directory (str): The directory path where the SQLite database is stored.
            llm_temerature (float): The temperature setting for the language model, controlling response randomness.
        """
        #  初始化 Llama 模型,使用 Groq 后端
        #  "llama-3.3-70b-specdec"
        self.sql_agent_llm = init_chat_model(llm, model_provider="groq", temperature=llm_temerature)

        self.db = SQLDatabase.from_uri(
            f"sqlite:///{sqldb_directory}")
#         print(self.db.get_usable_table_names())

        # 定义自定义提示模板,用于生成 SQL 查询
        custom_prompt = PromptTemplate(
            input_variables=["dialect", "input", "table_info", "top_k"],
            template="""You are a SQL expert using {dialect}.
        Given the following table schema:
        {table_info}
        Generate a syntactically correct SQL query to answer the question: "{input}".
        Do not Limit {top_k} the results.
        Return only the SQL query without any additional commentary or Markdown formatting.
        """
        )

        # write_query
        write_query = create_sql_query_chain(self.sql_agent_llm, self.db,prompt=custom_prompt)
        execute_query = QuerySQLDataBaseTool(db=self.db)

        # answer
        self.system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
            Question: {question}\n
            SQL Query: {query}\n
            SQL Result: {result}\n
            Answer:
            """
        answer_prompt = PromptTemplate.from_template(
            self.system_role)
        answer = answer_prompt | self.sql_agent_llm | StrOutputParser()
        
        # 8. 定义一个调试链 debug_chain,用于打印 write_query 生成的 SQL 查询。
        #    这里使用 RunnablePassthrough 执行一个 lambda 函数:
        #    lambda data: (print("write_query execution result:", data["query"]), data)[1]
        #    解释:先打印 data 字典中 "query" 对应的 SQL 语句,然后将原始 data 返回,以便后续链继续处理。
        debug_chain = RunnablePassthrough(lambda data: (print("write_query execution result:", data["query"]), data)[1])
        
        # 9. 构造完整的处理链 chain_ex:
        #    - 首先调用 write_query 生成 SQL 查询,并将结果存储到字典的 "query" 字段中;
        #    - 接着通过 debug_chain 打印出生成的 SQL 查询;
        #    - 然后使用 execute_query 执行 SQL 查询,结果存入 "result" 字段(此处利用 itemgetter 提取 "query" 字段后传递给执行工具);
        #    - 最后将问题、SQL 查询以及查询结果传递给 answer 链,生成最终回答。
        # chain
        chain_ex = (
            RunnablePassthrough.assign(query=write_query)
            | debug_chain
            | RunnablePassthrough.assign(result=itemgetter("query") | execute_query)
            | answer
        )
        
        # 利用 bind 绑定固定参数到链中
        bound_chain = chain_ex.bind(
            dialect=self.db.dialect,
            table_info=self.db.get_table_info(),
            top_k=55
        )
        self.chain = bound_chain
        
sqldb_directory = here("data/Chinook.db")
query = "and calculate the number of all Playlist"

@tool
def query_travel_sqldb(query: str) -> str:

    """Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
    agent = TravelSQLAgentTool(
        llm="llama3-70b-8192", #TOOLS_CFG.travel_sqlagent_llm
        sqldb_directory= sqldb_directory,#TOOLS_CFG.travel_sqldb_directory
        llm_temerature=0 #TOOLS_CFG.travel_sqlagent_llm_temperature
    )
    response = agent.chain.invoke({"question": query})
    return response

print(query_travel_sqldb(query))

下面我将分步骤、用通俗易懂的语言详细解释这段代码的含义和作用,并举例说明各部分的工作流程。


1. 整体概述

这段代码的主要目的是构造一个工具(TravelSQLAgentTool),它能利用大语言模型(LLM,例如 Llama 模型)来完成以下任务:

  1. 根据用户提出的“旅行”相关问题生成一条 SQL 查询语句;
  2. 将这条 SQL 查询语句发送给一个 SQLite 数据库执行,获得查询结果;
  3. 最后,再利用语言模型将查询结果转换成用户易懂的答案

同时,为了调试方便,在生成 SQL 查询后会打印出这条查询语句,这样你就能看到 LLM 生成的 SQL 语句内容。


2. 模块导入和环境设置

from langchain_core.tools import tool
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
import os
from dotenv import load_dotenv
load_dotenv()
  • 导入各个模块
    这些模块主要用于创建提示模板、构造数据处理链、连接 SQL 数据库以及注册工具函数。

  • 加载环境变量
    load_dotenv() 用于加载 .env 文件中的环境变量,方便管理敏感信息(如 API 密钥)。


3. TravelSQLAgentTool 类的定义

这个类封装了整个查询流程。我们逐行看它的初始化方法 __init__

3.1 初始化语言模型和数据库

self.sql_agent_llm = init_chat_model(llm, model_provider="groq", temperature=llm_temerature)
  • 作用:调用 init_chat_model 初始化一个 LLM 模型,这里传入的 llm 参数(例如 “llama3-70b-8192”)指定使用哪个模型。
  • 参数解释
    • model_provider="groq":表示使用 Groq 后端;
    • temperature=llm_temerature:温度参数决定了模型回答的随机性(0 表示确定性很高)。
self.db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
  • 作用:通过给定的数据库路径(如 "data/Chinook.db")构造一个 SQLite 数据库连接实例,后续执行 SQL 查询时会使用它。

3.2 定义生成 SQL 查询的提示模板

custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k"],
    template="""You are a SQL expert using {dialect}.
        Given the following table schema:
        {table_info}
        Generate a syntactically correct SQL query to answer the question: "{input}".
        Do not Limit {top_k} the results.
        Return only the SQL query without any additional commentary or Markdown formatting.
        """
)
  • 作用:构造一个模板,指导语言模型如何生成 SQL 查询。
  • 模板内容说明
    • {dialect}:SQL 的方言(例如 SQLite)。
    • {table_info}:数据库中各表的结构信息;
    • {input}:用户的问题;
    • {top_k}:限制查询返回记录条数的参数(不过这里实际上是说明“不要限制”)。
  • 举例
    如果用户的问题是“计算所有 Playlist 的数量”,模板就会要求模型生成类似:
    SELECT COUNT(*) FROM Playlist;
    
    的 SQL 查询。

3.3 创建 SQL 查询和执行链

write_query = create_sql_query_chain(self.sql_agent_llm, self.db, prompt=custom_prompt)
execute_query = QuerySQLDataBaseTool(db=self.db)
  • write_query

    • 作用:利用前面定义的提示模板,让语言模型根据问题生成 SQL 查询。
    • 举例:针对问题“calculate the number of all Playlist”,write_query 可能生成:
      SELECT COUNT(*) FROM Playlist;
      
  • execute_query

    • 作用:定义了一个工具,将生成的 SQL 查询发送到数据库执行,并返回查询结果。

3.4 定义生成最终答案的系统提示和链

self.system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
            Question: {question}\n
            SQL Query: {query}\n
            SQL Result: {result}\n
            Answer:
            """
answer_prompt = PromptTemplate.from_template(self.system_role)
answer = answer_prompt | self.sql_agent_llm | StrOutputParser()
  • system_role

    • 作用:定义了一个系统提示,告诉模型如何根据用户问题、生成的 SQL 查询和 SQL 查询结果生成最终的答案。
    • 举例:如果 SQL 查询返回的结果为 25,系统提示就会告诉模型“用户问:计算所有 Playlist 数量,SQL 查询是……,查询结果是 25,请生成一个最终的自然语言答案”。
  • answer 链

    • 通过将提示、语言模型和一个输出解析器(将模型的文本输出转换成纯文本)串联起来,构成生成最终答案的链。

3.5 定义调试链(debug_chain)

debug_chain = RunnablePassthrough(lambda data: (print("write_query execution result:", data["query"]), data)[1])
  • 作用
    这个调试链在数据传递过程中打印出 data["query"] 的内容,即打印出由 write_query 生成的 SQL 查询语句,然后把原始数据继续返回给后续链的步骤。
  • 举例说明
    假设上一步生成的 data 为:
    {"query": "SELECT COUNT(*) FROM Playlist;"}
    
    那么这个 lambda 函数会先打印:
    write_query execution result: SELECT COUNT(*) FROM Playlist;
    
    然后返回原始数据 {"query": "SELECT COUNT(*) FROM Playlist;"},不对数据做任何修改。

3.6 构造完整的链(chain_ex)

chain_ex = (
    RunnablePassthrough.assign(query=write_query)
    | debug_chain
    | RunnablePassthrough.assign(result=itemgetter("query") | execute_query)
    | answer
)

这段代码构造了一个数据处理流水线,每一步的含义如下:

  1. RunnablePassthrough.assign(query=write_query)

    • 作用:调用 write_query 生成 SQL 查询,并将生成的查询结果存储到数据字典的 "query" 键中。
    • 举例:生成的数据可能为:
      {"query": "SELECT COUNT(*) FROM Playlist;"}
      
  2. | debug_chain

    • 作用:将上一步生成的数据传入 debug_chain,打印出 SQL 查询,同时不改变数据。
    • 举例:会打印出上述 SQL 查询语句。
  3. | RunnablePassthrough.assign(result=itemgetter(“query”) | execute_query)

    • 作用:利用 itemgetter("query") 从数据字典中提取 SQL 查询语句,然后将其传递给 execute_query 工具,执行 SQL 查询,并将执行结果存储到数据字典的 "result" 键中。
    • 举例:假如数据库中 Playlist 表有 25 条记录,则查询结果可能为:
      {"query": "SELECT COUNT(*) FROM Playlist;", "result": 25}
      
  4. | answer

    • 作用:将包含问题、SQL 查询和查询结果的数据传入答案生成链,得到最终的自然语言回答。

3.7 绑定固定参数

bound_chain = chain_ex.bind(
    dialect=self.db.dialect,
    table_info=self.db.get_table_info(),
    top_k=55
)
self.chain = bound_chain
  • 作用
    • bind 方法将一些固定参数(如 SQL 的方言、数据库表结构信息、以及 top_k 参数)绑定到流水线中,确保每次调用链时这些参数都自动传递进去。
    • 绑定后,将整个链赋值给 self.chain,这样后续调用就会按照这个步骤顺序执行。

4. 全局部分与工具函数定义

4.1 指定数据库路径和示例查询

sqldb_directory = here("data/Chinook.db")
query = "and calculate the number of all Playlist"
  • 作用
    • 使用 here 函数定位到数据库文件所在的路径(相对路径 “data/Chinook.db”);
    • 定义一个示例查询字符串,意思是“计算所有 Playlist 的数量”。

4.2 定义工具函数 query_travel_sqldb

@tool
def query_travel_sqldb(query: str) -> str:
    """Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
    agent = TravelSQLAgentTool(
        llm="llama3-70b-8192", # 指定使用的语言模型
        sqldb_directory= sqldb_directory, # 数据库文件路径
        llm_temerature=0 # LLM 温度设为 0,表示回答比较确定,不引入随机性
    )
    response = agent.chain.invoke({"question": query})
    return response
  • 作用
    • 这个函数使用 @tool 装饰器注册成工具函数,方便外部调用。
    • 内部实例化一个 TravelSQLAgentTool 对象,并传入模型名称、数据库路径及温度参数;
    • 然后调用构造好的链(agent.chain)的 invoke 方法,把用户的问题(键名为 "question")传入整个链进行处理,最终得到回答。
  • 举例说明
    当你调用 query_travel_sqldb("and calculate the number of all Playlist") 时:
    1. 首先,write_query 生成 SQL 查询 SELECT COUNT(*) FROM Playlist;
    2. 调试链打印出这个 SQL 查询;
    3. 执行查询后,假设结果为 25;
    4. 最后,语言模型根据系统提示生成答案,比如“总共有 25 个 Playlist。”。

4.3 最后执行并打印结果

print(query_travel_sqldb(query))
  • 作用
    调用上面定义的工具函数,并将最终生成的答案打印出来。

5. 总结

这段代码构建了一个基于 LLM 的 SQL 查询代理工具,其工作流程为:

  1. 用户输入一个查询问题;
  2. 工具通过自定义提示模板调用语言模型生成 SQL 查询;
  3. 生成的 SQL 查询在调试链中打印出来以便检查;
  4. SQL 查询被发送到 SQLite 数据库执行,获得结果;
  5. 语言模型根据问题、SQL 查询和结果生成最终的自然语言答案;
  6. 最终答案返回并打印出来。

这种设计使得查询流程高度自动化,同时又便于调试和检查中间步骤,帮助理解每一步是如何工作的。


http://www.kler.cn/a/566535.html

相关文章:

  • XR应用测试:探索虚拟与现实的边界
  • unity pico开发 一:环境准备
  • 核弹级技术革命——搭配deepseek-r1满血版的腾讯云ai助手(codex)仅用14天独立开发出适配ARM架构的微内核操作系统!
  • 遇到liunx服务器IO负载,读IO流量峰值347MB/s,排查并解决。
  • 【STM32F103ZET6——库函数】4.串口通讯
  • Web3.py 入门笔记
  • 用大白话解释基础框架Spring Boot——像“装修套餐”一样简单
  • 2025年如何实现安卓、iOS、鸿蒙跨平台开发
  • vscode中使用PlatformIO创建工程加载慢
  • xss自动化扫描工具-DALFox
  • 阿里云的 ECS(Elastic Compute Service)实例
  • CD9.【C++ Dev】对“auto替换为变量实际的类型”的解释
  • AI大模型-提示工程学习笔记18—推理与行动的协同 (ReAct)
  • Go语言学习笔记(四)
  • idea导入新项目pom报错设置
  • PHP面试题--后端部分
  • MyBatis教程
  • 本地部署Deepseek+Cherry Studio
  • sklearn中的决策树-分类树:实例-分类树在合成数据集上的表现
  • 【MySQL】服务正在启动或停止中,请稍候片刻后再试一次【解决方案】