From 0741c51bc5d8dc1043e5e5e4b6a2d83df6319dfb Mon Sep 17 00:00:00 2001 From: Kyle Wang Date: Mon, 12 Jan 2026 01:26:31 -0600 Subject: [PATCH 1/7] add GEMM+ReduceScatter w/ workgroup specialization --- .../benchmark.py | 320 ++++++++++++++++++ .../gemm_reduce_scatter.py | 213 ++++++++++++ .../matmul_wrapper.py | 166 +++++++++ examples/common/validation.py | 63 ++++ 4 files changed, 762 insertions(+) create mode 100644 examples/21_gemm_reduce_scatter_wg_specialization/benchmark.py create mode 100644 examples/21_gemm_reduce_scatter_wg_specialization/gemm_reduce_scatter.py create mode 100644 examples/21_gemm_reduce_scatter_wg_specialization/matmul_wrapper.py diff --git a/examples/21_gemm_reduce_scatter_wg_specialization/benchmark.py b/examples/21_gemm_reduce_scatter_wg_specialization/benchmark.py new file mode 100644 index 000000000..85877bc78 --- /dev/null +++ b/examples/21_gemm_reduce_scatter_wg_specialization/benchmark.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# GEMM + ReduceScatter Benchmark with Workgroup Specialization +# Reference: ByteDance Triton-distributed +# https://github.com/ByteDance-Seed/Triton-distributed/blob/main/tutorials/10-AMD-overlapping-gemm-reduce-scatter.py + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +import random +import argparse +import math + +from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set +from examples.common.validation import validate_reduce_scatter + +import iris +from matmul_wrapper import matmul_rs + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="GEMM + ReduceScatter Benchmark with Workgroup Specialization", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=8192, help="Number of rows in matrix A (M)") + parser.add_argument("-n", type=int, default=4096, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=12288, help="Common dimension (K), will be split across ranks") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-t", "--trace_tiles", action="store_true", help="Enable tile-tracing mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of computation", + ) + parser.add_argument( + "--output_file", + type=str, + default="log.json", + help="Output file", + ) + parser.add_argument("--BLK_M", type=int, default=128, help="Block size M") + parser.add_argument("--BLK_N", type=int, default=256, help="Block size N") + parser.add_argument("--BLK_K", type=int, default=32, help="Block size K") + parser.add_argument("--gsize_m", type=int, default=1, help="L2-cache locality swizzle parameter") + parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") + parser.add_argument( + "--num_sms", + type=int, + default=None, + help="Number of total SMs (default: auto-detected)", + ) + parser.add_argument( + "--gemm_sms", + type=int, + default=None, + help="Number of SMs for GEMM (default: auto-detected as power of 2)", + ) + parser.add_argument("--num_stages", type=int, default=2, help="Number of stages") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Set default SM values if not provided + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + if args["num_sms"] is None: + args["num_sms"] = cu_count + if args["gemm_sms"] is None: + # Use next smaller power of 2 for GEMM SMs + args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + + # Datatype + datatype = torch.float16 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M, N, K = args["m"], args["n"], args["k"] + + assert M % world_size == 0, f"M ({M}) must be divisible by world size ({world_size})" + assert K % world_size == 0, f"K ({K}) must be divisible by world size ({world_size})" + assert (M // world_size) % args["BLK_M"] == 0, f"M_per_rank ({M // world_size}) must be divisible by BLK_M ({args['BLK_M']})" + + local_K = K // world_size + M_per_rank = M // world_size + + # Generate full matrices for reference calculation + A_full = shmem.randn(M, K, device="cuda", dtype=datatype) + B_full = shmem.randn(K, N, device="cuda", dtype=datatype) + + # Each rank gets a portion of K dimension + local_A = A_full[:, rank * local_K : (rank + 1) * local_K].clone() + local_B = B_full[rank * local_K : (rank + 1) * local_K, :].clone() + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("M", M) + json_writer.add_field("N", N) + json_writer.add_field("K", K) + json_writer.add_field("local_K", local_K) + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Local buffer for GEMM result [M, N] + local_buf = shmem.zeros((M, N), device="cuda", dtype=datatype) + + # Global output buffer for ReduceScatter result [M_per_rank, N] + # This is where each rank accumulates its final result + output_buf = shmem.zeros((M_per_rank, N), device="cuda", dtype=datatype) + + total_blocks_M = triton.cdiv(M, args["BLK_M"]) + total_blocks_N = triton.cdiv(N, args["BLK_N"]) + total_tiles = total_blocks_M * total_blocks_N + + locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8) + + gemm_stream = torch.cuda.Stream() + + json_writer.add_field("num_sms", args["num_sms"]) + json_writer.add_field("gemm_sms", args["gemm_sms"]) + + kernel_timing = { + "gemm_rs": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + # Allocate Timestamps + timestamps = Timestamps(num_tiles=total_tiles) + + def run_experiment(): + nonlocal local_buf, output_buf + + # Reset buffers + local_buf.zero_() + output_buf.zero_() + locks.zero_() + shmem.barrier() + + if args["trace_tiles"]: + timestamps.reset() + shmem.barrier() + + torch.cuda.nvtx.range_push("GEMM + ReduceScatter") + with torch.cuda.stream(gemm_stream): + kernel_timing["gemm_rs"]["start_event"].record() + matmul_rs.apply( + local_A, + local_B, + local_buf, + output_buf, + locks, + rank, + world_size, + args["gemm_sms"], + args["num_sms"], + args["BLK_M"], + args["BLK_N"], + args["BLK_K"], + args["gsize_m"], + args["num_stages"], + shmem.get_heap_bases(), + "gfx942", + args["trace_tiles"], + timestamps.mm_begin_timestamp, + timestamps.mm_end_timestamp, + ) + kernel_timing["gemm_rs"]["end_event"].record() + kernel_timing["gemm_rs"]["experiments"] += 1 + + torch.cuda.nvtx.range_pop() + # Ensure kernel completion before barrier + gemm_stream.synchronize() + shmem.barrier() + + for k in ["gemm_rs"]: + ms = kernel_timing[k]["start_event"].elapsed_time(kernel_timing[k]["end_event"]) + kernel_timing[k]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + # Warmup + run_experiment() + shmem.barrier() + + for k in ["gemm_rs"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + if args["validate"]: + shmem.info("Validating...") + matmul_rs.set_debug(True) + + # Run one more time for validation + run_experiment() + # Additional barrier to ensure all remote writes are complete + torch.cuda.synchronize() + shmem.barrier() + + # Get the GEMM result (input to reduce_scatter) and final output + local_gemm = local_buf.clone() + local_output = output_buf.clone() + + # Create process group for validation + tp_group = dist.new_group(ranks=list(range(world_size))) + + # For fp16 with atomic_add across multiple ranks, allow larger tolerance + # The 0.5 max_diff comes from accumulated rounding errors in atomic operations + # Relative error is ~0.08% which is acceptable for distributed computation + atol = 1.0 if datatype == torch.float16 else 0.5 + + # Validate reduce_scatter using the common validation function + success = validate_reduce_scatter(local_gemm, local_output, shmem, tp_group, atol=atol) + + if success: + shmem.info("✅ Triton and Torch match") + else: + shmem.info("❌ Triton and Torch differ") + + json_writer.add_field("success", success) + + if not is_triton_interpret_set(): + gemm_registers = matmul_rs.get_matmul_registers() + gemm_spills = matmul_rs.get_matmul_spills() + json_writer.add_field("gemm_registers", gemm_registers) + json_writer.add_field("gemm_spills", gemm_spills) + + shmem.barrier() + shmem.info("Validation completed") + + if args["benchmark"]: + matmul_rs.set_debug(False) + shmem.info("Benchmarking...") + + # Performance calculation: + # Each rank computes [M, N] partial result from [M, local_K] x [local_K, N] + # FLOPs = 2 * M * N * local_K + perf = lambda ms: 2 * M * N * local_K * 1e-12 / (ms * 1e-3) + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + triton_tflops = perf(triton_ms) + + shmem.info( + f"GEMM + ReduceScatter (total_tiles={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops" + ) + + json_writer.add_field("tflops", triton_tflops) + json_writer.add_field("total_ms", triton_ms) + + for k in ["gemm_rs"]: + json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"]) + json_writer.add_field(k + "_experiments", kernel_timing[k]["experiments"]) + + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + if args["trace_tiles"] and rank == 0: + gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3 + filename = f"gemm_tiles_reduce_scatter_trace_rank{rank}.json" + timestamps.to_json(filename, gpu_freq) + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/21_gemm_reduce_scatter_wg_specialization/gemm_reduce_scatter.py b/examples/21_gemm_reduce_scatter_wg_specialization/gemm_reduce_scatter.py new file mode 100644 index 000000000..455941011 --- /dev/null +++ b/examples/21_gemm_reduce_scatter_wg_specialization/gemm_reduce_scatter.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# GEMM + ReduceScatter implementation using iris +# Reference: ByteDance Triton-distributed tutorial +# https://github.com/ByteDance-Seed/Triton-distributed/blob/main/tutorials/10-AMD-overlapping-gemm-reduce-scatter.py + +import triton +import triton.language as tl +from examples.common.utils import read_realtime + +import iris + + +@triton.jit() +def persistent_gemm_reduce_scatter_wg_specialized( + A, + B, + C, # local buffer [M, N] + C_global, # global output buffer [M, N] on each rank + locks, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_cg_m, + stride_cg_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + GEMM_SMS: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + EVEN_K: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + COLLECT_TIMESTAMPS: tl.constexpr = False, + mm_begin_timestamp_ptr: tl.tensor = None, + mm_end_timestamp_ptr: tl.tensor = None, +): + """ + GEMM + ReduceScatter with Workgroup Specialization + + Split SMs into two groups: + - GEMM SMs: Perform matrix multiplication computation + - Communication SMs: Handle data communication (scatter to target ranks) + + This approach enables overlapping computation and communication. + + Data partitioning (ReduceScatter): + - A: [M, local_K] - Each rank has a portion of K dimension + - B: [local_K, N] - Each rank has a portion of K dimension + - Each rank computes partial C = A @ B of shape [M, N] + - ReduceScatter: Split C along M dimension into world_size chunks, + send chunk i to rank i, accumulate with atomic_add + - Output: Each rank ends up with [M/world_size, N] + """ + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + M_per_rank = M // world_size + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + + # Workgroup specialization: GEMM path + if pid < GEMM_SMS: + for tile_id in range(pid, total_tiles, GEMM_SMS): + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) + + # Standard tile mapping with GROUP_SIZE_M for L2 cache locality + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + rk = tl.arange(0, BLOCK_SIZE_K) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + + # GEMM computation + loop_k = tl.cdiv(K, BLOCK_SIZE_K) + if not EVEN_K: + loop_k -= 1 + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, loop_k): + a = tl.load(tl.multiple_of(A_BASE, (1, 16))) + b = tl.load(tl.multiple_of(B_BASE, (16, 1))) + acc += tl.dot(a, b) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + if not EVEN_K: + k = loop_k + rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + A_BASE = tl.multiple_of(A_BASE, (1, 16)) + B_BASE = tl.multiple_of(B_BASE, (16, 1)) + a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0) + b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0) + acc += tl.dot(a, b) + + c = acc.to(C.type.element_ty) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + sub_mask = (rm[:, None] < M) & (rn[None, :] < N) + + # Store to local buffer with write-through cache + local_offset = rm[:, None] * stride_cm + rn[None, :] * stride_cn + + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) + + tl.store(C + local_offset, c, mask=sub_mask, cache_modifier=".wt") + tl.debug_barrier() + tl.store(locks + tile_id, 1, cache_modifier=".wt") + + else: # Communication path: pid >= GEMM_SMS + COMM_SMS = NUM_SMS - GEMM_SMS + comm_pid = pid - GEMM_SMS + + for tile_id in range(comm_pid, total_tiles, COMM_SMS): + # Calculate tile position (same mapping as GEMM path) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + sub_mask = (rm[:, None] < M) & (rn[None, :] < N) + + local_offset = rm[:, None] * stride_cm + rn[None, :] * stride_cn + + # Wait for GEMM to finish this tile + while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1: + pass + + # Load computed data from local buffer + c = tl.load(C + local_offset, mask=sub_mask) + + # Determine target rank based on M position + # ReduceScatter: chunk i of M dimension goes to rank i + tile_m_start = pid_m * BLOCK_SIZE_M + target_rank = tile_m_start // M_per_rank + + # Calculate offset within target rank's output region + # target_m is the row offset within C_global[M_per_rank, N] + target_m = tile_m_start % M_per_rank + offs_cm = target_m + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + global_offset = offs_cm[:, None] * stride_cg_m + offs_cn[None, :] * stride_cg_n + + # Mask for valid elements within the target region + global_mask = (offs_cm[:, None] < M_per_rank) & (offs_cn[None, :] < N) + + # Send to target rank using atomic add + if target_rank == cur_rank: + # Local atomic add + tl.atomic_add(C_global + global_offset, c, mask=global_mask) + else: + # Remote atomic add using iris + iris.atomic_add( + C_global + global_offset, + c, + cur_rank, + target_rank, + heap_bases, + mask=global_mask, + ) diff --git a/examples/21_gemm_reduce_scatter_wg_specialization/matmul_wrapper.py b/examples/21_gemm_reduce_scatter_wg_specialization/matmul_wrapper.py new file mode 100644 index 000000000..ac6fac259 --- /dev/null +++ b/examples/21_gemm_reduce_scatter_wg_specialization/matmul_wrapper.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +from gemm_reduce_scatter import persistent_gemm_reduce_scatter_wg_specialized +from examples.common.utils import is_triton_interpret_set +import iris + +gemm_kernel = persistent_gemm_reduce_scatter_wg_specialized + + +class matmul_rs(torch.autograd.Function): + """ + GEMM + ReduceScatter matmul wrapper with Workgroup Specialization + + Split SMs into two groups: + - GEMM SMs: Perform matrix multiplication computation + - Communication SMs: Handle data communication + """ + _debug = False + _registers = None + _spills = None + _num_xcds = iris.hip.get_num_xcc() + + @staticmethod + def set_debug(debug: bool): + matmul_rs._debug = debug + + @staticmethod + def get_matmul_registers(): + if matmul_rs._debug: + return matmul_rs._registers + else: + raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") + + @staticmethod + def get_matmul_spills(): + if matmul_rs._debug: + return matmul_rs._spills + else: + raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") + + @staticmethod + def _call( + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, # local buffer [M, N] + c_global: torch.Tensor, # global output [M_per_rank, N] + locks: torch.Tensor, + rank: int, + world_size: int, + gemm_sms: int, + num_sms: int, + BLK_M: int, + BLK_N: int, + BLK_K: int, + gsize_m: int, + num_stages: int, + heap_bases_ptr: torch.Tensor, + arch: str = "gfx942", + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + num_xcds = matmul_rs._num_xcds + num_warps = 8 + waves_per_eu = 0 + mfma = 16 + kpack = 1 + even_k = K % BLK_K == 0 + + grid = (num_sms,) + + kk = gemm_kernel[grid]( + a, + b, + c, + c_global, + locks, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + c_global.stride(0), + c_global.stride(1), + BLOCK_SIZE_M=BLK_M, + BLOCK_SIZE_N=BLK_N, + BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=gsize_m, + GEMM_SMS=gemm_sms, + NUM_SMS=num_sms, + NUM_XCDS=num_xcds, + EVEN_K=even_k, + num_stages=num_stages, + num_warps=num_warps, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfma, + kpack=kpack, + heap_bases=heap_bases_ptr, + cur_rank=rank, + world_size=world_size, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp_ptr=mm_begin_timestamp, + mm_end_timestamp_ptr=mm_end_timestamp, + ) + + if matmul_rs._debug and not is_triton_interpret_set(): + matmul_rs._registers = kk.n_regs + matmul_rs._spills = kk.n_spills + + return c_global + + @staticmethod + def forward( + ctx, + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + c_global: torch.Tensor, + locks: torch.Tensor, + rank: int, + world_size: int, + gemm_sms: int, + num_sms: int, + BLK_M: int, + BLK_N: int, + BLK_K: int, + gsize_m: int, + num_stages: int, + heap_bases_ptr: torch.Tensor, + arch: str = "gfx942", + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + return matmul_rs._call( + a=a, + b=b, + c=c, + c_global=c_global, + locks=locks, + rank=rank, + world_size=world_size, + gemm_sms=gemm_sms, + num_sms=num_sms, + BLK_M=BLK_M, + BLK_N=BLK_N, + BLK_K=BLK_K, + gsize_m=gsize_m, + num_stages=num_stages, + heap_bases_ptr=heap_bases_ptr, + arch=arch, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp=mm_begin_timestamp, + mm_end_timestamp=mm_end_timestamp, + ) diff --git a/examples/common/validation.py b/examples/common/validation.py index 8046e92d8..4dade8a69 100644 --- a/examples/common/validation.py +++ b/examples/common/validation.py @@ -109,3 +109,66 @@ def validate_all_reduce(local_tensor, global_tensor, shmem, atol=1): return False return True + + +def validate_reduce_scatter(local_tensor, output_tensor, shmem, tp_group, atol=1): + """ + Validate reduce-scatter operation where each rank's local tensor is reduced (summed) + and the result is scattered across ranks along the first dimension. + + Args: + local_tensor: The local tensor on this rank before reduce-scatter [M, N] + output_tensor: The result tensor after reduce-scatter [M/world_size, N] + shmem: Iris shmem object + tp_group: torch.distributed process group for communication + atol: Absolute tolerance for comparison + + Returns: + bool: True if validation passes, False otherwise + + ReduceScatter semantics: + - Each rank has input tensor of shape [M, N] + - All inputs are reduced (summed) element-wise to get [M, N] + - The result is scattered: rank i gets rows [i*M_per_rank : (i+1)*M_per_rank] + - Output shape is [M/world_size, N] + """ + import torch.distributed as dist + + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + M, N = local_tensor.shape + M_per_rank = M // world_size + + # Verify output shape + expected_output_shape = (M_per_rank, N) + if output_tensor.shape != expected_output_shape: + shmem.error(f"Output tensor shape mismatch: expected {expected_output_shape}, got {output_tensor.shape}") + return False + + # Gather all local tensors to compute expected result + all_local_tensors = [torch.zeros_like(local_tensor) for _ in range(world_size)] + dist.all_gather(all_local_tensors, local_tensor, group=tp_group) + + # Compute expected: sum of all local tensors, then take this rank's slice + total_sum = sum(all_local_tensors) + expected = total_sum[rank * M_per_rank : (rank + 1) * M_per_rank, :] + + # Compare + diff_mask = ~torch.isclose(output_tensor, expected, atol=atol) + breaking_indices = torch.nonzero(diff_mask, as_tuple=False) + + if not torch.allclose(output_tensor, expected, atol=atol): + max_diff = (output_tensor - expected).abs().max().item() + mean_diff = (output_tensor - expected).abs().mean().item() + shmem.info(f"Reduce-scatter validation: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + for idx in breaking_indices[:5]: # Show up to 5 mismatches + idx = tuple(idx.tolist()) + computed_val = output_tensor[idx] + expected_val = expected[idx] + shmem.error( + f"Reduce-scatter mismatch at rank {rank}, index {idx}: got={computed_val}, expected={expected_val}" + ) + return False + + return True From 6da76b4657dfecda640f6fac516ff3fd19b89b64 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 12 Jan 2026 07:26:48 +0000 Subject: [PATCH 2/7] Apply Ruff auto-fixes --- .../21_gemm_reduce_scatter_wg_specialization/benchmark.py | 8 ++++---- .../gemm_reduce_scatter.py | 4 ++-- .../matmul_wrapper.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/21_gemm_reduce_scatter_wg_specialization/benchmark.py b/examples/21_gemm_reduce_scatter_wg_specialization/benchmark.py index 85877bc78..66630a84c 100644 --- a/examples/21_gemm_reduce_scatter_wg_specialization/benchmark.py +++ b/examples/21_gemm_reduce_scatter_wg_specialization/benchmark.py @@ -111,7 +111,9 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): assert M % world_size == 0, f"M ({M}) must be divisible by world size ({world_size})" assert K % world_size == 0, f"K ({K}) must be divisible by world size ({world_size})" - assert (M // world_size) % args["BLK_M"] == 0, f"M_per_rank ({M // world_size}) must be divisible by BLK_M ({args['BLK_M']})" + assert (M // world_size) % args["BLK_M"] == 0, ( + f"M_per_rank ({M // world_size}) must be divisible by BLK_M ({args['BLK_M']})" + ) local_K = K // world_size M_per_rank = M // world_size @@ -277,9 +279,7 @@ def run_experiment(): triton_ms = iris.do_bench(run_experiment, shmem.barrier) triton_tflops = perf(triton_ms) - shmem.info( - f"GEMM + ReduceScatter (total_tiles={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops" - ) + shmem.info(f"GEMM + ReduceScatter (total_tiles={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops") json_writer.add_field("tflops", triton_tflops) json_writer.add_field("total_ms", triton_ms) diff --git a/examples/21_gemm_reduce_scatter_wg_specialization/gemm_reduce_scatter.py b/examples/21_gemm_reduce_scatter_wg_specialization/gemm_reduce_scatter.py index 455941011..209388aec 100644 --- a/examples/21_gemm_reduce_scatter_wg_specialization/gemm_reduce_scatter.py +++ b/examples/21_gemm_reduce_scatter_wg_specialization/gemm_reduce_scatter.py @@ -16,8 +16,8 @@ def persistent_gemm_reduce_scatter_wg_specialized( A, B, - C, # local buffer [M, N] - C_global, # global output buffer [M, N] on each rank + C, # local buffer [M, N] + C_global, # global output buffer [M, N] on each rank locks, M, N, diff --git a/examples/21_gemm_reduce_scatter_wg_specialization/matmul_wrapper.py b/examples/21_gemm_reduce_scatter_wg_specialization/matmul_wrapper.py index ac6fac259..52ab0bd90 100644 --- a/examples/21_gemm_reduce_scatter_wg_specialization/matmul_wrapper.py +++ b/examples/21_gemm_reduce_scatter_wg_specialization/matmul_wrapper.py @@ -2,7 +2,6 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import torch -import triton from gemm_reduce_scatter import persistent_gemm_reduce_scatter_wg_specialized from examples.common.utils import is_triton_interpret_set import iris @@ -18,6 +17,7 @@ class matmul_rs(torch.autograd.Function): - GEMM SMs: Perform matrix multiplication computation - Communication SMs: Handle data communication """ + _debug = False _registers = None _spills = None @@ -45,8 +45,8 @@ def get_matmul_spills(): def _call( a: torch.Tensor, b: torch.Tensor, - c: torch.Tensor, # local buffer [M, N] - c_global: torch.Tensor, # global output [M_per_rank, N] + c: torch.Tensor, # local buffer [M, N] + c_global: torch.Tensor, # global output [M_per_rank, N] locks: torch.Tensor, rank: int, world_size: int, From b468024eb694521ba6a48ab371cda7a1fce9ba34 Mon Sep 17 00:00:00 2001 From: Kyle Wang Date: Tue, 13 Jan 2026 11:43:37 -0600 Subject: [PATCH 3/7] cleanup --- .../benchmark.py | 34 +++---------------- .../gemm_reduce_scatter.py | 16 +-------- .../matmul_wrapper.py | 0 examples/common/validation.py | 3 -- 4 files changed, 5 insertions(+), 48 deletions(-) rename examples/{21_gemm_reduce_scatter_wg_specialization => 22_gemm_one_shot_reduce_scatter_wg_specialization}/benchmark.py (87%) rename examples/{21_gemm_reduce_scatter_wg_specialization => 22_gemm_one_shot_reduce_scatter_wg_specialization}/gemm_reduce_scatter.py (90%) rename examples/{21_gemm_reduce_scatter_wg_specialization => 22_gemm_one_shot_reduce_scatter_wg_specialization}/matmul_wrapper.py (100%) diff --git a/examples/21_gemm_reduce_scatter_wg_specialization/benchmark.py b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py similarity index 87% rename from examples/21_gemm_reduce_scatter_wg_specialization/benchmark.py rename to examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py index 66630a84c..c3c2a5a48 100644 --- a/examples/21_gemm_reduce_scatter_wg_specialization/benchmark.py +++ b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py @@ -1,10 +1,6 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. -# -# GEMM + ReduceScatter Benchmark with Workgroup Specialization -# Reference: ByteDance Triton-distributed -# https://github.com/ByteDance-Seed/Triton-distributed/blob/main/tutorials/10-AMD-overlapping-gemm-reduce-scatter.py import torch import torch.distributed as dist @@ -20,8 +16,8 @@ import iris from matmul_wrapper import matmul_rs -torch.manual_seed(123) -random.seed(123) +torch.manual_seed(0) +random.seed(0) def parse_args(): @@ -95,7 +91,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Use next smaller power of 2 for GEMM SMs args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 - # Datatype datatype = torch.float16 if args["datatype"] == "fp16": datatype = torch.float16 @@ -118,7 +113,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): local_K = K // world_size M_per_rank = M // world_size - # Generate full matrices for reference calculation A_full = shmem.randn(M, K, device="cuda", dtype=datatype) B_full = shmem.randn(K, N, device="cuda", dtype=datatype) @@ -136,11 +130,8 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): for key, value in args.items(): json_writer.add_field(key, value) - # Local buffer for GEMM result [M, N] local_buf = shmem.zeros((M, N), device="cuda", dtype=datatype) - # Global output buffer for ReduceScatter result [M_per_rank, N] - # This is where each rank accumulates its final result output_buf = shmem.zeros((M_per_rank, N), device="cuda", dtype=datatype) total_blocks_M = triton.cdiv(M, args["BLK_M"]) @@ -163,13 +154,11 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): }, } - # Allocate Timestamps timestamps = Timestamps(num_tiles=total_tiles) def run_experiment(): nonlocal local_buf, output_buf - # Reset buffers local_buf.zero_() output_buf.zero_() locks.zero_() @@ -207,19 +196,17 @@ def run_experiment(): kernel_timing["gemm_rs"]["experiments"] += 1 torch.cuda.nvtx.range_pop() - # Ensure kernel completion before barrier - gemm_stream.synchronize() shmem.barrier() for k in ["gemm_rs"]: ms = kernel_timing[k]["start_event"].elapsed_time(kernel_timing[k]["end_event"]) kernel_timing[k]["ms"] += ms - # Synchronize across all GPUs shmem.barrier() # Warmup run_experiment() + shmem.barrier() for k in ["gemm_rs"]: @@ -230,25 +217,15 @@ def run_experiment(): shmem.info("Validating...") matmul_rs.set_debug(True) - # Run one more time for validation - run_experiment() - # Additional barrier to ensure all remote writes are complete - torch.cuda.synchronize() - shmem.barrier() - - # Get the GEMM result (input to reduce_scatter) and final output local_gemm = local_buf.clone() local_output = output_buf.clone() # Create process group for validation tp_group = dist.new_group(ranks=list(range(world_size))) - # For fp16 with atomic_add across multiple ranks, allow larger tolerance - # The 0.5 max_diff comes from accumulated rounding errors in atomic operations - # Relative error is ~0.08% which is acceptable for distributed computation + # Allow larger tolerance for fp16 due to accumulated rounding errors in atomic operations atol = 1.0 if datatype == torch.float16 else 0.5 - # Validate reduce_scatter using the common validation function success = validate_reduce_scatter(local_gemm, local_output, shmem, tp_group, atol=atol) if success: @@ -271,9 +248,6 @@ def run_experiment(): matmul_rs.set_debug(False) shmem.info("Benchmarking...") - # Performance calculation: - # Each rank computes [M, N] partial result from [M, local_K] x [local_K, N] - # FLOPs = 2 * M * N * local_K perf = lambda ms: 2 * M * N * local_K * 1e-12 / (ms * 1e-3) triton_ms = iris.do_bench(run_experiment, shmem.barrier) diff --git a/examples/21_gemm_reduce_scatter_wg_specialization/gemm_reduce_scatter.py b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py similarity index 90% rename from examples/21_gemm_reduce_scatter_wg_specialization/gemm_reduce_scatter.py rename to examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py index 209388aec..ad9a8cd67 100644 --- a/examples/21_gemm_reduce_scatter_wg_specialization/gemm_reduce_scatter.py +++ b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py @@ -1,9 +1,5 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. -# -# GEMM + ReduceScatter implementation using iris -# Reference: ByteDance Triton-distributed tutorial -# https://github.com/ByteDance-Seed/Triton-distributed/blob/main/tutorials/10-AMD-overlapping-gemm-reduce-scatter.py import triton import triton.language as tl @@ -81,14 +77,12 @@ def persistent_gemm_reduce_scatter_wg_specialized( acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 - # Workgroup specialization: GEMM path if pid < GEMM_SMS: for tile_id in range(pid, total_tiles, GEMM_SMS): if COLLECT_TIMESTAMPS: timestamp = read_realtime() tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) - # Standard tile mapping with GROUP_SIZE_M for L2 cache locality num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M @@ -108,7 +102,6 @@ def persistent_gemm_reduce_scatter_wg_specialized( A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - # GEMM computation loop_k = tl.cdiv(K, BLOCK_SIZE_K) if not EVEN_K: loop_k -= 1 @@ -151,12 +144,11 @@ def persistent_gemm_reduce_scatter_wg_specialized( tl.debug_barrier() tl.store(locks + tile_id, 1, cache_modifier=".wt") - else: # Communication path: pid >= GEMM_SMS + else: COMM_SMS = NUM_SMS - GEMM_SMS comm_pid = pid - GEMM_SMS for tile_id in range(comm_pid, total_tiles, COMM_SMS): - # Calculate tile position (same mapping as GEMM path) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M @@ -175,11 +167,9 @@ def persistent_gemm_reduce_scatter_wg_specialized( local_offset = rm[:, None] * stride_cm + rn[None, :] * stride_cn - # Wait for GEMM to finish this tile while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1: pass - # Load computed data from local buffer c = tl.load(C + local_offset, mask=sub_mask) # Determine target rank based on M position @@ -194,15 +184,11 @@ def persistent_gemm_reduce_scatter_wg_specialized( offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) global_offset = offs_cm[:, None] * stride_cg_m + offs_cn[None, :] * stride_cg_n - # Mask for valid elements within the target region global_mask = (offs_cm[:, None] < M_per_rank) & (offs_cn[None, :] < N) - # Send to target rank using atomic add if target_rank == cur_rank: - # Local atomic add tl.atomic_add(C_global + global_offset, c, mask=global_mask) else: - # Remote atomic add using iris iris.atomic_add( C_global + global_offset, c, diff --git a/examples/21_gemm_reduce_scatter_wg_specialization/matmul_wrapper.py b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/matmul_wrapper.py similarity index 100% rename from examples/21_gemm_reduce_scatter_wg_specialization/matmul_wrapper.py rename to examples/22_gemm_one_shot_reduce_scatter_wg_specialization/matmul_wrapper.py diff --git a/examples/common/validation.py b/examples/common/validation.py index 4dade8a69..d26f0f2f5 100644 --- a/examples/common/validation.py +++ b/examples/common/validation.py @@ -123,9 +123,6 @@ def validate_reduce_scatter(local_tensor, output_tensor, shmem, tp_group, atol=1 tp_group: torch.distributed process group for communication atol: Absolute tolerance for comparison - Returns: - bool: True if validation passes, False otherwise - ReduceScatter semantics: - Each rank has input tensor of shape [M, N] - All inputs are reduced (summed) element-wise to get [M, N] From 83ca44062424747c19e66bde4267eda2bf5ddcf1 Mon Sep 17 00:00:00 2001 From: Kyle Wang Date: Wed, 14 Jan 2026 16:32:08 -0600 Subject: [PATCH 4/7] address comment --- .../benchmark.py | 4 ++-- .../gemm_reduce_scatter.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py index c3c2a5a48..d681a9df8 100644 --- a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py +++ b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py @@ -138,7 +138,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): total_blocks_N = triton.cdiv(N, args["BLK_N"]) total_tiles = total_blocks_M * total_blocks_N - locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8) + locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) gemm_stream = torch.cuda.Stream() @@ -187,7 +187,7 @@ def run_experiment(): args["gsize_m"], args["num_stages"], shmem.get_heap_bases(), - "gfx942", + torch.cuda.get_device_properties(rank).name, args["trace_tiles"], timestamps.mm_begin_timestamp, timestamps.mm_end_timestamp, diff --git a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py index ad9a8cd67..65bcba38f 100644 --- a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py +++ b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py @@ -141,8 +141,7 @@ def persistent_gemm_reduce_scatter_wg_specialized( tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) tl.store(C + local_offset, c, mask=sub_mask, cache_modifier=".wt") - tl.debug_barrier() - tl.store(locks + tile_id, 1, cache_modifier=".wt") + iris.atomic_cas(locks + tile_id, 0, 1, cur_rank, cur_rank, heap_bases, sem="release", scope="sys") else: COMM_SMS = NUM_SMS - GEMM_SMS @@ -167,8 +166,11 @@ def persistent_gemm_reduce_scatter_wg_specialized( local_offset = rm[:, None] * stride_cm + rn[None, :] * stride_cn - while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1: - pass + done = 0 + while done == 0: + done = iris.atomic_cas( + locks + tile_id, 1, 0, cur_rank, cur_rank, heap_bases, sem="acquire", scope="sys" + ) c = tl.load(C + local_offset, mask=sub_mask) From df1bd9dc81737f36ca6b1cd8688a0d11b5b04cc9 Mon Sep 17 00:00:00 2001 From: Kyle Wang Date: Tue, 20 Jan 2026 03:04:39 -0600 Subject: [PATCH 5/7] clean up --- .../benchmark.py | 7 ++----- .../gemm_reduce_scatter.py | 14 +++++--------- .../matmul_wrapper.py | 12 ++---------- examples/common/validation.py | 6 ------ 4 files changed, 9 insertions(+), 30 deletions(-) diff --git a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py index d681a9df8..de371e546 100644 --- a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py +++ b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py @@ -83,7 +83,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): rank = shmem.get_rank() world_size = shmem.get_num_ranks() - # Set default SM values if not provided cu_count = torch.cuda.get_device_properties(rank).multi_processor_count if args["num_sms"] is None: args["num_sms"] = cu_count @@ -116,7 +115,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): A_full = shmem.randn(M, K, device="cuda", dtype=datatype) B_full = shmem.randn(K, N, device="cuda", dtype=datatype) - # Each rank gets a portion of K dimension + # Each rank gets a portion of K dimension as input local_A = A_full[:, rank * local_K : (rank + 1) * local_K].clone() local_B = B_full[rank * local_K : (rank + 1) * local_K, :].clone() @@ -220,12 +219,10 @@ def run_experiment(): local_gemm = local_buf.clone() local_output = output_buf.clone() - # Create process group for validation - tp_group = dist.new_group(ranks=list(range(world_size))) - # Allow larger tolerance for fp16 due to accumulated rounding errors in atomic operations atol = 1.0 if datatype == torch.float16 else 0.5 + tp_group = dist.new_group(ranks=list(range(world_size))) success = validate_reduce_scatter(local_gemm, local_output, shmem, tp_group, atol=atol) if success: diff --git a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py index 65bcba38f..2fe1461f7 100644 --- a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py +++ b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py @@ -127,13 +127,9 @@ def persistent_gemm_reduce_scatter_wg_specialized( c = acc.to(C.type.element_ty) - rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) - rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) sub_mask = (rm[:, None] < M) & (rn[None, :] < N) - # Store to local buffer with write-through cache + # Store to local buffer local_offset = rm[:, None] * stride_cm + rn[None, :] * stride_cn if COLLECT_TIMESTAMPS: @@ -174,13 +170,11 @@ def persistent_gemm_reduce_scatter_wg_specialized( c = tl.load(C + local_offset, mask=sub_mask) - # Determine target rank based on M position - # ReduceScatter: chunk i of M dimension goes to rank i + # chunk i of M dimension goes to rank i tile_m_start = pid_m * BLOCK_SIZE_M target_rank = tile_m_start // M_per_rank - # Calculate offset within target rank's output region - # target_m is the row offset within C_global[M_per_rank, N] + # offset within target rank's output target_m = tile_m_start % M_per_rank offs_cm = target_m + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -198,4 +192,6 @@ def persistent_gemm_reduce_scatter_wg_specialized( target_rank, heap_bases, mask=global_mask, + sem="relaxed", + scope="sys" ) diff --git a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/matmul_wrapper.py b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/matmul_wrapper.py index 52ab0bd90..712fbe14d 100644 --- a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/matmul_wrapper.py +++ b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/matmul_wrapper.py @@ -10,14 +10,6 @@ class matmul_rs(torch.autograd.Function): - """ - GEMM + ReduceScatter matmul wrapper with Workgroup Specialization - - Split SMs into two groups: - - GEMM SMs: Perform matrix multiplication computation - - Communication SMs: Handle data communication - """ - _debug = False _registers = None _spills = None @@ -45,8 +37,8 @@ def get_matmul_spills(): def _call( a: torch.Tensor, b: torch.Tensor, - c: torch.Tensor, # local buffer [M, N] - c_global: torch.Tensor, # global output [M_per_rank, N] + c: torch.Tensor, + c_global: torch.Tensor, locks: torch.Tensor, rank: int, world_size: int, diff --git a/examples/common/validation.py b/examples/common/validation.py index d26f0f2f5..d405dcf6d 100644 --- a/examples/common/validation.py +++ b/examples/common/validation.py @@ -122,12 +122,6 @@ def validate_reduce_scatter(local_tensor, output_tensor, shmem, tp_group, atol=1 shmem: Iris shmem object tp_group: torch.distributed process group for communication atol: Absolute tolerance for comparison - - ReduceScatter semantics: - - Each rank has input tensor of shape [M, N] - - All inputs are reduced (summed) element-wise to get [M, N] - - The result is scattered: rank i gets rows [i*M_per_rank : (i+1)*M_per_rank] - - Output shape is [M/world_size, N] """ import torch.distributed as dist From 30d02d10d7425d779738e987ef90bb61b5dc4d47 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 20 Jan 2026 09:05:00 +0000 Subject: [PATCH 6/7] Apply Ruff auto-fixes --- .../gemm_reduce_scatter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py index 2fe1461f7..60312b6b8 100644 --- a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py +++ b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py @@ -193,5 +193,5 @@ def persistent_gemm_reduce_scatter_wg_specialized( heap_bases, mask=global_mask, sem="relaxed", - scope="sys" + scope="sys", ) From a5d420309a2a3a2cea7615f84663c7c89d78bd41 Mon Sep 17 00:00:00 2001 From: Kyle Wang Date: Tue, 20 Jan 2026 16:23:12 -0600 Subject: [PATCH 7/7] address comments --- .../benchmark.py | 14 ++++++------ .../matmul_wrapper.py | 22 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py index de371e546..47728618f 100644 --- a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py +++ b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py @@ -14,7 +14,7 @@ from examples.common.validation import validate_reduce_scatter import iris -from matmul_wrapper import matmul_rs +from matmul_wrapper import MatMulReduceScatterWgSpecialized torch.manual_seed(0) random.seed(0) @@ -170,7 +170,7 @@ def run_experiment(): torch.cuda.nvtx.range_push("GEMM + ReduceScatter") with torch.cuda.stream(gemm_stream): kernel_timing["gemm_rs"]["start_event"].record() - matmul_rs.apply( + MatMulReduceScatterWgSpecialized.apply( local_A, local_B, local_buf, @@ -214,7 +214,7 @@ def run_experiment(): if args["validate"]: shmem.info("Validating...") - matmul_rs.set_debug(True) + MatMulReduceScatterWgSpecialized.set_debug(True) local_gemm = local_buf.clone() local_output = output_buf.clone() @@ -233,8 +233,8 @@ def run_experiment(): json_writer.add_field("success", success) if not is_triton_interpret_set(): - gemm_registers = matmul_rs.get_matmul_registers() - gemm_spills = matmul_rs.get_matmul_spills() + gemm_registers = MatMulReduceScatterWgSpecialized.get_matmul_registers() + gemm_spills = MatMulReduceScatterWgSpecialized.get_matmul_spills() json_writer.add_field("gemm_registers", gemm_registers) json_writer.add_field("gemm_spills", gemm_spills) @@ -242,10 +242,10 @@ def run_experiment(): shmem.info("Validation completed") if args["benchmark"]: - matmul_rs.set_debug(False) + MatMulReduceScatterWgSpecialized.set_debug(False) shmem.info("Benchmarking...") - perf = lambda ms: 2 * M * N * local_K * 1e-12 / (ms * 1e-3) + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) triton_ms = iris.do_bench(run_experiment, shmem.barrier) triton_tflops = perf(triton_ms) diff --git a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/matmul_wrapper.py b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/matmul_wrapper.py index 712fbe14d..bf75a5dfb 100644 --- a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/matmul_wrapper.py +++ b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/matmul_wrapper.py @@ -9,7 +9,7 @@ gemm_kernel = persistent_gemm_reduce_scatter_wg_specialized -class matmul_rs(torch.autograd.Function): +class MatMulReduceScatterWgSpecialized(torch.autograd.Function): _debug = False _registers = None _spills = None @@ -17,19 +17,19 @@ class matmul_rs(torch.autograd.Function): @staticmethod def set_debug(debug: bool): - matmul_rs._debug = debug + MatMulReduceScatterWgSpecialized._debug = debug @staticmethod def get_matmul_registers(): - if matmul_rs._debug: - return matmul_rs._registers + if MatMulReduceScatterWgSpecialized._debug: + return MatMulReduceScatterWgSpecialized._registers else: raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") @staticmethod def get_matmul_spills(): - if matmul_rs._debug: - return matmul_rs._spills + if MatMulReduceScatterWgSpecialized._debug: + return MatMulReduceScatterWgSpecialized._spills else: raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") @@ -59,7 +59,7 @@ def _call( M, K = a.shape _, N = b.shape - num_xcds = matmul_rs._num_xcds + num_xcds = MatMulReduceScatterWgSpecialized._num_xcds num_warps = 8 waves_per_eu = 0 mfma = 16 @@ -106,9 +106,9 @@ def _call( mm_end_timestamp_ptr=mm_end_timestamp, ) - if matmul_rs._debug and not is_triton_interpret_set(): - matmul_rs._registers = kk.n_regs - matmul_rs._spills = kk.n_spills + if MatMulReduceScatterWgSpecialized._debug and not is_triton_interpret_set(): + MatMulReduceScatterWgSpecialized._registers = kk.n_regs + MatMulReduceScatterWgSpecialized._spills = kk.n_spills return c_global @@ -135,7 +135,7 @@ def forward( mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, ): - return matmul_rs._call( + return MatMulReduceScatterWgSpecialized._call( a=a, b=b, c=c,