Using PyTorch I have a multidimensional tensor A of size (b, 2, x, y), and another related tensor B of size (b, 2, x, y, 3).
I want to get the index of the minimum value across dim=1 in A (this dimension is size 2), and apply this index tensor to B so that I would end up with a tensor of shape (b, x, y, 3).
By using A_mins, indices = torch.min(A, dim=1) I am able to get a tensor indices of shape (b, x, y) where the value is either 0 or 1 depending on which is the minimum value across dim=1 in A. I don't know how to then apply this to B to get the desired output. I am aware that torch.index_select does a similar job but only for 1D index vectors.