深度学习02-pytorch-05-张量的索引操作
在 PyTorch 中,张量的索引操作是非常灵活且强大的,允许你访问和修改张量中的特定元素或子集。与 NumPy 的数组操作类似,PyTorch 的张量支持多种索引方法,包括单纯的基础索引、高级索引和切片等。下面详细介绍 PyTorch 中常见的张量索引操作。
1. 基础索引
单个元素索引
你可以通过指定行、列来访问张量中的某个元素:
import torch
data = torch.tensor([[2, 8, 6, 9, 2],
[8, 7, 0, 8, 2],
[2, 7, 4, 0, 4],
[8, 9, 4, 3, 6]])
# 访问第0行,第1列的元素
print(data[0, 1]) # 输出 8
负索引
你可以使用负数索引来从末尾开始访问张量中的元素:
# 访问倒数第一行,倒数第二列的元素
print(data[-1, -2]) # 输出 3
2. 切片操作
你可以对张量的某个维度进行切片操作,获取其中的一部分数据。
切片示例
# 获取第0行到第2行(不包括第2行)和第1列到第3列(不包括第3列)的子张量
print(data[0:2, 1:3])
# 输出:
# tensor([[8, 6],
# [7, 0]])
使用 :
选择所有行或列
# 获取所有行的第1列
print(data[:, 1])
# 输出 tensor([8, 7, 7, 9])
3. 布尔索引
布尔索引允许你通过条件来选择符合条件的元素。例如,选取所有大于 5 的元素:
mask = data > 5
print(mask)
# 输出:
# tensor([[False, True, True, True, False],
# [ True, True, False, True, False],
# [False, True, False, False, False],
# [ True, True, False, False, True]])
# 选取大于 5 的所有元素
print(data[mask])
# 输出 tensor([8, 6, 9, 8, 7, 7, 8, 8, 9, 9, 6])
4. 高级索引
高级索引允许你通过张量、列表或其他数组来索引特定的位置。
多个索引
你可以使用两个列表/张量来同时索引行和列:
rows = [0, 1]
cols = [2, 3]
# 访问第 0 行第 2 列和第 1 行第 3 列的元素
print(data[rows, cols])
# 输出 tensor([6, 8])
结合切片与索引
# 获取第0行和第1行,并选取这些行的第2列和第3列的元素
print(data[[0, 1], 2:4])
# 输出:
# tensor([[6, 9],
# [0, 8]])
5. None 和 ellipsis (...
) 索引
None 索引
None 索引可以用于增加一个新的维度,类似于 NumPy 中的 np.newaxis:
data = torch.tensor([1, 2, 3])
expanded_data = data[:, None]
print(expanded_data)
# 输出:
# tensor([[1],
# [2],
# [3]])
...
(省略号)操作符
...
操作符可以用于简化多维张量的索引,它会选择所有没有明确指定索引的维度。例如:
data = torch.randn(3, 4, 5)
# 使用省略号,访问所有行和列,并获取第2维的第0列
print(data[..., 0].shape) # 输出形状是 (3, 4)
6. 张量赋值操作
通过索引操作,你还可以修改张量中的某些元素:
# 将第0行第1列的元素设为10
data[0, 1] = 10
print(data)
你也可以使用布尔索引或条件选择元素并赋值:
# 将所有大于5的元素设置为5
data[data > 5] = 5
print(data)
7. 高级操作:索引与广播
广播索引
当你使用多个不同形状的索引时,PyTorch 会自动应用广播机制,使索引适应不同维度的张量:
data = torch.tensor([[2, 8, 6, 9, 2],
[8, 7, 0, 8, 2],
[2, 7, 4, 0, 4],
[8, 9, 4, 3, 6]])
# 使用广播机制,选择第2行第2列、第2行第3列和第3行第2列、第3行第3列的元素
print(data[[[2], [3]], [2, 3]])
# 输出:
# tensor([[4, 0],
# [4, 3]])
8. 张量选择操作
PyTorch 提供了 torch.index_select
等函数用于更加灵活的选择操作:
# 选择第 1 和第 2 行
indices = torch.tensor([1, 2])
selected_rows = torch.index_select(data, 0, indices)
print(selected_rows)
9. Gather 函数
torch.gather
函数允许你在指定的维度上根据索引选择张量中的元素:
data = torch.tensor([[2, 8, 6, 9, 2],
[8, 7, 0, 8, 2],
[2, 7, 4, 0, 4],
[8, 9, 4, 3, 6]])
index = torch.tensor([[0, 2], [3, 1], [1, 0], [2, 3]])
# gather 操作在第 1 维度上根据给定的 index 选择数据
result = torch.gather(data, 1, index)
print(result)
# 输出:
# tensor([[2, 6],
# [8, 7],
# [7, 2],
# [4, 3]])
总结
PyTorch 提供了非常灵活和强大的张量索引机制。你可以使用基础索引、切片、布尔索引、以及高级索引来访问或修改张量中的数据。此外,PyTorch 的广播机制允许你对不同形状的张量进行操作,从而简化代码并提高计算效率。