当前位置: 首页 > article >正文

torch.embedding 报错 IndexError: index out of range in self

文章目录

  • 1. 报错
  • 2. 原因
  • 3. 解决方法


1. 报错

torch.embedding 报错:

IndexError: index out of range in self

2. 原因

首先看下正常情况:

import torch
import torch.nn.functional as F

inputs = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
embedding_matrix = torch.rand(10, 3)
print(embedding_matrix)
print(F.embedding(inputs, embedding_matrix))

输出:

tensor([[0.8832, 0.2487, 0.7640],
        [0.8973, 0.5747, 0.8496],
        [0.2269, 0.2961, 0.7951],
        [0.7736, 0.9914, 0.9448],
        [0.4134, 0.7143, 0.4455],
        [0.3482, 0.1837, 0.3179],
        [0.4071, 0.9485, 0.1735],
        [0.7494, 0.8119, 0.7899],
        [0.3922, 0.2944, 0.4924],
        [0.2391, 0.8299, 0.3299]])
tensor([[[0.8973, 0.5747, 0.8496],
         [0.2269, 0.2961, 0.7951],
         [0.4134, 0.7143, 0.4455],
         [0.3482, 0.1837, 0.3179]],

        [[0.4134, 0.7143, 0.4455],
         [0.7736, 0.9914, 0.9448],
         [0.2269, 0.2961, 0.7951],
         [0.2391, 0.8299, 0.3299]]])

在这里,embedding_matrix 是一个全量的权重表,需要根据 inputs 的列表来选择权重列表的第几行。

例如:inputs[0] = [1, 2, 4, 5],注意这里是从0开始的,那么依次选择 embedding_matrix 的第2行、第3行、第5行、第6行,则对应的列表组成为:

[[0.8973, 0.5747, 0.8496],
[0.2269, 0.2961, 0.7951],
[0.4134, 0.7143, 0.4455],
[0.3482, 0.1837, 0.3179]]

这就是输出的第一部分,对于inputs[1] = [4, 3, 2, 9] 同样如此。

到这里,应该就清楚了出现 IndexError: index out of range in self 报错的原因了。如果 inputs 中出现超过权重表长度的数,就会报错。

例如上面例子权重表有10行,所以 inputs 最大数应该为9,如果 inputs[1] = [4, 3, 2, 10],如下:

import torch
import torch.nn.functional as F

inputs = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 10]])
embedding_matrix = torch.rand(10, 3)
print(embedding_matrix)
print(F.embedding(inputs, embedding_matrix))

那么报错如下:
在这里插入图片描述

3. 解决方法

知道报错原因之后,那么需要弄清楚的就是 inputs 中为什么会出现超过 embedding_matrix 权重表长度的数,这里就需要具体分析了。


http://www.kler.cn/news/316067.html

相关文章:

  • 数据结构之二叉树遍历
  • 【Linux系统编程】第二十一弹---进程的地址空间
  • 《概率论与数理统计》学渣笔记
  • uni-app功能 1. 实现点击置顶,滚动吸顶2.swiper一个轮播显示一个半内容且实现无缝滚动3.穿透修改uni-ui的样式
  • 美团测开OC!
  • 【论文串烧】多媒体推荐中的模态平衡学习 | 音视频语音识别中丢失导致的模态偏差对丢失视频帧鲁棒性的影响
  • erlang学习:Linux常用命令2
  • Github 2024-09-23 开源项目周报 Top15
  • Kubernetes集群架构、安装和配置全面指南
  • 目标检测-数据集
  • 【MySQL】获取最近7天和最近14天的订单数量,使用MySQL详细写出,使用不同的方法
  • 想学习下Python和深度学习,Python需要学习到什么程度呢?
  • C++入门——(类的默认成员函数)析构函数
  • 数据库基础知识---------------------------(3)
  • 早期病毒和反病毒技术(网络安全小知识)
  • MATLAB系列08:输入/输入函数
  • SSCMS 插件示例 一插件创建及插件菜单
  • 大厂面试真题:SpringBoot的核心注解
  • FastAPI 的隐藏宝石:自动生成 TypeScript 客户端
  • Golang | Leetcode Golang题解之第423题从英文中重建数字
  • C++学习
  • 机器学习——Bagging
  • String类和String类常用方法
  • LinuxC高级作业1
  • css边框修饰
  • 代码随想录:打家劫舍||
  • 鸿蒙OpenHarmony【轻量系统内核扩展组件(CPU占用率)】子系统开发
  • 【C++】面向对象编程的三大特性:深入解析继承机制
  • Open3D(C++) 基于点云的曲率提取特征点(自定义阈值法)
  • Unity DOTS系列之IJobChunk来迭代处理数据