当前位置: 首页 > article >正文

TRL里面GRPOTrainer中grpo_train.py文件详解

下面是一篇面向中文读者的博客,旨在全面介绍 grpo_train.py 的整体结构和关键流程,帮助你快速理解它在做什么、如何运行,以及在 GRPO(Group Relative Policy Optimization)训练流程中扮演什么角色。
Source: anaconda3/envs/xxx/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py

在这里插入图片描述


1. 背景:GRPO 是什么?

GRPO(Group Relative Policy Optimization) 是一种针对大语言模型(LLM)的强化学习训练方法,可看作 PPO 的一种改进变体。

  • PPO 中通常需要一个与策略模型相当规模的价值函数(critic),但在大型模型上这一点开销很大;
  • GRPO 则通过一次性对同一个 prompt 采样多条回答(一个 “Group”)并进行组内比较,不再需要显式的价值函数;
  • 大幅降低显存和算力需求,也使流程更直观:在同一 prompt 上把多条回答的 reward 进行对比、求相对优势(advantage),直接更新策略模型。

具体可参考笔者的另一篇博客:深度解析DeepSeek原论文中的 GRPO:带 clip 操作的完整公式与示例代码
grpo_train.py 正是针对这个算法的一个 PyTorch / Transformers 生态中的示例式实现,使你能够用“策略模型 + 参考模型 +(可选的)奖励模型”的方式来训练大语言模型。


2. 文件概览

文件中最核心的类是 GRPOTrainer(Trainer),继承自 transformers.Trainer。它重写或扩展了若干方法,包括:

  1. __init__:初始化模型、参考模型(ref_model)、奖励模型(reward_funcs)等,并作一些超参数设置(如 num_generations, beta 等)。
  2. _prepare_inputs:在训练循环中,每一个 batch 先对 prompt 进行采样生成多条回答,调用奖励模型打分,计算组内相对优势。
  3. compute_loss:根据 GRPO 公式,结合 KL 惩罚项和相对优势,计算最终损失并进行反向传播。
  4. prediction_step:训练 / 验证阶段如何调用 _prepare_inputs 并获取 loss。
  5. logcreate_model_card:日志与模型卡部分,可上传到 Hugging Face Hub 做模型管理。

3. 运行逻辑全流程

在实际运行训练时,GRPOTrainer.train() 会在循环中反复调用“采样 + 计算损失 + 更新模型”这条主线逻辑。

3.1 初始化(__init__

  • 加载策略模型model
    用户可以传入一个字符串(表示从 huggingface.co 加载某个 Causal LM),或一个已经初始化好的模型对象。
  • 加载参考模型ref_model
    默认会克隆一份与策略模型同样初始化的 ref_model,也可在某些配置下跳过 / 用 PEFT 方式禁用;
    用于后续计算 KL 惩罚,约束更新不要跑太远。
  • 加载奖励函数(reward_funcs
    • 可以是预训练好的 “SequenceClassification” 模型;
    • 可以是自定义 Python 函数(对回答打分);
    • 或者一个列表,意味着多种奖励加起来一起用。
  • 关键超参数
    • num_generations(同一个 prompt 上一次性采样多少回答,G)
    • max_prompt_lengthmax_completion_length 用于截断长度
    • beta:KL 正则项系数
    • use_vllm:是否用 vLLM 做推理
  • 准备数据集
    继承自 Trainer 的方式,传入 train_dataseteval_dataset 即可。

3.2 _prepare_inputs:采样 + 打分 + 计算相对优势

具体解析请参考笔者的另一篇博客:GRPO 与 TRL实现的GRPOTrainer中_prepare_inputs函数详解

这是 GRPO 的关键。简要流程:

  1. 对 batch 中每个 prompt(如 batch_size=8):
    • 调用模型一次性生成 num_generations 条回答(比如 G=4),所以最终会产生 8 * 4 = 32 条回答。
  2. 对 EOS:如果生成提早出现 EOS token,会用一个 mask(completion_mask)来屏蔽无效 token。
  3. 参考模型 log prob:用 ref_model 对完整序列(prompt + completion)计算 token 级对数概率,用于后面做 KL。
  4. 调用奖励模型:对每条回答打出一个 reward(可以是多个函数相加),形成 [B*G] 形状的 reward。
  5. 相对优势:把这 [B*G] reshape 成 [B, G],对同一个 prompt 的 G 条回答做“均值、标准差”,再 broadcast 回去,以得到每条回答的相对 advantage = ( ( r i − μ ) / σ (r_i - \mu)/\sigma (riμ)/σ)。
  6. 返回:把处理好的 prompt_ids, completion_ids, completion_mask, ref_per_token_logps, advantages 等信息一并打包,等待下一步在 compute_loss 中使用。

3.3 compute_loss:根据 GRPO 公式求 Loss

具体请参考笔者的另一篇博客:从公式到代码:DeepSeek大模型GRPO算法中的 compute_loss如何实现(基于TRL源代码)

compute_loss 中会做以下事情:

  1. 当前策略的 token-level log prob:对 [prompt + completion] 做前向,拿到 logits。
  2. KL 惩罚:参考策略与当前策略的差异,用一个近似公式(per_token_kl)。
  3. 相对优势:之前在 _prepare_inputs 已经算好 advantages
  4. PPO / GRPO style 的公式
    l o s s = − ( exp ⁡ ( log ⁡ p θ − log ⁡ p θ . d e t a c h ( ) ) × A ^ − β × K L ) . \mathrm{loss} = -\left( \exp(\log p_\theta - \log p_\theta.detach()) \times \hat{A}- \beta \times \mathrm{KL} \right). loss=(exp(logpθlogpθ.detach())×A^β×KL).
    并且用 completion_mask 遮住无效 token,对 batch 做平均。
  5. 返回一个标量 loss,供 PyTorch 做 .backward() 与更新。

3.4 训练循环

  • Trainer 带的主循环 train() 会不停地取一个 batch -> _prepare_inputs -> compute_loss -> 反向传播 -> 更新参数 -> 下一个 batch …
  • KL 系数 beta 会让模型别离参考策略太远,而“组内相对奖励”让它学会偏向分数更高的回答。
  • 如此周而复始,就完成了 GRPO 训练。

4. 主要函数/方法亮点

4.1 _set_signature_columns_if_needed

因为我们会在 _prepare_inputs 里自定义处理数据,并不直接依赖典型的 “model inputs” 结构,所以这里覆盖了父类 Trainer 的默认逻辑,告诉它只关心 "prompt" 这个字段即可。

4.2 _get_per_token_logps

具体参考笔者的另一篇博客:(trl的grpo_trainer)深度解析 _get_per_token_logps:如何计算序列中每个 Token 的对数概率
这是一个辅助函数,用来计算模型对 [prompt + completion] 每个 token 的对数概率。

  • 先调用 model(...).logits,再用 gather 提取实际 token 的 logit,再减去 logsumexp 得到 log(prob)。
  • 同时做了“移除最后一列 logit”、“只保留后面若干 token logits”等兼容处理。

4.3 prediction_step

在评估或预测时,也需要执行 _prepare_inputs 来生成多条回答并算 loss,只不过不会再反向传播。

  • 这个函数就是在 eval 阶段或预测时,对应地拿到 loss,用于打印日志或 early stopping 等操作。

4.4 log & create_model_card

  • log:将一些 metrics,比如 KL、reward、completion length 等写到日志。
  • create_model_card:辅助你在做完训练后生成一份 README,包含 base model 名称、训练配置信息、引用等。

5. 总结

grpo_train.py 中的 GRPOTrainer 是一个覆盖了 Trainer 部分方法、带有 GRPO 逻辑的训练器,专为“大模型 + 强化学习”场景设计。它的主要流程可以简述为:

  1. 数据处理 / 采样:对 prompt 生成多条回答,打分并算相对优势;
  2. loss 计算:引入参考模型做 KL 惩罚,把相对优势带入类似 PPO 的公式;
  3. 训练循环:父类 Trainer 提供的 train() 方法会自动按 batch 调用 _prepare_inputs & compute_loss,完成 RL 更新。

从而在没有显式价值函数的情况下,通过“分组比较 + KL 惩罚”高效地优化语言模型回答质量。

如果你想使用 GRPO 来做大语言模型强化学习,只需要准备好数据集(必须包含 prompt 字段)、一个或多个奖励函数(可以是预训练好的分类模型,也可以是自定义 Python 函数),再配合 GRPOTrainer 的超参数,即可快速开始训练。

后记

2025年2月22日15点57分于上海。在GPT4o大模型辅助下完成。


http://www.kler.cn/a/557506.html

相关文章:

  • CNN常用卷积核
  • 2025/2/22论文阅读
  • 使用docker配置PostgreSQL
  • [创业之路-321]:创新开拓思维和经营管理思维的比较
  • PHP post 数据丢失问题
  • 【部署优化篇十四】【十万字全景拆解:GitHub Actions自动化流水线设计圣经(DeepSeek工业级实践大公开)】
  • BGP配置华为——路径优选验证
  • CTF 代码学习日记 PHP
  • 基于GA遗传优化的BiLSTM双向长短期记忆网络序列预测算法matlab仿真,对比BiLSTM和LSTM
  • nodejs爬虫抓取数据快速入门
  • RTSP场景下的RTP与RTCP
  • Spring Boot3+Vue2极速整合:10分钟搭建DeepSeek AI对话系统
  • Linux/POSIX 多路IO复用
  • RTSP协议全解析
  • jQuery AJAX 方法详解
  • Java 异常(Exception)全面解析:类型与原理
  • 【网络安全】常见的web攻击
  • AI Agent实战:打造京东广告主的超级助手 | 京东零售技术实践
  • python与pycharm如何设置文件夹为源代码根目录
  • DeepSeek掘金——SpringBoot 调用 DeepSeek API 快速实现应用开发