diff --git a/TraceLens/NcclAnalyser/nccl_analyser.py b/TraceLens/NcclAnalyser/nccl_analyser.py index 35f96bd2..3305a423 100644 --- a/TraceLens/NcclAnalyser/nccl_analyser.py +++ b/TraceLens/NcclAnalyser/nccl_analyser.py @@ -208,7 +208,7 @@ def build_df_summary_long(self, agg_metrics=['mean', 'std', 'min', 'max'], # Step 2: Build a wide table for implicit sync class # where each row is a collective operation # ------------------------------------------------------------------------ - def build_df_nccl_implicit_sync_cat(self, detailed=False): + def build_df_nccl_implicit_sync_cat(self, detailed=False, strict_metadata_check=True): """ Builds a single DF with one row *per collective ID*, including per-rank ts/dur + metadata. Ensures metadata consistency across ranks. @@ -237,7 +237,9 @@ def build_df_nccl_implicit_sync_cat(self, detailed=False): for field in metadata_fields: unique_values = rank_events[field].unique() if len(unique_values) > 1: - raise ValueError(f"Metadata mismatch in '{field}' for collective {cid}: {unique_values}") + if strict_metadata_check: + raise ValueError(f"Metadata mismatch in '{field}' for collective {cid}: {unique_values}") + warnings.warn(f"Metadata mismatch in '{field}' for collective {cid}: {unique_values}") row = {'collective_id': cid, **ref_metadata} @@ -301,13 +303,14 @@ def build_df_nccl_implicit_sync_cat(self, detailed=False): def build_df_summary_nccl_implicit_sync_cat(self, agg_metrics=['mean', 'std'], - metadata_fields=["Process Group Name", "Group size", "Full msg size (MB)"]): + metadata_fields=["Process Group Name", "Group size", "Full msg size (MB)"], + strict_metadata_check=True): """ Builds a summary DF with one row per collective name, dtype, and msg size. Aggregates across all collectives and ranks. """ if not hasattr(self, 'df_implicit_sync_cat'): - self.df_implicit_sync_cat = self.build_df_nccl_implicit_sync_cat() + self.df_implicit_sync_cat = self.build_df_nccl_implicit_sync_cat(strict_metadata_check=strict_metadata_check) # Aggregation logic @@ -350,7 +353,7 @@ def build_df_summary_nccl_implicit_sync_cat(self, agg_metrics=['mean', 'std'], summary_df = summary_df[columns_order] return summary_df - def build_df_nccl_all2allv(self, detailed=False): + def build_df_nccl_all2allv(self, detailed=False, strict_metadata_check=True): # this is diff from implicit sync cat # first, each rank can send and receive different amount of data # as a result they do not respect the implicit sync cat @@ -383,7 +386,9 @@ def build_df_nccl_all2allv(self, detailed=False): for field in metadata_fields: unique_values = rank_events[field].unique() if len(unique_values) > 1: - raise ValueError(f"Metadata mismatch in '{field}' for collective {cid}") + if strict_metadata_check: + raise ValueError(f"Metadata mismatch in '{field}' for collective {cid}") + warnings.warn(f"Metadata mismatch in '{field}' for collective {cid}") # **Common metadata** row = {'collective_id': cid, **ref_metadata}