当前位置: 首页 > article >正文

Pytorch如何将嵌套的dict类型数据加载到GPU

在PyTorch中,您可以使用.to(device)方法将嵌套的字典中的所有支持的Tensor对象转移到GPU。以下是一个简单的例子 

import torch
 
# 假设您已经有了一个名为device的GPU设备对象
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
# 嵌套的字典,其中包含一些Tensors
nested_dict = {
    'a': torch.randn(2, 2),
    'b': {
        'b1': torch.randn(2, 2),
        'b2': torch.randn(2, 2)
    },
    'c': torch.randn(2, 2)
}
 
# 将嵌套字典中的所有Tensors移动到GPU
def to_gpu(data):
    if isinstance(data, dict):
        return {k: to_gpu(v) for k, v in data.items()}
    elif isinstance(data, list):
        return [to_gpu(i) for i in data]
    elif isinstance(data, tuple):
        return tuple([to_gpu(i) for i in data])
    elif torch.is_tensor(data) and data.device != device:
        return data.to(device)
    else:
        return data
 
nested_dict_gpu = to_gpu(nested_dict)
 
# 检查是否所有Tensors都已移动到GPU
for k, v in nested_dict_gpu.items():
    if torch.is_tensor(v):
        assert v.device == device

这个函数to_gpu会递归地检查字典中的每个元素,如果是Tensor类型并且不在GPU上,就会使用.to(device)方法转移它。您需要先设置device变量指向您的GPU设备。如果没有GPU可用,它会默认使用CPU。


http://www.kler.cn/a/391160.html

相关文章:

  • 边缘计算在智能交通系统中的应用
  • CommandLineParser 使用
  • 嵌入式硬件杂谈(一)-推挽 开漏 高阻态 上拉电阻
  • 数据结构与算法-前缀和数组
  • Autosar CP DDS规范导读
  • 机器学习——损失函数、代价函数、KL散度
  • 【webrtc】RTX 重传包和NACK包
  • Secure Shell(SSH) 是一种网络协议
  • RDK X3 环形麦克风板录音与播放
  • STM32 设计的较为复杂的物联网项目,包括智能家居控制系统,涵盖了硬件和软件的详细设计。
  • 屏幕解析工具——OmniParser
  • vue内置方法总结
  • Qt中MainWindow的isVisible和isActiveWindow有什么区别
  • 基本和引用数据类型以及对象字面量(day14)
  • ubuntu24.04播放语音视频
  • 启动本地开发环境(自带热启动)yarn serve
  • Pytorch学习--神经网络--完整的模型验证套路
  • MacOS编译hello_xr——记一次CMake搜索路径限制导致的ANDROID_NATIVE_APP_GLUE not found
  • 网络安全-Linux基础(2)
  • 电子应用产品设计方案-5:多功能恒温控制器设计
  • 【主机游戏】正当防卫3游戏介绍
  • uniapp和uview-plus组件在项目中向后端发起请求的封装
  • 【蓝桥等考C++真题】蓝桥杯等级考试C++组第13级L13真题原题(含答案)-统计数字
  • 用轻量云服务器搭建一个开源的商城系统,含小程序和pc端
  • Java中的不可变集合:性能与安全并重的最佳实践
  • 力扣 LeetCode 977. 有序数组的平方(Day1:数组)