I'm struggling with dimension and matric multiplication in pytorch. I want to multiply matrix A
tensor([[[104.7500, 111.3750, 138.2500, 144.8750], [104.2500, 110.8750, 137.7500, 144.3750]], [[356.8750, 363.5000, 390.3750, 397.0000], [356.3750, 363.0000, 389.8750, 396.5000]]]) with matrix B
tensor([[[[ 0., 1., 2., 5., 6., 7., 10., 11., 12.], [ 2., 3., 4., 7., 8., 9., 12., 13., 14.], [ 10., 11., 12., 15., 16., 17., 20., 21., 22.], [ 12., 13., 14., 17., 18., 19., 22., 23., 24.]], [[ 25., 26., 27., 30., 31., 32., 35., 36., 37.], [ 27., 28., 29., 32., 33., 34., 37., 38., 39.], [ 35., 36., 37., 40., 41., 42., 45., 46., 47.], [ 37., 38., 39., 42., 43., 44., 47., 48., 49.]], [[ 50., 51., 52., 55., 56., 57., 60., 61., 62.], [ 52., 53., 54., 57., 58., 59., 62., 63., 64.], [ 60., 61., 62., 65., 66., 67., 70., 71., 72.], [ 62., 63., 64., 67., 68., 69., 72., 73., 74.]]], [[[ 75., 76., 77., 80., 81., 82., 85., 86., 87.], [ 77., 78., 79., 82., 83., 84., 87., 88., 89.], [ 85., 86., 87., 90., 91., 92., 95., 96., 97.], [ 87., 88., 89., 92., 93., 94., 97., 98., 99.]], [[100., 101., 102., 105., 106., 107., 110., 111., 112.], [102., 103., 104., 107., 108., 109., 112., 113., 114.], [110., 111., 112., 115., 116., 117., 120., 121., 122.], [112., 113., 114., 117., 118., 119., 122., 123., 124.]], [[125., 126., 127., 130., 131., 132., 135., 136., 137.], [127., 128., 129., 132., 133., 134., 137., 138., 139.], [135., 136., 137., 140., 141., 142., 145., 146., 147.], [137., 138., 139., 142., 143., 144., 147., 148., 149.]]]]) However using the simple @ to multiply them, doesn'e lead me to the desired result. What I want is somethinlg like: multiply the first two rows of A by the first 3 4x9 submatrices of B (let's say B[:,:,0,:]) so that I have two results, then in the same way muliply the third and fourth row of A with the second 3 4x9 submatrices of B, so to have again two results, then I want to sum the first results of each multiplication and the second results of each. I know I have to work with some kind of reshapes but I find it so confusing, can you help me with a quite generalizable solution?
torch.einsum