当前位置: 首页 > article >正文

混合精度-基于torch内部

定义

混合精度训练是一种在深度学习模型训练过程中,同时使用不同精度数据类型(主要是单精度 FP32 和半精度 FP16)来进行计算和存储的技术。以下是具体介绍:
数据类型

  • 单精度(FP32):具有较高的数值精度和动态范围,能够准确表示各种数值,但占用内存空间较大,计算速度相对较慢。在深度学习中,传统的训练方式通常使用 FP32 来存储模型参数、梯度和中间计算结果。
  • 半精度(FP16):占用内存空间仅为 FP32 的一半,数据传输和计算速度更快,但数值精度和动态范围相对较低。它适用于一些对精度要求不是特别高,但对计算速度和内存占用较为敏感的场景
    实现方式
    在混合精度训练中,通常会将模型的权重参数以 FP32 格式存储,以保证模型的准确性和稳定性。而在实际的计算过程中,将部分计算(如矩阵乘法、卷积运算等)使用 FP16 进行,以充分利用硬件对 FP16 的优化,提高计算速度和减少内存占用。同时,为了避免因 FP16 精度不足而导致的梯度消失或模型精度下降等问题,还会采用一些特殊的技术,如动态损失缩放(Dynamic Loss Scaling)。通过动态调整损失函数的缩放因子,使得在 FP16 计算中能够更好地保留梯度信息,确保模型的正常训练。

实战

-必要的库:from torch.cuda.amp import GradScaler, autocast

  • GradScaler 是 PyTorch 提供的一个工具,用于在混合精度训练中自动调整梯度的缩放比例。它可以帮助防止梯度下溢(underflow)和溢出(overflow),从而提高训练的稳定性和效率
  • autocast :autocast 上下文管理器会自动将计算图中的某些操作转换为半精度(FP16),而将某些操作保留为单精度(FP32)。这样可以在不损失精度的情况下,显著减少显存占用,并提高计算速度

步骤:
1.初始化梯度缩放器
scaler = GradScaler()
2.在前向传播中使用 autocast

with autocast():
    output = model(images, texts)
    total_loss = criterion(output, label_id)

在 autocast 上下文中,模型的前向传播和损失计算会自动使用混合精度。
3.在反向传播的时候用scaler.scale 来缩放梯度
scaler.scale(total_loss).backward()
反向传播时,scaler.scale 会将梯度放大,以防止梯度下溢。这有助于在使用混合精度训练时保持梯度的数值稳定性
4.更新参数:在 train_one_epoch 函数中,使用 scaler.step 和 scaler.update 来更新参数:

if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

作用:
scaler.step(optimizer):调用优化器的 step 方法来更新参数。
scaler.update():更新梯度缩放比例,以适应当前的梯度大小。
optimizer.zero_grad():清零梯度,为下一个批次的训练做准备。


http://www.kler.cn/a/595120.html

相关文章:

  • 尝试在软考65天前开始成为软件设计师-计算机网络
  • 【vLLM 学习】使用 XPU 安装
  • (C语言)sizeof与strlen的区别,以及有关习题练习
  • YOLO可视化界面,目标检测前端QT页面。
  • 深度解析超线程技术:一核多用的奥秘
  • 深入理解MySQL中的MVCC机制
  • 使用Qdrant等其他向量数据库时需要将将numpy 数组转换为列表 确保数据能被正确处理和序列化,避免类型不兼容的问题。
  • 正则表达式:文本处理的瑞士军刀
  • 动态生成 CSS 工具类:CSS函数实现 `pad20-top`、`pad40-bottom` 等灵活样式
  • tailwindcss如何修改elementplus的内部样式
  • 深度学习与传统算法在人脸识别领域的演进:从Eigenfaces到ArcFace
  • JsonAutoDetect.Visibility
  • Camera2 API拍照失败问题实录:从错误码到格式转换的排坑之旅
  • Langchain 提示词(Prompt)
  • 解锁C++编程能力:基础语法解析
  • DeepSeek面试——模型架构和主要创新点
  • 如何在Linux环境下编译文件
  • 【群晖NAS】git常见问题解决方法
  • NIO入门
  • VSCode中搜索插件显示“提取扩展时出错。Failed to fetch”问题解决!