From 6c48e536740572bd446066bdc953fed4d71b702c Mon Sep 17 00:00:00 2001 From: David Koski Date: Sun, 7 Jun 2026 22:18:45 -0700 Subject: [PATCH 1/6] MaterializedArray is a Sendable MLXArray - once an array has been evaluated we can use it as Sendable - a subtype of MLXArray that is Sendable --- Package.swift | 2 +- Source/MLX/MLXArray+Indexing.swift | 3 +- Source/MLX/MLXArray+Init.swift | 36 ----- Source/MLX/MLXArray.swift | 80 ++++++++--- Source/MLX/MaterializedArray.swift | 75 ++++++++++ Source/MLX/Transforms+Eval.swift | 16 +++ Source/MLXNN/Linear.swift | 4 +- Source/MLXNN/MaterializedModule.swift | 78 +++++++++++ Source/MLXNN/Module.swift | 93 +++++++++++-- Tests/MLXTests/MaterializedTests.swift | 186 +++++++++++++++++++++++++ 10 files changed, 501 insertions(+), 72 deletions(-) create mode 100644 Source/MLX/MaterializedArray.swift create mode 100644 Source/MLXNN/MaterializedModule.swift create mode 100644 Tests/MLXTests/MaterializedTests.swift diff --git a/Package.swift b/Package.swift index 17a4178f..2a5a8c95 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version: 5.12 +// swift-tools-version: 6.2 // The swift-tools-version declares the minimum version of Swift required to build this package. // Copyright © 2024 Apple Inc. diff --git a/Source/MLX/MLXArray+Indexing.swift b/Source/MLX/MLXArray+Indexing.swift index 9a891b5d..bca17018 100644 --- a/Source/MLX/MLXArray+Indexing.swift +++ b/Source/MLX/MLXArray+Indexing.swift @@ -388,8 +388,7 @@ extension MLXArray { var result = mlx_array_new() mlx_scatter( &result, self.ctx, indices_vector, update.ctx, axes, axes.count, stream.ctx) - mlx_array_set(&self.ctx, result) - mlx_array_free(result) + self._updateInternal(MLXArray(result)) return } else { self._updateInternal(update) diff --git a/Source/MLX/MLXArray+Init.swift b/Source/MLX/MLXArray+Init.swift index 01a8daa2..632c69e5 100644 --- a/Source/MLX/MLXArray+Init.swift +++ b/Source/MLX/MLXArray+Init.swift @@ -593,39 +593,3 @@ extension MLXArray { } } - -// MARK: - Expressible by literals - -extension MLXArray: ExpressibleByArrayLiteral { - - // Note: MLXArray does not implement ExpressibleByFloatLiteral etc. because - // we want to create arrays in the context of the other arrays. For example: - // - // let x = MLXArray(1.5, dtype: .float16) - // let r = x + 2.5 - // - // We expect r to have a dtype of float16. See ``ScalarOrArray``. - - /// Initializer allowing creation of 1d `MLXArray` from an array literal. - /// - /// ```swift - /// let a: MLXArray = [1, 2, 3] - /// ``` - /// - /// This is convenient for methods that have `MLXArray` parameters: - /// - /// ```swift - /// print(array.take([1, 2, 3], axis: 0)) - /// ``` - /// - /// ### See Also - /// - - public convenience init(arrayLiteral elements: Int32...) { - let ctx = elements.withUnsafeBufferPointer { ptr in - let shape = [Int32(elements.count)] - return mlx_array_new_data( - ptr.baseAddress!, shape, Int32(shape.count), Int32.dtype.cmlxDtype) - } - self.init(ctx) - } -} diff --git a/Source/MLX/MLXArray.swift b/Source/MLX/MLXArray.swift index 9e6148e0..cf21aad1 100644 --- a/Source/MLX/MLXArray.swift +++ b/Source/MLX/MLXArray.swift @@ -4,7 +4,7 @@ import Cmlx import Foundation import Numerics -public final class MLXArray { +public class MLXArray: ExpressibleByArrayLiteral { /// Internal pointer to the mlx-c wrapper on `mlx::core::array`, used with `Cmlx` interop. public internal(set) var ctx: mlx_array @@ -19,6 +19,37 @@ public final class MLXArray { self.ctx = ctx } + // Note: MLXArray does not implement ExpressibleByFloatLiteral etc. because + // we want to create arrays in the context of the other arrays. For example: + // + // let x = MLXArray(1.5, dtype: .float16) + // let r = x + 2.5 + // + // We expect r to have a dtype of float16. See ``ScalarOrArray``. + + /// Initializer allowing creation of 1d `MLXArray` from an array literal. + /// + /// ```swift + /// let a: MLXArray = [1, 2, 3] + /// ``` + /// + /// This is convenient for methods that have `MLXArray` parameters: + /// + /// ```swift + /// print(array.take([1, 2, 3], axis: 0)) + /// ``` + /// + /// ### See Also + /// - + required public convenience init(arrayLiteral elements: Int32...) { + let ctx = elements.withUnsafeBufferPointer { ptr in + let shape = [Int32(elements.count)] + return mlx_array_new_data( + ptr.baseAddress!, shape, Int32(shape.count), Int32.dtype.cmlxDtype) + } + self.init(ctx) + } + /// return the equivalent of a `.none` MLXArray (for the C API). /// /// Not called `.none` to avoid ambiguity with `Optional`. This can be used @@ -37,7 +68,7 @@ public final class MLXArray { } /// Number of bytes per element - public var itemSize: Int { mlx_array_itemsize(ctx) } + final public var itemSize: Int { mlx_array_itemsize(ctx) } /// Total number of elements in the array /// @@ -46,7 +77,7 @@ public final class MLXArray { /// print(array.size) /// // 12 /// ``` - public var size: Int { mlx_array_size(ctx) } + final public var size: Int { mlx_array_size(ctx) } /// Number of elements in the 0th dimension. /// @@ -62,10 +93,10 @@ public final class MLXArray { /// ... /// } /// ``` - public var count: Int { dim(0) } + final public var count: Int { dim(0) } /// Number of bytes in the array. - public var nbytes: Int { mlx_array_nbytes(ctx) } + final public var nbytes: Int { mlx_array_nbytes(ctx) } /// Number of dimensions in the array. /// @@ -74,7 +105,7 @@ public final class MLXArray { /// print(array.ndim) /// // 2 /// ``` - public var ndim: Int { mlx_array_ndim(ctx) } + final public var ndim: Int { mlx_array_ndim(ctx) } /// Data type of the elements in the array. /// @@ -83,7 +114,7 @@ public final class MLXArray { /// print(array.dtype) /// // .int64 (aka Int.dtype) /// ``` - public var dtype: DType { DType(mlx_array_dtype(ctx)) } + final public var dtype: DType { DType(mlx_array_dtype(ctx)) } /// Dimensions of the array. /// @@ -92,7 +123,7 @@ public final class MLXArray { /// print(array.shape) /// // [3, 4] /// ``` - public var shape: [Int] { + final public var shape: [Int] { let ndim = mlx_array_ndim(ctx) guard ndim > 0 else { return [] } let cShape = mlx_array_shape(ctx)! @@ -104,7 +135,7 @@ public final class MLXArray { /// ```swift /// let (w, h) = array.shape2 /// ``` - public var shape2: (Int, Int) { + final public var shape2: (Int, Int) { let ndim = mlx_array_ndim(ctx) precondition(ndim == 2) let cShape = mlx_array_shape(ctx)! @@ -116,7 +147,7 @@ public final class MLXArray { /// ```swift /// let (w, h, c) = array.shape3 /// ``` - public var shape3: (Int, Int, Int) { + final public var shape3: (Int, Int, Int) { let ndim = mlx_array_ndim(ctx) precondition(ndim == 3) let cShape = mlx_array_shape(ctx)! @@ -128,7 +159,7 @@ public final class MLXArray { /// ```swift /// let (b, w, h, c) = array.shape4 /// ``` - public var shape4: (Int, Int, Int, Int) { + final public var shape4: (Int, Int, Int, Int) { let ndim = mlx_array_ndim(ctx) precondition(ndim == 4) let cShape = mlx_array_shape(ctx)! @@ -149,7 +180,7 @@ public final class MLXArray { /// Strides of the array backing. /// /// Note: this is only stable once the array is evaluated. - var internalStrides: [Int] { + final var internalStrides: [Int] { let ndim = mlx_array_ndim(ctx) guard ndim > 0 else { return [] } let strides = mlx_array_strides(ctx)! @@ -167,7 +198,7 @@ public final class MLXArray { /// // 4.5 /// let value: Float = array[1].item() /// ``` - public func item() -> T { + final public func item() -> T { item(T.self) } @@ -328,7 +359,7 @@ public final class MLXArray { /// // 4.5 /// let value = array[1].item(Float.self) /// ``` - public func item(_ type: T.Type) -> T { + final public func item(_ type: T.Type) -> T { precondition(self.size == 1) eval() @@ -466,7 +497,7 @@ public final class MLXArray { /// print(array.dim(1)) /// // 4 /// ``` - public func dim(_ dim: Int) -> Int { + final public func dim(_ dim: Int) -> Int { Int(mlx_array_dim(ctx, MLX.resolve(axis: dim, ndim: mlx_array_ndim(ctx)).int32)) } @@ -481,7 +512,7 @@ public final class MLXArray { /// print(array.dim(index)) /// // 4 /// ``` - func dim(_ dim: Int32) -> Int32 { + final func dim(_ dim: Int32) -> Int32 { mlx_array_dim(ctx, MLX.resolve(axis: Int(dim), ndim: mlx_array_ndim(ctx)).int32) } @@ -492,7 +523,7 @@ public final class MLXArray { /// /// ### See Also /// - - public func asType(_ type: DType, stream: StreamOrDevice = .default) -> MLXArray { + final public func asType(_ type: DType, stream: StreamOrDevice = .default) -> MLXArray { guard type != self.dtype else { return self } var result = mlx_array_new() mlx_astype(&result, ctx, type.cmlxDtype, stream.ctx) @@ -506,8 +537,9 @@ public final class MLXArray { /// /// ### See Also /// - - public func asType(_ type: (some HasDType).Type, stream: StreamOrDevice = .default) -> MLXArray - { + final public func asType( + _ type: (some HasDType).Type, stream: StreamOrDevice = .default + ) -> MLXArray { asType(type.dtype, stream: stream) } @@ -524,7 +556,7 @@ public final class MLXArray { /// - ``realPart(stream:)`` /// - ``imaginaryPart(stream:)`` /// - - public func asImaginary(stream: StreamOrDevice = .default) -> MLXArray { + final public func asImaginary(stream: StreamOrDevice = .default) -> MLXArray { precondition(!dtype.isComplex) let i = MLXArray(real: 0, imaginary: 1) return self * i @@ -534,7 +566,7 @@ public final class MLXArray { /// /// ### See Also /// - - public func realPart(stream: StreamOrDevice = .default) -> MLXArray { + final public func realPart(stream: StreamOrDevice = .default) -> MLXArray { precondition(dtype.isComplex) return asType(Float.self) } @@ -543,7 +575,7 @@ public final class MLXArray { /// /// ### See Also /// - - public func imaginaryPart(stream: StreamOrDevice = .default) -> MLXArray { + final public func imaginaryPart(stream: StreamOrDevice = .default) -> MLXArray { precondition(dtype.isComplex) let i = MLXArray(real: 0, imaginary: 1) return (self / i).asType(.float32) @@ -559,6 +591,10 @@ public final class MLXArray { } } + public func materialized() -> MaterializedArray { + MLX.materialize(self) + } + /// Replace the contents with a reference to a new array (INTERNAL). /// /// Note: this is an implementation detail and only visible because of the need to call it from diff --git a/Source/MLX/MaterializedArray.swift b/Source/MLX/MaterializedArray.swift new file mode 100644 index 00000000..03fcd211 --- /dev/null +++ b/Source/MLX/MaterializedArray.swift @@ -0,0 +1,75 @@ +// Copyright © 2026 Apple Inc. + +import Cmlx +import Foundation +import Numerics + +public final class MaterializedArray: MLXArray, @unchecked Sendable { + + init(materialized ctx: consuming mlx_array) { + super.init(ctx) + } + + @available(*, unavailable, message: "MaterializedArray can only be created via materialize()") + required public convenience init(arrayLiteral elements: Int32...) { + fatalError("unavailable") + } + + final public override func materialized() -> MaterializedArray { + self + } + + // MARK: - Update sealing + + @available(*, unavailable) + override public func _updateInternal(_ array: MLXArray) { + // Note that this might be called via: + // + // a[1] = b + // a += b + fatalError("unavailable") + } + + // MARK: - In place operators + + @available(*, unavailable) + public static func += (lhs: inout MaterializedArray, rhs: MLXArray) { + fatalError("unavailable") + } + + @available(*, unavailable) + public static func += (lhs: inout MaterializedArray, rhs: some ScalarOrArray) { + fatalError("unavailable") + } + + @available(*, unavailable) + public static func -= (lhs: inout MaterializedArray, rhs: MLXArray) { + fatalError("unavailable") + } + + @available(*, unavailable) + public static func -= (lhs: inout MaterializedArray, rhs: some ScalarOrArray) { + fatalError("unavailable") + } + + @available(*, unavailable) + public static func *= (lhs: inout MaterializedArray, rhs: MLXArray) { + fatalError("unavailable") + } + + @available(*, unavailable) + public static func *= (lhs: inout MaterializedArray, rhs: some ScalarOrArray) { + fatalError("unavailable") + } + + @available(*, unavailable) + public static func /= (lhs: inout MaterializedArray, rhs: MLXArray) { + fatalError("unavailable") + } + + @available(*, unavailable) + public static func /= (lhs: inout MaterializedArray, rhs: some ScalarOrArray) { + fatalError("unavailable") + } + +} diff --git a/Source/MLX/Transforms+Eval.swift b/Source/MLX/Transforms+Eval.swift index 46fe9c59..73e4d365 100644 --- a/Source/MLX/Transforms+Eval.swift +++ b/Source/MLX/Transforms+Eval.swift @@ -8,6 +8,22 @@ import Foundation /// call back into eval. let evalLock = NSRecursiveLock() +public func materialize(_ array: MLXArray) -> MaterializedArray { + eval(array) + var m = mlx_array_new() + mlx_array_set(&m, array.ctx) + return MaterializedArray(materialized: m) +} + +public func materialize(_ arrays: [MLXArray]) -> [MaterializedArray] { + eval(arrays) + return arrays.map { + var m = mlx_array_new() + mlx_array_set(&m, $0.ctx) + return MaterializedArray(materialized: m) + } +} + /// Evaluate one or more `MLXArray` /// /// ### See Also diff --git a/Source/MLXNN/Linear.swift b/Source/MLXNN/Linear.swift index dd52c422..df246766 100644 --- a/Source/MLXNN/Linear.swift +++ b/Source/MLXNN/Linear.swift @@ -69,8 +69,8 @@ public class Identity: Module, UnaryLayer { /// - ``Bilinear`` open class Linear: Module, UnaryLayer, Quantizable { - public let weight: MLXArray - public let bias: MLXArray? + @ParameterInfo public var weight: MLXArray + @ParameterInfo public var bias: MLXArray? open var shape: (Int, Int) { weight.shape2 diff --git a/Source/MLXNN/MaterializedModule.swift b/Source/MLXNN/MaterializedModule.swift new file mode 100644 index 00000000..3dc45890 --- /dev/null +++ b/Source/MLXNN/MaterializedModule.swift @@ -0,0 +1,78 @@ +// Copyright © 2026 Apple Inc. + +import MLX + +open class MaterializedModule: Module, @unchecked Sendable { + + let base: LayerType + + public init(_ base: consuming LayerType) throws { + self.base = base + try self.base.materialize() + + // force caching of accessors (buildCaches) + _ = base.items() + } + + override func materialize() throws { + // NOP + } + + @available(*, unavailable) + @discardableResult + open override func update( + parameters: ModuleParameters, verify: VerifyUpdate, path: [String] = [], + modulePath: [String] = [] + ) throws -> Self { + fatalError("unavailable") + } + + @available(*, unavailable) + @discardableResult + open override func apply( + filter: (Module, String, ModuleItem) -> Bool = Module.filterValidParameters, + map: @escaping (MLXArray) -> MLXArray + ) -> Self { + fatalError("unavailable") + } + + @available(*, unavailable) + @discardableResult + open override func update( + modules: ModuleChildren, verify: VerifyUpdate, path: [String] = [], + modulePath: [String] = [] + ) throws -> Self { + fatalError("unavailable") + } + + @available(*, unavailable) + open override func updateModule(key: String, _ value: Any) throws { + fatalError("unavailable") + } + + @available(*, unavailable) + public override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) + throws + { + fatalError("unavailable") + } + + @available(*, unavailable) + public override func unfreeze( + recursive: Bool = true, keys: [String]? = nil, strict: Bool = false + ) throws { + fatalError("unavailable") + } + + @available(*, unavailable) + public override func train(_ mode: Bool = true) { + fatalError("unavailable") + } + +} + +extension MaterializedModule where LayerType: UnaryLayer { + public func callAsFunction(_ x: MLXArray) -> MLXArray { + base(x) + } +} diff --git a/Source/MLXNN/Module.swift b/Source/MLXNN/Module.swift index 7cba9b12..466258a2 100644 --- a/Source/MLXNN/Module.swift +++ b/Source/MLXNN/Module.swift @@ -123,6 +123,8 @@ open class Module { if let (_, _, setter) = isModuleInfo(c.value) { setters[key] = setter + } else if let provider = c.value as? TypeErasedSetterProvider { + setters[key] = provider.typeErasedSetter() } } @@ -449,7 +451,18 @@ open class Module { parameters: ModuleParameters, verify: VerifyUpdate, path: [String] = [], modulePath: [String] = [] ) throws -> Self { + try update(parameters: parameters, verify: verify, path: path, modulePath: modulePath) { + m, k, a, v in + a._updateInternal(v) + } + return self + } + func update( + parameters: ModuleParameters, verify: VerifyUpdate, path: [String] = [], + modulePath: [String] = [], + mutate: (Module, String, MLXArray, MLXArray) -> Void + ) throws { let modulePath = modulePath + [describeType(self)] func apply( @@ -471,7 +484,7 @@ open class Module { path: path, modules: modulePath, expectedShape: p.shape, actualShape: newArray.shape) } - p._updateInternal(newArray) + mutate(self, key, p, newArray) case (.value(.parameters), .none): if Self.parameterIsValid(key) { @@ -517,12 +530,12 @@ open class Module { case (.value(.module(let module)), .dictionary(let values)): try module.update( parameters: NestedDictionary(values: values), verify: verify, path: path, - modulePath: modulePath) + modulePath: modulePath, mutate: mutate) case (.value(.module(let module)), .none): try module.update( parameters: NestedDictionary(), verify: verify, path: path, - modulePath: modulePath) + modulePath: modulePath, mutate: mutate) case (.none, .none), (.value(.none), .none), (.value(.other(_)), .none): break @@ -548,8 +561,34 @@ open class Module { throw UpdateError.unhandledKeys( path: path, modules: modulePath, keys: processed.sorted()) } + } - return self + func materialize() throws { + // bulk eval the parameters + eval(self.parameters()) + + // now convert to MaterializedArray (where possible) + let newParameters = filterMap( + filter: { _, _, _ in true }, + map: Self.mapParameters(map: { $0.materialized() as MLXArray })) + + try update(parameters: newParameters, verify: .none) { m, k, a, v in + if let setter = m._setters?[k] { + do { + // use the setter to replace the array + try setter.update(v) + } catch { + a._updateInternal(v) + } + } else { + a._updateInternal(v) + } + } + + // some of the properties were updated so rebuild the properties cache + visit { key, m in + m.buildCaches() + } } /// Called from ``update(parameters:verify:path:modulePath:)`` if a required parameter @@ -797,7 +836,7 @@ open class Module { if let setter = _setters?[key] { do { - try setter.updateModule(value) + try setter.update(value) } catch { throw UpdateError.needModuleInfo( "Unable to set modules for \(describeType(self)).\(key) -- maybe type mismatch: \(describeType(value)), \(error)" @@ -1400,7 +1439,7 @@ public enum ModuleValue { /// ### See Also /// - /// - ``ModuleInfo`` -@propertyWrapper public class ParameterInfo { +@propertyWrapper public class ParameterInfo: TypeErasedSetterProvider { var value: T? let key: String? @@ -1446,11 +1485,47 @@ public enum ModuleValue { // cannot check via unwapProperty -- see wrappedValue.set } + + struct Setter: TypeErasedSetter { + unowned var info: ParameterInfo + + func update(_ value: Any) throws { + if let value = value as? T { + info.value = value + } else if let value = value as? [MLXArray] { + // try to recast as a tuple, e.g. + // @ParameterInfo var x: (MLXArray, MLXArray) + + if value.count == 2, let values = (value[0], value[1]) as? T { + info.value = values + } else if value.count == 3, let values = (value[0], value[1], value[2]) as? T { + info.value = values + } else if value.count == 4, + let values = (value[0], value[1], value[2], value[3]) as? T + { + info.value = values + } else if value.count == 5, + let values = (value[0], value[1], value[2], value[3], value[4]) as? T + { + info.value = values + } else { + throw UpdateError.unableToCast(String(describing: T.self)) + } + } else { + throw UpdateError.unableToCast(String(describing: T.self)) + } + } + } + + fileprivate func typeErasedSetter() -> TypeErasedSetter { + Setter(info: self) + } + } /// Helper protocol for writing back through ``ModuleInfo``, e.g. via ``Module/update(modules:)`` private protocol TypeErasedSetter { - func updateModule(_ value: Any) throws + func update(_ value: Any) throws } private protocol TypeErasedSetterProvider { @@ -1561,7 +1636,7 @@ private protocol TypeErasedSetterProvider { struct Setter: TypeErasedSetter { unowned var info: ModuleInfo - func updateModule(_ value: Any) throws { + func update(_ value: Any) throws { if let value = value as? T { info.module = value } else if let value = value as? [Module] { @@ -1577,7 +1652,7 @@ private protocol TypeErasedSetterProvider { { info.module = values } else if value.count == 5, - let values = (value[0], value[1], value[2], value[4], value[5]) as? T + let values = (value[0], value[1], value[2], value[3], value[4]) as? T { info.module = values } else { diff --git a/Tests/MLXTests/MaterializedTests.swift b/Tests/MLXTests/MaterializedTests.swift new file mode 100644 index 00000000..db83b435 --- /dev/null +++ b/Tests/MLXTests/MaterializedTests.swift @@ -0,0 +1,186 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import MLX +import MLXNN +import Testing + +@Test +func testMaterialize() async { + // a materialized array is Sendable + let x = MLXArray(10) + 5 + let m = materialize(x) + let t = Task { + print(m + 3) + } + _ = await t.result +} + +@Test +func testCompileMaterialized() async { + // compile manipulates the input arrays (tracers) + // make sure MaterializedArray doesn't run into problems here + func f(_ a: MLXArray, _ b: MLXArray) -> MLXArray { + square(a * b) + } + + let c = compile(f) + + let i1 = MLXRandom.normal([20, 20]).materialized() + let i2 = MLXRandom.normal([20, 20]).materialized() + + let t = Task { + let s = sum(c(i1, i2)) + let s2 = sum(f(i1, i2)) + #expect(s.allClose(s2).item(Bool.self)) + print(s, s2) + } + _ = await t.result +} + +@Test +func testMaterializedLinear() async throws { + let l = Linear(10, 10) + let lm = try MaterializedModule(l) + + // this will have been materialized in the call + #expect(l.weight is MaterializedArray) + + let t = Task { + let i = MLXRandom.normal([10, 10]) + let r = lm(i) + print(sum(r)) + print(lm) + } + _ = await t.result + +} + +@Test +func testMaterializedMultithreadedEval() async throws { + // Exercise concurrent evaluation: produce MaterializedArrays on the + // main task, fan them out to many child tasks, and have each task do + // its own work (creates new arrays, evals, reads scalars) at the same + // time. This stresses the evalLock and the Sendable contract of + // MaterializedArray. + + let taskCount = 16 + let iterations = 8 + let shape = [32, 32] + + // shared inputs created up-front and materialized so they can cross + // task boundaries + let a = MLXRandom.normal(shape).materialized() + let b = MLXRandom.normal(shape).materialized() + + // expected reference value computed serially on the main task + let expected = sum(square(a * b)).item(Float.self) + + try await withThrowingTaskGroup(of: Float.self) { group in + for i in 0 ..< taskCount { + group.addTask { + var last: Float = 0 + for _ in 0 ..< iterations { + // mix the shared inputs with task-local arrays so that + // each task is producing fresh graphs and evaluating + // them concurrently + let local = MLXRandom.normal(shape) + let r = sum(square(a * b) + (local - local)) + last = r.item(Float.self) + + // also exercise materialize from inside a task + let m = (a + MLXArray(Float(i))).materialized() + _ = (m - MLXArray(Float(i))).sum().item(Float.self) + } + return last + } + } + + for try await value in group { + #expect(abs(value - expected) < 1e-2) + } + } +} + +@Test +func testMaterializedHighContention() async throws { + // High-contention variant: many tasks producing and consuming + // MaterializedArrays through a shared actor-protected pool. Every + // task in a tight loop pulls two arrays from the pool, computes a + // new one, materializes it (forcing an eval), reads a scalar, and + // pushes the result back into the pool. Lots of arrays flowing + // between tasks, lots of concurrent eval calls hammering the + // evalLock. + + actor Pool { + var arrays: [MaterializedArray] + init(_ initial: [MaterializedArray]) { self.arrays = initial } + + func take() -> MaterializedArray { + arrays.randomElement()! + } + + func replace(_ a: MaterializedArray) { + arrays[Int.random(in: 0 ..< arrays.count)] = a + } + + func snapshot() -> [MaterializedArray] { + arrays + } + } + + let shape = [16, 16] + let poolSize = 16 + let taskCount = 32 + let iterations = 50 + + let initial = (0 ..< poolSize).map { _ in + MLXRandom.normal(shape).materialized() + } + let pool = Pool(initial) + + // sum + count tracker so we can assert no task silently dropped work + actor Counter { + var n = 0 + func bump() { n += 1 } + func value() -> Int { n } + } + let counter = Counter() + + try await withThrowingTaskGroup(of: Void.self) { group in + for t in 0 ..< taskCount { + group.addTask { + for k in 0 ..< iterations { + // pull two arrays from the shared pool — these may be + // referenced concurrently by other tasks at the same time + let a = await pool.take() + let b = await pool.take() + + // build a new graph, materialize it (forces eval inside + // the task), and read a scalar (forces another eval). + // tanh keeps values bounded to [-1, 1] so the pool can + // recycle results across iterations without blowing up. + let mixin = MLXArray(Float(t * 31 + k) * 1e-3) + let r = tanh((a * b) + (a - b) + mixin).materialized() + let s = r.sum().item(Float.self) + #expect(s.isFinite) + + // put the result back so other tasks see fresh arrays + await pool.replace(r) + await counter.bump() + } + } + } + try await group.waitForAll() + } + + #expect(await counter.value() == taskCount * iterations) + + // every entry in the final pool should still be a valid, finite array + let final = await pool.snapshot() + #expect(final.count == poolSize) + for a in final { + #expect(a.shape == shape) + #expect(sum(a).item(Float.self).isFinite) + } +} From 18106c8b266f77d2935278ae3ad01b21ca2b5809 Mon Sep 17 00:00:00 2001 From: David Koski Date: Sun, 7 Jun 2026 22:34:35 -0700 Subject: [PATCH 2/6] add documentation --- Source/MLX/MLXArray.swift | 13 ++++++ Source/MLX/MaterializedArray.swift | 33 +++++++++++++ Source/MLX/Transforms+Eval.swift | 24 ++++++++++ Source/MLXNN/MaterializedModule.swift | 67 +++++++++++++++++++++++++++ 4 files changed, 137 insertions(+) diff --git a/Source/MLX/MLXArray.swift b/Source/MLX/MLXArray.swift index cf21aad1..794f0ad1 100644 --- a/Source/MLX/MLXArray.swift +++ b/Source/MLX/MLXArray.swift @@ -591,6 +591,19 @@ public class MLXArray: ExpressibleByArrayLiteral { } } + /// Force evaluation and return a ``MaterializedArray`` snapshot of `self`. + /// + /// The returned array is fully evaluated, immutable, and `Sendable` — see + /// ``MaterializedArray`` for the guarantees and trade-offs. Use this when + /// you need to hand an array across task or actor boundaries, or when you + /// need a stable reference that cannot be mutated by other code holding + /// the original `MLXArray`. + /// + /// This is a thin convenience over ``materialize(_:)->MaterializedArray``. + /// + /// ### See Also + /// - ``MaterializedArray`` + /// - ``materialize(_:)->MaterializedArray`` public func materialized() -> MaterializedArray { MLX.materialize(self) } diff --git a/Source/MLX/MaterializedArray.swift b/Source/MLX/MaterializedArray.swift index 03fcd211..c902cb74 100644 --- a/Source/MLX/MaterializedArray.swift +++ b/Source/MLX/MaterializedArray.swift @@ -4,6 +4,39 @@ import Cmlx import Foundation import Numerics +/// A fully-evaluated, immutable ``MLXArray`` that is safe to share across +/// concurrency domains. +/// +/// `MLXArray` is normally lazy: it is a handle to a node in a computation +/// graph that is not realized until something forces it (a scalar read, +/// an ``eval(_:)-(MLXArray...)`` call, etc.). These unrealized arrays are +/// not thread safe -- they require mutation for their evaluation. +/// +/// `MaterializedArray` is a snapshot that closes that gap: +/// +/// - The contents are evaluated at the moment of construction, so no further +/// graph work is pending. +/// - You can only create an instance via ``MLXArray/materialized()`` +/// or ``materialize(_:)->MaterializedArray`` +/// - Mutation methods are marked as unavailable and will `fatalError` +/// if you somehow manage to call them. +/// - It is declared `@unchecked Sendable` and may be passed freely between +/// tasks, actors, and other concurrency boundaries. +/// +/// Construction is intentionally narrow. Obtain one via: +/// +/// ```swift +/// let m1 = a.materialized() +/// let m2 = materialize(a) +/// ``` +/// +/// A `MaterializedArray` is itself an ``MLXArray`` and can be used anywhere +/// an `MLXArray` is accepted. Operations involving one still produce +/// ordinary (lazy) `MLXArray` results — only the snapshot itself is frozen. +/// +/// ### See Also +/// - ``MLXArray/materialized()`` +/// - ``materialize(_:)->MaterializedArray`` public final class MaterializedArray: MLXArray, @unchecked Sendable { init(materialized ctx: consuming mlx_array) { diff --git a/Source/MLX/Transforms+Eval.swift b/Source/MLX/Transforms+Eval.swift index 73e4d365..86a06263 100644 --- a/Source/MLX/Transforms+Eval.swift +++ b/Source/MLX/Transforms+Eval.swift @@ -8,6 +8,20 @@ import Foundation /// call back into eval. let evalLock = NSRecursiveLock() +/// Evaluate `array` and return a ``MaterializedArray`` snapshot of its contents. +/// +/// The returned array is fully evaluated, immutable, and `Sendable`, so it can +/// be passed across task and actor boundaries. See ``MaterializedArray`` for +/// the full set of guarantees. +/// +/// `array` itself is unaffected — it remains the same lazy ``MLXArray`` it was +/// before the call — but because evaluation is forced, any pending graph work +/// behind it has been realized as a side effect. +/// +/// ### See Also +/// - ``MaterializedArray`` +/// - ``MLXArray/materialized()`` +/// - ``materialize(_:)`` public func materialize(_ array: MLXArray) -> MaterializedArray { eval(array) var m = mlx_array_new() @@ -15,6 +29,16 @@ public func materialize(_ array: MLXArray) -> MaterializedArray { return MaterializedArray(materialized: m) } +/// Evaluate `arrays` and return a ``MaterializedArray`` snapshot for each one. +/// +/// Equivalent to calling ``materialize(_:)`` on each element, but evaluates the +/// whole batch in a single ``eval(_:)-(Sequence)`` call so MLX can schedule the +/// work together. Prefer this overload when you need to materialize several +/// arrays at once — for example, the parameters of a model. +/// +/// ### See Also +/// - ``MaterializedArray`` +/// - ``materialize(_:)`` public func materialize(_ arrays: [MLXArray]) -> [MaterializedArray] { eval(arrays) return arrays.map { diff --git a/Source/MLXNN/MaterializedModule.swift b/Source/MLXNN/MaterializedModule.swift index 3dc45890..bee6f4cd 100644 --- a/Source/MLXNN/MaterializedModule.swift +++ b/Source/MLXNN/MaterializedModule.swift @@ -2,6 +2,73 @@ import MLX +/// A `Module` whose parameters have been materialized so that the whole +/// module is safe to share across concurrency domains. +/// +/// A normal ``Module`` is not `Sendable`: its parameters are ``MLXArray`` +/// instances, which are lazy and may be mutated in place during evaluation +/// or training. `MaterializedModule` wraps a base module, evaluates every +/// parameter, replaces each with a ``MaterializedArray``, and seals the +/// mutation surface so the wrapped module cannot be modified through this +/// reference. +/// +/// ## Construction and the `consuming` contract +/// +/// The initializer takes its base module as `consuming`: +/// +/// ```swift +/// let lm = try MaterializedModule(Linear(10, 10)) +/// ``` +/// +/// `consuming` expresses intent — **the caller must not retain or use the +/// original `base` reference after passing it in**. Because `Module` is a +/// reference type, Swift cannot enforce this for you: a caller who keeps a +/// reference and later mutates it will violate the `Sendable` invariant +/// that `MaterializedModule` relies on. In other words, `Sendable` here is +/// a contract you can follow rather than a guarantee the compiler proves. +/// The recommended pattern is to construct the base module inline at the +/// `MaterializedModule(...)` call site, as shown above, so no other +/// reference exists. +/// +/// ## What is sealed +/// +/// The following `Module` operations are marked `@available(*, unavailable)` +/// on `MaterializedModule` and will trap if called: +/// +/// - `update(parameters:...)` and `update(modules:...)` +/// - `updateModule(key:_:)` +/// - `apply(filter:map:)` +/// - `freeze(...)` / `unfreeze(...)` +/// - `train(_:)` +/// +/// ## Calling the wrapped module +/// +/// `MaterializedModule` does not itself know how to invoke `base`; that is +/// added per-layer-shape via an extension that constrains `LayerType`. +/// For example, every ``UnaryLayer`` already supports being called with a +/// single `MLXArray`, so the package ships: +/// +/// ```swift +/// extension MaterializedModule where LayerType: UnaryLayer { +/// public func callAsFunction(_ x: MLXArray) -> MLXArray { +/// base(x) +/// } +/// } +/// ``` +/// +/// Layers with different call signatures can be wrapped the same way. For a +/// transformer-style block that takes a tensor and an attention mask: +/// +/// ```swift +/// extension MaterializedModule where LayerType: AttentionBlock { +/// public func callAsFunction(_ x: MLXArray, mask: MLXArray?) -> MLXArray { +/// base(x, mask: mask) +/// } +/// } +/// ``` +/// +/// The pattern is always the same: constrain `LayerType` to the protocol or +/// concrete type that exposes the call you want, then forward to `base`. open class MaterializedModule: Module, @unchecked Sendable { let base: LayerType From ba24c6dd00ee398310cdcf59b5125f21596801b5 Mon Sep 17 00:00:00 2001 From: David Koski Date: Sun, 7 Jun 2026 22:41:34 -0700 Subject: [PATCH 3/6] fix build issue --- Source/MLXNN/MaterializedModule.swift | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Source/MLXNN/MaterializedModule.swift b/Source/MLXNN/MaterializedModule.swift index bee6f4cd..90978638 100644 --- a/Source/MLXNN/MaterializedModule.swift +++ b/Source/MLXNN/MaterializedModule.swift @@ -78,7 +78,9 @@ open class MaterializedModule: Module, @unchecked Sendable { try self.base.materialize() // force caching of accessors (buildCaches) - _ = base.items() + _ = self.base.items() + super.init() + _ = self.items() } override func materialize() throws { From 0ceeada3691aabd4b257e93f6fbcf603e11b9227 Mon Sep 17 00:00:00 2001 From: David Koski Date: Sun, 7 Jun 2026 22:44:28 -0700 Subject: [PATCH 4/6] no cross module links --- Source/MLXNN/MaterializedModule.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Source/MLXNN/MaterializedModule.swift b/Source/MLXNN/MaterializedModule.swift index 90978638..26009d55 100644 --- a/Source/MLXNN/MaterializedModule.swift +++ b/Source/MLXNN/MaterializedModule.swift @@ -5,10 +5,10 @@ import MLX /// A `Module` whose parameters have been materialized so that the whole /// module is safe to share across concurrency domains. /// -/// A normal ``Module`` is not `Sendable`: its parameters are ``MLXArray`` +/// A normal ``Module`` is not `Sendable`: its parameters are `MLXArray` /// instances, which are lazy and may be mutated in place during evaluation /// or training. `MaterializedModule` wraps a base module, evaluates every -/// parameter, replaces each with a ``MaterializedArray``, and seals the +/// parameter, replaces each with a `MaterializedArray`, and seals the /// mutation surface so the wrapped module cannot be modified through this /// reference. /// From 500ef2b6e5b898dc5b8b3455542100d60bba5c51 Mon Sep 17 00:00:00 2001 From: David Koski Date: Sun, 7 Jun 2026 23:24:23 -0700 Subject: [PATCH 5/6] base needs to be public so other packages can extend --- Source/MLXNN/MaterializedModule.swift | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/Source/MLXNN/MaterializedModule.swift b/Source/MLXNN/MaterializedModule.swift index 26009d55..0b16a817 100644 --- a/Source/MLXNN/MaterializedModule.swift +++ b/Source/MLXNN/MaterializedModule.swift @@ -51,7 +51,7 @@ import MLX /// ```swift /// extension MaterializedModule where LayerType: UnaryLayer { /// public func callAsFunction(_ x: MLXArray) -> MLXArray { -/// base(x) +/// _base(x) /// } /// } /// ``` @@ -71,14 +71,16 @@ import MLX /// concrete type that exposes the call you want, then forward to `base`. open class MaterializedModule: Module, @unchecked Sendable { - let base: LayerType + /// Usable by extensions to implement `callAsFunction()` -- anyone else, DO NOT USE. + public let _base: LayerType public init(_ base: consuming LayerType) throws { - self.base = base - try self.base.materialize() + self._base = base + try self._base.materialize() - // force caching of accessors (buildCaches) - _ = self.base.items() + // force caching of accessors (buildCaches) as + // these are not thread safe + _ = self._base.items() super.init() _ = self.items() } @@ -142,6 +144,6 @@ open class MaterializedModule: Module, @unchecked Sendable { extension MaterializedModule where LayerType: UnaryLayer { public func callAsFunction(_ x: MLXArray) -> MLXArray { - base(x) + _base(x) } } From a3fde72ad9bd254f2f08f8d5128d9c597eef595b Mon Sep 17 00:00:00 2001 From: David Koski Date: Tue, 9 Jun 2026 15:41:23 -0700 Subject: [PATCH 6/6] relax the throw --- Source/MLXNN/MaterializedModule.swift | 6 +++--- Source/MLXNN/Module.swift | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/Source/MLXNN/MaterializedModule.swift b/Source/MLXNN/MaterializedModule.swift index 0b16a817..f393365c 100644 --- a/Source/MLXNN/MaterializedModule.swift +++ b/Source/MLXNN/MaterializedModule.swift @@ -74,9 +74,9 @@ open class MaterializedModule: Module, @unchecked Sendable { /// Usable by extensions to implement `callAsFunction()` -- anyone else, DO NOT USE. public let _base: LayerType - public init(_ base: consuming LayerType) throws { + public init(_ base: consuming LayerType) { self._base = base - try self._base.materialize() + self._base.materialize() // force caching of accessors (buildCaches) as // these are not thread safe @@ -85,7 +85,7 @@ open class MaterializedModule: Module, @unchecked Sendable { _ = self.items() } - override func materialize() throws { + override func materialize() { // NOP } diff --git a/Source/MLXNN/Module.swift b/Source/MLXNN/Module.swift index 466258a2..7ca087c1 100644 --- a/Source/MLXNN/Module.swift +++ b/Source/MLXNN/Module.swift @@ -563,7 +563,7 @@ open class Module { } } - func materialize() throws { + func materialize() { // bulk eval the parameters eval(self.parameters()) @@ -572,7 +572,9 @@ open class Module { filter: { _, _, _ in true }, map: Self.mapParameters(map: { $0.materialized() as MLXArray })) - try update(parameters: newParameters, verify: .none) { m, k, a, v in + // not verifying and setting with same value -- any + // errors are programming errors + try! update(parameters: newParameters, verify: .none) { m, k, a, v in if let setter = m._setters?[k] { do { // use the setter to replace the array