在 crag 中用 LangGraph 进行评分知识精炼-下
在上一次给大家展示了基本的 Rag
检索过程,着重描述了增强检索中的知识精炼和补充检索,这些都是 crag
的一部分,这篇内容结合 langgraph
给大家展示通过检索增强生成(Retrieval-Augmented Generation, RAG
)的工作流,来处理问题并生成答案。
好了,下面我们直接开始代码。
定义graph
首先我们还是先定义 web
搜索工具,这里的 web
搜索工具作为补充检索很重要:
### Search
from langchain_community.tools.tavily_search import TavilySearchResults
web_search_tool = TavilySearchResults(k=3)
然后定义我们 graph
的基本状态结构:
from typing import List
from typing_extensions import TypedDict
class GraphState(TypedDict):
"""
表示我们图的状态。
属性:
question: 问题
generation: 大语言模型生成的内容
web_search: 是否使用web搜索
documents: 检索出来的文档列表
"""
question: str
generation: str
web_search: str
documents: List[str]
定义工作流节点
定义检索与问题相关的文档的节点:
def retrieve(state):
"""
检索文档
参数:
state (dict): 当前图状态
返回:
state (dict): 更新后的状态,包含检索到的文档
"""
print("---RETRIEVE---")
question = state["question"]
# 檢索相关文档
documents = retriever.get_relevant_documents(question)
return {"documents": documents, "question": question}
定义基于检索到的文档生成答案的节点:
def generate(state):
"""
生成答案
参数:
state (dict): 当前图状态
返回:
state (dict): 更新后的状态,包含生成的答案
"""
print("---GENERATE---")
question = state["question"]
documents = state["documents"]
# RAG generation
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
定义评估检索到的文档是否与问题相关的节点:
def grade_documents(state):
"""
评估检索到的文档是否与问题相关
参数:
state (dict): 当前图状态
返回:
state (dict): 更新后的状态,仅包含相关文档
"""
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
question = state["question"]
documents = state["documents"]
# Score each doc
filtered_docs = []
web_search = "No"
for d in documents:
score = retrieval_grader.invoke(
{"question": question, "document": d.page_content}
)
grade = score.binary_score
if grade == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
web_search = "Yes"
continue
return {"documents": filtered_docs, "question": question, "web_search": web_search}
重写问题以生成更好的查询:
def transform_query(state):
"""
重写问题以生成更好的查询
参数:
state (dict): 当前图状态
返回:
state (dict): 更新后的状态,包含重写后的问题
"""
print("---TRANSFORM QUERY---")
question = state["question"]
documents = state["documents"]
# Re-write question
better_question = question_rewriter.invoke({"question": question})
return {"documents": documents, "question": better_question}
基于重写后的问题进行网络搜索,并将结果添加到文档列表中。
def web_search(state):
"""
基于重写后的问题进行网络搜索
参数:
state (dict): 当前图状态
返回:
state (dict): 更新后的状态,包含网络搜索结果
"""
print("---WEB SEARCH---")
question = state["question"]
documents = state["documents"]
# Web search
docs = web_search_tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_results = Document(page_content=web_results)
documents.append(web_results)
return {"documents": documents, "question": question}
最后定义我们的条件边,决定是生成答案还是重写问题。
def decide_to_generate(state):
"""
决定是生成答案还是重写问题
参数:
state (dict): 当前图状态
返回:
str: 下一个节点的决策("transform_query" 或 "generate")
"""
print("---ASSESS GRADED DOCUMENTS---")
# state["question"]
web_search = state["web_search"]
# state["documents"]
if web_search == "Yes":
# All documents have been filtered check_relevance
# We will re-generate a new query
print(
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
)
return "transform_query"
else:
# We have relevant documents, so generate answer
print("---DECISION: GENERATE---")
return "generate"
这里我们稍微等一下,做个小总结,上面定义的一系列的工作流节点主要是下面几个用途:
1.检索文档: 调用 retrieve 函数,获取与问题相关的文档。
2:评估文档相关性: 调用 grade_documents 函数,过滤掉不相关的文档。
3:决定下一步:如果所有文档都不相关,调用 transform_query 重写问题,然后进行网络搜索(web_search)。如果有相关文档,调用 generate 生成答案。生成答案: 调用 generate 函数,基于相关文档生成最终答案。
创建工作流
我们把上面定义好的节点和边组织在一起:
from langgraph.graph import END, StateGraph, START
# 创建工作流
workflow = StateGraph(GraphState)
# 定义节点
workflow.add_node("retrieve", retrieve) # 检索文档
workflow.add_node("grade_documents", grade_documents) # 评估文档相关性
workflow.add_node("generate", generate) # 生成答案
workflow.add_node("transform_query", transform_query) # 重写问题
workflow.add_node("web_search_node", web_search) # 网络搜索
# 构建工作流
workflow.add_edge(START, "retrieve") # 从 START 到 retrieve
workflow.add_edge("retrieve", "grade_documents") # 从 retrieve 到 grade_documents
workflow.add_conditional_edges(
"grade_documents", # 从 grade_documents 出发
decide_to_generate, # 根据 decide_to_generate 的返回值决定下一步
{
"transform_query": "transform_query", # 如果返回 "transform_query",跳转到 transform_query 节点
"generate": "generate", # 如果返回 "generate",跳转到 generate 节点
},
)
workflow.add_edge("transform_query", "web_search_node") # 从 transform_query 到 web_search_node
workflow.add_edge("web_search_node", "generate") # 从 web_search_node 到 generate
workflow.add_edge("generate", END) # 从 generate 到 END
# 编译工作流
app = workflow.compile()
from IPython.display import display, Image
display(Image(graph.get_graph().draw_mermaid_png()))
得到下面的图形化结果:
我们这里就完成了整个工作流的逻辑框架,下面我们来调用这个 graph
试一下,看能得出什么结果:
from pprint import pprint
inputs = {"question": "agent memory 的类型有哪些?"}
for output in graph.stream(inputs):
for key, value in output.items():
# Node
pprint(f"Node '{key}':")
pprint(f"value '{value}':")
pprint("\n---\n")
我们分析执行流程日志就可以发现,我们采用一个基于状态的工作流 (StateGraph
),通过不同的节点和边来处理问题并生成答案的方式串联起了我们 crag
的整个流程,然后可以看到它是怎么来进行检索,然后怎么来调用工具的,到最后怎么完成知识精炼和补充,然后大模型返回给我们增强后的答案。大家可以根据上面的代码自己试一下。