深度学习优化-Gradient Checkpointing
数学原理参考:
梯度检查点技术(Gradient Checkpointing)详细介绍:中英双语-CSDN博客
视频讲解参考:
用梯度检查点来节省显存 gradient checkpointing_哔哩哔哩_bilibili
Gradient Checkpointing(梯度检查点
Gradient Checkpointing 是一种用于优化深度学习模型训练的技术,旨在减少训练过程中显存的占用。在深度神经网络训练中,通常需要存储每一层的激活值以用于反向传播计算梯度。然而,对于层数较多或参数量较大的模型,这些激活值会占用大量显存。
Gradient Checkpointing 的核心思想是在前向传播时选择性地保存部分激活值(称为检查点),而丢弃其他激活值。在反向传播时,如果需要这些被丢弃的激活值,则重新计算它们。通过这种方式,显存使用量可以从 O(L) 降低到 O(K),其中 L 是网络层数,K 是选择的检查点层数。
工作原理
-
选择检查点:在前向传播时,选择某些层作为检查点,保存这些层的激活值。
-
丢弃激活值:对于未被选为检查点的层,丢弃其激活值。
-
反向传播时重新计算:在反向传播时,如果需要被丢弃的激活值,则通过重新计算它们来获取,从而计算梯度。
a1和a3被丢弃,反向传播时,如果需要被丢弃的激活值,则需要重新计算
a1 = x * w1,
a3 = a2 * w3
优点与缺点
优点:
-
显著减少显存占用,使训练更大规模的模型成为可能。
-
在显存受限的环境中,可以提高训练效率。
-
允许使用更大的批量大小,从而加速训练。
缺点:
-
增加了计算开销,因为需要在反向传播时重新计算激活值。
-
实现复杂度增加,需要修改代码来管理检查点。
-
可能导致训练时间延长。
实现方法
在 PyTorch 中,可以通过 torch.utils.checkpoint
模块实现 Gradient Checkpointing。例如:
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.layer1 = nn.Linear(256, 256)
self.layer2 = nn.Linear(256, 256)
self.layer3 = nn.Linear(256, 10)
def forward(self, x):
x = checkpoint.checkpoint(self.layer1, x) # 应用梯度检查点
x = checkpoint.checkpoint(self.layer2, x)
x = self.layer3(x) # 最后一层不需要检查点
return x
在 DeepSpeed 中,可以通过配置文件启用 Gradient Checkpointing:
{
"train_batch_size": 16,
"gradient_accumulation_steps": 4,
"zero_optimization": {
"stage": 2,
"contiguous_gradients": true
},
"gradient_checkpointing": true
}
应用场景
Gradient Checkpointing 广泛应用于以下场景:
-
训练大规模深度学习模型,如 7B 或 10B 参数的模型。
-
在 GPU 显存有限的环境中优化训练。
-
提高训练效率,同时减少硬件成本。
通过合理使用 Gradient Checkpointing,可以在有限的硬件资源下训练更大规模的模型,同时平衡显存和计算开销。