在 Cloud TPU Pod 上训练 PyTorch 模型
准备工作
在 Cloud TPU Pod 上开始分布式训练之前,请验证您的模型可在单个 v2-8 或 v3-8 Cloud TPU 设备上正常训练。如果您的模型在单个设备上出现明显的性能问题,请参阅最佳做法和问题排查指南。
单个 TPU 设备成功训练后,请执行以下步骤,在 Cloud TPU Pod 上设置和训练:
-
配置
gcloud
命令。 -
【可选】 将虚拟机磁盘映像捕获到虚拟机映像中。
-
通过虚拟机映像创建实例模板。
-
通过实例模板创建实例组。
-
通过 SSH 连接到您的 Compute Engine 虚拟机
-
验证防火墙规则以允许虚拟机之间进行通信。
-
创建 Cloud TPU Pod。
-
在 Pod 上运行分布式训练。
-
清理。
配置 gcloud
命令
使用 gcloud
配置 Google Cloud 项目:
为项目 ID 创建一个变量。
export PROJECT_ID=project-id
将项目 ID 设置为 gcloud
中的默认项目
gcloud config set project ${PROJECT_ID}
当您第一次在新的 Cloud Shell 虚拟机中运行此命令时,系统会显示 Authorize Cloud Shell
页面。点击页面底部的 Authorize
,以允许 gcloud
使用您的凭据进行 API 调用。
使用 gcloud
配置默认区域:
gcloud config set compute/zone europe-west4-a
[可选] 捕获虚拟机磁盘映像
您可以使用您用于训练单个 TPU 的虚拟机中的磁盘映像(已安装数据集、软件包等)。在创建映像之前,请使用 gcloud
命令停止虚拟机:
gcloud compute instances stop vm-name
接下来,使用 gcloud
命令创建虚拟机映像:
gcloud compute images create image-name \
--source-disk instance-name \
--source-disk-zone europe-west4-a \
--family=torch-xla \
--storage-location europe-west4
通过虚拟机映像创建实例模板
创建默认实例模板。创建实例模板时,您可以使用您在上述步骤中创建的虚拟机映像,或者您可以使用 Google 提供的公开 PyTorch/XLA 映像。如需创建实例模板,请使用 gcloud
命令:
gcloud compute instance-templates create instance-template-name \
--machine-type n1-standard-16 \
--image-project=${PROJECT_ID} \
--image=image-name \
--scopes=https://www.googleapis.com/auth/cloud-platform
通过实例模板创建实例组
gcloud compute instance-groups managed create instance-group-name \
--size 4 \
--template template-name \
--zone europe-west4-a
通过 SSH 连接到您的 Compute Engine 虚拟机
创建实例组后,请通过 SSH 连接到实例组中的某个实例(虚拟机)。使用 gcloud
命令列出实例组中的所有实例:
gcloud compute instance-groups list-instances instance-group-name
通过 SSH 连接到使用 list-instances
命令列出的某个实例。
gcloud compute ssh instance-name --zone=europe-west4-a
验证实例组中的虚拟机是否可以相互通信
使用 nmap
命令验证实例组中的虚拟机是否可以相互通信。从您连接的虚拟机运行 nmap
命令,将 instance-name 替换为实例组中另一个虚拟机的实例名称。
(vm)$ nmap -Pn -p 8477 instance-name
Starting Nmap 7.40 ( https://nmap.org ) at 2019-10-02 21:35 UTC
Nmap scan report for pytorch-20190923-n4tx.c.jysohntpu.internal (10.164.0.3)
Host is up (0.00034s latency).
PORT STATE SERVICE
8477/tcp closed unknown
只要 STATE 字段未显示 filtered,就表示防火墙规则设置正确。
创建 Cloud TPU Pod
gcloud compute tpus create tpu-name \
--zone=europe-west4-a \
--network=default \
--accelerator-type=v2-32 \
--version=pytorch-1.13
在 Pod 上运行分布式训练
- 从虚拟机会话窗口中,导出 Cloud TPU 名称并激活 conda 环境。
(vm)$ export TPU_NAME=tpu-name
(vm)$ conda activate torch-xla-1.13
- 运行训练脚本:
(torch-xla-1.13)$ python -m torch_xla.distributed.xla_dist \
--tpu=$TPU_NAME \
--conda-env=torch-xla-1.13 \
--env XLA_USE_BF16=1 \
--env ANY_OTHER=ENV_VAR \
-- python /usr/share/torch-xla-1.13/pytorch/xla/test/test_train_mp_imagenet.py \
--fake_data
运行完上述命令后,您应该会看到如下所示的输出(注意,这使用的是 --fake_data
)。在 v3-32 TPU Pod 上,训练需要大约 1/2 小时。
2020-08-06 02:38:29 [] Command to distribute: "python" "/usr/share/torch-xla-nightly/pytorch/xla/test/test_train_mp_imagenet.py" "--fake_data"
2020-08-06 02:38:29 [] Cluster configuration: {client_workers: [{10.164.0.43, n1-standard-96, europe-west4-a, my-instance-group-hm88}, {10.164.0.109, n1-standard-96, europe-west4-a, my-instance-group-n3q2}, {10.164.0.46, n1-standard-96, europe-west4-a, my-instance-group-s0xl}, {10.164.0.49, n1-standard-96, europe-west4-a, my-instance-group-zp14}], service_workers: [{10.131.144.61, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.59, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.58, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.60, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}]}
2020-08-06 02:38:31 10.164.0.43 [0] % Total % Received % Xferd Average Speed Time Time Time Current
2020-08-06 02:38:31 10.164.0.43 [0] Dload Upload Total Spent Left Speed
100 19 100 19 0 0 2757 0 --:--:-- --:--:-- --:--:-- 3166
2020-08-06 02:38:34 10.164.0.43 [0] % Total % Received % Xferd Average Speed Time Time Time Current
2020-08-06 02:38:34 10.164.0.43 [0] Dload Upload Total Spent Left Speed
100 19 100 19 0 0 2623 0 --:--:-- --:--:-- --:--:-- 2714
2020-08-06 02:38:37 10.164.0.46 [2] % Total % Received % Xferd Average Speed Time Time Time Current
2020-08-06 02:38:37 10.164.0.46 [2] Dload Upload Total Spent Left Speed
100 19 100 19 0 0 2583 0 --:--:-- --:--:-- --:--:-- 2714
2020-08-06 02:38:37 10.164.0.49 [3] % Total % Received % Xferd Average Speed Time Time Time Current
2020-08-06 02:38:37 10.164.0.49 [3] Dload Upload Total Spent Left Speed
100 19 100 19 0 0 2530 0 --:--:-- --:--:-- --:--:-- 2714
2020-08-06 02:38:37 10.164.0.109 [1] % Total % Received % Xferd Average Speed Time Time Time Current
2020-08-06 02:38:37 10.164.0.109 [1] Dload Upload Total Spent Left Speed
100 19 100 19 0 0 2317 0 --:--:-- --:--:-- --:--:-- 2375
2020-08-06 02:38:40 10.164.0.46 [2] % Total % Received % Xferd Average Speed Time Time Time Current
2020-08-06 02:38:40 10.164.0.49 [3] % Total % Received % Xferd Average Speed Time Time Time Current
2020-08-06 02:38:40 10.164.0.46 [2] Dload Upload Total Spent Left Speed
2020-08-06 02:38:40 10.164.0.49 [3] Dload Upload Total Spent Left Speed
100 19 100 19 0 0 2748 0 --:--:-- --:--:-- --:--:-- 3166
100 19 100 19 0 0 2584 0 --:--:-- --:--:-- --:--:-- 2714
2020-08-06 02:38:40 10.164.0.109 [1] % Total % Received % Xferd Average Speed Time Time Time Current
2020-08-06 02:38:40 10.164.0.109 [1] Dload Upload Total Spent Left Speed
100 19 100 19 0 0 2495 0 --:--:-- --:--:-- --:--:-- 2714
2020-08-06 02:38:43 10.164.0.49 [3] % Total % Received % Xferd Average Speed Time Time Time Current
2020-08-06 02:38:43 10.164.0.49 [3] Dload Upload Total Spent Left Speed
100 19 100 19 0 0 2654 0 --:--:-- --:--:-- --:--:-- 2714
2020-08-06 02:38:43 10.164.0.43 [0] % Total % Received % Xferd Average Speed Time Time Time Current
2020-08-06 02:38:43 10.164.0.43 [0] Dload Upload Total Spent Left Speed
100 19 100 19 0 0 2784 0 --:--:-- --:--:-- --:--:-- 3166
2020-08-06 02:38:43 10.164.0.46 [2] % Total % Received % Xferd Average Speed Time Time Time Current
2020-08-06 02:38:43 10.164.0.46 [2] Dload Upload Total Spent Left Speed
100 19 100 19 0 0 2691 0 --:--:-- --:--:-- --:--:-- 3166
2020-08-06 02:38:43 10.164.0.109 [1] % Total % Received % Xferd Average Speed Time Time Time Current
2020-08-06 02:38:43 10.164.0.109 [1] Dload Upload Total Spent Left Speed
100 19 100 19 0 0 2589 0 --:--:-- --:--:-- --:--:-- 2714
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/14 Epoch=1 Step=0 Loss=6.87500 Rate=258.47 GlobalRate=258.47 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/15 Epoch=1 Step=0 Loss=6.87500 Rate=149.45 GlobalRate=149.45 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] Epoch 1 train begin 02:38:52
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:1/0 Epoch=1 Step=0 Loss=6.87500 Rate=25.72 GlobalRate=25.72 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/6 Epoch=1 Step=0 Loss=6.87500 Rate=89.01 GlobalRate=89.01 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/1 Epoch=1 Step=0 Loss=6.87500 Rate=64.15 GlobalRate=64.15 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.87500 Rate=93.19 GlobalRate=93.19 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/7 Epoch=1 Step=0 Loss=6.87500 Rate=58.78 GlobalRate=58.78 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] Epoch 1 train begin 02:38:56
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:1/8 Epoch=1 Step=0 Loss=6.87500 Rate=100.43 GlobalRate=100.43 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/13 Epoch=1 Step=0 Loss=6.87500 Rate=66.83 GlobalRate=66.83 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/11 Epoch=1 Step=0 Loss=6.87500 Rate=64.28 GlobalRate=64.28 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/10 Epoch=1 Step=0 Loss=6.87500 Rate=73.17 GlobalRate=73.17 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/9 Epoch=1 Step=0 Loss=6.87500 Rate=27.29 GlobalRate=27.29 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/12 Epoch=1 Step=0 Loss=6.87500 Rate=110.29 GlobalRate=110.29 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/20 Epoch=1 Step=0 Loss=6.87500 Rate=100.85 GlobalRate=100.85 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/22 Epoch=1 Step=0 Loss=6.87500 Rate=93.52 GlobalRate=93.52 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/23 Epoch=1 Step=0 Loss=6.87500 Rate=165.86 GlobalRate=165.86 Time=02:38:57
清理
为避免因本教程中使用的资源导致您的 Google Cloud 帐号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。
- 与 Compute Engine 虚拟机断开连接:
exit
- 删除实例组:
gcloud compute instance-groups managed delete instance-group-name
- 删除 TPU Pod:
gcloud compute tpus delete ${TPU_NAME} --zone=europe-west4-a
- 删除实例组模板:
gcloud compute instance-templates delete instance-template-name
- [可选] 删除您的虚拟机磁盘映像:
gcloud compute images delete image-name
后续步骤
试用 PyTorch Colab:
- 在 Cloud TPU 上开始使用 PyTorch
- 在 TPU 上训练 MNIST
- 使用 Cifar10 数据集在 TPU 上训练 ResNet18
- 使用预训练的 ResNet50 模型进行推理
- 快速神经风格转移
- 在 Fashion MNIST 上使用多核心训练 AlexNet
- 在 Fashion MNIST 上使用单核心训练 AlexNet
文章来源:google cloud
推荐阅读
- 在 Cloud TPU 上训练 NCF (TF 2.x)
- 在 Cloud TPU 上训练 DLRM 和 DCN (TF 2.x)
- 在 Cloud TPU 上训练 Mask RCNN (TF 2.x)
更多芯擎AI开发板干货请关注芯擎AI开发板专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。