diff --git a/iris/ccl/__init__.py b/iris/ccl/__init__.py index 687e171a2..29bb6cba9 100644 --- a/iris/ccl/__init__.py +++ b/iris/ccl/__init__.py @@ -11,6 +11,7 @@ """ from .config import Config +from .all_gather import AllGatherWorkspace from .utils import ReduceOp -__all__ = ["Config", "ReduceOp"] +__all__ = ["Config", "AllGatherWorkspace", "ReduceOp"] diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index 6b6c53dbe..97c1e36d9 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -6,12 +6,35 @@ Gathers tensors from all ranks and concatenates them along the last dimension. """ +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch import triton import triton.language as tl import iris from .config import Config from .utils import extract_group_info + +@dataclass +class AllGatherWorkspace: + """ + Holds reusable workspace allocations for ring-based all-gather. + + Pre-allocate via ``all_gather_preamble`` and pass to ``all_gather`` + to avoid per-call heap allocation overhead. + """ + + shape: Tuple[int, int] = () + dtype: Optional[torch.dtype] = None + ring_buffer: Optional[torch.Tensor] = None + flags: Optional[torch.Tensor] = None + num_chunks: int = 0 + chunk_rows: int = 0 + prepared: bool = False + + # Conditional import for Gluon try: from triton.experimental import gluon @@ -441,20 +464,301 @@ def persistent_all_gather_gluon( gl.store(remote_ptrs, data, mask=mask) +@triton.jit() +def persistent_all_gather_ring( + input_ptr, + output_ptr, + ring_buffer, + flags, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases: tl.tensor, + group_rank: tl.constexpr, + iris_rank: tl.constexpr, + world_size: tl.constexpr, + rank_start: tl.constexpr, + rank_stride: tl.constexpr, + next_rank: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + COMM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + CHUNK_SIZE: tl.constexpr, + NUM_CHUNKS: tl.constexpr, + CHUNK_ROWS: tl.constexpr, +): + """ + Ring-based all-gather with chunked pipelining and per-chunk handshake. + + The shard (M rows x N cols) is split into NUM_CHUNKS row-bands (chunks). + Each chunk is CHUNK_ROWS rows x N cols. One flag per chunk controls the + producer/consumer handshake, amortizing synchronization cost across all + tiles within a chunk. + + CUs are assigned to chunks round-robin: chunk_id = pid % NUM_CHUNKS. + Each CU processes all tiles in its assigned chunk across W-1 ring steps. + Different CUs work on different chunks at different ring steps, creating + a pipeline that keeps XGMI links continuously busy. + + Pipeline visualization (4 ranks, 4 chunks): + Time → + CU0: [chunk0 step0] [chunk0 step1] [chunk0 step2] + CU1: [chunk1 step0] [chunk1 step1] [chunk1 step2] + CU2: [chunk2 step0] [chunk2 step1] [chunk2 step2] + CU3: [chunk3 step0] [chunk3 step1] [chunk3 step2] + + All CUs send/receive concurrently, saturating the ring bandwidth. + + Flags layout: one int32 per chunk on symmetric heap. + flag=0 means ring_buffer chunk region is free (producer can write) + flag=1 means ring_buffer chunk region has data (consumer can read) + """ + pid = tl.program_id(0) + + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + tiles_per_chunk_m = tl.cdiv(CHUNK_ROWS, BLOCK_SIZE_M) + tiles_per_chunk = tiles_per_chunk_m * num_pid_n + + # Each CU handles one or more chunks round-robin + for chunk_id in range(pid, NUM_CHUNKS, COMM_SMS): + chunk_row_start = chunk_id * CHUNK_ROWS + + # Flag pointers for this chunk + remote_flag_ptr = flags + chunk_id + local_flag_ptr = flags + chunk_id + + # Step 0: Copy local shard chunk to output + for tile_in_chunk in range(tiles_per_chunk): + tile_m = tile_in_chunk // num_pid_n + tile_n = tile_in_chunk % num_pid_n + + rm_base = chunk_row_start + tile_m * BLOCK_SIZE_M + rn_base = tile_n * BLOCK_SIZE_N + rm = rm_base + tl.arange(0, BLOCK_SIZE_M) + rn = rn_base + tl.arange(0, BLOCK_SIZE_N) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + mask = (rm[:, None] < M) & (rn[None, :] < N) + in_offset = rm[:, None] * stride_in_m + rn[None, :] * stride_in_n + data = tl.load(input_ptr + in_offset, mask=mask, other=0) + + # Write to output[group_rank * M + row, col] + rm_out = rm + group_rank * M + out_offset = rm_out[:, None] * stride_out_m + rn[None, :] * stride_out_n + tl.store(output_ptr + out_offset, data, mask=mask, cache_modifier=".wt") + + # Ring steps: forward data around the ring. + # + # Critical invariant: we must NOT reset the local flag (allowing the + # predecessor to overwrite ring_buffer) until AFTER we've read ring_buffer + # for forwarding in the send phase. The sequence per step is: + # + # 1. SEND: read data (input at step 0, ring_buffer at step>0), + # write to next rank's ring_buffer, signal next rank + # 2. Release ring_buffer from previous step (reset flag, step>0 only) + # 3. RECV: wait for predecessor's data, read ring_buffer, write to output + # 4. Do NOT reset flag yet — ring_buffer data needed for next step's send + # + # After the last step, reset the flag for cleanup. + for _step in range(0, world_size - 1): + # === SEND PHASE === + # Wait for next rank's chunk flag to be 0 (ring_buffer is free) + while ( + iris.atomic_cas( + remote_flag_ptr, + 0, + 0, + iris_rank, + next_rank, + heap_bases, + sem="acquire", + scope="sys", + ) + != 0 + ): + pass + + # Write all tiles in this chunk to next rank's ring_buffer. + # Step 0: read from input (own shard) + # Step k>0: read from local ring_buffer (received from predecessor, + # still valid because we haven't reset the local flag yet) + for tile_in_chunk in range(tiles_per_chunk): + tile_m = tile_in_chunk // num_pid_n + tile_n = tile_in_chunk % num_pid_n + + rm_base = chunk_row_start + tile_m * BLOCK_SIZE_M + rn_base = tile_n * BLOCK_SIZE_N + rm = rm_base + tl.arange(0, BLOCK_SIZE_M) + rn = rn_base + tl.arange(0, BLOCK_SIZE_N) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + mask = (rm[:, None] < M) & (rn[None, :] < N) + buf_offset = rm[:, None] * stride_in_m + rn[None, :] * stride_in_n + + if _step == 0: + send_data = tl.load(input_ptr + buf_offset, mask=mask, other=0) + else: + send_data = tl.load(ring_buffer + buf_offset, mask=mask, other=0) + + iris.store( + ring_buffer + buf_offset, + send_data, + iris_rank, + next_rank, + heap_bases, + mask=mask, + hint=(1, BLOCK_SIZE_N), + ) + + tl.debug_barrier() + # Signal next rank: all tiles in chunk are ready + iris.atomic_xchg( + remote_flag_ptr, + 1, + iris_rank, + next_rank, + heap_bases, + sem="release", + scope="sys", + ) + + # === RELEASE PREVIOUS STEP'S RING BUFFER === + # We've finished reading ring_buffer for forwarding. Now safe to + # let the predecessor overwrite it for the next step. + if _step > 0: + tl.debug_barrier() + tl.atomic_xchg(local_flag_ptr, 0, sem="release", scope="sys") + + # === RECEIVE PHASE === + # Wait for local chunk flag to be 1 (predecessor wrote data) + while tl.atomic_cas(local_flag_ptr, 0, 0, sem="acquire", scope="sys") != 1: + pass + + # Read received data from local ring_buffer and copy to output + recv_rank_idx = (group_rank + world_size - _step - 1) % world_size + + for tile_in_chunk in range(tiles_per_chunk): + tile_m = tile_in_chunk // num_pid_n + tile_n = tile_in_chunk % num_pid_n + + rm_base = chunk_row_start + tile_m * BLOCK_SIZE_M + rn_base = tile_n * BLOCK_SIZE_N + rm = rm_base + tl.arange(0, BLOCK_SIZE_M) + rn = rn_base + tl.arange(0, BLOCK_SIZE_N) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + mask = (rm[:, None] < M) & (rn[None, :] < N) + buf_offset = rm[:, None] * stride_in_m + rn[None, :] * stride_in_n + + recv_data = tl.load(ring_buffer + buf_offset, mask=mask, other=0) + + # Write to output at the correct rank slot + rm_recv = rm + recv_rank_idx * M + out_offset_recv = rm_recv[:, None] * stride_out_m + rn[None, :] * stride_out_n + tl.store(output_ptr + out_offset_recv, recv_data, mask=mask, cache_modifier=".wt") + + # Do NOT reset flag here — ring_buffer data is needed for next step's send. + + # Final cleanup: release ring_buffer from last step + tl.debug_barrier() + tl.atomic_xchg(local_flag_ptr, 0, sem="release", scope="sys") + + +def all_gather_preamble( + output_tensor, + input_tensor, + ctx, + config=None, + workspace=None, +): + """ + Pre-allocate reusable workspace for ring-based all-gather. + + Call once, then pass the returned workspace to ``all_gather`` on + every iteration to avoid per-call symmetric-heap allocation. + + Args: + output_tensor: Output tensor of shape (world_size * M, N). + input_tensor: Input tensor of shape (M, N). + ctx: Iris context. + config: Config instance (default: None → default Config). + workspace: Existing workspace to reuse (default: None → create new). + + Returns: + AllGatherWorkspace ready for the next ``all_gather`` call. + """ + if config is None: + config = Config(block_size_m=32, block_size_n=64) + + M, N = input_tensor.shape[:2] + dtype = input_tensor.dtype + + if workspace is None: + workspace = AllGatherWorkspace() + + workspace.shape = (M, N) + workspace.dtype = dtype + workspace.prepared = False + + if config.all_gather_variant == "ring": + # Single ring buffer for per-chunk handshake + if ( + workspace.ring_buffer is None + or workspace.ring_buffer.shape != (M, N) + or workspace.ring_buffer.dtype != dtype + ): + workspace.ring_buffer = ctx.zeros((M, N), dtype=dtype) + else: + workspace.ring_buffer.zero_() + + # Chunk-based flags: split shard into num_chunks row-bands. + # Each chunk gets one flag. More chunks = deeper pipeline but more + # flag overhead. Default: use comm_sms chunks (one per CU) clamped + # to a reasonable range based on problem size. + num_chunks = min(config.comm_sms, max(1, M // config.block_size_m)) + chunk_rows = (M + num_chunks - 1) // num_chunks + # Round up to block boundary for clean tiling + chunk_rows = ((chunk_rows + config.block_size_m - 1) // config.block_size_m) * config.block_size_m + # Recompute num_chunks based on rounded chunk_rows + num_chunks = (M + chunk_rows - 1) // chunk_rows + + workspace.num_chunks = num_chunks + workspace.chunk_rows = chunk_rows + + if workspace.flags is None or workspace.flags.numel() != num_chunks: + workspace.flags = ctx.zeros((num_chunks,), dtype=torch.int32) + else: + workspace.flags.zero_() + + ctx.barrier() + + workspace.prepared = True + return workspace + + def all_gather( output_tensor, input_tensor, - shmem, + ctx, group=None, async_op=False, config=None, + workspace=None, ): """ Internal all-gather collective operation implementation. - This function is called internally by shmem.ccl.all_gather(). + This function is called internally by ctx.ccl.all_gather(). Users should use the Iris instance method instead: - >>> shmem.ccl.all_gather(output_tensor, input_tensor) + >>> ctx.ccl.all_gather(output_tensor, input_tensor) Each rank sends its input tensor to all ranks, and all ranks receive and concatenate all input tensors along dimension 0 (rows), matching @@ -463,7 +767,7 @@ def all_gather( Args: output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs input_tensor: Input tensor of shape (M, N) - local rank's data to send - shmem: Iris shmem context + ctx: Iris context group: ProcessGroup or None. If None, uses all ranks in `iris` context. Default: None. async_op: If False, performs a barrier at the end. If True, returns immediately. @@ -479,7 +783,7 @@ def all_gather( # Extract group information # rank_in_group: position within the ProcessGroup (0, 1, 2, ...) - passed as group_rank to kernel # rank_global: global rank in iris context - passed as iris_rank to kernel for RMA operations - rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, shmem) + rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, ctx) M, N = input_tensor.shape[:2] expected_output_shape = (world_size * M, N) @@ -495,12 +799,12 @@ def all_gather( # Choose between Triton and Gluon implementation if config.use_gluon and GLUON_AVAILABLE: - # Check if shmem is Iris Gluon (has get_device_context method) - if not hasattr(shmem, "get_device_context"): + # Check if ctx is Iris Gluon (has get_device_context method) + if not hasattr(ctx, "get_device_context"): raise ValueError("use_gluon=True requires Iris Gluon context. Use iris.experimental.iris_gluon.iris()") # Gluon only supports the persistent variant - if config.all_gather_variant != "persistent": + if config.all_gather_variant not in ("persistent",): raise ValueError( f"Gluon all_gather only supports all_gather_variant='persistent', got '{config.all_gather_variant}'." ) @@ -534,7 +838,7 @@ def all_gather( f"Recommended: block_size_m=8, block_size_n=256." ) - context_tensor = shmem.get_device_context() + context_tensor = ctx.get_device_context() persistent_all_gather_gluon[(config.comm_sms,)]( IrisDeviceCtx, @@ -573,41 +877,99 @@ def all_gather( f"Please adjust config.comm_sms to be a multiple of {world_size}." ) - heap_bases = shmem.get_heap_bases() + heap_bases = ctx.get_heap_bases() - # Dispatch to the appropriate kernel based on variant - if config.all_gather_variant == "persistent": - kernel_fn = persistent_all_gather - elif config.all_gather_variant == "partitioned": - kernel_fn = persistent_all_gather_partitioned + if config.all_gather_variant == "ring": + # Ring variant: use workspace if provided, otherwise allocate + if workspace is not None and workspace.prepared: + ring_buffer = workspace.ring_buffer + flags = workspace.flags + num_chunks = workspace.num_chunks + chunk_rows = workspace.chunk_rows + workspace.prepared = False + else: + ring_buffer = ctx.zeros((M, N), dtype=input_tensor.dtype) + num_chunks = min(config.comm_sms, max(1, M // config.block_size_m)) + chunk_rows = (M + num_chunks - 1) // num_chunks + chunk_rows = ((chunk_rows + config.block_size_m - 1) // config.block_size_m) * config.block_size_m + num_chunks = (M + chunk_rows - 1) // chunk_rows + flags = ctx.zeros((num_chunks,), dtype=torch.int32) + ctx.barrier() + + # Calculate next rank in the ring + if group is None: + next_rank = (rank_in_group + 1) % world_size + else: + import torch.distributed as dist + + group_ranks = dist.get_process_group_ranks(group) + next_rank_in_group = (rank_in_group + 1) % world_size + next_rank = group_ranks[next_rank_in_group] + + persistent_all_gather_ring[(config.comm_sms,)]( + input_tensor, + output_tensor, + ring_buffer, + flags, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases, + rank_in_group, + rank_global, + world_size, + rank_start, + rank_stride, + next_rank, + config.block_size_m, + config.block_size_n, + config.swizzle_size, + config.comm_sms, + config.num_xcds, + config.chunk_size, + num_chunks, + chunk_rows, + num_stages=config.num_stages, + num_warps=config.num_warps, + waves_per_eu=config.waves_per_eu, + ) else: - raise ValueError(f"Unknown all_gather_variant: {config.all_gather_variant}") - - kernel_fn[(config.comm_sms,)]( - input_tensor, - output_tensor, - M, - N, - stride_in_m, - stride_in_n, - stride_out_m, - stride_out_n, - heap_bases, - rank_in_group, - rank_global, - world_size, - rank_start, - rank_stride, - config.block_size_m, - config.block_size_n, - config.swizzle_size, - config.comm_sms, - config.num_xcds, - config.chunk_size, - num_stages=config.num_stages, - num_warps=config.num_warps, - waves_per_eu=config.waves_per_eu, - ) + # Dispatch to the appropriate kernel based on variant + if config.all_gather_variant == "persistent": + kernel_fn = persistent_all_gather + elif config.all_gather_variant == "partitioned": + kernel_fn = persistent_all_gather_partitioned + else: + raise ValueError(f"Unknown all_gather_variant: {config.all_gather_variant}") + + kernel_fn[(config.comm_sms,)]( + input_tensor, + output_tensor, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases, + rank_in_group, + rank_global, + world_size, + rank_start, + rank_stride, + config.block_size_m, + config.block_size_n, + config.swizzle_size, + config.comm_sms, + config.num_xcds, + config.chunk_size, + num_stages=config.num_stages, + num_warps=config.num_warps, + waves_per_eu=config.waves_per_eu, + ) if not async_op: - shmem.barrier() + ctx.barrier() diff --git a/iris/ccl/config.py b/iris/ccl/config.py index 1084de063..5e84cda1a 100644 --- a/iris/ccl/config.py +++ b/iris/ccl/config.py @@ -32,9 +32,10 @@ class Config: use_gluon: If True, use Gluon-based implementation (default: False) Gluon provides better control over warp-level traffic shaping all_gather_variant: Variant for all-gather operation (default: "persistent") - Options: "persistent", "partitioned" + Options: "persistent", "partitioned", "ring" - "persistent": Each PID handles multiple tiles and sends to all ranks - "partitioned": PIDs partitioned across ranks, eliminates inner loop + - "ring": Ring-based forwarding with flag sync, O(1) writes per step all_reduce_variant: Variant for all-reduce operation (default: "atomic") Options: "atomic", "ring", "two_shot", "one_shot", "spinlock" all_reduce_distribution: Distribution for two-shot all-reduce (default: 0) @@ -84,7 +85,7 @@ class Config: num_xcds: int | None = None chunk_size: int | None = None use_gluon: bool = False - all_gather_variant: str = "persistent" + all_gather_variant: str = "persistent" # "persistent", "partitioned", or "ring" all_reduce_variant: str = "two_shot" all_reduce_distribution: int = 1 all_reduce_num_rings: int = 1 @@ -114,9 +115,9 @@ def __post_init__(self): raise ValueError(f"comm_sms must be positive, got {self.comm_sms}") if self.num_xcds <= 0: raise ValueError(f"num_xcds must be positive, got {self.num_xcds}") - if self.all_gather_variant not in ["persistent", "partitioned"]: + if self.all_gather_variant not in ["persistent", "partitioned", "ring"]: raise ValueError( - f"all_gather_variant must be one of: 'persistent', 'partitioned', got {self.all_gather_variant}" + f"all_gather_variant must be one of: 'persistent', 'partitioned', 'ring', got {self.all_gather_variant}" ) if self.all_reduce_variant not in ["atomic", "ring", "two_shot", "one_shot", "spinlock"]: raise ValueError( diff --git a/iris/iris.py b/iris/iris.py index 8c750ba67..80bbb1c24 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1159,7 +1159,7 @@ def all_to_all(self, output_tensor, input_tensor, group=None, async_op=False, co _all_to_all(output_tensor, input_tensor, self._iris, group=group, async_op=async_op, config=config) - def all_gather(self, output_tensor, input_tensor, group=None, async_op=False, config=None): + def all_gather(self, output_tensor, input_tensor, group=None, async_op=False, config=None, workspace=None): """ All-gather collective operation. @@ -1170,29 +1170,54 @@ def all_gather(self, output_tensor, input_tensor, group=None, async_op=False, co Args: output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs input_tensor: Input tensor of shape (M, N) - local rank's data to send - group: ProcessGroup or None. If None, uses all ranks in shmem context. + group: ProcessGroup or None. If None, uses all ranks in iris context. Default: None. async_op: If False, performs a barrier at the end. If True, returns immediately. Default: False. config: Config instance with kernel parameters (default: None). If None, uses default Config values. + workspace: Optional AllGatherWorkspace from ``all_gather_preamble``. + Avoids per-call heap allocation for ring variant. Example: >>> ctx = iris.iris() >>> # Input: (M, N), Output: (world_size * M, N) >>> ctx.ccl.all_gather(output_tensor, input_tensor) - >>> # Custom configuration + >>> # Ring variant with pre-allocated workspace >>> from iris.ccl import Config - >>> config = Config(block_size_m=128, block_size_n=32) - >>> ctx.ccl.all_gather(output_tensor, input_tensor, config=config) - - >>> # Async operation (no barrier) - >>> ctx.ccl.all_gather(output_tensor, input_tensor, async_op=True) + >>> config = Config(all_gather_variant="ring") + >>> ws = ctx.ccl.all_gather_preamble(out, inp, config=config) + >>> ctx.ccl.all_gather(out, inp, config=config, workspace=ws) """ from iris.ccl.all_gather import all_gather as _all_gather - _all_gather(output_tensor, input_tensor, self._iris, group=group, async_op=async_op, config=config) + _all_gather( + output_tensor, + input_tensor, + self._iris, + group=group, + async_op=async_op, + config=config, + workspace=workspace, + ) + + def all_gather_preamble(self, output_tensor, input_tensor, config=None, workspace=None): + """ + Pre-allocate reusable workspace for ring-based all-gather. + + Args: + output_tensor: Output tensor of shape (world_size * M, N). + input_tensor: Input tensor of shape (M, N). + config: Optional Config describing variant parameters. + workspace: Optional existing workspace to update/reuse. + + Returns: + AllGatherWorkspace that can be passed to ``all_gather``. + """ + from iris.ccl.all_gather import all_gather_preamble as _all_gather_preamble + + return _all_gather_preamble(output_tensor, input_tensor, self._iris, config=config, workspace=workspace) def all_reduce_preamble(self, output_tensor, input_tensor, config=None, workspace=None): """ diff --git a/tests/ccl/test_all_gather.py b/tests/ccl/test_all_gather.py index 7858ed18d..c5a0bc827 100644 --- a/tests/ccl/test_all_gather.py +++ b/tests/ccl/test_all_gather.py @@ -90,6 +90,70 @@ def test_all_gather(dtype, M, N, block_size_m, block_size_n): gc.collect() +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.float32, + torch.bfloat16, + ], +) +@pytest.mark.parametrize( + "M, N, block_size_m, block_size_n", + [ + (128, 64, 32, 64), # Small + (128, 128, 32, 32), # Multi-block per rank + (1024, 256, 32, 64), # Medium + (8192, 8192, 32, 64), # Large + ], +) +def test_all_gather_ring(dtype, M, N, block_size_m, block_size_n): + """Test all-gather with ring variant by comparing against PyTorch's implementation.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 # 8GB + ctx = iris.iris(heap_size) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + # Reference: PyTorch all_gather_into_tensor + pytorch_input_tensor = torch.randn(M, N, dtype=dtype, device=f"cuda:{rank}") + pytorch_input_tensor.fill_(float(rank + 1)) + + pytorch_output_tensor = torch.zeros(world_size * M, N, dtype=dtype, device=f"cuda:{rank}") + + ctx.barrier() + dist.all_gather_into_tensor(pytorch_output_tensor, pytorch_input_tensor) + torch.cuda.synchronize() + + # Iris all_gather with ring variant + iris_input_tensor = ctx.zeros((M, N), dtype=dtype) + iris_input_tensor.copy_(pytorch_input_tensor) + + iris_output_tensor = ctx.zeros((world_size * M, N), dtype=dtype) + + ctx.barrier() + config = Config(block_size_m=block_size_m, block_size_n=block_size_n, all_gather_variant="ring") + ctx.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config) + torch.cuda.synchronize() + + atol = 1e-3 if dtype == torch.float16 else 1e-5 + max_diff = torch.abs(iris_output_tensor - pytorch_output_tensor).max().item() + + try: + assert torch.allclose(iris_output_tensor, pytorch_output_tensor, atol=atol), ( + f"Max difference: {max_diff}, expected < {atol}\n" + f"Rank {rank}: Iris output (ring) doesn't match PyTorch's all_gather_into_tensor" + ) + finally: + ctx.barrier() + del ctx + import gc + + gc.collect() + + @pytest.mark.parametrize( "dtype", [