Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions TraceLens/NcclAnalyser/nccl_analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def build_df_summary_long(self, agg_metrics=['mean', 'std', 'min', 'max'],
# Step 2: Build a wide table for implicit sync class
# where each row is a collective operation
# ------------------------------------------------------------------------
def build_df_nccl_implicit_sync_cat(self, detailed=False):
def build_df_nccl_implicit_sync_cat(self, detailed=False, strict_metadata_check=True):
"""
Builds a single DF with one row *per collective ID*, including per-rank ts/dur + metadata.
Ensures metadata consistency across ranks.
Expand Down Expand Up @@ -237,7 +237,9 @@ def build_df_nccl_implicit_sync_cat(self, detailed=False):
for field in metadata_fields:
unique_values = rank_events[field].unique()
if len(unique_values) > 1:
raise ValueError(f"Metadata mismatch in '{field}' for collective {cid}: {unique_values}")
if strict_metadata_check:
raise ValueError(f"Metadata mismatch in '{field}' for collective {cid}: {unique_values}")
warnings.warn(f"Metadata mismatch in '{field}' for collective {cid}: {unique_values}")

row = {'collective_id': cid, **ref_metadata}

Expand Down Expand Up @@ -301,13 +303,14 @@ def build_df_nccl_implicit_sync_cat(self, detailed=False):


def build_df_summary_nccl_implicit_sync_cat(self, agg_metrics=['mean', 'std'],
metadata_fields=["Process Group Name", "Group size", "Full msg size (MB)"]):
metadata_fields=["Process Group Name", "Group size", "Full msg size (MB)"],
strict_metadata_check=True):
"""
Builds a summary DF with one row per collective name, dtype, and msg size.
Aggregates across all collectives and ranks.
"""
if not hasattr(self, 'df_implicit_sync_cat'):
self.df_implicit_sync_cat = self.build_df_nccl_implicit_sync_cat()
self.df_implicit_sync_cat = self.build_df_nccl_implicit_sync_cat(strict_metadata_check=strict_metadata_check)

# Aggregation logic

Expand Down Expand Up @@ -350,7 +353,7 @@ def build_df_summary_nccl_implicit_sync_cat(self, agg_metrics=['mean', 'std'],
summary_df = summary_df[columns_order]
return summary_df

def build_df_nccl_all2allv(self, detailed=False):
def build_df_nccl_all2allv(self, detailed=False, strict_metadata_check=True):
# this is diff from implicit sync cat
# first, each rank can send and receive different amount of data
# as a result they do not respect the implicit sync cat
Expand Down Expand Up @@ -383,7 +386,9 @@ def build_df_nccl_all2allv(self, detailed=False):
for field in metadata_fields:
unique_values = rank_events[field].unique()
if len(unique_values) > 1:
raise ValueError(f"Metadata mismatch in '{field}' for collective {cid}")
if strict_metadata_check:
raise ValueError(f"Metadata mismatch in '{field}' for collective {cid}")
warnings.warn(f"Metadata mismatch in '{field}' for collective {cid}")

# **Common metadata**
row = {'collective_id': cid, **ref_metadata}
Expand Down