当前位置: 首页 > article >正文

deepseek-r1技术报告解析

文章题目:DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning在这里插入图片描述
这篇技术报告中提到了三个贡献:

  1. 实现了通过大规模强化学习(RL)训练的模型DeepSeek-R1-Zero
  2. 对DeepSeek-R1-Zero的训练方法做出改进,训练得到DeepSeek-R1模型
  3. 对DeepSeek-R1模型进行蒸馏,得到不同量级的稠密模型
    在这里插入图片描述

DeepSeek-R1-Zero的训练方法

基座模型:DeepSeek-V3-Base模型(MoE架构)
强化学习算法:GRPO
论文:DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models
代码:https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py
在这里插入图片描述
该方法主要对PPO进行了改进,原始的PPO算法需要CriticModel,然而,CriticModel需要额外的计算资源,并且也会有Reward Hacking等问题,因此在GRPO算法中去掉了CriticModel。具体实现参照huggingface的代码比较简洁,奖励的制定放在self.reward_funcs中。

# https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py
 def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
     if return_outputs:
         raise ValueError("The GRPOTrainer does not support returning outputs")

     prompts = [x["prompt"] for x in inputs]
     prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
     prompt_inputs = self.processing_class(
         prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
     )
     prompt_inputs = super()._prepare_inputs(prompt_inputs)

     if self.max_prompt_length is not None:
         prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
         prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]

     # Generate completions
     with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
         prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config)
         # self.num_generations has been set in self.generation_config
     prompt_length = prompt_inputs["input_ids"].size(1)
     completion_ids = prompt_completion_ids[:, prompt_length:]

     # Get the per-token log probabilities for the completions for the model and the reference model
     def get_per_token_logps(model, input_ids):
         logits = model(input_ids).logits  # (B, L, V)
         logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
         input_ids = input_ids[:, 1:]  # (B, L-1), exclude the first input ID since we don't have logits for it
         # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
         per_token_logps = []
         for logits_row, input_ids_row in zip(logits, input_ids):
             log_probs = logits_row.log_softmax(dim=-1)
             token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
             per_token_logps.append(token_log_prob)
         return torch.stack(per_token_logps)

     per_token_logps = get_per_token_logps(model, prompt_completion_ids)
     # Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
     per_token_logps = per_token_logps[:, prompt_length - 1 :]

     with torch.inference_mode():
         if self.ref_model is not None:
             ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids)
         else:
             with self.accelerator.unwrap_model(model).disable_adapter():
                 ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids)
     ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]

     # Compute the KL divergence between the model and the reference model
     per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

     # Mask everything after the first EOS token
     is_eos = completion_ids == self.processing_class.eos_token_id
     device = self.accelerator.device
     eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
     eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
     sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
     completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

     # Decode the generated completions
     completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
     if is_conversational(inputs[0]):
         completions = [[{"role": "assistant", "content": completion}] for completion in completions]

     # Compute the rewards
     prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]

     rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
     for i, (reward_func, reward_processing_class) in enumerate(
         zip(self.reward_funcs, self.reward_processing_classes)
     ):
     	 # RewardModel Or Function
         if isinstance(reward_func, PreTrainedModel):
             if is_conversational(inputs[0]):
                 messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
                 texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
             else:
                 texts = [p + c for p, c in zip(prompts, completions)]
             reward_inputs = reward_processing_class(
                 texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
             )
             reward_inputs = super()._prepare_inputs(reward_inputs)
             with torch.inference_mode():
                 rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]  # Shape (B*G,)
         else:
             # Repeat all input columns (but "prompt" and "completion") to match the number of generations
             reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
             for key in reward_kwargs:
                 for example in inputs:
                     # Repeat each value in the column for `num_generations` times
                     reward_kwargs[key].extend([example[key]] * self.num_generations)
             output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
             rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

     # Sum the rewards from all reward functions
     rewards = rewards_per_func.sum(dim=1)

     # Compute grouped-wise rewards
     mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
     std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)

     # Normalize the rewards to compute the advantages
     mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
     std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
     advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

     # x - x.detach() allows for preserving gradients from x
     per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
     per_token_loss = -(per_token_loss - self.beta * per_token_kl)
     loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

     # Log the metrics
     reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
     for i, reward_func in enumerate(self.reward_funcs):
         if isinstance(reward_func, PreTrainedModel):
             reward_func_name = reward_func.config._name_or_path.split("/")[-1]
         else:
             reward_func_name = reward_func.__name__
         self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())

     self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())

     self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())

     mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
     self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

     return loss

模版的设置方法如下:

在这里插入图片描述

强化学习奖励的制定

DeepSeek-R1-Zero模型主要针对数学或者编程,它们都有一个确切的答案,奖励容易制定。

包含两部分:1. 是否回答正确问题 2. 是否回答符合格式要求

两部分奖励进行加和作为最终奖励

DeepSeek-R1-Zero存在的缺点:思维链可读性差、语言混乱,冷启动(可能一开始的时候所有的输出都比较差)

针对这些缺点,DeepSeek人员研制了DeepSeek-R1模型。

DeepSeek-R1的训练方法

冷启动

收集了几千条用于长链条的思维链用以SFT基础模型(DeepSeek-V3-Base)。

收集方法:使用few-shot,直接提示DeepSeek-R1-Zero通过反思和验证生成详细的答案,并通过人工注释来对回复进行修正。

  1. 可读性更高,思维链中包含了对思考过程的总结,更有利于人类阅读;

  2. 比DeepSeek-Zero更好证明了这是一种迭代训练的模式,未来有着潜力可以继续使用。

面向推理的强化学习

相比DeepSeek-R1-Zero模型,DeepSeek-R1增加了语言一致性奖励,以增加结果的可读性。

拒绝采样和SFT

当面向推理的强化学习收敛时,利用得到的checkpoint来为收集SFT(监督微调)数据。与主要专注于推理的初始冷启动数据不同,这个阶段整合了来自其他领域的数据,以增强模型在写作、角色扮演和其他通用任务方面的能力。

推理数据:1. 除了用于强化学习的数据之外,引入了其他的数据集(猜测应该包含问题和思维链提示结果),通过把checkpoint生成的CoT和数据集中的CoT输入到DeepSeek-V3中,让DeepSeek-V3判断是否保留这条数据;2. 过滤了混合语言、长参数和代码块的思想链。对于每个提示,采样多个响应,并只保留正确的响应。总共收集600k训练样本。

非推理数据:如写作、事实QA、自我认知和翻译,调用DeepSeek-V3,提示回答问题之前生成一个潜在的思维链;此外也抽取了一部分DeepSeek-V3的SFT训练集用作该阶段的训练。对于更简单的查询,如“hello”,不提供CoT作为响应。最后,总共收集了大约200k个与推理无关的训练样本。

使用上述约800k个样本的数据集,对DeepSeek-V3-Base进行了SFT了2个epoch。

针对所有场景的强化学习

旨在提高模型的帮助性和无害性,同时改进其推理能力。对于推理数据,坚持DeepSeek-R1-Zero中利用基于规则的奖励来指导数学、代码和逻辑推理领域的学习过程。对于一般的数据,诉诸于奖励模型来捕捉人类偏好。

蒸馏小模型方法

微调数据:800k样本
微调模型:Qwen和Llama等开源模型
微调方法:只有SFT微调(与DeepSeek-R1相比,没有第四阶段的强化学习部分)

实验

评估指标:由于贪婪解码会导致模型口吃,因此通过设置TopK(0.95)和temperature(0.6)来生成答案,每个题目生成k个回答。

pass@1为这k个回答的正确率, p a s s @ 1 = 1 / k ∑ i = 1 k p i pass@1 = 1/k \sum_{i=1}^kp_i pass@1=1/ki=1kpi

cons@64为k取值64,这64个回答多数投票的结果。
在这里插入图片描述
在这里插入图片描述
在DeepSeek-R1技术报告实验部分的讨论了小模型是否需要大规模强化学习,实验结果如下,可以得出两个结论:首先,将更强大的模型提炼成更小的模型会得到很好的结果,而依赖于文章中提到的大规模强化学习训练较小模型可能无法达到蒸馏的性能。其次,虽然蒸馏策略既经济又有效,但超越智能的边界可能仍然需要更强大的基础模型和更大规模的强化学习。表中pass@1为
在这里插入图片描述
DeepSeek-R1-Zero的实验结果如下,数学能力接近open-ai,代码能力要差于open-ai。
在这里插入图片描述
此外,在DeepSeek-R1-Zero的训练过程中发生了两个有趣的现象,1. 随着训练轮次的增加,思维链在逐渐变长,如Figure 3所示;2. 在思维链中出现了拟人化的语气来重新思考的语句,引导模型深入探索,如Table 3所示。
在这里插入图片描述
在这里插入图片描述

不成功的尝试
  1. 蒙特卡洛树搜索(计算复杂度过高)
  2. 过程奖励模型(在一般推理中明确定义一个细粒度步骤具有挑战性、确定当前的中间步骤是否正确具有挑战性、会使训练变得复杂)

进一步思考

也可以使用rule-base的方法制定一些soft-reward,例如根据思维链的重复度,长度等做一些soft-reward,或许可以改善强化学习小模型的效果。

在使用过程中发现和gpt4-o1相比deepseek-r1模型生成的思维链实在是太长了,等待时间太久,之后或许可以探索怎么让他不要思考过度。


http://www.kler.cn/a/522299.html

相关文章:

  • three.js用粒子使用canvas生成的中文字符位图材质
  • Greenplum临时表未清除导致库龄过高处理
  • 第19篇:python高级编程进阶:使用Flask进行Web开发
  • WSL 安装cuDNN
  • 【学习笔记】计算机网络(二)
  • TCP是怎么判断丢包的?
  • 在RHEL 8.10上安装开源工业物联网解决方案Thingsboard 3.9
  • 【Linux】互斥锁、基于阻塞队列、环形队列的生产消费模型、单例线程池
  • “基因合作:生命演化中的共生与目的性”
  • 【2024年华为OD机试】 (A卷,200分)- 开放日活动、取出尽量少的球(JavaScriptJava PythonC/C++)
  • 6. 使用springboot做一个音乐播放器软件项目【1.0版项目完结】附带源码~
  • Android SystemUI——最近任务列表启动(十八)
  • FPGA 26,数码管动态显示,解析与实现( 使用 Xilinx Vivado 实现数码管动态显示 )
  • 计算机网络之计算机网络基本概念
  • 【Leetcode 每日一题】45. 跳跃游戏 II
  • Linux 命令之技巧(Tips for Linux Commands)
  • QT 笔记
  • 深入探讨防抖函数中的 this 上下文
  • 论文笔记(六十五)Schmidt-EKF-based Visual-Inertial Moving Object Tracking
  • LeetCode-175. 组合两个表
  • H2 Database安装部署
  • VMware 中Ubuntu无网络连接/无网络标识解决方法【已解决】
  • PHP Error处理与优化指南
  • volatile之四类内存屏障指令 内存屏障 面试重点 底层源码
  • 多模态论文笔记——TECO
  • 已解决:Win10任务状态栏卡死点击无响应的解决方案