当前位置: 首页 > article >正文

DeepSeek-R1大模型学习笔记

DeepSeek-R1模型架构设计

DeepSeek-R1基于DeepSeek-V3 base模型,提出了一系列训练策略,包括基于纯强化学习的训练(DeepSeek-R1-Zero)、基于多阶段的训练和冷启动(DeepSeek-R1)、知识蒸馏等。下面的思维导图摘自另一个推文,我感觉把DeepSeek整体的框架概括得很清晰:

在这里插入图片描述

专家混合模型(MoE)

MoE在每次推理时选择性地激活部分模型参数,在不成比例增加计算成本的情况下,可以扩展模型参数。在DeepSeek-V2中就已经提出了用于FFN层的DeepSeekMoE。

  • 动态专家分配:根据token的上下文动态分配合适的专家
  • 强化学习(RL)引导路由:与DeepSeek-V2不同,DeepSeek-R1采用强化学习来引导专家利用,确保计算负载平衡
  • DeepSeek-V2引入辅助损失进行负载均衡,确保令牌在专家之间的分配更加均衡。DeepSeek-V3和DeepSeek-R1进一步采用用auxiliary-loss-free load balancing实现负载均衡,引入一个expert bias,这个bias只影响专家路由,而不影响任何梯度。动态调整bias,专家overloaded则降低bias,专家unoverloaded则增大bias。简单来说就是用加法高效地对gating score进行re-weight的过程
  • DeepSeek-R1总参数量671B,通过MoE对单个token的激活参数量仅37B,这个和DeepSeek-V3是一致的

在这里插入图片描述
Auxiliary-Loss-Free Load Balancing
和DeepSeek-V3一样,DeepSeek-R1采用了细粒度的MoE,一些expert作为共享expert,另一些expert作为routed expert进行动态激活。对于第t个token u t u_t ut,下面是MoE计算的过程:

在这里插入图片描述
以前基于auxiliary loss的方法需要修改loss function,当auxiliary loss很大时会影响模型性能。那么Auxiliary-Loss-Free则是在gating value g g g的基础上,额外加上了bias来实现负载均衡:

在这里插入图片描述
注意bias只影响专家路由,而不影响任何梯度。专家overloaded则降低bias,专家unoverloaded则增大bias。调整的速度由超参数 γ \gamma γ控制,这个和反向传播的梯度更新过程类似。

下图是该方法的出处:Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts文章所提出的负载均衡策略:
在这里插入图片描述
和DeepSeek-V2一样,DeepSeek-V3和DeepSeek-R1都采用了限制设备数量的MoE,并且不会再训练时做token dropping了。

多头潜在注意力(MLA)

MLA通过将QKV矩阵投影到低维潜在空间,显著降低计算和内存成本。DeepSeek-V2中就提出了用MLA来替代传统的多头自注意力。

MLA和其他注意力的对比如下,KV cache以一个更低的维度去存储和计算。
在这里插入图片描述
K和V的联合压缩如下:
在这里插入图片描述

真正推理时,cache的就是低维的 c t K V c_t^{KV} ctKV,并且down-proj和up-proj矩阵可以分别被吸收进 W Q W^Q WQ W O W^O WO中,不会造成额外的计算开销。这个方法和Palu: Compressing KV-Cache with Low-Rank Projection那篇文章一致。具体的融合过程如下(以 W Q W^Q WQ的融合为例):

在这里插入图片描述

为了在训练时降低激活的memory,也对query做低秩压缩:
在这里插入图片描述
【还没理解到对query低秩分解怎么省计算,算的时候不需要重构回去?】

RoPE位置编码兼容性考虑
但是KV cache的低秩压缩和RoPE位置编码并不兼容!如果对 k t C k_t^C ktC做RoPE, W U K W^{UK} WUK会和位置敏感的RoPE矩阵耦合在一起,从而不能在推理时被吸收进 W Q W^Q WQ中(这里应该强调一下吸收是totally offline完成的),带来额外的计算。
进一步理解 W U K W^{UK} WUK和RoPE矩阵的耦合:与生成当前token相关的RoPE矩阵位于 W Q W^{Q} WQ W U K W^{UK} WUK之间,而矩阵乘法不满足交换律。
于是DeepSeek-V2提出了解耦RoPE策略,用额外的多头query和一个共享key来计算RoPE,然后和原本的query和key拼接起来。至于这里怎么得到的额外query和key,就是用来两个额外的线性层来算得的。

在这里插入图片描述

下图体现了MLA的整个过程,值得注意的一点是,MLA的低秩分解是基于训练的,而非用SVD之类的方式post-training分解直接推理(比如Pula文章)。

在这里插入图片描述

训练策略

DeepSeek-R1-Zero

在这里插入图片描述

DeepSeek-R1-Zero直接在DeepSeek-V3 base模型的基础上用纯的Group Relative Policy Optimization (GRPO)强化学习算法,而不引入Supervised Fine-tuning (SFT)训练。

DeepSeek-R1-Zero采用基于规则的奖励机制,包含1)accuracy奖励和2)格式奖励,模板如下图所示,思考过程和回答过程需要放在对应的tag中。
在这里插入图片描述

DeepSeek-R1-Zero的缺点
DeepSeek-R1-Zero面临着可读性差和语言混合(比如中英文混杂)等挑战。

DeepSeek-R1

在DeepSeek-R1-Zero的基础上,DeepSeek-R1加入了冷启动,并且用高质量数据做SFT+RL训练,得到了当今的“最强大模型”。下面是关键技术和训练流程:
冷启动SFT
在高质量的CoT数据集上做SFT,改善DeepSeek-R1-Zero可读性较差的问题。
面向推理的RL
冷启动训练后,用和DeepSeek-R1-Zero一样的RL训练策略继续训练,并且采用语言一致性奖励改善语言混杂的问题,进一步增强模型的可读性。
拒绝采样和SFT
RL训练收敛后,用拒绝采样生成600k推理数据,并与非推理数据(写作、事实质量保证、自我认知和翻译)融合,然后进行SFT,以适应非推理场景。总共大约800k数据,训练DeepSeek-V3-Base 2个epoch。
全场景RL
进一步使模型与人类的偏好保持一致,进一步用了二阶段RL阶段提高模型的帮助性和无害性,同时改进其功能推理能力。对于推理任务,继续用基于规则的奖励;对于通用任务,采用偏好奖励模型。

多token预测(MTP)

MTP使DeepSeek-R1并行预测多个token,从而显著提高推理速度。MTP已经在DeepSeek-V3中已经被用于训练的目标。

  • 并行解码:通过允许在相同的上下文窗口内进行多个token生成预测,扩展了自回归框架
  • 动态预测视距:根据模型置信度调整每步预测的token数量
  • 强化学习引导的token选择:确保多token预测中的一致性,并减少错误传播
  • 训练时MTP包含了多个MTP模块,主要用于提升模型的推理性能,在推理时,可以直接将多个MTP模块丢弃,只保留主模型,然后推理。也可以重新利用这些MTP模块,借助speculative decoding来加速推理过程。

在这里插入图片描述

FP8量化

DeepSeek-R1利用8位浮点数(FP8)量化,减少内存使用和计算成本,同时保持数值稳定性。

  • 自适应位宽:根据计算需求动态调整不同层的位宽精度
  • 感知损失的量化:使用损失敏感的缩放函数,确保在不同计算阶段保持数值精度

知识蒸馏

用DeepSeek-R1直接蒸馏Qwen2.5-Math-1.5B, Qwen2.5-Math-7B, Qwen2.5-14B, Qwen2.5-32B, Llama-3.1-8B和Llama-3.3-70B-Instruct,都有显著的性能提升。蒸馏时只在800k个样本上用了SFT,而并没有用RL。


http://www.kler.cn/a/533118.html

相关文章:

  • 【Git】一、初识Git Git基本操作详解
  • 物联网领域的MQTT协议,优势和应用场景
  • 【PyQt】pyqt小案例实现简易文本编辑器
  • 【数据结构】栈与队列
  • 机器学习--1.KNN机器学习入门
  • pytorch图神经网络处理图结构数据
  • 用Python实现SVM分类器:从数据到决策边界可视化,以鸢尾花数据集为例
  • DeepSeek 本地部署全攻略
  • Java使用Jsoup处理报文简单样例
  • CSS in JS
  • 【LeetCode: 922. 按奇偶排序数组 II + 双指针】
  • 个人c项目 java项目解释
  • 力扣 45. 跳跃游戏 II
  • 3. k8s二进制集群之负载均衡器高可用部署
  • 7. k8s二进制集群之Kube ApiServer部署
  • Oracle日常管理(8)——OS日常管理(1)
  • WPS计算机二级•幻灯片的配色、美化与动画
  • Day 28 卡玛笔记
  • JAVA篇12 —— 泛型的使用(待完善)
  • 多线程的常用方法
  • 高等代数笔记—域与一元多项式
  • 中国证券基本知识汇总
  • HTB:Administrator[WriteUP]
  • 【01-Qt-C++-android】
  • Redis --- 秒杀优化方案(阻塞队列+基于Stream流的消息队列)
  • 100.3 AI量化面试题:解释配对交易(Pairs Trading)的原理,并说明如何选择配对股票以及设计交易信号