当前位置: 首页 > article >正文

PyTorch DDP快速上手附代码

如标题所示,博主有任务 想要快速上手DDP以测试RDMA介入下的延迟带宽.

环境准备:

验证PyTorch:

python -c "import torch; print(torch.__version__, torch.cuda.device_count())"

核心概念:

进程组:每个GPU对应一个独立进程
​AllReduce:跨卡聚合数据(如梯度)
​torchrun:官方推荐启动工具
在这里插入图片描述

测试两卡all-reduce

看看博主的卡:
在这里插入图片描述

import torch
import torch.distributed as dist

dist.init_process_group(backend='nccl')
rank = dist.get_rank()
world_size = dist.get_world_size()

print(f'I am rank: {rank}  of world {world_size}')

input_tensor = torch.rand([1024, 1024, 10], dtype=torch.float).to('cuda:%d' % rank)
input_tensor.fill_(1.0)

dist.all_reduce(input_tensor)

print(input_tensor[0][0])

结果:
在这里插入图片描述

还想进一步测试不同buffer size下的延迟和带宽情况:

import os
import torch
import torch.distributed as dist
import time

os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'

dist.init_process_group(backend='nccl')
rank = dist.get_rank()
torch.cuda.set_device(rank)
world_size = dist.get_world_size()

buffer_sizes = [1, 1024, 2048, 4096, 8192, 
                256 * 1024, 512 * 1024, 
                1024 * 1024, 4 * 1024 * 1024, 
                256 * 1024 * 1024]

def benchmark(size):
    input_tensor = torch.ones(size, dtype=torch.float32).to(f'cuda:{rank}')
    
    torch.cuda.synchronize()
    start = time.time()
    
    dist.all_reduce(input_tensor)
    
    torch.cuda.synchronize()
    elapsed = time.time() - start
    
    if rank == 0:
        data_size = size * 4  # float32占4字节
        bandwidth = (data_size / elapsed) / (1024**2)  # MB/s
        print(f"Size: {size/1024:.1f} KB\tLatency: {elapsed*1000:.3f} ms\tBandwidth: {bandwidth:.2f} MB/s")

if rank == 0:
    print("Buffer Size (KB)\tLatency (ms)\tBandwidth (MB/s)")
    print("------------------------------------------------")

for size in buffer_sizes:
    benchmark(size)

dist.destroy_process_group()

结果:
在这里插入图片描述

至此 单机两卡h20测试成功 带宽延迟符合预期


http://www.kler.cn/a/610950.html

相关文章:

  • 【大模型开发】将vocab解码
  • webpackVSVite热更新本质区别
  • 测试 SpatialLM 空间语义识别
  • AI之山,鸿蒙之水,画一幅未来之家
  • Python正则表达式(一)
  • 鸿蒙系统起飞!Flutter 完全适配指南CSDN2021-01-23 02:47
  • 深度解析:4G路由器CPE性能测试的五大关键指标
  • 影刀魔法指令3.0:开启自动化新篇章
  • 编写简单的小程序
  • SpringCloud入门、搭建、调试、源代码
  • Flink 常用及优化参数
  • Serverless架构的应用场景
  • 文件上传的小点总结
  • 自然语言处理(11:RNN(RNN的前置知识和引入)
  • 学习爬虫的第二天——分页爬取并存入表中
  • NO.58十六届蓝桥杯备战|基础算法-枚举|普通枚举|二进制枚举|铺地毯|回文日期|扫雷|子集|费解的开关|Even Parity(C++)
  • Spring MVC 配置详解与入门案例
  • 3ds Max 2026 新功能全面解析
  • husky的简介以及如果想要放飞自我的解决方案
  • Linux centos 7 vsftp本地部署脚本