Skip to content

fix: replace deprecated torch.cuda.amp with torch.amp#445

Open
haoyu-haoyu wants to merge 1 commit intoRosettaCommons:mainfrom
haoyu-haoyu:fix/replace-deprecated-torch-cuda-amp
Open

fix: replace deprecated torch.cuda.amp with torch.amp#445
haoyu-haoyu wants to merge 1 commit intoRosettaCommons:mainfrom
haoyu-haoyu:fix/replace-deprecated-torch-cuda-amp

Conversation

@haoyu-haoyu
Copy link
Copy Markdown

Summary

  • Replace all torch.cuda.amp.autocasttorch.amp.autocast('cuda', ...) (3 instances)
  • Replace torch.cuda.amp.GradScalertorch.amp.GradScaler('cuda', ...) (1 instance)

The torch.cuda.amp.* APIs were deprecated in PyTorch 1.13 (migration guide) and emit FutureWarning in PyTorch 2.4+. The new torch.amp.* equivalents require an explicit device_type argument.

Files changed

File Change
rfdiffusion/Track_module.py:236 @torch.cuda.amp.autocast(enabled=False)@torch.amp.autocast('cuda', enabled=False)
env/SE3Transformer/.../inference.py:52 context manager
env/SE3Transformer/.../training.py:93 context manager
env/SE3Transformer/.../training.py:130 GradScaler constructor

Test plan

  • Verify no FutureWarning with python -W error::FutureWarning
  • Run existing test suite (tests/test_diffusion.py)
  • Confirm inference output is unchanged
`torch.cuda.amp.autocast`, `torch.cuda.amp.GradScaler` were deprecated in PyTorch 1.13 and will be removed in a future release. Replace with the device-explicit `torch.amp.autocast('cuda', ...)` and `torch.amp.GradScaler('cuda', ...)` equivalents. Files changed: - rfdiffusion/Track_module.py (decorator on Str2Str.forward) - env/SE3Transformer/se3_transformer/runtime/inference.py - env/SE3Transformer/se3_transformer/runtime/training.py (2 instances)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

1 participant