当前位置: 首页 > article >正文

【复现DeepSeek-R1之Open R1实战】系列6:GRPO源码逐行深度解析(上)

目录

  • 4 GRPO源码分析
    • 4.1 数据类 `GRPOScriptArguments`
    • 4.2 系统提示字符串 `SYSTEM_PROMPT`
    • 4.3 奖励函数
      • 4.3.1 accuracy_reward函数
      • 4.3.2 verify函数
      • 4.3.3 format_reward函数
    • 4.4 将数据集格式化为对话形式
    • 4.5 初始化GRPO Trainer


【复现DeepSeek-R1之Open R1实战】系列3:SFT和GRPO源码逐行深度解析(上)
【复现DeepSeek-R1之Open R1实战】系列5:SFT和GRPO源码逐行深度解析(中)

4 GRPO源码分析

前面两篇博文已经详细介绍了一些基础知识和SFT源码,本文继续解读GRPO源码。与SFT源码差不多的部分,我们就不展开细说了,这里只解析GRPO独特的部分。

4.1 数据类 GRPOScriptArguments

该类使用了 Python 的 dataclass 装饰器,这是一种简化类定义的方式,特别是对于那些主要用来存储数据的类。它继承自 ScriptArguments 类。

  • reward_funcs: 这是一个列表,包含了一系列可能的奖励函数名称,默认值为 ["accuracy", "format"]。这些奖励函数可能是用于评估模型性能的不同标准。

    reward_funcs: list[str] = field(
        default_factory=lambda: ["accuracy", "format"],
        metadata={
            "help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"
        },
    )
    
  • cosine_min_value_wrongcosine_max_value_wrong: 分别表示错误答案在余弦相似度尺度上的最小和最大奖励值,默认分别为 0.0-0.5

  • cosine_min_value_correctcosine_max_value_correct: 分别表示正确答案在余弦相似度尺度上的最小和最大奖励值,默认分别为 0.51.0

  • cosine_max_len: 表示余弦相似度尺度的最大长度,默认值为 1000

  • repetition_n_grams: 表示用于重复惩罚奖励的n-gram数量,默认值为 3

  • repetition_max_penalty: 表示重复惩罚奖励的最大负值,默认值为 -1.0

每个字段都使用了 field() 函数来定义其默认值和元数据(如帮助信息)。这有助于工具和库更好地理解和处理这些字段,例如生成命令行解析器时。

4.2 系统提示字符串 SYSTEM_PROMPT

SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)

字符串描述了一个对话场景,用户先提问,助手首先思考推理过程,然后提供答案。推理过程和答案分别用 <think><answer> 标签包裹,这种格式化有助于区分和识别不同的部分,和DeepSeek-R1的思考过程格式一致。

4.3 奖励函数

奖励函数的定义如下,GRPO默认用到了accuracy_reward和format_reward这两个函数。

# Get reward functions
    REWARD_FUNCS_REGISTRY = {
        "accuracy": accuracy_reward,
        "format": format_reward,
        "reasoning_steps": reasoning_steps_reward,
        "cosine": get_cosine_scaled_reward(
            min_value_wrong=script_args.cosine_min_value_wrong,
            max_value_wrong=script_args.cosine_max_value_wrong,
            min_value_correct=script_args.cosine_min_value_correct,
            max_value_correct=script_args.cosine_max_value_correct,
            max_len=script_args.cosine_max_len,
        ),
        "repetition_penalty": get_repetition_penalty_reward(
            ngram_size=script_args.repetition_n_grams,
            max_penalty=script_args.repetition_max_penalty,
        ),
        "length": len_reward,
    }
    reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

这段代码定义了一个奖励函数注册表 REWARD_FUNCS_REGISTRY,并根据用户提供的配置动态生成一个奖励函数列表 reward_funcs。每个奖励函数用于评估模型输出的不同方面,如准确性、格式、推理步骤等。

  1. 注册表定义
  • accuracy: 使用 accuracy_reward 函数评估模型输出的准确性。
  • format: 使用 format_reward 函数评估模型输出的格式。
  • reasoning_steps: 使用 reasoning_steps_reward 函数评估模型输出的推理步骤。
  • cosine: 使用 get_cosine_scaled_reward 函数计算余弦相似度奖励,参数包括:
    • min_value_wrong: 错误情况下的最小值。
    • max_value_wrong: 错误情况下的最大值。
    • min_value_correct: 正确情况下的最小值。
    • max_value_correct: 正确情况下的最大值。
    • max_len: 最大长度。
  • repetition_penalty: 使用 get_repetition_penalty_reward 函数计算重复惩罚奖励,参数包括:
    • ngram_size: n-gram 的大小。
    • max_penalty: 最大惩罚值。
  • length: 使用 len_reward 函数评估模型输出的长度。
  1. 动态生成奖励函数列表
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
  • 根据 script_args.reward_funcs 中指定的奖励函数名称,从 REWARD_FUNCS_REGISTRY 中获取相应的奖励函数,并生成一个列表 reward_funcs

4.3.1 accuracy_reward函数

该函数用于计算模型生成的补全与真实答案之间的准确性奖励。它通过解析和验证生成的内容与真实答案来确定奖励值。

def accuracy_reward(completions, solution, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, sol in zip(contents, solution):
        gold_parsed = parse(
            sol,
            extraction_mode="first_match",
            extraction_config=[LatexExtractionConfig()],
        )
        if len(gold_parsed) != 0:
            # We require the answer to be provided in correct latex (no malformed operators)
            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed="all",
                            units=True,
                        ),
                        # Ensures that boxed is tried first
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )
            # Reward 1 if the content is the same as the ground truth, 0 otherwise
            reward = float(verify(answer_parsed, gold_parsed))
        else:
            # If the gold solution is not parseable, we reward 1 to skip this example
            reward = 1.0
            print("Failed to parse gold solution: ", sol)
        rewards.append(reward)

    return rewards
  • completions (list): 包含多个补全结果的列表,每个补全结果是一个包含内容的字典列表。
  • solution (list): 真实答案的列表。
  • kwargs: 其他可选参数(在本函数中未使用)。
  1. 提取补全内容

    contents = [completion[0]["content"] for completion in completions]
    
    • completions 列表中提取每个补全的第一个内容(假设每个补全是单个元素的列表),形成一个新的 contents 列表。
  2. 初始化奖励列表

    rewards = []
    
  3. 遍历每个补全和对应的真实答案

    for content, sol in zip(contents, solution):
        gold_parsed = parse(
            sol,
            extraction_mode="first_match",
            extraction_config=[LatexExtractionConfig()],
        )
    
    • 使用 zip 函数将 contentssolution 配对。
    • 对于每一对补全内容和真实答案,首先解析真实答案 sol,使用 parse 函数提取其中的信息。
  4. 处理解析结果

    if len(gold_parsed) != 0:
        answer_parsed = parse(
            content,
            extraction_config=[
                LatexExtractionConfig(
                    normalization_config=NormalizationConfig(
                        nits=False,
                        malformed_operators=False,
                        basic_latex=True,
                        equations=True,
                        boxed="all",
                        units=True,
                    ),
                    # Ensures that boxed is tried first
                    boxed_match_priority=0,
                    try_extract_without_anchor=False,
                )
            ],
            extraction_mode="first_match",
        )
    
    • 如果解析得到的真实答案 gold_parsed 非空,则继续解析生成的补全内容 content
    • 使用 LatexExtractionConfigNormalizationConfig 进行详细配置,确保解析过程中考虑了各种格式要求(如方程、单位等)。
  5. 计算奖励

    reward = float(verify(answer_parsed, gold_parsed))
    
    • 使用 verify 函数比较生成的补全解析结果和真实答案的解析结果。
    • 如果两者匹配,则返回 1.0,否则返回 0.0
  6. 处理无法解析的情况

    else:
        reward = 1.0
        print("Failed to parse gold solution: ", sol)
    
    • 如果真实答案无法解析,则默认给予奖励 1.0 并打印警告信息。
  7. 添加奖励到列表

    rewards.append(reward)
    
  8. 返回所有奖励

    return rewards
    

4.3.2 verify函数

该函数用于验证目标表达式是否与参考表达式匹配,它通过多种比较策略来处理不同的数学对象(如数字、表达式、集合、矩阵等),并提供灵活的配置选项以适应不同的需求。

def verify(
    gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, 
    target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, 
    float_rounding: int=6,
    numeric_precision: int=15,
    strict: bool=True,
    timeout_seconds: int=3
) -> bool:
  • gold: 参考或正确的表达式,可以是单个 SymPy 表达式(BasicMatrixBase)、字符串或这些类型的列表。
  • target: 需要验证的表达式,类型同 gold
  • float_rounding: 浮点数舍入的小数位数,默认为 6。
  • numeric_precision: 数值比较时考虑的小数位数,默认为 15。
  • strict: 是否启用严格比较模式,默认为 True
    • 在严格模式下:变量很重要,集合不可与元组比较。
    • 在非严格模式下:变量按位置匹配,集合可与元组比较。
  • timeout_seconds: 单次比较操作的最大超时时间(秒),默认为 3 秒。
  1. 定义内部比较函数 compare_single_extraction

    @timeout(timeout_seconds=timeout_seconds)
    def compare_single_extraction(gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str) -> bool:
        ...
    
    • 使用装饰器 @timeout 设置超时保护,默认超时时间为 timeout_seconds
    • 比较两个表达式:
      • 如果两者都是 SymPy 表达式(BasicMatrixBase),则调用 sympy_expr_eq 进行比较。
      • 如果两者都是字符串,则进行简单的字符串比较。
  2. 定义包装函数 compare_single_extraction_wrapper

    def compare_single_extraction_wrapper(g, t):
        try:
            return compare_single_extraction(g, t)
        except Exception as e:
            logger.exception(f"Error comparing {g} and {t}")
            return False
    
    • 包装 compare_single_extraction,捕获并记录任何异常,返回 False 以避免程序中断。
  3. 处理输入列表

    if not isinstance(gold, list):
        gold = [gold]
    if not isinstance(target, list):
        target = [target]
    
    • 如果 goldtarget 不是列表,则将其转换为单元素列表,以便统一处理。
  4. 组合所有可能的比较

    return any(compare_single_extraction_wrapper(g, t) for g, t in product(gold, target))
    
    • 使用 itertools.product 生成所有可能的 goldtarget 组合。
    • 对每个组合调用 compare_single_extraction_wrapper,如果任意一对匹配成功,则返回 True

4.3.3 format_reward函数

函数用于检查给定的完成文本是否符合特定的格式,它验证完成文本是否包含 <think><answer> 标签,并且这两个标签的内容是非空的。

def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]
  • completions: 这是一个列表,其中每个元素都是一个包含完成内容的对象(通常是字典)。假设每个完成对象的第一个元素包含一个键 "content",其值是需要检查的文本。
  • kwargs: 其他关键字参数,这里没有使用,但可以为未来的扩展提供灵活性。
  1. 正则表达式模式定义

    pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
    
    • 这个正则表达式用于匹配字符串是否以 <think> 开始,紧接着是任意字符(非贪婪匹配),然后是 </think>,接着可能有任意数量的空白字符(包括换行符),最后是以 <answer> 开始并以 </answer> 结束。
    • .*? 是非贪婪匹配,确保尽可能少地匹配字符。
    • \s* 匹配零个或多个空白字符(包括换行符)。
    • re.DOTALL | re.MULTILINE 标志允许点号 . 匹配所有字符(包括换行符),并且使多行文本中的每一行都可以独立匹配。
  2. 提取完成内容

    completion_contents = [completion[0]["content"] for completion in completions]
    
    • 这里通过列表推导式从 completions 列表中提取每个完成对象的第一个元素的 "content" 字段,形成一个新的列表 completion_contents
  3. 匹配正则表达式

    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
    
    • 使用 re.match 函数对 completion_contents 中的每个内容应用正则表达式模式。
    • matches 列表将包含 re.Match 对象(如果匹配成功)或 None(如果匹配失败)。
  4. 生成奖励分数

    return [1.0 if match else 0.0 for match in matches]
    
    • 最后一步是根据匹配结果生成奖励分数。如果匹配成功(即 match 不是 None),则返回 1.0;否则返回 0.0

示例代码:

completions = [
    [{"content": "<think>This is reasoning.</think><answer>This is answer.</answer>"}],
    [{"content": "<think>This is reasoning.</think>"}],
    [{"content": "<answer>This is answer.</answer>"}],
    [{"content": "This does not match."}]
]

reward_scores = format_reward(completions)
print(reward_scores)  # 输出: [1.0, 0.0, 0.0, 0.0]

在这个例子中:

  • 第一个完成内容完全匹配正则表达式,因此得分为 1.0
  • 后三个完成内容不符合要求,因此得分均为 0.0

4.4 将数据集格式化为对话形式

# Format into conversation
    def make_conversation(example):
        return {
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": example["problem"]},
            ],
        }

    dataset = dataset.map(make_conversation)
    for split in dataset:
        if "messages" in dataset[split].column_names:
            dataset[split] = dataset[split].remove_columns("messages")

将一个数据集中的每个示例转换为对话格式,并确保数据集中没有多余的列(如 messages)。

  • 输入example 是一个字典,包含单个数据样本的信息,其中 "problem" 键对应的值是用户的问题或任务描述。
  • 输出:返回一个新的字典,包含一个 "prompt" 键,其值是一个对话列表:
    • 第一条消息是系统消息,内容由 SYSTEM_PROMPT 定义。
    • 第二条消息是用户消息,内容是 example["problem"]
  • dataset.map(make_conversation):使用 map 方法将 make_conversation 函数应用到数据集的每个示例上,生成新的对话格式。
  • 移除多余列:遍历数据集的每个拆分(split),如果存在 "messages" 列,则将其移除。

4.5 初始化GRPO Trainer

trainer = GRPOTrainer(
        model=model_args.model_name_or_path,
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=dataset[script_args.dataset_train_split],
        eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
        peft_config=get_peft_config(model_args),
        callbacks=get_callbacks(training_args, model_args),
    )

篇幅有限,训练部分的代码我们放到下一篇博文详细解读!


http://www.kler.cn/a/551667.html

相关文章:

  • 深入解析「卡顿帧堆栈」 | UWA GPM 2.0 技术细节与常见问题
  • 25工商管理研究生复试面试问题汇总 工商管理专业知识问题很全! 工商管理复试全流程攻略 工商管理考研复试真题汇总
  • 解决DeepSeek服务器繁忙的有效方法
  • vue3项目,商城系统
  • 网络工程师 (45)网际控制报文协议ICMP
  • 分布式储能监测云平台
  • 麒麟V10离线安装docker和docker-compose
  • 1.王道_常用命令
  • 嵌入式学习第十六天--stdio(二)
  • SQL进阶技巧:如何统计用户跨端消费行为?
  • STM32 HAL库USART串口中断编程:环形缓冲区防止数据丢失
  • 【开源】基于SSM框架网上招聘系统(计算机毕业设计)+万字毕业论文+远程部署+ppt+代码讲解 ssm592
  • android studio 界面启动模拟器无反应——从命令行启动模拟器
  • LLaVA-Mini部署教程 :模态预融合与视觉符元压缩,重新定义图像视频理解边界!
  • 调试变量的变化 vs数据断点调试
  • 核函数简述
  • 【MySQL排错 】mysql: command not found 数据库安装后无法加载的解决办法
  • 高效执行自动化用例:分布式执行工具pytest-xdist实战
  • leetcod20-有效的括号
  • 基于机器学习的医疗图像分析:从图像识别到精准诊断