I wonder if I can get torch.argmax of my input excluding certain index. For example,
target = torch.tensor([1,2]) input = torch.tensor([[0.1,0.5,0.2,0.2], [0.1,0.5,0.1,0.3]]) I want to get the maximum value in input excluding the index on the target, so that the result would be
output = torch.tensor([[0.2],[0.5]])