分布式训练中的 rank 和 local_rank
在分布式训练中,rank
和 local_rank
是两个不同的概念,它们分别表示不同层次的进程标识符。理解这两者的区别和关系对于正确设置分布式训练环境至关重要。
rank
和 local_rank
的定义
-
rank
:- 全局唯一标识符(Global Rank),表示当前进程在整个分布式训练集群中的全局编号。
- 范围是从
0
到world_size - 1
,其中world_size
是参与分布式训练的所有进程的总数。 - 每个进程都有一个唯一的
rank
,用于区分不同的进程。
-
local_rank
:- 本地设备标识符(Local Rank),表示当前进程在单个节点(即一台机器)上的设备编号。
- 范围是从
0
到num_devices_per_node - 1
,其中num_devices_per_node
是该节点上的设备数量(如 GPU 或 NPU 数量)。 - 用于指定当前进程使用的具体设备(如 GPU 或 NPU)。
关系与区别
-
关系:
rank
是全局的,标识了每个进程在整个集群中的位置。local_rank
是局部的,标识了每个进程在其所在节点上的设备位置。- 可以通过
rank
和world_size
计算出local_rank
,公式如下:local_rank = rank % num_devices_per_node
- 其中
num_devices_per_node
是每个节点上的设备数量(例如,在一个有4个GPU的节点上,num_devices_per_node
就是4)。
-
区别:
rank
是整个集群范围内的唯一标识符,而local_rank
是节点范围内的标识符。rank
通常用于控制全局操作(如同步所有进程),而local_rank
用于选择特定的设备进行计算。
示例说明
假设你有一个由两台机器组成的集群,每台机器上有两个NPU:
- 第一台机器:NPU 0 和 NPU 1
- 第二台机器:NPU 0 和 NPU 1
在这种情况下:
world_size
是 4(因为总共有4个进程)num_devices_per_node
是 2(因为每台机器有两个NPU)
进程分配
机器 | NPU 设备 | rank | local_rank |
---|---|---|---|
1 | NPU 0 | 0 | 0 |
1 | NPU 1 | 1 | 1 |
2 | NPU 0 | 2 | 0 |
2 | NPU 1 | 3 | 1 |
代码解释
import os
import torch.distributed as dist
import torch_npu
# 初始化过程组
dist.init_process_group("nccl", init_method=f"tcp://{addr}:{port}",
rank=rank_id,
world_size=world_size)
# 获取 local_rank
local_rank = int(os.environ.get('LOCAL_RANK', 0))
# 设置 NPU 设备
torch_npu.npu.set_device(local_rank)
解释
-
dist.init_process_group
:- 使用
rank=rank_id
来指定当前进程的全局编号。 world_size
是总的进程数,包括所有节点上的所有进程。
- 使用
-
获取
local_rank
:- 通常通过环境变量
LOCAL_RANK
获取当前进程在本地节点上的设备编号。 - 如果环境变量不存在,默认值为
0
。
- 通常通过环境变量
-
设置设备:
- 使用
torch_npu.npu.set_device(local_rank)
来指定当前进程使用的 NPU 设备。
- 使用
示例代码
以下是一个完整的示例,展示了如何在分布式训练中使用 rank
和 local_rank
:
import os
import torch
import torch.distributed as dist
import torch_npu
def setup(rank, world_size, addr, port):
# 初始化进程组
dist.init_process_group(
backend="nccl", # 使用 nccl 后端,适用于 GPU/NPU
init_method=f"tcp://{addr}:{port}",
rank=rank,
world_size=world_size
)
# 获取 local_rank
local_rank = int(os.environ.get('LOCAL_RANK', 0))
# 设置 NPU 设备
torch_npu.npu.set_device(local_rank)
def cleanup():
dist.destroy_process_group()
def train(rank, world_size, addr, port):
setup(rank, world_size, addr, port)
# 设置设备
device = torch.device(f"npu:{local_rank}")
# 创建模型并移动到相应的设备
model = SimpleModel().to(device)
ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
# 其他训练逻辑...
cleanup()
if __name__ == "__main__":
# 配置信息
world_size = 4 # 总共4个进程
addr = "192.168.1.1" # 主节点的IP地址
port = "12355" # 通信使用的端口
# 启动每个进程
import multiprocessing as mp
processes = []
for rank in range(world_size):
p = mp.Process(target=train, args=(rank, world_size, addr, port))
p.start()
processes.append(p)
for p in processes:
p.join()