[ROCm][W4A16] Cache dequantized bf16 weights for the linear prefill GEMM#1007
Draft
roberteg16 wants to merge 1 commit into
Draft
[ROCm][W4A16] Cache dequantized bf16 weights for the linear prefill GEMM#1007roberteg16 wants to merge 1 commit into
roberteg16 wants to merge 1 commit into
Conversation
75b317b to
71a1c0b
Compare
|
Please merge origin/gfx11 to fix build failure |
Opt-in (VLLM_W4A16_PREFILL_BF16=1, default off): on the W4A16 linear prefill path (gfx11/gfx1151) trade VRAM for compute by dequantizing the int4 weight to a bf16/fp16 copy once at load (_hybrid_w_bf16); prefill (M > MAX_SKINNY_BATCH_SIZE) then runs a dense hipBLASLt GEMM (F.linear) on it instead of the fused in-kernel int4 unpack. Decode (M <= 5) is untouched and still uses the int4 wvSplitK path. The bf16 copy is ~4x the int4 weight, so caching is gated per-weight on vLLM's memory budget: with budget = total * gpu_memory_utilization, a copy is kept only while there is room left in the budget (free_mem - total * (1 - util) >= bf16_bytes); otherwise the weight keeps the int4 prefill path and a message is emitted (warn-once). Free memory is read via MemorySnapshot (psutil fallback on integrated/UMA GPUs like gfx1151), matching how the KV-cache budget is sized. The check runs in process_weights_after_loading (before the memory profiler) and uses live free memory, so it is self-limiting -- each copy shrinks the spendable budget until it is exhausted -- and KV-cache sizing accounts for the copies. The fused prefill kernel is compute-bound (~26 TFLOPS, 60.9% peak); the gap is the in-kernel ExLlama unshuffle + per-group dequant. A dense bf16 GEMM hits 33-36 TFLOPS at the prefill shapes. E2E numbers below are from the pre-gating all-weights experiment (re-validation with the budget gate is pending): gfx1151, Qwen3 35B-A3B W4A16, 1280x720 image, input-len 100, output-len 1, ROCM_AITER_FA: median TTFT 633 -> 613 ms, mean 667 -> 652, prefill 1566 -> 1620 tok/s; GPU KV cache 3.48 -> 1.1 GiB. MM sanity passes. AI assistance (Claude) was used. Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Robert Esclapez Garcia <robert.garcia@amd.com>
71a1c0b to
af9ef38
Compare
Author
|
Found a bug, converting to draft |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This adds an opt-in path (
VLLM_W4A16_PREFILL_BF16=1, default off) that trades VRAM for compute on the ROCm W4A16 linear prefill path (gfx11/gfx1151). At load time each W4A16 weight is dequantized once to a dense bf16/fp16 copy; prefill (batchM > MAX_SKINNY_BATCH_SIZE) then runs a dense hipBLASLt GEMM (F.linear) on that copy instead of the fused in-kernel int4 unpack, while decode (M <= 5) is untouched and still uses the int4wvSplitKskinny path. The motivation is that the fused prefill kernel is compute-bound (~26 TFLOPS, ~61% of peak) because of the in-kernel ExLlama unshuffle + per-group dequant, whereas a dense bf16 GEMM reaches ~33-39 TFLOPS at the same shapes. The feature is fully A/B-able: with the env unset nothing changes, and even when set, layers that do not fit the memory budget transparently keep the int4 prefill path.What this adds
vllm/envs.py: registersVLLM_W4A16_PREFILL_BF16(bool, default off). The bf16 copy is only kept while there is room left in thegpu_memory_utilizationbudget; there is no separate reserve knob.vllm/model_executor/kernels/linear/mixed_precision/hybrid_w4a16.py: inprocess_weights_after_loading, when the env is set, dequantize each W4A16 weight to a dense copy and store it as a non-persistent buffer (_hybrid_w_bf16). The copy is gated per-weight on the memory budget (see Memory behavior). The decode-vs-prefill-vs-bf16 dispatch is performed inside the existingtorch.ops.vllm.hybrid_w4a16_applycustom op (which now also takes the optionalw_bf16tensor), so the M-based selection is evaluated at runtime rather than baked at trace time.Design notes
The memory budget gate is anchored to the configured
gpu_memory_utilizationrather than a raw free-VRAM number. Withbudget = total * gpu_memory_utilization, a weight is cached only iffree_mem - total * (1 - gpu_memory_utilization) >= bf16_bytesat load time, otherwise it keeps the int4 prefill path. Free memory is read viaMemorySnapshot, which falls back to psutil on integrated/UMA GPUs (e.g. gfx1151 Strix Halo, wherecudaMemGetInfounderreports free memory), so the gate stays consistent with how vLLM sizes the KV cache. The check uses live free memory and runs before the memory profiler, so it is self-limiting (each copy shrinks the spendable budget until it is exhausted) and the KV-cache sizing automatically accounts for the copies. A warn-once message is emitted if the budget is exhausted.The decode/prefill selection lives inside the opaque custom op on purpose, and this is load-bearing. An earlier version branched in Python in
apply_weights(if w_bf16 is not None and M > MAX_SKINNY_BATCH_SIZE: F.linear(...)). UnderVLLM_COMPILEthat branch is traced with a prefill example (M large), baked as taken, and then reused at decode (M=1) withdynamic_shapes_config.evaluate_guards=False, which forced the dense bf16 GEMM onto the memory-bound decode path and caused a large TPOT/E2E regression. Moving the M-check inside the custom op (which torch.compile treats as opaque) makes the selection run for real on the actual batch size every call, so decode reliably stays on int4.Performance
End-to-end on gfx1151 (Strix Halo),
Qwen3.6-35B-A3B-W4A16, single prompt, 100 input tokens + one 1280x720 image,--output-len 128,--no-cudagraph,VLLM_MOE_HIP=1,ROCM_AITER_FA. Single run each (treat small deltas as indicative).Prefill is ~2.4% faster (TTFT) / ~2.5% higher throughput; decode (TPOT) and E2E are unchanged within run-to-run noise, confirming the int4 decode path is preserved. Profile inspection (
tools/roofline_trace.py --report-by-op --kind ALL --gpu gfx1151) confirms the swap: with the env on, the W4A16 linear prefill no longer shows_triton_w4a16_skinny_fmt_kernel(replaced by dense bf16aten::mm/addmmat the prefill shapes), while at decode the W4A16 layers (e.g. the GDNin_proj, M=1, N=12288) run_rocm_C::wvSplitK_int4_gexactly as in the baseline (227.6 ms total, ~59.7 us/call in both runs).Memory behavior
The bf16 copy is roughly 4x the int4 weight, so it is opt-in and budget-gated. On the run above it consumed ~2.38 GiB (KV cache 7.23 -> 4.85 GiB) at
gpu_memory_utilization=0.9. This is the expected cost; the feature is worthwhile only when prefill latency matters more than KV-cache capacity for the deployment.Testing
pre-commit run ruff-check,ruff-format, andmypy-localpass on the changed files.Qwen3.6-35B-A3B-W4A16using the command above, run three ways: env off (baseline), env on, and env on after the compile-safe dispatch fix. All three complete withFailed requests: 0; the numbers and profile evidence are in the Performance section.tools/roofline_trace.pyon the exported traces confirmed (a) decode stays onwvSplitK_int4_g, (b) prefill uses the dense bf16 GEMM, and (c) the budget gate emits no warning atgpu_memory_utilization=0.9(all copies fit).