当前位置: 首页 > article >正文

在 Cloud TPU Pod 上训练 PyTorch 模型

准备工作

在 Cloud TPU Pod 上开始分布式训练之前,请验证您的模型可在单个 v2-8 或 v3-8 Cloud TPU 设备上正常训练。如果您的模型在单个设备上出现明显的性能问题,请参阅最佳做法和问题排查指南。

单个 TPU 设备成功训练后,请执行以下步骤,在 Cloud TPU Pod 上设置和训练:

  1. 配置 gcloud 命令。

  2. 【可选】 将虚拟机磁盘映像捕获到虚拟机映像中。

  3. 通过虚拟机映像创建实例模板。

  4. 通过实例模板创建实例组。

  5. 通过 SSH 连接到您的 Compute Engine 虚拟机

  6. 验证防火墙规则以允许虚拟机之间进行通信。

  7. 创建 Cloud TPU Pod。

  8. 在 Pod 上运行分布式训练。

  9. 清理。

配置 gcloud 命令

使用 gcloud 配置 Google Cloud 项目:

为项目 ID 创建一个变量。

export PROJECT_ID=project-id

将项目 ID 设置为 gcloud 中的默认项目

gcloud config set project ${PROJECT_ID}

当您第一次在新的 Cloud Shell 虚拟机中运行此命令时,系统会显示 Authorize Cloud Shell 页面。点击页面底部的 Authorize,以允许 gcloud 使用您的凭据进行 API 调用。

使用 gcloud 配置默认区域:

gcloud config set compute/zone europe-west4-a

[可选] 捕获虚拟机磁盘映像

您可以使用您用于训练单个 TPU 的虚拟机中的磁盘映像(已安装数据集、软件包等)。在创建映像之前,请使用 gcloud 命令停止虚拟机:

gcloud compute instances stop vm-name

接下来,使用 gcloud 命令创建虚拟机映像:

gcloud compute images create image-name  \
    --source-disk instance-name \
    --source-disk-zone europe-west4-a \
    --family=torch-xla \
    --storage-location europe-west4

通过虚拟机映像创建实例模板

创建默认实例模板。创建实例模板时,您可以使用您在上述步骤中创建的虚拟机映像,或者您可以使用 Google 提供的公开 PyTorch/XLA 映像。如需创建实例模板,请使用 gcloud 命令:

gcloud compute instance-templates create instance-template-name \
    --machine-type n1-standard-16 \
    --image-project=${PROJECT_ID} \
    --image=image-name \
    --scopes=https://www.googleapis.com/auth/cloud-platform

通过实例模板创建实例组

gcloud compute instance-groups managed create instance-group-name \
    --size 4 \
    --template template-name \
    --zone europe-west4-a

通过 SSH 连接到您的 Compute Engine 虚拟机

创建实例组后,请通过 SSH 连接到实例组中的某个实例(虚拟机)。使用 gcloud 命令列出实例组中的所有实例:

gcloud compute instance-groups list-instances instance-group-name

通过 SSH 连接到使用 list-instances 命令列出的某个实例。

gcloud compute ssh instance-name --zone=europe-west4-a

验证实例组中的虚拟机是否可以相互通信

使用 nmap 命令验证实例组中的虚拟机是否可以相互通信。从您连接的虚拟机运行 nmap 命令,将 instance-name 替换为实例组中另一个虚拟机的实例名称。

(vm)$ nmap -Pn -p 8477 instance-name

Starting Nmap 7.40 ( https://nmap.org ) at 2019-10-02 21:35 UTC
Nmap scan report for pytorch-20190923-n4tx.c.jysohntpu.internal (10.164.0.3)
Host is up (0.00034s latency).
PORT     STATE  SERVICE
8477/tcp closed unknown

只要 STATE 字段未显示 filtered,就表示防火墙规则设置正确。

创建 Cloud TPU Pod

gcloud compute tpus create tpu-name \
    --zone=europe-west4-a \
    --network=default \
    --accelerator-type=v2-32 \
    --version=pytorch-1.13

在 Pod 上运行分布式训练

  1. 从虚拟机会话窗口中,导出 Cloud TPU 名称并激活 conda 环境。
 (vm)$ export TPU_NAME=tpu-name
    
    (vm)$ conda activate torch-xla-1.13
  1. 运行训练脚本:
    (torch-xla-1.13)$ python -m torch_xla.distributed.xla_dist \
          --tpu=$TPU_NAME \
          --conda-env=torch-xla-1.13 \
          --env XLA_USE_BF16=1 \
          --env ANY_OTHER=ENV_VAR \
          -- python /usr/share/torch-xla-1.13/pytorch/xla/test/test_train_mp_imagenet.py \
          --fake_data

运行完上述命令后,您应该会看到如下所示的输出(注意,这使用的是 --fake_data)。在 v3-32 TPU Pod 上,训练需要大约 1/2 小时。

2020-08-06 02:38:29  [] Command to distribute: "python" "/usr/share/torch-xla-nightly/pytorch/xla/test/test_train_mp_imagenet.py" "--fake_data"
2020-08-06 02:38:29  [] Cluster configuration: {client_workers: [{10.164.0.43, n1-standard-96, europe-west4-a, my-instance-group-hm88}, {10.164.0.109, n1-standard-96, europe-west4-a, my-instance-group-n3q2}, {10.164.0.46, n1-standard-96, europe-west4-a, my-instance-group-s0xl}, {10.164.0.49, n1-standard-96, europe-west4-a, my-instance-group-zp14}], service_workers: [{10.131.144.61, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.59, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.58, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.60, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}]}
2020-08-06 02:38:31 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:31 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2757      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:34 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:34 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2623      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2583      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2530      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2317      0 --:--:-- --:--:-- --:--:--  2375
2020-08-06 02:38:40 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
2020-08-06 02:38:40 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2748      0 --:--:-- --:--:-- --:--:--  3166
100    19  100    19    0     0   2584      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:40 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2495      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:43 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2654      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:43 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2784      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:43 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2691      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:43 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2589      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/14 Epoch=1 Step=0 Loss=6.87500 Rate=258.47 GlobalRate=258.47 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/15 Epoch=1 Step=0 Loss=6.87500 Rate=149.45 GlobalRate=149.45 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] Epoch 1 train begin 02:38:52
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:1/0 Epoch=1 Step=0 Loss=6.87500 Rate=25.72 GlobalRate=25.72 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/6 Epoch=1 Step=0 Loss=6.87500 Rate=89.01 GlobalRate=89.01 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/1 Epoch=1 Step=0 Loss=6.87500 Rate=64.15 GlobalRate=64.15 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.87500 Rate=93.19 GlobalRate=93.19 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/7 Epoch=1 Step=0 Loss=6.87500 Rate=58.78 GlobalRate=58.78 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] Epoch 1 train begin 02:38:56
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:1/8 Epoch=1 Step=0 Loss=6.87500 Rate=100.43 GlobalRate=100.43 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/13 Epoch=1 Step=0 Loss=6.87500 Rate=66.83 GlobalRate=66.83 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/11 Epoch=1 Step=0 Loss=6.87500 Rate=64.28 GlobalRate=64.28 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/10 Epoch=1 Step=0 Loss=6.87500 Rate=73.17 GlobalRate=73.17 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/9 Epoch=1 Step=0 Loss=6.87500 Rate=27.29 GlobalRate=27.29 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/12 Epoch=1 Step=0 Loss=6.87500 Rate=110.29 GlobalRate=110.29 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/20 Epoch=1 Step=0 Loss=6.87500 Rate=100.85 GlobalRate=100.85 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/22 Epoch=1 Step=0 Loss=6.87500 Rate=93.52 GlobalRate=93.52 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/23 Epoch=1 Step=0 Loss=6.87500 Rate=165.86 GlobalRate=165.86 Time=02:38:57

清理

为避免因本教程中使用的资源导致您的 Google Cloud 帐号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

  1. 与 Compute Engine 虚拟机断开连接:
exit
  1. 删除实例组:
gcloud compute instance-groups managed delete instance-group-name
  1. 删除 TPU Pod:
gcloud compute tpus delete ${TPU_NAME} --zone=europe-west4-a
  1. 删除实例组模板:
gcloud compute instance-templates delete instance-template-name
  1. [可选] 删除您的虚拟机磁盘映像:
gcloud compute images delete image-name

后续步骤

试用 PyTorch Colab:

  • 在 Cloud TPU 上开始使用 PyTorch
  • 在 TPU 上训练 MNIST
  • 使用 Cifar10 数据集在 TPU 上训练 ResNet18
  • 使用预训练的 ResNet50 模型进行推理
  • 快速神经风格转移
  • 在 Fashion MNIST 上使用多核心训练 AlexNet
  • 在 Fashion MNIST 上使用单核心训练 AlexNet

文章来源:google cloud

推荐阅读

  • 在 Cloud TPU 上训练 NCF (TF 2.x)
  • 在 Cloud TPU 上训练 DLRM 和 DCN (TF 2.x)
  • 在 Cloud TPU 上训练 Mask RCNN (TF 2.x)

更多芯擎AI开发板干货请关注芯擎AI开发板专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。


http://www.kler.cn/a/292696.html

相关文章:

  • C 语言 【模拟实现内存库函数】
  • 【大数据学习 | flume】flume的概述与组件的介绍
  • 比ChatGPT更酷的AI工具
  • 数据集标注txt文件读取小工具
  • Bugku CTF_Web——文件上传
  • python魔术方法的学习
  • Java重修笔记 第四十八天 TreeSet 类、TreeMap 类
  • 计算机毕设选题推荐-基于python的剧本杀预约服务平台【python-爬虫-大数据定制】
  • 人工智能在网络安全中的重要性
  • 一文讲懂扩散模型
  • 安装opengauss企业版单机流程
  • 【GD32】---- 使用GD32调试串口并实现printf打印输出
  • 修改服务器DNS解析及修改自动对时时区
  • 【Motion Forecasting】SIMPL:简单且高效的自动驾驶运动预测Baseline
  • AI时代来临,AI基础数据服务行业未来发展有哪些变化
  • 产品经理的学习笔记(全集)-持续更新
  • 基础算法题————散列/哈希/Hash
  • ElasticSearch-倒排索引 文档映射
  • 深入理解JavaScript闭包:避免常见的内存泄漏问题
  • 深度学习|模型推理:端到端任务处理
  • 【Netty】自定义网络通信协议
  • FFmpeg源码:avpriv_set_pts_info函数分析
  • SpringBoot 实战:SpringBoot整合Flink CDC,实时追踪mysql数据变动
  • Java简单实现服务器客户端通信
  • 0to1使用JWT实现登录认证
  • ubuntu24下安装pytorch3d