Tensor做为索引
文章目录
- 1、tensor作为行索引
- 2、tensor作为列索引
- 3、切片作为索引
- 4、布尔张量作为索引
1、tensor作为行索引
假设我们有一个二维张量 A,它的形状是 4×4:
A = torch.tensor([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]
])
如果我们想用一个索引张量 index 来选择特定的行和列:
index = torch.tensor([0, 2]) # 选取第0行和第2行
result = A[index]
result 将包含 A 的第0行和第2行:
tensor([
[1, 2, 3, 4],
[9, 10, 11, 12]
])
2、tensor作为列索引
import torch
# 创建一个 3x4 的随机张量
A = torch.tensor([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]
])
# 选择特定的行
rows = torch.tensor([0, 1, 2]) # 所有行
# 使用索引获取第 1 列元素
column_1 = A[rows, 1] # 指定列索引为 1
print("第 1 列元素:", column_1)
# 指定行和列的索引
rows = torch.tensor([0, 1, 2]) # 所有行
columns = torch.tensor([0, 2]) # 选择第 0 列和第 2 列
# 使用索引获取指定列的元素
result = A[rows.unsqueeze(1), columns] # 需要增加一个维度来广播
print("第 0 列和第 2 列元素:\n", result)
第 1 列元素: tensor([ 2, 6, 10])
第 0 列和第 2 列元素:
tensor([[ 1, 3],
[ 5, 7],
[ 9, 11]])
3、切片作为索引
import torch
# 二维张量 A,它的形状是 4*4
A = torch.tensor([[
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
],
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]
],
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]
]],
[
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
],
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]
],
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]
]]])
print(A.shape)
# 选取第0行和第2行
index = torch.tensor([0, 1,2,3])
index2 = torch.tensor([0,1,2,3])
result = A[:,:,index2, index]
print(result)
'''
torch.Size([2, 3, 4, 4])
tensor([[[ 1, 6, 11, 16],
[ 1, 6, 11, 16],
[ 1, 6, 11, 16]],
[[ 1, 6, 11, 16],
[ 1, 6, 11, 16],
[ 1, 6, 11, 16]]])
'''
4、布尔张量作为索引
A = torch.tensor([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]
])
# 创建一个布尔张量,选择所有大于 5 的元素
mask = A > 5
# 使用布尔张量作为索引
filtered_elements = A[mask]
print("原始张量 A:")
print(A)
print("\n布尔张量 mask:")
print(mask)
print("\n过滤后的元素:")
print(filtered_elements)
原始张量 A:
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
布尔张量 mask:
tensor([[False, False, False, False],
[False, True, True, True],
[ True, True, True, True]])
过滤后的元素:
tensor([ 6, 7, 8, 9, 10, 11, 12])