[ROCm][MoE] Custom W4A16 MoE prefill WMMA GEMM for gfx11 (default-on)#1015
Open
roberteg16 wants to merge 2 commits into
Open
[ROCm][MoE] Custom W4A16 MoE prefill WMMA GEMM for gfx11 (default-on)#1015roberteg16 wants to merge 2 commits into
roberteg16 wants to merge 2 commits into
Conversation
57e38ca to
d87b530
Compare
|
Can you show the before/after GEMM table from vllm-bench profiling for Qwen3.6? |
mgehre-amd
reviewed
Jun 23, 2026
b6879d0 to
7f41b26
Compare
Author
Before:After |
d793760 to
a899393
Compare
…ault-on) Adds a producer/consumer WMMA kernel (rdna_moe_gemm) for the W4A16 MoE prefill GEMM1 on AMD RDNA3 (gfx11), a faster alternative to the Triton fused_moe_kernel_gptq_awq. Enabled by default on gfx11 (VLLM_MOE_HIP=0 forces Triton, =1 forces on). On the rdna_moe_gemm path gemm1 (up/gate proj, top_k>1) runs the rdna_moe_gemm WMMA kernel and gemm2 (down proj, top_k=1) runs Triton, both at a single block_m=32 moe_align alignment. The per-routed-token activations are stored compact (flat-topk indexed -- the kernels write C[sorted_token_ids[slot]]), so gemm2 gathers them via the shared sorted_token_ids with no re-permute. gemm1 at block_m=32 is faster than at 16 on gfx1151 (the larger WMMA tile more than pays for the extra alignment padding), so there is a single alignment for both gemms. - csrc/rocm/moe_gemm_w4a16_wmma.cu: the WMMA kernel with K, N and the weight N-row stride as runtime args (one instantiation per tile family handles any compliant shape and any weight padding). Compiled into _rocm_C; the WMMA body is gfx11-only with a stub on other arches. Registered as torch.ops._rocm_C.moe_gemm_w4a16; an unsupported shape TORCH_CHECKs. - vllm/envs.py: VLLM_MOE_HIP tri-state (unset = default-on on gfx11). - vllm/model_executor/layers/fused_moe/moe_hip_w4a16.py: host-side shape predicate (prefill_uses_rdna_moe_gemm) plus a graph-safe vLLM custom op wrapper with a no-op fake, so the path is torch.compile-safe. - vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py: apply() dispatch -- rdna_moe_gemm gemm1 when the shape is supported, else Triton; gemm2 always Triton. - tests/kernels/quantization/test_moe_gemm_w4a16.py: compiled op vs Triton reference across shapes and block sizes. AI assistance (Claude) was used. Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Robert Esclapez Garcia <robert.garcia@amd.com>
Wrap the rdna_moe_gemm gemm1 launch (torch.ops.vllm.moe_gemm_w4a16) in apply() with an apply()-level record_function scope carrying the dims the roofline tool needs: M, N, K, E, top_k, the quant group size g, block_m, valid_blocks (= num_tokens_post_padded // block_m) and n_routed (M*top_k), plus per-expert tok_hist / vtok_hist histograms describing the routing skew that drives the block_m padding. g is the real self._group_size (the per-group scale bytes and dequant FLOPs scale with K/g), not an assumed 128. The scope is gated on VLLM_CUSTOM_SCOPES_FOR_PROFILING / VLLM_NVTX_SCOPES_FOR_PROFILING -- the valid_blocks/histogram reads force a device->host sync, taken only when profiling; production gets a nullcontext. AI assistance (Claude) was used. Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Robert Esclapez Garcia <robert.garcia@amd.com>
a899393 to
f50fdea
Compare
Author
|
@mgehre-amd polite ping |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This adds a producer/consumer WMMA kernel (rdna_moe_gemm) for the W4A16 MoE prefill GEMM1 on AMD RDNA3 (gfx11), a faster alternative to the Triton
fused_moe_kernel_gptq_awq. It is enabled by default on gfx11 (setVLLM_MOE_HIP=0to force the Triton path, orVLLM_MOE_HIP=1to force-enable). On the rdna_moe_gemm path gemm1 (up/gate proj, top_k>1) runs the custom WMMA kernel and gemm2 (down proj, top_k=1) runs Triton, both at a singleblock_m=32moe_alignalignment.K,Nand the weight row stride are runtime kernel arguments, so a single instantiation per tile family accepts any shape meeting the tile constraints (K % G == 0,N % BN == 0) and any weight N-row padding; any other shape, group size, top_k, arch, or an explicit disable falls back to Triton unchanged, so the change is a no-op outside the supported gfx11 path and remains fully A/B-able.What this adds
csrc/rocm/moe_gemm_w4a16_wmma.cu: the WMMA kernel, withK,Nand the weight N-row stride as runtime args so one instantiation per tile family handles any compliant shape and any weight padding. Compiled into_rocm_Cunconditionally; the WMMA body is guarded to gfx11 device passes with an empty stub on other arches (mirroring the existingq_gemm_rdna3_wmmapattern) so multi-arch / CDNA builds still link. The host does constraint-based tile selection keyed by top_k; an unsupported shapeTORCH_CHECKs (callers gate first). Registered astorch.ops._rocm_C.moe_gemm_w4a16inops.h/torch_bindings.cpp. The A producer load is vectorized to oneglobal_load_b128per lane (av4u= 8 bf16) instead of a per-element 16-bit gather, with a matching wide LDS store when the padded row stride allows it.vllm/envs.py: registersVLLM_MOE_HIPas a tri-state env var (unset = default-on on gfx11,1forces on,0forces the Triton path).vllm/model_executor/layers/fused_moe/moe_hip_w4a16.py: a host-side shape predicate (prefill_uses_rdna_moe_gemm) gates onon_gfx11()plus the same shape/alignment constraints as the kernel, and a registered vLLM custom optorch.ops.vllm.moe_gemm_w4a16with a no-op fake wraps the call so the path is graph-safe undertorch.compile(no data-dependent branch on a device result).vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py: dispatch inHybridW4A16MoEExperts.apply()— gemm1 runs the rdna_moe_gemm kernel when its predicate accepts the shape, else the existing Triton invoke; gemm2 always runs Triton. Both gemms share oneblock_m=32alignment, so there is no secondmoe_align.tests/kernels/quantization/test_moe_gemm_w4a16.py: validates the compiled op against the Triton reference across multiple shapes and block sizes.A second commit adds an
apply()-levelrecord_functionprofiling scope around the rdna_moe_gemm gemm1 launch (moe_gemm_w4a16 …) carrying the dims the roofline tooling needs — M, N, K, E, top_k, the quant group size g, block_m, valid_blocks (=num_tokens_post_padded // block_m), n_routed, and per-expert routing histograms — gated onVLLM_CUSTOM_SCOPES_FOR_PROFILING/VLLM_NVTX_SCOPES_FOR_PROFILINGso the device→host sync it needs runs only when profiling; production gets anullcontextand pays nothing.Design notes
The activations between the two gemms are stored compact (flat-topk indexed — the kernels write
C[sorted_token_ids[slot]]), so gemm2 gathers them via the sharedsorted_token_idswith no re-permute, and the whole MoE runs on a single alignment. We useblock_m=32for both gemms because gemm1 at 32 is faster than at 16 on gfx1151: the larger WMMA M-tile more than pays for the extra alignment padding (measured ~6% faster per useful row), and gemm2 (Triton) is already fastest at 32. An earlier dual-alignment variant (gemm1 at 16, gemm2 at 32 via a secondmoe_align) was both slower and more complex, so it was dropped.Performance
End-to-end on gfx1151 against
Qwen3.6-35B-A3B-W4A16(prefill, 100 input tokens + one 1280×720 image, output_len=1), median time-to-first-token, rdna_moe_gemm path (the default) vs the Triton path (VLLM_MOE_HIP=0):VLLM_MOE_HIP=0)The kernel is faster in both cases (~40 ms / ~6.5% at 1 prompt, ~31 ms / ~4.8% at 10 prompts). The end-to-end gain is much smaller than the kernel speedup because the MoE gemm1 is only one component of prefill, and single-prompt latency carries run-to-run variance.
A second run with a smaller 640×480 image (fewer vision tokens, so prefill is shorter and the MoE gemm1 is a smaller slice of TTFT), median time-to-first-token, rdna_moe_gemm path (the default) vs the Triton path (
VLLM_MOE_HIP=0):VLLM_MOE_HIP=0)Testing
Unit test (compiled op vs Triton reference, relative error < 0.01):
pytest tests/kernels/quantization/test_moe_gemm_w4a16.pypasses on gfx1151, covering the tuned shapes plus a non-tunedK×N, each across block_m 16 and 32; it skips automatically off gfx11 (where the kernel body is a stub). End-to-end smoke on gfx1151 againstQwen3.6-35B-A3B-W4A16: with the default (kernel active) the engine starts, gemm1 fires through the registered op at block_m=32, and generation succeeds (multimodal sanity check passes); a run withVLLM_MOE_HIP=0confirms the Triton path still works and is selected. Lint/format/type checks (ruff, ruff-format, clang-format, mypy, typos) pass.