vulkan: bound per-workgroup KV in flash attention (candidate fix for #185 device-lost)#186
vulkan: bound per-workgroup KV in flash attention (candidate fix for #185 device-lost)#186TheTom wants to merge 2 commits into
Conversation
Deep-context flash attention device-losts on AMD APUs / RADV (Strix Halo gfx1151, issue #185): the byte-based submit heuristic in graph_compute only accounts for matmul work, so flash-attention-heavy graphs (the dense MTP head doing full attention over deep KV) batch up to nodes_per_submit=100 nodes into one vkQueueSubmit, exceeding the GPU job watchdog (amdgpu.lockup_timeout default 2000ms) and triggering a compute-ring reset / ErrorDeviceLost. This is the same root cause and fix as upstream ggml-org#21724 (lowering nodes_per_submit resolves the identical device-lost with no measurable regression). Default to frequent submits on uma/integrated devices and add a GGML_VK_NODES_PER_SUBMIT override for tuning. Discrete GPUs keep the existing batch size.
6cc8c50 to
c584da0
Compare
…rams Logs per-FLASH_ATTN_EXT dispatch: N, KV, gqa_ratio, K/V types, mask + mask_opt state, uma, split_k/split_kv, and workgroup grid. The last line before a device-lost identifies the exact crashing dispatch and its single-submit GPU shape (issue #185), to drive the KV-chunking thresholds. Env-gated, no effect unless GGML_VK_FA_LOG is set. Pair with GGML_VK_PERF_LOGGER for per-op GPU timing to confirm the FA submit approaching the amdgpu watchdog.
|
Added env-gated FA diagnostics on this branch ( |
|
gfx1151 / RADV validation — fix holds, no device-lost at 70k. Built the #186 branch tip ( Result: no device-lost. Prefilled to 70,020 prompt tokens and completed normally ( Deepest dispatches: Note |
Candidate fix for #185 (draft-mtp device-lost at deep context on gfx1151 / Vulkan)
Addresses the GPU device-lost in the MTP catch-up decode at ~47k tokens on Strix Halo (gfx1151) / RADV.
Root cause (hypothesis)
The dense MTP head on hybrid Qwen3.5/3.6 is the only path doing full attention over the whole context (the main model is SWA + DeltaNet, so it never does full-KV attention). At deep context that becomes a single flash-attention dispatch over ~47k KV. The existing
split_kheuristic inggml_vk_flash_attnkeepssplit_k == 1whenever the query/head workgroup count already fills the GPU (true for large prefill ubatches), so each workgroup iterates the entire KV serially. A multi-second FA dispatch is consistent with a RADV compute device-lost (the main decode survives the same depth because its attention is windowed).Fix
Force a split so the per-workgroup KV span is bounded (<= 16384 elements), even when the heuristic would otherwise keep
split_k == 1:split_kis the generic FA mechanism (fa_split_k_reduce) shared by the scalar / coopmat1 / coopmat2 paths, andsplit_k_sizeis still bounded by the existingmaxStorageBufferRangecheck. At KV=47k this makessplit_k= 3 (per-workgroup KV ~15.6k, comfortably below the ~43k depth where it currently starts failing).Validation
--spec-type draft-mtp, deep context past ~47k). If it still device-losts, also try loweringmax_kv_per_workgroupto 8192 (one-line change) and captureRADV_DEBUG=hangso we can localize the exact dispatch.Notes