当前位置: 首页 > article >正文

DeepSeek学习

下面详细介绍一下这个文件中与强化学习相关的代码实现思路,以及其中用到的库和函数。需要说明的是,在这个文件中,并没有看到完整的强化学习训练循环(比如 PPO、TRPO 等完整的 RL 算法实现),而是主要定义了几个用于 RL 环境下奖励计算的函数。这些奖励函数通常会被集成到一个强化学习训练流程中,用来为生成的回答提供反馈,从而指导模型优化。下面我具体讲解这部分代码及其如何与 RL 训练框架结合。


一、奖励函数的设计与实现

在这个文件中,强化学习的“反馈信号”主要体现在三个奖励函数中,每个奖励函数的核心作用如下:

  1. 正确性奖励函数

    def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
        responses = [completion[0]['content'] for completion in completions]
        q = prompts[0][-1]['content']
        extracted_responses = [extract_xml_answer(r) for r in responses]
        
        print('-'*20, f"Question:\n{q}", f"\nResponse:\n{responses[0]}", 
              f"\nExtracted:\n{extracted_responses[0]}", f"\nAnswer:\n{answer[0]}")
        
        return [1 if a in r else 0.0 for r, a in zip(extracted_responses, answer)]
    

    工作原理:

    • 提取输出内容:首先从 completions 中提取模型生成的文本。
    • 答案提取:利用之前定义的 extract_xml_answer() 函数,从生成的文本中提取出 <answer>...</answer> 内的内容。
    • 奖励判断:对每个输出,如果提取的答案中包含了标准答案,则给予奖励 1;否则奖励为 0

    用到的库和函数

    • Python 内置模块print(用于调试输出)。
    • 正则表达式模块re(在 extract_xml_answer 中使用,用于匹配 <answer> 标签)。
  2. 松散格式奖励函数

    def soft_format_reward_func(completions, **kwargs) -> list[float]:
        pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
        responses = [completion[0]["content"] for completion in completions]
        matches = [re.search(pattern, r, re.DOTALL) for r in responses]
        return [2 if match else 0.0 for match in matches]
    

    工作原理:

    • 检查生成的回答是否包含 <reasoning>...</reasoning><answer>...</answer> 这两个部分(中间可以有空格或换行)。
    • 如果匹配上就给予奖励 2,否则奖励为 0
  3. 严格格式奖励函数

    def strict_format_reward_func(completions, **kwargs) -> list[float]:
        pattern = r"^\s*<reasoning>.*?</reasoning>\s*<answer>.*?</answer>\s*$"
        responses = [completion[0]["content"] for completion in completions]
        matches = [re.search(pattern, r, re.DOTALL) for r in responses]
        return [4 if match else 0.0 for match in matches]
    

    工作原理:

    • 与松散奖励类似,但要求整个生成文本完全符合预定的格式,即开头和结尾都不能有额外的字符。
    • 如果完全匹配,则奖励 4;否则奖励 0

这些奖励函数都是在 RL 环境下给出反馈信号的关键组件,它们的返回值会作为奖励信号输入到 RL 算法中,指导模型优化。


二、强化学习代码的实现思路与库调用

虽然这个文件没有完整展示强化学习的训练循环,但可以推断出其实现思路和可能用到的外部库、函数如下:

  1. 模型输出生成与奖励计算

    • 生成回答:模型在 RL 训练中会接收一批提示(prompts),并生成一批回答(completions)。这一部分通常会使用类似于 Hugging Face Transformers 的生成接口(例如调用 model.generate())。
    • 奖励计算:生成的回答会依次传入上述奖励函数(如 correctness_reward_funcsoft_format_reward_funcstrict_format_reward_func)中,计算出每个回答的奖励值。
  2. 整合到 RL 算法中

    • RL 框架:强化学习部分很可能使用了像 PPO(Proximal Policy Optimization) 这类策略梯度算法。实际项目中,常用的库包括 Hugging Face 的 TRL(用于对 Transformer 模型进行 RLHF)或其他 RL 库(例如 Stable-Baselines3)。
    • 训练循环:在 RL 训练过程中,主要步骤通常包括:
      • 采样阶段:模型基于当前策略生成一批回答。
      • 奖励评估:调用奖励函数对生成的回答进行评分。
      • 策略更新:根据计算的奖励,使用策略梯度方法(如 PPO)更新模型参数。

    调用的核心函数和库:

    • Transformers 库:用于加载预训练模型和 Tokenizer,并生成文本。

    • ModelScope:用于下载模型文件。

    • Datasets 库:用于加载和处理训练数据集。

    • Python 的 re 模块:用于实现答案和格式的正则匹配。

    • RL 训练算法的接口:虽然在这个文件中没有直接看到诸如 PPOTrainer 或 RLTrainer

      的调用,但在实际使用中,这部分代码会调用诸如 TRL 库中的 PPO 接口来构建 RL 训练循环。例如,你可能会看到如下伪代码:

      from trl import PPOTrainer, PPOConfig
      
      # 配置 PPO 参数
      ppo_config = PPOConfig(
          model_name="Qwen/Qwen2.5-0.5B-Instruct",
          learning_rate=1e-5,
          batch_size=8,
          # ...其他参数
      )
      
      # 初始化 PPOTrainer
      ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer)
      
      # RL 训练循环
      for batch in training_dataloader:
          prompts = batch["prompt"]
          completions = model.generate(prompts)
          # 计算奖励,可能会调用 correctness_reward_func 等
          rewards = correctness_reward_func(prompts, completions, batch["answer"])
          # 更新模型
          ppo_trainer.step(prompts, completions, rewards)
      

      这里的

      step()
      

      方法会内部处理策略梯度、价值函数估计等更新步骤,而奖励函数正是用来计算每个生成回答的奖励信号的。

  3. 强化学习训练中的核心流程

    • 采样(Sampling):调用模型生成回答。
    • 奖励计算(Reward Computation):调用上面定义的奖励函数,对每个生成结果打分。
    • 策略更新(Policy Update):利用 RL 算法(例如 PPO)的更新规则,对模型进行参数调整,以便在未来生成更高奖励的回答。
    • 日志记录与调试:代码中使用 print 输出调试信息,这在 RL 训练中有助于观察问题、回答和奖励分数的对应关系,方便调试和验证奖励函数的有效性。

三、如何参照这段代码实现你的强化学习代码

如果你想实现自己的强化学习代码,可以参考以下步骤:

  1. 准备环境与依赖库
    • 安装并导入 TransformersDatasetsModelScope(如果需要)等库。
    • 如需完整 RL 支持,建议使用 Hugging Face 的 TRL 或其他 RL 库(例如 Stable-Baselines3)。
  2. 定义奖励函数
    • 根据你的任务需求,设计并实现奖励函数。你可以参照这里的正确性、松散格式和严格格式奖励函数。
    • 确保奖励函数返回数值型奖励,并且能够对生成文本进行解析(这里使用了正则表达式来提取 <answer> 部分)。
  3. 实现采样与生成逻辑
    • 利用 Transformers 的生成接口(例如 model.generate()),根据输入 prompt 生成回答。
  4. 集成到 RL 训练流程中
    • 构建一个 RL 训练循环:采样 → 奖励计算 → 策略更新。
    • 如果使用 TRL,你可以按照 TRL 的文档将奖励函数传入训练器中,调用诸如 ppo_trainer.step() 之类的接口进行训练。
  5. 调试与日志记录
    • 在奖励函数中增加调试信息(如示例中的 print 输出),以便在训练过程中检查生成回答与奖励计算是否符合预期。

总结

  • 奖励函数部分:代码中实现了三个奖励函数,用来衡量生成回答的正确性和格式是否满足要求。这部分代码调用了 Python 内置的 re 模块和标准数据结构操作。
  • 强化学习框架:虽然文件中没有展示完整的 RL 训练循环,但奖励函数的设计正是为 RL 训练服务的。你可以将这些奖励函数嵌入到如 PPO 这类 RL 算法的训练循环中。常用的库包括 Hugging Face 的 Transformers(生成回答)和 TRL(进行策略更新)。
  • 参考实现:你可以参照上面的伪代码示例,结合奖励函数、生成接口以及 RL 训练库,构建一个完整的强化学习训练流程。

下面是整理后的 Python 脚本文件,包含详细的注释,解释了每个部分的作用以及如何在强化学习流程中使用这些奖励函数。你可以将以下代码保存为例如 rl_training.py 文件:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
这个脚本用于下载 Qwen2.5-0.5B-Instruct 模型,加载对应的 Tokenizer,
并定义用于强化学习训练中的奖励函数。奖励函数包括:
    - 正确性奖励函数:检测生成的回答中是否包含正确答案
    - 松散格式奖励函数:检测回答是否包含符合预期的 XML 格式(宽松匹配)
    - 严格格式奖励函数:检测回答是否完全符合预定的 XML 格式(严格匹配)

在强化学习训练中,这些奖励函数将作为反馈信号,指导模型更新策略,
例如在 PPO 等 RL 算法中使用。
"""

import re
from modelscope import snapshot_download
from transformers import AutoTokenizer
from datasets import load_dataset, Dataset  # 如有需要加载数据集,可使用该库

# -----------------------------------------------------------------------------
# 模型下载与 Tokenizer 加载
# -----------------------------------------------------------------------------

# 使用 ModelScope 下载 Qwen2.5-0.5B-Instruct 模型
model_name = snapshot_download('Qwen/Qwen2.5-0.5B-Instruct')

# 使用 Hugging Face 的 AutoTokenizer 加载对应的 Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 设置 pad_token 为 eos_token,确保在序列填充时一致性
tokenizer.pad_token = tokenizer.eos_token

# -----------------------------------------------------------------------------
# 系统提示词与格式定义
# -----------------------------------------------------------------------------

# 系统提示词,指导模型按照指定的格式生成回答,包括推理过程和答案
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

# XML 格式的提示模板,通过字符串 format 方法填充具体的推理和答案内容
XML_COT_FORMAT = """
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

# -----------------------------------------------------------------------------
# 辅助函数:提取 XML 格式答案
# -----------------------------------------------------------------------------

def extract_xml_answer(text: str) -> str:
    """
    从给定文本中提取 <answer>...</answer> 标签内的内容。
    
    参数:
        text: 包含 XML 格式答案的文本字符串。
    
    返回:
        去除首尾空格后的标签内答案内容,如果没有找到,则返回空字符串。
    """
    # 使用正则表达式匹配 <answer> 标签中的所有内容(支持多行匹配)
    match = re.search('<answer>(.*)</answer>', text, re.DOTALL)
    if match:
        answer = match.group(1)
    else:
        answer = ''
    return answer.strip()

# -----------------------------------------------------------------------------
# 奖励函数:强化学习训练中的奖励信号
# -----------------------------------------------------------------------------

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """
    正确性奖励函数:
    判断生成的回答中提取出的答案是否包含正确答案,若包含则给予奖励。
    
    参数:
        prompts: 提示词列表(通常包含问题),格式为列表的列表,例如 [[{"content": "问题内容"}]]
        completions: 模型生成的回答列表,每个元素通常是一个包含文本内容的字典,例如 [{"content": "模型回答"}]
        answer: 正确答案列表,与生成回答一一对应,例如 ["正确答案"]
        kwargs: 其他可选参数
    
    返回:
        奖励列表:若提取的答案包含正确答案,则奖励 1,否则奖励 0.
    """
    # 从 completions 中提取模型生成的文本回答
    responses = [completion[0]['content'] for completion in completions]
    
    # 从 prompts 中提取问题内容(用于调试输出)
    q = prompts[0][-1]['content']
    
    # 利用 extract_xml_answer 从回答中提取 <answer> 标签内的内容
    extracted_responses = [extract_xml_answer(r) for r in responses]
    
    # 输出调试信息:显示问题、原始回答、提取的答案和正确答案
    print('-' * 20,
          f"Question:\n{q}",
          f"\nResponse:\n{responses[0]}",
          f"\nExtracted:\n{extracted_responses[0]}",
          f"\nAnswer:\n{answer[0]}")
    
    # 逐一判断:如果正确答案包含在提取的答案中,则奖励为 1,否则为 0
    return [1 if a in r else 0.0 for r, a in zip(extracted_responses, answer)]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """
    松散格式奖励函数:
    检查生成回答是否包含符合要求的 XML 格式(即包含 <reasoning> 和 <answer> 标签),
    不要求文本完全从头到尾匹配格式。
    
    参数:
        completions: 模型生成的回答列表,每个元素是包含文本内容的字典
        kwargs: 其他可选参数
    
    返回:
        奖励列表:如果回答中包含匹配的格式,奖励 2 分,否则奖励 0 分。
    """
    # 定义正则表达式,匹配包含 <reasoning>...</reasoning> 和 <answer>...</answer> 的回答
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    # 提取回答文本
    responses = [completion[0]["content"] for completion in completions]
    # 对每个回答使用正则表达式进行匹配(支持多行匹配)
    matches = [re.search(pattern, r, re.DOTALL) for r in responses]
    # 返回奖励:匹配成功则奖励 2 分,否则 0 分
    return [2 if match else 0.0 for match in matches]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """
    严格格式奖励函数:
    检查生成回答是否完全符合预定的 XML 格式,
    要求整个文本从头到尾只包含 <reasoning> 和 <answer> 两部分,允许空白字符但无其他字符。
    
    参数:
        completions: 模型生成的回答列表,每个元素是包含文本内容的字典
        kwargs: 其他可选参数
    
    返回:
        奖励列表:如果回答完全匹配预定格式,奖励 4 分,否则奖励 0 分。
    """
    # 定义严格匹配的正则表达式,要求文本完全符合指定格式
    pattern = r"^\s*<reasoning>.*?</reasoning>\s*<answer>.*?</answer>\s*$"
    # 提取回答文本
    responses = [completion[0]["content"] for completion in completions]
    # 对每个回答使用正则表达式进行严格匹配
    matches = [re.search(pattern, r, re.DOTALL) for r in responses]
    # 返回奖励:完全匹配则奖励 4 分,否则 0 分
    return [4 if match else 0.0 for match in matches]

# -----------------------------------------------------------------------------
# 示例:如何在强化学习训练流程中使用这些奖励函数
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    # 下面是一个简单的示例,展示如何调用奖励函数。
    # 在实际的强化学习训练中,你需要将以下步骤集成到模型生成、采样和策略更新中,
    # 例如结合 PPO 算法(使用 TRL 库或其他 RL 库)。

    # 构造示例提示词(prompts)
    example_prompts = [[{"content": "请问1+1等于多少?"}]]
    
    # 构造示例模型生成的回答(completions),回答格式应符合预期的 XML 格式
    example_completions = [[{"content": "<reasoning>1+1等于2,因为1和1相加得到2</reasoning><answer>2</answer>"}]]
    
    # 构造示例正确答案(answer)
    example_answers = ["2"]
    
    # 调用正确性奖励函数
    correctness_rewards = correctness_reward_func(example_prompts, example_completions, example_answers)
    print("Correctness Rewards:", correctness_rewards)
    
    # 调用松散格式奖励函数
    soft_format_rewards = soft_format_reward_func(example_completions)
    print("Soft Format Rewards:", soft_format_rewards)
    
    # 调用严格格式奖励函数
    strict_format_rewards = strict_format_reward_func(example_completions)
    print("Strict Format Rewards:", strict_format_rewards)
    
    # 在完整的强化学习训练过程中,奖励函数将与以下步骤结合:
    #   1. 利用模型(例如通过 model.generate())生成回答
    #   2. 使用奖励函数计算生成回答的奖励值
    #   3. 将奖励信号传入 RL 算法(例如 PPO)中进行策略更新
    # 以下是一个伪代码示例:
    #
    # from trl import PPOTrainer, PPOConfig
    #
    # # 配置 PPO 参数
    # ppo_config = PPOConfig(
    #     model_name="Qwen/Qwen2.5-0.5B-Instruct",
    #     learning_rate=1e-5,
    #     batch_size=8,
    #     # ...其他参数
    # )
    #
    # # 初始化 PPOTrainer
    # ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer)
    #
    # for batch in training_dataloader:
    #     prompts = batch["prompt"]
    #     completions = model.generate(prompts)
    #     rewards = correctness_reward_func(prompts, completions, batch["answer"])
    #     ppo_trainer.step(prompts, completions, rewards)

说明

  • 脚本顶部部分负责下载模型、加载 Tokenizer,并设置必要的配置。
  • 接下来定义了用于生成回答格式的系统提示词与 XML 模板。
  • extract_xml_answer 函数用于解析模型回答中的 <answer> 部分。
  • 三个奖励函数分别实现了对回答正确性、松散格式和严格格式的检测。
  • 最后,在 if __name__ == "__main__": 部分提供了一个简单示例,展示如何调用这些函数。实际的强化学习训练中,这些奖励函数将嵌入到生成、奖励计算和策略更新的训练循环中。

你可以根据自己的需求进一步扩展和集成到具体的强化学习训练流程中。


http://www.kler.cn/a/554947.html

相关文章:

  • NPM如何更换淘宝镜像——Node.js国内镜像配置教程
  • 「JVS更新日志」低代码、ERP应用、智能BI、智能排产2.19更新说明
  • Linux 实操篇 组管理和权限管理、定时任务调度、Linux磁盘分区和挂载
  • Redis 中列表(List)常见命令详解
  • 抖音试水AI分身;腾讯 AI 战略调整架构;百度旗下小度官宣接入DeepSeek...|网易数智日报
  • 网络安全防护
  • 【深度学习】计算机视觉(CV)-图像生成-风格迁移(Style Transfer)
  • 接口测试-Protobuf相关
  • 【RabbitMQ业务幂等设计】RabbitMQ消息是幂等的吗?
  • 我用Ai学Android Jetpack Compose之Composable与View的区别与联系
  • LeetCode 热题 100_搜索插入位置(63_35_简单_C++)(二分查找)(”>>“ 与 “/”)
  • 【HappyBase】连接hbase报错:ecybin.ProtocolError: No protocol version header
  • A105基于SpringBoot实现的甘肃非物质文化网站
  • 宠物行业研究系列报告
  • 为什么WP建站更适合于谷歌SEO优化?
  • 【HarmonyOS之旅】基于ArkTS开发(三) -> 兼容JS的类Web开发(四) -> 常见组件(二) -> swiper
  • 油田安全系统:守护能源生命线的坚固壁垒
  • Android14(13)添加墨水屏手写API
  • 使用Termux将安卓手机变成随身AI服务器(page assist连接)
  • 【Linux网络】TCP/IP地址的有机结合(有能力VS100%???),IP地址的介绍