当前位置: 首页 > article >正文

如何阅读PyTorch文档及常见PyTorch错误

如何阅读PyTorch文档及常见PyTorch错误

文章目录

  • 如何阅读PyTorch文档及常见PyTorch错误
    • 阅读PyTorch文档示例
    • 常见Pytorch错误
      • Tensor在不同设备上
      • 维度不匹配
      • cuda内存不足
      • 张量类型不匹配
    • 参考

PyTorch文档查看https://pytorch.org/docs/stable/

image-20240904161104184

image-20240904161329441

torch.nn -> 定义神经网络
torch.optim -> 优化算法
torch.utils.data -> 数据加载 dataset, dataloader类

阅读PyTorch文档示例

torch.max为例

image-20240904161651956

有些函数对不同的输入有不同的行为

Parameters(位置参数):不需要指定参数的名称

Keyword Arguments(关键字参数):必须指定参数的名称

他们通过 * 隔开

带默认值的参数:有些参数有默认值(keepdim=False),所以传递这个参数的值是可选的

image-20240904161820050

三种torch.max的不同输入

  1. 返回整个张量的最大值(torch.max(input) → Tensor)
# 1. max of entire tensor (torch.max(input) → Tensor)
m = torch.max(x)
print(m)

image-20240904214513806

  1. 沿一个维度的最大值 (torch.max(input, dim, keepdim=False, *, out=None) → (Tensor, LongTensor))

    # 2. max along a dimension (torch.max(input, dim, keepdim=False, *, out=None) → (Tensor, LongTensor))
    m, idx = torch.max(x,0)
    print(m)
    print(idx)
    

    image-20240904214724819

    位置参数可以不指定参数的名字,关键字参数必须指定参数名字。以 * 隔开,(位置参数 * 关键字参数)

    # 2-2 位置参数可以不指定参数的名字,关键字参数必须指定参数名字。以 * 隔开,(位置参数 * 关键字参数)
    m, idx = torch.max(input=x,dim=0)
    print(m)
    print(idx)
    

    image-20240904214845204

    # 2-3
    m, idx = torch.max(x,0,False)
    print(m)
    print(idx)
    
    # 2-4
    m, idx = torch.max(x,dim=0,keepdim=True)
    print(m)
    print(idx)
    
    # 2-5
    p = (m,idx)
    torch.max(x,0,False,out=p)
    print(p[0])
    print(p[1])
    

    位置参数可以不指定参数的名字,关键字参数必须指定参数名字。

    image-20240904215101006

  2. 两个张量上的选择最大的(torch.max(input, other, *, out=None) → Tensor)

    # 3. max(choose max) operators on two tensors (torch.max(input, other, *, out=None) → Tensor)
    t = torch.max(x,y)
    print(t)
    

    image-20240904215223304

常见Pytorch错误

Tensor在不同设备上

报错信息:RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_mm)

解决方案:将张量移动到GPU

image-20240904215745830

# 1. different device error (fixed)
x = torch.randn(5).to("cuda:0")
y = model(x)
print(y.shape)

维度不匹配

报错信息:RuntimeError: The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1

解决办法:张量的形状不正确,使用transpose,squeeze, unsqueeze来对齐尺寸

image-20240904220743294

# 2. mismatched dimensions error 1 (fixed by transpose)
y = y.transpose(0,1)
z = x + y
print(z.shape)

cuda内存不足

报错信息:RuntimeError: CUDA out of memory. Tried to allocate 7.27 GiB (GPU 0; 4.00 GiB total capacity; 8.67 GiB already allocated; 0 bytes free; 8.69 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

解决方法:数据的批量大小太大,无法装入GPU。减小批量大小。如果对数据进行迭代(batch size = 1),问题就会得到解决。你也可以使用DataLoader

image-20240904221307600

# 3. cuda out of memory error (fixed, but it might take some time to execute)
for d in data:
	out = resnet18(d.to("cuda:0").unsqueeze(0))
print(out.shape)

张量类型不匹配

报错信息:RuntimeError: expected scalar type Long but found Float

解决方法:标签张量类型必须是Long,将其转换为“long”以解决此问题

image-20240904221529335

# 4. mismatched tensor type (fixed)
labels = labels.long()
lossval = L(outs,labels)
print(lossval)

参考

torch.max — PyTorch 2.4 documentation

Hongyi_Lee_dl_homeworks/Warmup/Pytorch_Tutorial_2.pdf at master · huaiyuechusan/Hongyi_Lee_dl_homeworks (github.com)

orial_2.pdf at master · huaiyuechusan/Hongyi_Lee_dl_homeworks (github.com)](https://github.com/huaiyuechusan/Hongyi_Lee_dl_homeworks/blob/master/Warmup/Pytorch_Tutorial_2.pdf)


http://www.kler.cn/a/291484.html

相关文章:

  • 01_MinIO部署(Windows单节点部署/Docker化部署)
  • 若依权限控制
  • 搭建MC服务器
  • MySQL45讲 第二十四讲 MySQL是怎么保证主备一致的?——阅读总结
  • GNN初探
  • django的model时间怎么转时间戳
  • MLM:多模态大型语言模型的简介、微调方法、发展历史及其代表性模型、案例应用之详细攻略
  • JavaEE(2):前后端项目之间的交互
  • King’s LIMS 实验室信息管理系统:引领实验室数字化转型的创新力量
  • plc1200 weiluntong001
  • Tomato靶机通关攻略
  • sci文章录用后能要求删除其中一位作者吗?
  • 【Linux】在 bash shell 环境下,当一命令正在执行时,按下 control-Z 会?
  • [Java]SpringBoot业务代码增强
  • # 利刃出鞘_Tomcat 核心原理解析(十)-- Tomcat 性能调优--1
  • 微信公众号《GIS 数据工程:开始您的 ETL 之旅 》 文章删除及原因
  • okhttp,retrofit,rxjava 是如何配合工作的 作用分别是什么
  • Eureka:Spring Cloud中的服务注册与发现如何实现?
  • 数据结构(邓俊辉)学习笔记】串 16——Karp-Rabin算法:串即是数
  • 9:00面试,9:08就出来了,问的问题有点变态。。。
  • 九、制作卡牌预制体
  • 【深度学习】yolov8的微调
  • Android framework 编程之 - Binder调用方UID
  • CSS基础 --- % 相对于谁
  • 斯坦福UE4 C++课学习补充21:击败动画
  • Snipaste:一款强大的截图与贴图工具