Skip to content

bench: report cross-rank timing stats#493

Open
mawad-amd wants to merge 3 commits intomainfrom
muhaawad/bench-cross-rank-stats
Open

bench: report cross-rank timing stats#493
mawad-amd wants to merge 3 commits intomainfrom
muhaawad/bench-cross-rank-stats

Conversation

@mawad-amd
Copy link
Copy Markdown
Collaborator

Summary

  • Report timing from all ranks instead of only rank 0. Each rank computes the median of its iteration times, then dist.all_gather collects every rank's median to rank 0.
  • Headline GPU Time is now max across ranks (the true collective latency — slowest rank determines when the operation is done).
  • Two new counter columns: min_ms (fastest rank) and skew_% ((max - min) / max * 100, straggler penalty as a percentage).
  • Supports both NCCL (CUDA tensors) and gloo (CPU tensors) backends for the gather.

Example output on MI308X ×8:

num_ranks     M     N     dtype   variant  GPU Time (ms)  BW (GB/s)  min_ms  skew_%
        8  1024  1024  bfloat16  two_shot          0.217       16.9   0.215   0.953
        8  4096  1024  bfloat16  two_shot          0.272       53.9   0.270   0.758

Test plan

  • Ran bench_all_reduce.py on MI308X ×8 with --axis_num_ranks=8 --axis_M=1024,4096 --axis_N=1024 --axis_dtype=bf16
  • Verify JSON and CSV output formats include new counters
  • Test with --axis_num_ranks=2,4,8 to confirm multi-launch still works

🤖 Generated with Claude Code

mawad-amd and others added 2 commits March 31, 2026 23:03
Instead of only reporting rank 0's mean, gather timing from all ranks
and report:
- gpu_time_ms: max across ranks (true collective latency)
- min_ms: fastest rank's median
- skew_%: (max - min) / max * 100 — straggler penalty as percentage

Per-rank summary uses median across iterations (robust to outliers).
Cross-rank aggregation uses all_gather to collect every rank's median.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Use CPU tensors for all_gather when backend is gloo instead of
assuming CUDA.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@mawad-amd mawad-amd requested review from BKP and neoblizz as code owners April 1, 2026 06:30
Copilot AI review requested due to automatic review settings April 1, 2026 06:30
@github-actions github-actions bot added in-progress We are working on it iris Iris project issue labels Apr 1, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates the benchmark runner to report cross-rank timing statistics so the headline “GPU Time” reflects true collective latency (slowest rank), and exposes cross-rank variability via new counters.

Changes:

  • Compute per-rank median iteration time and aggregate timing stats across ranks.
  • Define “GPU Time” as max rank median, and compute bandwidth/TFLOPs from that value.
  • Add min_ms and skew_% counters to quantify straggler impact.

Comment on lines +433 to +439
# Cross-rank: gather every rank's median to compute
# max (true collective latency), min, and skew.
gather_device = "cuda" if backend == "nccl" else "cpu"
local_t = torch.tensor([local_median_ms], device=gather_device)
gathered = [torch.zeros(1, device=gather_device) for _ in range(world_size)]
dist.all_gather(gathered, local_t)
rank_medians = [t.item() for t in gathered]

This comment was marked as resolved.

Comment on lines +433 to +442
# Cross-rank: gather every rank's median to compute
# max (true collective latency), min, and skew.
gather_device = "cuda" if backend == "nccl" else "cpu"
local_t = torch.tensor([local_median_ms], device=gather_device)
gathered = [torch.zeros(1, device=gather_device) for _ in range(world_size)]
dist.all_gather(gathered, local_t)
rank_medians = [t.item() for t in gathered]

max_ms = max(rank_medians)
min_ms = min(rank_medians)

This comment was marked as resolved.


counters = dict(state._counters)
counters["min_ms"] = min_ms
counters["skew_%"] = skew_pct

This comment was marked as resolved.

Comment on lines +436 to +437
local_t = torch.tensor([local_median_ms], device=gather_device)
gathered = [torch.zeros(1, device=gather_device) for _ in range(world_size)]

This comment was marked as resolved.

mawad-amd

This comment was marked as outdated.

Identifier-friendly key name for programmatic access (CSV/JSON).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants