diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index c0f1a3c315..3469d99788 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -1359,4 +1359,8 @@ void QQMatmul::eval_cpu(const std::vector& inputs, array& out) { } } +void GatherQQMM::eval_cpu(const std::vector& inputs, array& out) { + throw std::runtime_error("[GatherQQMM] NYI"); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/device/qmm_naive.cuh b/mlx/backend/cuda/device/qmm_naive.cuh index 01e5f444d5..f5c73c260b 100644 --- a/mlx/backend/cuda/device/qmm_naive.cuh +++ b/mlx/backend/cuda/device/qmm_naive.cuh @@ -2,6 +2,7 @@ #include "mlx/backend/cuda/device/cute_dequant.cuh" #include "mlx/backend/cuda/device/gemm_sm70.cuh" +#include "mlx/backend/cuda/device/utils.cuh" #include @@ -25,6 +26,7 @@ CUTE_DEVICE void qmm_naive_mainloop( TensorS gS, TensorZ gZ, TensorC gC, + const float* global_scale, int m_max_coord, int n_max_coord, int k_residue, @@ -32,6 +34,7 @@ CUTE_DEVICE void qmm_naive_mainloop( // Get the types of operands. using Element = typename decltype(gA)::value_type; using Quant = typename decltype(gB)::value_type; + using Scale = typename decltype(gS)::value_type; // Shift tensor so we handle residue of K in the 0th tile. gA = domain_offset(make_coord(0, k_residue, 0), gA); @@ -196,7 +199,17 @@ CUTE_DEVICE void qmm_naive_mainloop( CUTE_UNROLL for (int i = 0; i < size(tCrC); ++i) { if ((get<0>(tCcC(i)) < m_max_coord) && (get<1>(tCcC(i)) < n_max_coord)) { - tCgC(i) = Element(tCrC(i)); + if constexpr ( + cuda::std::is_same_v && + cuda::std::is_same_v) { + if (global_scale) { + tCgC(i) = Element(tCrC(i) * (*global_scale / (F8E4M3_MAX * F4E2M1_MAX))); + } else { + tCgC(i) = Element(tCrC(i)); + } + } else { + tCgC(i) = Element(tCrC(i)); + } } } } @@ -224,6 +237,7 @@ void qmm_naive_kernel( const Quant* B, const Scale* S, const Element* Z, + const float* global_scale, const uint32_t* lhs_indices, const uint32_t* rhs_indices, Element* C, @@ -295,6 +309,7 @@ void qmm_naive_kernel( gS, gZ, gC, + global_scale, m_max_coord, n_max_coord, k_residue, thread_idx); } diff --git a/mlx/backend/cuda/quantized/qmm/qmm.h b/mlx/backend/cuda/quantized/qmm/qmm.h index 8d998cda40..64cfecfd5a 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.h +++ b/mlx/backend/cuda/quantized/qmm/qmm.h @@ -74,6 +74,7 @@ void qmm_naive( const array& w, const array& scales, const std::optional& biases, + const std::optional& global_scale, const std::optional& lhs_indices, const std::optional& rhs_indices, array& out, diff --git a/mlx/backend/cuda/quantized/qmm/qmm_naive.cu b/mlx/backend/cuda/quantized/qmm/qmm_naive.cu index cb47d7f1aa..c7c5c6049a 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_naive.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm_naive.cu @@ -29,6 +29,7 @@ void qmm_naive( const array& w, const array& scales, const std::optional& biases, + const std::optional& global_scale, const std::optional& lhs_indices, const std::optional& rhs_indices, array& out, @@ -75,6 +76,9 @@ void qmm_naive( if (biases) { encoder.set_input_array(*biases); } + if (global_scale) { + encoder.set_input_array(*global_scale); + } if (lhs_indices) { encoder.set_input_array(*lhs_indices); } @@ -103,6 +107,7 @@ void qmm_naive( gpu_ptr(w), gpu_ptr(scales), biases ? gpu_ptr(*biases) : nullptr, + global_scale ? gpu_ptr(*global_scale) : nullptr, lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, rhs_indices ? gpu_ptr(*rhs_indices) : nullptr, gpu_ptr(out), diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index eaec2ac8f4..06cb9f3452 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -21,7 +21,7 @@ std::tuple quantize_input( QuantizationMode mode, int bits, int group_size, - std::optional global_scale = std::nullopt) { + std::optional global_scale) { const array x = ensure_contiguous(input, encoder, s); // Compute output shapes @@ -52,6 +52,27 @@ std::tuple quantize_input( return {std::move(x_q), std::move(scales_x)}; } +array quantize_dequantize_input( + const array& x_pre, + const std::optional& global_scale, + int bits, + int group_size, + cu::CommandEncoder& encoder, + Stream s) { + bool donate_x = x_pre.is_donatable(); + array x = ensure_row_contiguous(x_pre, encoder, s); + // If x is a copy it should be donatable + donate_x |= x.is_donatable(); + auto xhat = donate_x + ? x + : array(cu::malloc_async(x.nbytes(), encoder), x.shape(), x.dtype()); + if (!donate_x) { + encoder.add_temporary(xhat); + } + fp_quantize_dequantize(x, xhat, group_size, bits, global_scale, encoder, s); + return xhat; +} + GemmScalars create_nvfp4_scalars( const array& global_scale_x, const array& global_scale_w, @@ -75,77 +96,81 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& encoder = cu::get_command_encoder(s); auto& device = encoder.device(); - bool w_quantized = (inputs[1].dtype() == uint32); + + const array& x_pre = inputs[0]; + const array& w_pre = inputs[1]; + + out.set_data(cu::malloc_async(out.nbytes(), encoder)); // - 2 inputs: x, w (non-quantized w) // - 3 inputs: x, w, scales_w (quantized w) + bool w_quantized = (w_pre.dtype() == uint32); int base_size = w_quantized ? 3 : 2; - assert( - inputs.size() == base_size || - (mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2)); - // For nvfp4, global scales are optional but must be both present or both // absent If present, they add 2 more inputs (global_scale_x, global_scale_w) bool has_global_scales = - mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size; - std::optional global_scale_x = std::nullopt; - std::optional global_scale_w = std::nullopt; + mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2; + assert(inputs.size() == base_size || has_global_scales); + + std::optional global_scale_x; + std::optional global_scale_w; if (has_global_scales) { global_scale_x = inputs[inputs.size() - 2]; global_scale_w = inputs[inputs.size() - 1]; } - if (w_quantized && inputs[0].shape(-2) == 1) { - out.set_data(cu::malloc_async(out.nbytes(), encoder)); - - bool donate_x = inputs[0].is_donatable(); - array x = ensure_row_contiguous(inputs[0], encoder, s); - // If x is a copy it should be donatable - donate_x |= x.is_donatable(); - auto xhat = donate_x - ? x - : array(cu::malloc_async(x.nbytes(), encoder), x.shape(), x.dtype()); - if (!donate_x) { - encoder.add_temporary(xhat); - } - fp_quantize_dequantize( - x, xhat, group_size_, bits_, global_scale_x, encoder, s); - - const array& w = inputs[1]; - const array& scales = inputs[2]; - qmv(xhat, - w, - scales, - std::nullopt, - global_scale_w, - out, - bits_, - group_size_, - mode_, - encoder); - return; - } - - auto cc = device.compute_capability_major() * 100 + - device.compute_capability_minor() * 10; - if (cc < 1000) { - throw std::runtime_error( - "[QQMatmul::eval_gpu] QQMM is only supported on GPUs with compute capability 10.0 or higher."); - } - - // Quantize inputs (or use pre-quantized) - auto [x_q, scale_x_pre] = quantize_input( - inputs[0], encoder, s, mode_, bits_, group_size_, global_scale_x); - auto [w_q, scale_w_pre] = !w_quantized + // Quantize weights. + auto [w_q, scales_w] = !w_quantized ? quantize_input( - inputs[1], encoder, s, mode_, bits_, group_size_, global_scale_w) + w_pre, encoder, s, mode_, bits_, group_size_, global_scale_w) : std::make_tuple( - ensure_contiguous(inputs[1], encoder, s), + ensure_contiguous(w_pre, encoder, s), ensure_contiguous(inputs[2], encoder, s)); - out.set_data(cu::malloc_async(out.nbytes(), encoder)); + // Reroute to qmm when: no support in cuBLAS, or doing GEMV. + bool can_use_cublas = + (mode_ == QuantizationMode::Nvfp4 || mode_ == QuantizationMode::Mxfp8) && + (device.compute_capability_major() >= 10); + int M = x_pre.shape(-2); + bool use_qmm = (!can_use_cublas) || (M == 1); + + if (use_qmm) { + array x = quantize_dequantize_input( + x_pre, global_scale_x, bits_, group_size_, encoder, s); + if (M < 8) { + qmv(x, + w_q, + scales_w, + std::nullopt, + global_scale_w, + out, + bits_, + group_size_, + mode_, + encoder); + } else { + qmm_naive( + x, + w_q, + scales_w, + std::nullopt, + global_scale_w, + std::nullopt, + std::nullopt, + out, + true, // transpose + bits_, + group_size_, + mode_, + encoder); + } + return; + } + + // Quantize activation. + auto [x_q, scales_x] = quantize_input( + x_pre, encoder, s, mode_, bits_, group_size_, global_scale_x); - int M = x_q.shape(-2); int N = w_q.shape(-2); // transposed int K = x_q.shape(-1) * (32 / bits_); @@ -155,8 +180,8 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { int64_t ldb = K; // Repack scales to tiled layout for tensor cores - array scale_x = pad_and_swizzle_scales(scale_x_pre, encoder, s); - array scale_w = pad_and_swizzle_scales(scale_w_pre, encoder, s); + scales_x = pad_and_swizzle_scales(scales_x, encoder, s); + scales_w = pad_and_swizzle_scales(scales_w, encoder, s); GemmScalars scalars; if (has_global_scales) { @@ -175,10 +200,69 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { out, x_q, w_q, - scale_x, - scale_w, + scales_x, + scales_w, mode_, scalars); } +void GatherQQMM::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("QQMatmul::eval_gpu"); + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + const array& x_pre = inputs[0]; + const array& w_pre = inputs[1]; + const array& lhs_indices = ensure_row_contiguous(inputs[2], encoder, s); + const array& rhs_indices = ensure_row_contiguous(inputs[3], encoder, s); + + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + + // - 4 inputs: x, w, lhs_indices, rhs_indices (non-quantized w) + // - 5 inputs: x, w, lhs_indices, rhs_indices, scales_w (quantized w) + bool w_quantized = (w_pre.dtype() == uint32); + int base_size = w_quantized ? 5 : 4; + // For nvfp4, global scales are optional but must be both present or both + // absent If present, they add 2 more inputs (global_scale_x, global_scale_w) + bool has_global_scales = + mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2; + assert(inputs.size() == base_size || has_global_scales); + + std::optional global_scale_x; + std::optional global_scale_w; + if (has_global_scales) { + global_scale_x = inputs[inputs.size() - 2]; + global_scale_w = inputs[inputs.size() - 1]; + } + + // Quantize weights. + auto [w_q, scales_w] = !w_quantized + ? quantize_input( + w_pre, encoder, s, mode_, bits_, group_size_, global_scale_w) + : std::make_tuple( + ensure_contiguous(w_pre, encoder, s), + ensure_contiguous(inputs[4], encoder, s)); + + // Quantize activation. + array x = quantize_dequantize_input( + x_pre, global_scale_x, bits_, group_size_, encoder, s); + + // Reroute to qmm. + qmm_naive( + x, + w_q, + scales_w, + std::nullopt, + global_scale_w, + lhs_indices, + rhs_indices, + out, + true, // transpose + bits_, + group_size_, + mode_, + encoder); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 2a1a268c91..4d25f3c3e0 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -72,6 +72,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { biases, std::nullopt, std::nullopt, + std::nullopt, out, transpose_, bits_, @@ -211,6 +212,7 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { w, scales, biases, + std::nullopt, lhs_indices, rhs_indices, out, diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index cfba4c0f8d..0e9a1c694c 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -1667,6 +1667,10 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { } } +void GatherQQMM::eval_gpu(const std::vector& inputs, array& out) { + throw std::runtime_error("[GatherQQMM] NYI"); +} + void fast::Quantize::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index ae51dd9b2f..faaeb0c7c4 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -71,6 +71,7 @@ NO_CPU(Gather) NO_CPU(GatherAxis) NO_CPU(GatherMM) NO_CPU(GatherQMM) +NO_CPU(GatherQQMM) NO_CPU(Greater) NO_CPU(GreaterEqual) NO_CPU(Hadamard) diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 4819ed2724..0e05e9d19f 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -98,6 +98,7 @@ NO_GPU(Gather) NO_GPU(GatherAxis) NO_GPU(GatherMM) NO_GPU(GatherQMM) +NO_GPU(GatherQQMM) NO_GPU(Greater) NO_GPU(GreaterEqual) NO_GPU(Hadamard) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 143c394ece..34b4f24144 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -65,11 +65,26 @@ Dtype at_least_float(const Dtype& d) { } array indices_or_default( - std::optional indices, + std::string_view tag, + const std::optional& indices, const array& x, StreamOrDevice s) { + if (x.ndim() < 2) { + std::ostringstream msg; + msg << tag + << " Input must have at least two dimensions but got input with shape " + << x.shape() << "."; + throw std::invalid_argument(msg.str()); + } + if (indices.has_value()) { - return indices.value(); + if (!issubdtype(indices->dtype(), integer)) { + std::ostringstream msg; + msg << tag + << " Got indices with invalid dtype. Indices must be integral."; + throw std::invalid_argument(msg.str()); + } + return astype(indices.value(), uint32); } Shape shape(x.shape().begin(), x.shape().end() - 2); @@ -4547,10 +4562,10 @@ void validate_global_scale( } array quantized_matmul( - array x, - array w, - array scales, - std::optional biases /* = std::nullopt */, + const array& x, + const array& w, + const array& scales, + const std::optional& biases /* = std::nullopt */, bool transpose /* = true */, std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, @@ -4599,13 +4614,13 @@ array quantized_matmul( } void validate_qqmm_inputs( - array x, - array w, - std::optional scales_w, + const array& x, + const array& w, + const std::optional& scales_w, int group_size, int bits, - std::optional global_scale_x, - std::optional global_scale_w, + const std::optional& global_scale_x, + const std::optional& global_scale_w, QuantizationMode qmode) { // check 2D (for now) if (x.ndim() > 2 || w.ndim() > 2) { @@ -4659,9 +4674,9 @@ void validate_qqmm_inputs( } std::pair extract_qqmm_dims( - array x, - array w, - std::optional scales_w, + const array& x, + const array& w, + const std::optional& scales_w, int group_size, int bits) { if (w.dtype() != uint32) { @@ -4689,24 +4704,17 @@ std::pair extract_qqmm_dims( } array qqmm( - array in_x, - array w, - std::optional scales_w, + const array& in_x, + const array& w, + const std::optional& scales_w, std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "nvfp4" */, - const std::optional global_scale_x /* = std::nullopt */, - const std::optional global_scale_w /* = std::nullopt */, + const std::optional& global_scale_x /* = std::nullopt */, + const std::optional& global_scale_w /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto stream = to_stream(s); auto qmode = string_to_quantization_mode(mode, "qqmm"); - // cuBLAS block scaled matmul only supports nvfp4 and mxfp8 - if (qmode != QuantizationMode::Nvfp4 && qmode != QuantizationMode::Mxfp8) { - std::ostringstream msg; - msg << "[qqmm] Only 'nvfp4' and 'mxfp8' quantization modes are supported but '" - << mode << "' was provided."; - throw std::invalid_argument(msg.str()); - } // we need to check 2 cases: // 1. w is quantized, scales is provided // 2. w is not quantized, scales is not provided @@ -5373,30 +5381,11 @@ array gather_qmm( } // Extract indices and broadcast them - array lhs_indices = indices_or_default(lhs_indices_, x, s); - array rhs_indices = indices_or_default(rhs_indices_, w, s); + array lhs_indices = indices_or_default("[gather_qmm]", lhs_indices_, x, s); + array rhs_indices = indices_or_default("[gather_qmm]", rhs_indices_, w, s); std::tie(lhs_indices, rhs_indices) = broadcast_arrays(lhs_indices, rhs_indices, s); - if (!issubdtype(lhs_indices.dtype(), integer)) { - throw std::invalid_argument( - "[gather_qmm] Got lhs_indices with invalid dtype. Indices must be integral."); - } - - if (!issubdtype(rhs_indices.dtype(), integer)) { - throw std::invalid_argument( - "[gather_qmm] Got rhs_indices with invalid dtype. Indices must be integral."); - } - if (x.ndim() < 2) { - std::ostringstream msg; - msg << "[gather_qmm] Non-quantized input must have at least two" - << " dimensions but got input with shape " << x.shape() << "."; - throw std::invalid_argument(msg.str()); - } - - lhs_indices = astype(lhs_indices, uint32, s); - rhs_indices = astype(rhs_indices, uint32, s); - // Compute the full output shape auto out_shape = lhs_indices.shape(); out_shape.push_back(x.shape(-2)); @@ -5432,6 +5421,56 @@ array gather_qmm( std::move(inputs)); } +array gather_qqmm( + const array& x, + const array& w, + const std::optional& scales_w, + const std::optional& lhs_indices_, + const std::optional& rhs_indices_, + std::optional group_size_, + std::optional bits_, + const std::string& mode, + const std::optional& global_scale_x, + const std::optional& global_scale_w, + bool sorted_indices, + StreamOrDevice s) { + auto stream = to_stream(s); + auto qmode = string_to_quantization_mode(mode, "gather_qqmm"); + auto [group_size, bits] = + quantization_params_from_mode(qmode, group_size_, bits_); + + // Extract indices and broadcast them + array lhs_indices = indices_or_default("[gather_qqmm]", lhs_indices_, x, s); + array rhs_indices = indices_or_default("[gather_qqmm]", rhs_indices_, w, s); + std::tie(lhs_indices, rhs_indices) = + broadcast_arrays(lhs_indices, rhs_indices, s); + + std::vector inputs = { + x, + w, + lhs_indices, + rhs_indices, + }; + if (scales_w.has_value()) { + inputs.push_back(*scales_w); + } + if (global_scale_x.has_value() && global_scale_w.has_value()) { + inputs.push_back(*global_scale_x); + inputs.push_back(*global_scale_w); + } + + auto [w_inner_dims, w_outer_dims] = + extract_qqmm_dims(x, w, scales_w, group_size, bits); + auto out_shape = lhs_indices.shape(); + out_shape.push_back(x.shape(-2)); + out_shape.push_back(w_outer_dims); + return array( + std::move(out_shape), + x.dtype(), + std::make_shared(stream, group_size, bits, qmode), + std::move(inputs)); +} + array tensordot( const array& a, const array& b, @@ -5908,28 +5947,14 @@ array gather_mm( b = astype(b, out_type, s); // Handle broadcasting - array lhs_indices = indices_or_default(lhs_indices_, a, s); - array rhs_indices = indices_or_default(rhs_indices_, b, s); - - if (!issubdtype(lhs_indices.dtype(), integer)) { - throw std::invalid_argument( - "[gather_mm] Got lhs_indices with invalid dtype. Indices must be integral."); - } - - if (!issubdtype(rhs_indices.dtype(), integer)) { - throw std::invalid_argument( - "[gather_mm] Got rhs_indices with invalid dtype. Indices must be integral."); - } - - lhs_indices = astype(lhs_indices, uint32, s); - rhs_indices = astype(rhs_indices, uint32, s); + array lhs_indices = indices_or_default("[gather_mm]", lhs_indices_, a, s); + array rhs_indices = indices_or_default("[gather_mm]", rhs_indices_, b, s); + std::tie(lhs_indices, rhs_indices) = + broadcast_arrays(lhs_indices, rhs_indices, s); int M = a.shape(-2); int N = b.shape(-1); - std::tie(lhs_indices, rhs_indices) = - broadcast_arrays(lhs_indices, rhs_indices, s); - auto out_shape = lhs_indices.shape(); out_shape.push_back(M); out_shape.push_back(N); diff --git a/mlx/ops.h b/mlx/ops.h index 3bcc97fe09..ffd53f6ca4 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1490,10 +1490,10 @@ MLX_API array conv_transpose3d( /** Quantized matmul multiplies x with a quantized matrix w*/ MLX_API array quantized_matmul( - array x, - array w, - array scales, - std::optional biases = std::nullopt, + const array& x, + const array& w, + const array& scales, + const std::optional& biases = std::nullopt, bool transpose = true, std::optional group_size = std::nullopt, std::optional bits = std::nullopt, @@ -1522,15 +1522,15 @@ MLX_API array dequantize( StreamOrDevice s = {}); MLX_API array qqmm( - array x, // input activations - array w, // maybe quantized weights - const std::optional w_scales = std::nullopt, // optional scales if w - // is quantized + const array& x, // input activations + const array& w, // maybe quantized weights + const std::optional& w_scales = std::nullopt, // optional scales if w + // is quantized std::optional group_size = std::nullopt, std::optional bits = std::nullopt, const std::string& mode = "nvfp4", - const std::optional global_scale_x = std::nullopt, - const std::optional global_scale_w = std::nullopt, + const std::optional& global_scale_x = std::nullopt, + const std::optional& global_scale_w = std::nullopt, StreamOrDevice s = {}); /** Convert an E4M3 float8 to the given floating point dtype. */ @@ -1554,6 +1554,20 @@ MLX_API array gather_qmm( bool sorted_indices = false, StreamOrDevice s = {}); +MLX_API array gather_qqmm( + const array& x, + const array& w, + const std::optional& scales_w = std::nullopt, + const std::optional& lhs_indices = std::nullopt, + const std::optional& rhs_indices = std::nullopt, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "nvfp4", + const std::optional& global_scale_x = std::nullopt, + const std::optional& global_scale_w = std::nullopt, + bool sorted_indices = false, + StreamOrDevice s = {}); + /** Returns a contraction of a and b over multiple dimensions. */ MLX_API array tensordot( const array& a, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 75bae1ba87..b4320ec97b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3813,6 +3813,22 @@ std::vector GatherQMM::output_shapes(const std::vector& inputs) { return {out_shape}; } +bool GatherQQMM::is_equivalent(const Primitive& other) const { + const GatherQQMM& qm_other = static_cast(other); + return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && + mode_ == qm_other.mode_; +} + +std::vector GatherQQMM::output_shapes(const std::vector& inputs) { + const auto& x = inputs[0]; + const auto& w = inputs[1]; + const auto& lhs_indices = inputs[2]; + auto out_shape = lhs_indices.shape(); + out_shape.push_back(x.shape(-2)); + out_shape.push_back(w.shape(-2)); + return {out_shape}; +} + std::pair, std::vector> RandomBits::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 5b8517c56d..403490dbe8 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1716,6 +1716,41 @@ class GatherQMM : public UnaryPrimitive { bool right_sorted_; }; +class GatherQQMM : public UnaryPrimitive { + public: + explicit GatherQQMM( + Stream stream, + int group_size, + int bits, + QuantizationMode mode, + bool left_sorted = false, + bool right_sorted = false) + : UnaryPrimitive(stream), + group_size_(group_size), + bits_(bits), + mode_(mode), + left_sorted_(left_sorted), + right_sorted_(right_sorted) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_NAME(GatherQQMM) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_tuple( + group_size_, bits_, mode_, left_sorted_, right_sorted_); + } + + private: + int group_size_; + int bits_; + QuantizationMode mode_; + bool left_sorted_; + bool right_sorted_; +}; + class RandomBits : public UnaryPrimitive { public: explicit RandomBits(Stream stream, const Shape& shape, int width) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 67dd5ac4f3..f6f382dd76 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4657,6 +4657,58 @@ void init_ops(nb::module_& m) { array: The result of the multiplication of ``x`` with ``w`` after gathering using ``lhs_indices`` and ``rhs_indices``. )pbdoc"); + m.def( + "gather_qqmm", + &mx::gather_qqmm, + nb::arg(), + nb::arg(), + "scales"_a = nb::none(), + "lhs_indices"_a = nb::none(), + "rhs_indices"_a = nb::none(), + "group_size"_a = nb::none(), + "bits"_a = nb::none(), + "mode"_a = "nvfp4", + "global_scale_x"_a = nb::none(), + "global_scale_w"_a = nb::none(), + nb::kw_only(), + "sorted_indices"_a = false, + "stream"_a = nb::none(), + nb::sig( + "def gather_qqmm(x: array, w: array, /, scales: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'nvfp4', global_scale_x: Optional[array] = None, global_scale_w: Optional[array] = None, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Fused :func:`qqmm` with matrix-level gather. + + Similar to :func:`gather_mm`, the indices ``lhs_indices`` and + ``rhs_indices`` contain flat indices along the batch dimensions (i.e. + all but the last two dimensions) of ``x`` and ``w`` respectively. + + Args: + x (array): Input array. + w (array): Weight matrix. If quantized, it is packed in unsigned integers. + scales (array, optional): The scales to use per ``group_size`` elements of + ``w`` if ``w`` is quantized. Default: ``None``. + lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``. + rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``. + group_size (int, optional): Number of elements in ``x`` and ``w`` that + share a scale. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + bits (int, optional): Number of bits used to represent each element of + ``x`` and ``w``. See supported values and defaults in the + :ref:`table of quantization modes `. Default: ``None``. + mode (str, optional): The quantization mode. Default: ``"nvfp4"``. + Supported modes are ``nvfp4`` and ``mxfp8``. See the + :ref:`table of quantization modes ` for details. + global_scale (array, optional): The per-input float32 scale used for x + with ``"nvfp4"`` quantization. Default: ``None``. + global_scale_w (array, optional): The per-input float32 scale used for w + with ``"nvfp4"`` quantization. Default: ``None``. + sorted_indices (bool, optional): May allow a faster implementation + if the passed indices are sorted. Default: ``False``. + + Returns: + array: The result of the multiplication of quantized ``x`` with quantized ``w``. + needed). + )pbdoc"); m.def( "segmented_mm", &mx::segmented_mm, diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index f30170d44d..3ba7fdfaf6 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -224,6 +224,55 @@ def test_qqmv(self): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) + @unittest.skipIf( + not mx.cuda.is_available() or mx.default_device() == mx.cpu, + "Only implemented in CUDA", + ) + def test_qqmm(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + tests = product( + [8, 32, 33, 64], # M + [128, 256], # N + [128, 256], # K + ["nvfp4", "mxfp8"], # mode + ) + for M, N, K, mode in tests: + with self.subTest(shape=(M, N, K), mode=mode): + x_shape = (M, K) + w_shape = (N, K) + + x = mx.random.normal(shape=x_shape, key=k1) + global_scale_x = mx.max(mx.abs(x)) if mode == "nvfp4" else None + x_hat = mx.dequantize( + *mx.quantize(x, mode=mode, global_scale=global_scale_x), + mode=mode, + dtype=mx.float32, + global_scale=global_scale_x, + ) + + w = mx.random.normal(shape=w_shape, key=k2) + global_scale_w = mx.max(mx.abs(w)) if mode == "nvfp4" else None + w_q, scales = mx.quantize(w, mode=mode, global_scale=global_scale_w) + w_hat = mx.dequantize( + w_q, + scales, + mode=mode, + global_scale=global_scale_w, + dtype=mx.float32, + ) + y_q = mx.qqmm( + x, + w_q, + scales, + mode=mode, + global_scale_x=global_scale_x, + global_scale_w=global_scale_w, + ) + y_hat = x_hat @ mx.swapaxes(w_hat, -1, -2) + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_qqmm_metal_global_scale_rejected(self): # Tensor-scale nvfp4 (global_scale_x / global_scale_w) is not # implemented in the Metal qqmm kernels. mx.qqmm must reject the @@ -995,6 +1044,100 @@ def test_shape( test_shape(32, 512, 32, transpose=False, **kwargs) test_shape(1, 512, 32, transpose=False, **kwargs) + @unittest.skipIf( + not mx.cuda.is_available() or mx.default_device() == mx.cpu, + "Only implemented in CUDA", + ) + def test_gather_qqmm(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + batches = ( + { + "batch_A": (1,), + "lhs_indices": (0,), + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + { + "batch_A": (1,), + "lhs_indices": None, + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + { + "batch_A": (2,), + "lhs_indices": None, + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + { + "batch_A": (3,), + "lhs_indices": (0, 2), + "batch_B": (1,), + "rhs_indices": (0,), + }, + { + "batch_A": (5,), + "lhs_indices": (0, 2), + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + ) + tests = product( + batches, + [1, 32], # M + [32, 256], # N + [32, 256], # K + ["nvfp4", "mxfp8"], # mode + ) + + for batch, M, N, K, mode in tests: + with self.subTest(shape=(M, N, K), mode=mode, **batch): + batch_A, lhs_indices, batch_B, rhs_indices = batch.values() + x_shape = (*batch_A, M, K) + w_shape = (*batch_B, N, K) + + x = mx.random.normal(shape=x_shape, key=k1) + global_scale_x = mx.max(mx.abs(x)) if mode == "nvfp4" else None + x_hat = mx.dequantize( + *mx.quantize(x, mode=mode, global_scale=global_scale_x), + mode=mode, + dtype=mx.float32, + global_scale=global_scale_x, + ) + + w = mx.random.normal(shape=w_shape, key=k2) + global_scale_w = mx.max(mx.abs(w)) if mode == "nvfp4" else None + w_q, scales = mx.quantize(w, mode=mode, global_scale=global_scale_w) + w_hat = mx.dequantize( + w_q, + scales, + mode=mode, + global_scale=global_scale_w, + dtype=mx.float32, + ) + + if lhs_indices is not None: + lhs_indices = mx.array(lhs_indices) + if rhs_indices is not None: + rhs_indices = mx.array(rhs_indices) + + y_q = mx.gather_qqmm( + x, + w_q, + scales, + lhs_indices, + rhs_indices, + mode=mode, + global_scale_x=global_scale_x, + global_scale_w=global_scale_w, + ) + y_hat = mx.gather_mm( + x_hat, mx.swapaxes(w_hat, -1, -2), lhs_indices, rhs_indices + ) + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_qmm_fp_type(self): indices = mx.array([[2], [0], [1]], dtype=mx.uint32)