当前位置: 首页 > article >正文

在 crag 中用 LangGraph 进行评分知识精炼-下

在上一次给大家展示了基本的 Rag 检索过程,着重描述了增强检索中的知识精炼和补充检索,这些都是 crag 的一部分,这篇内容结合 langgraph 给大家展示通过检索增强生成(Retrieval-Augmented Generation, RAG)的工作流,来处理问题并生成答案。
好了,下面我们直接开始代码。

定义graph

首先我们还是先定义 web 搜索工具,这里的 web 搜索工具作为补充检索很重要:

### Search
from langchain_community.tools.tavily_search import TavilySearchResults

web_search_tool = TavilySearchResults(k=3)

然后定义我们 graph 的基本状态结构:

from typing import List
from typing_extensions import TypedDict
class GraphState(TypedDict):
    """
    表示我们图的状态。

    属性:
        question: 问题
        generation: 大语言模型生成的内容
        web_search: 是否使用web搜索
        documents: 检索出来的文档列表
    """
    question: str
    generation: str
    web_search: str
    documents: List[str]

定义工作流节点

定义检索与问题相关的文档的节点:

def retrieve(state):
    """
    检索文档

    参数:
        state (dict): 当前图状态

    返回:
        state (dict): 更新后的状态,包含检索到的文档
    """
    print("---RETRIEVE---")
    question = state["question"]

    # 檢索相关文档
    documents = retriever.get_relevant_documents(question)
    return {"documents": documents, "question": question}

定义基于检索到的文档生成答案的节点:

def generate(state):
    """
    生成答案

    参数:
        state (dict): 当前图状态

    返回:
        state (dict): 更新后的状态,包含生成的答案
    """
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]

    # RAG generation
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}

定义评估检索到的文档是否与问题相关的节点:

def grade_documents(state):
    """
    评估检索到的文档是否与问题相关

    参数:
        state (dict): 当前图状态

    返回:
        state (dict): 更新后的状态,仅包含相关文档
    """

    print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
    question = state["question"]
    documents = state["documents"]

    # Score each doc
    filtered_docs = []
    web_search = "No"
    for d in documents:
        score = retrieval_grader.invoke(
            {"question": question, "document": d.page_content}
        )
        grade = score.binary_score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            web_search = "Yes"
            continue
    return {"documents": filtered_docs, "question": question, "web_search": web_search}

重写问题以生成更好的查询:

def transform_query(state):
    """
    重写问题以生成更好的查询

    参数:
        state (dict): 当前图状态

    返回:
        state (dict): 更新后的状态,包含重写后的问题
    """

    print("---TRANSFORM QUERY---")
    question = state["question"]
    documents = state["documents"]

    # Re-write question
    better_question = question_rewriter.invoke({"question": question})
    return {"documents": documents, "question": better_question}

基于重写后的问题进行网络搜索,并将结果添加到文档列表中。

def web_search(state):
    """
    基于重写后的问题进行网络搜索

    参数:
        state (dict): 当前图状态

    返回:
        state (dict): 更新后的状态,包含网络搜索结果
    """

    print("---WEB SEARCH---")
    question = state["question"]
    documents = state["documents"]

    # Web search
    docs = web_search_tool.invoke({"query": question})
    web_results = "\n".join([d["content"] for d in docs])
    web_results = Document(page_content=web_results)
    documents.append(web_results)

    return {"documents": documents, "question": question}

最后定义我们的条件边,决定是生成答案还是重写问题。

def decide_to_generate(state):
    """
    决定是生成答案还是重写问题

    参数:
        state (dict): 当前图状态

    返回:
        str: 下一个节点的决策("transform_query" 或 "generate")
    """

    print("---ASSESS GRADED DOCUMENTS---")
    # state["question"]
    web_search = state["web_search"]
    # state["documents"]

    if web_search == "Yes":
        # All documents have been filtered check_relevance
        # We will re-generate a new query
        print(
            "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
        )
        return "transform_query"
    else:
        # We have relevant documents, so generate answer
        print("---DECISION: GENERATE---")
        return "generate"

这里我们稍微等一下,做个小总结,上面定义的一系列的工作流节点主要是下面几个用途:

1.检索文档: 调用 retrieve 函数,获取与问题相关的文档。
2:评估文档相关性: 调用 grade_documents 函数,过滤掉不相关的文档。
3:决定下一步:如果所有文档都不相关,调用 transform_query 重写问题,然后进行网络搜索(web_search)。如果有相关文档,调用 generate 生成答案。生成答案: 调用 generate 函数,基于相关文档生成最终答案。

创建工作流

我们把上面定义好的节点和边组织在一起:

from langgraph.graph import END, StateGraph, START

# 创建工作流
workflow = StateGraph(GraphState)

# 定义节点
workflow.add_node("retrieve", retrieve)  # 检索文档
workflow.add_node("grade_documents", grade_documents)  # 评估文档相关性
workflow.add_node("generate", generate)  # 生成答案
workflow.add_node("transform_query", transform_query)  # 重写问题
workflow.add_node("web_search_node", web_search)  # 网络搜索

# 构建工作流
workflow.add_edge(START, "retrieve")  # 从 START 到 retrieve
workflow.add_edge("retrieve", "grade_documents")  # 从 retrieve 到 grade_documents
workflow.add_conditional_edges(
    "grade_documents",  # 从 grade_documents 出发
    decide_to_generate,  # 根据 decide_to_generate 的返回值决定下一步
    {
        "transform_query": "transform_query",  # 如果返回 "transform_query",跳转到 transform_query 节点
        "generate": "generate",  # 如果返回 "generate",跳转到 generate 节点
    },
)
workflow.add_edge("transform_query", "web_search_node")  # 从 transform_query 到 web_search_node
workflow.add_edge("web_search_node", "generate")  # 从 web_search_node 到 generate
workflow.add_edge("generate", END)  # 从 generate 到 END

# 编译工作流
app = workflow.compile()

from IPython.display import display, Image
display(Image(graph.get_graph().draw_mermaid_png()))

得到下面的图形化结果:
在这里插入图片描述
我们这里就完成了整个工作流的逻辑框架,下面我们来调用这个 graph 试一下,看能得出什么结果:

from pprint import pprint

inputs = {"question": "agent memory 的类型有哪些?"}
for output in graph.stream(inputs):
    for key, value in output.items():
        # Node
        pprint(f"Node '{key}':")
        pprint(f"value '{value}':")

    pprint("\n---\n")

在这里插入图片描述
我们分析执行流程日志就可以发现,我们采用一个基于状态的工作流 (StateGraph),通过不同的节点和边来处理问题并生成答案的方式串联起了我们 crag 的整个流程,然后可以看到它是怎么来进行检索,然后怎么来调用工具的,到最后怎么完成知识精炼和补充,然后大模型返回给我们增强后的答案。大家可以根据上面的代码自己试一下。


http://www.kler.cn/a/529589.html

相关文章:

  • 通信易懂唠唠SOME/IP——SOME/IP协议简介
  • Linux-CentOS的yum源
  • 【教程】在CMT上注册账号并声明Conflicts
  • 【Python】第七弹---Python基础进阶:深入字典操作与文件处理技巧
  • Hot100之普通数组
  • Longformer:处理长文档的Transformer模型
  • 7 [拒绝Github投毒通过Sharp4SuoBrowser分析VisualStudio隐藏文件]
  • redis原理之数据结构
  • c语言二级注意事项
  • 使用 Numpy 自定义数据集,使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数
  • C29.【C++ Cont】STL库:动态顺序表(vector容器)
  • LeetCode //C - 567. Permutation in String
  • IM 即时通讯系统-42-基于netty实现的IM服务端,提供客户端jar包,可集成自己的登录系统
  • 【Redis】Redis 经典面试题解析:深入理解 Redis 的核心概念与应用
  • java基础概念63-多线程
  • 【xdoj-离散线上练习】T251(C++)
  • AI技术路线(marked)
  • LeetCode 344: 反转字符串
  • Zabbix 推送告警 消息模板 美化(钉钉Webhook机器人、邮件)
  • 无人机飞手光伏吊运、电力巡检、农林植保技术详解
  • kamailio的kamctl的使用
  • [c语言日寄]C语言类型转换规则详解
  • ZYNQ-AXI DMA+AXI-S FIFO回环学习
  • DirectShow过滤器开发-读视频文件过滤器(再写)
  • 本地缓存~
  • 功防世界 Web_php_include