2

Using PyTorch I have a multidimensional tensor A of size (b, 2, x, y), and another related tensor B of size (b, 2, x, y, 3).

I want to get the index of the minimum value across dim=1 in A (this dimension is size 2), and apply this index tensor to B so that I would end up with a tensor of shape (b, x, y, 3).

By using A_mins, indices = torch.min(A, dim=1) I am able to get a tensor indices of shape (b, x, y) where the value is either 0 or 1 depending on which is the minimum value across dim=1 in A. I don't know how to then apply this to B to get the desired output. I am aware that torch.index_select does a similar job but only for 1D index vectors.

1 Answer 1

3

I think a more appropriate function would be torch.gather. You should first apply torch.Tensor.argmin (or equally with torch.Tensor.min) with the keepdim option set to True and broadcast the indexer (here reduced A since the indexed tensor B has an extra dimension):

>>> indexer = A.argmin(1,True).unsqueeze(-1).expand(*(-1,)*A.ndim, 3) >>> out = torch.gather(B, 1, indexer)[:, 0] 

In terms of shapes:

  • indexer tensor will have a shape of (b, 1, x, y, 3) where the last dimension is essentially a view to the values (we expanded from singleton to three-channel with torch.expand).

  • the resulting tensor out will have a shape of (b, x, y, 3) after having squeeze the singleton on dim=1 with squeeze(1) or in an equivalent fashion with the [:, 0] indexing...

Sign up to request clarification or add additional context in comments.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.