当前位置: 首页 > article >正文

LangGraph 源码分析 | 结构化输出

文章目录

    • with_structured_output 方法
      • 主要功能
      • 核心参数
      • 返回值
      • 核心逻辑
        • 绑定工具与选择解析器
        • 解析并结构化输出
      • 源代码
    • 解析器
      • JsonOutputKeyToolsParser
        • 调用父类的 parse_result
        • 源代码
      • PydanticToolsParser
        • 概览
        • _parse_obj:解析 JSON 为 Pydantic 模型
        • parse_result:根据 LLM 生成的数据解析结果
        • get_format_instructions:返回格式化说明
        • 源代码
    • 总结

with_structured_output 方法

主要功能

允许用户将模型(LLM)的输出转换为特定的数据格式:

  • 字典
  • JSON Schema
  • TypedDict
  • Pydantic 类 :Pydantic 是一个用于数据验证和解析的 Python 库。BaseModel 是 Pydantic 库中的一个基类,用于定义数据模型【BaseModel 提供了一种声明性的方式来定义数据模型,包括字段类型、验证规则和默认值】

核心参数

  • schema: 指定模型输出的格式,可以是:
    • OpenAI函数/工具的schema
    • JSON Schema格式
    • TypedDict类 (支持自0.2.26版本)
    • Pydantic类:如果提供的是Pydantic类,则会自动将模型输出转换为该Pydantic对象,并对字段进行验证。
  • include_raw
    • 是否返回原始输出。如果为True,返回一个包含"raw"、"parsed"和"parsing_error"三个键的字典。
    • 如果为False,只返回解析后的结构化输出(或者抛出解析错误)。

返回值

返回的是一个 Runnable 对象。该对象会根据输入生成模型响应,并将响应解析成符合指定schema格式的数据:

  • include_raw=False时:
    • 如果schema是Pydantic类,则返回一个Pydantic对象。
    • 否则,返回一个字典格式的结构化数据。
  • include_raw=True时:
    • 返回一个包含"raw""parsed""parsing_error"三个键的字典。

核心逻辑

首先,模型需要支持Tool use,因为后续需要通过bind_tools来格式化输出

绑定工具与选择解析器
llm = self.bind_tools([schema], tool_choice="any")
if isinstance(schema, type) and is_basemodel_subclass(schema):
    output_parser = PydanticToolsParser(
        tools=[cast(TypeBaseModel, schema)], first_tool_only=True
    )
else:
    key_name = convert_to_openai_tool(schema)["function"]["name"]
    output_parser = JsonOutputKeyToolsParser(
        key_name=key_name, first_tool_only=True
    )

  • 将模型输出的格式schema以工具的方式,绑定到 LLM 上
  • 如果schema是 Pydantic 类,则使用PydanticToolsParser解析
  • 如果是其他类型,则使用JsonOutputKeyToolsParser解析
解析并结构化输出
if include_raw:
    parser_assign = RunnablePassthrough.assign(
        parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
    )
    parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
    parser_with_fallback = parser_assign.with_fallbacks(
        [parser_none], exception_key="parsing_error"
    )
    return RunnableMap(raw=llm) | parser_with_fallback
else:
    return llm | output_parser
  • 如果include_rawTrue,则使用一个更复杂的解析器链,包括原始数据解析、回退机制和异常处理
  • 如果include_rawFalse,则使用一个更简单的解析器链

源代码

    def with_structured_output(
        self,
        schema: Union[typing.Dict, type],  # noqa: UP006
        *,
        include_raw: bool = False,
        **kwargs: Any,
    ) -> Runnable[LanguageModelInput, Union[typing.Dict, BaseModel]]:  # noqa: UP006
        """Model wrapper that returns outputs formatted to match the given schema.

        Args:
            schema:
                The output schema. Can be passed in as:
                    - an OpenAI function/tool schema,
                    - a JSON Schema,
                    - a TypedDict class (support added in 0.2.26),
                    - or a Pydantic class.
                If ``schema`` is a Pydantic class then the model output will be a
                Pydantic instance of that class, and the model-generated fields will be
                validated by the Pydantic class. Otherwise the model output will be a
                dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
                for more on how to properly specify types and descriptions of
                schema fields when specifying a Pydantic or TypedDict class.

                .. versionchanged:: 0.2.26

                        Added support for TypedDict class.

            include_raw:
                If False then only the parsed structured output is returned. If
                an error occurs during model output parsing it will be raised. If True
                then both the raw model response (a BaseMessage) and the parsed model
                response will be returned. If an error occurs during output parsing it
                will be caught and returned as well. The final output is always a dict
                with keys "raw", "parsed", and "parsing_error".

        Returns:
            A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.

            If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs
            an instance of ``schema`` (i.e., a Pydantic object).

            Otherwise, if ``include_raw`` is False then Runnable outputs a dict.

            If ``include_raw`` is True, then Runnable outputs a dict with keys:
                - ``"raw"``: BaseMessage
                - ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
                - ``"parsing_error"``: Optional[BaseException]

        Example: Pydantic schema (include_raw=False):
            .. code-block:: python

                from pydantic import BaseModel

                class AnswerWithJustification(BaseModel):
                    '''An answer to the user question along with justification for the answer.'''
                    answer: str
                    justification: str

                llm = ChatModel(model="model-name", temperature=0)
                structured_llm = llm.with_structured_output(AnswerWithJustification)

                structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")

                # -> AnswerWithJustification(
                #     answer='They weigh the same',
                #     justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
                # )

        Example: Pydantic schema (include_raw=True):
            .. code-block:: python

                from pydantic import BaseModel

                class AnswerWithJustification(BaseModel):
                    '''An answer to the user question along with justification for the answer.'''
                    answer: str
                    justification: str

                llm = ChatModel(model="model-name", temperature=0)
                structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)

                structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
                # -> {
                #     'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
                #     'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
                #     'parsing_error': None
                # }

        Example: Dict schema (include_raw=False):
            .. code-block:: python

                from pydantic import BaseModel
                from langchain_core.utils.function_calling import convert_to_openai_tool

                class AnswerWithJustification(BaseModel):
                    '''An answer to the user question along with justification for the answer.'''
                    answer: str
                    justification: str

                dict_schema = convert_to_openai_tool(AnswerWithJustification)
                llm = ChatModel(model="model-name", temperature=0)
                structured_llm = llm.with_structured_output(dict_schema)

                structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
                # -> {
                #     'answer': 'They weigh the same',
                #     'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
                # }
        """  # noqa: E501
        # 是否包含无效参数
        if kwargs:
            msg = f"Received unsupported arguments {kwargs}"
            raise ValueError(msg)

        # 导入解析器模块
        from langchain_core.output_parsers.openai_tools import (
            JsonOutputKeyToolsParser,
            PydanticToolsParser,
        )

        # 判断模型是否支持 with_structured_output
        if self.bind_tools is BaseChatModel.bind_tools:
            msg = "with_structured_output is not implemented for this model."
            raise NotImplementedError(msg)
        # 绑定工具与选择解析器
        llm = self.bind_tools([schema], tool_choice="any")
        if isinstance(schema, type) and is_basemodel_subclass(schema):
            output_parser: OutputParserLike = PydanticToolsParser(
                tools=[cast(TypeBaseModel, schema)], first_tool_only=True
            )
        else:
            key_name = convert_to_openai_tool(schema)["function"]["name"]
            output_parser = JsonOutputKeyToolsParser(
                key_name=key_name, first_tool_only=True
            )
        if include_raw:
            parser_assign = RunnablePassthrough.assign(
                parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
            )
            parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
            parser_with_fallback = parser_assign.with_fallbacks(
                [parser_none], exception_key="parsing_error"
            )
            return RunnableMap(raw=llm) | parser_with_fallback
        else:
            return llm | output_parser

解析器

JsonOutputKeyToolsParser

调用父类的 parse_result

作用:解析 LLM 调用的结果,并将其转化为工具调用的列表

参数:

  • result: 这是一个包含 LLM 调用结果的列表,通常是 Generation 类型的对象
  • partial: 一个可选的布尔参数,默认值为 False。指示是否解析部分 JSON。如果为 True,则返回的将是包含已返回键的 JSON 对象;如果为 False,则返回完整的 JSON 对象。
    def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
        # 从 result 列表中获取第一个生成结果
        generation = result[0]
        # 检查 generation 是否是 ChatGeneration 类型
        if not isinstance(generation, ChatGeneration):
            msg = "This output parser can only be used with a chat generation."
            raise OutputParserException(msg)
        # 从 generation 中提取消息
        message = generation.message
        if isinstance(message, AIMessage) and message.tool_calls:
            tool_calls = [dict(tc) for tc in message.tool_calls]
            for tool_call in tool_calls:
                if not self.return_id:
                    _ = tool_call.pop("id")
        else:
            try:
                raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
            except KeyError:
                return []
            tool_calls = parse_tool_calls(
                raw_tool_calls,
                partial=partial,
                strict=self.strict,
                return_id=self.return_id,
            )
        # for backwards compatibility
        for tc in tool_calls:
            tc["type"] = tc.pop("name")

        if self.first_tool_only:
            return tool_calls[0] if tool_calls else None
        return tool_calls
源代码
class JsonOutputKeyToolsParser(JsonOutputToolsParser):
    """Parse tools from OpenAI response."""

    key_name: str
    """The type of tools to return."""

    def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
        """Parse the result of an LLM call to a list of tool calls.

        Args:
            result: The result of the LLM call.
            partial: Whether to parse partial JSON.
                If True, the output will be a JSON object containing
                all the keys that have been returned so far.
                If False, the output will be the full JSON object.
                Default is False.

        Returns:
            The parsed tool calls.
        """
        parsed_result = super().parse_result(result, partial=partial)

        # 只需要返回第一个匹配的工具结果
        if self.first_tool_only:
            single_result = (
                parsed_result
                if parsed_result and parsed_result["type"] == self.key_name
                else None
            )
            if self.return_id:
                return single_result
            elif single_result:
                return single_result["args"]
            else:
                return None
        # 处理多个工具结果
        parsed_result = [res for res in parsed_result if res["type"] == self.key_name]
        if not self.return_id:
            parsed_result = [res["args"] for res in parsed_result]
        return parsed_result

PydanticToolsParser

概览
  1. 输入:LLM 返回的 JSON 数据(嵌套在 Generation 对象内)。
  2. 处理
  • 验证并将 JSON 数据映射为 Pydantic 模型。
  • 如果出错,会抛出自定义的异常。
  1. 输出:返回一个 Pydantic 模型对象,便于进一步处理或使用。
_parse_obj:解析 JSON 为 Pydantic 模型
def _parse_obj(self, obj: dict) -> TBaseModel:
    if PYDANTIC_MAJOR_VERSION == 2:
        try:
            if issubclass(self.pydantic_object, pydantic.BaseModel):
                return self.pydantic_object.model_validate(obj)
            elif issubclass(self.pydantic_object, pydantic.v1.BaseModel):
                return self.pydantic_object.parse_obj(obj)
            else:
                msg = f"Unsupported model version for PydanticOutputParser: \
                        {self.pydantic_object.__class__}"
                raise OutputParserException(msg)
        except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
            raise self._parser_exception(e, obj) from e
    else:
        try:
            return self.pydantic_object.parse_obj(obj)
        except pydantic.ValidationError as e:
            raise self._parser_exception(e, obj) from e

parse_result:根据 LLM 生成的数据解析结果
def parse_result(
    self, result: list[Generation], *, partial: bool = False
) -> Optional[TBaseModel]:
    """Parse the result of an LLM call to a pydantic object."""
    try:
        json_object = super().parse_result(result)
        return self._parse_obj(json_object)
    except OutputParserException as e:
        if partial:
            return None
        raise e

  1. 使用父类的 parse_result() 方法将 Generation 对象转换为 JSON 数据。
  2. 使用 _parse_obj() 方法进一步解析为 Pydantic 模型
  3. 如果解析失败且 partial=True,则返回 None;否则抛出异常。
get_format_instructions:返回格式化说明
def get_format_instructions(self) -> str:
    """Return the format instructions for the JSON output."""
    schema = dict(self.pydantic_object.model_json_schema().items())
    if "title" in schema:
        del schema["title"]
    if "type" in schema:
        del schema["type"]
    schema_str = json.dumps(schema, ensure_ascii=False)
    return _PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)

  • 获取 Pydantic 模型的 JSON schema,并删除不必要的字段(如 titletype)。
  • 返回格式化后的 JSON schema 字符串,作为模型的格式说明。
源代码
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
    """Parse an output using a pydantic model."""

    pydantic_object: Annotated[type[TBaseModel], SkipValidation()]  # type: ignore
    """The pydantic model to parse."""

    def _parse_obj(self, obj: dict) -> TBaseModel:
        if PYDANTIC_MAJOR_VERSION == 2:
            try:
                if issubclass(self.pydantic_object, pydantic.BaseModel):
                    return self.pydantic_object.model_validate(obj)
                elif issubclass(self.pydantic_object, pydantic.v1.BaseModel):
                    return self.pydantic_object.parse_obj(obj)
                else:
                    msg = f"Unsupported model version for PydanticOutputParser: \
                            {self.pydantic_object.__class__}"
                    raise OutputParserException(msg)
            except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
                raise self._parser_exception(e, obj) from e
        else:  # pydantic v1
            try:
                return self.pydantic_object.parse_obj(obj)
            except pydantic.ValidationError as e:
                raise self._parser_exception(e, obj) from e

    def _parser_exception(
        self, e: Exception, json_object: dict
    ) -> OutputParserException:
        json_string = json.dumps(json_object)
        name = self.pydantic_object.__name__
        msg = f"Failed to parse {name} from completion {json_string}. Got: {e}"
        return OutputParserException(msg, llm_output=json_string)

    def parse_result(
        self, result: list[Generation], *, partial: bool = False
    ) -> Optional[TBaseModel]:
        """Parse the result of an LLM call to a pydantic object.

        Args:
            result: The result of the LLM call.
            partial: Whether to parse partial JSON objects.
                If True, the output will be a JSON object containing
                all the keys that have been returned so far.
                Defaults to False.

        Returns:
            The parsed pydantic object.
        """
        try:
            json_object = super().parse_result(result)
            return self._parse_obj(json_object)
        except OutputParserException as e:
            if partial:
                return None
            raise e

    def parse(self, text: str) -> TBaseModel:
        """Parse the output of an LLM call to a pydantic object.

        Args:
            text: The output of the LLM call.

        Returns:
            The parsed pydantic object.
        """
        return super().parse(text)

    def get_format_instructions(self) -> str:
        """Return the format instructions for the JSON output.

        Returns:
            The format instructions for the JSON output.
        """
        # Copy schema to avoid altering original Pydantic schema.
        schema = dict(self.pydantic_object.model_json_schema().items())

        # Remove extraneous fields.
        reduced_schema = schema
        if "title" in reduced_schema:
            del reduced_schema["title"]
        if "type" in reduced_schema:
            del reduced_schema["type"]
        # Ensure json in context is well-formed with double quotes.
        schema_str = json.dumps(reduced_schema, ensure_ascii=False)

        return _PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)

    @property
    def _type(self) -> str:
        return "pydantic"

    @property
    @override
    def OutputType(self) -> type[TBaseModel]:
        """Return the pydantic model."""
        return self.pydantic_object

总结

  • 结构化输出建立在 LLM 已经经过 Tool use 微调的基础上
  • Tool use 功能允许在特定场景下,根据可用的工具列表,LLM 能够智能地选择最合适的工具。接着,模型会根据该工具的 API 调用指南,生成符合要求的 JSON 格式请求
  • 而 LangGraph 的结构化输出复用了 Tool use 的功能,将给定的模式 schema 看成一个工具,要求 LLM 根据 schema 的描述,输出符合要求的“请求格式”,并取出 Tool use 功能返回的冗余字段

http://www.kler.cn/news/362951.html

相关文章:

  • Hallo2 长视频和高分辨率的音频驱动的肖像图像动画 (数字人技术)
  • 【Jmeter】jmeter指定jdk版本启动
  • 【保姆级教程】DolphinScheduler本地部署与远程访问详细步骤解析
  • 全新子比主题7.9.2开心版 子比主题最新版源码
  • 汽配企业数字工厂管理系统实施规划方案
  • 第四十三条:方法引用优先于Lambda
  • Umi UI报错:连接失败,请尝试重启dev服务
  • 从一个简单的计算问题,看国内几个大语言模型推理逻辑能力
  • 市面上什么台灯性价比高?五款超强实力护眼台灯测评推荐!
  • SVN小乌龟 create patch 和 apply patch 功能
  • 基于Multisim的水温控制电路设计与仿真
  • 51单片机应用——直流电机PWM调速
  • TikTok营销实用技巧与数据分析工具:视频洞察
  • konvajs -基础图形-标签-箭头,动画,学习笔记
  • GORM框架中的预加载功能Preload详解
  • Java智慧工地管理平台SaaS源码:打造安全、高效、绿色、智能的建筑施工新生态
  • 如何在PyCharm中安全地设置和使用API Key
  • 开源项目 - yolo v5 物体检测 手检测 深度学习
  • vue使用xlsx以及file-saver进行下载xlsx文件以及Unit8Array、ArrayBuffer、charCodeAt的使用
  • C# 简单排序方法
  • VS 插入跟踪点,依赖断点,临时断点的区别
  • Linux中vim的三种主要模式和具体用法
  • SpringBootWeb请求响应
  • ReactOS系统中搜索给定长度的空间地址区间中的二叉树
  • 外呼机器人的功能特点
  • 即插即用篇 | YOLOv10 引入 MogaBlock | 多阶门控聚合网络 | ICLR 2024