Skip to content

[ROCm][W4A16] Cache dequantized bf16 weights for the linear prefill GEMM#1007

Draft
roberteg16 wants to merge 1 commit into
gfx11from
rogarcia.w4a16-prefill-bf16-cache
Draft

[ROCm][W4A16] Cache dequantized bf16 weights for the linear prefill GEMM#1007
roberteg16 wants to merge 1 commit into
gfx11from
rogarcia.w4a16-prefill-bf16-cache

Conversation

@roberteg16

@roberteg16 roberteg16 commented Jun 15, 2026

Copy link
Copy Markdown

Summary

This adds an opt-in path (VLLM_W4A16_PREFILL_BF16=1, default off) that trades VRAM for compute on the ROCm W4A16 linear prefill path (gfx11/gfx1151). At load time each W4A16 weight is dequantized once to a dense bf16/fp16 copy; prefill (batch M > MAX_SKINNY_BATCH_SIZE) then runs a dense hipBLASLt GEMM (F.linear) on that copy instead of the fused in-kernel int4 unpack, while decode (M <= 5) is untouched and still uses the int4 wvSplitK skinny path. The motivation is that the fused prefill kernel is compute-bound (~26 TFLOPS, ~61% of peak) because of the in-kernel ExLlama unshuffle + per-group dequant, whereas a dense bf16 GEMM reaches ~33-39 TFLOPS at the same shapes. The feature is fully A/B-able: with the env unset nothing changes, and even when set, layers that do not fit the memory budget transparently keep the int4 prefill path.

What this adds

  • vllm/envs.py: registers VLLM_W4A16_PREFILL_BF16 (bool, default off). The bf16 copy is only kept while there is room left in the gpu_memory_utilization budget; there is no separate reserve knob.
  • vllm/model_executor/kernels/linear/mixed_precision/hybrid_w4a16.py: in process_weights_after_loading, when the env is set, dequantize each W4A16 weight to a dense copy and store it as a non-persistent buffer (_hybrid_w_bf16). The copy is gated per-weight on the memory budget (see Memory behavior). The decode-vs-prefill-vs-bf16 dispatch is performed inside the existing torch.ops.vllm.hybrid_w4a16_apply custom op (which now also takes the optional w_bf16 tensor), so the M-based selection is evaluated at runtime rather than baked at trace time.

Design notes

The memory budget gate is anchored to the configured gpu_memory_utilization rather than a raw free-VRAM number. With budget = total * gpu_memory_utilization, a weight is cached only if free_mem - total * (1 - gpu_memory_utilization) >= bf16_bytes at load time, otherwise it keeps the int4 prefill path. Free memory is read via MemorySnapshot, which falls back to psutil on integrated/UMA GPUs (e.g. gfx1151 Strix Halo, where cudaMemGetInfo underreports free memory), so the gate stays consistent with how vLLM sizes the KV cache. The check uses live free memory and runs before the memory profiler, so it is self-limiting (each copy shrinks the spendable budget until it is exhausted) and the KV-cache sizing automatically accounts for the copies. A warn-once message is emitted if the budget is exhausted.

The decode/prefill selection lives inside the opaque custom op on purpose, and this is load-bearing. An earlier version branched in Python in apply_weights (if w_bf16 is not None and M > MAX_SKINNY_BATCH_SIZE: F.linear(...)). Under VLLM_COMPILE that branch is traced with a prefill example (M large), baked as taken, and then reused at decode (M=1) with dynamic_shapes_config.evaluate_guards=False, which forced the dense bf16 GEMM onto the memory-bound decode path and caused a large TPOT/E2E regression. Moving the M-check inside the custom op (which torch.compile treats as opaque) makes the selection run for real on the actual batch size every call, so decode reliably stays on int4.

Performance

End-to-end on gfx1151 (Strix Halo), Qwen3.6-35B-A3B-W4A16, single prompt, 100 input tokens + one 1280x720 image, --output-len 128, --no-cudagraph, VLLM_MOE_HIP=1, ROCM_AITER_FA. Single run each (treat small deltas as indicative).

metric baseline (off) bf16 cache on
TTFT (ms) 598.1 583.5
prefill (tok/s) 1662 1703
TPOT (ms) 30.86 31.45
E2E latency (ms) 4517 4577
KV cache 7.23 GiB / 210,066 tok 4.85 GiB / 140,434 tok

Prefill is ~2.4% faster (TTFT) / ~2.5% higher throughput; decode (TPOT) and E2E are unchanged within run-to-run noise, confirming the int4 decode path is preserved. Profile inspection (tools/roofline_trace.py --report-by-op --kind ALL --gpu gfx1151) confirms the swap: with the env on, the W4A16 linear prefill no longer shows _triton_w4a16_skinny_fmt_kernel (replaced by dense bf16 aten::mm/addmm at the prefill shapes), while at decode the W4A16 layers (e.g. the GDN in_proj, M=1, N=12288) run _rocm_C::wvSplitK_int4_g exactly as in the baseline (227.6 ms total, ~59.7 us/call in both runs).

Memory behavior

The bf16 copy is roughly 4x the int4 weight, so it is opt-in and budget-gated. On the run above it consumed ~2.38 GiB (KV cache 7.23 -> 4.85 GiB) at gpu_memory_utilization=0.9. This is the expected cost; the feature is worthwhile only when prefill latency matters more than KV-cache capacity for the deployment.

Testing

  • Lint/type: pre-commit run ruff-check, ruff-format, and mypy-local pass on the changed files.
  • End-to-end smoke + A/B on gfx1151 against Qwen3.6-35B-A3B-W4A16 using the command above, run three ways: env off (baseline), env on, and env on after the compile-safe dispatch fix. All three complete with Failed requests: 0; the numbers and profile evidence are in the Performance section.
  • Profile verification via tools/roofline_trace.py on the exported traces confirmed (a) decode stays on wvSplitK_int4_g, (b) prefill uses the dense bf16 GEMM, and (c) the budget gate emits no warning at gpu_memory_utilization=0.9 (all copies fit).

@roberteg16 roberteg16 force-pushed the rogarcia.w4a16-prefill-bf16-cache branch 2 times, most recently from 75b317b to 71a1c0b Compare June 25, 2026 21:09
@roberteg16 roberteg16 changed the title [EXPERIMENT][ROCm][W4A16] Cache dequantized bf16 weights for prefill GEMM [ROCm][W4A16] Cache dequantized bf16 weights for the linear prefill GEMM Jun 25, 2026
@roberteg16 roberteg16 marked this pull request as ready for review June 25, 2026 21:16
@roberteg16 roberteg16 requested a review from mgehre-amd June 25, 2026 21:18

@mgehre-amd mgehre-amd left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@mgehre-amd

Copy link
Copy Markdown

Please merge origin/gfx11 to fix build failure

Opt-in (VLLM_W4A16_PREFILL_BF16=1, default off): on the W4A16 linear
prefill path (gfx11/gfx1151) trade VRAM for compute by dequantizing the
int4 weight to a bf16/fp16 copy once at load (_hybrid_w_bf16); prefill
(M > MAX_SKINNY_BATCH_SIZE) then runs a dense hipBLASLt GEMM (F.linear) on
it instead of the fused in-kernel int4 unpack. Decode (M <= 5) is
untouched and still uses the int4 wvSplitK path.

The bf16 copy is ~4x the int4 weight, so caching is gated per-weight on
vLLM's memory budget: with budget = total * gpu_memory_utilization, a copy
is kept only while there is room left in the budget
(free_mem - total * (1 - util) >= bf16_bytes); otherwise the weight keeps
the int4 prefill path and a message is emitted (warn-once). Free memory is
read via MemorySnapshot (psutil fallback on integrated/UMA GPUs like
gfx1151), matching how the KV-cache budget is sized. The check runs in
process_weights_after_loading (before the memory profiler) and uses live
free memory, so it is self-limiting -- each copy shrinks the spendable
budget until it is exhausted -- and KV-cache sizing accounts for the copies.

The fused prefill kernel is compute-bound (~26 TFLOPS, 60.9% peak); the
gap is the in-kernel ExLlama unshuffle + per-group dequant. A dense bf16
GEMM hits 33-36 TFLOPS at the prefill shapes.

E2E numbers below are from the pre-gating all-weights experiment
(re-validation with the budget gate is pending): gfx1151, Qwen3 35B-A3B
W4A16, 1280x720 image, input-len 100, output-len 1, ROCM_AITER_FA: median
TTFT 633 -> 613 ms, mean 667 -> 652, prefill 1566 -> 1620 tok/s; GPU KV
cache 3.48 -> 1.1 GiB. MM sanity passes. AI assistance (Claude) was used.

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Robert Esclapez Garcia <robert.garcia@amd.com>
@roberteg16 roberteg16 force-pushed the rogarcia.w4a16-prefill-bf16-cache branch from 71a1c0b to af9ef38 Compare June 26, 2026 15:21
@roberteg16 roberteg16 marked this pull request as draft June 26, 2026 17:04
@roberteg16

Copy link
Copy Markdown
Author

Found a bug, converting to draft

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants