当前位置: 首页 > article >正文

AMP 混合精度训练中的动态缩放机制: grad_scaler.py函数解析( torch._amp_update_scale_)

AMP 混合精度训练中的动态缩放机制

在深度学习中,混合精度训练(AMP, Automatic Mixed Precision)是一种常用的技术,它利用半精度浮点(FP16)计算来加速训练,同时使用单精度浮点(FP32)来保持数值稳定性。为了在混合精度训练中避免数值溢出,PyTorch 提供了一种动态缩放机制来调整 “loss scale”(损失缩放值)。本文将详细解析动态缩放机制的实现原理,并通过代码展示其内部逻辑。


动态缩放机制简介

动态缩放机制的核心思想是通过一个可动态调整的缩放因子(scale factor)放大 FP16 的梯度,从而降低舍入误差对训练的影响。当检测到数值不稳定(例如 NaN 或无穷大)时,缩放因子会被降低;当连续多步未检测到数值问题时,缩放因子会被提高。其调整策略基于以下两个参数:

  • growth_factor: 连续成功步骤后用于增加缩放因子的乘数(通常大于 1,如 2.0)。
  • backoff_factor: 检测到数值溢出时用于减少缩放因子的乘数(通常小于 1,如 0.5)。

此外,动态缩放还使用 growth_interval 参数控制连续成功步骤的计数阈值。当达到这个阈值时,缩放因子才会增加。


AMP 缩放更新核心代码解析

PyTorch 实现了一个用于更新缩放因子的 CUDA 核函数以及相关的 Python 包装函数。以下是核心代码解析:

CUDA 核函数实现

// amp_update_scale_cuda_kernel 核函数实现
__global__ void amp_update_scale_cuda_kernel(float* current_scale,
                                             int* growth_tracker,
                                             const float* found_inf,
                                             double growth_factor,
                                             double backoff_factor,
                                             int growth_interval) {
  if (*found_inf) {
    // 如果发现梯度中存在 NaN 或 Inf,缩放因子乘以 backoff_factor,并重置 growth_tracker。
    *current_scale = (*current_scale) * backoff_factor;
    *growth_tracker = 0;
  } else {
    // 未发现数值问题,增加 growth_tracker 的计数。
    auto successful = (*growth_tracker) + 1;
    if (successful == growth_interval) {
      // 当 growth_tracker 达到 growth_interval,尝试增长缩放因子。
      auto new_scale = static_cast<float>((*current_scale) * growth_factor);
      if (isfinite_ensure_cuda_math(new_scale)) {
        *current_scale = new_scale;
      }
      *growth_tracker = 0;
    } else {
      *growth_tracker = successful;
    }
  }
}
核函数逻辑
  1. 发现数值溢出(found_inf > 0):

    • 缩放因子 current_scale 乘以 backoff_factor
    • 重置成功计数器 growth_tracker 为 0。
  2. 未发现数值溢出:

    • 增加成功计数器 growth_tracker
    • 如果 growth_tracker 达到 growth_interval,则将缩放因子乘以 growth_factor
    • 保证缩放因子不会超过 FP32 的数值上限。

C++ 包装函数实现

在 PyTorch 中,这一 CUDA 核函数通过 C++ 包装函数 _amp_update_scale_cuda_ 被调用。以下是实现代码:

Tensor& _amp_update_scale_cuda_(Tensor& current_scale,
                                Tensor& growth_tracker,
                                const Tensor& found_inf,
                                double growth_factor,
                                double backoff_factor,
                                int64_t growth_interval) {
  TORCH_CHECK(growth_tracker.is_cuda(), "growth_tracker must be a CUDA tensor.");
  TORCH_CHECK(current_scale.is_cuda(), "current_scale must be a CUDA tensor.");
  TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
  
  // 核函数调用
  amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
    current_scale.mutable_data_ptr<float>(),
    growth_tracker.mutable_data_ptr<int>(),
    found_inf.const_data_ptr<float>(),
    growth_factor,
    backoff_factor,
    growth_interval);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return current_scale;
}

Python 调用入口

AMP 的 GradScaler 类通过 _amp_update_scale_ 函数更新缩放因子,以下是相关代码:
代码来源:anaconda3/envs/xxxx/lib/python3.10/site-packages/torch/amp/grad_scaler.py

具体调用过程可以参考笔者的另一篇博文:PyTorch到C++再到 CUDA 的调用链(C++ ATen 层) :以torch._amp_update_scale_调用为例

def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
    """更新缩放因子"""
    if not self._enabled:
        return

    _scale, _growth_tracker = self._check_scale_growth_tracker("update")

    if new_scale is not None:
        # 设置用户定义的新缩放因子。
        self._scale.fill_(new_scale)
    else:
        # 收集所有优化器中的 found_inf 数据。
        found_infs = [
            found_inf.to(device=_scale.device, non_blocking=True)
            for state in self._per_optimizer_states.values()
            for found_inf in state["found_inf_per_device"].values()
        ]

        found_inf_combined = found_infs[0]
        if len(found_infs) > 1:
            for i in range(1, len(found_infs)):
                found_inf_combined += found_infs[i]

        # 更新缩放因子。
        torch._amp_update_scale_(
            _scale,
            _growth_tracker,
            found_inf_combined,
            self._growth_factor,
            self._backoff_factor,
            self._growth_interval,
        )

总结

PyTorch 的动态缩放机制通过 CUDA 核函数和 Python 包装函数协作完成。其核心逻辑是:

  1. 检测数值不稳定(如 NaN 或 Inf),通过缩小缩放因子提高数值稳定性。
  2. 当连续多次未出现数值不稳定时,逐步增大缩放因子以充分利用 FP16 的动态范围。
  3. 所有更新操作都在 GPU 上异步完成,最大限度地减少同步开销。

通过动态调整缩放因子,AMP 有效地加速了深度学习模型的训练,同时避免了梯度溢出等数值问题。


推荐阅读

  • PyTorch 官方文档
  • 混合精度训练介绍

后记

2025年1月2日15点38分于上海,在GPT4o大模型辅助下完成。


http://www.kler.cn/a/466362.html

相关文章:

  • springboot550乐乐农产品销售系统(论文+源码)_kaic
  • QT----------QT Data Visualzation
  • 基于Arduino的FPV头部追踪相机系统
  • C#—SynchronizationContext类详解 (同步上下文)
  • 机器人对物体重定向操作的发展简述
  • 如何使用SparkSql
  • Android 网络判断
  • Couchbase 的 OLAP 能力现状以及提升 OLAP 能力的方法
  • Android:动态去掉RecyclerView动画导致时长累加问题解决
  • 【蓝桥杯比赛-C++组-经典题目汇总】
  • cka考试-03-k8s版本升级
  • SpringBootWeb案例-2
  • 图形 3.5 Early-z和Z-prepass
  • Mysql监视器搭建
  • FPGA、STM32、ESP32、RP2040等5大板卡,结合AI,更突出模拟+数字+控制+算法
  • 仓储机器人底盘的研究
  • 在Microsoft Windows上安装MySQL
  • 2025年第五届控制理论与应用国际会议 | Ei Scopus双检索
  • 「Mac畅玩鸿蒙与硬件53」UI互动应用篇30 - 打卡提醒小应用
  • Chapter2 文本规范化
  • #C02L02P01. C02.L02.一维数组最值问题.知识点1.求最大值
  • Elasticsearch:利用 AutoOps 检测长时间运行的搜索查询
  • 【2025最新计算机毕业设计】基于SpringBoot+Vue智慧养老医护系统(高质量源码,提供文档,免费部署到本地)【提供源码+答辩PPT+文档+项目部署】
  • unity学习2:关于最近github的2FA(two-factor authentication)新认证
  • 深入理解正则表达式及基本使用教程
  • 图像转换 VM与其他格式互转