diff --git a/benchmark/ccl/bench_moe_dispatch.py b/benchmark/ccl/bench_moe_dispatch.py new file mode 100644 index 000000000..8c0e1234e --- /dev/null +++ b/benchmark/ccl/bench_moe_dispatch.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Performance benchmark: iris MoEDispatcher vs naive torch.distributed.all_to_all dispatch. + +Run with: + torchrun --nproc_per_node=8 benchmark/ccl/bench_moe_dispatch.py + +Compares: + 1. iris MoEDispatcher (direct iris.store scatter) + 2. Naive all_to_all based dispatch (pack → all_to_all → unpack) +""" + +import gc +import importlib.util +import os +import sys +from pathlib import Path + +import torch +import torch.distributed as dist + +import iris +from iris.ccl import MoEDispatcher +from iris.ccl.moe_utils import ( + make_expt_dict_uniform, + make_expt_assignment, + topk, + _make_bitmatrix_metadata, + make_ragged_tensor_metadata, + remap_ragged_tensor_metadata, + reduce, +) + +# Load grouped_matmul from examples +PROJECT_ROOT = Path(__file__).resolve() +while not (PROJECT_ROOT / "tests").is_dir() or not (PROJECT_ROOT / "examples").is_dir(): + if PROJECT_ROOT == PROJECT_ROOT.parent: + raise FileNotFoundError("Could not find project root") + PROJECT_ROOT = PROJECT_ROOT.parent + +EXAMPLE_DIR = PROJECT_ROOT / "examples" / "31_expert_sharded_moe" +sys.path.insert(0, str(EXAMPLE_DIR)) + + +def _load_module(name, path): + spec = importlib.util.spec_from_file_location(name, path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +GROUPED_MATMUL_MOD = _load_module("gm", EXAMPLE_DIR / "grouped_matmul.py") +grouped_matmul = GROUPED_MATMUL_MOD.grouped_matmul + + +def naive_a2a_dispatch( + tokens, topk_idx, topk_weight, w_local, b_local, expt_assignment, n_expts_tot, n_expts_act, rank, world_size +): + """Naive dispatch using torch.distributed.all_to_all. + + This is the baseline: gather all tokens, sort by expert, send via all_to_all, + run expert matmul, send back via all_to_all, reduce. + """ + n_tokens_local, d_model = tokens.shape + n_tokens_global = n_tokens_local * world_size + device = tokens.device + + # All-gather tokens + x_gathered = torch.empty(n_tokens_global, d_model, dtype=tokens.dtype, device=device) + dist.all_gather_into_tensor(x_gathered, tokens.contiguous()) + + # All-gather routing info + topk_idx_i32 = topk_idx.contiguous().to(torch.int32) + topk_weight_f32 = topk_weight.contiguous().to(torch.float32) + idx_gathered = torch.empty(n_tokens_global, n_expts_act, dtype=torch.int32, device=device) + val_gathered = torch.empty(n_tokens_global, n_expts_act, dtype=torch.float32, device=device) + dist.all_gather_into_tensor(idx_gathered, topk_idx_i32) + dist.all_gather_into_tensor(val_gathered, topk_weight_f32) + + # Build routing metadata + mask_metadata = _make_bitmatrix_metadata(idx_gathered, n_expts_tot) + dispatch_indx = mask_metadata.row_sorted_indx + combine_indx = mask_metadata.col_sorted_indx + expt_sizes = mask_metadata.col_sum + n_active = int(expt_sizes.sum().item()) + + # Gather tokens into expert-sorted order + gather_idx = torch.div(combine_indx[:n_active], n_expts_act, rounding_mode="trunc") + x_sorted = torch.zeros(n_active, d_model, dtype=tokens.dtype, device=device) + valid_gather = gather_idx >= 0 + x_sorted[valid_gather] = x_gathered[gather_idx[valid_gather].long()] + + # Build ragged metadata and remap to local view + ragged_meta = make_ragged_tensor_metadata(expt_sizes, n_active) + expt_map = expt_assignment.expt_map[rank, :].contiguous() + local_meta = remap_ragged_tensor_metadata(ragged_meta, expt_map) + + # Expert computation + expert_out = grouped_matmul(x_sorted, w_local, b_local, local_meta) + + # Scatter back using combine indices + y_flat = torch.zeros(n_tokens_global * n_expts_act, d_model, dtype=tokens.dtype, device=device) + for i in range(n_active): + dst = combine_indx[i].item() + if dst >= 0: + y_flat[dst] = expert_out[i] + + # Reduce + y_mask = (dispatch_indx != -1).view(n_tokens_global, n_expts_act, 1) + local_mask = y_mask[rank * n_tokens_local : (rank + 1) * n_tokens_local] + y_3d = y_flat[rank * n_tokens_local * n_expts_act : (rank + 1) * n_tokens_local * n_expts_act] + y_3d = y_3d.view(n_tokens_local, n_expts_act, d_model) + local_mask = local_mask.expand_as(y_3d).contiguous() + z_local, _ = reduce(y_3d, dim=1, mask=local_mask) + return z_local + + +def iris_dispatch_combine(tokens, topk_idx, topk_weight, w_local, b_local, dispatcher, local_meta_out): + """Iris MoEDispatcher dispatch + expert matmul + combine.""" + dispatch_buf, local_meta, handle = dispatcher.dispatch(tokens, topk_idx, topk_weight) + expert_out = grouped_matmul(dispatch_buf, w_local, b_local, local_meta) + z_local = dispatcher.combine(expert_out, handle) + return z_local + + +def benchmark(fn, warmup=20, measured=100): + """Benchmark a function, return median latency in us.""" + torch.cuda.synchronize() + + # Warmup + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + # Measured + times = [] + for _ in range(measured): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end) * 1000) # us + + times.sort() + median = times[len(times) // 2] + p10 = times[len(times) // 10] + p90 = times[9 * len(times) // 10] + return median, p10, p90 + + +def main(): + rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(rank) + dist.init_process_group(backend="nccl") + world_size = dist.get_world_size() + device = torch.device(f"cuda:{rank}") + + ctx = iris.iris() + dtype = torch.bfloat16 + + # Test configurations + configs = [ + # (n_tokens_local, d_model, n_expts_tot, n_expts_act) + (32, 4096, world_size * 8, 2), # Small decode batch, LLM hidden dim + (64, 4096, world_size * 8, 2), # Medium decode + (128, 4096, world_size * 8, 2), # Large decode + (256, 4096, world_size * 8, 2), # Prefill-like + (512, 4096, world_size * 8, 2), # Large prefill + (128, 2048, world_size * 4, 2), # Smaller model + (128, 4096, world_size * 8, 1), # topk=1 + (128, 4096, world_size * 8, 4), # topk=4 + ] + + if rank == 0: + print("# MoE Dispatch/Combine Benchmark — iris vs naive all_to_all") + print("## Hardware") + print(f"- GPUs: {world_size}x MI300X (or similar)") + print(f"- dtype: {dtype}") + print("- Warmup: 20 iterations, Measured: 100 iterations") + print() + print("| T_local | H | E_tot | k | iris (us) | naive (us) | Speedup |") + print("|---------|------|-------|---|-----------|------------|---------|") + + for n_tokens_local, d_model, n_expts_tot, n_expts_act in configs: + n_tokens = n_tokens_local * world_size + + torch.manual_seed(0) + x_global = torch.randn(n_tokens, d_model, device=device, dtype=dtype) + l_global = torch.rand(n_tokens, n_expts_tot, device=device, dtype=torch.float32) + w_global = torch.randn(n_expts_tot, d_model, d_model, device=device, dtype=dtype) + b_global = torch.randn(n_expts_tot, d_model, device=device, dtype=torch.float32) + dist.broadcast(x_global, src=0) + dist.broadcast(l_global, src=0) + dist.broadcast(w_global, src=0) + dist.broadcast(b_global, src=0) + + expt_dict = make_expt_dict_uniform(world_size, n_expts_tot) + expt_assignment = make_expt_assignment(world_size, n_expts_tot, expt_dict, device) + + first = rank * n_tokens_local + last = first + n_tokens_local + x_local = x_global[first:last].contiguous() + l_local = l_global[first:last].contiguous() + w_local = w_global[expt_assignment.expt_boolmask[rank]].contiguous() + b_local = b_global[expt_assignment.expt_boolmask[rank]].contiguous() + + topk_result = topk(l_local, n_expts_act, apply_softmax=True) + + # Create dispatcher + dispatcher = MoEDispatcher( + ctx, + d_model, + n_expts_tot, + n_expts_act, + n_tokens_local, + dtype=dtype, + expt_assignment=expt_assignment, + ) + ctx.barrier() + + # Benchmark iris + def _iris_fn(x=x_local, idx=topk_result.indx, vals=topk_result.vals, w=w_local, b=b_local, d=dispatcher): + return iris_dispatch_combine(x, idx, vals, w, b, d, None) + + iris_time, _, _ = benchmark(_iris_fn, warmup=20, measured=100) + + # Benchmark naive + naive_time, _, _ = benchmark( + lambda: naive_a2a_dispatch( + x_local, + topk_result.indx, + topk_result.vals, + w_local, + b_local, + expt_assignment, + n_expts_tot, + n_expts_act, + rank, + world_size, + ), + warmup=20, + measured=100, + ) + + speedup = naive_time / iris_time if iris_time > 0 else float("inf") + + if rank == 0: + print( + f"| {n_tokens_local:>7} | {d_model:>4} | {n_expts_tot:>5} | {n_expts_act} | {iris_time:>9.0f} | {naive_time:>10.0f} | {speedup:>6.2f}x |" + ) + + # Cleanup + del dispatcher + gc.collect() + ctx.barrier() + + if rank == 0: + print() + print("## Analysis") + print("iris MoEDispatcher uses direct symmetric heap scatter (iris.store)") + print("which avoids the all_to_all pack/unpack overhead. The naive approach") + print("uses all_gather + host-side sorting, which involves extra copies.") + + del ctx + gc.collect() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/iris/ccl/__init__.py b/iris/ccl/__init__.py index 687e171a2..d1d63f724 100644 --- a/iris/ccl/__init__.py +++ b/iris/ccl/__init__.py @@ -12,5 +12,24 @@ from .config import Config from .utils import ReduceOp +from .moe_dispatch import MoEDispatcher, MoEDispatchConfig, DispatchHandle +from .moe_utils import ( + ExptAssignment, + make_expt_dict_uniform, + make_expt_assignment, + topk, + TopkResult, +) -__all__ = ["Config", "ReduceOp"] +__all__ = [ + "Config", + "ReduceOp", + "MoEDispatcher", + "MoEDispatchConfig", + "DispatchHandle", + "ExptAssignment", + "make_expt_dict_uniform", + "make_expt_assignment", + "topk", + "TopkResult", +] diff --git a/iris/ccl/moe_dispatch.py b/iris/ccl/moe_dispatch.py new file mode 100644 index 000000000..64ddee154 --- /dev/null +++ b/iris/ccl/moe_dispatch.py @@ -0,0 +1,441 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +MoE token dispatch/combine for expert-parallel inference via iris symmetric heap. + +Provides ``MoEDispatcher`` — a pre-allocated, handle-based API for routing +tokens to expert-owning ranks (dispatch) and sending results back with +aggregation (combine). + +Dispatch uses direct iris.store scatter (not AllToAll/AllToAllv) for sparse, +routing-dependent token movement. Buffers are allocated once in __init__ +and sliced per-call to amortize allocation overhead. + +Kernels are the same as examples/31_expert_sharded_moe/{dispatch,combine}.py, +promoted here for production use. +""" + +from dataclasses import dataclass + +import torch +import triton +import triton.language as tl +import iris + +from .moe_utils import ( + ExptAssignment, + RaggedTensorMetadata, + _make_bitmatrix_metadata, + make_ragged_tensor_metadata, + remap_ragged_tensor_metadata, + reduce, +) + + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + + +@dataclass +class MoEDispatchConfig: + """Tuning knobs for MoE dispatch/combine kernels.""" + + dispatch_block_size: int = 512 # Tile size for dispatch kernel + combine_block_size: int = 512 # Tile size for combine kernel + + +# --------------------------------------------------------------------------- +# DispatchHandle — opaque state passed from dispatch() to combine() +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class DispatchHandle: + """Opaque handle returned by ``dispatch()`` and consumed by ``combine()``. + + Carries routing metadata and buffer references so that ``combine()`` + does not need to recompute them. + """ + + expt_assignment: ExptAssignment + expt_indx_global: torch.Tensor # (T_global, k) int32 + dispatch_indx: torch.Tensor # (T_global * k,) row_sorted_indx + combine_indx: torch.Tensor # (T_global * k,) col_sorted_indx + topk_vals: torch.Tensor # (T_global, k) gating weights + ragged_meta_global: RaggedTensorMetadata + expt_sizes: torch.Tensor # (n_expts,) per-expert counts + dispatch_buffer: torch.Tensor # (T_global * k, H) on shmem heap + n_tokens_local: int + n_tokens_global: int + hidden_dim: int + topk: int + + +# --------------------------------------------------------------------------- +# Triton kernels (from examples/31_expert_sharded_moe/) +# --------------------------------------------------------------------------- + + +@triton.jit +def _convert_dp_to_ep( + dst_ptr, + dst_stride_m, + src_ptr, + src_stride_m, + src_shape_n, + expt_filter_ptr, + expt_filter_stride_m, + expt_indx_ptr, + expt_indx_stride_m, + dst_row_indx_ptr, + dst_row_indx_stride_m, + n_tokens_local, + heap_bases, + SRC_RANK: tl.constexpr, + N_EXPT_ACT: tl.constexpr, + N_RANKS: tl.constexpr, + BLOCK: tl.constexpr, +): + pid_m = tl.program_id(0) + off_m_global = pid_m + n_tokens_local * SRC_RANK + off_m_local = pid_m + + offs_n = tl.arange(0, BLOCK) + offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK) + + for act in tl.static_range(N_EXPT_ACT): + dst_row = tl.load(dst_row_indx_ptr + off_m_global * dst_row_indx_stride_m + act) + if dst_row >= 0: + expt_id = tl.load(expt_indx_ptr + off_m_global * expt_indx_stride_m + act) + + dst_rank = 0 + for r in tl.static_range(N_RANKS): + word = expt_id // 32 + bit = expt_id % 32 + filt = tl.load(expt_filter_ptr + r * expt_filter_stride_m + word) + if (filt >> bit) & 1: + dst_rank = r + + for start_n in range(0, src_shape_n, BLOCK): + mask_n = start_n + offs_n < src_shape_n + src = tl.load( + src_ptr + off_m_local * src_stride_m + start_n + offs_n, + mask=mask_n, + other=0.0, + ) + dst_off = dst_row * dst_stride_m + start_n + offs_n + for r in tl.static_range(N_RANKS): + if dst_rank == r: + iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n, hint=16) + + +@triton.jit +def _convert_ep_to_dp( + dst_ptr, + dst_stride_m, + src_ptr, + src_stride_m, + src_shape_n, + expt_filter_ptr, + expt_filter_stride_m, + expt_indx_ptr, + dst_row_indx_ptr, + n_slots_per_rank, + heap_bases, + BLOCK: tl.constexpr, + SRC_RANK: tl.constexpr, + N_RANKS: tl.constexpr, +): + pid_m = tl.program_id(0) + + dst_indx_global = tl.load(dst_row_indx_ptr + pid_m) + if dst_indx_global < 0: + return + + dst_rank = dst_indx_global // n_slots_per_rank + + dst_expt_indx = tl.load(expt_indx_ptr + dst_indx_global).to(tl.int32) + expt_filter_ptr_local = expt_filter_ptr + SRC_RANK * expt_filter_stride_m + has_dst_expt = (tl.load(expt_filter_ptr_local + dst_expt_indx // 32) >> (dst_expt_indx % 32)) & 1 + if not has_dst_expt.to(tl.int1): + return + + dst_indx_local = dst_indx_global - dst_rank * n_slots_per_rank + + offs_n = tl.arange(0, BLOCK) + offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK) + for start_n in range(0, src_shape_n, BLOCK): + mask_n = start_n + offs_n < src_shape_n + src = tl.load( + src_ptr + pid_m * src_stride_m + start_n + offs_n, + mask=mask_n, + other=0.0, + ) + dst_off = dst_indx_local * dst_stride_m + start_n + offs_n + for r in tl.static_range(N_RANKS): + if dst_rank == r: + iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n, hint=16) + + +# --------------------------------------------------------------------------- +# MoEDispatcher +# --------------------------------------------------------------------------- + + +class MoEDispatcher: + """Pre-allocated MoE token dispatch/combine for expert-parallel inference. + + Usage:: + + dispatcher = ctx.ccl.moe_dispatcher(hidden_dim=4096, num_experts=64, + topk=2, max_tokens=4096) + # In forward pass: + recv_tokens, local_meta, handle = dispatcher.dispatch(tokens, topk_idx, topk_weight) + expert_out = run_experts(recv_tokens, local_meta) # grouped matmul + combined = dispatcher.combine(expert_out, handle) + """ + + def __init__( + self, + ctx, + hidden_dim: int, + num_experts: int, + topk: int, + max_tokens: int, + dtype=None, + group=None, + config: MoEDispatchConfig | None = None, + expt_assignment: ExptAssignment | None = None, + ): + """Pre-allocate buffers for dispatch/combine. + + Args: + ctx: iris.Iris instance. + hidden_dim: model hidden dimension (H). + num_experts: total number of experts across all ranks. + topk: number of experts activated per token (k). + max_tokens: maximum tokens per rank per call. + dtype: data type for dispatch/combine buffers (default: bfloat16). + group: reserved for future process group support. + config: kernel tuning config. + expt_assignment: expert-to-rank mapping. If None, uses uniform + contiguous assignment. + """ + self._ctx = ctx + self._hidden_dim = hidden_dim + self._num_experts = num_experts + self._topk = topk + self._max_tokens = max_tokens + self._dtype = dtype or torch.bfloat16 + self._config = config or MoEDispatchConfig() + + self._rank = ctx.get_rank() + self._world_size = ctx.get_num_ranks() + self._device = torch.device(f"cuda:{self._rank}") + + # Expert assignment + if expt_assignment is not None: + self._expt_assignment = expt_assignment + else: + from .moe_utils import make_expt_dict_uniform, make_expt_assignment + + expt_dict = make_expt_dict_uniform(self._world_size, num_experts) + self._expt_assignment = make_expt_assignment(self._world_size, num_experts, expt_dict, self._device) + + # Pre-allocate buffers + max_T_global = max_tokens * self._world_size + max_slots = max_T_global * topk + + self._dispatch_buf = ctx.zeros((max_slots, hidden_dim), dtype=self._dtype) + self._combine_buf = ctx.zeros((max_tokens * topk, hidden_dim), dtype=self._dtype) + self._ag_indx_buf = ctx.zeros((max_T_global, topk), dtype=torch.int32) + self._ag_vals_buf = ctx.zeros((max_T_global, topk), dtype=torch.float32) + + def dispatch( + self, + tokens: torch.Tensor, + topk_idx: torch.Tensor, + topk_weight: torch.Tensor, + ) -> tuple[torch.Tensor, RaggedTensorMetadata, DispatchHandle]: + """Dispatch local tokens to expert-owning ranks. + + Args: + tokens: (T_local, H) local token activations. + topk_idx: (T_local, k) int expert indices from gating. + topk_weight: (T_local, k) float gating weights. + + Returns: + (dispatch_buffer, local_ragged_meta, handle) where: + - dispatch_buffer: (n_active_slots, H) tokens in expert-sorted order + on this rank's symmetric heap. + - local_ragged_meta: RaggedTensorMetadata for this rank's experts. + - handle: DispatchHandle to pass to combine(). + """ + ctx = self._ctx + rank = self._rank + world_size = self._world_size + k = self._topk + hidden_dim = self._hidden_dim + config = self._config + + n_tokens_local = tokens.shape[0] + n_tokens_global = n_tokens_local * world_size + + # Step 1: Promote indices to int32 (narrow type corruption bug) + topk_idx_i32 = topk_idx.contiguous().to(torch.int32) + topk_weight_f32 = topk_weight.contiguous().to(torch.float32) + + # Step 2: All-gather topk_idx and topk_weight via ctx.ccl.all_gather + ag_indx = self._ag_indx_buf[:n_tokens_global, :k] + ag_vals = self._ag_vals_buf[:n_tokens_global, :k] + ctx.ccl.all_gather(ag_indx, topk_idx_i32) + ctx.ccl.all_gather(ag_vals, topk_weight_f32) + + indx_global = ag_indx # (T_global, k) int32 + vals_global = ag_vals # (T_global, k) float32 + + # Step 3: Build BitmatrixMetadata from global indices + mask_metadata = _make_bitmatrix_metadata(indx_global.to(torch.int32), self._num_experts) + dispatch_indx = mask_metadata.row_sorted_indx # (T_global * k,) + combine_indx = mask_metadata.col_sorted_indx # (T_global * k,) + expt_sizes = mask_metadata.col_sum # (n_expts,) + + # Step 4: Build RaggedTensorMetadata + n_active = int(expt_sizes.sum().item()) + ragged_meta_global = make_ragged_tensor_metadata(expt_sizes, n_active) + + # Step 5: Zero dispatch buffer, barrier + n_total_slots = n_tokens_global * k + dispatch_buf = self._dispatch_buf[:n_total_slots, :hidden_dim] + dispatch_buf.zero_() + ctx.barrier() + + # Step 6: Launch _convert_dp_to_ep kernel + BLOCK = min(triton.next_power_of_2(hidden_dim), config.dispatch_block_size) + grid = (n_tokens_local,) + + expt_bitmask = self._expt_assignment.expt_bitmask + + _convert_dp_to_ep[grid]( + dispatch_buf, + dispatch_buf.stride(0), + tokens, + tokens.stride(0), + tokens.shape[1], + expt_bitmask, + expt_bitmask.stride(0), + indx_global, + indx_global.stride(0), + dispatch_indx, + k, + n_tokens_local, + ctx.get_heap_bases(), + SRC_RANK=rank, + N_EXPT_ACT=k, + N_RANKS=world_size, + BLOCK=BLOCK, + ) + + # Step 7: Barrier (all stores must complete before reads) + ctx.barrier() + + # Step 8: Remap ragged metadata to local expert view + expt_map = self._expt_assignment.expt_map[rank, :].contiguous() + local_ragged_meta = remap_ragged_tensor_metadata(ragged_meta_global, expt_map) + + # Build handle + handle = DispatchHandle( + expt_assignment=self._expt_assignment, + expt_indx_global=indx_global, + dispatch_indx=dispatch_indx, + combine_indx=combine_indx, + topk_vals=vals_global, + ragged_meta_global=ragged_meta_global, + expt_sizes=expt_sizes, + dispatch_buffer=dispatch_buf, + n_tokens_local=n_tokens_local, + n_tokens_global=n_tokens_global, + hidden_dim=hidden_dim, + topk=k, + ) + + return dispatch_buf, local_ragged_meta, handle + + def combine( + self, + expert_output: torch.Tensor, + handle: DispatchHandle, + ) -> torch.Tensor: + """Combine expert results back to token-owning ranks. + + Args: + expert_output: (n_total_slots, H) expert-sorted matmul output. + These are the results after the grouped expert computation. + handle: DispatchHandle from dispatch(). + + Returns: + combined: (T_local, H) combined output for this rank's tokens. + """ + ctx = self._ctx + rank = self._rank + world_size = self._world_size + config = self._config + + n_tokens_local = handle.n_tokens_local + n_tokens_global = handle.n_tokens_global + hidden_dim = handle.hidden_dim + k = handle.topk + + expt_bitmask = handle.expt_assignment.expt_bitmask + flat_expt_indx = handle.expt_indx_global.to(torch.int32).reshape(-1) + combine_indx = handle.combine_indx + + # Step 1: Zero combine buffer, barrier + n_local_slots = n_tokens_local * k + combine_buf = self._combine_buf[:n_local_slots, :hidden_dim] + combine_buf.zero_() + ctx.barrier() + + # Step 2: Launch _convert_ep_to_dp kernel + # n_slots_per_rank for the combine kernel: n_tokens_local * k + # because the flat dispatch_indx has n_tokens_global * k entries, + # and each rank's portion is n_tokens_local * k + n_slots_per_rank = n_tokens_local * k + n_total_slots = n_tokens_global * k + + BLOCK = min(triton.next_power_of_2(hidden_dim), config.combine_block_size) + grid = (n_total_slots,) + + _convert_ep_to_dp[grid]( + combine_buf, + combine_buf.stride(0), + expert_output, + expert_output.stride(0), + expert_output.shape[1], + expt_bitmask, + expt_bitmask.stride(0), + flat_expt_indx, + combine_indx, + n_slots_per_rank, + ctx.get_heap_bases(), + BLOCK=BLOCK, + SRC_RANK=rank, + N_RANKS=world_size, + ) + + # Step 3: Barrier + ctx.barrier() + + # Step 4: Reshape combine buffer to (T_local, k, H) + combine_3d = combine_buf.view(n_tokens_local, k, hidden_dim) + + # Step 5: Masked reduce over dim=1 + dispatch_indx = handle.dispatch_indx + y_mask = (dispatch_indx != -1).view(n_tokens_global, k, 1) + local_mask = y_mask[rank * n_tokens_local : (rank + 1) * n_tokens_local] + local_mask = local_mask.expand_as(combine_3d).contiguous() + combined, _ = reduce(combine_3d, dim=1, mask=local_mask) + + return combined diff --git a/iris/ccl/moe_utils.py b/iris/ccl/moe_utils.py new file mode 100644 index 000000000..fd12df5e2 --- /dev/null +++ b/iris/ccl/moe_utils.py @@ -0,0 +1,355 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +MoE routing utilities for expert-parallel dispatch/combine. + +Promoted from examples/31_expert_sharded_moe/ into iris.ccl for production use. +Provides expert assignment, ragged tensor metadata, top-k routing, and reduce. + +Ported from triton_kernels: + - distributed.py (ExptAssignment) + - tensor_details/ragged_tensor.py (RaggedTensorMetadata) + - topk.py / bitmatrix.py (TopkResult, BitmatrixMetadata) + - reduce.py (masked sum-reduce) +""" + +import torch +import triton +import triton.language as tl +from dataclasses import dataclass + + +# --------------------------------------------------------------------------- +# Expert Assignment +# --------------------------------------------------------------------------- + + +@dataclass +class ExptAssignment: + """Expert-to-rank assignment for expert-parallel MoE. + + Attributes: + expt_bitmask: (n_shards, ceil(n_expts_tot / 32)) packed int32 bitmask. + (expt_bitmask[i, j//32] >> j%32) & 1 == 1 iff expert j is owned by shard i. + expt_boolmask: (n_shards, n_expts_tot) boolean mask. + expt_map: (n_shards, n_expts_tot) local expert id or -1. + n_expts_per_shard: list of expert counts per shard. + """ + + expt_bitmask: torch.Tensor + expt_boolmask: torch.Tensor + expt_map: torch.Tensor + n_expts_per_shard: list[int] + + +def make_expt_dict_uniform(n_shards: int, n_expts_tot: int) -> dict[int, list[int]]: + """Contiguous assignment: shard i owns experts [i*E_per_shard, (i+1)*E_per_shard).""" + assert n_expts_tot % n_shards == 0, "n_expts_tot must be divisible by n_shards" + e_per_shard = n_expts_tot // n_shards + return {i: list(range(i * e_per_shard, (i + 1) * e_per_shard)) for i in range(n_shards)} + + +def make_expt_assignment( + n_shards: int, + n_expts_tot: int, + expt_dict: dict[int, list[int]], + device, +) -> ExptAssignment: + """Build bitmask, boolmask, and local-id map from an expert ownership dict.""" + words = (n_expts_tot + 31) // 32 + expt_bitmask = torch.zeros((n_shards, words), dtype=torch.int32) + expt_boolmask = torch.zeros((n_shards, n_expts_tot), dtype=torch.bool) + counts = {e: 0 for e in range(n_expts_tot)} + + for shard, experts in expt_dict.items(): + if not (0 <= shard < n_shards): + raise ValueError(f"shard {shard} out of range [0, {n_shards})") + if len(experts) == 0: + raise ValueError(f"shard {shard} has no experts") + for e in experts: + counts[e] += 1 + if not (0 <= e < n_expts_tot): + raise ValueError(f"expert id {e} out of range [0, {n_expts_tot})") + word = e >> 5 + bit = e & 31 + expt_bitmask[shard, word] |= 1 << bit + expt_boolmask[shard, e] = True + + if not all(c == 1 for c in counts.values()): + raise ValueError("each expert must be owned by exactly one shard") + + expt_bitmask = expt_bitmask.to(device) + expt_boolmask = expt_boolmask.to(device) + + expt_map = torch.full((n_shards, n_expts_tot), -1, dtype=torch.int32) + for shard, experts in expt_dict.items(): + for local_id, global_id in enumerate(sorted(experts)): + expt_map[shard, global_id] = local_id + expt_map = expt_map.to(device) + + n_expts_per_shard = [len(expt_dict[s]) for s in range(n_shards)] + return ExptAssignment(expt_bitmask, expt_boolmask, expt_map, n_expts_per_shard) + + +# --------------------------------------------------------------------------- +# Ragged Tensor Metadata +# --------------------------------------------------------------------------- + + +@dataclass +class RaggedTensorMetadata: + """Lightweight ragged tensor descriptor. + + Example with 4 experts receiving [3, 0, 5, 2] tokens: + slice_sizes = [3, 0, 5, 2] + slice_offs = [0, 3, 3, 8, 10] + """ + + slice_sizes: torch.Tensor # (n_slices,) int32 + slice_offs: torch.Tensor # (n_slices + 1,) int32 + + @property + def n_slices(self) -> int: + return self.slice_sizes.shape[0] + + +def make_ragged_tensor_metadata( + slice_sizes: torch.Tensor, + n_total_rows: int, +) -> RaggedTensorMetadata: + """Build ragged metadata from per-expert token counts. + + Args: + slice_sizes: (n_experts,) int32 tensor of token counts per expert. + n_total_rows: total number of active token-expert slots (for validation). + """ + assert slice_sizes.ndim == 1 + slice_sizes = slice_sizes.to(torch.int32) + offs = torch.zeros(slice_sizes.shape[0] + 1, dtype=torch.int32, device=slice_sizes.device) + offs[1:] = torch.cumsum(slice_sizes, dim=0) + return RaggedTensorMetadata(slice_sizes, offs) + + +def remap_ragged_tensor_metadata( + metadata: RaggedTensorMetadata, + expt_map: torch.Tensor, +) -> RaggedTensorMetadata: + """Remap global expert metadata to a local expert view. + + expt_map: (n_expts_tot,) int32 where expt_map[global_id] is the local id + on this rank, or -1 if the expert is not on this rank. + + Returns metadata containing only the experts owned by this rank, with + ORIGINAL global offsets preserved so the grouped matmul addresses the + correct positions in the globally-indexed dispatch buffer. + """ + valid = expt_map != -1 + local_ids = expt_map[valid] + n_local = int(local_ids.max().item()) + 1 if local_ids.numel() > 0 else 0 + device = metadata.slice_sizes.device + local_sizes = torch.zeros(n_local, dtype=torch.int32, device=device) + local_offs = torch.zeros(n_local + 1, dtype=torch.int32, device=device) + for g in range(expt_map.shape[0]): + lid = expt_map[g].item() + if lid >= 0: + local_sizes[lid] = metadata.slice_sizes[g] + local_offs[lid] = metadata.slice_offs[g] + if n_local > 0: + local_offs[n_local] = local_offs[n_local - 1] + local_sizes[n_local - 1] + return RaggedTensorMetadata(local_sizes, local_offs) + + +# --------------------------------------------------------------------------- +# Top-k Routing / Bitmatrix Metadata +# --------------------------------------------------------------------------- + + +@dataclass +class BitmatrixMetadata: + """Routing indices derived from the top-k selection. + + col_sum: (n_expts,) histogram: tokens per expert + row_sorted_indx: (n_tokens * k,) flat token-expert slots grouped by expert (dispatch order) + col_sorted_indx: (n_tokens * k,) inverse permutation (combine order) + """ + + col_sum: torch.Tensor + row_sorted_indx: torch.Tensor + col_sorted_indx: torch.Tensor + + +@dataclass +class TopkResult: + vals: torch.Tensor # (n_tokens, k) softmax gating weights + indx: torch.Tensor # (n_tokens, k) expert indices (int16) + mask_metadata: BitmatrixMetadata + + +def _make_bitmatrix_metadata(indx: torch.Tensor, n_expts: int) -> BitmatrixMetadata: + """Build dispatch/combine indices from the (n_tokens, k) expert-index tensor. + + Follows triton_kernels/tensor_details/bitmatrix.py (optimised convention): + col_sorted_indx[expert_sorted_pos] = original flat index + row_sorted_indx[original_flat_idx] = expert_sorted_pos + + Handles -1 (invalid) entries correctly. + """ + device = indx.device + flat_indx = indx.reshape(-1).to(torch.int32) + n_elements = flat_indx.numel() + + valid = flat_indx >= 0 + n_valid = valid.sum().item() + + col_sum = torch.histc( + flat_indx[valid].float(), + bins=n_expts, + min=0, + max=n_expts - 1, + ).to(torch.int32) + + col_sorted_indx = torch.full((n_elements,), -1, dtype=torch.int32, device=device) + row_sorted_indx = torch.full((n_elements,), -1, dtype=torch.int32, device=device) + + sort_keys = flat_indx.clone().long() + sort_keys[~valid] = n_expts + sorted_order = torch.argsort(sort_keys, stable=True).to(torch.int32) + + col_sorted_indx[:n_valid] = sorted_order[:n_valid] + expert_positions = torch.arange(n_valid, device=device, dtype=torch.int32) + row_sorted_indx.scatter_(0, sorted_order[:n_valid].long(), expert_positions) + + return BitmatrixMetadata( + col_sum=col_sum, + col_sorted_indx=col_sorted_indx, + row_sorted_indx=row_sorted_indx, + ) + + +def topk( + x: torch.Tensor, + k: int, + apply_softmax: bool = True, +) -> TopkResult: + """Compute top-k routing over expert logits. + + Uses PyTorch ops (matches upstream topk_torch reference). + + Args: + x: (n_tokens, n_expts) float32 logit tensor. + k: number of experts to activate per token. + apply_softmax: whether to softmax the selected values. + + Returns: + TopkResult with vals, indx, and mask_metadata. + """ + n_tokens, n_expts = x.shape + + vals, indx = torch.topk(x.float(), k, dim=1, sorted=True) + + if apply_softmax: + vals = torch.softmax(vals, dim=-1).to(x.dtype) + else: + vals = vals.to(x.dtype) + indx = indx.to(torch.int16) + + mask_metadata = _make_bitmatrix_metadata(indx.to(torch.int32), n_expts) + return TopkResult(vals=vals, indx=indx, mask_metadata=mask_metadata) + + +# --------------------------------------------------------------------------- +# Masked Reduce +# --------------------------------------------------------------------------- + + +@triton.jit +def _reduce_kernel( + Y_ptr, + stride_y_t, + stride_y_a, + stride_y_d, + Z_ptr, + stride_z_t, + stride_z_d, + Mask_ptr, + n_tokens, + d_model, + N_EXPTS_ACT: tl.constexpr, + BLOCK_D: tl.constexpr, + HAS_MASK: tl.constexpr, +): + pid_t = tl.program_id(0) + pid_d = tl.program_id(1) + if pid_t >= n_tokens: + return + + offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask_d = offs_d < d_model + + acc = tl.zeros([BLOCK_D], dtype=tl.float32) + for act in range(N_EXPTS_ACT): + if HAS_MASK: + m = tl.load( + Mask_ptr + pid_t * N_EXPTS_ACT * d_model + act * d_model + offs_d, + mask=mask_d, + other=0, + ).to(tl.int1) + y = tl.load( + Y_ptr + pid_t * stride_y_t + act * stride_y_a + offs_d * stride_y_d, + mask=mask_d, + other=0.0, + ).to(tl.float32) + if HAS_MASK: + y = tl.where(m, y, 0.0) + acc += y + + tl.store( + Z_ptr + pid_t * stride_z_t + offs_d * stride_z_d, + acc.to(Z_ptr.dtype.element_ty), + mask=mask_d, + ) + + +def reduce( + y: torch.Tensor, + dim: int = 1, + mask: torch.Tensor | None = None, +) -> tuple[torch.Tensor, None]: + """Sum-reduce over *dim* with optional boolean mask. + + Matches the upstream ``reduce(y, dim=1, mask=mask)`` signature. + + Args: + y: (n_tokens, k, d_model) expert outputs. + dim: reduction dimension (must be 1). + mask: (n_tokens, k, d_model) bool/int mask; zero = skip. + + Returns: + (z, None) where z has shape (n_tokens, d_model). + """ + assert dim == 1 and y.ndim == 3 + n_tokens, k, d_model = y.shape + device = y.device + + z = torch.zeros((n_tokens, d_model), dtype=y.dtype, device=device) + + BLOCK_D = min(triton.next_power_of_2(d_model), 512) + grid = (n_tokens, triton.cdiv(d_model, BLOCK_D)) + + _reduce_kernel[grid]( + y, + y.stride(0), + y.stride(1), + y.stride(2), + z, + z.stride(0), + z.stride(1), + mask if mask is not None else y, + n_tokens, + d_model, + N_EXPTS_ACT=k, + BLOCK_D=BLOCK_D, + HAS_MASK=(mask is not None), + ) + return z, None diff --git a/iris/iris.py b/iris/iris.py index 8c750ba67..d2bbb758a 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1316,6 +1316,50 @@ def reduce_scatter(self, output_tensor, input_tensor, op=None, group=None, async output_tensor, input_tensor, self._iris, op=op, group=group, async_op=async_op, config=config ) + def moe_dispatcher( + self, hidden_dim, num_experts, topk, max_tokens, dtype=None, group=None, config=None, expt_assignment=None + ): + """ + Create a pre-allocated MoE token dispatcher for expert-parallel inference. + + Provides dispatch (route tokens to expert-owning ranks) and combine + (send results back, weighted sum) via iris symmetric heap scatter. + + Args: + hidden_dim: Model hidden dimension. + num_experts: Total number of experts across all ranks. + topk: Number of experts activated per token. + max_tokens: Maximum tokens per rank per call. + dtype: Data type for dispatch/combine buffers (default: bfloat16). + group: Reserved for future process group support. + config: MoEDispatchConfig with kernel tuning parameters. + expt_assignment: Expert-to-rank mapping. If None, uses uniform assignment. + + Returns: + MoEDispatcher instance with dispatch() and combine() methods. + + Example: + >>> ctx = iris.iris() + >>> dispatcher = ctx.ccl.moe_dispatcher(hidden_dim=4096, num_experts=64, + ... topk=2, max_tokens=4096) + >>> recv_tokens, local_meta, handle = dispatcher.dispatch(tokens, topk_idx, topk_weight) + >>> expert_out = run_experts(recv_tokens, local_meta) + >>> combined = dispatcher.combine(expert_out, handle) + """ + from iris.ccl.moe_dispatch import MoEDispatcher + + return MoEDispatcher( + self._iris, + hidden_dim, + num_experts, + topk, + max_tokens, + dtype=dtype, + group=group, + config=config, + expt_assignment=expt_assignment, + ) + @triton.jit def __translate(ptr, from_rank, to_rank, heap_bases, hint: tl.constexpr = None): diff --git a/tests/ccl/test_moe_dispatch.py b/tests/ccl/test_moe_dispatch.py new file mode 100644 index 000000000..fb6e99935 --- /dev/null +++ b/tests/ccl/test_moe_dispatch.py @@ -0,0 +1,469 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Test suite for MoEDispatcher (dispatch/combine via iris symmetric heap). + +Uses ``mixture_of_expt_nosharded`` from the MoE example as the ground-truth +reference. Tests cover end-to-end correctness, dispatch-only, combine-only, +buffer reuse across varying batch sizes, topk=1 routing, and handle immutability. + +Run with: + torchrun --nproc_per_node=8 -m pytest tests/ccl/test_moe_dispatch.py -v +""" + +import gc +import importlib.util +import sys +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist + +import iris +from iris.ccl import MoEDispatcher +from iris.ccl.moe_utils import ( + make_expt_dict_uniform, + make_expt_assignment, + topk, +) + +# --------------------------------------------------------------------------- +# Load grouped_matmul from the example directory (not promoted to ccl yet) +# --------------------------------------------------------------------------- + +PROJECT_ROOT = Path(__file__).resolve() +while not (PROJECT_ROOT / "tests").is_dir() or not (PROJECT_ROOT / "examples").is_dir(): + if PROJECT_ROOT == PROJECT_ROOT.parent: + raise FileNotFoundError("Could not find project root") + PROJECT_ROOT = PROJECT_ROOT.parent + +EXAMPLE_DIR = PROJECT_ROOT / "examples" / "31_expert_sharded_moe" +sys.path.insert(0, str(EXAMPLE_DIR)) + + +def _load_module(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +GROUPED_MATMUL_MOD = _load_module("grouped_matmul_test", EXAMPLE_DIR / "grouped_matmul.py") +grouped_matmul = GROUPED_MATMUL_MOD.grouped_matmul + +MOE_MOD = _load_module("moe_test", EXAMPLE_DIR / "moe.py") +mixture_of_expt_nosharded = MOE_MOD.mixture_of_expt_nosharded + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _setup(heap_size=2**33): + """Common setup: check dist, create iris context, return (ctx, rank, world_size, device).""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + ctx = iris.iris(heap_size) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + device = torch.device(f"cuda:{rank}") + return ctx, rank, world_size, device + + +def _make_global_data(n_tokens, d_model, n_expts_tot, dtype, device): + """Generate shared global data (same seed on all ranks, broadcast).""" + torch.manual_seed(0) + x_global = torch.randn(n_tokens, d_model, device=device, dtype=dtype) + l_global = torch.rand(n_tokens, n_expts_tot, device=device, dtype=torch.float32) + w_global = torch.randn(n_expts_tot, d_model, d_model, device=device, dtype=dtype) + b_global = torch.randn(n_expts_tot, d_model, device=device, dtype=torch.float32) + dist.broadcast(x_global, src=0) + dist.broadcast(l_global, src=0) + dist.broadcast(w_global, src=0) + dist.broadcast(b_global, src=0) + return x_global, l_global, w_global, b_global + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("n_tokens_local", [32, 128]) +@pytest.mark.parametrize("d_model", [64, 256]) +@pytest.mark.parametrize("n_expts_act", [1, 2]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_dispatch_combine_e2e(n_tokens_local, d_model, n_expts_act, dtype): + """Full dispatch + expert matmul + combine pipeline matches single-device reference.""" + ctx = None + try: + ctx, rank, world_size, device = _setup() + n_expts_tot = world_size * 2 + n_tokens = n_tokens_local * world_size + + x_global, l_global, w_global, b_global = _make_global_data(n_tokens, d_model, n_expts_tot, dtype, device) + + # Reference: single-device MoE + y_ref = mixture_of_expt_nosharded(x_global, l_global, w_global, b_global, n_expts_act) + + # Expert assignment + expt_dict = make_expt_dict_uniform(world_size, n_expts_tot) + expt_assignment = make_expt_assignment(world_size, n_expts_tot, expt_dict, device) + + # Local slices + first = rank * n_tokens_local + last = first + n_tokens_local + x_local = x_global[first:last].contiguous() + l_local = l_global[first:last].contiguous() + w_local = w_global[expt_assignment.expt_boolmask[rank]].contiguous() + b_local = b_global[expt_assignment.expt_boolmask[rank]].contiguous() + + # Top-k routing (local) + topk_result = topk(l_local, n_expts_act, apply_softmax=True) + + # Create dispatcher + dispatcher = MoEDispatcher( + ctx, + d_model, + n_expts_tot, + n_expts_act, + n_tokens_local, + dtype=dtype, + expt_assignment=expt_assignment, + ) + + ctx.barrier() + + # Dispatch + dispatch_buf, local_meta, handle = dispatcher.dispatch( + x_local, + topk_result.indx, + topk_result.vals, + ) + + # Expert computation (grouped matmul on local experts) + expert_out = grouped_matmul(dispatch_buf, w_local, b_local, local_meta) + + # Combine + z_local = dispatcher.combine(expert_out, handle) + + # Gather all local outputs and compare with reference + z_gathered = torch.empty_like(y_ref) + dist.all_gather_into_tensor(z_gathered, z_local.contiguous()) + + torch.testing.assert_close(y_ref, z_gathered, atol=1e-2, rtol=1e-2) + finally: + if ctx is not None: + try: + ctx.barrier() + except Exception: + pass + del ctx + gc.collect() + + +@pytest.mark.parametrize("n_tokens_local", [64]) +@pytest.mark.parametrize("d_model", [128]) +def test_dispatch_only(n_tokens_local, d_model): + """Dispatch buffer matches the example's convert_dp_to_ep output.""" + ctx = None + try: + ctx, rank, world_size, device = _setup() + n_expts_tot = world_size * 2 + n_expts_act = 2 + n_tokens = n_tokens_local * world_size + dtype = torch.bfloat16 + + x_global, l_global, _, _ = _make_global_data(n_tokens, d_model, n_expts_tot, dtype, device) + + expt_dict = make_expt_dict_uniform(world_size, n_expts_tot) + expt_assignment = make_expt_assignment(world_size, n_expts_tot, expt_dict, device) + + first = rank * n_tokens_local + last = first + n_tokens_local + x_local = x_global[first:last].contiguous() + l_local = l_global[first:last].contiguous() + + topk_result = topk(l_local, n_expts_act, apply_softmax=True) + + dispatcher = MoEDispatcher( + ctx, + d_model, + n_expts_tot, + n_expts_act, + n_tokens_local, + dtype=dtype, + expt_assignment=expt_assignment, + ) + + ctx.barrier() + + dispatch_buf, local_meta, handle = dispatcher.dispatch( + x_local, + topk_result.indx, + topk_result.vals, + ) + + # Verify dispatch buffer has correct shape + assert dispatch_buf.shape[1] == d_model + assert dispatch_buf.shape[0] == n_tokens * n_expts_act + + # Verify local_meta has only this rank's experts + n_local_experts = expt_assignment.n_expts_per_shard[rank] + assert local_meta.n_slices == n_local_experts + + # Verify non-zero entries exist (tokens were routed) + assert dispatch_buf.abs().sum() > 0 + finally: + if ctx is not None: + try: + ctx.barrier() + except Exception: + pass + del ctx + gc.collect() + + +@pytest.mark.parametrize("n_tokens_local", [64]) +@pytest.mark.parametrize("d_model", [128]) +def test_combine_only(n_tokens_local, d_model): + """Combine output matches the example's convert_ep_to_dp + reduce output.""" + ctx = None + try: + ctx, rank, world_size, device = _setup() + n_expts_tot = world_size * 2 + n_expts_act = 2 + n_tokens = n_tokens_local * world_size + dtype = torch.bfloat16 + + x_global, l_global, w_global, b_global = _make_global_data(n_tokens, d_model, n_expts_tot, dtype, device) + + expt_dict = make_expt_dict_uniform(world_size, n_expts_tot) + expt_assignment = make_expt_assignment(world_size, n_expts_tot, expt_dict, device) + + first = rank * n_tokens_local + last = first + n_tokens_local + x_local = x_global[first:last].contiguous() + l_local = l_global[first:last].contiguous() + w_local = w_global[expt_assignment.expt_boolmask[rank]].contiguous() + b_local = b_global[expt_assignment.expt_boolmask[rank]].contiguous() + + topk_result = topk(l_local, n_expts_act, apply_softmax=True) + + dispatcher = MoEDispatcher( + ctx, + d_model, + n_expts_tot, + n_expts_act, + n_tokens_local, + dtype=dtype, + expt_assignment=expt_assignment, + ) + + ctx.barrier() + + # Full pipeline + dispatch_buf, local_meta, handle = dispatcher.dispatch( + x_local, + topk_result.indx, + topk_result.vals, + ) + expert_out = grouped_matmul(dispatch_buf, w_local, b_local, local_meta) + z_local = dispatcher.combine(expert_out, handle) + + # Verify output shape + assert z_local.shape == (n_tokens_local, d_model) + # Verify output is non-trivial + assert z_local.abs().sum() > 0 + finally: + if ctx is not None: + try: + ctx.barrier() + except Exception: + pass + del ctx + gc.collect() + + +def test_buffer_reuse(): + """Pre-allocated buffers work correctly across calls with different batch sizes.""" + ctx = None + try: + ctx, rank, world_size, device = _setup() + d_model = 64 + n_expts_tot = world_size * 2 + n_expts_act = 2 + dtype = torch.bfloat16 + max_tokens = 128 + + expt_dict = make_expt_dict_uniform(world_size, n_expts_tot) + expt_assignment = make_expt_assignment(world_size, n_expts_tot, expt_dict, device) + + dispatcher = MoEDispatcher( + ctx, + d_model, + n_expts_tot, + n_expts_act, + max_tokens, + dtype=dtype, + expt_assignment=expt_assignment, + ) + + # Run with two different batch sizes + for n_tokens_local in [32, 64]: + n_tokens = n_tokens_local * world_size + + x_global, l_global, w_global, b_global = _make_global_data(n_tokens, d_model, n_expts_tot, dtype, device) + + y_ref = mixture_of_expt_nosharded(x_global, l_global, w_global, b_global, n_expts_act) + + first = rank * n_tokens_local + last = first + n_tokens_local + x_local = x_global[first:last].contiguous() + l_local = l_global[first:last].contiguous() + w_local = w_global[expt_assignment.expt_boolmask[rank]].contiguous() + b_local = b_global[expt_assignment.expt_boolmask[rank]].contiguous() + + topk_result = topk(l_local, n_expts_act, apply_softmax=True) + ctx.barrier() + + dispatch_buf, local_meta, handle = dispatcher.dispatch( + x_local, + topk_result.indx, + topk_result.vals, + ) + expert_out = grouped_matmul(dispatch_buf, w_local, b_local, local_meta) + z_local = dispatcher.combine(expert_out, handle) + + z_gathered = torch.empty_like(y_ref) + dist.all_gather_into_tensor(z_gathered, z_local.contiguous()) + + torch.testing.assert_close(y_ref, z_gathered, atol=1e-2, rtol=1e-2) + finally: + if ctx is not None: + try: + ctx.barrier() + except Exception: + pass + del ctx + gc.collect() + + +def test_topk_1(): + """Simplest routing: each token goes to exactly one expert.""" + ctx = None + try: + ctx, rank, world_size, device = _setup() + d_model = 64 + n_expts_tot = world_size * 2 + n_expts_act = 1 + n_tokens_local = 32 + n_tokens = n_tokens_local * world_size + dtype = torch.bfloat16 + + x_global, l_global, w_global, b_global = _make_global_data(n_tokens, d_model, n_expts_tot, dtype, device) + + y_ref = mixture_of_expt_nosharded(x_global, l_global, w_global, b_global, n_expts_act) + + expt_dict = make_expt_dict_uniform(world_size, n_expts_tot) + expt_assignment = make_expt_assignment(world_size, n_expts_tot, expt_dict, device) + + first = rank * n_tokens_local + last = first + n_tokens_local + x_local = x_global[first:last].contiguous() + l_local = l_global[first:last].contiguous() + w_local = w_global[expt_assignment.expt_boolmask[rank]].contiguous() + b_local = b_global[expt_assignment.expt_boolmask[rank]].contiguous() + + topk_result = topk(l_local, n_expts_act, apply_softmax=True) + + dispatcher = MoEDispatcher( + ctx, + d_model, + n_expts_tot, + n_expts_act, + n_tokens_local, + dtype=dtype, + expt_assignment=expt_assignment, + ) + + ctx.barrier() + + dispatch_buf, local_meta, handle = dispatcher.dispatch( + x_local, + topk_result.indx, + topk_result.vals, + ) + expert_out = grouped_matmul(dispatch_buf, w_local, b_local, local_meta) + z_local = dispatcher.combine(expert_out, handle) + + z_gathered = torch.empty_like(y_ref) + dist.all_gather_into_tensor(z_gathered, z_local.contiguous()) + + torch.testing.assert_close(y_ref, z_gathered, atol=1e-2, rtol=1e-2) + finally: + if ctx is not None: + try: + ctx.barrier() + except Exception: + pass + del ctx + gc.collect() + + +def test_handle_frozen(): + """DispatchHandle is immutable (frozen dataclass).""" + ctx = None + try: + ctx, rank, world_size, device = _setup() + d_model = 64 + n_expts_tot = world_size * 2 + n_expts_act = 2 + n_tokens_local = 32 + n_tokens = n_tokens_local * world_size + dtype = torch.bfloat16 + + torch.manual_seed(0) + x_local = torch.randn(n_tokens_local, d_model, device=device, dtype=dtype) + l_local = torch.rand(n_tokens_local, n_expts_tot, device=device, dtype=torch.float32) + dist.broadcast(x_local, src=0) + dist.broadcast(l_local, src=0) + + expt_dict = make_expt_dict_uniform(world_size, n_expts_tot) + expt_assignment = make_expt_assignment(world_size, n_expts_tot, expt_dict, device) + + topk_result = topk(l_local, n_expts_act, apply_softmax=True) + + dispatcher = MoEDispatcher( + ctx, + d_model, + n_expts_tot, + n_expts_act, + n_tokens_local, + dtype=dtype, + expt_assignment=expt_assignment, + ) + + ctx.barrier() + + _, _, handle = dispatcher.dispatch( + x_local, + topk_result.indx, + topk_result.vals, + ) + + with pytest.raises(AttributeError): + handle.n_tokens_local = 999 + finally: + if ctx is not None: + try: + ctx.barrier() + except Exception: + pass + del ctx + gc.collect()