-
Notifications
You must be signed in to change notification settings - Fork 43
feat(provider): adopt Gemma4 vMLX decode stack #470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a3c9725
ce47be6
98956ca
3bb004d
fb47375
e0b67c5
6819a45
8e4f59a
1231a12
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| +3 −1 | Source/MLX/Device.swift | |
| +3 −1 | Source/MLX/MLXArray.swift | |
| +3 −1 | Source/MLX/Stream.swift | |
| +23 −5 | Source/MLX/Transforms+Compile.swift | |
| +61 −0 | Tests/MLXTests/TransformTests.swift |
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,7 +49,8 @@ extension BatchScheduler { | |
| topK: Int? = nil, | ||
| seed: UInt64? = nil, | ||
| requestId: String? = nil, | ||
| cacheScope: String = "" | ||
| cacheScope: String = "", | ||
| allowFastPath: Bool = true | ||
| ) async -> AsyncStream<GenerationEvent> { | ||
| let id = requestId ?? "req-\(UUID().uuidString.prefix(12))" | ||
| let (stream, continuation) = AsyncStream<GenerationEvent>.makeStream() | ||
|
|
@@ -89,6 +90,34 @@ extension BatchScheduler { | |
| continuation.finish() | ||
| return stream | ||
| } | ||
| // B=1 greedy fast path eligibility — MUST be decided BEFORE inserting our | ||
| // bridge (it checks `activeBridges.isEmpty` for exclusivity). When taken, | ||
| // this bypasses the batched engine + planner for a single greedy request; | ||
| // otherwise the unchanged batched-engine path below runs. | ||
| let useFastPath = b1FastPathEligible( | ||
| temperature: temperature, | ||
| topP: topP, | ||
| topK: topK, | ||
| seed: seed, | ||
| promptTokenCount: promptTokens.count, | ||
| maxTokens: maxTokens, | ||
| cacheScope: cacheScope, | ||
| allowFastPath: allowFastPath | ||
| ) | ||
| // Concurrency gate: never run the batched engine concurrently with an | ||
| // in-flight B=1 fast-path task. The fast path assumes it is the sole GPU / | ||
| // KV consumer (single-row), so a request that is NOT itself taking the | ||
| // fast path while one is active is rejected with a retryable signal | ||
| // (`token_budget_exhausted` ⇒ 503 upstream) so it reroutes / retries | ||
| // rather than overlapping. Inert when the fast path is OFF (the default): | ||
| // `fastPathTasks` is then always empty, so the engine path is unchanged. | ||
| if !useFastPath && !fastPathTasks.isEmpty { | ||
| noteAdmissionReject() | ||
| continuation.yield(.error( | ||
| "token_budget_exhausted: a single-request fast path is active; retry shortly")) | ||
| continuation.finish() | ||
| return stream | ||
| } | ||
| let bridge = BridgeState( | ||
| requestId: id, | ||
| promptTokens: promptTokens.count, | ||
|
|
@@ -97,6 +126,50 @@ extension BatchScheduler { | |
| ) | ||
| activeBridges[id] = bridge | ||
|
|
||
| if useFastPath { | ||
| // Reserve KV bytes (cold; no restore, no planner, no engine enqueue) | ||
| // so the global KV budget and concurrent admissions still account for | ||
| // this request exactly as the engine path would. | ||
| let kvOutcome = await reserveKVForRequest( | ||
| requestId: id, | ||
| requestTokens: requestBudget, | ||
| reservationTokens: requestBudget, | ||
| restorePlanned: false | ||
|
Comment on lines
+133
to
+137
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When the B=1 path is enabled on a KV-quantized model, this branch reserves through Useful? React with 👍 / 👎. |
||
| ) | ||
| guard kvOutcome != .failed else { | ||
| await dropBridge(requestId: id) | ||
| noteAdmissionReject() | ||
| continuation.yield(.error( | ||
| "token_budget_exhausted: insufficient global KV cache headroom")) | ||
| continuation.finish() | ||
| return stream | ||
| } | ||
| // Re-check the captured engine is still current and the container is | ||
| // live after the reserve await — a reload/unload may have run. Use | ||
| // releaseRequestResources (not bare dropBridge) so the reservation | ||
| // made above is not leaked if a cancel dropped the bridge meanwhile. | ||
| guard engineStillCurrent(submitEpoch, engine), let container = self.modelContainer else { | ||
| await releaseRequestResources(id) | ||
| continuation.yield(.error("model reloaded during submit; please retry")) | ||
| continuation.finish() | ||
| return stream | ||
| } | ||
| runGreedyFastPath( | ||
| requestId: id, | ||
| container: container, | ||
| promptTokens: promptTokens, | ||
| maxTokens: maxTokens, | ||
| continuation: continuation | ||
| ) | ||
| let scheduler = self | ||
| continuation.onTermination = { @Sendable termination in | ||
| if case .cancelled = termination { | ||
| Task { await scheduler.cancel(requestId: id) } | ||
| } | ||
| } | ||
| return stream | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This returns before Useful? React with 👍 / 👎. |
||
| } | ||
|
|
||
| if let planner = self.planner { | ||
| await refreshPlannerPolicy(activeTokenBudget: tokenBudgetMax) | ||
| let result = await planner.admit( | ||
|
|
@@ -222,6 +295,19 @@ extension BatchScheduler { | |
| // Pin the load epoch with the captured engine (see submitTokenized). | ||
| let submitEpoch = generationEpoch | ||
|
|
||
| // Concurrency gate (see submitTokenized): this ChatCompletionRequest path | ||
| // always uses the batched engine, which must not overlap an in-flight B=1 | ||
| // fast-path task. Reject with a retryable signal while one is active. | ||
| // Inert when the fast path is OFF (`fastPathTasks` always empty), so the | ||
| // engine path is unchanged by default. | ||
| if !fastPathTasks.isEmpty { | ||
| noteAdmissionReject() | ||
| continuation.yield(.error( | ||
| "token_budget_exhausted: a single-request fast path is active; retry shortly")) | ||
| continuation.finish() | ||
| return stream | ||
| } | ||
|
|
||
| // Pre-tokenize so chat-template errors surface as `.error` events; | ||
| // engine's internal `buildPrompt` silently falls back to role:content. | ||
| let messages: [[String: any Sendable]] = request.messages.map { msg in | ||
|
|
@@ -405,6 +491,14 @@ extension BatchScheduler { | |
| } | ||
|
|
||
| public func cancel(requestId: String) async { | ||
| // B=1 fast-path request: it runs off-engine, so the engine abort below | ||
| // can't reach it. Cancel its task; the task observes the cancellation, | ||
| // runs its own finish bookkeeping (KV release + bridge removal + terminal | ||
| // events) and clears its handle. (If it already finished and self-removed, | ||
| // this is a no-op and we fall through to the harmless engine/local path.) | ||
| if cancelFastPathTask(requestId) { | ||
| return | ||
| } | ||
| if let engine = self.engine { | ||
| // Engine delivers a terminal RequestOutput synchronously; the | ||
| // streaming Task handles `recordFinish` + KV release. | ||
|
|
@@ -436,6 +530,11 @@ extension BatchScheduler { | |
| } | ||
|
|
||
| public func cancelAll() async { | ||
| // Cancel any off-engine B=1 fast-path tasks first; each self-removes and | ||
| // releases its KV/bridge. The bridge-id KV release + removeAll below then | ||
| // covers the engine-path bridges (and is an idempotent no-op for any | ||
| // fast-path bridge a racing task already tore down). | ||
| cancelAllFastPathTasks() | ||
| if let engine = self.engine { | ||
| _ = engine.core.abortAllRequests() | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the fast path is enabled, eligibility is decided without the prompt length or
maxContextLength, and this branch returns before the planner'smaxTokensPerBatch/context-window rejection runs. A greedy request whose prompt is longer than the model context but still fits the memory token budget can therefore be sent directly toModelContainer.generateinstead of producing the deterministic context error the engine path emits, risking runtime failures or malformed output for over-context prompts.Useful? React with 👍 / 👎.