Skip to content

vulkan: bound per-workgroup KV in flash attention (candidate fix for #185 device-lost)#186

Open
TheTom wants to merge 2 commits into
feature/turboquant-kv-cachefrom
investigate/mtp-devicelost
Open

vulkan: bound per-workgroup KV in flash attention (candidate fix for #185 device-lost)#186
TheTom wants to merge 2 commits into
feature/turboquant-kv-cachefrom
investigate/mtp-devicelost

Conversation

@TheTom

@TheTom TheTom commented Jun 19, 2026

Copy link
Copy Markdown
Owner

Candidate fix for #185 (draft-mtp device-lost at deep context on gfx1151 / Vulkan)

Addresses the GPU device-lost in the MTP catch-up decode at ~47k tokens on Strix Halo (gfx1151) / RADV.

Root cause (hypothesis)

The dense MTP head on hybrid Qwen3.5/3.6 is the only path doing full attention over the whole context (the main model is SWA + DeltaNet, so it never does full-KV attention). At deep context that becomes a single flash-attention dispatch over ~47k KV. The existing split_k heuristic in ggml_vk_flash_attn keeps split_k == 1 whenever the query/head workgroup count already fills the GPU (true for large prefill ubatches), so each workgroup iterates the entire KV serially. A multi-second FA dispatch is consistent with a RADV compute device-lost (the main decode survives the same depth because its attention is windowed).

Fix

Force a split so the per-workgroup KV span is bounded (<= 16384 elements), even when the heuristic would otherwise keep split_k == 1:

const uint32_t max_kv_per_workgroup = 16384;
const uint32_t min_split_k = CEIL_DIV(KV, max_kv_per_workgroup);
if (min_split_k > split_k) split_k = min_split_k;

split_k is the generic FA mechanism (fa_split_k_reduce) shared by the scalar / coopmat1 / coopmat2 paths, and split_k_size is still bounded by the existing maxStorageBufferRange check. At KV=47k this makes split_k = 3 (per-workgroup KV ~15.6k, comfortably below the ~43k depth where it currently starts failing).

Validation

  • Built and ran on RX 9070 XT (gfx1201, AMD proprietary driver): large-KV full-attention FA still produces correct results, no regression. This GPU/driver does NOT reproduce the device-lost, so the crash fix itself is unverified here.
  • Needs validation on the gfx1151 / RADV repro. @reporter: please build this branch and re-run the failing agentic loop (--spec-type draft-mtp, deep context past ~47k). If it still device-losts, also try lowering max_kv_per_workgroup to 8192 (one-line change) and capture RADV_DEBUG=hang so we can localize the exact dispatch.

Notes

  • Separate observation, not addressed here: large-KV FA shows borderline precision drift (ERR ~7e-4 vs the 5e-4 test threshold at 32k+ KV) on gfx1201 with f32 accumulation. Unrelated to the device-lost; tracking separately.

Deep-context flash attention device-losts on AMD APUs / RADV (Strix Halo
gfx1151, issue #185): the byte-based submit heuristic in graph_compute only
accounts for matmul work, so flash-attention-heavy graphs (the dense MTP head
doing full attention over deep KV) batch up to nodes_per_submit=100 nodes into
one vkQueueSubmit, exceeding the GPU job watchdog (amdgpu.lockup_timeout
default 2000ms) and triggering a compute-ring reset / ErrorDeviceLost.

This is the same root cause and fix as upstream ggml-org#21724
(lowering nodes_per_submit resolves the identical device-lost with no
measurable regression). Default to frequent submits on uma/integrated devices
and add a GGML_VK_NODES_PER_SUBMIT override for tuning. Discrete GPUs keep the
existing batch size.
@TheTom TheTom force-pushed the investigate/mtp-devicelost branch from 6cc8c50 to c584da0 Compare June 20, 2026 01:37
…rams

Logs per-FLASH_ATTN_EXT dispatch: N, KV, gqa_ratio, K/V types, mask + mask_opt
state, uma, split_k/split_kv, and workgroup grid. The last line before a
device-lost identifies the exact crashing dispatch and its single-submit GPU
shape (issue #185), to drive the KV-chunking thresholds. Env-gated, no effect
unless GGML_VK_FA_LOG is set. Pair with GGML_VK_PERF_LOGGER for per-op GPU
timing to confirm the FA submit approaching the amdgpu watchdog.
@TheTom

TheTom commented Jun 21, 2026

Copy link
Copy Markdown
Owner Author

Added env-gated FA diagnostics on this branch (5cc7f0cfc): GGML_VK_FA_LOG=1 logs per-FLASH_ATTN_EXT dispatch params (N, KV, gqa, K/V types, mask/mask_opt, uma, split_k/split_kv, wg). Zero-risk (env-gated, no behavior change). Run the deep-context (>54k) repro on this branch with GGML_VK_FA_LOG=1 GGML_VK_PERF_LOGGER=1 and capture the last [FA] line before the device-lost + the FA op timing. Full instructions in #185. That data sizes the KV-chunking patch (reuse the existing split_k partial + fa_split_k_reduce, driven across per-seq submits). This branch also carries the shipped v2 fix (nodes_per_submit = uma ? 1 : 100) which holds to ~50k.

@Defilan

Defilan commented Jun 23, 2026

Copy link
Copy Markdown

gfx1151 / RADV validation — fix holds, no device-lost at 70k.

Built the #186 branch tip (5cc7f0c, system_fingerprint: b1-5cc7f0c) and ran the repro on the Strix Halo box (AMD Radeon 8060S, gfx1151, RADV, Vulkan): --flash-attn on --cache-type-k f16 --cache-type-v f16 --spec-type draft-mtp --ubatch-size 2048 --parallel 1, single deep-context request, GGML_VK_FA_LOG=1.

Result: no device-lost. Prefilled to 70,020 prompt tokens and completed normally (finish_reason: stop) — well past the ~43k–47k where it previously crashed every run. 665 FLASH_ATTN_EXT dispatches captured, KV 256 → 70144, 0 crashes.

Deepest dispatches:

[FA] N=2048 KV=65536 gqa=1 K=f16 V=f16 mask=1 mask_opt=1 uma=1 split_k=1 split_kv=65536 wg=[2048,24,1]
[FA] N=2048 KV=67584 gqa=1 K=f16 V=f16 mask=1 mask_opt=1 uma=1 split_k=1 split_kv=67584 wg=[2048,24,1]
[FA] N=2044 KV=70144 gqa=1 K=f16 V=f16 mask=1 mask_opt=1 uma=1 split_k=1 split_kv=70144 wg=[2044,24,1]

Note split_k stayed 1 (a single full-KV dispatch) all the way to 70k — so on this box the device-lost is resolved by the c584da0 submit-frequency change, not by FA chunking. Happy to also build with max_kv_per_workgroup forcing a split if you want to compare the two approaches at depth.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants