Use MaterializedArray for Sendable conformance#335
Conversation
| private let context: SerialAccessContainer<EmbedderModelContext> | ||
| private let context: EmbedderModelContext |
There was a problem hiding this comment.
We don't need this as the Context is now Sendable
| get async { | ||
| await context.read { $0.configuration } | ||
| } | ||
| context.configuration |
There was a problem hiding this comment.
These become synchronous accessors.
| public func perform<R>( | ||
| _ action: @Sendable (EmbedderModelContext) throws -> R | ||
| ) rethrows -> R { |
There was a problem hiding this comment.
And we can have a synchronous perform()
| @@ -27,52 +29,52 @@ import MLXLMCommon | |||
| /// } | |||
| /// ``` | |||
| public final class EmbedderModelContainer: Sendable { | |||
There was a problem hiding this comment.
Is this type still needed? Maybe. EmbedderModelContext is a struct, so this gives reference semantics -- users share the instance. Potentially EmbedderModelContext could become a class and we remove this? It would need to be immutable to do so. Keeping it for now.
| } | ||
| } | ||
|
|
||
| extension MaterializedModule: EmbeddingModel, BaseLanguageModel where LayerType: EmbeddingModel { |
There was a problem hiding this comment.
This is how we make MaterializedModule usable as a EmbeddingModel
| public var model: any EmbeddingModel | ||
| public var model: any EmbeddingModel & Sendable | ||
| public var tokenizer: any Tokenizer | ||
| public let pooling: Pooling |
There was a problem hiding this comment.
Potentially we want this to be var
| /// `Pooling` takes the sequence of hidden states from a transformer model and collapses them | ||
| /// into a single vector using strategies like mean, max, or token selection. | ||
| open class Pooling: Module { | ||
| public struct Pooling: Sendable { |
There was a problem hiding this comment.
I don't see why this should be Module. It isn't attached to the model itself (so e.g. update() semantics are not needed). I changed it like this and everything built cleanly.
| ], | ||
| dependencies: [ | ||
| .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.4")), | ||
| .package(url: "https://github.com/ml-explore/mlx-swift", branch: "materialized-array"), |
There was a problem hiding this comment.
Pick up that branch for now.
| try await context.read { | ||
| try await action($0) | ||
| } | ||
| try await action(context) |
There was a problem hiding this comment.
This seems to drop the serialization guarantee previously provided. The old SerialAccessContainer.read held an async mutex for the full duration of the closure, including suspension points, so concurrent perform calls could not overlap on the same model/tokenizer/pooling context.
With this direct call, two tasks can now enter perform concurrently and use the same underlying context values at the same time, while the type-level documentation still says the container “guarantees single threaded access.” Even if MaterializedModule makes the wrapped model sendable, that does not by itself preserve the previous exclusive-access contract for all work reachable through the context.
This may not be a real issue, except a documentation kind of inconsistency
There was a problem hiding this comment.
I think that is a good point -- it won't matter as much for the Embedders because they don't have state, but for example the KVCache is state. It has to have exclusive access or be some kind of copy-in/copy-out setup.
For example in ChatSession it had to play some games with private classes and knowledge of what was safe:
private func streamMap<R: Sendable>() {
try await cache.update { cache in
let model = await model.perform { context in
SendableBox(context.model)
}.consume()This would get exclusive access to the KVCache and then "borrow" the model/weights -- it treated them as Sendable, but they were not represented that way in the type system. This had to be internal implementation because it required care to make sure it was thread safe.
My hope is that providing Sendable pieces would allow anyone to write this code and do it safely.
But your point about what is serial access vs not is important.
| ) { | ||
| self.configuration = configuration | ||
| self.model = model | ||
| self.model = MaterializedModule(model) |
There was a problem hiding this comment.
context.model is still exposed as any EmbeddingModel & Sendable; EmbeddingModel inherits BaseLanguageModel, which inherits the Module mutation APIs. This now stores a MaterializedModule, whose update/apply/train/freeze overrides trap with fatalError. Existing callers mutating the model through perform { try $0.model.update(...) } can still compile through the existential API but now crash at runtime.
There was a problem hiding this comment.
Yes, I would prefer it to be typed as MaterializedModule and then these functions are unavailable. That would require either the ModelContext to be generic OR MaterializedModule to not be generic (probably subclasses).
I need to play around with that and see if one works better than the other. Just getting this to compile took some effort -- perhaps it can be done better. This change has to be done in the context of integrating it!
The fact that the Materialized variants violate Liskov substitution isn't great, but I think the benefits of documenting and obtaining Sendable in a way that is interoperable is probably worth it, but I would appreciate feedback from people who use it!
There was a problem hiding this comment.
I should note that MaterializedModule has these methods marked as unavailable so the compile time warning would be available if it were typed like that.
There was a problem hiding this comment.
A non-generic MaterializedModule won't work -- it has to be final for the Sendable to work on it. I can't make it a struct because you can't subclass a struct.
I will look at making ModelContext generic, but that will change things all over the place.
We may have to live with some limitations on MaterializedModule, but perhaps if it is encapsulated inside things like ModelContext it won't be an issue.
- adopt changes from ml-explore/mlx-swift#418 - we don't need private box types -- the technique becomes general - it also opens up some potential for synchronous evaluation
da4f7d5 to
8b5c7b7
Compare
Proposed changes
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes