ZeRO(Zero Redundancy Optimizer) 技术
1. ZeRO要解决什么问题?
训练超大模型(如GPT-3)时,内存不够用!传统数据并行(Data Parallelism)的痛点:
-
内存冗余:每个GPU都保存完整的模型、优化器状态、梯度,浪费显存。
-
通信成本高:梯度同步需要大量数据传输。
ZeRO的核心目标:消除内存冗余,同时保持计算效率。
2. ZeRO的核心思想
将模型训练所需的状态(参数、梯度、优化器状态)分割到不同GPU上,每个GPU只保留一部分,需要时再通过通信获取。
关键术语:
-
优化器状态(Optimizer States):如Adam中的动量(momentum)、方差(variance)。
-
梯度(Gradients):反向传播后的梯度。
-
模型参数(Parameters):模型的权重。
3. ZeRO的三个阶段
ZeRO-1:分割优化器状态
-
每个GPU只保存一部分优化器状态,其他部分通过通信获取。
-
内存节省:优化器状态减少到原来的 1/N(N为GPU数量)。
ZeRO-2:分割优化器状态 + 梯度
-
优化器状态和梯度都分割到不同GPU。
-
内存节省:梯度也减少到原来的 1/N。
ZeRO-3:分割优化器状态 + 梯度 + 模型参数
-
所有状态(参数、梯度、优化器状态)全部分割。
-
内存节省最大化:模型参数也减少到原来的 1/N。
-
代价:通信量增加,需权衡。
4. ZeRO如何节省内存?(以Adam优化器为例)
假设模型参数量为 M,使用Adam优化器,GPU数量为 N:
-
传统数据并行:每个GPU需要存储 M×(参数 + 梯度 + 动量 + 方差) = Ψ×16字节。
-
ZeRO-1:优化器状态分割 → 每个GPU存储 M×12 + M×4/N 字节。
-
ZeRO-2:优化器状态+梯度分割 → M×8 + M×8/N 字节。
-
ZeRO-3:全部分割 → M×16/N 字节。
5. 核心代码示例(基于PyTorch + DeepSpeed)
以下是一个使用ZeRO-2训练的简化代码示例:
环境准备
pip install deepspeed # 安装DeepSpeed库
Python代码
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import deepspeed
# 加载模型和数据集
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
dataset = [...] # 假设已准备好数据集
# DeepSpeed配置文件(zero_2.json)
ds_config = {
"train_batch_size": 16,
"gradient_accumulation_steps": 4,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 5e-5
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": 2, # 使用ZeRO-2
"allgather_partitions": True,
"allgather_bucket_size": 2e8,
"reduce_scatter": True,
"reduce_bucket_size": 2e8,
"overlap_comm": True # 重叠通信和计算
}
}
# 初始化DeepSpeed引擎
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
config_params=ds_config
)
# 训练循环
for batch in dataset:
inputs = tokenizer(batch, return_tensors="pt").to(model_engine.device)
outputs = model_engine(**inputs)
loss = outputs.loss
model_engine.backward(loss)
model_engine.step()
代码解释
-
DeepSpeed配置文件:通过
zero_optimization.stage
指定ZeRO阶段(这里是ZeRO-2)。 -
内存优化:梯度(Gradients)和优化器状态(Adam的动量、方差)被分割到不同GPU。
-
通信优化:
overlap_comm
允许通信和计算并行,减少训练时间。
6. ZeRO的优缺点
优点 | 缺点 |
---|---|
大幅降低显存占用,支持训练超大模型 | 增加通信开销(需高速GPU互联) |
兼容数据并行和模型并行 | 配置复杂(需调整通信参数) |
开源实现成熟(DeepSpeed库) | ZeRO-3可能增加代码复杂性 |
ZeRO的本质:用通信换内存,通过分割模型状态实现超大规模训练。
掌握ZeRO,你就能在面试中证明自己具备“训练千亿参数模型”的能力!