torch.embedding 报错 IndexError: index out of range in self
文章目录
- 1. 报错
- 2. 原因
- 3. 解决方法
1. 报错
torch.embedding
报错:
IndexError: index out of range in self
2. 原因
首先看下正常情况:
import torch
import torch.nn.functional as F
inputs = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
embedding_matrix = torch.rand(10, 3)
print(embedding_matrix)
print(F.embedding(inputs, embedding_matrix))
输出:
tensor([[0.8832, 0.2487, 0.7640],
[0.8973, 0.5747, 0.8496],
[0.2269, 0.2961, 0.7951],
[0.7736, 0.9914, 0.9448],
[0.4134, 0.7143, 0.4455],
[0.3482, 0.1837, 0.3179],
[0.4071, 0.9485, 0.1735],
[0.7494, 0.8119, 0.7899],
[0.3922, 0.2944, 0.4924],
[0.2391, 0.8299, 0.3299]])
tensor([[[0.8973, 0.5747, 0.8496],
[0.2269, 0.2961, 0.7951],
[0.4134, 0.7143, 0.4455],
[0.3482, 0.1837, 0.3179]],
[[0.4134, 0.7143, 0.4455],
[0.7736, 0.9914, 0.9448],
[0.2269, 0.2961, 0.7951],
[0.2391, 0.8299, 0.3299]]])
在这里,embedding_matrix
是一个全量的权重表,需要根据 inputs
的列表来选择权重列表的第几行。
例如:inputs[0] = [1, 2, 4, 5]
,注意这里是从0开始的,那么依次选择 embedding_matrix
的第2行、第3行、第5行、第6行,则对应的列表组成为:
[[0.8973, 0.5747, 0.8496],
[0.2269, 0.2961, 0.7951],
[0.4134, 0.7143, 0.4455],
[0.3482, 0.1837, 0.3179]]
这就是输出的第一部分,对于inputs[1] = [4, 3, 2, 9]
同样如此。
到这里,应该就清楚了出现 IndexError: index out of range in self
报错的原因了。如果 inputs
中出现超过权重表长度的数,就会报错。
例如上面例子权重表有10行,所以 inputs
最大数应该为9,如果 inputs[1] = [4, 3, 2, 10]
,如下:
import torch
import torch.nn.functional as F
inputs = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 10]])
embedding_matrix = torch.rand(10, 3)
print(embedding_matrix)
print(F.embedding(inputs, embedding_matrix))
那么报错如下:
3. 解决方法
知道报错原因之后,那么需要弄清楚的就是 inputs
中为什么会出现超过 embedding_matrix
权重表长度的数,这里就需要具体分析了。