torch.distributed.all_reduce
是 PyTorch 中分布式通信
的一部分,通常用于分布式训练场景下的梯度汇总
。在分布式训练中,每个参与的进程都有自己的一部分数据
和模型
,并行计算
其梯度
或更新参数
。为了确保这些进程中的模型能够同步,需要将不同进程中的梯度汇总
,all_reduce
是实现这一过程的常用操作。
注:reduce在英文中也有归纳、简化的意思。
函数原型
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp