3

I wonder if I can get torch.argmax of my input excluding certain index. For example,

target = torch.tensor([1,2]) input = torch.tensor([[0.1,0.5,0.2,0.2], [0.1,0.5,0.1,0.3]]) 

I want to get the maximum value in input excluding the index on the target, so that the result would be

output = torch.tensor([[0.2],[0.5]]) 

2 Answers 2

4

You can try this

  • Set negative infy to the target indices in temp tensor
  • Then use torch.max or torch.argmax
tmp_input = input.clone() tmp_input[range(len(input)), target] = float("-Inf") torch.max(tmp_input, dim=1).values tensor([0.2000, 0.5000]) torch.max(tmp_input, dim=1).indices tensor([3, 1]) torch.argmax(tmp_input, dim=1) tensor([3, 1]) 
Sign up to request clarification or add additional context in comments.

Comments

1
input[target[0]-1,target[1]-1] = -1 # or use -inf #-1 is added for python indexing style output = torch.max(input,dim = 1) 

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.