From 7db225ca6f78dc2877e30d24f549051706012294 Mon Sep 17 00:00:00 2001 From: Alessio Pollero Date: Fri, 19 Jun 2026 19:48:15 +0400 Subject: [PATCH] Add global scale support to quantized layers Store optional globalScale on QuantizedLinear and QuantizedEmbedding so nvfp4 weights can preserve the scale needed by lower-level MLX quantize/dequantize operations. Forward the scale when creating or dequantizing weights, add a direct pre-quantized QuantizedEmbedding initializer to match QuantizedLinear, and guard global-scale execution paths on Metal because MLX does not support globalScale dequantization there. Add focused tests for the new layer state and parameter exposure without broadening the generic quantization API surface. --- Source/MLXNN/Quantized.swift | 83 ++++++++++++++++++++------ Tests/MLXTests/QuantizationTests.swift | 58 ++++++++++++++++++ 2 files changed, 122 insertions(+), 19 deletions(-) diff --git a/Source/MLXNN/Quantized.swift b/Source/MLXNN/Quantized.swift index 076e91ca..c71056ee 100644 --- a/Source/MLXNN/Quantized.swift +++ b/Source/MLXNN/Quantized.swift @@ -159,6 +159,7 @@ open class QuantizedEmbedding: Embedding, Quantized { public let mode: QuantizationMode public let scales: MLXArray public let biases: MLXArray? + @ParameterInfo(key: "global_scale") public private(set) var globalScale: MLXArray? open override var shape: (Int, Int) { let (embeddingCount, dimensions) = super.shape @@ -184,14 +185,16 @@ open class QuantizedEmbedding: Embedding, Quantized { public init( weight: MLXArray, groupSize: Int = 64, bits: Int = 4, - mode: QuantizationMode = .affine + mode: QuantizationMode = .affine, + globalScale: MLXArray? = nil ) { self.groupSize = groupSize self.bits = bits self.mode = mode + self._globalScale.wrappedValue = globalScale let (quantizedWeight, scales, biases) = MLX.quantized( - weight, groupSize: groupSize, bits: bits, mode: mode) + weight, groupSize: groupSize, bits: bits, mode: mode, globalScale: globalScale) self.scales = scales self.biases = biases @@ -201,20 +204,46 @@ open class QuantizedEmbedding: Embedding, Quantized { self.freeze() } + /// Initializer meant for subclasses to provide arrays directly. + public init( + weight: MLXArray, scales: MLXArray, biases: MLXArray?, + groupSize: Int, bits: Int, + mode: QuantizationMode = .affine, + globalScale: MLXArray? = nil + ) { + self.groupSize = groupSize + self.bits = bits + self.mode = mode + self.scales = scales + self.biases = biases + self._globalScale.wrappedValue = globalScale + super.init(weight: weight) + } + open override func callAsFunction(_ x: MLXArray) -> MLXArray { let s = x.shape let x = x.flattened() let out = dequantized( weight[x], scales: scales[x], biases: biases == nil ? nil : biases![x], - groupSize: groupSize, bits: bits, mode: mode) + groupSize: groupSize, bits: bits, mode: mode, globalScale: globalScale) return out.reshaped(s + [-1]) } open override func asLinear(_ x: MLXArray) -> MLXArray { - quantizedMM( + if globalScale != nil { + return matmul(x, dequantizedWeight.T) + } + + return quantizedMM( x, weight, scales: scales, biases: biases, transpose: true, groupSize: groupSize, bits: bits, mode: mode) } + + private var dequantizedWeight: MLXArray { + dequantized( + weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits, mode: mode, + globalScale: globalScale) + } } /// Applies an affine transformation to the input using a quantized weight matrix. @@ -243,6 +272,7 @@ open class QuantizedLinear: Linear, Quantized { public let mode: QuantizationMode public let scales: MLXArray public let biases: MLXArray? + @ParameterInfo(key: "global_scale") public private(set) var globalScale: MLXArray? open override var shape: (Int, Int) { let shape = weight.shape2 @@ -292,14 +322,16 @@ open class QuantizedLinear: Linear, Quantized { /// Initialize a ``QuantizedLinear`` with non-quantized weights and bias. public init( weight: MLXArray, bias: MLXArray?, groupSize: Int = 64, bits: Int = 4, - mode: QuantizationMode = .affine + mode: QuantizationMode = .affine, + globalScale: MLXArray? = nil ) { self.groupSize = groupSize self.bits = bits self.mode = mode + self._globalScale.wrappedValue = globalScale let (quantizedWeight, scales, biases) = MLX.quantized( - weight, groupSize: groupSize, bits: bits, mode: mode) + weight, groupSize: groupSize, bits: bits, mode: mode, globalScale: globalScale) self.scales = scales self.biases = biases @@ -316,13 +348,15 @@ open class QuantizedLinear: Linear, Quantized { public init( weight: MLXArray, bias: MLXArray? = nil, scales: MLXArray, biases: MLXArray?, groupSize: Int, bits: Int, - mode: QuantizationMode = .affine + mode: QuantizationMode = .affine, + globalScale: MLXArray? = nil ) { self.groupSize = groupSize self.bits = bits self.mode = mode self.scales = scales self.biases = biases + self._globalScale.wrappedValue = globalScale super.init(weight: weight, bias: bias) } @@ -334,20 +368,31 @@ open class QuantizedLinear: Linear, Quantized { } open override func callAsFunction(_ x: MLXArray) -> MLXArray { - var x = quantizedMM( - x, - weight, - scales: scales, - biases: biases, - transpose: true, - groupSize: groupSize, - bits: bits, - mode: mode - ) + var result: MLXArray + if globalScale != nil { + result = matmul(x, dequantizedWeight.T) + } else { + result = quantizedMM( + x, + weight, + scales: scales, + biases: biases, + transpose: true, + groupSize: groupSize, + bits: bits, + mode: mode + ) + } if let bias { - x = x + bias + result = result + bias } - return x + return result + } + + private var dequantizedWeight: MLXArray { + dequantized( + weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits, mode: mode, + globalScale: globalScale) } /// Returns a QuantizedLinear layer that applies the same linear transformation up to the quantization error. diff --git a/Tests/MLXTests/QuantizationTests.swift b/Tests/MLXTests/QuantizationTests.swift index 0edbd545..ed810b1b 100644 --- a/Tests/MLXTests/QuantizationTests.swift +++ b/Tests/MLXTests/QuantizationTests.swift @@ -39,4 +39,62 @@ class QuantizationTests: XCTestCase { let quantized = QuantizedLinear(64, 64, groupSize: 32, bits: 4, mode: .mxfp4) XCTAssertNil(quantized.biases) } + + func testQuantizedLinearStoresGlobalScale() { + let globalScale = MLXArray(1.0, dtype: .float32) + let quantized = QuantizedLinear( + weight: MLXArray.zeros([8, 4], dtype: .uint32), + bias: nil, + scales: MLXArray.ones([8, 4], dtype: .uint8), + biases: nil, + groupSize: 16, + bits: 4, + mode: .nvfp4, + globalScale: globalScale) + + XCTAssertNotNil(quantized.globalScale) + XCTAssertEqual(quantized.globalScale?.dtype, .float32) + XCTAssertNotNil(quantized.parameters()["global_scale"]) + XCTAssertNil(quantized.parameters()["globalScale"]) + } + + func testQuantizedEmbeddingStoresGlobalScale() { + let globalScale = MLXArray(1.0, dtype: .float32) + let quantized = QuantizedEmbedding( + weight: MLXArray.zeros([8, 2], dtype: .uint32), + scales: MLXArray.ones([8, 2], dtype: .uint8), + biases: nil, + groupSize: 16, + bits: 4, + mode: .nvfp4, + globalScale: globalScale) + + XCTAssertNotNil(quantized.globalScale) + XCTAssertEqual(quantized.globalScale?.dtype, .float32) + XCTAssertNotNil(quantized.parameters()["global_scale"]) + XCTAssertNil(quantized.parameters()["globalScale"]) + } + + func testQuantizedGlobalScaleIsOptionalParameter() { + let linear = QuantizedLinear( + weight: MLXArray.zeros([8, 4], dtype: .uint32), + bias: nil, + scales: MLXArray.ones([8, 4], dtype: .uint8), + biases: nil, + groupSize: 16, + bits: 4, + mode: .nvfp4) + let embedding = QuantizedEmbedding( + weight: MLXArray.zeros([8, 2], dtype: .uint32), + scales: MLXArray.ones([8, 2], dtype: .uint8), + biases: nil, + groupSize: 16, + bits: 4, + mode: .nvfp4) + + XCTAssertNil(linear.globalScale) + XCTAssertNil(linear.parameters()["global_scale"]) + XCTAssertNil(embedding.globalScale) + XCTAssertNil(embedding.parameters()["global_scale"]) + } }