Adam 和 AdamW 优化器详解及其训练显存需求分析:以LLaMA-2 7B为例(中英双语)
中文版
Adam 和 AdamW 优化器详解及其显存需求分析
在训练大规模神经网络时,优化器的选择和其在显存中的消耗是至关重要的,特别是像 LLaMA-2 7B 这样的大模型。今天我们将详细分析 Adam 优化器 和 AdamW 优化器,并结合 float32 和 bfloat16 精度的情况,探讨它们在显存消耗方面的表现。
1. Adam 优化器简介
Adam(Adaptive Moment Estimation)是一种常用的优化算法,它结合了动量(Momentum)和RMSProp优化器的优点。Adam通过维护每个参数的一阶矩估计(动量)和二阶矩估计(梯度平方的指数加权平均)来对参数进行更新。Adam优化器通过以下公式更新模型的权重:
1.1 Adam 优化器数学公式
假设我们有一个损失函数 ( L ( θ ) L(\theta) L(θ) ) 和参数向量 ( θ \theta θ ),则 Adam 优化器的更新规则如下:
-
计算梯度:
g t = ∇ θ L ( θ t ) g_t = \nabla_\theta L(\theta_t) gt=∇θL(θt)
其中 ( g t g_t gt ) 是当前时间步 ( t t t ) 的梯度。 -
一阶矩估计 (动量):
m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t mt=β1mt−1+(1−β1)gt
其中 ( β 1 \beta_1 β1 ) 是一阶矩的衰减率,通常取值接近 1(如 0.9)。 -
二阶矩估计:
v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 vt=β2vt−1+(1−β2)gt2
其中 ( β 2 \beta_2 β2 ) 是二阶矩的衰减率,通常取值接近 1(如 0.999)。 -
偏差校正(为了修正初始时刻 ( m_t ) 和 ( v_t ) 的偏差):
m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} m^t=1−β1tmt,v^t=1−β2tvt -
更新参数:
θ t + 1 = θ t − α v ^ t + ϵ m ^ t \theta_{t+1} = \theta_t - \frac{\alpha}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t θt+1=θt−v^t+ϵαm^t
其中 ( α \alpha α ) 是学习率,( ϵ \epsilon ϵ ) 是防止除零的常数。
1.2 Adam 优化器的内存消耗
Adam 优化器的内存消耗比传统的 SGD 优化器更高,因为它需要为每个参数维护 一阶矩 和 二阶矩,即两个额外的变量。假设 LLaMA-2 7B 模型有 70 亿个参数,每个参数需要存储 两个额外的矩,因此优化器的内存需求是模型参数内存的两倍。
- 参数内存:假设使用 float32 精度,每个参数占用 4 字节,那么 7B 模型的参数内存为:
参数内存 = 7 × 1 0 9 × 4 字节 = 28 GB \text{参数内存} = 7 \times 10^9 \times 4 \, \text{字节} = 28 \, \text{GB} 参数内存=7×109×4字节=28GB - 优化器内存:由于需要为每个参数维护两个矩,因此优化器的内存需求是模型参数的两倍:
优化器内存 = 2 × 28 GB = 56 GB \text{优化器内存} = 2 \times 28 \, \text{GB} = 56 \, \text{GB} 优化器内存=2×28GB=56GB
总的来说,Adam 优化器在 LLaMA-2 7B 模型的训练中,总显存消耗大约为 模型参数显存 + 梯度显存 + 优化器显存,即 28 GB + 28 GB + 56 GB = 112 GB,对于 24GB 显卡显然是无法容纳的。
下面的python代码测试7B大模型本身的参数量:以float32计算。进位采用1024,计算得出:7B大模型的参数量为26.08 GB;当进位采用1000时,计算得出28.00 GB。为什么尝试1000,是因为在其他博文中看到28GB这个数字,自己测试一下,发现他们是在以1000为进位的时候测试得出的。参考文章:https://cuiyuhao.com/posts/c87c0f5d/
# 定义参数
num_parameters = 7 * 10**9 # 70 亿个参数
bytes_per_param = 4 # 每个参数占用 4 字节(32 位浮动数)
# 计算显存需求(单位:字节)
memory_in_bytes = num_parameters * bytes_per_param
# 将字节转换为 GB
memory_in_GB = memory_in_bytes / (1024 ** 3) # 转换为 GB, 可调为1000
print(f"模型需要的显存为: {memory_in_GB:.2f} GB")
以bf16为例,由于它是float32的一半,所以它的参数量为 26.08GB / 2 = 13.04GB (以1024为进位),当以1000进位的时候,28GB / 2 = 14GB
2. AdamW 优化器简介
AdamW 是 Adam 优化器的一种变体,它对权重衰减(weight decay)做了改进。Adam 在更新参数时直接将权重衰减项添加到梯度中,而 AdamW 通过将衰减项从一阶矩和二阶矩的更新中分离出来,使得优化过程更加稳定。AdamW 的更新公式与 Adam 类似,但衰减项被单独处理:
2.1 AdamW 优化器数学公式
- 权重衰减(weight decay)项:
θ t + 1 = θ t − α v ^ t + ϵ m ^ t − λ θ t \theta_{t+1} = \theta_t - \frac{\alpha}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t - \lambda \theta_t θt+1=θt−v^t+ϵαm^t−λθt
其中 ( λ \lambda λ ) 是权重衰减系数。
2.2 AdamW 优化器的内存消耗
与 Adam 优化器一样,AdamW 也需要为每个参数维护 一阶矩 和 二阶矩。因此,它的内存消耗与 Adam 优化器相同,差异主要体现在梯度更新时的计算过程,而不是内存需求。
所以,AdamW 优化器对显存的占用与 Adam 优化器是相同的,依然是 模型参数的两倍。
- 参数内存:28 GB
- 优化器内存:56 GB
因此,AdamW 优化器在 LLaMA-2 7B 模型上的显存消耗与 Adam 优化器一致。
3. 在 float32 和 bfloat16 下的显存需求
3.1 使用 float32 精度
- 模型参数内存:28 GB
- 梯度内存:28 GB
- 优化器内存:56 GB
- 总显存需求:28 GB + 28 GB + 56 GB = 112 GB
3.2 使用 bfloat16 精度
使用 bfloat16 精度时,每个参数、梯度和优化器状态的内存需求将减半。假设 LLaMA-2 7B 使用 bfloat16 精度,则:
- 模型参数内存:14 GB
- 梯度内存:14 GB
- 优化器内存:28 GB
- 总显存需求:14 GB + 14 GB + 28 GB = 56 GB
因此,使用 bfloat16 精度时,显存需求比使用 float32 精度时减少了约一半。
4. 总结
- Adam 优化器 和 AdamW 优化器 都需要为每个参数维护一阶矩和二阶矩,因此它们的内存消耗是 模型参数内存的两倍。
- float32 精度:在使用 float32 精度时,LLaMA-2 7B 模型的总显存需求大约为 112 GB。
- bfloat16 精度:在使用 bfloat16 精度时,LLaMA-2 7B 模型的总显存需求为 56 GB。
通过选择合适的优化器和精度,尤其是在资源有限的情况下,可以大大减少显存消耗,确保大模型的训练可以在较小的 GPU 上完成。
英文版
Detailed Analysis of Adam and AdamW Optimizers and Their Memory Consumption with float32 and bfloat16 Precision
When training large-scale neural networks, especially models like LLaMA-2 7B, the choice of optimizer and the associated memory consumption are crucial factors. In this post, we’ll delve into the Adam optimizer and AdamW optimizer, explain their memory consumption in both float32 and bfloat16 precision, and provide a detailed example using the LLaMA-2 7B model.
1. Adam Optimizer Overview
Adam (Adaptive Moment Estimation) is a widely used optimizer that combines the benefits of both momentum and RMSProp optimizers. It maintains first-order momentum (moving averages of gradients) and second-order momentum (moving averages of squared gradients) to adaptively adjust the learning rate for each parameter. The update rule for Adam is as follows:
1.1 Adam Optimizer Mathematical Formulas
Assume we have a loss function ( L ( θ ) L(\theta) L(θ) ) and parameter vector ( θ \theta θ ). The Adam optimizer’s update rule can be written as:
-
Compute the gradient:
g t = ∇ θ L ( θ t ) g_t = \nabla_\theta L(\theta_t) gt=∇θL(θt)
where ( g t g_t gt ) is the gradient at time step ( t t t ). -
First moment estimate (momentum):
m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t mt=β1mt−1+(1−β1)gt
where ( β 1 \beta_1 β1 ) is the decay rate for the first moment, typically close to 1 (e.g., 0.9). -
Second moment estimate:
v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 vt=β2vt−1+(1−β2)gt2
where ( β 2 \beta_2 β2 ) is the decay rate for the second moment, typically close to 1 (e.g., 0.999). -
Bias correction (to adjust for the initial bias):
m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} m^t=1−β1tmt,v^t=1−β2tvt -
Parameter update:
θ t + 1 = θ t − α v ^ t + ϵ m ^ t \theta_{t+1} = \theta_t - \frac{\alpha}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t θt+1=θt−v^t+ϵαm^t
where ( α \alpha α ) is the learning rate and ( ϵ \epsilon ϵ ) is a small constant to prevent division by zero.
1.2 Memory Consumption of Adam Optimizer
The Adam optimizer requires extra memory compared to traditional stochastic gradient descent (SGD) because it stores both first moment (m) and second moment ( v v v) for each parameter. So, for each parameter, Adam needs to store two additional variables.
For example, with the LLaMA-2 7B model having 7 billion parameters, the memory required for the optimizer would be twice the size of the model parameters.
- Model Parameters Memory: Assuming float32 precision (4 bytes per parameter):
Model Parameters Memory = 7 × 1 0 9 × 4 bytes = 28 GB \text{Model Parameters Memory} = 7 \times 10^9 \times 4 \, \text{bytes} = 28 \, \text{GB} Model Parameters Memory=7×109×4bytes=28GB - Optimizer Memory: Since Adam maintains two variables for each parameter, the memory needed for the optimizer is:
Optimizer Memory = 2 × 28 GB = 56 GB \text{Optimizer Memory} = 2 \times 28 \, \text{GB} = 56 \, \text{GB} Optimizer Memory=2×28GB=56GB
In total, the memory required for training with Adam would be the sum of the model parameters, gradients, and optimizer states:
Total Memory = 28 GB (model) + 28 GB (gradients) + 56 GB (optimizer) = 112 GB \text{Total Memory} = 28 \, \text{GB (model)} + 28 \, \text{GB (gradients)} + 56 \, \text{GB (optimizer)} = 112 \, \text{GB} Total Memory=28GB (model)+28GB (gradients)+56GB (optimizer)=112GB
This total memory requirement makes it clear that for a 24GB GPU, training this model would not fit without further optimizations.
2. AdamW Optimizer Overview
AdamW is a variant of the Adam optimizer that decouples the weight decay from the gradient update. In the standard Adam optimizer, weight decay is incorporated into the gradient update, while in AdamW, the weight decay is applied separately to the parameters.
2.1 AdamW Optimizer Mathematical Formulas
The update rule for AdamW is similar to that of Adam but includes a decoupled weight decay term:
- Weight decay term:
θ t + 1 = θ t − α v ^ t + ϵ m ^ t − λ θ t \theta_{t+1} = \theta_t - \frac{\alpha}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t - \lambda \theta_t θt+1=θt−v^t+ϵαm^t−λθt
where ( λ \lambda λ ) is the weight decay coefficient.
2.2 Memory Consumption of AdamW Optimizer
Since AdamW maintains the same first and second moment estimates as Adam, its memory consumption is identical to Adam’s. Thus, the memory requirements for AdamW in terms of the model parameters and optimizer states are the same as those for Adam:
- Model Parameters Memory: 28 GB
- Optimizer Memory: 56 GB
Thus, AdamW’s memory consumption in LLaMA-2 7B training is also:
Total Memory = 28 GB (model) + 28 GB (gradients) + 56 GB (optimizer) = 112 GB \text{Total Memory} = 28 \, \text{GB (model)} + 28 \, \text{GB (gradients)} + 56 \, \text{GB (optimizer)} = 112 \, \text{GB} Total Memory=28GB (model)+28GB (gradients)+56GB (optimizer)=112GB
3. Memory Consumption with float32 and bfloat16 Precision
3.1 Using float32 Precision
In float32 precision, each parameter, gradient, and optimizer state requires 4 bytes of memory. Therefore, for LLaMA-2 7B, the memory consumption is:
- Model Parameters: 28 GB
- Gradients: 28 GB
- Optimizer States: 56 GB
- Total Memory: 28 GB + 28 GB + 56 GB = 112 GB
3.2 Using bfloat16 Precision
Using bfloat16 precision (16-bit floating point), each parameter, gradient, and optimizer state requires only 2 bytes of memory. For LLaMA-2 7B, the memory consumption with bfloat16 would be:
- Model Parameters: ( 7 × 1 0 9 × 2 bytes = 14 GB 7 \times 10^9 \times 2 \, \text{bytes} = 14 \, \text{GB} 7×109×2bytes=14GB )
- Gradients: 14 GB
- Optimizer States: ( 2 × 14 GB = 28 GB 2 \times 14 \, \text{GB} = 28 \, \text{GB} 2×14GB=28GB )
- Total Memory: 14 GB + 14 GB + 28 GB = 56 GB
By using bfloat16 precision, the memory consumption is reduced by half compared to float32, which is a significant advantage for training large models on GPUs with limited memory.
4. Summary
- Adam Optimizer and AdamW Optimizer both require additional memory for maintaining the first and second moment estimates, leading to twice the memory requirement of model parameters for the optimizer.
- float32 Precision: With float32 precision, the memory requirement for training LLaMA-2 7B with Adam or AdamW is approximately 112 GB.
- bfloat16 Precision: With bfloat16 precision, the memory requirement is reduced to 56 GB.
By choosing the appropriate optimizer and precision, you can significantly reduce memory usage and ensure the training of large models on GPUs with limited memory, which is essential for scaling up deep learning experiments.
后记
2024年11月29日18点38分于上海,在GPT4o大模型辅助下完成。