LangGraph 源码分析 | 结构化输出
文章目录
- with_structured_output 方法
- 主要功能
- 核心参数
- 返回值
- 核心逻辑
- 绑定工具与选择解析器
- 解析并结构化输出
- 源代码
- 解析器
- JsonOutputKeyToolsParser
- 调用父类的 parse_result
- 源代码
- PydanticToolsParser
- 概览
- _parse_obj:解析 JSON 为 Pydantic 模型
- parse_result:根据 LLM 生成的数据解析结果
- get_format_instructions:返回格式化说明
- 源代码
- 总结
with_structured_output 方法
主要功能
允许用户将模型(LLM)的输出转换为特定的数据格式:
- 字典
- JSON Schema
- TypedDict
- Pydantic 类 :Pydantic 是一个用于数据验证和解析的 Python 库。BaseModel 是 Pydantic 库中的一个基类,用于定义数据模型【BaseModel 提供了一种声明性的方式来定义数据模型,包括字段类型、验证规则和默认值】
核心参数
schema
: 指定模型输出的格式,可以是:- OpenAI函数/工具的schema
- JSON Schema格式
- TypedDict类 (支持自0.2.26版本)
- Pydantic类:如果提供的是Pydantic类,则会自动将模型输出转换为该Pydantic对象,并对字段进行验证。
include_raw
:- 是否返回原始输出。如果为
True
,返回一个包含"raw"、"parsed"和"parsing_error"三个键的字典。 - 如果为
False
,只返回解析后的结构化输出(或者抛出解析错误)。
- 是否返回原始输出。如果为
返回值
返回的是一个 Runnable 对象。该对象会根据输入生成模型响应,并将响应解析成符合指定schema
格式的数据:
- 当
include_raw=False
时:- 如果
schema
是Pydantic类,则返回一个Pydantic对象。 - 否则,返回一个字典格式的结构化数据。
- 如果
- 当
include_raw=True
时:- 返回一个包含
"raw"
、"parsed"
和"parsing_error"
三个键的字典。
- 返回一个包含
核心逻辑
首先,模型需要支持Tool use
,因为后续需要通过bind_tools
来格式化输出
绑定工具与选择解析器
llm = self.bind_tools([schema], tool_choice="any")
if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser = PydanticToolsParser(
tools=[cast(TypeBaseModel, schema)], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
- 将模型输出的格式
schema
以工具的方式,绑定到 LLM 上 - 如果
schema
是 Pydantic 类,则使用PydanticToolsParser
解析 - 如果是其他类型,则使用
JsonOutputKeyToolsParser
解析
解析并结构化输出
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
- 如果
include_raw
为True
,则使用一个更复杂的解析器链,包括原始数据解析、回退机制和异常处理 - 如果
include_raw
为False
,则使用一个更简单的解析器链
源代码
def with_structured_output(
self,
schema: Union[typing.Dict, type], # noqa: UP006
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[typing.Dict, BaseModel]]: # noqa: UP006
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema:
The output schema. Can be passed in as:
- an OpenAI function/tool schema,
- a JSON Schema,
- a TypedDict class (support added in 0.2.26),
- or a Pydantic class.
If ``schema`` is a Pydantic class then the model output will be a
Pydantic instance of that class, and the model-generated fields will be
validated by the Pydantic class. Otherwise the model output will be a
dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
for more on how to properly specify types and descriptions of
schema fields when specifying a Pydantic or TypedDict class.
.. versionchanged:: 0.2.26
Added support for TypedDict class.
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs
an instance of ``schema`` (i.e., a Pydantic object).
Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
If ``include_raw`` is True, then Runnable outputs a dict with keys:
- ``"raw"``: BaseMessage
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- ``"parsing_error"``: Optional[BaseException]
Example: Pydantic schema (include_raw=False):
.. code-block:: python
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatModel(model="model-name", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> AnswerWithJustification(
# answer='They weigh the same',
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
# )
Example: Pydantic schema (include_raw=True):
.. code-block:: python
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatModel(model="model-name", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
# 'parsing_error': None
# }
Example: Dict schema (include_raw=False):
.. code-block:: python
from pydantic import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_tool
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
dict_schema = convert_to_openai_tool(AnswerWithJustification)
llm = ChatModel(model="model-name", temperature=0)
structured_llm = llm.with_structured_output(dict_schema)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
# }
""" # noqa: E501
# 是否包含无效参数
if kwargs:
msg = f"Received unsupported arguments {kwargs}"
raise ValueError(msg)
# 导入解析器模块
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
# 判断模型是否支持 with_structured_output
if self.bind_tools is BaseChatModel.bind_tools:
msg = "with_structured_output is not implemented for this model."
raise NotImplementedError(msg)
# 绑定工具与选择解析器
llm = self.bind_tools([schema], tool_choice="any")
if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser: OutputParserLike = PydanticToolsParser(
tools=[cast(TypeBaseModel, schema)], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
解析器
JsonOutputKeyToolsParser
调用父类的 parse_result
作用:解析 LLM 调用的结果,并将其转化为工具调用的列表
参数:
result
: 这是一个包含 LLM 调用结果的列表,通常是Generation
类型的对象partial
: 一个可选的布尔参数,默认值为False
。指示是否解析部分 JSON。如果为True
,则返回的将是包含已返回键的 JSON 对象;如果为False
,则返回完整的 JSON 对象。
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
# 从 result 列表中获取第一个生成结果
generation = result[0]
# 检查 generation 是否是 ChatGeneration 类型
if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
# 从 generation 中提取消息
message = generation.message
if isinstance(message, AIMessage) and message.tool_calls:
tool_calls = [dict(tc) for tc in message.tool_calls]
for tool_call in tool_calls:
if not self.return_id:
_ = tool_call.pop("id")
else:
try:
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
except KeyError:
return []
tool_calls = parse_tool_calls(
raw_tool_calls,
partial=partial,
strict=self.strict,
return_id=self.return_id,
)
# for backwards compatibility
for tc in tool_calls:
tc["type"] = tc.pop("name")
if self.first_tool_only:
return tool_calls[0] if tool_calls else None
return tool_calls
源代码
class JsonOutputKeyToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response."""
key_name: str
"""The type of tools to return."""
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a list of tool calls.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON.
If True, the output will be a JSON object containing
all the keys that have been returned so far.
If False, the output will be the full JSON object.
Default is False.
Returns:
The parsed tool calls.
"""
parsed_result = super().parse_result(result, partial=partial)
# 只需要返回第一个匹配的工具结果
if self.first_tool_only:
single_result = (
parsed_result
if parsed_result and parsed_result["type"] == self.key_name
else None
)
if self.return_id:
return single_result
elif single_result:
return single_result["args"]
else:
return None
# 处理多个工具结果
parsed_result = [res for res in parsed_result if res["type"] == self.key_name]
if not self.return_id:
parsed_result = [res["args"] for res in parsed_result]
return parsed_result
PydanticToolsParser
概览
- 输入:LLM 返回的 JSON 数据(嵌套在
Generation
对象内)。 - 处理:
- 验证并将 JSON 数据映射为 Pydantic 模型。
- 如果出错,会抛出自定义的异常。
- 输出:返回一个 Pydantic 模型对象,便于进一步处理或使用。
_parse_obj:解析 JSON 为 Pydantic 模型
def _parse_obj(self, obj: dict) -> TBaseModel:
if PYDANTIC_MAJOR_VERSION == 2:
try:
if issubclass(self.pydantic_object, pydantic.BaseModel):
return self.pydantic_object.model_validate(obj)
elif issubclass(self.pydantic_object, pydantic.v1.BaseModel):
return self.pydantic_object.parse_obj(obj)
else:
msg = f"Unsupported model version for PydanticOutputParser: \
{self.pydantic_object.__class__}"
raise OutputParserException(msg)
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
raise self._parser_exception(e, obj) from e
else:
try:
return self.pydantic_object.parse_obj(obj)
except pydantic.ValidationError as e:
raise self._parser_exception(e, obj) from e
parse_result:根据 LLM 生成的数据解析结果
def parse_result(
self, result: list[Generation], *, partial: bool = False
) -> Optional[TBaseModel]:
"""Parse the result of an LLM call to a pydantic object."""
try:
json_object = super().parse_result(result)
return self._parse_obj(json_object)
except OutputParserException as e:
if partial:
return None
raise e
- 使用父类的
parse_result()
方法将Generation
对象转换为 JSON 数据。 - 使用
_parse_obj()
方法进一步解析为 Pydantic 模型。 - 如果解析失败且
partial=True
,则返回None
;否则抛出异常。
get_format_instructions:返回格式化说明
def get_format_instructions(self) -> str:
"""Return the format instructions for the JSON output."""
schema = dict(self.pydantic_object.model_json_schema().items())
if "title" in schema:
del schema["title"]
if "type" in schema:
del schema["type"]
schema_str = json.dumps(schema, ensure_ascii=False)
return _PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)
- 获取 Pydantic 模型的 JSON schema,并删除不必要的字段(如
title
和type
)。 - 返回格式化后的 JSON schema 字符串,作为模型的格式说明。
源代码
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
"""Parse an output using a pydantic model."""
pydantic_object: Annotated[type[TBaseModel], SkipValidation()] # type: ignore
"""The pydantic model to parse."""
def _parse_obj(self, obj: dict) -> TBaseModel:
if PYDANTIC_MAJOR_VERSION == 2:
try:
if issubclass(self.pydantic_object, pydantic.BaseModel):
return self.pydantic_object.model_validate(obj)
elif issubclass(self.pydantic_object, pydantic.v1.BaseModel):
return self.pydantic_object.parse_obj(obj)
else:
msg = f"Unsupported model version for PydanticOutputParser: \
{self.pydantic_object.__class__}"
raise OutputParserException(msg)
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
raise self._parser_exception(e, obj) from e
else: # pydantic v1
try:
return self.pydantic_object.parse_obj(obj)
except pydantic.ValidationError as e:
raise self._parser_exception(e, obj) from e
def _parser_exception(
self, e: Exception, json_object: dict
) -> OutputParserException:
json_string = json.dumps(json_object)
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {json_string}. Got: {e}"
return OutputParserException(msg, llm_output=json_string)
def parse_result(
self, result: list[Generation], *, partial: bool = False
) -> Optional[TBaseModel]:
"""Parse the result of an LLM call to a pydantic object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects.
If True, the output will be a JSON object containing
all the keys that have been returned so far.
Defaults to False.
Returns:
The parsed pydantic object.
"""
try:
json_object = super().parse_result(result)
return self._parse_obj(json_object)
except OutputParserException as e:
if partial:
return None
raise e
def parse(self, text: str) -> TBaseModel:
"""Parse the output of an LLM call to a pydantic object.
Args:
text: The output of the LLM call.
Returns:
The parsed pydantic object.
"""
return super().parse(text)
def get_format_instructions(self) -> str:
"""Return the format instructions for the JSON output.
Returns:
The format instructions for the JSON output.
"""
# Copy schema to avoid altering original Pydantic schema.
schema = dict(self.pydantic_object.model_json_schema().items())
# Remove extraneous fields.
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema_str = json.dumps(reduced_schema, ensure_ascii=False)
return _PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)
@property
def _type(self) -> str:
return "pydantic"
@property
@override
def OutputType(self) -> type[TBaseModel]:
"""Return the pydantic model."""
return self.pydantic_object
总结
- 结构化输出建立在 LLM 已经经过 Tool use 微调的基础上
- Tool use 功能允许在特定场景下,根据可用的工具列表,LLM 能够智能地选择最合适的工具。接着,模型会根据该工具的 API 调用指南,生成符合要求的 JSON 格式请求
- 而 LangGraph 的结构化输出复用了 Tool use 的功能,将给定的模式 schema 看成一个工具,要求 LLM 根据 schema 的描述,输出符合要求的“请求格式”,并取出 Tool use 功能返回的冗余字段