当前位置: 首页 > article >正文

PyTorch 2.0: 新特性与升级指南

什么是 PyTorch 2.0?

PyTorch 2.0 是 PyTorch 的最新版本,它保留了之前版本的即时执行模式(eager mode),同时引入了一个全新的编译模式。这个编译模式通过 torch.compile 函数实现,有潜力显著提升模型的训练和推理速度。

为什么是 2.0 而不是 1.14?

PyTorch 团队认为这个版本引入的新特性足以改变用户使用 PyTorch 的方式,因此决定将其命名为 2.0 而不是 1.14。

如何安装 PyTorch 2.0?

你可以通过 pip 安装最新的 nightly 版本。根据你的 CUDA 版本或是否使用 CPU,选择相应的安装命令:

# CUDA 11.8
pip3 install numpy --pre torch torchvision torchaudio --force-reinstall --index-url https://download.pytorch.org/whl/nightly/cu118

# CUDA 11.7
pip3 install numpy --pre torch torchvision torchaudio --force-reinstall --index-url https://download.pytorch.org/whl/nightly/cu117

# CPU
pip3 install numpy --pre torch torchvision torchaudio --force-reinstall --index-url https://download.pytorch.org/whl/nightly/cpu

2.0 版本的兼容性如何?

PyTorch 2.0 完全向后兼容 1.x 版本。你无需修改现有的 PyTorch 工作流程。只需添加一行代码 model = torch.compile(model) 就可以优化你的模型以使用 2.0 的新特性。

如何迁移到 PyTorch 2.0?

大多数情况下,你的代码无需任何改动就可以在 PyTorch 2.0 上运行。如果你想使用新的编译模式特性,只需要在你的模型上调用 torch.compile

import torch

def train(model, dataloader):
    model = torch.compile(model)
    for batch in dataloader:
        run_epoch(model, batch)

def infer(model, input):
    model = torch.compile(model)
    return model(**input)

PyTorch 2.0 的工作原理

当你使用 torch.compile(model) 包装你的模型时,模型会经历以下三个步骤:

  1. 图获取:模型被重写为子图块。
  2. 图降低:PyTorch 操作被分解为特定后端的核心操作。
  3. 图编译:核心操作调用相应的低级设备特定操作。

PyTorch 2.0 的新组件

  1. TorchDynamo:从 Python 字节码生成 FX 图。
  2. AOTAutograd:为 TorchDynamo 捕获的前向图生成对应的反向图。
  3. PrimTorch:将复杂的 PyTorch 操作分解为更简单和基本的操作。
  4. 后端:与 TorchDynamo 集成,将图编译为可在加速器上运行的 IR。

分布式训练

在编译模式下,DDP 和 FSDP 可以比即时执行模式快 15%(FP32)到 80%(AMP 精度)。使用 DDP 时,请确保设置 static_graph=False

遇到问题怎么办?

如果你的代码在编译模式下运行变慢或崩溃,很可能是由于图断裂(graph breaks)导致的。你可以参考 PyTorch 官方文档 来诊断和解决这些问题。

PyTorch 2.0 带来了显著的性能提升和新特性,同时保持了与旧版本的兼容性。通过简单的一行代码,你就可以享受到这些优化带来的好处.


http://www.kler.cn/a/389518.html

相关文章:

  • Visual Studio Community 2022(VS2022)安装方法
  • 当设置dialog中有el-table时,并设置el-table区域的滚动,看到el-table中多了一条横线
  • vue+高德API搭建前端Echarts图表页面
  • Tomcat下载配置
  • 【LeetCode: 215. 数组中的第K个最大元素 + 快速选择排序】
  • Web前端开发技术之HTMLCSS知识点总结
  • SwiftUI开发教程系列 - 第2章:基础布局与视图
  • 微服务之多机部署,负载均衡-LoadBalance
  • 卷积神经网络基础
  • 前缀和 so easy! 力扣.128 最长连续序列 leetcode longest-consecutive-sequence
  • 【动手学电机驱动】 STM32-FOC(2)STM32 导入和创建项目
  • 中兴光猫修改SN,MAC,修改地区,异地注册,改桥接,路由拨号
  • 今日 AI 简报|苹果推出的新框架,智源开源千万级多模态数据集,字节推出图像编辑模型,开源大语言模型和实时对话系统等
  • 24/11/7 算法笔记 PCA主成分分析
  • 【前端】JavaScript 方法速查大全-函数、正则、格式化、转换、进制、 XSS 转义(四)
  • ArkTS--应用状态
  • Linux服务器使用ps和top命令查看进程
  • 加载与存储指令及算数指令
  • HarmonyOS Next 实战卡片开发 01
  • Android CCodec Codec2 (二十)C2Buffer与Codec2Buffer
  • 深度学习中的 Dropout:原理、公式与实现解析
  • [Linux] 共享内存
  • 使用 IDEA 创建 Java 项目(二)
  • Hive:UDTF 函数
  • 优化时钟网络之时钟偏移
  • leetcode01 --- 环形链表判定