当前位置: 首页 > article >正文

使用DeepSpeed进行多机多卡训练模型

在 DeepSpeed 中进行多机多卡训练时,需要进行以下几方面的配置和设置:

1. 设置主节点和节点之间的通信

为了使各节点(机器)能够相互通信,你需要设置主节点的地址和端口。DeepSpeed 使用 torch.distributed 进行进程间通信,因此要确保以下环境变量配置正确:

  • 主节点地址和端口:使用 MASTER_ADDRMASTER_PORT 环境变量指定主节点的 IP 和端口号。假设主节点 IP 为 192.168.1.1,可以这样设置:
export MASTER_ADDR=192.168.1.1
export MASTER_PORT=29500
  • 总节点数和 GPU 数量:通过 WORLD_SIZE 变量定义所有节点的 GPU 总数量。例如,有 2 台机器,每台有 4 张 GPU,那么 WORLD_SIZE 应设为 8:
export WORLD_SIZE=8

2. 启动主节点和从节点

在主节点和从节点上分别启动训练进程。主节点使用 deepspeed 启动脚本,而从节点使用 deepspeed --hostfiledeepspeed.launcher 启动进程。

例如,在主节点上运行:

deepspeed --num_gpus=4 train.py

在从节点上运行以下命令:

deepspeed --num_gpus=4 --master_addr=192.168.1.1 --master_port=29500 --node_rank=1 train.py

其中,node_rank 区分主节点和从节点;master_addrmaster_port 需要与主节点一致。

3. 代码调整

代码方面的调整较少,DeepSpeed 会自动处理多机分布式训练。确保在 deepspeed.initialize() 中正确配置参数:

model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    optimizer=optimizer,
    config=ds_config
)

ds_config 中可根据多机多卡环境调整批量大小、梯度累积等超参数,DeepSpeed 会自动管理训练过程。

4. 使用 hostfile 文件(可选)

多机训练时可以用一个 hostfile 文件列出所有节点的主机名和 GPU 数量,例如:

192.168.1.1 slots=4
192.168.1.2 slots=4

然后指定使用该文件:

deepspeed --hostfile=hostfile train.py

5. DeepSpeed 配置文件设置

ds_config 配置文件中,不需特别调整机器数量相关的配置,主要是根据训练环境进行优化。例如:

{
  "train_batch_size": 32,
  "gradient_accumulation_steps": 1,
  "fp16": {
    "enabled": true
  },
  "zero_optimization": {
    "stage": 2,
    "contiguous_gradients": true,
    "reduce_scatter": true,
    "overlap_comm": true
  }
}

6. 启动命令汇总

假设有 2 台机器,每台 4 张 GPU,启动步骤如下:

  • 主节点
deepspeed --num_gpus=4 train.py
  • 从节点
deepspeed --num_gpus=4 --master_addr=192.168.1.1 --master_port=29500 --node_rank=1 train.py

完成这些步骤即可启动多机多卡训练。


http://www.kler.cn/a/370046.html

相关文章:

  • 63,【3】buuctf web Upload-Labs-Linux 1
  • docker 部署confluence
  • 简识JVM私有内存区域栈、数据结构
  • Golang:使用DuckDB查询Parquet文件数据
  • Tensor 基本操作1 unsqueeze, squeeze, softmax | PyTorch 深度学习实战
  • Java面试专题——面向对象
  • Bug|空心病,不知道自己要干什么
  • 大语言模型数据流程源码解读(基于llama3模型)
  • 自己搭建[文本转语音]服务器
  • 2024 Rust现代实用教程:1.2编译器与包管理工具以及开发环境搭建
  • C++基于opencv的视频质量检测--图像清晰度检测
  • electron 监听窗口高端变化
  • JS | CommonJS、AMD、CMD、ES6-Module、UMD五种JS模块化规范
  • 海外发稿:探索海外外媒宣发分发渠道-大舍传媒
  • 如何使用VBA识别Excel中的“单元格中的图片”(1/2)
  • 着色器的认识
  • JMeter之JMX文件解释
  • Windows驱动开发(三)—— 驱动和应用层通信的几种方式
  • Openpyxl--学习记录
  • 【文心智能体 | AI大师工坊】如何使用智能体插件,完成一款旅游类智能体的开发,来体验一下我的智能体『​​​​​​​背包客』
  • 如何将 Excel 数据转换为 SQL 脚本:基于 Java 的全面解析
  • 问:数据库SQL优化实践整理?
  • python 相关
  • Android--简易计算器实现
  • Redis中Lua脚本的使用场景
  • 深度学习领域如何正确地读取视频