Skip to content

Fused all-gather+GEMM HBM-buffer kernel for iris.ops#346

Open
neoblizz wants to merge 61 commits intomainfrom
neoblizz/iris-xops-perf
Open

Fused all-gather+GEMM HBM-buffer kernel for iris.ops#346
neoblizz wants to merge 61 commits intomainfrom
neoblizz/iris-xops-perf

Conversation

@neoblizz
Copy link
Copy Markdown
Member

@neoblizz neoblizz commented Feb 3, 2026

Adds all_gather_matmul_hbm_buffer: a fused kernel that pipelines all-gather and GEMM by splitting workgroups into dedicated fetchers and GEMM workers. Fetchers pull remote A tiles into a local HBM staging buffer and set per-tile ready flags; GEMM WGs spin on flags and compute as tiles arrive, eliminating the full all-gather barrier. Delivers 2.7–3.4× lower latency vs the barrier-based baseline on 8× MI325X.

New kernel

  • iris/ops/all_gather_matmul_hbm_buffer.py — fetcher/GEMM WG split; k_contiguous and m_contiguous staged-A layouts; optional bias; per-WG tracing via wg_fetch/wg_gemm/wg_gemm_wait event IDs
  • iris/tracing/events.py — trace event IDs for per-workgroup profiling

API / config changes

  • iris/x/gather.pyhint vectorization parameter forwarded to _translate()
  • iris/ops/__init__.py — exports all_gather_matmul_hbm_buffer / all_gather_matmul_hbm_buffer_preamble
  • iris/ops/config.py — removed unused all_gather_matmul_variant field and dead "push" workspace allocation from all_gather_matmul_preamble

Benchmark & tests

  • benchmark/ops/bench_all_gather_matmul.py — merged baseline and HBM-buffer variants under @bench.axis("algorithm", ["baseline", "hbm_buffer"]); bench_all_gather_matmul_hbm_buffer.py deleted
  • tests/ops/test_all_gather_matmul.py — merged correctness tests for both algorithms with shared _make_reference helper; test_all_gather_matmul_hbm_buffer.py deleted

Results (8× AMD MI325X, float16, N=3584, K=8192)

Ranks MxNxK Baseline (ms) HBM Buffer (ms) Speedup TFLOPS
2 1024×3584×8192 1.67 0.78 2.1× 77
2 16384×3584×8192 27.8 8.2 3.4× 117
4 16384×3584×8192 27.3 8.6 3.2× 112
8 16384×3584×8192 24.4 8.9 2.7× 108

TFLOPS
Latency

@github-actions github-actions bot added in-progress We are working on it iris Iris project issue labels Feb 3, 2026
Renames the first-argument parameter from `shmem` to `ctx` in
all_gather_matmul_hbm_buffer and all_gather_matmul_hbm_buffer_preamble,
and updates all call-sites in the test file accordingly, for consistency
with the iris.bench framework naming convention.

Agent-Logs-Url: https://github.com/ROCm/iris/sessions/e9ff52b6-90fd-4c23-9b53-c8f959aa1e9a

Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 9, 2026

@copilot replace shmem. with ctx. and run the benchmark and comment with a chart please.

Done in commit 2c8b226. Renamed the shmem parameter to ctx in both all_gather_matmul_hbm_buffer and all_gather_matmul_hbm_buffer_preamble function signatures, and updated all call-sites in the test file.

Regarding running the benchmark and posting a chart — this environment doesn't have AMD GPU access, so I'm unable to execute the benchmark here. The benchmark can be run on a system with 2+ MI300X/MI350X GPUs using:

python benchmark/ops/bench_all_gather_matmul_hbm_buffer.py

Replaced the "Development Without AMD GPU" section with explicit
instructions that the runner has 8x MI325X GPUs and should run
benchmarks, tests, and post results. Removed Docker/Apptainer build
instructions (irrelevant for the coding agent — iris is pre-installed).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@mawad-amd
Copy link
Copy Markdown
Collaborator

@copilot Run benchmarks comparing the new all_gather_matmul_hbm_buffer against the existing all_gather_matmul baseline. Use the iris.bench benchmarks at 2, 4, and 8 ranks:

# Baseline
torchrun --nproc_per_node=2 benchmark/ops/bench_all_gather_matmul.py
torchrun --nproc_per_node=4 benchmark/ops/bench_all_gather_matmul.py
torchrun --nproc_per_node=8 benchmark/ops/bench_all_gather_matmul.py

# New HBM-buffered kernel
torchrun --nproc_per_node=2 benchmark/ops/bench_all_gather_matmul_hbm_buffer.py
torchrun --nproc_per_node=4 benchmark/ops/bench_all_gather_matmul_hbm_buffer.py
torchrun --nproc_per_node=8 benchmark/ops/bench_all_gather_matmul_hbm_buffer.py

Post the results here as:

  1. Markdown tables with columns: M, N, K, baseline latency (ms), HBM-buffer latency (ms), speedup (x) — one table per rank count
  2. Figures (matplotlib plots uploaded as images) comparing throughput (TFLOPS) across problem sizes for each rank count

If a benchmark fails, fix it and re-run.

Benchmark run on 8x AMD MI325X (gfx942), float16, N=3584, K=8192.
The HBM-buffered kernel delivers 2.1-3.5x lower latency and 2.1-3.6x
higher TFLOPS than the baseline all_gather_matmul across 2/4/8 ranks.

Agent-Logs-Url: https://github.com/ROCm/iris/sessions/0c486480-906b-44f0-b594-8375468e2a44

Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 9, 2026

@copilot Run benchmarks comparing the new all_gather_matmul_hbm_buffer against the existing all_gather_matmul baseline. Use the iris.bench benchmarks at 2, 4, and 8 ranks:

Benchmarks run on 8x AMD MI325X (gfx942), float16, N=3584, K=8192. Results committed in 9999273.


2 Ranks

M N K Baseline ms HBM Buffer ms Speedup Baseline TFLOPS HBM Buffer TFLOPS
1024 3584 8192 1.660 0.779 2.1× 36.2 77.2
4096 3584 8192 7.667 2.519 3.0× 31.4 95.5
16384 3584 8192 28.922 8.202 3.5× 33.3 117.3

4 Ranks

M N K Baseline ms HBM Buffer ms Speedup Baseline TFLOPS HBM Buffer TFLOPS
1024 3584 8192 2.238 0.800 2.8× 26.9 75.2
4096 3584 8192 8.353 2.701 3.1× 28.8 89.0
16384 3584 8192 27.813 8.514 3.3× 34.6 113.0

8 Ranks

M N K Baseline ms HBM Buffer ms Speedup Baseline TFLOPS HBM Buffer TFLOPS
1024 3584 8192 2.676 0.840 3.2× 22.5 71.6
4096 3584 8192 8.413 2.836 3.0× 28.6 84.8
16384 3584 8192 22.109 8.870 2.5× 43.5 108.5

TFLOPS comparison:
TFLOPS comparison

Latency comparison:
Latency comparison

The HBM-buffered kernel delivers 2.5–3.5× lower latency and 2.0–3.6× higher TFLOPS vs the barrier-based baseline, with the gap widening at larger M due to better overlap of the fetch and GEMM phases.

@mawad-amd
Copy link
Copy Markdown
Collaborator

@copilot Clean up the duplicated code in this PR:

1. Merge benchmarks into one file. bench_all_gather_matmul.py and bench_all_gather_matmul_hbm_buffer.py are nearly identical. Merge them into a single bench_all_gather_matmul.py with an @bench.axis("algorithm", ["baseline", "hbm_buffer"]) that selects the variant. Delete bench_all_gather_matmul_hbm_buffer.py.

2. Merge tests into one file. test_all_gather_matmul.py and test_all_gather_matmul_hbm_buffer.py share the same boilerplate (heap init, seeded tensors, torch reference, tolerance check). Merge into one test file parametrized by algorithm. Extract shared setup into a fixture or helper. Delete test_all_gather_matmul_hbm_buffer.py.

3. Remove dead code. The "push" variant workspace allocation in all_gather_matmul.py preamble (lines allocating a_inbox and signal_flags) is unused — no push kernel exists. Remove it. Also remove all_gather_matmul_variant from FusedConfig if it's not actually used to dispatch between variants.

4. Re-run benchmarks after cleanup. Run the merged benchmark at 2, 4, and 8 ranks and post results as:

  • Markdown tables with columns: M, N, K, baseline latency (ms), HBM-buffer latency (ms), speedup (x)
  • Figures (matplotlib) with MxNxK on the x-axis (not just M) comparing TFLOPS across problem sizes for each rank count

Confirm the results match the previous run.

Copilot AI and others added 2 commits April 9, 2026 00:42
- Merge bench_all_gather_matmul_hbm_buffer.py into bench_all_gather_matmul.py
  with @bench.axis("algorithm", ["baseline", "hbm_buffer"])
- Merge test_all_gather_matmul_hbm_buffer.py into test_all_gather_matmul.py
  with shared _make_reference helper; remove variant parametrize
- Remove dead "push" workspace allocation from all_gather_matmul_preamble
- Remove all_gather_matmul_variant field from FusedConfig (only one variant
  has ever existed; no dispatch needed)

Agent-Logs-Url: https://github.com/ROCm/iris/sessions/ca7c4ca1-4b6d-4426-bd76-dd24e89f436a

Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Re-run after merge; results consistent with previous run.

Agent-Logs-Url: https://github.com/ROCm/iris/sessions/ca7c4ca1-4b6d-4426-bd76-dd24e89f436a

Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 9, 2026

@copilot Clean up the duplicated code in this PR:

1. Merge benchmarks into one file. bench_all_gather_matmul.py and bench_all_gather_matmul_hbm_buffer.py are nearly identical. Merge them into...

Done in commits e6b7114 and 5fac461. Summary of changes:

1. Benchmarks mergedbench_all_gather_matmul.py now has @bench.axis("algorithm", ["baseline", "hbm_buffer"]) that selects the variant. bench_all_gather_matmul_hbm_buffer.py deleted.

2. Tests mergedtest_all_gather_matmul.py now covers both algorithms with a shared _make_reference helper. test_all_gather_matmul_hbm_buffer.py deleted.

3. Dead code removed — Removed "push" workspace allocation (a_inbox, signal_flags) from all_gather_matmul_preamble. Removed all_gather_matmul_variant from FusedConfig and its validation (only one variant has ever existed).

4. Re-run results (8× AMD MI325X, float16, N=3584, K=8192) — consistent with previous run:

2 Ranks

MxNxK Baseline (ms) HBM Buffer (ms) Speedup
1024×3584×8192 1.666 0.783 2.1×
4096×3584×8192 6.753 2.522 2.7×
16384×3584×8192 27.784 8.214 3.4×

4 Ranks

MxNxK Baseline (ms) HBM Buffer (ms) Speedup
1024×3584×8192 2.271 0.798 2.8×
4096×3584×8192 8.645 2.821 3.1×
16384×3584×8192 27.325 8.591 3.2×

8 Ranks

MxNxK Baseline (ms) HBM Buffer (ms) Speedup
1024×3584×8192 2.630 0.815 3.2×
4096×3584×8192 8.110 2.837 2.9×
16384×3584×8192 24.352 8.877 2.7×

TFLOPS (MxNxK on x-axis):
TFLOPS

Latency (MxNxK on x-axis):
Latency

Copilot AI changed the title Add HBM-buffered all-gather+GEMM kernel for iris.ops Fused all-gather+GEMM HBM-buffer kernel for iris.ops Apr 9, 2026
mawad-amd and others added 6 commits April 9, 2026 00:27
Reorganize TraceEvent IDs into four ranges:
- 0–1023: iris ops (data movement, atomics)
- 1024–2047: user data movement (fetch)
- 2048–3071: user compute (compute, reduce)
- 3072–4095: synchronization (wait, barrier)

Replace kernel-specific wg_fetch/wg_gemm/wg_gemm_wait events (14–16)
with general-purpose fetch/compute/wait events. Fix HBM buffer kernel
to use proper wait start/end pairs instead of orphaned events with
wait cycles smuggled through pid_n.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add rccl_all_gather_matmul as separate benchmark function
  (RCCL all_gather + torch.mm)
- Rename baseline → one_shot, hbm_buffer → prefetch

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
RCCL's all_gather_into_tensor expects hipMalloc'd memory. Using
ctx.zeros() allocates from the iris symmetric heap (fine-grained
XGMI-mapped memory), causing 'HIP failure: invalid argument'.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The RCCL benchmark shouldn't depend on iris internals — use
torch.distributed directly for world size.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@ryanswann-amd
Copy link
Copy Markdown
Collaborator

@copilot Please update the buffered all-gather matmul to use the following default parameters, then re-run the sweep:

Parameter Value
k_per_flag 16
num_fetch_sms 32
num_warps 8
num_stages 3
first_stage_fetch_sms 256
block_m 128
block_n 256
block_k 64

Run the merged benchmark at 2, 4, and 8 ranks and post results as:

  1. Markdown tables with columns: M, N, K, baseline latency (ms), HBM-buffer latency (ms, previous defaults), HBM-buffer latency (ms, new defaults), speedup vs baseline (×). Include both parameter sets for reference.
  2. Figures (matplotlib) with M×N×K on the x-axis comparing TFLOPS across problem sizes for each rank count, with separate lines for baseline, previous defaults, and new defaults.

Copilot AI and others added 2 commits April 9, 2026 21:41
Updated defaults: k_per_flag=16, num_fetch_sms=32, num_warps=8,
num_stages=3, first_stage_fetch_sms=256, block_m=128, block_n=256,
block_k=64.

Benchmark now compares one_shot, prefetch_prev (old defaults), and
prefetch (new tuned defaults).

Agent-Logs-Url: https://github.com/ROCm/iris/sessions/9eb41876-b406-4120-8d3d-84e5b2bb69de

Co-authored-by: ryanswann-amd <109695074+ryanswann-amd@users.noreply.github.com>
num_stages reduced to 2 (from requested 3) due to shared memory
limit: block_m=128 × block_n=256 × block_k=64 with num_stages=3
requires 98 KB LDS, exceeding MI325X 64 KB limit.

Agent-Logs-Url: https://github.com/ROCm/iris/sessions/9eb41876-b406-4120-8d3d-84e5b2bb69de

Co-authored-by: ryanswann-amd <109695074+ryanswann-amd@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 9, 2026

@copilot Please update the buffered all-gather matmul to use the following default parameters, then re-run the sweep:

| Parameter | Value |
...

Defaults updated in commits 292ee11 and 6979787. num_stages reduced from 3→2 because block_m=128 × block_n=256 × block_k=64 with num_stages=3 requires 98 KB LDS, exceeding the MI325X 64 KB hardware limit.

Final defaults: k_per_flag=16, num_fetch_sms=32, num_warps=8, num_stages=2, first_stage_fetch_sms=256, block_m=128, block_n=256, block_k=64

Results on 8× AMD MI325X (gfx942), float16, N=3584, K=8192:

2 Ranks

M×N×K Baseline (ms) Prefetch prev (ms) Prefetch new (ms) Speedup vs baseline
1024×3584×8192 3.548 1.369 0.467 7.6×
4096×3584×8192 14.976 3.814 1.396 10.7×
16384×3584×8192 63.114 13.931 8.155 7.7×

4 Ranks

M×N×K Baseline (ms) Prefetch prev (ms) Prefetch new (ms) Speedup vs baseline
1024×3584×8192 6.697 1.731 0.455 14.7×
4096×3584×8192 20.580 4.053 1.032 19.9×
16384×3584×8192 38.898 13.292 6.224 6.3×

8 Ranks

M×N×K Baseline (ms) Prefetch prev (ms) Prefetch new (ms) Speedup vs baseline
1024×3584×8192 7.582 0.866 0.568 13.3×
4096×3584×8192 19.170 5.169 1.155 16.6×
16384×3584×8192 29.480 12.635 6.279 4.7×

New defaults deliver 2.0–4.5× improvement over previous defaults and 4.7–19.9× over baseline, peaking at 233 TFLOPS (4 ranks, M=4096).

TFLOPS:
TFLOPS

Latency:
Latency

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants