Skip to content

Conversation

@littlebullGit
Copy link
Contributor

@littlebullGit littlebullGit commented Dec 20, 2025

What does this PR do?

Fixes #21409

When different ranks log different metric keys with sync_dist=True, the all_reduce operations become mismatched across ranks, causing silent data corruption where values from different metrics get averaged together.

Solution

This PR implements key consistency validation at controlled synchronization points:

  1. Added _assert_sync_dist_metric_keys_consistency() function - Uses all_gather_object to collect metric keys from all ranks and validates they're identical. Raises a clear MisconfigurationException if a mismatch is detected.

  2. Added sync_on_step_metrics() method in _ResultCollection - Called after each training/validation step completes. For on_step=True metrics:

    • Defers the sync from update() time to a controlled sync point
    • Validates all ranks have the same keys
    • Performs sync in deterministic order
  3. Added sync_on_epoch_metrics() method in _ResultCollection - Called at epoch end before metrics are consumed. For on_epoch=True metrics:

    • Validates all ranks have the same keys
    • Performs compute() (which includes sync) in deterministic order
  4. Hook points in loops:

    • Training loop: sync_on_step_metrics() after step, sync_on_epoch_metrics() before epoch metrics consumed
    • Evaluation loop: sync_on_step_metrics() after step, sync_on_epoch_metrics() before dataloader outputs stored

Key Design Decisions

  • Validation happens at controlled synchronization points where ALL ranks are guaranteed to be at the same place
  • For on_step metrics, sync is deferred via _forward_cache_synced flag until after the step completes
  • For on_epoch metrics, validation happens before compute() is called
  • Clear error message shows exactly which keys differ on which ranks

Example Error Message

MisconfigurationException: When logging with `sync_dist=True`, all processes must log the same metric keys in the same order within a given hook. Detected a mismatch during `training_step`. Synchronized metric keys per rank: rank=0: ['training_step.loss', 'training_step.metric_a'] rank=1: ['training_step.loss', 'training_step.metric_b'] Either log the same keys on all ranks (for example by logging dummy values), or set `sync_dist=False` and manually synchronize (for example using `all_gather`).


📚 Documentation preview 📚: https://pytorch-lightning--21434.org.readthedocs.build/en/21434/

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Dec 20, 2025
Adds a check to ensure all processes log the same metric keys in the same order when using `sync_dist=True`. This prevents silent errors or hangs that can occur when processes attempt to synchronize different sets of metrics.

- Add `_assert_sync_dist_metric_keys_consistency()` to validate metric keys across ranks
- Defer syncing of on_step metrics until collection time to enable validation
- Add `_forward_cache_synced` flag to track sync
@littlebullGit littlebullGit force-pushed the fix/21409-ddp-metric-key-mismatch branch from 6361e70 to e1d6a53 Compare December 20, 2025 04:16
@codecov
Copy link

codecov bot commented Dec 20, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 79%. Comparing base (28f18dc) to head (aa58ed8).
⚠️ Report is 3 commits behind head on master.
✅ All tests successful. No failed tests found.

❗ There is a different number of reports uploaded between BASE (28f18dc) and HEAD (aa58ed8). Click for more details.

HEAD has 120 uploads less than BASE
Flag BASE (28f18dc) HEAD (aa58ed8)
cpu 60 30
lightning_fabric 15 0
pytest 30 0
python3.12 18 9
lightning 30 15
python3.11 12 6
python3.10 6 3
python3.12.7 18 9
python 6 3
Additional details and impacted files
@@            Coverage Diff            @@
##           master   #21434     +/-   ##
=========================================
- Coverage      87%      79%     -8%     
=========================================
  Files         269      267      -2     
  Lines       23853    24066    +213     
=========================================
- Hits        20667    19014   -1653     
- Misses       3186     5052   +1866     

…validation

- Add tests for non-tensor (TorchMetric) path in sync methods
- Add tests for value update logic in sync_on_step_metrics
- Add tests for skip conditions (already synced, already computed, None cache)
- Add tests for multiple metrics sync
- Add DDP tests for value update and combined step+epoch sync
- Add detailed error message validation test
- Add initialization and basic behavior tests for _forward_cache_synced
Reduce the number of spawn_launch calls from 11 to 3 by combining
related DDP tests into single spawn functions. This prevents
segmentation faults that can occur when many processes are spawned
in quick succession during test runs with random order.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DDP Validation Metric Logging got Misplaced Silently When Processes Different Metric Keys on Different Devices

1 participant