当前位置: 首页 > article >正文

Tensor 基本操作5 device 管理,使用 GPU 设备 | PyTorch 深度学习实战

前一篇文章,Tensor 基本操作4 理解 indexing,加减乘除和 broadcasting 运算 | PyTorch 深度学习实战

本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started

Tensor 基本使用

  • 检查设备
  • 创建 tensor 时声明设备
  • 更改默认设备
  • 创建 tensor 后移动 tensor.to
  • 注意事项
    • 1. 运算的发生位置
    • 2. 当两个 tensor 进行运算时,需要在同一个设备上

检查设备

  • 默认为 CPU,根据是否有 GPU 设定 device。
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
  • 检查默认设备
device = torch.get_default_device()

创建 tensor 时声明设备

    # 声明为 CPU 设备
    device_cpu = torch.device('cpu')
    points_cpu = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]], device=device_cpu)

更改默认设备

将默认设备,设置为 GPU。

device_gpu = torch.device('cuda')
torch.set_default_device(device_gpu) 
points_default = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]]) # 此时,points_default  被定义到 GPU 上

创建 tensor 后移动 tensor.to

将一个 tensor 移动到指定设备。

    device_gpu = torch.device('cuda')
    points2 = points.to(device_gpu)  # 将 Tensor 复制到 GPU
    print(points2)

注意事项

1. 运算的发生位置

    device_gpu = torch.device('cuda')
    points2 = points.to(device_gpu)  # 将 Tensor 复制到 GPU
    points3 = points2 * 2   # points3 还是在 GPU 上
    points4 = points2 + 2  # points4 还是在 GPU 上

但是,打印 points3 或 points4 时,将会复制该值到 CPU 上输出。

2. 当两个 tensor 进行运算时,需要在同一个设备上

  File "C:\devel\Python\Python311\Lib\site-packages\torch\utils\_device.py", line 79, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

http://www.kler.cn/a/521972.html

相关文章:

  • 2025春晚刘谦魔术揭秘魔术过程
  • Vuex中的getter和mutation有什么区别
  • Git进阶之旅:Git 配置信息 Config
  • AI大模型开发原理篇-2:语言模型雏形之词袋模型
  • PostgreSQL 数据备份与恢复:掌握 pg_dump 和 pg_restore 的最佳实践
  • 使用 Redis List 和 Pub/Sub 实现简单的消息队列
  • 低代码系统-产品架构案例介绍、明道云(十一)
  • Spring中@RequestBody、@PathVariable、@RequestParam三个注解详解
  • 如何用前端技术开发一个浪漫的生日祝福网站
  • 豆包MarsCode:前缀和计算问题
  • 【flutter版本升级】【Nativeshell适配】nativeshell需要做哪些更改
  • 《深度揭秘:TPU张量计算架构如何重塑深度学习运算》
  • npm常见报错整理
  • .strip()用法
  • Nacos统一配置管理
  • read+write实现:链表放到文件+文件数据放到链表 的功能
  • 第1章 量子暗网中的血色黎明
  • 17【棋牌游戏到底有没有透视】
  • games101-(3/4)变换
  • 弹性分组环——RPR技术
  • python Fabric在自动化部署中的应用
  • 使用 Python 和 scikit-learn 实现 KNN 分类:以鸢尾花数据集为例
  • 【由浅入深认识Maven】第3部分 maven多模块管理
  • fastadmin中require-form.js的data-favisible控制显示隐藏
  • 基于Flask的哔哩哔哩综合指数UP榜单数据分析系统的设计与实现
  • S4 HANA定义税码(FTXP)