diff --git a/mlx/backend/metal/kernels/fp_quantized.h b/mlx/backend/metal/kernels/fp_quantized.h index 8d6740db5b..5dd81b4f1e 100644 --- a/mlx/backend/metal/kernels/fp_quantized.h +++ b/mlx/backend/metal/kernels/fp_quantized.h @@ -440,7 +440,7 @@ METAL_FUNC void fp_qmv_impl( auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device auto* sl = scales + row * in_vec_size_g; - uint8_t s = sl[0]; + U s = dequantize_scale(sl[0]); result[row] += qdot(wl, x_thread, s); } diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 2850c7c357..bd76df42f1 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -482,6 +482,47 @@ def test_fp_qmv(self): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_fp_qmv_small_non_multiples(self): + # Regression test for the fp_qmv_impl out_vec_size < 8 branch, whose + # full-block loop loaded the fp8 scale as a raw byte instead of decoding + # it through dequantize_scale (see fp_quantized.h). The bug only bites + # fp-quantized matvec when the output dim is < 8 AND K is large enough to + # run at least one full block (K > block_size), so it slipped past both + # test_fp_qmv (output dim >= 8) and test_qmv_small_non_multiples (K = 32, + # below block_size, exercising only the correct remainder loop). + # K = 512 forces full blocks; N in {1..7} hits the < 8 branch. + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + K = 512 + for mode, group_size, bits in [ + ("mxfp4", 32, 4), + ("mxfp8", 32, 8), + ("nvfp4", 16, 4), + ]: + for M in [1, 2]: + for N in [1, 2, 3, 5, 7]: + with self.subTest(M=M, K=K, N=N, mode=mode): + x = mx.random.normal(shape=(M, K), key=k1) / K**0.5 + w = mx.random.normal(shape=(N, K), key=k2) / K**0.5 + w_q, scales = mx.quantize( + w, group_size=group_size, bits=bits, mode=mode + ) + w_hat = mx.dequantize( + w_q, scales, group_size=group_size, bits=bits, mode=mode + ) + y_q = mx.quantized_matmul( + x, + w_q, + scales, + transpose=True, + group_size=group_size, + bits=bits, + mode=mode, + ) + y_hat = x @ mx.swapaxes(w_hat, -1, -2) + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_qmv_wide(self): # M in [2, vector_limit) routes to qmv_wide -- except K in {64, 128} # with power-of-2 bits, which stays on qmv_quad. Check both paths