dim = 1
import torch
input = torch.tensor([[1, 2, 3],
[4, 5, 6]])
index = torch.tensor([[0, 1],
[1, 2]])
output_gather = torch.gather(input, 1, index)
print('torch.gather result\n',output_gather)
rows = torch.arange(input.size(0)).unsqueeze(1).expand_as(index)
'''
torch.arange(input.size(0)) = tensor([0, 1])
torch.arange(input.size(0)).unsqueeze(1) = tensor([[0],
[1]])
rows = tensor([[0, 0],
[1, 1]])
'''
output_index = input[rows, index]
print('index result\n',output_gather)
linear_index = index + torch.arange(input.size(0)).unsqueeze(1) * input.size(1)
'''
tensor([[0], * 3
[1]])
linear_index = tensor([[0, 1],
[4, 5]])
'''
output = torch.take(input, linear_index)
output_3 = output.view(index.shape)
print('take and flatten result\n',output_3)
output_4 = torch.stack([input[i, index[i]] for i in range(input.size(0))])
print('stack result\n',output_4)
import pdb;pdb.set_trace()
dim = 0
import torch
input = torch.tensor([[1, 2, 3],
[4, 5, 6]])
index = torch.tensor([[0, 1],
[1, 0]])
output_gather = torch.gather(input, 0, index)
print('torch.gather result\n',output_gather)
a= [input[index[i],i] for i in range(input.size(0))]
output_4 = torch.stack([input[index[i],i] for i in range(input.size(0))],dim=1)
print('torch.stack result\n',output_4)
import pdb;pdb.set_trace()