Skip to content

bf16-mixed causing issues with fused AdamW #21435

@vsandwar-sumer

Description

@vsandwar-sumer

Bug description

bfloat16-mixed is incompatible with fused optimizers on the lightning framework. After some digging I found PR #15555 as a response to issue [#15501] which makes it seem like this is a safety net introduced by the lightning team.

Reproduction Code:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

import lightning as L

class SimpleModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(10, 1)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-3, fused=True)


def main():
    X, y = torch.randn(100, 10), torch.randn(100, 1)
    dataloader = DataLoader(TensorDataset(X, y), batch_size=16)

    trainer = L.Trainer(
        accelerator="gpu",
        devices=1,
        precision="16-mixed",  # Using bf16-true passes
        gradient_clip_val=1.0,
        max_epochs=1,
        logger=False,
        enable_checkpointing=False,
        enable_progress_bar=False,
    )

    print(f"Scaler: {trainer.precision_plugin.scaler}")

    trainer.fit(SimpleModel(), dataloader)

if __name__ == "__main__":
    main()

Error:

RuntimeError: The current optimizer, AdamW, does not allow for gradient clipping because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?

Expected Solution:

I believe the key error is in this function handling. I'm a newbie to the pytorch-lightning OSS community and internal frameworks, but this is in the native-amp code, and is what's in traceback. I think it should have an additional condition which is basically and trainer.precision_plugin.scaler is not None (or however it's accessed in the internal API), which would prevent the unnecessary handle in the bf16 case.

if clip_val > 0 and _optimizer_handles_unscaling(optimizer):
            raise RuntimeError(
                f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"
                " because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?"
            )

The _step_supports_amp_scaling flag just indicates the optimizer can handle internal unscaling when a GradScaler is present. With bf16 (no scaler), there's nothing to unscale - gradient clipping works normally.

Lightning's check should be:
if clip_val > 0 and self.scaler is not None and _optimizer_handles_unscaling(optimizer):
raise RuntimeError(...)

Not just:
if clip_val > 0 and _optimizer_handles_unscaling(optimizer):
raise RuntimeError(...)

To verify that this is undesired behavior here's a PyTorch only code that works perfectly fine:

import torch
import torch.nn as nn


def main():
    print("PyTorch:", torch.__version__)
    print()

    # Simple model
    model = nn.Linear(10, 1).cuda()

    # Fused AdamW
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, fused=True)

    print(f"_step_supports_amp_scaling: {optimizer._step_supports_amp_scaling}")

    for step in range(10):
        # Random data
        x = torch.randn(16, 10, device="cuda")
        y = torch.randn(16, 1, device="cuda")

        optimizer.zero_grad()

        # bf16 autocast (no GradScaler needed for bf16)
        with torch.autocast("cuda", dtype=torch.bfloat16):
            y_hat = model(x)
            loss = nn.functional.mse_loss(y_hat, y)

        # Backward
        loss.backward()

        # Gradient clipping - works fine!
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # Step
        optimizer.step()

if __name__ == "__main__":
    main()

Personal Setup:
RTX A5000
pytorch-lightning version: 2.5.6

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

cc @ethanwharris

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions