Skip to content

Commit e1d6a53

Browse files
committed
Add validation for consistent metric keys when using sync_dist=True
Adds a check to ensure all processes log the same metric keys in the same order when using `sync_dist=True`. This prevents silent errors or hangs that can occur when processes attempt to synchronize different sets of metrics. - Add `_assert_sync_dist_metric_keys_consistency()` to validate metric keys across ranks - Defer syncing of on_step metrics until collection time to enable validation - Add `_forward_cache_synced` flag to track sync
1 parent 28f18dc commit e1d6a53

File tree

8 files changed

+824
-3
lines changed

8 files changed

+824
-3
lines changed

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,8 @@ def _on_evaluation_epoch_end(self) -> None:
383383
def _store_dataloader_outputs(self) -> None:
384384
trainer = self.trainer
385385
trainer._logger_connector.epoch_end_reached()
386+
# Sync on_epoch metrics across ranks and validate all ranks logged the same keys
387+
trainer._logger_connector.sync_on_epoch_metrics()
386388
self._logged_outputs.append(trainer._logger_connector.update_eval_epoch_metrics())
387389

388390
def _on_before_fetch(self) -> None:
@@ -442,6 +444,9 @@ def _evaluation_step(
442444

443445
self.batch_progress.increment_processed()
444446

447+
# Sync on_step metrics across ranks and validate all ranks logged the same keys
448+
trainer._logger_connector.sync_on_step_metrics()
449+
445450
if using_dataloader_iter:
446451
# update the hook kwargs now that the step method might have consumed the iterator
447452
batch = data_fetcher._batch

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,10 @@ def on_advance_end(self) -> None:
479479
call._call_lightning_module_hook(trainer, "on_train_epoch_end")
480480
call._call_callback_hooks(trainer, "on_train_epoch_end", monitoring_callbacks=True)
481481

482+
# Sync on_epoch metrics across ranks and validate all ranks logged the same keys
483+
# Must be called before on_epoch_end() which computes the metrics
484+
trainer._logger_connector.sync_on_epoch_metrics()
485+
482486
trainer._logger_connector.on_epoch_end()
483487

484488
if not self.restarting and self.epoch_loop._num_ready_batches_reached():
@@ -489,6 +493,7 @@ def on_advance_end(self) -> None:
489493
# we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics
490494
# even when the batch loop has finished
491495
self.epoch_loop._batches_that_stepped -= 1
496+
492497
# log epoch metrics
493498
trainer._logger_connector.update_train_epoch_metrics()
494499
self.epoch_loop._batches_that_stepped += 1

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,9 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
355355

356356
self.batch_progress.increment_processed()
357357

358+
# Sync on_step metrics across ranks and validate all ranks logged the same keys
359+
trainer._logger_connector.sync_on_step_metrics()
360+
358361
# update non-plateau LR schedulers
359362
# update epoch-interval ones only when we are at the end of training epoch
360363
self.update_lr_schedulers("step", update_plateau_schedulers=False)

src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,36 @@ def on_batch_start(self, batch: Any, dataloader_idx: Optional[int] = None) -> No
190190
results.batch_size = None
191191
results.dataloader_idx = dataloader_idx
192192

193+
def sync_on_step_metrics(self) -> None:
194+
"""Synchronize on_step metrics across ranks.
195+
196+
This must be called at a point where ALL ranks are synchronized, typically right after
197+
training_step/validation_step returns. It validates that all ranks logged the same metric keys with
198+
sync_dist=True and performs the sync operations.
199+
200+
See
201+
https://github.com/Lightning-AI/pytorch-lightning/issues/21409
202+
203+
"""
204+
results = self.trainer._results
205+
if results is not None:
206+
results.sync_on_step_metrics()
207+
208+
def sync_on_epoch_metrics(self) -> None:
209+
"""Synchronize on_epoch metrics across ranks.
210+
211+
This must be called at a point where ALL ranks are synchronized, typically at epoch end before metrics are
212+
consumed. It validates that all ranks logged the same metric keys with sync_dist=True and performs the
213+
compute/sync operations.
214+
215+
See
216+
https://github.com/Lightning-AI/pytorch-lightning/issues/21409
217+
218+
"""
219+
results = self.trainer._results
220+
if results is not None:
221+
results.sync_on_epoch_metrics()
222+
193223
def epoch_end_reached(self) -> None:
194224
self._first_loop_iter = None
195225

src/lightning/pytorch/trainer/connectors/logger_connector/result.py

Lines changed: 140 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any, Callable, Optional, Union, cast
1818

1919
import torch
20+
import torch.distributed as dist
2021
from lightning_utilities.core.apply_func import apply_to_collection
2122
from torch import Tensor
2223
from torchmetrics import Metric
@@ -46,6 +47,42 @@ class _METRICS(TypedDict):
4647
warning_cache = WarningCache()
4748

4849

50+
def _assert_sync_dist_metric_keys_consistency(keys: list[str], fx: str, group: Optional[Any]) -> None:
51+
"""Validate that all ranks have the same metric keys for sync_dist operations.
52+
53+
This function must be called at a synchronization point where ALL ranks are guaranteed
54+
to participate. It uses all_gather_object to collect keys from all ranks and validates
55+
they are identical.
56+
57+
Args:
58+
keys: List of metric keys that need to be synchronized on this rank.
59+
fx: The hook name (e.g., 'training_step') for error messages.
60+
group: The process group to use for the collective operation.
61+
62+
Raises:
63+
MisconfigurationException: If ranks have different metric keys.
64+
65+
"""
66+
if not _distributed_is_initialized() or not dist.is_available():
67+
return
68+
world_size = dist.get_world_size(group=group)
69+
if world_size <= 1:
70+
return
71+
72+
gathered: list[object] = [None] * world_size
73+
dist.all_gather_object(gathered, keys, group=group)
74+
first = gathered[0]
75+
if any(item != first for item in gathered[1:]):
76+
ranks = "\n".join(f" rank={i}: {k}" for i, k in enumerate(gathered))
77+
raise MisconfigurationException(
78+
"When logging with `sync_dist=True`, all processes must log the same metric keys in the same order "
79+
f"within a given hook. Detected a mismatch during `{fx}`.\n"
80+
f"Synchronized metric keys per rank:\n{ranks}\n"
81+
"Either log the same keys on all ranks (for example by logging dummy values), or set `sync_dist=False` "
82+
"and manually synchronize (for example using `all_gather`)."
83+
)
84+
85+
4986
@dataclass
5087
class _Sync:
5188
fn: Optional[Callable] = None
@@ -202,6 +239,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
202239
self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum)
203240
# this is defined here only because upstream is missing the type annotation
204241
self._forward_cache: Optional[Any] = None
242+
self._forward_cache_synced: bool = False
205243

206244
@override
207245
def update(self, value: _VALUE, batch_size: int) -> None:
@@ -222,7 +260,10 @@ def update(self, value: _VALUE, batch_size: int) -> None:
222260
value = value.to(dtype)
223261

224262
if self.meta.on_step:
225-
self._forward_cache = self.meta.sync(value.clone()) # `clone` because `sync` is in-place
263+
# Defer sync to sync_on_step_metrics() which is called at a controlled synchronization point
264+
# This allows validating that all ranks have the same metric keys before syncing
265+
self._forward_cache = value.clone()
266+
self._forward_cache_synced = False
226267
# performance: no need to accumulate on values only logged on_step
227268
if not self.meta.on_epoch:
228269
self.value = self._forward_cache
@@ -239,7 +280,7 @@ def update(self, value: _VALUE, batch_size: int) -> None:
239280
self.value = self.value + value
240281
else:
241282
value = cast(Metric, value)
242-
self.value = value
283+
self.value = value # type: ignore[assignment]
243284
self._forward_cache = value._forward_cache
244285

245286
@override
@@ -421,6 +462,103 @@ def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
421462
result_metric.forward(value, batch_size)
422463
result_metric.has_reset = False
423464

465+
def sync_on_step_metrics(self) -> None:
466+
"""Synchronize all on_step metrics that have sync_dist=True.
467+
468+
This method must be called at a point where ALL ranks are synchronized (e.g., after
469+
training_step/validation_step returns). It:
470+
1. Gathers all metric keys that need syncing from all ranks
471+
2. Validates that all ranks have the same keys in the same order
472+
3. Performs the sync operations in a deterministic order
473+
474+
This approach prevents the silent data corruption that occurs when ranks log different
475+
metric keys with sync_dist=True.
476+
477+
See https://github.com/Lightning-AI/pytorch-lightning/issues/21409
478+
479+
"""
480+
if not _distributed_is_initialized():
481+
return
482+
483+
# Collect all metrics that need on_step sync
484+
items_to_sync: list[tuple[str, _ResultMetric]] = []
485+
for key, result_metric in self.valid_items():
486+
if (
487+
result_metric.meta.on_step
488+
and result_metric.is_tensor
489+
and result_metric.meta.sync.should
490+
and not result_metric.meta.sync.rank_zero_only
491+
and not result_metric._forward_cache_synced
492+
and result_metric._forward_cache is not None
493+
):
494+
items_to_sync.append((key, result_metric))
495+
496+
if not items_to_sync:
497+
return
498+
499+
# Get keys in order for validation
500+
keys = [key for key, _ in items_to_sync]
501+
fx = items_to_sync[0][1].meta.fx
502+
group = items_to_sync[0][1].meta.sync.group
503+
504+
# Validate all ranks have the same keys (this is a collective operation)
505+
_assert_sync_dist_metric_keys_consistency(keys, fx, group)
506+
507+
# Now perform the actual sync for each metric in order
508+
for _, result_metric in items_to_sync:
509+
if result_metric._forward_cache is not None:
510+
synced_value = result_metric.meta.sync(result_metric._forward_cache.clone())
511+
result_metric._forward_cache = synced_value
512+
result_metric._forward_cache_synced = True
513+
# Also update value if this is on_step only (not accumulated for on_epoch)
514+
if not result_metric.meta.on_epoch:
515+
result_metric.value = synced_value
516+
517+
def sync_on_epoch_metrics(self) -> None:
518+
"""Synchronize all on_epoch metrics that have sync_dist=True.
519+
520+
This method must be called at a point where ALL ranks are synchronized (e.g., at
521+
epoch end before metrics are consumed). It:
522+
1. Gathers all metric keys that need syncing from all ranks
523+
2. Validates that all ranks have the same keys in the same order
524+
3. Performs the compute() (which includes sync) in a deterministic order
525+
526+
This approach prevents the silent data corruption that occurs when ranks log different
527+
metric keys with sync_dist=True.
528+
529+
See https://github.com/Lightning-AI/pytorch-lightning/issues/21409
530+
531+
"""
532+
if not _distributed_is_initialized():
533+
return
534+
535+
# Collect all metrics that need on_epoch sync (not yet computed)
536+
items_to_sync: list[tuple[str, _ResultMetric]] = []
537+
for key, result_metric in self.valid_items():
538+
if (
539+
result_metric.meta.on_epoch
540+
and result_metric.is_tensor
541+
and result_metric.meta.sync.should
542+
and not result_metric.meta.sync.rank_zero_only
543+
and result_metric._computed is None # Not yet computed/synced
544+
):
545+
items_to_sync.append((key, result_metric))
546+
547+
if not items_to_sync:
548+
return
549+
550+
# Get keys in order for validation
551+
keys = [key for key, _ in items_to_sync]
552+
fx = items_to_sync[0][1].meta.fx
553+
group = items_to_sync[0][1].meta.sync.group
554+
555+
# Validate all ranks have the same keys (this is a collective operation)
556+
_assert_sync_dist_metric_keys_consistency(keys, fx, group)
557+
558+
# Now perform the actual compute (which includes sync) for each metric in order
559+
for _, result_metric in items_to_sync:
560+
result_metric.compute()
561+
424562
@staticmethod
425563
def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
426564
cache = None

0 commit comments

Comments
 (0)