diff --git a/Source/MLXNN/Quantized.swift b/Source/MLXNN/Quantized.swift index 076e91ca..e529d324 100644 --- a/Source/MLXNN/Quantized.swift +++ b/Source/MLXNN/Quantized.swift @@ -27,6 +27,12 @@ public protocol Quantized: Module { var mode: QuantizationMode { get } } +private func preconditionSupportsGlobalScale() { + precondition( + Device.defaultDevice().deviceType == .cpu, + "globalScale dequantization is not supported on the Metal backend") +} + /// Quantize any ``Quantizable`` layer that is not already quantized. public func quantizeSingle( layer: Module, groupSize: Int = 64, bits: Int = 4, mode: QuantizationMode = .affine @@ -159,6 +165,7 @@ open class QuantizedEmbedding: Embedding, Quantized { public let mode: QuantizationMode public let scales: MLXArray public let biases: MLXArray? + public let globalScale: MLXArray? open override var shape: (Int, Int) { let (embeddingCount, dimensions) = super.shape @@ -184,14 +191,20 @@ open class QuantizedEmbedding: Embedding, Quantized { public init( weight: MLXArray, groupSize: Int = 64, bits: Int = 4, - mode: QuantizationMode = .affine + mode: QuantizationMode = .affine, + globalScale: MLXArray? = nil ) { self.groupSize = groupSize self.bits = bits self.mode = mode + self.globalScale = globalScale + + if globalScale != nil { + preconditionSupportsGlobalScale() + } let (quantizedWeight, scales, biases) = MLX.quantized( - weight, groupSize: groupSize, bits: bits, mode: mode) + weight, groupSize: groupSize, bits: bits, mode: mode, globalScale: globalScale) self.scales = scales self.biases = biases @@ -201,20 +214,51 @@ open class QuantizedEmbedding: Embedding, Quantized { self.freeze() } + /// Initializer meant for subclasses to provide arrays directly. + public init( + weight: MLXArray, scales: MLXArray, biases: MLXArray?, + groupSize: Int, bits: Int, + mode: QuantizationMode = .affine, + globalScale: MLXArray? = nil + ) { + self.groupSize = groupSize + self.bits = bits + self.mode = mode + self.scales = scales + self.biases = biases + self.globalScale = globalScale + super.init(weight: weight) + } + open override func callAsFunction(_ x: MLXArray) -> MLXArray { + if globalScale != nil { + preconditionSupportsGlobalScale() + } + let s = x.shape let x = x.flattened() let out = dequantized( weight[x], scales: scales[x], biases: biases == nil ? nil : biases![x], - groupSize: groupSize, bits: bits, mode: mode) + groupSize: groupSize, bits: bits, mode: mode, globalScale: globalScale) return out.reshaped(s + [-1]) } open override func asLinear(_ x: MLXArray) -> MLXArray { - quantizedMM( + if globalScale != nil { + preconditionSupportsGlobalScale() + return matmul(x, dequantizedWeight.T) + } + + return quantizedMM( x, weight, scales: scales, biases: biases, transpose: true, groupSize: groupSize, bits: bits, mode: mode) } + + private var dequantizedWeight: MLXArray { + dequantized( + weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits, mode: mode, + globalScale: globalScale) + } } /// Applies an affine transformation to the input using a quantized weight matrix. @@ -243,6 +287,7 @@ open class QuantizedLinear: Linear, Quantized { public let mode: QuantizationMode public let scales: MLXArray public let biases: MLXArray? + public let globalScale: MLXArray? open override var shape: (Int, Int) { let shape = weight.shape2 @@ -292,14 +337,20 @@ open class QuantizedLinear: Linear, Quantized { /// Initialize a ``QuantizedLinear`` with non-quantized weights and bias. public init( weight: MLXArray, bias: MLXArray?, groupSize: Int = 64, bits: Int = 4, - mode: QuantizationMode = .affine + mode: QuantizationMode = .affine, + globalScale: MLXArray? = nil ) { self.groupSize = groupSize self.bits = bits self.mode = mode + self.globalScale = globalScale + + if globalScale != nil { + preconditionSupportsGlobalScale() + } let (quantizedWeight, scales, biases) = MLX.quantized( - weight, groupSize: groupSize, bits: bits, mode: mode) + weight, groupSize: groupSize, bits: bits, mode: mode, globalScale: globalScale) self.scales = scales self.biases = biases @@ -316,13 +367,15 @@ open class QuantizedLinear: Linear, Quantized { public init( weight: MLXArray, bias: MLXArray? = nil, scales: MLXArray, biases: MLXArray?, groupSize: Int, bits: Int, - mode: QuantizationMode = .affine + mode: QuantizationMode = .affine, + globalScale: MLXArray? = nil ) { self.groupSize = groupSize self.bits = bits self.mode = mode self.scales = scales self.biases = biases + self.globalScale = globalScale super.init(weight: weight, bias: bias) } @@ -334,20 +387,32 @@ open class QuantizedLinear: Linear, Quantized { } open override func callAsFunction(_ x: MLXArray) -> MLXArray { - var x = quantizedMM( - x, - weight, - scales: scales, - biases: biases, - transpose: true, - groupSize: groupSize, - bits: bits, - mode: mode - ) + var result: MLXArray + if globalScale != nil { + preconditionSupportsGlobalScale() + result = matmul(x, dequantizedWeight.T) + } else { + result = quantizedMM( + x, + weight, + scales: scales, + biases: biases, + transpose: true, + groupSize: groupSize, + bits: bits, + mode: mode + ) + } if let bias { - x = x + bias + result = result + bias } - return x + return result + } + + private var dequantizedWeight: MLXArray { + dequantized( + weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits, mode: mode, + globalScale: globalScale) } /// Returns a QuantizedLinear layer that applies the same linear transformation up to the quantization error. diff --git a/Tests/MLXTests/QuantizationTests.swift b/Tests/MLXTests/QuantizationTests.swift index 0edbd545..58c70ea5 100644 --- a/Tests/MLXTests/QuantizationTests.swift +++ b/Tests/MLXTests/QuantizationTests.swift @@ -39,4 +39,37 @@ class QuantizationTests: XCTestCase { let quantized = QuantizedLinear(64, 64, groupSize: 32, bits: 4, mode: .mxfp4) XCTAssertNil(quantized.biases) } + + func testQuantizedLinearStoresGlobalScale() { + let globalScale = MLXArray(1.0, dtype: .float32) + let quantized = QuantizedLinear( + weight: MLXArray.zeros([8, 4], dtype: .uint32), + bias: nil, + scales: MLXArray.ones([8, 4], dtype: .uint8), + biases: nil, + groupSize: 16, + bits: 4, + mode: .nvfp4, + globalScale: globalScale) + + XCTAssertNotNil(quantized.globalScale) + XCTAssertEqual(quantized.globalScale?.dtype, .float32) + XCTAssertNotNil(quantized.parameters()["globalScale"]) + } + + func testQuantizedEmbeddingStoresGlobalScale() { + let globalScale = MLXArray(1.0, dtype: .float32) + let quantized = QuantizedEmbedding( + weight: MLXArray.zeros([8, 2], dtype: .uint32), + scales: MLXArray.ones([8, 2], dtype: .uint8), + biases: nil, + groupSize: 16, + bits: 4, + mode: .nvfp4, + globalScale: globalScale) + + XCTAssertNotNil(quantized.globalScale) + XCTAssertEqual(quantized.globalScale?.dtype, .float32) + XCTAssertNotNil(quantized.parameters()["globalScale"]) + } }