Skip to content

[ROCm][MoE] Custom W4A16 MoE prefill WMMA GEMM for gfx11 (default-on)#1015

Open
roberteg16 wants to merge 2 commits into
gfx11from
rogarcia.moe_gemm1_new_hip_kernel
Open

[ROCm][MoE] Custom W4A16 MoE prefill WMMA GEMM for gfx11 (default-on)#1015
roberteg16 wants to merge 2 commits into
gfx11from
rogarcia.moe_gemm1_new_hip_kernel

Conversation

@roberteg16

@roberteg16 roberteg16 commented Jun 22, 2026

Copy link
Copy Markdown

Summary

This adds a producer/consumer WMMA kernel (rdna_moe_gemm) for the W4A16 MoE prefill GEMM1 on AMD RDNA3 (gfx11), a faster alternative to the Triton fused_moe_kernel_gptq_awq. It is enabled by default on gfx11 (set VLLM_MOE_HIP=0 to force the Triton path, or VLLM_MOE_HIP=1 to force-enable). On the rdna_moe_gemm path gemm1 (up/gate proj, top_k>1) runs the custom WMMA kernel and gemm2 (down proj, top_k=1) runs Triton, both at a single block_m=32 moe_align alignment. K, N and the weight row stride are runtime kernel arguments, so a single instantiation per tile family accepts any shape meeting the tile constraints (K % G == 0, N % BN == 0) and any weight N-row padding; any other shape, group size, top_k, arch, or an explicit disable falls back to Triton unchanged, so the change is a no-op outside the supported gfx11 path and remains fully A/B-able.

What this adds

  • csrc/rocm/moe_gemm_w4a16_wmma.cu: the WMMA kernel, with K, N and the weight N-row stride as runtime args so one instantiation per tile family handles any compliant shape and any weight padding. Compiled into _rocm_C unconditionally; the WMMA body is guarded to gfx11 device passes with an empty stub on other arches (mirroring the existing q_gemm_rdna3_wmma pattern) so multi-arch / CDNA builds still link. The host does constraint-based tile selection keyed by top_k; an unsupported shape TORCH_CHECKs (callers gate first). Registered as torch.ops._rocm_C.moe_gemm_w4a16 in ops.h / torch_bindings.cpp. The A producer load is vectorized to one global_load_b128 per lane (a v4u = 8 bf16) instead of a per-element 16-bit gather, with a matching wide LDS store when the padded row stride allows it.
  • vllm/envs.py: registers VLLM_MOE_HIP as a tri-state env var (unset = default-on on gfx11, 1 forces on, 0 forces the Triton path).
  • vllm/model_executor/layers/fused_moe/moe_hip_w4a16.py: a host-side shape predicate (prefill_uses_rdna_moe_gemm) gates on on_gfx11() plus the same shape/alignment constraints as the kernel, and a registered vLLM custom op torch.ops.vllm.moe_gemm_w4a16 with a no-op fake wraps the call so the path is graph-safe under torch.compile (no data-dependent branch on a device result).
  • vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py: dispatch in HybridW4A16MoEExperts.apply() — gemm1 runs the rdna_moe_gemm kernel when its predicate accepts the shape, else the existing Triton invoke; gemm2 always runs Triton. Both gemms share one block_m=32 alignment, so there is no second moe_align.
  • tests/kernels/quantization/test_moe_gemm_w4a16.py: validates the compiled op against the Triton reference across multiple shapes and block sizes.

A second commit adds an apply()-level record_function profiling scope around the rdna_moe_gemm gemm1 launch (moe_gemm_w4a16 …) carrying the dims the roofline tooling needs — M, N, K, E, top_k, the quant group size g, block_m, valid_blocks (= num_tokens_post_padded // block_m), n_routed, and per-expert routing histograms — gated on VLLM_CUSTOM_SCOPES_FOR_PROFILING / VLLM_NVTX_SCOPES_FOR_PROFILING so the device→host sync it needs runs only when profiling; production gets a nullcontext and pays nothing.

Design notes

The activations between the two gemms are stored compact (flat-topk indexed — the kernels write C[sorted_token_ids[slot]]), so gemm2 gathers them via the shared sorted_token_ids with no re-permute, and the whole MoE runs on a single alignment. We use block_m=32 for both gemms because gemm1 at 32 is faster than at 16 on gfx1151: the larger WMMA M-tile more than pays for the extra alignment padding (measured ~6% faster per useful row), and gemm2 (Triton) is already fastest at 32. An earlier dual-alignment variant (gemm1 at 16, gemm2 at 32 via a second moe_align) was both slower and more complex, so it was dropped.

Performance

End-to-end on gfx1151 against Qwen3.6-35B-A3B-W4A16 (prefill, 100 input tokens + one 1280×720 image, output_len=1), median time-to-first-token, rdna_moe_gemm path (the default) vs the Triton path (VLLM_MOE_HIP=0):

--num-prompts rdna_moe_gemm Triton (VLLM_MOE_HIP=0)
1 574 ms 614 ms
10 619 ms 650 ms

The kernel is faster in both cases (~40 ms / ~6.5% at 1 prompt, ~31 ms / ~4.8% at 10 prompts). The end-to-end gain is much smaller than the kernel speedup because the MoE gemm1 is only one component of prefill, and single-prompt latency carries run-to-run variance.

A second run with a smaller 640×480 image (fewer vision tokens, so prefill is shorter and the MoE gemm1 is a smaller slice of TTFT), median time-to-first-token, rdna_moe_gemm path (the default) vs the Triton path (VLLM_MOE_HIP=0):

--num-prompts rdna_moe_gemm Triton (VLLM_MOE_HIP=0)
1 279 ms 285 ms
10 278 ms 290 ms

Testing

Unit test (compiled op vs Triton reference, relative error < 0.01): pytest tests/kernels/quantization/test_moe_gemm_w4a16.py passes on gfx1151, covering the tuned shapes plus a non-tuned K×N, each across block_m 16 and 32; it skips automatically off gfx11 (where the kernel body is a stub). End-to-end smoke on gfx1151 against Qwen3.6-35B-A3B-W4A16: with the default (kernel active) the engine starts, gemm1 fires through the registered op at block_m=32, and generation succeeds (multimodal sanity check passes); a run with VLLM_MOE_HIP=0 confirms the Triton path still works and is selected. Lint/format/type checks (ruff, ruff-format, clang-format, mypy, typos) pass.

@mgehre-amd

Copy link
Copy Markdown

Can you show the before/after GEMM table from vllm-bench profiling for Qwen3.6?

Comment thread vllm/model_executor/layers/fused_moe/moe_hip_w4a16.py Outdated
@roberteg16 roberteg16 force-pushed the rogarcia.moe_gemm1_new_hip_kernel branch 3 times, most recently from b6879d0 to 7f41b26 Compare June 24, 2026 14:02
@roberteg16

roberteg16 commented Jun 25, 2026

Copy link
Copy Markdown
Author

Before:

Kernel               Shape                       DType    CUDA%    Calls    Avg CUDA    CUDA Total    Weight    BW (GiB/s)    TFLOPS
-------------------  ------------------------  -------  -------  -------  ----------  ------------  --------  ------------  --------
hybrid_triton_moe    994x1024x2048 E=256 k=8      int4    21.4%       40     2.96 ms     118.20 ms    268 MB          84.6     11.29
hybrid_triton_moe    15888x2048x512 E=256 k=1     int4     9.6%       40     1.33 ms      53.16 ms    134 MB          94.1     25.07
[...]

After

  Kernel               Shape                       DType    CUDA%    Calls    Avg CUDA    CUDA Total    Weight    BW (GiB/s)    TFLOPS
-------------------  ------------------------  -------  -------  -------  ----------  ------------  --------  ------------  --------
moe_gemm_w4a16       994x1024x2048 E=256 k=8      int4    18.5%       39     2.50 ms      97.36 ms    268 MB         100.1     13.36
hybrid_triton_moe    15888x2048x512 E=256 k=1     int4     9.8%       40     1.30 ms      52.20 ms    134 MB          95.8     25.53
[...]

@roberteg16 roberteg16 requested a review from mgehre-amd June 25, 2026 08:27
@roberteg16 roberteg16 force-pushed the rogarcia.moe_gemm1_new_hip_kernel branch from d793760 to a899393 Compare June 25, 2026 08:50
…ault-on)

Adds a producer/consumer WMMA kernel (rdna_moe_gemm) for the W4A16 MoE prefill
GEMM1 on AMD RDNA3 (gfx11), a faster alternative to the Triton
fused_moe_kernel_gptq_awq. Enabled by default on gfx11 (VLLM_MOE_HIP=0 forces
Triton, =1 forces on).

On the rdna_moe_gemm path gemm1 (up/gate proj, top_k>1) runs the rdna_moe_gemm
WMMA kernel and gemm2 (down proj, top_k=1) runs Triton, both at a single
block_m=32 moe_align alignment. The per-routed-token activations are stored
compact (flat-topk indexed -- the kernels write C[sorted_token_ids[slot]]), so
gemm2 gathers them via the shared sorted_token_ids with no re-permute. gemm1 at
block_m=32 is faster than at 16 on gfx1151 (the larger WMMA tile more than pays
for the extra alignment padding), so there is a single alignment for both gemms.

- csrc/rocm/moe_gemm_w4a16_wmma.cu: the WMMA kernel with K, N and the weight
  N-row stride as runtime args (one instantiation per tile family handles any
  compliant shape and any weight padding). Compiled into _rocm_C; the WMMA body
  is gfx11-only with a stub on other arches. Registered as
  torch.ops._rocm_C.moe_gemm_w4a16; an unsupported shape TORCH_CHECKs.
- vllm/envs.py: VLLM_MOE_HIP tri-state (unset = default-on on gfx11).
- vllm/model_executor/layers/fused_moe/moe_hip_w4a16.py: host-side shape
  predicate (prefill_uses_rdna_moe_gemm) plus a graph-safe vLLM custom op
  wrapper with a no-op fake, so the path is torch.compile-safe.
- vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py: apply() dispatch --
  rdna_moe_gemm gemm1 when the shape is supported, else Triton; gemm2 always
  Triton.
- tests/kernels/quantization/test_moe_gemm_w4a16.py: compiled op vs Triton
  reference across shapes and block sizes.

AI assistance (Claude) was used.

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Robert Esclapez Garcia <robert.garcia@amd.com>
Wrap the rdna_moe_gemm gemm1 launch (torch.ops.vllm.moe_gemm_w4a16) in apply()
with an apply()-level record_function scope carrying the dims the roofline tool
needs: M, N, K, E, top_k, the quant group size g, block_m, valid_blocks
(= num_tokens_post_padded // block_m) and n_routed (M*top_k), plus per-expert
tok_hist / vtok_hist histograms describing the routing skew that drives the
block_m padding. g is the real self._group_size (the per-group scale bytes and
dequant FLOPs scale with K/g), not an assumed 128.

The scope is gated on VLLM_CUSTOM_SCOPES_FOR_PROFILING /
VLLM_NVTX_SCOPES_FOR_PROFILING -- the valid_blocks/histogram reads force a
device->host sync, taken only when profiling; production gets a nullcontext.

AI assistance (Claude) was used.

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Robert Esclapez Garcia <robert.garcia@amd.com>
@roberteg16 roberteg16 force-pushed the rogarcia.moe_gemm1_new_hip_kernel branch from a899393 to f50fdea Compare June 25, 2026 13:10
@roberteg16

Copy link
Copy Markdown
Author

@mgehre-amd polite ping

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants