1717from typing import Any , Callable , Optional , Union , cast
1818
1919import torch
20+ import torch .distributed as dist
2021from lightning_utilities .core .apply_func import apply_to_collection
2122from torch import Tensor
2223from torchmetrics import Metric
@@ -46,6 +47,42 @@ class _METRICS(TypedDict):
4647warning_cache = WarningCache ()
4748
4849
50+ def _assert_sync_dist_metric_keys_consistency (keys : list [str ], fx : str , group : Optional [Any ]) -> None :
51+ """Validate that all ranks have the same metric keys for sync_dist operations.
52+
53+ This function must be called at a synchronization point where ALL ranks are guaranteed
54+ to participate. It uses all_gather_object to collect keys from all ranks and validates
55+ they are identical.
56+
57+ Args:
58+ keys: List of metric keys that need to be synchronized on this rank.
59+ fx: The hook name (e.g., 'training_step') for error messages.
60+ group: The process group to use for the collective operation.
61+
62+ Raises:
63+ MisconfigurationException: If ranks have different metric keys.
64+
65+ """
66+ if not _distributed_is_initialized () or not dist .is_available ():
67+ return
68+ world_size = dist .get_world_size (group = group )
69+ if world_size <= 1 :
70+ return
71+
72+ gathered : list [object ] = [None ] * world_size
73+ dist .all_gather_object (gathered , keys , group = group )
74+ first = gathered [0 ]
75+ if any (item != first for item in gathered [1 :]):
76+ ranks = "\n " .join (f" rank={ i } : { k } " for i , k in enumerate (gathered ))
77+ raise MisconfigurationException (
78+ "When logging with `sync_dist=True`, all processes must log the same metric keys in the same order "
79+ f"within a given hook. Detected a mismatch during `{ fx } `.\n "
80+ f"Synchronized metric keys per rank:\n { ranks } \n "
81+ "Either log the same keys on all ranks (for example by logging dummy values), or set `sync_dist=False` "
82+ "and manually synchronize (for example using `all_gather`)."
83+ )
84+
85+
4986@dataclass
5087class _Sync :
5188 fn : Optional [Callable ] = None
@@ -202,6 +239,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
202239 self .add_state ("cumulated_batch_size" , torch .tensor (0 ), dist_reduce_fx = torch .sum )
203240 # this is defined here only because upstream is missing the type annotation
204241 self ._forward_cache : Optional [Any ] = None
242+ self ._forward_cache_synced : bool = False
205243
206244 @override
207245 def update (self , value : _VALUE , batch_size : int ) -> None :
@@ -222,7 +260,10 @@ def update(self, value: _VALUE, batch_size: int) -> None:
222260 value = value .to (dtype )
223261
224262 if self .meta .on_step :
225- self ._forward_cache = self .meta .sync (value .clone ()) # `clone` because `sync` is in-place
263+ # Defer sync to sync_on_step_metrics() which is called at a controlled synchronization point
264+ # This allows validating that all ranks have the same metric keys before syncing
265+ self ._forward_cache = value .clone ()
266+ self ._forward_cache_synced = False
226267 # performance: no need to accumulate on values only logged on_step
227268 if not self .meta .on_epoch :
228269 self .value = self ._forward_cache
@@ -239,7 +280,7 @@ def update(self, value: _VALUE, batch_size: int) -> None:
239280 self .value = self .value + value
240281 else :
241282 value = cast (Metric , value )
242- self .value = value
283+ self .value = value # type: ignore[assignment]
243284 self ._forward_cache = value ._forward_cache
244285
245286 @override
@@ -421,6 +462,103 @@ def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
421462 result_metric .forward (value , batch_size )
422463 result_metric .has_reset = False
423464
465+ def sync_on_step_metrics (self ) -> None :
466+ """Synchronize all on_step metrics that have sync_dist=True.
467+
468+ This method must be called at a point where ALL ranks are synchronized (e.g., after
469+ training_step/validation_step returns). It:
470+ 1. Gathers all metric keys that need syncing from all ranks
471+ 2. Validates that all ranks have the same keys in the same order
472+ 3. Performs the sync operations in a deterministic order
473+
474+ This approach prevents the silent data corruption that occurs when ranks log different
475+ metric keys with sync_dist=True.
476+
477+ See https://github.com/Lightning-AI/pytorch-lightning/issues/21409
478+
479+ """
480+ if not _distributed_is_initialized ():
481+ return
482+
483+ # Collect all metrics that need on_step sync
484+ items_to_sync : list [tuple [str , _ResultMetric ]] = []
485+ for key , result_metric in self .valid_items ():
486+ if (
487+ result_metric .meta .on_step
488+ and result_metric .is_tensor
489+ and result_metric .meta .sync .should
490+ and not result_metric .meta .sync .rank_zero_only
491+ and not result_metric ._forward_cache_synced
492+ and result_metric ._forward_cache is not None
493+ ):
494+ items_to_sync .append ((key , result_metric ))
495+
496+ if not items_to_sync :
497+ return
498+
499+ # Get keys in order for validation
500+ keys = [key for key , _ in items_to_sync ]
501+ fx = items_to_sync [0 ][1 ].meta .fx
502+ group = items_to_sync [0 ][1 ].meta .sync .group
503+
504+ # Validate all ranks have the same keys (this is a collective operation)
505+ _assert_sync_dist_metric_keys_consistency (keys , fx , group )
506+
507+ # Now perform the actual sync for each metric in order
508+ for _ , result_metric in items_to_sync :
509+ if result_metric ._forward_cache is not None :
510+ synced_value = result_metric .meta .sync (result_metric ._forward_cache .clone ())
511+ result_metric ._forward_cache = synced_value
512+ result_metric ._forward_cache_synced = True
513+ # Also update value if this is on_step only (not accumulated for on_epoch)
514+ if not result_metric .meta .on_epoch :
515+ result_metric .value = synced_value
516+
517+ def sync_on_epoch_metrics (self ) -> None :
518+ """Synchronize all on_epoch metrics that have sync_dist=True.
519+
520+ This method must be called at a point where ALL ranks are synchronized (e.g., at
521+ epoch end before metrics are consumed). It:
522+ 1. Gathers all metric keys that need syncing from all ranks
523+ 2. Validates that all ranks have the same keys in the same order
524+ 3. Performs the compute() (which includes sync) in a deterministic order
525+
526+ This approach prevents the silent data corruption that occurs when ranks log different
527+ metric keys with sync_dist=True.
528+
529+ See https://github.com/Lightning-AI/pytorch-lightning/issues/21409
530+
531+ """
532+ if not _distributed_is_initialized ():
533+ return
534+
535+ # Collect all metrics that need on_epoch sync (not yet computed)
536+ items_to_sync : list [tuple [str , _ResultMetric ]] = []
537+ for key , result_metric in self .valid_items ():
538+ if (
539+ result_metric .meta .on_epoch
540+ and result_metric .is_tensor
541+ and result_metric .meta .sync .should
542+ and not result_metric .meta .sync .rank_zero_only
543+ and result_metric ._computed is None # Not yet computed/synced
544+ ):
545+ items_to_sync .append ((key , result_metric ))
546+
547+ if not items_to_sync :
548+ return
549+
550+ # Get keys in order for validation
551+ keys = [key for key , _ in items_to_sync ]
552+ fx = items_to_sync [0 ][1 ].meta .fx
553+ group = items_to_sync [0 ][1 ].meta .sync .group
554+
555+ # Validate all ranks have the same keys (this is a collective operation)
556+ _assert_sync_dist_metric_keys_consistency (keys , fx , group )
557+
558+ # Now perform the actual compute (which includes sync) for each metric in order
559+ for _ , result_metric in items_to_sync :
560+ result_metric .compute ()
561+
424562 @staticmethod
425563 def _get_cache (result_metric : _ResultMetric , on_step : bool ) -> Optional [Tensor ]:
426564 cache = None
0 commit comments