MaterializedArray is a Sendable MLXArray#418
Conversation
| import Foundation | ||
| import Numerics | ||
|
|
||
| public final class MaterializedArray: MLXArray, @unchecked Sendable { |
There was a problem hiding this comment.
A new subtype of MLXArray that is Sendable
|
|
||
| public final class MaterializedArray: MLXArray, @unchecked Sendable { | ||
|
|
||
| init(materialized ctx: consuming mlx_array) { |
There was a problem hiding this comment.
Create using materialize(x) or x.materialized()
|
|
||
| import Cmlx | ||
| import Foundation | ||
| import Numerics |
There was a problem hiding this comment.
Ignore this -- I had to move them to be part of MLXArray (not an extension) so that I could override them
| import Numerics | ||
|
|
||
| public final class MLXArray { | ||
| public class MLXArray: ExpressibleByArrayLiteral { |
There was a problem hiding this comment.
- remove
final-- it is still closed outside the package - I don't think we were seeing any particular benefit from the final
- add ExpressibleByArrayLiteral to the main type so that subclasses can override
| self.init(ctx) | ||
| } | ||
|
|
||
| // MARK: - Inits |
There was a problem hiding this comment.
This is the body of MLXArray+Init.swift -- it is moved here because a subclass cannot override an init (or maybe even a method) that is not in the main class (vs extensions).
| public let weight: MLXArray | ||
| public let bias: MLXArray? | ||
| @ParameterInfo public var weight: MLXArray | ||
| @ParameterInfo public var bias: MLXArray? |
There was a problem hiding this comment.
By making these @ParameterInfo we can replace the MLXArray with a MaterializedArray -- see the tests.
Note: this isn't required but it seemed like it might be useful.
|
|
||
| import MLX | ||
|
|
||
| open class MaterializedModule<LayerType: Module>: Module, @unchecked Sendable { |
There was a problem hiding this comment.
A container for a Module that is Sendable.
| } | ||
| } | ||
|
|
||
| extension MaterializedModule where LayerType: UnaryLayer { |
There was a problem hiding this comment.
We can use this pattern e.g. for LanguageModel in mlx-swift-lm
| func update( | ||
| parameters: ModuleParameters, verify: VerifyUpdate, path: [String] = [], | ||
| modulePath: [String] = [], | ||
| mutate: (Module, String, MLXArray, MLXArray) -> Void |
There was a problem hiding this comment.
Normally we update the backing point in an MLXArray. For materialize() I want to replace the MLXArray instance with a MaterializedArray.
| let x = MLXArray(10) + 5 | ||
| let m = materialize(x) | ||
| let t = Task { | ||
| print(m + 3) |
There was a problem hiding this comment.
Tada, pass MLXArray (materialized) between isolation contexts! The compiler would give an error if we did this with x.
|
|
||
| let t = Task { | ||
| let i = MLXRandom.normal([10, 10]) | ||
| let r = lm(i) |
There was a problem hiding this comment.
The same thing with a Module.
| @@ -1,4 +1,4 @@ | |||
| // swift-tools-version: 5.12 | |||
| // swift-tools-version: 6.2 | |||
There was a problem hiding this comment.
Time to turn on the stricter concurrency checks.
f18e7c2 to
2e0965e
Compare
- once an array has been evaluated we can use it as Sendable - a subtype of MLXArray that is Sendable
2e0965e to
18106c8
Compare
- adopt changes from ml-explore/mlx-swift#418 - we don't need private box types -- the technique becomes general - it also opens up some potential for synchronous evaluation
| open class MaterializedModule<LayerType: Module>: Module, @unchecked Sendable { | ||
|
|
||
| /// Usable by extensions to implement `callAsFunction()` -- anyone else, DO NOT USE. | ||
| public let _base: LayerType |
There was a problem hiding this comment.
Question: would it make sense to avoid exposing this as public or at least rename it to something like _unsafeBase? Since callers can mutate the wrapper module through this reference.
There was a problem hiding this comment.
Good question! It has to be public because:
- the class can't be subclassed as it is final/Sendable (the former being a requirement for the latter)
- since it can't be subclassed it has to be extended and the extensions can only access public members
But renaming it _unsafeBase is a good idea!
I have a staged commit where it marks the held model as immutable -- that will give additional runtime protection.
The integration in mlx-swift-lm for Embedders was good but perhaps too simple. I need to do the same for LLM/VLM and see what pops up.
| @Test | ||
| func testMaterializedLinear() async throws { | ||
| let l = Linear(10, 10) | ||
| let lm = try MaterializedModule(l) |
There was a problem hiding this comment.
Nit: MaterializedModule's init is not throwing, so this test can drop both try and throws.
There was a problem hiding this comment.
Yes, it was at one point but not now.
- remove try
| import MLX | ||
| import MLXNN | ||
| import Testing | ||
|
|
There was a problem hiding this comment.
Suggestion: It would be nice to add a unit test for MaterializedModule.parameters(), so it verify the wrapper still exposes the materialized parameter tree.
- adopt changes from ml-explore/mlx-swift#418 - we don't need private box types -- the technique becomes general - it also opens up some potential for synchronous evaluation
Proposed changes
Looking for feedback on this. I will add documentation once I test it a bit. Check out the tests to see how it works.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes