图神经网络学习笔记—纯 PyTorch 中的多 GPU 训练(专题十二)
对于许多大规模的真实数据集,可能需要在多个 GPU 上进行扩展训练。本教程将介绍如何通过 torch.nn.parallel.DistributedDataParallel
在 PyG 和 PyTorch 中设置多 GPU 训练管道,而无需任何其他第三方库(如 PyTorch Lightning)。请注意,此方法基于数据并行。这意味着每个 GPU 运行模型的相同副本;如果您希望跨设备扩展模型,可能需要研究 PyTorch FSDP。数据并行允许您通过跨 GPU 聚合梯度来增加模型的批量大小,然后在每个模型副本中共享相同的优化器步骤。普林斯顿大学的这篇 DDP+MNIST 教程 提供了一些关于该过程的精彩图示。
具体来说,本教程展示了如何在 Reddit 数据集上训练 GraphSAGE GNN 模型。为此,我们将使用 torch.nn.parallel.DistributedDataParallel
在所有可用 GPU 上进行扩展训练。我们将通过从 Python 代码中生成多个进程来实现这一点,这些进程都将执行相同的函数。在每个进程中,我们设置模型实例并通过 NeighborLoader
将数据输入模型。通过将模型包装在 torch.nn.parallel.DistributedDataParallel
中(如其官方教程所述)来同步