当前位置: 首页 > article >正文

【复现DeepSeek-R1之Open R1实战】系列7:GRPO原理介绍、训练流程和源码深度解析

目录

    • 4.6 GRPO训练过程
      • 4.6.1 GRPO原理
      • 4.6.2 设置参考模型
      • 4.6.3 从训练集中抽取问题
      • 4.6.4 旧策略模型生成G个输出
      • 4.6.5 对每个输出用奖励模型 RM 打分
      • 4.6.6 根据目标函数做梯度更新

【复现DeepSeek-R1之Open R1实战】系列博文链接:
【复现DeepSeek-R1之Open R1实战】系列1:跑通SFT(一步步操作,手把手教学)
【复现DeepSeek-R1之Open R1实战】系列2:没有卡也能训模型!Colab跑OpenR1(附源码)
【复现DeepSeek-R1之Open R1实战】系列3:基础知识介绍
【复现DeepSeek-R1之Open R1实战】系列4:跑通GRPO!
【复现DeepSeek-R1之Open R1实战】系列5:SFT源码逐行深度解析
【复现DeepSeek-R1之Open R1实战】系列6:GRPO源码结构解析
【复现DeepSeek-R1之Open R1实战】系列7:GRPO原理介绍、训练流程和源码深度解析

4.6 GRPO训练过程

我们挑一些重点部分的代码来分析。

4.6.1 GRPO原理

在分析源码之前,我们再回顾一下GRPO的原理。

原理可以参考该博文:【DeepSeek-R1背后的技术】系列三:强化学习(Reinforcement Learning, RL)。

核心思想如下图所示:

GRPO

核心动机:在许多实际应用中,奖励只有在序列末端才给一个分数(称之为 Result/Oucome Supervision),或在每一步给一些局部分数(Process Supervision)。不管怎么样,这个奖励本身往往是离散且比较稀疏的,要让价值网络去学习每个token的价值,可能并不划算。而如果我们在同一个问题 q 上采样多份输出 o1, o2, … , oG,经过奖励模型Reward Model之后得到对应到奖励 r1, r2, … , rG,对它们进行奖励对比,就能更好地推断哪些输出更好。由此,就能对每个输出的所有 token 做相对评分,无须明确地学到一个价值函数。

在数理推理、数学解题等场景,这个技巧尤其管用,因为常常会基于同一个题目 q 生成多个候选输出,有对有错,或者优劣程度不同。那就把它们的奖励进行一个分组内的比较,以获取相对差异,然后把相对优势视为更新策略的依据。

关键点1:分组采样与相对奖励

GRPO 中,“分组”非常关键:我们会在一个问题 q 上,采样 GRPO 份输出 o1, o2, … , oG,然后把这组输出一起送进奖励模型(或规则),得到奖励分 r = {r1, r2, … , rG},先对r做归一化(减去均值除以标准差),从而得出分组内的相对水平,这样就形成了相对奖励 r’i,最后我们把这个相对奖励赋给该输出对应的所有 token 的优势函数。简单来说:多生成几份答案,一起比较,再根据排名或分数差更新,能更直接、简洁地反映同一问题下的优劣关系,而不需要用一个显式的价值网络去学习所有中间时刻的估计。

关键点2:无需价值网络的高效策略优化

因为不再需要在每个 token 上拟合一个价值函数,我们就能大幅节省内存,因为不必再维护和 Actor 同样大的 Critic 模型。这不仅是存储层面的解放,也是训练过程中的显著加速。当然,GRPO 也会引入一些新的代价:我们要为每个问题采样一组输出(不止一条),意味着推理时要多花点算力去生成候选答案。这种方法和“自洽性采样(Self-consistency)”思路也有点类似。

具体流程如下:

伪代码

分组相对奖励A’i,t的计算方法:

我们先把每个oi的奖励ri做归一化 r’i = ( ri - mean( r ) ) / std( r ),然后令A’i,t = r’i,也就是说,输出oi的所有 token 共享同一个分数r’i。它们的好坏相对于该分组内的平均水平来衡量,而不依赖外部价值网络去“拆分”或“插值”。这样我们就得到了一个无价值网络的优势函数,核心思路就是基于相互间的比较与排序。

如果用的是过程监督(process supervision),即在推理过程中的每个关键步骤都打分,那么就会略有不同。那时每个步骤都有一个局部奖励,就可以把它依时间序列累加或折算成与 token 对应的优势。

过程监督VS结果监督:过程奖励与末端奖励的对比

  • 结果监督(Outcome Supervision):只有输出序列结束才打一个奖励,如回答对/错、得分多少。GRPO 则把这个 r rr 同样分配给序列里每个 token。
  • 过程监督(Process Supervision):对中间推理步骤也有打分(比如计算正确一步就+1,错误一步就-1)。那就得收集多个时刻的奖励,然后累加到每个 token 或步骤上,再做分组相对化。

那么问题来了,batch内如何分组?在实际操作中,我们往往会在一个 batch 中包含若干个问题 q ,对每个问题生成 G 个答案。也就是说 batch 大小 = B,每个问题生成 G 个候选,那么一次前向推理要生成 B ∗ G 条候选。然后,每个候选都送进奖励模型得到分数ri。这样做推理开销不小,如果 G 较大,会显著地增加生成次数,但换来的好处是,我们不再需要价值网络了。

延伸:迭代式强化学习——奖励模型的更新与回放机制

在实际用 GRPO 的时候,如果奖励模型 RM 也是学习得来的,那么当策略模型变强时,RM 所得到的训练样本分布会越来越“难”,这时 RM 自身也需要更新。这样就会出现迭代强化学习流程:先用当前 RM 来指导一轮策略更新,然后再用新策略生成的数据来更新 RM。为了避免灾难性遗忘,可以保留一部分旧数据(回放机制 replay buffer),让 RM 每次都在新旧数据上共同训练,这样 RM 不会完全忘记之前的问题特征。

接下来,我们就按照上面的流程,详细解读源码。

4.6.2 设置参考模型

        # Reference model
        if is_deepspeed_zero3_enabled():
            self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
        elif not is_peft_model(model):
            # If PEFT configuration is not provided, create a reference model based on the initial model.
            self.ref_model = create_reference_model(model)
        else:
            # If PEFT is used, the reference model is not needed since the adapter can be disabled
            # to revert to the initial model.
            self.ref_model = None

这段代码片段展示了如何根据不同的条件创建或配置一个参考模型(ref_model),主要用于深度学习中的模型训练和评估。以下是详细的解析:

  1. DeepSpeed ZeRO-3 启用时

    • is_deepspeed_zero3_enabled():检查是否启用了 DeepSpeed 的 ZeRO-3 零冗余优化器。
    • 如果启用,则从预训练模型中加载一个因果语言模型(Causal Language Model)作为参考模型,并使用 model_init_kwargs 中的参数进行初始化。
  2. PEFT 模型未启用时

    • is_peft_model(model):检查当前模型是否是 PEFT(Parameter-Efficient Fine-Tuning)模型。
    • 如果不是 PEFT 模型,则调用 create_reference_model(model) 创建一个基于初始模型的参考模型。
  3. PEFT 模型启用时

    • 如果是 PEFT 模型,则不需要创建参考模型,因为可以通过禁用适配器(adapter)来恢复到初始模型状态,因此将 self.ref_model 设为 None

4.6.3 从训练集中抽取问题

采样器是RepeatRandomSampler类,主要是通过_prepare_inputs函数准备输入数据的。

def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
class RepeatRandomSampler(Sampler):
    """
    Sampler that repeats the indices of a dataset N times.

    Args:
        data_source (`Sized`):
            Dataset to sample from.
        repeat_count (`int`):
            Number of times to repeat each index.
        seed (`Optional[int]`):
            Random seed for reproducibility (only affects this sampler).

    Example:
    ```python
    >>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
    >>> list(sampler)
    [2, 2, 0, 0, 3, 3, 1, 1]
    ```
    """

    def __init__(self, data_source: Sized, repeat_count: int, seed: Optional[int] = None):
        self.data_source = data_source
        self.repeat_count = repeat_count
        self.num_samples = len(data_source)
        self.seed = seed
        self.generator = torch.Generator()  # Create a local random generator
        if seed is not None:
            self.generator.manual_seed(seed)

    def __iter__(self):
        indexes = [
            idx
            for idx in torch.randperm(self.num_samples, generator=self.generator).tolist()
            for _ in range(self.repeat_count)
        ]
        return iter(indexes)

    def __len__(self):
        return self.num_samples * self.repeat_count


    def _get_train_sampler(self) -> Sampler:
        # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
        # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
        # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
        # preventing discrepancies in group formation.
        return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)

    def _get_eval_sampler(self, eval_dataset) -> Sampler:
        # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
        # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
        # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
        # preventing discrepancies in group formation.
        return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)

4.6.4 旧策略模型生成G个输出

同样在_prepare_inputs函数函数里,通过self.llm.generate()函数生成了G个输出,并做了一系列后处理操作:

# Generate completions using either vLLM or regular generation
        if self.args.use_vllm:
            # First, have main process load weights if needed
            if self.state.global_step != self._last_loaded_step:
                self._move_model_to_vllm()
                self._last_loaded_step = self.state.global_step

            # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
            all_prompts_text = gather_object(prompts_text)
            if self.accelerator.is_main_process:
                outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
                completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
            else:
                completion_ids = [None] * len(all_prompts_text)
            # Broadcast the completions from the main process to all processes, ensuring each process receives its
            # corresponding slice.
            completion_ids = broadcast_object_list(completion_ids, from_process=0)
            process_slice = slice(
                self.accelerator.process_index * len(prompts),
                (self.accelerator.process_index + 1) * len(prompts),
            )
            completion_ids = completion_ids[process_slice]

            # Pad the completions, and concatenate them with the prompts
            completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
            completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
            prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        else:
            # Regular generation path
            with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
                prompt_completion_ids = unwrapped_model.generate(
                    prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
                )

            # Compute prompt length and extract completion ids
            prompt_length = prompt_ids.size(1)
            prompt_ids = prompt_completion_ids[:, :prompt_length]
            completion_ids = prompt_completion_ids[:, prompt_length:]

4.6.5 对每个输出用奖励模型 RM 打分

  1. Reward Model初始化
        # Reward functions
        if not isinstance(reward_funcs, list):
            reward_funcs = [reward_funcs]
        for i, reward_func in enumerate(reward_funcs):
            if isinstance(reward_func, str):
                reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
                    reward_func, num_labels=1, **model_init_kwargs
                )
        self.reward_funcs = reward_funcs

        # Reward weights
        if args.reward_weights is not None:
            if len(args.reward_weights) != len(reward_funcs):
                raise ValueError(
                    f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
                    f"functions ({len(reward_funcs)})"
                )
            self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
        else:
            self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
  1. 计算奖励分数

计算过程是在_prepare_inputs函数里实现的,主要功能模块如代码注释所示,整个计算过程和我们在上一节原理介绍里能一一对应上:


        rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
        for i, (reward_func, reward_processing_class) in enumerate(
            zip(self.reward_funcs, self.reward_processing_classes)
        ):
            if isinstance(reward_func, nn.Module):  # Module instead of PretrainedModel for compat with compiled models
                if is_conversational(inputs[0]):
                    messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
                    texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
                else:
                    texts = [p + c for p, c in zip(prompts, completions)]
                reward_inputs = reward_processing_class(
                    texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
                )
                reward_inputs = super()._prepare_inputs(reward_inputs)
                with torch.inference_mode():
                    rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]  # Shape (B*G,)
            else:
                # Repeat all input columns (but "prompt" and "completion") to match the number of generations
                keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
                reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
                output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
                rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

        # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
        # completions may be distributed across processes
        rewards_per_func = gather(rewards_per_func)

        # Apply weights to each reward function's output and sum
        rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)

        # Compute grouped-wise rewards
        mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
        std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)

        # Normalize the rewards to compute the advantages
        mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

        # Slice to keep only the local part of the data
        process_slice = slice(
            self.accelerator.process_index * len(prompts),
            (self.accelerator.process_index + 1) * len(prompts),
        )
        advantages = advantages[process_slice]

4.6.6 根据目标函数做梯度更新

在compute_loss函数里,根据每个生成的优势分数advantages计算对应的损失,并加上KL正则。

梯度更新在Trainer的主函数train()里实现了,可以参考前一篇博文介绍:【复现DeepSeek-R1之Open R1实战】系列5:SFT源码逐行深度解析。

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if return_outputs:
            raise ValueError("The GRPOTrainer does not support returning outputs")
        # Compute the per-token log probabilities for the model

        prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
        completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

        per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

        # Compute the KL divergence between the model and the reference model
        ref_per_token_logps = inputs["ref_per_token_logps"]
        per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

        # x - x.detach() allows for preserving gradients from x
        advantages = inputs["advantages"]
        per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
        per_token_loss = -(per_token_loss - self.beta * per_token_kl)
        loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

        # Log the metrics
        completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
        self._metrics["completion_length"].append(completion_length)

        mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
        self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

        return loss

http://www.kler.cn/a/552971.html

相关文章:

  • 从中心化到点对点:视频通话SDK组件EasyRTC如何通过WebP2P技术实现低延迟通信
  • 双脑微状态:一种量化任务驱动的脑间非对称性的超扫描EEG新方法
  • DeepSeek模型快速部署教程-搭建自己的DeepSeek
  • UE_C++ —— Container TMap
  • Spark 和 Hive 的关系与区别
  • 旧手机热点无法提供ipv6解决方法(emui 8 热点提供ipv6)
  • windows系统本地部署DeepSeek-R1全流程指南:Ollama+Docker+OpenWebUI
  • 【Postgresql】Windows 部署 Postgresql 数据库 (图文教程)
  • Cursor实战:Web版背单词应用开发演示
  • C# 实现完善 Excel 不规则合并单元格数据导入
  • 蓝桥杯备考:二分算法之木材加工
  • 【前端学习笔记】Vite
  • 开题报告——基于Spring Boot的社区居民健康管理平台的设计与实现
  • 如何禁止本地网络访问抖音?
  • 深入理解CSS三大特性——继承、优先级与层叠
  • 【AI】mac 本地部署 Dify 实现智能体
  • 国产编辑器EverEdit - 独门暗器:自动监视剪贴板内容
  • 【核心算法篇十一】《DeepSeek对抗训练:提升模型鲁棒性的五大策略》
  • Go语言入门指南
  • Bio-ORACLE数据分享[decade 2010-2020] [Surface layers]