TRL里面GRPOTrainer中grpo_train.py文件详解
下面是一篇面向中文读者的博客,旨在全面介绍 grpo_train.py
的整体结构和关键流程,帮助你快速理解它在做什么、如何运行,以及在 GRPO(Group Relative Policy Optimization)训练流程中扮演什么角色。
Source: anaconda3/envs/xxx/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py
1. 背景:GRPO 是什么?
GRPO(Group Relative Policy Optimization) 是一种针对大语言模型(LLM)的强化学习训练方法,可看作 PPO 的一种改进变体。
- PPO 中通常需要一个与策略模型相当规模的价值函数(critic),但在大型模型上这一点开销很大;
- GRPO 则通过一次性对同一个 prompt 采样多条回答(一个 “Group”)并进行组内比较,不再需要显式的价值函数;
- 大幅降低显存和算力需求,也使流程更直观:在同一 prompt 上把多条回答的 reward 进行对比、求相对优势(advantage),直接更新策略模型。
具体可参考笔者的另一篇博客:深度解析DeepSeek原论文中的 GRPO:带 clip 操作的完整公式与示例代码
grpo_train.py
正是针对这个算法的一个 PyTorch / Transformers 生态中的示例式实现,使你能够用“策略模型 + 参考模型 +(可选的)奖励模型”的方式来训练大语言模型。
2. 文件概览
文件中最核心的类是 GRPOTrainer(Trainer)
,继承自 transformers.Trainer
。它重写或扩展了若干方法,包括:
__init__
:初始化模型、参考模型(ref_model)、奖励模型(reward_funcs)等,并作一些超参数设置(如num_generations
,beta
等)。_prepare_inputs
:在训练循环中,每一个 batch 先对prompt
进行采样生成多条回答,调用奖励模型打分,计算组内相对优势。compute_loss
:根据 GRPO 公式,结合 KL 惩罚项和相对优势,计算最终损失并进行反向传播。prediction_step
:训练 / 验证阶段如何调用_prepare_inputs
并获取 loss。log
与create_model_card
:日志与模型卡部分,可上传到 Hugging Face Hub 做模型管理。
3. 运行逻辑全流程
在实际运行训练时,GRPOTrainer.train()
会在循环中反复调用“采样 + 计算损失 + 更新模型”这条主线逻辑。
3.1 初始化(__init__
)
- 加载策略模型(
model
)
用户可以传入一个字符串(表示从 huggingface.co 加载某个 Causal LM),或一个已经初始化好的模型对象。 - 加载参考模型(
ref_model
)
默认会克隆一份与策略模型同样初始化的ref_model
,也可在某些配置下跳过 / 用 PEFT 方式禁用;
用于后续计算 KL 惩罚,约束更新不要跑太远。 - 加载奖励函数(
reward_funcs
)- 可以是预训练好的 “SequenceClassification” 模型;
- 可以是自定义 Python 函数(对回答打分);
- 或者一个列表,意味着多种奖励加起来一起用。
- 关键超参数
num_generations
(同一个 prompt 上一次性采样多少回答,G)max_prompt_length
、max_completion_length
用于截断长度beta
:KL 正则项系数use_vllm
:是否用 vLLM 做推理
- 准备数据集
继承自Trainer
的方式,传入train_dataset
、eval_dataset
即可。
3.2 _prepare_inputs
:采样 + 打分 + 计算相对优势
具体解析请参考笔者的另一篇博客:GRPO 与 TRL实现的GRPOTrainer中_prepare_inputs函数详解
这是 GRPO 的关键。简要流程:
- 对 batch 中每个 prompt(如 batch_size=8):
- 调用模型一次性生成
num_generations
条回答(比如 G=4),所以最终会产生 8 * 4 = 32 条回答。
- 调用模型一次性生成
- 对 EOS:如果生成提早出现 EOS token,会用一个 mask(
completion_mask
)来屏蔽无效 token。 - 参考模型 log prob:用
ref_model
对完整序列(prompt + completion)计算 token 级对数概率,用于后面做 KL。 - 调用奖励模型:对每条回答打出一个 reward(可以是多个函数相加),形成
[B*G]
形状的 reward。 - 相对优势:把这
[B*G]
reshape 成[B, G]
,对同一个 prompt 的 G 条回答做“均值、标准差”,再 broadcast 回去,以得到每条回答的相对 advantage = ( ( r i − μ ) / σ (r_i - \mu)/\sigma (ri−μ)/σ)。 - 返回:把处理好的
prompt_ids
,completion_ids
,completion_mask
,ref_per_token_logps
,advantages
等信息一并打包,等待下一步在compute_loss
中使用。
3.3 compute_loss
:根据 GRPO 公式求 Loss
具体请参考笔者的另一篇博客:从公式到代码:DeepSeek大模型GRPO算法中的 compute_loss如何实现(基于TRL源代码)
在 compute_loss
中会做以下事情:
- 当前策略的 token-level log prob:对
[prompt + completion]
做前向,拿到 logits。 - KL 惩罚:参考策略与当前策略的差异,用一个近似公式(
per_token_kl
)。 - 相对优势:之前在
_prepare_inputs
已经算好advantages
。 - PPO / GRPO style 的公式:
l o s s = − ( exp ( log p θ − log p θ . d e t a c h ( ) ) × A ^ − β × K L ) . \mathrm{loss} = -\left( \exp(\log p_\theta - \log p_\theta.detach()) \times \hat{A}- \beta \times \mathrm{KL} \right). loss=−(exp(logpθ−logpθ.detach())×A^−β×KL).
并且用completion_mask
遮住无效 token,对 batch 做平均。 - 返回一个标量
loss
,供 PyTorch 做.backward()
与更新。
3.4 训练循环
- Trainer 带的主循环
train()
会不停地取一个 batch ->_prepare_inputs
->compute_loss
-> 反向传播 -> 更新参数 -> 下一个 batch … - KL 系数
beta
会让模型别离参考策略太远,而“组内相对奖励”让它学会偏向分数更高的回答。 - 如此周而复始,就完成了 GRPO 训练。
4. 主要函数/方法亮点
4.1 _set_signature_columns_if_needed
因为我们会在 _prepare_inputs
里自定义处理数据,并不直接依赖典型的 “model inputs” 结构,所以这里覆盖了父类 Trainer 的默认逻辑,告诉它只关心 "prompt"
这个字段即可。
4.2 _get_per_token_logps
具体参考笔者的另一篇博客:(trl的grpo_trainer)深度解析 _get_per_token_logps:如何计算序列中每个 Token 的对数概率
这是一个辅助函数,用来计算模型对 [prompt + completion]
每个 token 的对数概率。
- 先调用
model(...).logits
,再用gather
提取实际 token 的 logit,再减去logsumexp
得到 log(prob)。 - 同时做了“移除最后一列 logit”、“只保留后面若干 token logits”等兼容处理。
4.3 prediction_step
在评估或预测时,也需要执行 _prepare_inputs
来生成多条回答并算 loss,只不过不会再反向传播。
- 这个函数就是在
eval
阶段或预测时,对应地拿到 loss,用于打印日志或 early stopping 等操作。
4.4 log
& create_model_card
log
:将一些 metrics,比如 KL、reward、completion length 等写到日志。create_model_card
:辅助你在做完训练后生成一份 README,包含 base model 名称、训练配置信息、引用等。
5. 总结
grpo_train.py
中的 GRPOTrainer
是一个覆盖了 Trainer
部分方法、带有 GRPO 逻辑的训练器,专为“大模型 + 强化学习”场景设计。它的主要流程可以简述为:
- 数据处理 / 采样:对 prompt 生成多条回答,打分并算相对优势;
- loss 计算:引入参考模型做 KL 惩罚,把相对优势带入类似 PPO 的公式;
- 训练循环:父类
Trainer
提供的train()
方法会自动按 batch 调用_prepare_inputs
&compute_loss
,完成 RL 更新。
从而在没有显式价值函数的情况下,通过“分组比较 + KL 惩罚”高效地优化语言模型回答质量。
如果你想使用 GRPO 来做大语言模型强化学习,只需要准备好数据集(必须包含 prompt
字段)、一个或多个奖励函数(可以是预训练好的分类模型,也可以是自定义 Python 函数),再配合 GRPOTrainer
的超参数,即可快速开始训练。
后记
2025年2月22日15点57分于上海。在GPT4o大模型辅助下完成。