Conversation
…radScaler("cuda", args...)`` instead.
There was a problem hiding this comment.
Pull Request Overview
This PR updates the codebase to use the new torch.amp.GradScaler API instead of the deprecated torch.cuda.amp.GradScaler API, following PyTorch's deprecation guidance.
Key Changes:
- Replaced deprecated
torch.cuda.amp.GradScalerimports withtorch.amp.GradScalerusing apartialfunction to set the device to 'cuda' - Applied this change to both production code (
amp_optimizer_wrapper.py) and test code (test_optimizer_wrapper.py)
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| mmengine/optim/optimizer/amp_optimizer_wrapper.py | Updated GradScaler import from deprecated torch.cuda.amp to torch.amp with device parameter |
| tests/test_optim/test_optimizer/test_optimizer_wrapper.py | Updated GradScaler import in tests to use the new torch.amp API |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from functools import partial | ||
|
|
||
| import unittest | ||
| from unittest import TestCase | ||
| from unittest.mock import MagicMock |
There was a problem hiding this comment.
The blank line after from functools import partial creates inconsistent import grouping. Move the functools import to be with other standard library imports (os, unittest) before the blank line that separates standard library imports from third-party imports.
| from functools import partial | |
| import unittest | |
| from unittest import TestCase | |
| from unittest.mock import MagicMock | |
| import unittest | |
| from unittest import TestCase | |
| from unittest.mock import MagicMock | |
| from functools import partial |
| from parameterized import parameterized | ||
| from torch.cuda.amp import GradScaler | ||
| from torch.amp import GradScaler as amp_GradScaler | ||
| GradScaler = partial(amp_GradScaler, device='cuda') |
There was a problem hiding this comment.
[nitpick] Creating a module-level variable GradScaler through partial assignment makes the code less maintainable and harder to understand. Consider either: (1) using amp_GradScaler('cuda', ...) directly at call sites, or (2) creating a proper wrapper function with a docstring explaining the device binding.
| GradScaler = partial(amp_GradScaler, device='cuda') | |
| def get_cuda_grad_scaler(*args, **kwargs): | |
| """Return a torch.amp.GradScaler instance bound to the 'cuda' device. | |
| Args: | |
| *args: Positional arguments for torch.amp.GradScaler. | |
| **kwargs: Keyword arguments for torch.amp.GradScaler. | |
| Returns: | |
| amp_GradScaler: An instance of GradScaler with device='cuda'. | |
| """ | |
| return amp_GradScaler(*args, device='cuda', **kwargs) |
| GradScaler = partial(amp_GradScaler, device='cuda') | ||
|
|
||
|
|
There was a problem hiding this comment.
[nitpick] Creating a module-level variable GradScaler through partial assignment makes the code less maintainable and harder to understand. Consider either: (1) using amp_GradScaler('cuda', ...) directly at call sites, or (2) creating a proper wrapper function with a docstring explaining the device binding.
| GradScaler = partial(amp_GradScaler, device='cuda') | |
| def get_grad_scaler(*args, **kwargs): | |
| """Create a torch.amp.GradScaler instance bound to device='cuda'. | |
| Args: | |
| *args: Positional arguments passed to torch.amp.GradScaler. | |
| **kwargs: Keyword arguments passed to torch.amp.GradScaler. | |
| Returns: | |
| amp_GradScaler: An instance of torch.amp.GradScaler with device='cuda'. | |
| """ | |
| return amp_GradScaler(*args, device='cuda', **kwargs) |
|
@HAOCHENYE This one is ready to be reviewed. The copilot review suggested a different implementation which also LGTM. You may choose one to merge~~ |
…dScaler` to validate instead.
|
This workaround LGTM, I'll merge it after fixing the lint |
This is a sub-PR of #1665
Brief
According to PyTorch:
This includes two related replacement:
amp_optimizer_wrappertest_optimizer_wrapperPyTest Result
pytest tests/test_optim/test_optimizer/test_optimizer_wrapper.py