Summary
In the Metal fp_qmv_impl kernel, the out_vec_size < 8 (small output-dim) branch's full-block loop loads the fp8 scale as a raw byte and feeds it straight to qdot without decoding it through dequantize_scale. Every other fp scale-load site in the same file decodes it. As a result, fp-quantized (mxfp4 / fp8) mat-vec with output dimension N < 8 multiplies by the raw 0–255 e8m0/e4m3 byte pattern instead of the actual fp8 scale, producing grossly wrong output.
mlx/backend/metal/kernels/fp_quantized.h (current main, blob f4bf438df2679161917414ff02c9870589c6eb26):
// fp_qmv_impl, out_vec_size < 8 branch, full-block loop
441 const device auto* sl = scales + row * in_vec_size_g;
442
443 uint8_t s = sl[0]; // <-- raw scale byte
444 result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s);
vs the sibling sites (incl. the remainder loop of this same branch at :464):
306: U s = dequantize_scale<U, group_size>(sl[0]);
368: U s = dequantize_scale<U, group_size>(sl[0]);
464: U s = dequantize_scale<U, group_size>(sl[0]); // remainder loop of the same <8 branch
495: U s = dequantize_scale<U, group_size>(sl[0]);
514: U s = dequantize_scale<U, group_size>(sl[0]);
qdot (fp_quantized.h:82) takes the scale as a parameter and applies it directly (scale * accum) — it does not decode the scale internally (the reproduction below confirms this). dequantize_scale<U, group_size> reinterprets the byte as fp8_e4m3 (group_size == 16) or fp8_e8m0 and converts to float. Skipping it means qdot multiplies by the literal 0–255 byte value instead of the decoded fp8 scale.
Reproduction (mlx 0.31.2, Apple M5 Max, macOS 26.5.1)
import mlx.core as mx
def check(mode, bits, gs, K, Ns, seed=0):
print(f"# mode={mode} bits={bits} group_size={gs} K={K}")
for N in Ns:
x = mx.random.normal((1, K), key=mx.random.key(seed))
w = mx.random.normal((N, K), key=mx.random.key(seed + 100 + N))
q = mx.quantize(w, bits=bits, group_size=gs, mode=mode) # (wq, scales[, biases])
yq = mx.quantized_matmul(x, *q, transpose=True, group_size=gs, bits=bits, mode=mode)
wdeq = mx.dequantize(*q, group_size=gs, bits=bits, mode=mode)
yref = x @ wdeq.T
mx.eval(yq, yref)
rel = (mx.max(mx.abs(yq - yref)) / (mx.max(mx.abs(yref)) + 1e-9)).item()
branch = "fast(N%8==0)" if (N % 8 == 0 and K % 512 == 0) else ("non-fast <8" if N < 8 else "non-fast >=8")
print(f" N={N:>2} {branch:<14} rel_err={rel:.3e} {'WRONG' if rel > 0.05 else 'ok'}")
check("mxfp4", 4, 32, 512, [5, 7, 12, 8, 16])
check("affine", 4, 32, 512, [5, 7, 12, 8, 16]) # integer path = control, clean everywhere
Output:
# mode=mxfp4 bits=4 group_size=32 K=512
N= 5 non-fast <8 rel_err=2.180e+02 WRONG
N= 7 non-fast <8 rel_err=1.833e+02 WRONG
N=12 non-fast >=8 rel_err=7.835e-08 ok
N= 8 fast(N%8==0) rel_err=2.033e-07 ok
N=16 fast(N%8==0) rel_err=6.738e-08 ok
# mode=affine bits=4 group_size=32 K=512
N= 5 non-fast <8 rel_err=2.518e-07 ok
N= 7 non-fast <8 rel_err=3.587e-07 ok
N=12 non-fast >=8 rel_err=2.422e-07 ok
N= 8 fast(N%8==0) rel_err=4.346e-07 ok
N=16 fast(N%8==0) rel_err=1.992e-07 ok
The error fires only for mxfp4 with N < 8 (218× / 183× relative error; max|yq| ≈ 6900 vs reference ≈ 32 — gross, consistent with using the raw exponent byte). The N=12 non-fast case (which takes the >=8 branch, correctly decoding the scale) is clean, the fast path is clean, and the integer affine path is clean at every N. This isolates the defect to the out_vec_size < 8 full-block loop's scale load.
Why this is a bug, not an intentional fp-vs-int difference
The integer quantized.h qmv_impl legitimately does U s = sl[0]; because there scales is const device T* (a real float array). In fp_quantized.h, fp_qmv_impl's scales is const device uint8_t* (packed e8m0/e4m3), so it must be decoded — which is why every other fp site here calls dequantize_scale. Line 443 looks like a copy-paste from the integer path.
Reachability
Host dispatch qmv() in mlx/backend/metal/quantized.cpp selects the non-fast kernel when N % 8 == 0 && K % 512 == 0 is false. N % 8 != 0 → fp_qmv → fp_qmv_impl; out_vec_size < 8 enters the affected branch; in_vec_size > block_size exercises the full-block loop (line 443). So it bites fp-quantized mat-vecs (M=1) whose output dim is < 8.
Proposed fix (1 line)
// fp_quantized.h:443
- uint8_t s = sl[0];
+ U s = dequantize_scale<U, group_size>(sl[0]);
This makes the full-block loop identical to the remainder loop of the same branch (:464) and the >=8 branch loops (:495/:514). Happy to open a PR with a regression test (a quantized_matmul vs dequantize+matmul check parametrized over N ∈ {1..7, 8, 12, 16} for the fp modes) if that's useful — just let me know if you'd prefer to take the kernel change directly.
Introduced by the original fp-qmv kernel (the integer path's U s = sl[0] pattern carried over). Verified present in current main (blob f4bf438) and reproduced on the released mlx==0.31.2.
Summary
In the Metal
fp_qmv_implkernel, theout_vec_size < 8(small output-dim) branch's full-block loop loads the fp8 scale as a raw byte and feeds it straight toqdotwithout decoding it throughdequantize_scale. Every other fp scale-load site in the same file decodes it. As a result, fp-quantized (mxfp4/ fp8) mat-vec with output dimensionN < 8multiplies by the raw0–255e8m0/e4m3 byte pattern instead of the actual fp8 scale, producing grossly wrong output.mlx/backend/metal/kernels/fp_quantized.h(currentmain, blobf4bf438df2679161917414ff02c9870589c6eb26):vs the sibling sites (incl. the remainder loop of this same branch at :464):
qdot(fp_quantized.h:82) takes the scale as a parameter and applies it directly (scale * accum) — it does not decode the scale internally (the reproduction below confirms this).dequantize_scale<U, group_size>reinterprets the byte asfp8_e4m3(group_size == 16) orfp8_e8m0and converts to float. Skipping it meansqdotmultiplies by the literal0–255byte value instead of the decoded fp8 scale.Reproduction (mlx 0.31.2, Apple M5 Max, macOS 26.5.1)
Output:
The error fires only for
mxfp4withN < 8(218× / 183× relative error;max|yq|≈ 6900 vs reference ≈ 32 — gross, consistent with using the raw exponent byte). TheN=12non-fast case (which takes the>=8branch, correctly decoding the scale) is clean, the fast path is clean, and the integeraffinepath is clean at everyN. This isolates the defect to theout_vec_size < 8full-block loop's scale load.Why this is a bug, not an intentional fp-vs-int difference
The integer
quantized.hqmv_impllegitimately doesU s = sl[0];because therescalesisconst device T*(a real float array). Infp_quantized.h,fp_qmv_impl'sscalesisconst device uint8_t*(packed e8m0/e4m3), so it must be decoded — which is why every other fp site here callsdequantize_scale. Line 443 looks like a copy-paste from the integer path.Reachability
Host dispatch
qmv()inmlx/backend/metal/quantized.cppselects the non-fast kernel whenN % 8 == 0 && K % 512 == 0is false.N % 8 != 0→fp_qmv→fp_qmv_impl;out_vec_size < 8enters the affected branch;in_vec_size > block_sizeexercises the full-block loop (line 443). So it bites fp-quantized mat-vecs (M=1) whose output dim is< 8.Proposed fix (1 line)
This makes the full-block loop identical to the remainder loop of the same branch (:464) and the
>=8branch loops (:495/:514). Happy to open a PR with a regression test (aquantized_matmulvsdequantize+matmul check parametrized overN ∈ {1..7, 8, 12, 16}for the fp modes) if that's useful — just let me know if you'd prefer to take the kernel change directly.Introduced by the original fp-qmv kernel (the integer path's
U s = sl[0]pattern carried over). Verified present in currentmain(blobf4bf438) and reproduced on the releasedmlx==0.31.2.