pytorch 笔记:张量索引的维度扩展规则
1 基本原理
在PyTorch中,张量索引的维度扩展规则遵循以下原则:
索引操作的核心规则:
当使用索引数组访问张量时:
- 索引数组的每个元素对应选取原张量的一个子张量
- 结果形状 = 索引数组形状 + 原张量剩余维度形状
这么说可能不清楚,我们用例子说明
2 举例
假如我们有这样的两个张量
mlp_weights.shape = (num_poi, mer_size, mer_dim)
top2_indices.shape = (batch, cur_seq_len, 2)
selected_weights = mlp_weights[top2_indices]
步骤 | 操作 | 维度变化说明 |
---|---|---|
1 | 索引数组作用在维度0 | 原张量第0维度(num_poi)被索引数组替换 |
2 | 索引数组形状保持 | 继承索引数组的(batch, cur_seq_len, 2)形状 |
3 | 附加后续维度 | 保留原张量的(mer_size, mer_dim)维度 |
——>selected_weights的维度是 (batch, cur_seq_len, 2,mer_size, mer_dim)
2.1 索引过程
-
每个索引值对应一个(mer_size,mer_dim)的矩阵:
top2_indices[0,0,0] = 1
→ 选取mlp_weights[1,:,:]
top2_indices[0,0,1] = 2
→ 选取mlp_weights[2,:,:]