Skip to content

feat: ring-based all_gather with workspace preamble#490

Closed
mawad-amd wants to merge 16 commits intomainfrom
muhaawad/allgather-ring-chunked
Closed

feat: ring-based all_gather with workspace preamble#490
mawad-amd wants to merge 16 commits intomainfrom
muhaawad/allgather-ring-chunked

Conversation

@mawad-amd
Copy link
Copy Markdown
Collaborator

Summary

  • Add ring-based all_gather variant using flag-based producer/consumer sync
  • Each rank forwards data around a ring (O(1) writes per step vs O(N) for persistent)
  • Add AllGatherWorkspace + all_gather_preamble() for pre-allocated scratch buffers
  • Fix Triton negative modulo bug (% uses C truncated-division, not Python floored)
  • Rename shmemctx in all_gather module and ring tests

Test plan

  • All 42 existing all_gather tests pass on 8x MI355X (persistent + partitioned + ring)
  • New test_all_gather_ring covers 4 shapes × 3 dtypes = 12 test cases
  • Manual correctness test with 2, 4, and 8 ranks
  • Performance study: gist

🤖 Generated with Claude Code

mawad-amd and others added 7 commits March 27, 2026 21:45
Ring-based all-gather where each rank forwards data around a ring
instead of all-pairs writes. In each step, rank r reads a shard from
its output buffer and iris.store()s it to the next rank. After N-1
steps every rank has all shards. Uses per-tile flag-based
producer/consumer sync.

Benefits: O(1) peer writes per step (vs O(N) fan-out), avoids
memory-controller contention for large messages.

- New kernel: persistent_all_gather_ring
- Config: all_gather_variant="ring"
- Tests: test_all_gather_ring with multiple shapes and dtypes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Mirrors the proven all_reduce ring kernel: uses a separate ring_buffer
on the symmetric heap for receiving data, simple 0/1 flag toggling,
no step-counting. Each step: send current data to next rank's
ring_buffer, wait for predecessor to write into our ring_buffer,
copy received data to correct output slot, forward it next step.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Triton's % operator uses C truncated-division semantics where
(-1) % 4 = -1, not Python's floored-division where (-1) % 4 = 3.
This caused recv_rank_idx and source_rank_idx to be negative for
ranks where group_rank < step, writing to invalid output locations.

Fix: add world_size before the modulo to ensure non-negative operands.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Align with codebase convention (README, benchmarks, docs all use ctx).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pre-allocates ring_buffer and flags on the symmetric heap once.
Pass the workspace to all_gather to avoid per-call allocation overhead
(~13ms saved per call in benchmarks).

Follows the same pattern as AllReduceWorkspace / all_reduce_preamble.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@mawad-amd mawad-amd requested review from BKP and neoblizz as code owners March 28, 2026 06:27
Copilot AI review requested due to automatic review settings March 28, 2026 06:27
@github-actions github-actions bot added in-progress We are working on it iris Iris project issue labels Mar 28, 2026
mawad-amd and others added 9 commits March 27, 2026 23:49
Wraps remote operations (iris.atomic_cas, iris.store, iris.atomic_xchg)
with DeviceTracing record_event_start/end for per-operation timing.
Zero overhead when tracing=False via constexpr dead-code elimination.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The tracing API's record_event_start requires block pointers (for tl.min
reduction), not scalar pointers. Restructure to trace full ring steps
using the ring_buffer tile address instead of individual atomic ops.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… flags

Replace O(tiles * steps) remote atomic operations with O(steps) by using
a global barrier. All CUs write tiles in parallel for each step, then
synchronize via a local counter + single remote signal. This eliminates
the massive atomic contention that caused ~10x slowdown vs RCCL.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Previous global barrier had two bugs:
1. Race between step reset and CU entry (fixed: monotonic counter)
2. Write-after-read hazard on single ring_buffer (fixed: double-buffer)

Step k reads from buf[k%2], writes to next rank's buf[(k+1)%2].

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace per-tile flag handshake with per-chunk handshake to amortize
atomic synchronization overhead. The shard is split into NUM_CHUNKS
row-bands, each with one flag. CUs process all tiles within a chunk
before signaling, reducing atomics from 4*total_tiles*steps to
4*num_chunks*steps. Different CUs work on different chunks at
different ring steps, creating a pipeline that keeps XGMI links busy.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Defer local flag reset until AFTER reading ring_buffer for forwarding
in the next step's send phase. Previously, resetting the flag at the
end of the receive phase allowed the predecessor to overwrite ring_buffer
before we read it for forwarding, causing data corruption.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@mawad-amd mawad-amd closed this Mar 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant