1F1B 非交错式调度模式与 GPipe 策略的内存节省优势
1F1B 非交错式调度模式在节省内存方面表现更好,主要体现在以下几个方面:
- 减少中间变量的缓存
在 GPipe 策略中,每个设备需要缓存多个 micro-batch 的中间变量(activations),因为每个 micro-batch 的前向计算结果需要在后向计算中使用。这导致显存占用显著增加。例如,如果有4个 micro-batch,每个设备需要缓存4份中间变量。
而在 1F1B 非交错式调度模式中,每个设备在完成一个 micro-batch 的前向计算后,立即进行后向计算,从而可以及时释放前向计算的中间变量。这样,每个设备只需要缓存当前正在处理的 micro-batch 的中间变量,显著减少了显存占用。 - 显存峰值的降低
研究表明,1F1B 方式相比于 F-then-B 方式(如 GPipe),峰值显存可以节省 37.5%。这意味着在相同的显存条件下,1F1B 可以训练更大的模型。 - 资源利用率的提升
1F1B 非交错式调度模式通过及时释放中间变量,使得每个设备可以更高效地利用显存,从而提高了整体资源利用率。在 GPipe 中,由于需要缓存多个 micro-batch 的中间变量,显存利用率较低,导致设备空闲时间增加。 - 减少流水线刷新(Pipeline Flush)
在 GPipe 中,将 mini-batch 切分成多个 micro-batch 后,会带来更频繁的流水线刷新,这降低了硬件效率,导致空闲时间的增加。而 1F1B 非交错式调度模式通过交错执行前向和后向计算,减少了流水线刷新的频率,从而提高了硬件效率。
具体示例
假设我们有4个设备(设备0、设备1、设备2、设备3),每个设备负责模型的一部分层。模型被分成4个部分,每个设备负责一个部分。具体来说:
设备0:负责层1-4
设备1:负责层5-8
设备2:负责层9-12
设备3:负责层13-16
在 GPipe 中,每个设备需要缓存4个 micro-batch 的中间变量,如下图所示:
复制
设备0: [F1, F2, F3, F4] -> [B1, B2, B3, B4]
设备1: [F1, F2, F3, F4] -> [B1, B2, B3, B4]
设备2: [F1, F2, F3, F4] -> [B1, B2, B3, B4]
设备3: [F1, F2, F3, F4] -> [B1, B2, B3, B4]
在 1F1B 非交错式调度模式中,每个设备在完成一个 micro-batch 的前向计算后,立即进行后向计算,如下图所示:
复制
设备0: [F1 -> B1, F2 -> B2, F3 -> B3, F4 -> B4]
设备1: [F1 -> B1, F2 -> B2, F3 -> B3, F4 -> B4]
设备2: [F1 -> B1, F2 -> B2, F3 -> B3, F4 -> B4]
设备3: [F1 -> B1, F2 -> B2, F3 -> B3, F4 -> B4]
在 1F1B 模式下,每个设备只需要缓存当前正在处理的 micro-batch 的中间变量,显著减少了显存占用。
总结
1F1B 非交错式调度模式通过及时释放中间变量,减少了显存占用,提高了资源利用率,从而在内存节省方面表现更好。虽然 1F1B 非交错式调度模式在完成一轮计算的时间上与 GPipe 相同,但它在内存管理和资源利用方面具有显著优势。