0

Given a Tensor A of shape (N,C) and an indices Tensor Idx of shape (N,), i'd like to sum all the elements of each row in A excluding the corresponding column index in I. For example:

A = torch.tensor([[1,2,3], [4,5,6]]) Idx = torch.tensor([0,2]) #result: torch.tensor([[5], [9]]) 

A solution using loops is known.

1 Answer 1

1

You can set excluded elements to zero:

A[range(A.shape[0]),Idx] = 0 

and sum tensor along rows:

b = A.sum(dim = 1,keepdim = True ) # b = torch.tensor([[5], [9]]) 
Sign up to request clarification or add additional context in comments.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.