Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 84 additions & 19 deletions Source/MLXNN/Quantized.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ public protocol Quantized: Module {
var mode: QuantizationMode { get }
}

private func preconditionSupportsGlobalScale() {
precondition(
Device.defaultDevice().deviceType == .cpu,
"globalScale dequantization is not supported on the Metal backend")
}

/// Quantize any ``Quantizable`` layer that is not already quantized.
public func quantizeSingle(
layer: Module, groupSize: Int = 64, bits: Int = 4, mode: QuantizationMode = .affine
Expand Down Expand Up @@ -159,6 +165,7 @@ open class QuantizedEmbedding: Embedding, Quantized {
public let mode: QuantizationMode
public let scales: MLXArray
public let biases: MLXArray?
public let globalScale: MLXArray?

open override var shape: (Int, Int) {
let (embeddingCount, dimensions) = super.shape
Expand All @@ -184,14 +191,20 @@ open class QuantizedEmbedding: Embedding, Quantized {

public init(
weight: MLXArray, groupSize: Int = 64, bits: Int = 4,
mode: QuantizationMode = .affine
mode: QuantizationMode = .affine,
globalScale: MLXArray? = nil
) {
self.groupSize = groupSize
self.bits = bits
self.mode = mode
self.globalScale = globalScale

if globalScale != nil {
preconditionSupportsGlobalScale()
}

let (quantizedWeight, scales, biases) = MLX.quantized(
weight, groupSize: groupSize, bits: bits, mode: mode)
weight, groupSize: groupSize, bits: bits, mode: mode, globalScale: globalScale)

self.scales = scales
self.biases = biases
Expand All @@ -201,20 +214,51 @@ open class QuantizedEmbedding: Embedding, Quantized {
self.freeze()
}

/// Initializer meant for subclasses to provide arrays directly.
public init(
weight: MLXArray, scales: MLXArray, biases: MLXArray?,
groupSize: Int, bits: Int,
mode: QuantizationMode = .affine,
globalScale: MLXArray? = nil
) {
self.groupSize = groupSize
self.bits = bits
self.mode = mode
self.scales = scales
self.biases = biases
self.globalScale = globalScale
super.init(weight: weight)
}

open override func callAsFunction(_ x: MLXArray) -> MLXArray {
if globalScale != nil {
preconditionSupportsGlobalScale()
}

let s = x.shape
let x = x.flattened()
let out = dequantized(
weight[x], scales: scales[x], biases: biases == nil ? nil : biases![x],
groupSize: groupSize, bits: bits, mode: mode)
groupSize: groupSize, bits: bits, mode: mode, globalScale: globalScale)
return out.reshaped(s + [-1])
}

open override func asLinear(_ x: MLXArray) -> MLXArray {
quantizedMM(
if globalScale != nil {
preconditionSupportsGlobalScale()
return matmul(x, dequantizedWeight.T)
}

return quantizedMM(
x, weight, scales: scales, biases: biases, transpose: true, groupSize: groupSize,
bits: bits, mode: mode)
}

private var dequantizedWeight: MLXArray {
dequantized(
weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits, mode: mode,
globalScale: globalScale)
}
}

/// Applies an affine transformation to the input using a quantized weight matrix.
Expand Down Expand Up @@ -243,6 +287,7 @@ open class QuantizedLinear: Linear, Quantized {
public let mode: QuantizationMode
public let scales: MLXArray
public let biases: MLXArray?
public let globalScale: MLXArray?

open override var shape: (Int, Int) {
let shape = weight.shape2
Expand Down Expand Up @@ -292,14 +337,20 @@ open class QuantizedLinear: Linear, Quantized {
/// Initialize a ``QuantizedLinear`` with non-quantized weights and bias.
public init(
weight: MLXArray, bias: MLXArray?, groupSize: Int = 64, bits: Int = 4,
mode: QuantizationMode = .affine
mode: QuantizationMode = .affine,
globalScale: MLXArray? = nil
) {
self.groupSize = groupSize
self.bits = bits
self.mode = mode
self.globalScale = globalScale

if globalScale != nil {
preconditionSupportsGlobalScale()
}

let (quantizedWeight, scales, biases) = MLX.quantized(
weight, groupSize: groupSize, bits: bits, mode: mode)
weight, groupSize: groupSize, bits: bits, mode: mode, globalScale: globalScale)

self.scales = scales
self.biases = biases
Expand All @@ -316,13 +367,15 @@ open class QuantizedLinear: Linear, Quantized {
public init(
weight: MLXArray, bias: MLXArray? = nil, scales: MLXArray, biases: MLXArray?,
groupSize: Int, bits: Int,
mode: QuantizationMode = .affine
mode: QuantizationMode = .affine,
globalScale: MLXArray? = nil
) {
self.groupSize = groupSize
self.bits = bits
self.mode = mode
self.scales = scales
self.biases = biases
self.globalScale = globalScale
super.init(weight: weight, bias: bias)
}

Expand All @@ -334,20 +387,32 @@ open class QuantizedLinear: Linear, Quantized {
}

open override func callAsFunction(_ x: MLXArray) -> MLXArray {
var x = quantizedMM(
x,
weight,
scales: scales,
biases: biases,
transpose: true,
groupSize: groupSize,
bits: bits,
mode: mode
)
var result: MLXArray
if globalScale != nil {
preconditionSupportsGlobalScale()
result = matmul(x, dequantizedWeight.T)
} else {
result = quantizedMM(
x,
weight,
scales: scales,
biases: biases,
transpose: true,
groupSize: groupSize,
bits: bits,
mode: mode
)
}
if let bias {
x = x + bias
result = result + bias
}
return x
return result
}

private var dequantizedWeight: MLXArray {
dequantized(
weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits, mode: mode,
globalScale: globalScale)
}

/// Returns a QuantizedLinear layer that applies the same linear transformation up to the quantization error.
Expand Down
33 changes: 33 additions & 0 deletions Tests/MLXTests/QuantizationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,37 @@ class QuantizationTests: XCTestCase {
let quantized = QuantizedLinear(64, 64, groupSize: 32, bits: 4, mode: .mxfp4)
XCTAssertNil(quantized.biases)
}

func testQuantizedLinearStoresGlobalScale() {
let globalScale = MLXArray(1.0, dtype: .float32)
let quantized = QuantizedLinear(
weight: MLXArray.zeros([8, 4], dtype: .uint32),
bias: nil,
scales: MLXArray.ones([8, 4], dtype: .uint8),
biases: nil,
groupSize: 16,
bits: 4,
mode: .nvfp4,
globalScale: globalScale)

XCTAssertNotNil(quantized.globalScale)
XCTAssertEqual(quantized.globalScale?.dtype, .float32)
XCTAssertNotNil(quantized.parameters()["globalScale"])
}

func testQuantizedEmbeddingStoresGlobalScale() {
let globalScale = MLXArray(1.0, dtype: .float32)
let quantized = QuantizedEmbedding(
weight: MLXArray.zeros([8, 2], dtype: .uint32),
scales: MLXArray.ones([8, 2], dtype: .uint8),
biases: nil,
groupSize: 16,
bits: 4,
mode: .nvfp4,
globalScale: globalScale)

XCTAssertNotNil(quantized.globalScale)
XCTAssertEqual(quantized.globalScale?.dtype, .float32)
XCTAssertNotNil(quantized.parameters()["globalScale"])
}
}