当前位置: 首页 > article >正文

pytorch torch.utils.checkpoint模块介绍

torch.utils.checkpoint 是 PyTorch 中用于实现 梯度检查点(gradient checkpointing)的模块。它通过在反向传播中 重新计算 前向传播的某些部分,以显著减少激活值的显存占用。

梯度检查点的核心原理

  • 在前向传播中,不是保存每一层的激活值,而是保存输入和部分中间结果。
  • 在反向传播时,重新计算需要的前向激活值。
  • 优势
    • 显存占用减少:适合超大模型的训练。
  • 劣势
    • 计算量增加:反向传播时需要额外的前向计算。

核心API

1. torch.utils.checkpoint.checkpoint

torch.utils.checkpoint.checkpoint 是 PyTorch 提供的一种 内存优化工具,通过 计算图重新计算 的方式来节省显存。它特别适用于深度学习中 大模型或长序列 的训练场景,能够在不降低模型性能的情况下减少显存使用。

工作原理
  1. 标准前向传播

    • 默认情况下,PyTorch 在前向传播过程中,会存储中间激活值以供反向传播使用。
    • 如果模型层数很多或者中间激活值占用大量显存,会导致显存不足。
  2. 检查点机制

    • 在前向传播时,torch.utils.checkpoint.checkpoint 会丢弃某些中间激活值(未存储在显存中)。
    • 在反向传播时,丢弃的中间激活值会通过 重新计算前向传播 来生成。
    • 通过这种方式,显存的占用降低,但会增加一些前向计算的开销。
函数签名
torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=True)

参数
  • function:
    • 前向传播的函数,必须是纯函数(只依赖输入,不依赖外部状态)。
  • *args:
    • 传递给 function 的参数。
  • use_reentrant (默认值为 True):
    • 如果设置为 True,使用旧的递归检查点实现;如果为 False,启用非递归实现,推荐设置为 False 来避免潜在问题。
优缺点

优点

节省显存

  • 丢弃中间激活值后,显存占用显著降低,适合训练大模型。

适配性强

  • 不需要修改模型结构,只需在关键的计算图中插入检查点即可。
返回值

output:

  • 前向传播的结果。
使用场景

大模型的训练

  • 模型层数较多,激活值占用大量显存时&#

http://www.kler.cn/a/459435.html

相关文章:

  • 51单片机——LED模块
  • Appium 2.0:移动自动化测试的革新之旅
  • CSS系列(45)-- Scope详解
  • 金融租赁系统的创新与发展推动行业效率提升
  • linux-26 文件管理(四)install
  • API多并发识别、C#文字识别
  • Golang协程为什么⽐线程轻量
  • o1到o3的发展历程
  • lombok-macros
  • 【Go】context标准库
  • 步进电机驱动算法——S形加减速算法原理
  • 面试经典150题——数组/字符串(二)
  • 开发运维基本功:无需复杂配置快速实现本地Nginx的公网远程访问
  • 【文献精读笔记】Explainability for Large Language Models: A Survey (大语言模型的可解释性综述)(三)
  • 打破传统,手势验证码识别
  • DP协议:PHY层
  • JS - Array Api
  • 从零开始学架构——互联网架构的演进
  • C++ hashtable
  • No.3十六届蓝桥杯备战|数据类型长度|sizeof|typedef|练习(C++)
  • 线程-4-线程库与线程封装
  • 完整的 FFmpeg 命令使用教程
  • 【PyCharm】如何把本地整个项目同步到服务器?
  • 在web.xml中配置Servlet映射
  • 【Next.js】002-路由篇|App Router
  • 冒泡排序c语言