diff --git a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py new file mode 100644 index 000000000..47728618f --- /dev/null +++ b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +import random +import argparse +import math + +from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set +from examples.common.validation import validate_reduce_scatter + +import iris +from matmul_wrapper import MatMulReduceScatterWgSpecialized + +torch.manual_seed(0) +random.seed(0) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="GEMM + ReduceScatter Benchmark with Workgroup Specialization", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=8192, help="Number of rows in matrix A (M)") + parser.add_argument("-n", type=int, default=4096, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=12288, help="Common dimension (K), will be split across ranks") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-t", "--trace_tiles", action="store_true", help="Enable tile-tracing mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of computation", + ) + parser.add_argument( + "--output_file", + type=str, + default="log.json", + help="Output file", + ) + parser.add_argument("--BLK_M", type=int, default=128, help="Block size M") + parser.add_argument("--BLK_N", type=int, default=256, help="Block size N") + parser.add_argument("--BLK_K", type=int, default=32, help="Block size K") + parser.add_argument("--gsize_m", type=int, default=1, help="L2-cache locality swizzle parameter") + parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") + parser.add_argument( + "--num_sms", + type=int, + default=None, + help="Number of total SMs (default: auto-detected)", + ) + parser.add_argument( + "--gemm_sms", + type=int, + default=None, + help="Number of SMs for GEMM (default: auto-detected as power of 2)", + ) + parser.add_argument("--num_stages", type=int, default=2, help="Number of stages") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + if args["num_sms"] is None: + args["num_sms"] = cu_count + if args["gemm_sms"] is None: + # Use next smaller power of 2 for GEMM SMs + args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 + + datatype = torch.float16 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M, N, K = args["m"], args["n"], args["k"] + + assert M % world_size == 0, f"M ({M}) must be divisible by world size ({world_size})" + assert K % world_size == 0, f"K ({K}) must be divisible by world size ({world_size})" + assert (M // world_size) % args["BLK_M"] == 0, ( + f"M_per_rank ({M // world_size}) must be divisible by BLK_M ({args['BLK_M']})" + ) + + local_K = K // world_size + M_per_rank = M // world_size + + A_full = shmem.randn(M, K, device="cuda", dtype=datatype) + B_full = shmem.randn(K, N, device="cuda", dtype=datatype) + + # Each rank gets a portion of K dimension as input + local_A = A_full[:, rank * local_K : (rank + 1) * local_K].clone() + local_B = B_full[rank * local_K : (rank + 1) * local_K, :].clone() + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("M", M) + json_writer.add_field("N", N) + json_writer.add_field("K", K) + json_writer.add_field("local_K", local_K) + + for key, value in args.items(): + json_writer.add_field(key, value) + + local_buf = shmem.zeros((M, N), device="cuda", dtype=datatype) + + output_buf = shmem.zeros((M_per_rank, N), device="cuda", dtype=datatype) + + total_blocks_M = triton.cdiv(M, args["BLK_M"]) + total_blocks_N = triton.cdiv(N, args["BLK_N"]) + total_tiles = total_blocks_M * total_blocks_N + + locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) + + gemm_stream = torch.cuda.Stream() + + json_writer.add_field("num_sms", args["num_sms"]) + json_writer.add_field("gemm_sms", args["gemm_sms"]) + + kernel_timing = { + "gemm_rs": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + timestamps = Timestamps(num_tiles=total_tiles) + + def run_experiment(): + nonlocal local_buf, output_buf + + local_buf.zero_() + output_buf.zero_() + locks.zero_() + shmem.barrier() + + if args["trace_tiles"]: + timestamps.reset() + shmem.barrier() + + torch.cuda.nvtx.range_push("GEMM + ReduceScatter") + with torch.cuda.stream(gemm_stream): + kernel_timing["gemm_rs"]["start_event"].record() + MatMulReduceScatterWgSpecialized.apply( + local_A, + local_B, + local_buf, + output_buf, + locks, + rank, + world_size, + args["gemm_sms"], + args["num_sms"], + args["BLK_M"], + args["BLK_N"], + args["BLK_K"], + args["gsize_m"], + args["num_stages"], + shmem.get_heap_bases(), + torch.cuda.get_device_properties(rank).name, + args["trace_tiles"], + timestamps.mm_begin_timestamp, + timestamps.mm_end_timestamp, + ) + kernel_timing["gemm_rs"]["end_event"].record() + kernel_timing["gemm_rs"]["experiments"] += 1 + + torch.cuda.nvtx.range_pop() + shmem.barrier() + + for k in ["gemm_rs"]: + ms = kernel_timing[k]["start_event"].elapsed_time(kernel_timing[k]["end_event"]) + kernel_timing[k]["ms"] += ms + + shmem.barrier() + + # Warmup + run_experiment() + + shmem.barrier() + + for k in ["gemm_rs"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + if args["validate"]: + shmem.info("Validating...") + MatMulReduceScatterWgSpecialized.set_debug(True) + + local_gemm = local_buf.clone() + local_output = output_buf.clone() + + # Allow larger tolerance for fp16 due to accumulated rounding errors in atomic operations + atol = 1.0 if datatype == torch.float16 else 0.5 + + tp_group = dist.new_group(ranks=list(range(world_size))) + success = validate_reduce_scatter(local_gemm, local_output, shmem, tp_group, atol=atol) + + if success: + shmem.info("✅ Triton and Torch match") + else: + shmem.info("❌ Triton and Torch differ") + + json_writer.add_field("success", success) + + if not is_triton_interpret_set(): + gemm_registers = MatMulReduceScatterWgSpecialized.get_matmul_registers() + gemm_spills = MatMulReduceScatterWgSpecialized.get_matmul_spills() + json_writer.add_field("gemm_registers", gemm_registers) + json_writer.add_field("gemm_spills", gemm_spills) + + shmem.barrier() + shmem.info("Validation completed") + + if args["benchmark"]: + MatMulReduceScatterWgSpecialized.set_debug(False) + shmem.info("Benchmarking...") + + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + triton_tflops = perf(triton_ms) + + shmem.info(f"GEMM + ReduceScatter (total_tiles={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops") + + json_writer.add_field("tflops", triton_tflops) + json_writer.add_field("total_ms", triton_ms) + + for k in ["gemm_rs"]: + json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"]) + json_writer.add_field(k + "_experiments", kernel_timing[k]["experiments"]) + + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + if args["trace_tiles"] and rank == 0: + gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3 + filename = f"gemm_tiles_reduce_scatter_trace_rank{rank}.json" + timestamps.to_json(filename, gpu_freq) + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py new file mode 100644 index 000000000..60312b6b8 --- /dev/null +++ b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl +from examples.common.utils import read_realtime + +import iris + + +@triton.jit() +def persistent_gemm_reduce_scatter_wg_specialized( + A, + B, + C, # local buffer [M, N] + C_global, # global output buffer [M, N] on each rank + locks, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_cg_m, + stride_cg_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + GEMM_SMS: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + EVEN_K: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + COLLECT_TIMESTAMPS: tl.constexpr = False, + mm_begin_timestamp_ptr: tl.tensor = None, + mm_end_timestamp_ptr: tl.tensor = None, +): + """ + GEMM + ReduceScatter with Workgroup Specialization + + Split SMs into two groups: + - GEMM SMs: Perform matrix multiplication computation + - Communication SMs: Handle data communication (scatter to target ranks) + + This approach enables overlapping computation and communication. + + Data partitioning (ReduceScatter): + - A: [M, local_K] - Each rank has a portion of K dimension + - B: [local_K, N] - Each rank has a portion of K dimension + - Each rank computes partial C = A @ B of shape [M, N] + - ReduceScatter: Split C along M dimension into world_size chunks, + send chunk i to rank i, accumulate with atomic_add + - Output: Each rank ends up with [M/world_size, N] + """ + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + M_per_rank = M // world_size + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + + if pid < GEMM_SMS: + for tile_id in range(pid, total_tiles, GEMM_SMS): + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + rk = tl.arange(0, BLOCK_SIZE_K) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + + loop_k = tl.cdiv(K, BLOCK_SIZE_K) + if not EVEN_K: + loop_k -= 1 + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, loop_k): + a = tl.load(tl.multiple_of(A_BASE, (1, 16))) + b = tl.load(tl.multiple_of(B_BASE, (16, 1))) + acc += tl.dot(a, b) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + if not EVEN_K: + k = loop_k + rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + A_BASE = tl.multiple_of(A_BASE, (1, 16)) + B_BASE = tl.multiple_of(B_BASE, (16, 1)) + a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0) + b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0) + acc += tl.dot(a, b) + + c = acc.to(C.type.element_ty) + + sub_mask = (rm[:, None] < M) & (rn[None, :] < N) + + # Store to local buffer + local_offset = rm[:, None] * stride_cm + rn[None, :] * stride_cn + + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) + + tl.store(C + local_offset, c, mask=sub_mask, cache_modifier=".wt") + iris.atomic_cas(locks + tile_id, 0, 1, cur_rank, cur_rank, heap_bases, sem="release", scope="sys") + + else: + COMM_SMS = NUM_SMS - GEMM_SMS + comm_pid = pid - GEMM_SMS + + for tile_id in range(comm_pid, total_tiles, COMM_SMS): + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + sub_mask = (rm[:, None] < M) & (rn[None, :] < N) + + local_offset = rm[:, None] * stride_cm + rn[None, :] * stride_cn + + done = 0 + while done == 0: + done = iris.atomic_cas( + locks + tile_id, 1, 0, cur_rank, cur_rank, heap_bases, sem="acquire", scope="sys" + ) + + c = tl.load(C + local_offset, mask=sub_mask) + + # chunk i of M dimension goes to rank i + tile_m_start = pid_m * BLOCK_SIZE_M + target_rank = tile_m_start // M_per_rank + + # offset within target rank's output + target_m = tile_m_start % M_per_rank + offs_cm = target_m + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + global_offset = offs_cm[:, None] * stride_cg_m + offs_cn[None, :] * stride_cg_n + + global_mask = (offs_cm[:, None] < M_per_rank) & (offs_cn[None, :] < N) + + if target_rank == cur_rank: + tl.atomic_add(C_global + global_offset, c, mask=global_mask) + else: + iris.atomic_add( + C_global + global_offset, + c, + cur_rank, + target_rank, + heap_bases, + mask=global_mask, + sem="relaxed", + scope="sys", + ) diff --git a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/matmul_wrapper.py b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/matmul_wrapper.py new file mode 100644 index 000000000..bf75a5dfb --- /dev/null +++ b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/matmul_wrapper.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +from gemm_reduce_scatter import persistent_gemm_reduce_scatter_wg_specialized +from examples.common.utils import is_triton_interpret_set +import iris + +gemm_kernel = persistent_gemm_reduce_scatter_wg_specialized + + +class MatMulReduceScatterWgSpecialized(torch.autograd.Function): + _debug = False + _registers = None + _spills = None + _num_xcds = iris.hip.get_num_xcc() + + @staticmethod + def set_debug(debug: bool): + MatMulReduceScatterWgSpecialized._debug = debug + + @staticmethod + def get_matmul_registers(): + if MatMulReduceScatterWgSpecialized._debug: + return MatMulReduceScatterWgSpecialized._registers + else: + raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") + + @staticmethod + def get_matmul_spills(): + if MatMulReduceScatterWgSpecialized._debug: + return MatMulReduceScatterWgSpecialized._spills + else: + raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") + + @staticmethod + def _call( + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + c_global: torch.Tensor, + locks: torch.Tensor, + rank: int, + world_size: int, + gemm_sms: int, + num_sms: int, + BLK_M: int, + BLK_N: int, + BLK_K: int, + gsize_m: int, + num_stages: int, + heap_bases_ptr: torch.Tensor, + arch: str = "gfx942", + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + num_xcds = MatMulReduceScatterWgSpecialized._num_xcds + num_warps = 8 + waves_per_eu = 0 + mfma = 16 + kpack = 1 + even_k = K % BLK_K == 0 + + grid = (num_sms,) + + kk = gemm_kernel[grid]( + a, + b, + c, + c_global, + locks, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + c_global.stride(0), + c_global.stride(1), + BLOCK_SIZE_M=BLK_M, + BLOCK_SIZE_N=BLK_N, + BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=gsize_m, + GEMM_SMS=gemm_sms, + NUM_SMS=num_sms, + NUM_XCDS=num_xcds, + EVEN_K=even_k, + num_stages=num_stages, + num_warps=num_warps, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfma, + kpack=kpack, + heap_bases=heap_bases_ptr, + cur_rank=rank, + world_size=world_size, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp_ptr=mm_begin_timestamp, + mm_end_timestamp_ptr=mm_end_timestamp, + ) + + if MatMulReduceScatterWgSpecialized._debug and not is_triton_interpret_set(): + MatMulReduceScatterWgSpecialized._registers = kk.n_regs + MatMulReduceScatterWgSpecialized._spills = kk.n_spills + + return c_global + + @staticmethod + def forward( + ctx, + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + c_global: torch.Tensor, + locks: torch.Tensor, + rank: int, + world_size: int, + gemm_sms: int, + num_sms: int, + BLK_M: int, + BLK_N: int, + BLK_K: int, + gsize_m: int, + num_stages: int, + heap_bases_ptr: torch.Tensor, + arch: str = "gfx942", + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + return MatMulReduceScatterWgSpecialized._call( + a=a, + b=b, + c=c, + c_global=c_global, + locks=locks, + rank=rank, + world_size=world_size, + gemm_sms=gemm_sms, + num_sms=num_sms, + BLK_M=BLK_M, + BLK_N=BLK_N, + BLK_K=BLK_K, + gsize_m=gsize_m, + num_stages=num_stages, + heap_bases_ptr=heap_bases_ptr, + arch=arch, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp=mm_begin_timestamp, + mm_end_timestamp=mm_end_timestamp, + ) diff --git a/examples/common/validation.py b/examples/common/validation.py index 8046e92d8..d405dcf6d 100644 --- a/examples/common/validation.py +++ b/examples/common/validation.py @@ -109,3 +109,57 @@ def validate_all_reduce(local_tensor, global_tensor, shmem, atol=1): return False return True + + +def validate_reduce_scatter(local_tensor, output_tensor, shmem, tp_group, atol=1): + """ + Validate reduce-scatter operation where each rank's local tensor is reduced (summed) + and the result is scattered across ranks along the first dimension. + + Args: + local_tensor: The local tensor on this rank before reduce-scatter [M, N] + output_tensor: The result tensor after reduce-scatter [M/world_size, N] + shmem: Iris shmem object + tp_group: torch.distributed process group for communication + atol: Absolute tolerance for comparison + """ + import torch.distributed as dist + + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + M, N = local_tensor.shape + M_per_rank = M // world_size + + # Verify output shape + expected_output_shape = (M_per_rank, N) + if output_tensor.shape != expected_output_shape: + shmem.error(f"Output tensor shape mismatch: expected {expected_output_shape}, got {output_tensor.shape}") + return False + + # Gather all local tensors to compute expected result + all_local_tensors = [torch.zeros_like(local_tensor) for _ in range(world_size)] + dist.all_gather(all_local_tensors, local_tensor, group=tp_group) + + # Compute expected: sum of all local tensors, then take this rank's slice + total_sum = sum(all_local_tensors) + expected = total_sum[rank * M_per_rank : (rank + 1) * M_per_rank, :] + + # Compare + diff_mask = ~torch.isclose(output_tensor, expected, atol=atol) + breaking_indices = torch.nonzero(diff_mask, as_tuple=False) + + if not torch.allclose(output_tensor, expected, atol=atol): + max_diff = (output_tensor - expected).abs().max().item() + mean_diff = (output_tensor - expected).abs().mean().item() + shmem.info(f"Reduce-scatter validation: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + for idx in breaking_indices[:5]: # Show up to 5 mismatches + idx = tuple(idx.tolist()) + computed_val = output_tensor[idx] + expected_val = expected[idx] + shmem.error( + f"Reduce-scatter mismatch at rank {rank}, index {idx}: got={computed_val}, expected={expected_val}" + ) + return False + + return True