Pytorch分布式训练杂记1
-
dist.broadcast_object_list
- 功能:这个函数用于在分布式训练中,将某个进程上的对象列表广播给所有其他进程。它可以确保所有进程在训练开始前或训练过程中共享相同的数据,尤其是在多进程场景中每个进程都需要访问相同的对象时。
- 用法:
import torch.distributed as dist # 假设我们有一个对象列表 obj_list obj_list = [some_object] # 从 rank 0 广播对象列表给所有其他进程 dist.broadcast_object_list(obj_list, src=0)
obj_list
:要广播的对象列表。src
:源进程的 rank(一般是 0),该进程上的对象会广播给其他进程。
-
dist.barrier()
- 功能:这个函数是一个同步机制。它会让所有进程在调用此函数后,必须等待其他进程也到达这个屏障点才能继续执行。通常用于确保所有进程都完成了某些步骤(比如数据加载)后再继续训练,避免进程之间的不同步问题。
- 用法:
import torch.distributed as dist # 阻塞所有进程,直到每个进程都到达这个屏障 dist.barrier()
-
dist.get_rank()
- 功能:这个函数返回当前进程的 rank,即进程在分布式训练中的编号。每个进程在初始化时都会被分配一个唯一的 rank,rank 为 0 的进程通常被称为主进程(主节点),负责某些全局操作,比如保存模型。
- 用法:
import torch.distributed as dist # 获取当前进程的 rank rank = dist.get_rank() print(f"Current rank: {rank}")
-
ddp = int(os.environ.get("RANK", -1)) != -1
- 功能:这一行代码通过检查环境变量
RANK
来判断当前是否处于分布式环境中。RANK
环境变量在分布式训练中通常由框架自动设置,表示进程的编号。如果RANK
存在且不是-1
,则表明当前进程正在参与分布式训练。 - 解释:
import os ddp = int(os.environ.get("RANK", -1)) != -1 if ddp: print("We are in a distributed environment.") else: print("Not in a distributed environment.")
这段代码通过检查
RANK
是否存在(并且不是-1
)来判断是否处于分布式训练模式下,ddp
变量的值为True
表示是分布式环境,False
表示不是。 - 功能:这一行代码通过检查环境变量