fix(models): Fix Perceiver interpolate_pos_encoding interpolating to the source size#44899
Conversation
zucchini-nlp left a comment
There was a problem hiding this comment.
Let's add a test if we don't yet have
| position_embeddings = nn.functional.interpolate( | ||
| position_embeddings, | ||
| size=(new_height, new_width), | ||
| size=(height, width), |
There was a problem hiding this comment.
sounds reasonable. I am not super familiar with perceiver, do we not need to divide the height/width by patch size, so it matches with patched image features?
There was a problem hiding this comment.
Sorry I think I could've made this clearer as well in the PR description! I think this traces back to the reason it was missed as well, which was the prevalent ViT-family fix which worked for most models was applied here, only here there's no patch_size to divide by so it's incorrect (Perceiver uses conv1x1 with spatial_downsample=1).
→ ViT/DeiT: new_height = height // self.patch_size makes the value in new_height the target.
→ Perceiver: new_height = torch_int(num_positions**0.5) makes it the source grid size, so the same size=(new_height, new_width) pattern that's correct everywhere else becomes a no-op here.
Thank you for your time @zucchini-nlp! Checked the test coverage, this behavior is covered by |
| @harshaljanjani I see, thaks a lot for explaining! Do you mind adding a fast test, |
| [For maintainers] Suggested jobs to run (before merge) run-slow: perceiver |



What does this PR do?
The following failing Perceiver use case was identified and fixed in this PR:
→ c6d2848 (🚨 Fix torch.jit.trace for interpolate_pos_encoding in all vision models) refactored all vision models'
interpolate_pos_encodingmethods for torch.jit.trace; the canonical pattern used across other vision models (e.g. modeling_vit.py, modeling_deit.py) is that they passes the target (height, width) tonn.functional.interpolate; but the Perceiver diff passed the source grid dims practically making the interpolation a no-op; this should fix that!→ I also checked if other models have the exact same issue; and they don't, they compute
new_height = height // self.patch_size(target patch grid) and pass that.Fixes #44898
Before the fix (feel free to cross-check; these errors are reproducible):
After the fix (feel free to cross-check):
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.