当前位置: 首页 > article >正文

LLM - 大模型 ScallingLaws 的 C=6ND 公式推导 教程(1)

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/145185794


Scaling Laws

Scaling Laws (缩放法则) 是大模型领域中,用于描述 模型性能(Loss) 与 模型规模N、数据量D、计算资源C 之间关系的经验规律,揭示在大模型中,随着模型参数数量、数据集大小和计算资源的增加,模型性能的变化模式,指导更高效地分配资源,优化模型训练过程,实现更好的性能。这些规律不仅有助于预测不同规模模型的表现,还能为模型设计和训练提供理论依据,是推动大模型发展和应用的重要理论基础。

  • Paper: Scaling Laws for Neural Language Models
  • 其他参考:计算 大语言模型(多模态) 的参数量

系列文章:

  1. 大模型 ScallingLaws 的 C=6ND 公式推导
  2. 大模型 ScallingLaws 的 CLM 和 MLM 中不同系数
  3. 大模型 ScallingLaws 的迁移学习与混合训练

对于 Decoder-Only 模型,计算量 C C C (Flops)、模型参数量 N,数据大小 D D D (Tokens),三者近似满足 C ≈ 6 N D C \approx 6ND C6ND

1. 模型参数量 (N)

假设 Decoder 堆叠层数是 l l lAttention 隐藏层维度是 d d dFeedForward 维度是 4 d 4d 4d,其中,忽略 Embedding、Norm 和 Bias。

Transformer 的每 1 层包括 Self-AttentionMLP 等 2 个部分:

  • Self-Attention 的 参数量,包括 W Q , W K , W V , W O W_{Q},W_{K},W_{V},W_{O} WQ,WK,WV,WO 等 4个部分,维度均是 R d × d \mathbb{R}^{d \times d} Rd×d,整体参数量是 4 d 2 4d^2 4d2 (暂时忽略 MQA)
  • MLP 的参数量,只包括 W u p , W d o w n W_{up},W_{down} Wup,Wdown,维度均是 R d × 4 d \mathbb{R}^{d \times 4d} Rd×4d,整体参数量 2 ∗ 4 ∗ d 2 = 8 d 2 2*4*d^{2}=8d^{2} 24d2=8d2,(暂时忽略 W g a t e W_{gate} Wgate)
  • 全部层数 l l l 参数量,即 12 l d 2 12ld^{2} 12ld2

2. 模型计算量 (C)

模型的前向推理的计算量:

计算量的单位是 FLOPs (Floating Point Operations),对于矩阵 A ∈ R m × n , B ∈ R n × p A \in \mathbb{R}^{m \times n},B \in \mathbb{R}^{n \times p} ARm×n,BRn×p A B AB AB相乘的计算量为 2 m n p 2mnp 2mnp,计算每个元素 c i , j c_{i,j} ci,j 包括 1 次加法 1 次乘法,即每个点积运算都有 n n n 次 乘法和 n − 1 n-1 n1 次加法,即 2 × m n p 2 \times mnp 2×mnp

模型的反向推理的计算量是前向推理的 2 倍,即:

前向只需要结果往后传递,反向除了需要梯度往前传递,还需要计算当前参数 W W W 的梯度,更新当前的参数 W W W,因此计算量是 2 倍。

Decoder 的输入是 X ∈ R b × s × d X \in \mathbb{R}^{b \times s \times d} XRb×s×d,其中 b b b 是 batch size, s s s 是序列长度, d d d 是模型维度。

其中 Self-Attention 的 计算量:

  • 输入层计算: Q = X W Q , K = X W K , V = X W V Q=XW_{Q},K=XW_{K},V=XW_{V} Q=XWQ,K=XWK,V=XWV,即 3 × b × ( 2 × s × d × d ) = 6 b s d 2 3 \times b \times (2 \times s \times d \times d) = 6bsd^{2} 3×b×(2×s×d×d)=6bsd2
  • Attention 计算 Score: A = Q K ⊤ A=QK^{\top} A=QK ,使用 bmm (批次矩阵乘法),batch size 不变,计算过程是 b × R s × d × R d × s = b × R s × s b \times \mathbb{R}^{s \times d} \times \mathbb{R}^{d \times s} = b \times \mathbb{R}^{s \times s} b×Rs×d×Rd×s=b×Rs×s,计算量即 b × ( 2 × s × d × s ) = 2 b s 2 d b \times (2 \times s \times d \times s) = 2bs^{2}d b×(2×s×d×s)=2bs2d
  • Score 与 V 计算: X ′ = A V X^{'}=AV X=AV,计算过程是 b × R s × s × R s × d = b × R s × d b \times \mathbb{R}^{s \times s} \times \mathbb{R}^{s \times d} = b \times \mathbb{R}^{s \times d} b×Rs×s×Rs×d=b×Rs×d,计算量即 b × ( 2 × s × s × d ) = 2 b s 2 d b \times (2 \times s \times s \times d)=2bs^{2}d b×(2×s×s×d)=2bs2d
  • 输出层计算: X ′ W O X^{'}W_{O} XWO,计算过程是 b × R s × d × R d × d = b × R s × d b \times \mathbb{R}^{s \times d} \times \mathbb{R}^{d \times d} = b \times \mathbb{R}^{s \times d} b×Rs×d×Rd×d=b×Rs×d,计算量即 b × ( 2 × s × d × d ) = 2 b s d 2 b \times (2 \times s \times d \times d)=2bsd^{2} b×(2×s×d×d)=2bsd2
  • 合计: C A t t e n t i o n = 8 b s d 2 + 4 b s 2 d = b s d ( 8 d + 4 s ) C_{Attention}=8bsd^{2}+4bs^{2}d=bsd(8d+4s) CAttention=8bsd2+4bs2d=bsd(8d+4s)

其中 MLP 的 计算量,升维和降维的计算量相同:

  • 升维 X W u p XW_{up} XWup,计算过程是 b × R s × d × R d × 4 d = b × R s × 4 d b \times \mathbb{R}^{s \times d} \times \mathbb{R}^{d \times 4d} = b \times \mathbb{R}^{s \times 4d} b×Rs×d×Rd×4d=b×Rs×4d,计算量 b × ( 2 × s × d × 4 d ) = 8 b s d 2 b \times (2 \times s \times d \times 4d)=8bsd^{2} b×(2×s×d×4d)=8bsd2
  • 同理,降维也是一样。
  • 合计: C M L P = 16 b s d 2 C_{MLP}=16bsd^{2} CMLP=16bsd2

则每层的计算量:

C L a y e r = C A t t e n i o n + C M L P = 24 b s d 2 + 4 b s 2 d = b s d ( 24 d + 4 s ) C f o r w a r d = l b s d ( 24 d + 4 s ) C_{Layer}=C_{Attenion}+C_{MLP}=24bsd^{2}+4bs^{2}d=bsd(24d+4s) \\ C_{forward}=lbsd(24d+4s) CLayer=CAttenion+CMLP=24bsd2+4bs2d=bsd(24d+4s)Cforward=lbsd(24d+4s)

反向传播是正向传播的 2 倍,合计是 3 倍,即:

C = 3 × C f o r w a r d = 72 l b s d 2 + 12 l b s 2 d = 12 l b s d ( 6 d + s ) C= 3 \times C_{forward} = 72lbsd^{2} + 12lbs^{2}d = 12lbsd(6d + s) C=3×Cforward=72lbsd2+12lbs2d=12lbsd(6d+s)

1.3 合计

模型参数量是 N = 12 l d 2 N=12ld^{2} N=12ld2,计算量是 C = l b s d ( 72 d + 12 s ) C=lbsd(72d + 12s) C=lbsd(72d+12s),假设 s ≪ 6 d s \ll 6d s6d,那么:

C = 12 l d 2 × b s × ( 6 + s d ) = 6 × b s × 12 l d 2 × ( 1 + s 6 d ) = 6 × b s × N C = 12ld^{2} \times bs \times (6+\frac{s}{d}) = 6 \times bs \times 12ld^{2} \times (1+\frac{s}{6d}) = 6 \times bs \times N C=12ld2×bs×(6+ds)=6×bs×12ld2×(1+6ds)=6×bs×N

那么每个 Token 的计算量,即 除以 b s bs bs,整体计算量再 乘以 全部数据集(Token) D D D,即:

C = 6 × N × D C=6 \times N \times D C=6×N×D

参考:

  • 知乎 - 为什么反向计算是前向耗时的两倍?
  • GitHub - backprop_FLOPs.py
  • 知乎 - 腾讯算出 MoE 模型的 Scaling Law
  • 知乎 - 解析大模型中的 Scaling Law

http://www.kler.cn/a/508818.html

相关文章:

  • FPGA车牌识别
  • 本地仓库管理之当前分支内的操作
  • DPIN与CESS Network达成全球战略合作,推动DePIN与AI领域创新突破
  • linux下的NFS和FTP部署
  • [0242].第4-3章:SpringBoot2核心技术笔记
  • 【服务治理中间件】consul介绍和基本原理
  • 1.Spring AI 从入门到实践
  • 2025年应用与API安全展望:挑战与机遇并存
  • 青少年编程与数学 02-007 PostgreSQL数据库应用 05课题、结构化查询语言(SQL)
  • 1.6 阅读k8s源码的准备工作
  • Android 12.0 息屏休眠后立即启动屏保功能实现
  • SpringMVC 实战指南:打造高效 Web 应用的秘籍
  • 外包公司名单一览表(成都)
  • 《C++11》中的显式虚函数重载:深入理解与应用
  • 【数据分析(二)】初探 Pandas
  • Java中线程的学习
  • EI Scopus双检索 | 2025年第四届信息与通信工程国际会议(JCICE 2025)
  • redis.call()和redis.pcall()的区别
  • uniapp 微信小程序 editor 富文本编辑器
  • SpringBoot中Get请求和POST请求接收参数详解
  • STM32--定时器输出pwm知识点_stm32 pwm-CSDN博客
  • Python毕业设计选题:基于django+vue的智能租房系统的设计与实现
  • 彩色图像面积计算一般方法及MATLAB实现
  • 电脑换固态硬盘
  • 瑞芯微开发板/主板Android配置APK默认开启性能模式方法
  • Cursor新建远程分支后,更新到本地,一步到位