Metal: FP8-packed compressed-KV cache + long-context memory optimizations#416
Open
lixiangnlp wants to merge 1 commit into
Open
Metal: FP8-packed compressed-KV cache + long-context memory optimizations#416lixiangnlp wants to merge 1 commit into
lixiangnlp wants to merge 1 commit into
Conversation
Three zero-drift memory optimizations for the Metal compressed-KV path plus an
opt-in packed FP8 comp cache:
- packed FP8 comp cache (584 vs 1024 B/row, e4m3+ue8m0+f16 rot; LUT dequant),
opt-in via DS4_METAL_FP8_KV_STORE; read directly by indexed + flash kernels.
- comp_mask stored f16 (binary -inf/0 mask, exact in half).
- indexer_scores token-tiled (comp_cap*512 vs comp_cap*prefill_cap, 8x).
DS4_DISABLE_KV_OPTS=1 reverts to the prior layout for A/B.
ds4-bench (M5 Max, macOS 27 / Metal 4.1, DeepSeek-V4-Flash, 8k-64k sweep):
bit-identical logits; KV cache ~2.3x smaller; long-context decode +4.6%..+17.7%.
See PR_FP8_KV_MEMORY.md for the full design + results.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
aledesogusbusiness-hue
pushed a commit
to aledesogusbusiness-hue/ds4
that referenced
this pull request
Jun 16, 2026
…ions Three memory optimizations for the Metal backend's compressed-KV (MLA latent) path: 1. Packed FP8 comp cache (opt-in, DS4_METAL_FP8_KV_STORE=1): stores comp rows as e4m3 + ue8m0 scale + f16 rot = 584 B/row vs 1024 B/row f16. Dequant uses a 128-entry LUT (ds4_e4m3_lut) avoiding branch+exp2. Validated bit-identical. 2. comp_mask stored as f16 (always on): binary -inf/0 mask fits exactly in f16, halving the mask buffer size at all context lengths. 3. indexer_scores token-tiling (always on): DS4_INDEXER_SCORE_TILE=512 reduces the score working buffer from comp_cap*prefill_cap to comp_cap*512 (~8x). Together: ~2.3x KV cache reduction at long context, bit-identical output, speed-neutral decode (bandwidth saving offset by per-element dequant cost). Revert with DS4_DISABLE_KV_OPTS=1. Tested on M5 Max, 8k-96k context. Fixes: antirez#416 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR: FP8-packed compressed-KV cache + long-context memory optimizations (Metal)
Author: lixiang
<lixiang.ict@gmail.com>Assisted by: Claude Opus 4.8 (design, implementation, and validation done in collaboration with Claude Opus 4.8)
Summary
Three always-on, zero-drift memory optimizations for the Metal backend's compressed
(MLA latent) KV path, plus an opt-in packed FP8 comp cache. Together they cut
context-dependent KV memory by ~2.3x with bit-identical model output and
speed-neutral decode (the comp-read bandwidth saved is offset by the per-element
e4m3 dequant, so decode is break-even — this is a memory optimization, not a speedup).
All three are on by default;
DS4_DISABLE_KV_OPTS=1reverts to the previous layoutfor A/B comparison.
DS4_METAL_FP8_KV_STORE=1additionally enables the packed FP8comp cache (opt-in because it changes the comp cache to e4m3 precision — validated
output-equivalent, see below).
Motivation
At long context the compressed-KV cache and its indexer scratch dominate the
context-dependent memory and the per-step decode bandwidth. Profiling the
DeepSeek-V4-Flash model showed the comp cache (f16) plus two f32 scratch buffers
(
indexer_scores,comp_mask, eachcomp_cap × prefill_cap) are the largestcontext-scaling allocations, and the repeated comp-KV read is the decode bottleneck
at long context.
What changed
Packed FP8 compressed-KV cache (opt-in,
DS4_METAL_FP8_KV_STORE).The persistent comp cache is stored in a 3-plane row layout
(
struct ds4_fp8_kv_row: e4m3 nope bytes + ue8m0 scale + f16 rot) = 584 B/rowvs 1024 B/row f16. The compressor writes packed once; the indexed-attention
kernel and a region-aware FlashAttention variant read packed directly (dequant
to half is bit-identical to the prior
(half)value). e4m3 decode uses a128-entry constant LUT (
ds4_e4m3_lut) so the per-element dequant is a tableload, not a branch+exp2.
metal/flash_attn.metal:ds4_fp8_kv_row,ds4_e4m3_lut/ds4_e4m3_mag,dequantize_fp8_kv_t4/_4x4, MIX region-aware vec + non-vec kernels.metal/dsv4_kv.metal:kernel_dsv4_kv_pack_fp8_row_f32,kernel_dsv4_kv_unpack_fp8_row_f16.ds4.c:metal_graph_attn_comp_cache_format()(0=f32/1=f16/2=packed),packed alloc/store/row_bytes, session save/restore payload sizing.
ds4_metal.m: packed pipelines,ds4_gpu_dsv4_kv_pack_comp_rows, theformat-aware comp→f16 staging unpack, indexed/flash dispatch wiring.
comp_maskstored f16 (always on; zero-drift). The top-k mask is binary(
-inf/0), both exact in f16 → halves thecomp_cap × prefill_capbuffer.metal/dsv4_misc.metal:kernel_dsv4_topk_mask/_scatterwrite f16 or f32by an
args.mask_f16flag.metal/cpy.metal:kernel_cpy_f16_f16.ds4_metal.m:ds4_gpu_comp_mask_f16(), f16/f32-aware wrapper + readers.indexer_scorestoken-tiling (always on; zero-drift). The score matrix isprocessed in token tiles of 512 so the buffer is
comp_cap × 512instead ofcomp_cap × prefill_cap(4096) = 8x smaller. Top-k is per-token independent,so tiling is exact.
ds4_metal.m:ds4_gpu_indexer_prefill_score_topk_tiled+ds4_gpu_indexer_decode_batch_score_topk_tiled(token-offset views, pos0=t0).ds4.c:DS4_INDEXER_SCORE_TILE, tiled call sites, scratch sizing.Also fixed: the session/checkpoint payload-size accounting
(
session_payload_live_tensor_bytes,ds4_session_layer_payload_bytes) now usesthe packed row stride when packed (previously assumed f32 → "trailing payload
bytes" on packed save/restore).
Experiment environment
(
-std=metal4.1).ds4-bench, LongMemEval_s prompt, context sweep 8k/16k/32k/64k(+ a 96k point), greedy decode 64 tokens. A/B via the single binary:
DS4_DISABLE_KV_OPTS=1(original) vsDS4_METAL_FP8_KV_STORE=1(optimized).The two sweeps run back-to-back with a 300 s cooldown between them so both start
from a comparable thermal state; the speed comparison was additionally repeated
with the sweep order reversed to cancel the residual run-order thermal bias (see
Performance). Memory = process
phys_footprint(viafootprint -p PID), whichexcludes the file-backed model mmap and so captures exactly the KV + scratch +
activations.
Validation results
Quality — zero difference. Frontier next-token logits are bit-identical
(
max_abs_diff = 0.0, 0/129280 vocab logits differ; argmax token + logit identical)at every frontier (8k, 16k, 32k, 64k, and 96k). e4m3-precision comp was separately
confirmed output-equivalent on 3 LongMemEval cases.
Memory — KV cache ~2.3x smaller.
The
kvcache_bytesfigures above are deterministic (computed from the layout, notmeasured). Process
phys_footprint(which also includes activations + indexerscratch and is sampled at 5 s, so it varies slightly run to run) lands at peak
~6.2 GB → ~5.4 GB (≈12–14%) and avg ~11% saved at the 64k sweep. Projected ~7.8 GB
saved at 1M context.
Performance — prefill equal, decode speed-neutral (break-even).
Decode throughput on this hardware is dominated by run-to-run thermal variance that
is larger than any effect from the change. The longest context (64k) is the last and
hottest measurement in a sweep, so whichever variant runs second loses ~15% there
purely from heat. Running the A/B in both orders and averaging cancels that bias:
The 64k decode figure flips sign with run order (−16.5% ↔ +12.5%), which by itself
proves it is a thermal artifact rather than an algorithmic cost. After averaging,
every context is within ±2%: decode is break-even. The comp-KV read bandwidth
saved (1024 → 584 B/row) is offset by the per-element e4m3 dequant (kept cheap by the
128-entry LUT), and both scale with comp-row count, so the net is context-independent
and ~zero. Prefill is equal. This change buys memory, not speed.
Notes
opt-in because it pins comp to e4m3 precision (validated output-equivalent).
ps -o rssis misleading for this model (it excludes the shared Metal mmap andreports ~91 MB); use
footprint/phys_footprint.🤖 Generated with Claude Code