当前位置: 首页 > article >正文

为什么混合精度训练中优化器参数仍然以 FP32 存储?LLaMA 2 7B 模型在混合精度下的显存需求

在混合精度训练(Mixed Precision Training)中,模型的权重以 BF16 格式存储用于前向计算和反向传播,以减少显存占用,但会保留 FP32 格式的副本用于权重更新(这是数值稳定性和更新精度的需求),优化器的参数(如 AdamW 的一阶动量和二阶动量)通常也以 FP32 存储。以下是详细分析 LLaMA 2 7B 模型在混合精度下的显存需求:


1. LLaMA 2 7B 模型的参数规模

LLaMA 2 7B 模型包含 70 亿参数(7B)
假设每个参数存储格式为 BF16 或 FP32,不同精度的存储需求为:

  • BF16:每个参数占 2 字节。
  • FP32:每个参数占 4 字节。

2. 显存需求的详细计算

在混合精度训练中,显存需求分为以下几个部分:

1)模型权重

  • 前向传播和反向传播
    模型权重以 BF16 存储。
    模型权重(BF16) = 7 × 1 0 9 × 2   bytes = 14   GB \text{模型权重(BF16)} = 7 \times 10^9 \times 2 \, \text{bytes} = 14 \, \text{GB} 模型权重(BF16)=7×109×2bytes=14GB

  • 权重更新副本
    模型权重需要一个 FP32 副本,用于优化器的权重更新。
    模型权重(FP32) = 7 × 1 0 9 × 4   bytes = 28   GB \text{模型权重(FP32)} = 7 \times 10^9 \times 4 \, \text{bytes} = 28 \, \text{GB} 模型权重(FP32)=7×109×4bytes=28GB

2)梯度

反向传播时计算的梯度与权重大小相同。梯度在计算过程中通常以 BF16 表示:
梯度(BF16) = 7 × 1 0 9 × 2   bytes = 14   GB \text{梯度(BF16)} = 7 \times 10^9 \times 2 \, \text{bytes} = 14 \, \text{GB} 梯度(BF16)=7×109×2bytes=14GB

3)优化器动量参数

AdamW 优化器需要额外存储两组与权重相同大小的动量参数:

  • 一阶动量(( m \mathbf{m} m)):以 FP32 存储。
    一阶动量(FP32) = 7 × 1 0 9 × 4   bytes = 28   GB \text{一阶动量(FP32)} = 7 \times 10^9 \times 4 \, \text{bytes} = 28 \, \text{GB} 一阶动量(FP32)=7×109×4bytes=28GB

  • 二阶动量(( v \mathbf{v} v)):以 FP32 存储。
    二阶动量(FP32) = 7 × 1 0 9 × 4   bytes = 28   GB \text{二阶动量(FP32)} = 7 \times 10^9 \times 4 \, \text{bytes} = 28 \, \text{GB} 二阶动量(FP32)=7×109×4bytes=28GB

4)激活值

在反向传播中,需要保留前向传播的激活值以计算梯度。激活值的显存需求取决于模型的结构(如 Transformer 层数、隐藏层大小等)。假设激活值约占权重显存的 30%(经验值),并以 BF16 存储:
激活值(BF16) ≈ 0.3 × 14   GB = 4.2   GB \text{激活值(BF16)} \approx 0.3 \times 14 \, \text{GB} = 4.2 \, \text{GB} 激活值(BF16)0.3×14GB=4.2GB


显存需求总结

组件存储格式显存需求
模型权重(BF16)BF1614 GB
权重更新副本(FP32)FP3228 GB
梯度(BF16)BF1614 GB
一阶动量(FP32)FP3228 GB
二阶动量(FP32)FP3228 GB
激活值(BF16)BF164.2 GB
总计116.2 GB

3. 深入分析

  1. 为什么权重需要 FP32 副本?
    尽管 BF16 格式可以减少显存需求,但它的精度(7 位有效数字)不足以支持高效的权重更新。FP32 格式(23 位有效数字)能避免误差累积,保证优化器的数值稳定性和训练效果。

  2. 动量参数显存需求为何如此高?
    AdamW 的一阶动量 ( m \mathbf{m} m) 和二阶动量 ( v \mathbf{v} v) 是与模型权重同规模的参数,各自以 FP32 存储,显存占用等于权重的 4 倍。

  3. 如何优化显存占用?

    • 使用 DeepSpeed ZeRO 技术,将优化器参数和动量参数分片,降低单张 GPU 的显存需求。
    • 使用更高效的优化器(如 LION),减少动量参数存储。

4. 代码示例:DeepSpeed 混合精度训练

以下是基于 DeepSpeed 和 LLaMA 2 7B 的训练代码:

import deepspeed
from transformers import AutoModelForCausalLM, AutoTokenizer

# 加载模型和分词器
model_name = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# DeepSpeed 配置
ds_config = {
    "fp16": {
        "enabled": True  # 启用混合精度训练 (BF16)
    },
    "optimizer": {
        "type": "AdamW",  # 优化器
        "params": {
            "lr": 1e-5,
            "betas": [0.9, 0.999],
            "eps": 1e-8,
            "weight_decay": 0.01
        }
    },
    "zero_optimization": {
        "stage": 2,  # ZeRO Stage 2,优化器参数分片
        "contiguous_gradients": True,
        "overlap_comm": True
    },
    "gradient_accumulation_steps": 4,
    "train_micro_batch_size_per_gpu": 1,
    "gradient_clipping": 1.0
}

# 启动 DeepSpeed
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    config_params=ds_config
)

# 示例数据
inputs = tokenizer("Hello, DeepSpeed!", return_tensors="pt")
outputs = model_engine(**inputs, labels=inputs["input_ids"])
loss = outputs.loss

# 反向传播和优化
model_engine.backward(loss)
model_engine.step()

5. 总结

  • 混合精度训练通过 BF16 格式大幅减少显存需求,但关键的优化器参数(权重更新副本、一阶动量、二阶动量)仍然以 FP32 存储,保证数值稳定性和训练精度。
  • 对于 LLaMA 2 7B 模型,混合精度训练的显存需求约为 116.2 GB,其中优化器参数占了大头。
  • 使用 DeepSpeed 等分布式技术可以显著优化显存分布,为大规模模型训练提供更高效的解决方案。

后记

2024年12月1日14点33分于上海,在GPT4o大模型辅助下完成。


http://www.kler.cn/a/420374.html

相关文章:

  • 主动安全和驾驶辅助模块(ASDM):未来驾驶的核心科技 随着汽车技术的不断进步,驾驶体验和安全性正经历着前所未有的变革。
  • 【从零开始的LeetCode-算法】35. 搜索插入位置
  • 3D Bounce Ball Game 有什么技巧吗?
  • 32 从前序与中序遍历序列构造二叉树
  • 【论文笔记】A Token-level Contrastive Framework for Sign Language Translation
  • iQOO Neo10系列携三大蓝科技亮相,性能与续航全面升级
  • react 父子组件通信
  • 【Qt】QDateTimeEdit控件实现清空(不保留默认时间/最小时间)
  • Pytorch使用手册- TorchVision目标检测微调Tutorial的使用指南(专题十二)
  • bash命令缓存导致命令执行失败的问题
  • 插入数据如何确保redis与数据库同步 详解
  • 单链表---链表分割
  • 基于米尔全志T527开发板的FacenetPytorch人脸识别方案
  • 【C++】深入解析 using namespace std 语句
  • npm error code ETIMEDOUT 简单排查
  • 双向长短期记忆(Bi-LSTM)神经网络介绍
  • Linux - 前端程序员常用的 Linux 命令
  • LearnOpenGL学习(光照 -- 投光物,多光源)
  • 在云上怎么样让环境更加安全?
  • SQLAlchemy
  • Spring,SpringMVC,SpringBoot,SpringCloud有什么区别和联系?
  • 汽车操作系统详解
  • dhcpd服务器的配置与管理(超详细!!!)
  • 贝叶斯统计的核心思想与基础知识:中英双语
  • 含k个3的数
  • 产品转后端?