pytorch张量高级索引介绍
PyTorch 中,张量索引操作可以使用高级索引(advanced indexing),其中索引可以是另一个张量。使用这种索引方式时,返回值的维度由索引张量的形状和原始张量的形状共同决定。以下是具体的规则和解释:
1. 基本概念
假设我们有一个张量 x
和索引张量 indices
,我们通过 x[indices]
进行高级索引操作。
规则:
- 索引张量的形状将决定返回值的形状。
- 返回值的维度由索引张量的维度代替索引位置后的张量维度。
2. 示例讲解
示例 1:一维索引
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
indices = torch.tensor([0, 1])
result = x[indices]
x
的形状是(2, 3)
。indices
是一维张量,形状是(2,)
。- 索引
x[indices]
的结果:- 取出
x
的第 0 行和第 1 行。 - 返回值的形状是
(2, 3)
。
- 取出
示例 2:多维索引
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
indices = torch.tensor([[0, 1], [1, 0]])
result = x[indices]
print(f"x.shape:{x.shape}")
print(f"index.shape:{index.shape}")
print(f"result.shape:{result.shape}")
print(result)
输出:
x.shape:torch.Size([2, 3])
index.shape:torch.Size([2, 2])
result.shape:torch.Size([2, 2, 3])
tensor([[[10, 20, 30],
[40, 50, 60]],
[[40, 50, 60],
[10, 20, 30]]])
示例 3:多维组合索引
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
rows = torch.tensor([0, 1])
cols = torch.tensor([1, 2])
result = x[rows, cols]
x
的形状是(2, 3)
。rows
和cols
都是一维张量,形状为(2,)
。- 索引
x[rows, cols]
:- 分别取出
x[0, 1]
和x[1, 2]
。 - 返回值是
(20, 60)
,形状为(2,)
。
- 分别取出
示例 4:广播索引
x = torch.tensor([[10, 20, 30], [40, 50, 60]])
rows = torch.tensor([[0], [1]])
cols = torch.tensor([0, 2])
result = x[rows, cols]
x
的形状是(2, 3)
。rows
的形状是(2, 1)
,cols
的形状是(2,)
。- 索引
x[rows, cols]
:rows
和cols
会广播成(2, 2)
。- 返回值的形状是
(2, 2)
。
示例 5:更复杂的张量索引操作
AF3 AtomAttentionEncoder类的init_pair_repr方法解读-CSDN博客中的 张量的高级索引
总结:
- 索引张量的形状直接决定了返回张量的形状。
- 当多个索引张量时,它们会广播以匹配维度,然后返回广播后形状的张量。