PyTorch基本使用-张量的索引操作
在操作张量时,经常要去获取某些元素进行处理或者修改操作,在这里需要了解torch中的索引操作。
准备数据:
data = torch.randint(0,10,[4,5])
print('data--->',data)
输出结果:
data---> tensor([[3, 9, 4, 0, 5],
[7, 5, 9, 9, 7],
[5, 9, 8, 9, 7],
[9, 2, 6, 7, 7]])
-
简单行、列索引
print('第一行:',data[0]) print('第一列:',data[:,0])
输出结果:
第一行: tensor([3, 9, 4, 0, 5]) 第一列: tensor([3, 7, 5, 9])
-
列表索引
print('-----------------返回(0,1)、(1,2) 2个位置的元素------------------') print(data[[0,1],[1,2]]) print('-----------------返回0、1 行的1、2 列共4个元素------------------') print(data[[[0],[1]],[1,2]])
输出结果:
-----------------返回(0,1)、(1,2) 2个位置的元素------------------ tensor([9, 9]) -----------------返回0、1 行的1、2 列共4个元素------------------ tensor([[9, 4], [5, 9]])
-
范围索引
print('-----------------前3行、前2列的数据------------------') print(data[:3,:2]) print('-----------------第2行到最后的前2列数据------------------') print(data[2:,:2])
输出结果:
-----------------前3行、前2列的数据------------------ tensor([[3, 9], [7, 5], [5, 9]]) -----------------第2行到最后的前2列数据------------------ tensor([[5, 9], [9, 2]])
-
布尔索引
print('-----------------第三列大于5的行数据------------------') print(data[data[:,2] > 5]) print('-----------------第二行大于5的行数据------------------') print(data[:,data[1] > 5])
输出结果:
-----------------第三列大于5的行数据------------------ tensor([[7, 5, 9, 9, 7], [5, 9, 8, 9, 7], [9, 2, 6, 7, 7]]) -----------------第二行大于5的行数据------------------ tensor([[3, 4, 0, 5], [7, 9, 9, 7], [5, 8, 9, 7], [9, 6, 7, 7]])
-
多维索引
data = torch.randint(0,10,[3,4,5]) print(data) # 获取0轴上的第一个数据 print(data[0,:,:]) # 获取1轴上的第一个数据 print(data[:,0,:]) # 获取2轴上的第一个数据 print(data[:,:,0])
输出结果:
tensor([[[8, 3, 6, 1, 5], [5, 0, 4, 3, 8], [8, 3, 3, 5, 0], [6, 4, 0, 8, 4]], [[7, 2, 3, 8, 5], [6, 2, 9, 5, 0], [4, 2, 7, 1, 1], [5, 4, 4, 1, 1]], [[2, 4, 7, 2, 5], [6, 1, 4, 5, 6], [9, 2, 3, 1, 0], [2, 1, 2, 7, 9]]]) tensor([[8, 3, 6, 1, 5], [5, 0, 4, 3, 8], [8, 3, 3, 5, 0], [6, 4, 0, 8, 4]]) tensor([[8, 3, 6, 1, 5], [7, 2, 3, 8, 5], [2, 4, 7, 2, 5]]) tensor([[8, 5, 8, 6], [7, 6, 4, 5], [2, 6, 9, 2]])