Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlx/backend/cpu/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1359,4 +1359,8 @@ void QQMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}

void GatherQQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("[GatherQQMM] NYI");
}

} // namespace mlx::core
17 changes: 16 additions & 1 deletion mlx/backend/cuda/device/qmm_naive.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "mlx/backend/cuda/device/cute_dequant.cuh"
#include "mlx/backend/cuda/device/gemm_sm70.cuh"
#include "mlx/backend/cuda/device/utils.cuh"

#include <cuda/cmath>

Expand All @@ -25,13 +26,15 @@ CUTE_DEVICE void qmm_naive_mainloop(
TensorS gS,
TensorZ gZ,
TensorC gC,
const float* global_scale,
int m_max_coord,
int n_max_coord,
int k_residue,
int thread_idx) {
// Get the types of operands.
using Element = typename decltype(gA)::value_type;
using Quant = typename decltype(gB)::value_type;
using Scale = typename decltype(gS)::value_type;

// Shift tensor so we handle residue of K in the 0th tile.
gA = domain_offset(make_coord(0, k_residue, 0), gA);
Expand Down Expand Up @@ -196,7 +199,17 @@ CUTE_DEVICE void qmm_naive_mainloop(
CUTE_UNROLL
for (int i = 0; i < size(tCrC); ++i) {
if ((get<0>(tCcC(i)) < m_max_coord) && (get<1>(tCcC(i)) < n_max_coord)) {
tCgC(i) = Element(tCrC(i));
if constexpr (
cuda::std::is_same_v<Quant, cutlass::float_e2m1_t> &&
cuda::std::is_same_v<Scale, cutlass::float_e4m3_t>) {
if (global_scale) {
tCgC(i) = Element(tCrC(i) * (*global_scale / (F8E4M3_MAX * F4E2M1_MAX)));
} else {
tCgC(i) = Element(tCrC(i));
}
} else {
tCgC(i) = Element(tCrC(i));
}
}
}
}
Expand Down Expand Up @@ -224,6 +237,7 @@ void qmm_naive_kernel(
const Quant* B,
const Scale* S,
const Element* Z,
const float* global_scale,
const uint32_t* lhs_indices,
const uint32_t* rhs_indices,
Element* C,
Expand Down Expand Up @@ -295,6 +309,7 @@ void qmm_naive_kernel(
gS,
gZ,
gC,
global_scale,
m_max_coord, n_max_coord, k_residue,
thread_idx);
}
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/cuda/quantized/qmm/qmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ void qmm_naive(
const array& w,
const array& scales,
const std::optional<array>& biases,
const std::optional<array>& global_scale,
const std::optional<array>& lhs_indices,
const std::optional<array>& rhs_indices,
array& out,
Expand Down
5 changes: 5 additions & 0 deletions mlx/backend/cuda/quantized/qmm/qmm_naive.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void qmm_naive(
const array& w,
const array& scales,
const std::optional<array>& biases,
const std::optional<array>& global_scale,
const std::optional<array>& lhs_indices,
const std::optional<array>& rhs_indices,
array& out,
Expand Down Expand Up @@ -75,6 +76,9 @@ void qmm_naive(
if (biases) {
encoder.set_input_array(*biases);
}
if (global_scale) {
encoder.set_input_array(*global_scale);
}
if (lhs_indices) {
encoder.set_input_array(*lhs_indices);
}
Expand Down Expand Up @@ -103,6 +107,7 @@ void qmm_naive(
gpu_ptr<void>(w),
gpu_ptr<void>(scales),
biases ? gpu_ptr<void>(*biases) : nullptr,
global_scale ? gpu_ptr<void>(*global_scale) : nullptr,
lhs_indices ? gpu_ptr<void>(*lhs_indices) : nullptr,
rhs_indices ? gpu_ptr<void>(*rhs_indices) : nullptr,
gpu_ptr<void>(out),
Expand Down
202 changes: 143 additions & 59 deletions mlx/backend/cuda/quantized/qqmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ std::tuple<array, array> quantize_input(
QuantizationMode mode,
int bits,
int group_size,
std::optional<array> global_scale = std::nullopt) {
std::optional<array> global_scale) {
const array x = ensure_contiguous(input, encoder, s);

// Compute output shapes
Expand Down Expand Up @@ -52,6 +52,27 @@ std::tuple<array, array> quantize_input(
return {std::move(x_q), std::move(scales_x)};
}

array quantize_dequantize_input(
const array& x_pre,
const std::optional<array>& global_scale,
int bits,
int group_size,
cu::CommandEncoder& encoder,
Stream s) {
bool donate_x = x_pre.is_donatable();
array x = ensure_row_contiguous(x_pre, encoder, s);
// If x is a copy it should be donatable
donate_x |= x.is_donatable();
auto xhat = donate_x
? x
: array(cu::malloc_async(x.nbytes(), encoder), x.shape(), x.dtype());
if (!donate_x) {
encoder.add_temporary(xhat);
}
fp_quantize_dequantize(x, xhat, group_size, bits, global_scale, encoder, s);
return xhat;
}

GemmScalars create_nvfp4_scalars(
const array& global_scale_x,
const array& global_scale_w,
Expand All @@ -75,77 +96,81 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto& device = encoder.device();
bool w_quantized = (inputs[1].dtype() == uint32);

const array& x_pre = inputs[0];
const array& w_pre = inputs[1];

out.set_data(cu::malloc_async(out.nbytes(), encoder));

// - 2 inputs: x, w (non-quantized w)
// - 3 inputs: x, w, scales_w (quantized w)
bool w_quantized = (w_pre.dtype() == uint32);
int base_size = w_quantized ? 3 : 2;
assert(
inputs.size() == base_size ||
(mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2));

// For nvfp4, global scales are optional but must be both present or both
// absent If present, they add 2 more inputs (global_scale_x, global_scale_w)
bool has_global_scales =
mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size;
std::optional<array> global_scale_x = std::nullopt;
std::optional<array> global_scale_w = std::nullopt;
mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2;
assert(inputs.size() == base_size || has_global_scales);

std::optional<array> global_scale_x;
std::optional<array> global_scale_w;
if (has_global_scales) {
global_scale_x = inputs[inputs.size() - 2];
global_scale_w = inputs[inputs.size() - 1];
}

if (w_quantized && inputs[0].shape(-2) == 1) {
out.set_data(cu::malloc_async(out.nbytes(), encoder));

bool donate_x = inputs[0].is_donatable();
array x = ensure_row_contiguous(inputs[0], encoder, s);
// If x is a copy it should be donatable
donate_x |= x.is_donatable();
auto xhat = donate_x
? x
: array(cu::malloc_async(x.nbytes(), encoder), x.shape(), x.dtype());
if (!donate_x) {
encoder.add_temporary(xhat);
}
fp_quantize_dequantize(
x, xhat, group_size_, bits_, global_scale_x, encoder, s);

const array& w = inputs[1];
const array& scales = inputs[2];
qmv(xhat,
w,
scales,
std::nullopt,
global_scale_w,
out,
bits_,
group_size_,
mode_,
encoder);
return;
}

auto cc = device.compute_capability_major() * 100 +
device.compute_capability_minor() * 10;
if (cc < 1000) {
throw std::runtime_error(
"[QQMatmul::eval_gpu] QQMM is only supported on GPUs with compute capability 10.0 or higher.");
}

// Quantize inputs (or use pre-quantized)
auto [x_q, scale_x_pre] = quantize_input(
inputs[0], encoder, s, mode_, bits_, group_size_, global_scale_x);
auto [w_q, scale_w_pre] = !w_quantized
// Quantize weights.
auto [w_q, scales_w] = !w_quantized
? quantize_input(
inputs[1], encoder, s, mode_, bits_, group_size_, global_scale_w)
w_pre, encoder, s, mode_, bits_, group_size_, global_scale_w)
: std::make_tuple(
ensure_contiguous(inputs[1], encoder, s),
ensure_contiguous(w_pre, encoder, s),
ensure_contiguous(inputs[2], encoder, s));

out.set_data(cu::malloc_async(out.nbytes(), encoder));
// Reroute to qmm when: no support in cuBLAS, or doing GEMV.
bool can_use_cublas =
(mode_ == QuantizationMode::Nvfp4 || mode_ == QuantizationMode::Mxfp8) &&
(device.compute_capability_major() >= 10);
int M = x_pre.shape(-2);
bool use_qmm = (!can_use_cublas) || (M == 1);

if (use_qmm) {
array x = quantize_dequantize_input(
x_pre, global_scale_x, bits_, group_size_, encoder, s);
if (M < 8) {
qmv(x,
w_q,
scales_w,
std::nullopt,
global_scale_w,
out,
bits_,
group_size_,
mode_,
encoder);
} else {
qmm_naive(
x,
w_q,
scales_w,
std::nullopt,
global_scale_w,
std::nullopt,
std::nullopt,
out,
true, // transpose
bits_,
group_size_,
mode_,
encoder);
}
return;
}

// Quantize activation.
auto [x_q, scales_x] = quantize_input(
x_pre, encoder, s, mode_, bits_, group_size_, global_scale_x);

int M = x_q.shape(-2);
int N = w_q.shape(-2); // transposed
int K = x_q.shape(-1) * (32 / bits_);

Expand All @@ -155,8 +180,8 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int64_t ldb = K;

// Repack scales to tiled layout for tensor cores
array scale_x = pad_and_swizzle_scales(scale_x_pre, encoder, s);
array scale_w = pad_and_swizzle_scales(scale_w_pre, encoder, s);
scales_x = pad_and_swizzle_scales(scales_x, encoder, s);
scales_w = pad_and_swizzle_scales(scales_w, encoder, s);

GemmScalars scalars;
if (has_global_scales) {
Expand All @@ -175,10 +200,69 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
out,
x_q,
w_q,
scale_x,
scale_w,
scales_x,
scales_w,
mode_,
scalars);
}

void GatherQQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("QQMatmul::eval_gpu");

auto& s = stream();
auto& encoder = cu::get_command_encoder(s);

const array& x_pre = inputs[0];
const array& w_pre = inputs[1];
const array& lhs_indices = ensure_row_contiguous(inputs[2], encoder, s);
const array& rhs_indices = ensure_row_contiguous(inputs[3], encoder, s);

out.set_data(cu::malloc_async(out.nbytes(), encoder));

// - 4 inputs: x, w, lhs_indices, rhs_indices (non-quantized w)
// - 5 inputs: x, w, lhs_indices, rhs_indices, scales_w (quantized w)
bool w_quantized = (w_pre.dtype() == uint32);
int base_size = w_quantized ? 5 : 4;
// For nvfp4, global scales are optional but must be both present or both
// absent If present, they add 2 more inputs (global_scale_x, global_scale_w)
bool has_global_scales =
mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2;
assert(inputs.size() == base_size || has_global_scales);

std::optional<array> global_scale_x;
std::optional<array> global_scale_w;
if (has_global_scales) {
global_scale_x = inputs[inputs.size() - 2];
global_scale_w = inputs[inputs.size() - 1];
}

// Quantize weights.
auto [w_q, scales_w] = !w_quantized
? quantize_input(
w_pre, encoder, s, mode_, bits_, group_size_, global_scale_w)
: std::make_tuple(
ensure_contiguous(w_pre, encoder, s),
ensure_contiguous(inputs[4], encoder, s));

// Quantize activation.
array x = quantize_dequantize_input(
x_pre, global_scale_x, bits_, group_size_, encoder, s);

// Reroute to qmm.
qmm_naive(
x,
w_q,
scales_w,
std::nullopt,
global_scale_w,
lhs_indices,
rhs_indices,
out,
true, // transpose
bits_,
group_size_,
mode_,
encoder);
}

} // namespace mlx::core
2 changes: 2 additions & 0 deletions mlx/backend/cuda/quantized/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
biases,
std::nullopt,
std::nullopt,
std::nullopt,
out,
transpose_,
bits_,
Expand Down Expand Up @@ -211,6 +212,7 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
w,
scales,
biases,
std::nullopt,
lhs_indices,
rhs_indices,
out,
Expand Down
4 changes: 4 additions & 0 deletions mlx/backend/metal/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1667,6 +1667,10 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}

void GatherQQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("[GatherQQMM] NYI");
}

void fast::Quantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/no_cpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ NO_CPU(Gather)
NO_CPU(GatherAxis)
NO_CPU(GatherMM)
NO_CPU(GatherQMM)
NO_CPU(GatherQQMM)
NO_CPU(Greater)
NO_CPU(GreaterEqual)
NO_CPU(Hadamard)
Expand Down
Loading
Loading