Skip to content

Metal fp_qmv_impl: out_vec_size < 8 branch uses raw scale byte instead of dequantize_scale → wrong mxfp4 matvec for output dim < 8 #3762

Description

@jax-0n-git

Summary

In the Metal fp_qmv_impl kernel, the out_vec_size < 8 (small output-dim) branch's full-block loop loads the fp8 scale as a raw byte and feeds it straight to qdot without decoding it through dequantize_scale. Every other fp scale-load site in the same file decodes it. As a result, fp-quantized (mxfp4 / fp8) mat-vec with output dimension N < 8 multiplies by the raw 0–255 e8m0/e4m3 byte pattern instead of the actual fp8 scale, producing grossly wrong output.

mlx/backend/metal/kernels/fp_quantized.h (current main, blob f4bf438df2679161917414ff02c9870589c6eb26):

// fp_qmv_impl, out_vec_size < 8 branch, full-block loop
441    const device auto* sl = scales + row * in_vec_size_g;
442
443    uint8_t s = sl[0];                                                   // <-- raw scale byte
444    result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s);

vs the sibling sites (incl. the remainder loop of this same branch at :464):

306:      U s = dequantize_scale<U, group_size>(sl[0]);
368:      U s = dequantize_scale<U, group_size>(sl[0]);
464:        U s = dequantize_scale<U, group_size>(sl[0]);   // remainder loop of the same <8 branch
495:        U s = dequantize_scale<U, group_size>(sl[0]);
514:        U s = dequantize_scale<U, group_size>(sl[0]);

qdot (fp_quantized.h:82) takes the scale as a parameter and applies it directly (scale * accum) — it does not decode the scale internally (the reproduction below confirms this). dequantize_scale<U, group_size> reinterprets the byte as fp8_e4m3 (group_size == 16) or fp8_e8m0 and converts to float. Skipping it means qdot multiplies by the literal 0–255 byte value instead of the decoded fp8 scale.

Reproduction (mlx 0.31.2, Apple M5 Max, macOS 26.5.1)

import mlx.core as mx

def check(mode, bits, gs, K, Ns, seed=0):
    print(f"# mode={mode} bits={bits} group_size={gs} K={K}")
    for N in Ns:
        x = mx.random.normal((1, K), key=mx.random.key(seed))
        w = mx.random.normal((N, K), key=mx.random.key(seed + 100 + N))
        q = mx.quantize(w, bits=bits, group_size=gs, mode=mode)   # (wq, scales[, biases])
        yq   = mx.quantized_matmul(x, *q, transpose=True, group_size=gs, bits=bits, mode=mode)
        wdeq = mx.dequantize(*q, group_size=gs, bits=bits, mode=mode)
        yref = x @ wdeq.T
        mx.eval(yq, yref)
        rel = (mx.max(mx.abs(yq - yref)) / (mx.max(mx.abs(yref)) + 1e-9)).item()
        branch = "fast(N%8==0)" if (N % 8 == 0 and K % 512 == 0) else ("non-fast <8" if N < 8 else "non-fast >=8")
        print(f"  N={N:>2} {branch:<14} rel_err={rel:.3e}  {'WRONG' if rel > 0.05 else 'ok'}")

check("mxfp4",  4, 32, 512, [5, 7, 12, 8, 16])
check("affine", 4, 32, 512, [5, 7, 12, 8, 16])   # integer path = control, clean everywhere

Output:

# mode=mxfp4 bits=4 group_size=32 K=512
  N= 5 non-fast <8    rel_err=2.180e+02  WRONG
  N= 7 non-fast <8    rel_err=1.833e+02  WRONG
  N=12 non-fast >=8   rel_err=7.835e-08  ok
  N= 8 fast(N%8==0)   rel_err=2.033e-07  ok
  N=16 fast(N%8==0)   rel_err=6.738e-08  ok

# mode=affine bits=4 group_size=32 K=512
  N= 5 non-fast <8    rel_err=2.518e-07  ok
  N= 7 non-fast <8    rel_err=3.587e-07  ok
  N=12 non-fast >=8   rel_err=2.422e-07  ok
  N= 8 fast(N%8==0)   rel_err=4.346e-07  ok
  N=16 fast(N%8==0)   rel_err=1.992e-07  ok

The error fires only for mxfp4 with N < 8 (218× / 183× relative error; max|yq| ≈ 6900 vs reference ≈ 32 — gross, consistent with using the raw exponent byte). The N=12 non-fast case (which takes the >=8 branch, correctly decoding the scale) is clean, the fast path is clean, and the integer affine path is clean at every N. This isolates the defect to the out_vec_size < 8 full-block loop's scale load.

Why this is a bug, not an intentional fp-vs-int difference

The integer quantized.h qmv_impl legitimately does U s = sl[0]; because there scales is const device T* (a real float array). In fp_quantized.h, fp_qmv_impl's scales is const device uint8_t* (packed e8m0/e4m3), so it must be decoded — which is why every other fp site here calls dequantize_scale. Line 443 looks like a copy-paste from the integer path.

Reachability

Host dispatch qmv() in mlx/backend/metal/quantized.cpp selects the non-fast kernel when N % 8 == 0 && K % 512 == 0 is false. N % 8 != 0fp_qmvfp_qmv_impl; out_vec_size < 8 enters the affected branch; in_vec_size > block_size exercises the full-block loop (line 443). So it bites fp-quantized mat-vecs (M=1) whose output dim is < 8.

Proposed fix (1 line)

// fp_quantized.h:443
-        uint8_t s = sl[0];
+        U s = dequantize_scale<U, group_size>(sl[0]);

This makes the full-block loop identical to the remainder loop of the same branch (:464) and the >=8 branch loops (:495/:514). Happy to open a PR with a regression test (a quantized_matmul vs dequantize+matmul check parametrized over N ∈ {1..7, 8, 12, 16} for the fp modes) if that's useful — just let me know if you'd prefer to take the kernel change directly.

Introduced by the original fp-qmv kernel (the integer path's U s = sl[0] pattern carried over). Verified present in current main (blob f4bf438) and reproduced on the released mlx==0.31.2.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions