当前位置: 首页 > article >正文

深度学习优化-Gradient Checkpointing

数学原理参考:

梯度检查点技术(Gradient Checkpointing)详细介绍:中英双语-CSDN博客

视频讲解参考:

用梯度检查点来节省显存 gradient checkpointing_哔哩哔哩_bilibili

Gradient Checkpointing(梯度检查点

Gradient Checkpointing 是一种用于优化深度学习模型训练的技术,旨在减少训练过程中显存的占用。在深度神经网络训练中,通常需要存储每一层的激活值以用于反向传播计算梯度。然而,对于层数较多或参数量较大的模型,这些激活值会占用大量显存。

Gradient Checkpointing 的核心思想是在前向传播时选择性地保存部分激活值(称为检查点),而丢弃其他激活值。在反向传播时,如果需要这些被丢弃的激活值,则重新计算它们。通过这种方式,显存使用量可以从 O(L) 降低到 O(K),其中 L 是网络层数,K 是选择的检查点层数。

工作原理

  1. 选择检查点:在前向传播时,选择某些层作为检查点,保存这些层的激活值。

  2. 丢弃激活值:对于未被选为检查点的层,丢弃其激活值。

  3. 反向传播时重新计算:在反向传播时,如果需要被丢弃的激活值,则通过重新计算它们来获取,从而计算梯度。

a1和a3被丢弃,反向传播时,如果需要被丢弃的激活值,则需要重新计算

a1 = x * w1,

a3 = a2 * w3

优点与缺点

优点

  • 显著减少显存占用,使训练更大规模的模型成为可能。

  • 在显存受限的环境中,可以提高训练效率。

  • 允许使用更大的批量大小,从而加速训练。

缺点

  • 增加了计算开销,因为需要在反向传播时重新计算激活值。

  • 实现复杂度增加,需要修改代码来管理检查点。

  • 可能导致训练时间延长。

实现方法

在 PyTorch 中,可以通过 torch.utils.checkpoint 模块实现 Gradient Checkpointing。例如:

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.Linear(256, 256)
        self.layer2 = nn.Linear(256, 256)
        self.layer3 = nn.Linear(256, 10)

    def forward(self, x):
        x = checkpoint.checkpoint(self.layer1, x)  # 应用梯度检查点
        x = checkpoint.checkpoint(self.layer2, x)
        x = self.layer3(x)  # 最后一层不需要检查点
        return x

在 DeepSpeed 中,可以通过配置文件启用 Gradient Checkpointing:

{
    "train_batch_size": 16,
    "gradient_accumulation_steps": 4,
    "zero_optimization": {
        "stage": 2,
        "contiguous_gradients": true
    },
    "gradient_checkpointing": true
}

应用场景

Gradient Checkpointing 广泛应用于以下场景:

  • 训练大规模深度学习模型,如 7B 或 10B 参数的模型。

  • 在 GPU 显存有限的环境中优化训练。

  • 提高训练效率,同时减少硬件成本。

通过合理使用 Gradient Checkpointing,可以在有限的硬件资源下训练更大规模的模型,同时平衡显存和计算开销。


http://www.kler.cn/a/586197.html

相关文章:

  • ORACLE 19.8版本遭遇ORA-600 [kqrHashTableRemove: X lock].宕机的问题分析
  • CSS:不设定高度的情况,如何让flex下的两个元素的高度一致
  • 历次科技泡沫对人工智能发展的启示与规避措施
  • Python----计算机视觉处理(opencv:图片灰度化)
  • Unity屏幕适配——立项时设置
  • Python使用FastAPI结合Word2vec来向量化200维的语言向量数值
  • 缓存使用的具体场景有哪些?缓存的一致性问题如何解决?缓存使用常见问题有哪些?
  • 蓝思科技冲刺港股上市,双重上市的意欲何为?
  • TI的Doppler-Azimuth架构(TI文档)
  • 山东省新一代信息技术创新应用大赛-计算机网络管理赛项(样题)
  • LSTM方法实践——基于LSTM的汽车销量时序建模与预测分析
  • 华为OD机试-测试用例执行计划(Java 2024 D卷 100分)
  • MIFNet (论文阅读笔记)
  • 【实战篇】MySQL 时间字段的处理
  • java学习笔记1
  • 【eNSP实战】三层交换机使用ACL实现网络安全
  • C++中的单例模式及具体应用示例
  • 阿里云dataworks入门操作
  • Go语言单元测试和基准测试
  • 一个网络安全产品设计文档