-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
bfloat16-mixed is incompatible with fused optimizers on the lightning framework. After some digging I found PR #15555 as a response to issue [#15501] which makes it seem like this is a safety net introduced by the lightning team.
Reproduction Code:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import lightning as L
class SimpleModel(L.LightningModule):
def __init__(self):
super().__init__()
self.layer = nn.Linear(10, 1)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
return loss
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=1e-3, fused=True)
def main():
X, y = torch.randn(100, 10), torch.randn(100, 1)
dataloader = DataLoader(TensorDataset(X, y), batch_size=16)
trainer = L.Trainer(
accelerator="gpu",
devices=1,
precision="16-mixed", # Using bf16-true passes
gradient_clip_val=1.0,
max_epochs=1,
logger=False,
enable_checkpointing=False,
enable_progress_bar=False,
)
print(f"Scaler: {trainer.precision_plugin.scaler}")
trainer.fit(SimpleModel(), dataloader)
if __name__ == "__main__":
main()
Error:
RuntimeError: The current optimizer, AdamW, does not allow for gradient clipping because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?
Expected Solution:
I believe the key error is in this function handling. I'm a newbie to the pytorch-lightning OSS community and internal frameworks, but this is in the native-amp code, and is what's in traceback. I think it should have an additional condition which is basically and trainer.precision_plugin.scaler is not None (or however it's accessed in the internal API), which would prevent the unnecessary handle in the bf16 case.
if clip_val > 0 and _optimizer_handles_unscaling(optimizer):
raise RuntimeError(
f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"
" because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?"
)
The _step_supports_amp_scaling flag just indicates the optimizer can handle internal unscaling when a GradScaler is present. With bf16 (no scaler), there's nothing to unscale - gradient clipping works normally.
Lightning's check should be:
if clip_val > 0 and self.scaler is not None and _optimizer_handles_unscaling(optimizer):
raise RuntimeError(...)
Not just:
if clip_val > 0 and _optimizer_handles_unscaling(optimizer):
raise RuntimeError(...)
To verify that this is undesired behavior here's a PyTorch only code that works perfectly fine:
import torch
import torch.nn as nn
def main():
print("PyTorch:", torch.__version__)
print()
# Simple model
model = nn.Linear(10, 1).cuda()
# Fused AdamW
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, fused=True)
print(f"_step_supports_amp_scaling: {optimizer._step_supports_amp_scaling}")
for step in range(10):
# Random data
x = torch.randn(16, 10, device="cuda")
y = torch.randn(16, 1, device="cuda")
optimizer.zero_grad()
# bf16 autocast (no GradScaler needed for bf16)
with torch.autocast("cuda", dtype=torch.bfloat16):
y_hat = model(x)
loss = nn.functional.mse_loss(y_hat, y)
# Backward
loss.backward()
# Gradient clipping - works fine!
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Step
optimizer.step()
if __name__ == "__main__":
main()
Personal Setup:
RTX A5000
pytorch-lightning version: 2.5.6
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response