diff --git a/libs/mlx-swift b/libs/mlx-swift index 3c50ad69..e20ea3dd 160000 --- a/libs/mlx-swift +++ b/libs/mlx-swift @@ -1 +1 @@ -Subproject commit 3c50ad693a7a3fbfbcc1d7ddf834a37e9ee8cf14 +Subproject commit e20ea3dd85456910e61f7fd20e2de9285ea51903 diff --git a/libs/mlx-swift-lm b/libs/mlx-swift-lm index b1df0f22..791c17f8 160000 --- a/libs/mlx-swift-lm +++ b/libs/mlx-swift-lm @@ -1 +1 @@ -Subproject commit b1df0f22424f94005685baf79192ff08e9c43eb9 +Subproject commit 791c17f86dc0dac0571a7d933f62bb181948443b diff --git a/provider-swift/Sources/ProviderCore/Inference/BatchScheduler+B1FastPath.swift b/provider-swift/Sources/ProviderCore/Inference/BatchScheduler+B1FastPath.swift new file mode 100644 index 00000000..e68973c3 --- /dev/null +++ b/provider-swift/Sources/ProviderCore/Inference/BatchScheduler+B1FastPath.swift @@ -0,0 +1,370 @@ +// Copyright © 2026 Eigen Labs. +// +// BatchScheduler B=1 greedy fast path: an env-gated bypass of the +// continuous-batching engine for a single exclusive, greedy (temperature == 0) +// request. +// +// The batched engine carries continuous-batching overhead a single-row decode +// does not need: batch tensor (de)allocation per step, the scheduler step loop, +// the output collector, and cross-thread `RequestOutput` streaming. On Gemma-4 +// that overhead shows up as ~63 TPS through `BatchedEngine` vs ~75 TPS for the +// raw single-sequence loop (see `Tests/.../Gemma4DecodeProfileTests.swift`). +// +// When exactly one request is in flight and it is pure greedy, we run that +// single-sequence decode through `ModelContainer.generate` — the SAME +// concurrency-safe path the VLM media path (`VLMRequestInference`) already uses +// alongside the engine. `ModelContainer.generate` holds the container +// exclusively only for the prefill, then streams the decode (asyncEval +// pipelined) on its own task. We translate its `Generation` events to our +// `GenerationEvent` stream. +// +// Safety posture: +// * OFF by default; opt in with an env flag. +// * Conservative gate — anything that isn't a single exclusive greedy text +// request falls back to the batched engine, so the engine path's behavior +// is never altered. +// * KV byte budget is still reserved/released, and bridge bookkeeping +// (heartbeats, decode/prefill EWMA, billing-safe usage) is preserved via +// the SAME `recordAdmission` / `recordFirstToken` / `recordFinish` methods +// the engine bridge uses. +// * The in-flight task is tracked so `cancel` / `cancelAll` / +// `stopCurrentEngine` tear it down deterministically. + +import Foundation +import MLX +import MLXLMCommon + +extension BatchScheduler { + + // MARK: - Env gate + + /// True when the operator opted into the B=1 greedy fast path. Two flags are + /// accepted: `DARKBLOOM_B1_GREEDY_FAST_PATH` (generic) and + /// `DARKBLOOM_GEMMA_B1_FAST_PATH` (Gemma-targeted alias). Either set to `"1"` + /// enables it. Read per-call (cheap) so tests can toggle it via the + /// environment without restarting the scheduler. + static func b1GreedyFastPathEnabled() -> Bool { + let env = ProcessInfo.processInfo.environment + return env["DARKBLOOM_B1_GREEDY_FAST_PATH"] == "1" + || env["DARKBLOOM_GEMMA_B1_FAST_PATH"] == "1" + } + + // MARK: - Eligibility + + /// Whether this request can take the single-exclusive greedy fast path. + /// + /// MUST be evaluated BEFORE the request's own bridge is inserted into + /// `activeBridges` — the exclusivity check reads `activeBridges.count`. + /// Every condition is conservative: a miss simply defers to the batched + /// engine, so this can only shrink the set of requests the fast path serves, + /// never change the engine path's correctness. The decision itself is a pure + /// function (`b1FastPathEligiblePure`) so it can be unit-tested exhaustively + /// without a loaded model. + func b1FastPathEligible( + temperature: Float, + topP: Float?, + topK: Int?, + seed: UInt64?, + promptTokenCount: Int, + maxTokens: Int, + cacheScope: String, + allowFastPath: Bool + ) -> Bool { + Self.b1FastPathEligiblePure( + // Test override wins when set; otherwise consult the env flags. + enabled: _forceB1FastPathForTest ?? Self.b1GreedyFastPathEnabled(), + allowFastPath: allowFastPath, + modelId: modelId, + kvQuantEnabled: kvQuantEnabled, + temperature: temperature, + topP: topP, + topK: topK, + seed: seed, + promptTokenCount: promptTokenCount, + maxTokens: maxTokens, + maxContextLength: maxContextLength, + cacheScope: cacheScope, + activeBridgeCount: activeBridges.count, + pendingRequestCount: pendingRequestCount, + fastPathActive: !fastPathTasks.isEmpty, + hasContainer: modelContainer != nil + ) + } + + /// Pure eligibility policy for the B=1 greedy fast path. No actor state — all + /// inputs are parameters — so it is fully unit-testable. Order is irrelevant + /// to the result (all conditions must hold), but kept cheapest-first. + static func b1FastPathEligiblePure( + enabled: Bool, + allowFastPath: Bool, + modelId: String, + kvQuantEnabled: Bool, + temperature: Float, + topP: Float?, + topK: Int?, + seed: UInt64?, + promptTokenCount: Int, + maxTokens: Int, + maxContextLength: Int, + cacheScope: String, + activeBridgeCount: Int, + pendingRequestCount: Int, + fastPathActive: Bool, + hasContainer: Bool + ) -> Bool { + guard enabled else { return false } + // Caller opt-in. The engine consumer clears this for tool-bearing + // requests: the fast path is greedy text-only and cannot reproduce the + // engine's raw-text tool-call contract (`container.generate` may parse a + // call into a `.toolCall` event, CONSUMING the text — see the runner's + // `.toolCall` handling), so tool requests must stay on the engine path. + guard allowFastPath else { return false } + // Family gate: only Gemma-4 is profiled + validated for this bypass, and + // its greedy / EOS behavior is only known-good there. Every other family + // (different EOS sets, tool/stop conventions) defers to the batched engine. + guard modelId.lowercased().contains("gemma") else { return false } + // KV quantization: batched-engine admission reserves at the REDUCED + // (quantized) per-token KV rate, but `ModelContainer.generate` allocates a + // full fp16 KV cache. A fast-path reservation sized at the quantized rate + // would under-count ~2x and risk a unified-memory OOM, so whenever KV + // quant is active we defer to the engine (which owns the quantized cache). + guard !kvQuantEnabled else { return false } + // Pure greedy only: temperature 0 and no nucleus / top-k truncation. + // (minP / repetition / presence / frequency penalties are not part of + // the tokenized submit surface, so temperature + topP + topK fully + // characterize "greedy" here.) + guard temperature == 0 else { return false } + guard topP == nil || topP == 0 else { return false } + guard topK == nil || topK == 0 else { return false } + // A seed implies sampling intent; greedy ignores it, but treat its + // presence as "not the simple greedy case" and defer to the engine. + guard seed == nil else { return false } + guard maxTokens > 0 else { return false } + // Need a real prompt to prefill (a 0-token prompt has no greedy seed). + guard promptTokenCount > 0 else { return false } + // Context window: the fast path runs a cold prefill of the WHOLE prompt + // and decodes up to `maxTokens` against one fresh cache. If that span + // exceeds the model's context window, defer to the engine path — it + // enforces context limits and emits the precise context-overflow + // rejection. `maxContextLength == 0` ⇒ context unknown ⇒ skip this gate + // (the remaining gates, incl. the token-budget guard upstream, still apply). + if maxContextLength > 0 { + guard promptTokenCount + maxTokens <= maxContextLength else { return false } + } + // No prefix-cache scope: the fast path runs a cold prefill against a + // fresh cache and does not participate in the checkpoint / engine prefix + // tiers, so a scoped request keeps the engine path to retain cache reuse. + guard cacheScope.isEmpty else { return false } + // Exclusive: no other in-flight or queued work. Concurrent batched work + // would defeat the single-row assumption (shared GPU + KV headroom). + guard activeBridgeCount == 0 else { return false } + guard pendingRequestCount == 0 else { return false } + // And no OTHER fast-path task already running (explicit single-row gate; + // belt-and-suspenders with the activeBridgeCount check, since a running + // fast path also holds a bridge). + guard !fastPathActive else { return false } + // Need a live container to generate against. + guard hasContainer else { return false } + return true + } + + // MARK: - Runner + + /// Drive a single greedy request through `ModelContainer.generate` and + /// translate its `Generation` events onto the scheduler's `GenerationEvent` + /// stream. Mirrors `runBridge`'s lifecycle (admission / first-token / finish + /// bookkeeping and terminal `.info` / `.error` mapping) but sources tokens + /// from the single-sequence generator instead of `engine.core.streamOutputs`. + /// + /// The spawned task is tracked in `fastPathTasks[id]` so `cancel` / + /// `cancelAll` / `stopCurrentEngine` can tear it down; it removes its own + /// handle on completion. The caller (`submitTokenized`) is responsible for + /// having inserted the bridge and reserved KV before this runs, and for + /// wiring `continuation.onTermination`. + func runGreedyFastPath( + requestId id: String, + container: ModelContainer, + promptTokens: [Int], + maxTokens: Int, + continuation: AsyncStream.Continuation + ) { + let scheduler = self + let promptCount = promptTokens.count + let task = Task { + // Token-only input (no media). `MLXArray(promptTokens)` is a cheap + // host-side copy; the GPU work happens inside `generate`. + let lmInput = LMInput(tokens: MLXArray(promptTokens)) + // temperature 0 ⇒ ArgMaxSampler. topP/topK/penalties left at their + // defaults are inert under greedy. maxTokens bounds the decode. + let params = GenerateParameters(maxTokens: maxTokens, temperature: 0) + + // Admission ≈ now: prefill is about to begin. Drives the + // pending-timeout predicate and starts the prefill-EWMA window. + await scheduler.recordAdmission(requestId: id, at: .now) + + let genStream: AsyncStream + do { + genStream = try await container.generate(input: lmInput, parameters: params) + } catch { + _ = await scheduler.recordFinish( + requestId: id, promptTokens: promptCount, + completionTokens: 0, success: false) + continuation.yield(.error( + "fast path generation failed: \(error.localizedDescription)")) + continuation.finish() + await scheduler.clearFastPathTask(id) + return + } + + var sawFirstToken = false + // Count every streamed chunk as >= 1 completion token. The terminal + // `.info` carries the EXACT generation count, but it only arrives on a + // clean finish; on cancellation the loop breaks before it, so without + // this running tally `recordFinish` would settle at 0 completion tokens + // and the coordinator would bill $0 for work already streamed to the + // client. `recordFinish` takes max(observed, terminal), so a clean + // finish still uses the exact `.info` count (>= the chunk tally). + var streamedTokens = 0 + var terminalCompletion: Int? = nil + var reportedPrompt = promptCount + // Defensive: the greedy text-only fast path should never see a parsed + // tool call (tool requests are kept on the engine path by the caller's + // `allowFastPath` gate). If one is surfaced anyway we cannot faithfully + // reproduce the engine's raw-text behavior, so we FAIL rather than drop. + var sawToolCall = false + + for await gen in genStream { + // Cooperative cancellation: a client cancel / model reload cancels + // this task; break and let the finish bookkeeping below run. + if Task.isCancelled { break } + switch gen { + case .chunk(let text): + if !sawFirstToken { + sawFirstToken = true + await scheduler.recordFirstToken(requestId: id, at: .now) + } + streamedTokens += 1 + if !text.isEmpty { + continuation.yield(.chunk(text)) + } + case .info(let info): + reportedPrompt = info.promptTokenCount + terminalCompletion = info.generationTokenCount + case .toolCall: + // `container.generate` parsed a tool call (and may have + // CONSUMED its text rather than emitting it as `.chunk`s). + // Silently dropping it would lose the call; the engine path + // emits raw text and never `.toolCall`, so we cannot match it + // here. Mark failure and stop. + sawToolCall = true + } + if sawToolCall { break } + } + + let cancelled = Task.isCancelled + // Billing-safe completion count: terminal exact count when present, + // otherwise the streamed-chunk lower bound (covers cancel + tool-call + // failure, where no `.info` arrived). + let completionTokens = max(terminalCompletion ?? 0, streamedTokens) + let succeeded = !cancelled && !sawToolCall + // Reuse the engine bridge's finish bookkeeping: removes the bridge, + // updates the decode + prefill EWMA, releases the KV reservation, and + // returns billing-safe usage counts (max of observed vs. terminal). + let usage = await scheduler.recordFinish( + requestId: id, + promptTokens: reportedPrompt, + completionTokens: completionTokens, + success: succeeded) + + // Emit delivered usage (so a listener can bill partial work) before + // any terminal error, mirroring the engine bridge. + if !succeeded, usage.promptTokens > 0 || usage.completionTokens > 0 { + continuation.yield(.info( + promptTokens: usage.promptTokens, + completionTokens: usage.completionTokens, + tokensPerSecond: usage.tps)) + } + if cancelled { + continuation.yield(.error("request cancelled")) + } else if sawToolCall { + continuation.yield(.error( + "fast path does not support tool calls; please retry")) + } else { + continuation.yield(.info( + promptTokens: usage.promptTokens, + completionTokens: usage.completionTokens, + tokensPerSecond: usage.tps)) + } + continuation.finish() + await scheduler.clearFastPathTask(id) + } + fastPathTasks[id] = task + } + + // MARK: - Task tracking / teardown + + /// Remove a finished fast-path task handle. Called by the task itself on + /// completion. Safe for an unknown id. + func clearFastPathTask(_ id: String) { + fastPathTasks.removeValue(forKey: id) + } + + /// Cancel the in-flight fast-path task for `id`, if any. Returns true when a + /// task existed and was cancelled. The task observes `Task.isCancelled`, + /// runs its finish bookkeeping (KV release, bridge removal, terminal events) + /// and clears its own handle. + @discardableResult + func cancelFastPathTask(_ id: String) -> Bool { + guard let task = fastPathTasks[id] else { return false } + task.cancel() + return true + } + + /// Cancel every in-flight fast-path task (model reload / `cancelAll`). Each + /// task self-removes; callers that also clear `fastPathTasks` (e.g. + /// `stopCurrentEngine`) make late `clearFastPathTask` calls harmless no-ops. + func cancelAllFastPathTasks() { + for task in fastPathTasks.values { task.cancel() } + } + + /// Cancel AND fence every in-flight fast-path task — used by + /// `stopCurrentEngine` before it nil's `modelContainer` and clears the MLX + /// cache. Unlike the engine (which is fenced by `stopAndWait`), a fast-path + /// task runs off-engine inside `ModelContainer.generate`, holding and running + /// GPU work against the model + its KV cache. If teardown freed that state + /// while a task were still mid-`generate`, it could touch released model/MLX + /// state. Awaiting each task's value blocks until it has observed + /// cancellation, run its finish bookkeeping (KV release + bridge removal + + /// terminal events) and dropped its model/iterator references. + /// + /// The handles are snapshotted first so a self-removing `clearFastPathTask` + /// during the awaits cannot mutate the collection being iterated. The await + /// suspends the actor so those actor-isolated callbacks make progress; no NEW + /// fast path can start meanwhile because `stopCurrentEngine` has already + /// nil'd `engine` (every submit path short-circuits on a nil engine). + /// Idempotent: a no-op when nothing is in flight. + func waitForFastPathTasks() async { + let inflight = Array(fastPathTasks.values) + guard !inflight.isEmpty else { return } + for task in inflight { task.cancel() } + for task in inflight { await task.value } + fastPathTasks.removeAll() + } +} + +// MARK: - Test support +// +// Internal + `@testable`-only; dead-code-stripped from production binaries. + +extension BatchScheduler { + /// Force the B=1 fast-path enablement gate on/off, bypassing the env flags. + /// `nil` restores env-driven behavior. Lets a benchmark A/B the fast path vs. + /// the batched engine in a single process (mutating `ProcessInfo`'s cached + /// environment mid-run is unreliable). + func _setForceB1FastPathForTest(_ value: Bool?) { + _forceB1FastPathForTest = value + } + + /// Test accessor: number of in-flight fast-path tasks currently tracked. + func _fastPathTaskCountForTest() -> Int { fastPathTasks.count } +} diff --git a/provider-swift/Sources/ProviderCore/Inference/BatchScheduler+Submit.swift b/provider-swift/Sources/ProviderCore/Inference/BatchScheduler+Submit.swift index 55f15b75..7fb28dd3 100644 --- a/provider-swift/Sources/ProviderCore/Inference/BatchScheduler+Submit.swift +++ b/provider-swift/Sources/ProviderCore/Inference/BatchScheduler+Submit.swift @@ -49,7 +49,8 @@ extension BatchScheduler { topK: Int? = nil, seed: UInt64? = nil, requestId: String? = nil, - cacheScope: String = "" + cacheScope: String = "", + allowFastPath: Bool = true ) async -> AsyncStream { let id = requestId ?? "req-\(UUID().uuidString.prefix(12))" let (stream, continuation) = AsyncStream.makeStream() @@ -89,6 +90,34 @@ extension BatchScheduler { continuation.finish() return stream } + // B=1 greedy fast path eligibility — MUST be decided BEFORE inserting our + // bridge (it checks `activeBridges.isEmpty` for exclusivity). When taken, + // this bypasses the batched engine + planner for a single greedy request; + // otherwise the unchanged batched-engine path below runs. + let useFastPath = b1FastPathEligible( + temperature: temperature, + topP: topP, + topK: topK, + seed: seed, + promptTokenCount: promptTokens.count, + maxTokens: maxTokens, + cacheScope: cacheScope, + allowFastPath: allowFastPath + ) + // Concurrency gate: never run the batched engine concurrently with an + // in-flight B=1 fast-path task. The fast path assumes it is the sole GPU / + // KV consumer (single-row), so a request that is NOT itself taking the + // fast path while one is active is rejected with a retryable signal + // (`token_budget_exhausted` ⇒ 503 upstream) so it reroutes / retries + // rather than overlapping. Inert when the fast path is OFF (the default): + // `fastPathTasks` is then always empty, so the engine path is unchanged. + if !useFastPath && !fastPathTasks.isEmpty { + noteAdmissionReject() + continuation.yield(.error( + "token_budget_exhausted: a single-request fast path is active; retry shortly")) + continuation.finish() + return stream + } let bridge = BridgeState( requestId: id, promptTokens: promptTokens.count, @@ -97,6 +126,50 @@ extension BatchScheduler { ) activeBridges[id] = bridge + if useFastPath { + // Reserve KV bytes (cold; no restore, no planner, no engine enqueue) + // so the global KV budget and concurrent admissions still account for + // this request exactly as the engine path would. + let kvOutcome = await reserveKVForRequest( + requestId: id, + requestTokens: requestBudget, + reservationTokens: requestBudget, + restorePlanned: false + ) + guard kvOutcome != .failed else { + await dropBridge(requestId: id) + noteAdmissionReject() + continuation.yield(.error( + "token_budget_exhausted: insufficient global KV cache headroom")) + continuation.finish() + return stream + } + // Re-check the captured engine is still current and the container is + // live after the reserve await — a reload/unload may have run. Use + // releaseRequestResources (not bare dropBridge) so the reservation + // made above is not leaked if a cancel dropped the bridge meanwhile. + guard engineStillCurrent(submitEpoch, engine), let container = self.modelContainer else { + await releaseRequestResources(id) + continuation.yield(.error("model reloaded during submit; please retry")) + continuation.finish() + return stream + } + runGreedyFastPath( + requestId: id, + container: container, + promptTokens: promptTokens, + maxTokens: maxTokens, + continuation: continuation + ) + let scheduler = self + continuation.onTermination = { @Sendable termination in + if case .cancelled = termination { + Task { await scheduler.cancel(requestId: id) } + } + } + return stream + } + if let planner = self.planner { await refreshPlannerPolicy(activeTokenBudget: tokenBudgetMax) let result = await planner.admit( @@ -222,6 +295,19 @@ extension BatchScheduler { // Pin the load epoch with the captured engine (see submitTokenized). let submitEpoch = generationEpoch + // Concurrency gate (see submitTokenized): this ChatCompletionRequest path + // always uses the batched engine, which must not overlap an in-flight B=1 + // fast-path task. Reject with a retryable signal while one is active. + // Inert when the fast path is OFF (`fastPathTasks` always empty), so the + // engine path is unchanged by default. + if !fastPathTasks.isEmpty { + noteAdmissionReject() + continuation.yield(.error( + "token_budget_exhausted: a single-request fast path is active; retry shortly")) + continuation.finish() + return stream + } + // Pre-tokenize so chat-template errors surface as `.error` events; // engine's internal `buildPrompt` silently falls back to role:content. let messages: [[String: any Sendable]] = request.messages.map { msg in @@ -405,6 +491,14 @@ extension BatchScheduler { } public func cancel(requestId: String) async { + // B=1 fast-path request: it runs off-engine, so the engine abort below + // can't reach it. Cancel its task; the task observes the cancellation, + // runs its own finish bookkeeping (KV release + bridge removal + terminal + // events) and clears its handle. (If it already finished and self-removed, + // this is a no-op and we fall through to the harmless engine/local path.) + if cancelFastPathTask(requestId) { + return + } if let engine = self.engine { // Engine delivers a terminal RequestOutput synchronously; the // streaming Task handles `recordFinish` + KV release. @@ -436,6 +530,11 @@ extension BatchScheduler { } public func cancelAll() async { + // Cancel any off-engine B=1 fast-path tasks first; each self-removes and + // releases its KV/bridge. The bridge-id KV release + removeAll below then + // covers the engine-path bridges (and is an idempotent no-op for any + // fast-path bridge a racing task already tore down). + cancelAllFastPathTasks() if let engine = self.engine { _ = engine.core.abortAllRequests() } diff --git a/provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift b/provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift index a03d1cb4..2c1e9e4c 100644 --- a/provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift +++ b/provider-swift/Sources/ProviderCore/Inference/BatchScheduler.swift @@ -261,6 +261,22 @@ public actor BatchScheduler { /// (vs. "request cancelled" for client-initiated aborts). var timedOutBridges: Set = [] + /// In-flight B=1 greedy fast-path tasks, keyed by request id. The fast path + /// (see `BatchScheduler+B1FastPath.swift`) bypasses the batched engine for a + /// single exclusive greedy request and runs `ModelContainer.generate` + /// directly, so it is NOT registered with `engine.core` — `cancel` / + /// `cancelAll` / `stopCurrentEngine` must cancel these tasks here (the + /// engine abort path can't reach them). Each task removes its own entry on + /// completion via `clearFastPathTask`. + var fastPathTasks: [String: Task] = [:] + + /// Test-only override for the B=1 fast-path enablement gate. `nil` (the + /// production default) defers to the env flags via `b1GreedyFastPathEnabled()`. + /// Set via `_setForceB1FastPathForTest` so a benchmark can A/B the fast path + /// against the batched engine in one process without relying on mutating the + /// (often cached) `ProcessInfo` environment mid-run. @testable-only. + internal var _forceB1FastPathForTest: Bool? = nil + // MARK: - Telemetry state (read by `backendCapacity`) var observedDecodeTpsEwma: Double = 0 @@ -398,6 +414,17 @@ public actor BatchScheduler { // submits fail the guard and reject/retry against the next model instead. var stoppingEngine = self.engine self.engine = nil + // Tear down any in-flight B=1 fast-path tasks: they run off-engine via + // `ModelContainer.generate`, so the engine abort below can't reach them. + // We must FENCE (not just cancel) them here — before this teardown nil's + // `modelContainer` and runs `MLX.Memory.clearCache()` below — because a + // task still inside its `generate` loop holds and runs GPU work against + // the model + its KV. `waitForFastPathTasks` cancels each task and awaits + // its unwind (cancellation observation + finish bookkeeping: KV release, + // bridge removal, terminal events). The await suspends this actor so those + // callbacks make progress; no new fast path can start because `engine` is + // already nil above (every submit path short-circuits on a nil engine). + await waitForFastPathTasks() pendingTimeoutTask?.cancel() pendingTimeoutTask = nil // Stop the backend-liveness watchdog; a recovery restart re-arms it via diff --git a/provider-swift/Sources/ProviderCore/Inference/MultiModelBatchSchedulerEngine.swift b/provider-swift/Sources/ProviderCore/Inference/MultiModelBatchSchedulerEngine.swift index e6f78c46..04c8fd00 100644 --- a/provider-swift/Sources/ProviderCore/Inference/MultiModelBatchSchedulerEngine.swift +++ b/provider-swift/Sources/ProviderCore/Inference/MultiModelBatchSchedulerEngine.swift @@ -350,7 +350,11 @@ public struct MultiModelBatchSchedulerEngine: MLXServerEngine, Sendable { topP: request.topP, topK: request.topK, requestId: requestId, - cacheScope: cacheScope + cacheScope: cacheScope, + // Keep tool-bearing requests off the greedy text-only B=1 fast path: + // it cannot reproduce the engine's raw-text tool-call contract. No + // tools ⇒ fast path may apply (subject to the scheduler's gates). + allowFastPath: toolHandler == nil ) return AsyncThrowingStream { continuation in diff --git a/provider-swift/Sources/ProviderCore/Inference/VLMRequestInference.swift b/provider-swift/Sources/ProviderCore/Inference/VLMRequestInference.swift index 96517b0e..e7307652 100644 --- a/provider-swift/Sources/ProviderCore/Inference/VLMRequestInference.swift +++ b/provider-swift/Sources/ProviderCore/Inference/VLMRequestInference.swift @@ -105,15 +105,33 @@ public enum VLMRequestInference { request.maxTokens ?? defaultMaxTokens } - /// Conservative per-image soft-token allotment for the KV-token estimate. - /// Gemma-4 pools every image to a FIXED `vision_soft_tokens_per_image` (256) - /// regardless of resolution; other VLMs run higher. 1024 (4× Gemma) is a - /// generous model-agnostic upper bound that is still bounded by the model's - /// context window via the clamp in `projectedKVTokens`. + /// Conservative per-image (and per-video-frame) soft-token allotment for the + /// KV-token estimate. Gemma-4 pools every image/frame to a FIXED soft-token + /// block (`image_seq_length`, default 280) regardless of resolution and wraps + /// it with 2 `boi`/`eoi` delimiter tokens; other VLMs run higher. 1024 is a + /// generous model-agnostic upper bound that over-covers that whole + /// `boi + soft_tokens + eoi` per-frame span, and is still bounded by the + /// model's context window via the clamp in `projectedKVTokens`. static let visionTokensPerImage = 1024 - /// A video samples multiple frames, each contributing image-like soft tokens. - /// Charge a larger fixed allotment per video; still clamped to the context. - static let visionTokensPerVideo = 4096 + /// Max frames Gemma-4 samples from a single video. Mirrors `maxFrames: 32` + /// in the Gemma4 video processor (`Gemma4Processor.prepare`), which samples + /// up to 32 frames spread uniformly across the clip and expands EACH into its + /// own image-like `boi + soft_token*image_seq_length + eoi` block. A video's + /// KV footprint therefore scales with the sampled frame count — it is NOT a + /// flat per-video allotment. + static let maxVideoFramesSampled = 32 + /// A video reserves KV for EVERY sampled frame: the processor emits one + /// image-like soft-token block per sampled frame (up to + /// `maxVideoFramesSampled`), so the worst case is `maxVideoFramesSampled × + /// visionTokensPerImage`. Because `visionTokensPerImage` already over-covers + /// a single frame's soft tokens AND its `boi`/`eoi` delimiters, this product + /// bounds the full `32 × (soft tokens + delimiters)` span the prefill + /// actually writes into KV. The previous flat 4096 covered only ~4 frames and + /// badly under-reserved a full clip (Gemma-4's real worst case is + /// 32 × (280 + 2) = 9024 soft tokens). Still clamped to the model's context + /// window via the clamp in `projectedKVTokens`, so over-reservation never + /// projects past a request the context could actually hold. + static let visionTokensPerVideo = maxVideoFramesSampled * visionTokensPerImage /// Conservative chars→tokens divisor for the text prompt estimate. Real /// tokenizers average ~4 chars/token; dividing by 3 OVER-estimates the token /// count (the safe direction for a reservation). diff --git a/provider-swift/Tests/ProviderCoreTests/B1GreedyFastPathTests.swift b/provider-swift/Tests/ProviderCoreTests/B1GreedyFastPathTests.swift new file mode 100644 index 00000000..85b0b634 --- /dev/null +++ b/provider-swift/Tests/ProviderCoreTests/B1GreedyFastPathTests.swift @@ -0,0 +1,345 @@ +// B1GreedyFastPathTests -- unit coverage for the B=1 greedy fast-path +// eligibility policy, plus an opt-in live A/B benchmark that compares the +// fast path against the batched engine on Gemma-4. +// +// The eligibility tests are pure and run in CI (no model, no GPU). The +// benchmark is gated exactly like `Gemma4DecodeProfileTests`: +// +// DARKBLOOM_LIVE_MLX_TESTS=1 DARKBLOOM_LIVE_MLX_GEMMA=1 \ +// DARKBLOOM_GEMMA_MODEL=mlx-community/gemma-4-26b-a4b-it-8bit \ +// swift test --filter B1GreedyFastPathBenchmark +// +// Set DARKBLOOM_GEMMA_PRINT_TEXT=1 to print a short decoded sample from each +// path so a reviewer can eyeball output parity. + +import Foundation +import Testing +import MLX +import MLXLMCommon +import MLXVLM +@testable import ProviderCore + +// MARK: - Pure eligibility policy (CI-safe, no model) + +@Suite("B=1 fast-path eligibility") +struct B1GreedyFastPathEligibilityTests { + + /// All conditions satisfied for a single exclusive greedy Gemma-4 request. + private func eligible( + enabled: Bool = true, + allowFastPath: Bool = true, + modelId: String = "mlx-community/gemma-4-26b-a4b-it-8bit", + kvQuantEnabled: Bool = false, + temperature: Float = 0, + topP: Float? = nil, + topK: Int? = nil, + seed: UInt64? = nil, + promptTokenCount: Int = 16, + maxTokens: Int = 128, + maxContextLength: Int = 8192, + cacheScope: String = "", + activeBridgeCount: Int = 0, + pendingRequestCount: Int = 0, + fastPathActive: Bool = false, + hasContainer: Bool = true + ) -> Bool { + BatchScheduler.b1FastPathEligiblePure( + enabled: enabled, + allowFastPath: allowFastPath, + modelId: modelId, + kvQuantEnabled: kvQuantEnabled, + temperature: temperature, + topP: topP, + topK: topK, + seed: seed, + promptTokenCount: promptTokenCount, + maxTokens: maxTokens, + maxContextLength: maxContextLength, + cacheScope: cacheScope, + activeBridgeCount: activeBridgeCount, + pendingRequestCount: pendingRequestCount, + fastPathActive: fastPathActive, + hasContainer: hasContainer + ) + } + + @Test("the canonical single greedy request is eligible") + func canonicalEligible() { + #expect(eligible()) + } + + @Test("disabled gate short-circuits everything") + func disabledIsIneligible() { + #expect(!eligible(enabled: false)) + } + + @Test("non-greedy sampling is ineligible") + func samplingIsIneligible() { + #expect(!eligible(temperature: 0.7)) + #expect(!eligible(topP: 0.9)) + #expect(!eligible(topK: 40)) + #expect(!eligible(seed: 42)) + // Inert/disabled sampling knobs do NOT disqualify a greedy request. + #expect(eligible(topP: 0)) + #expect(eligible(topK: 0)) + } + + @Test("zero / negative maxTokens is ineligible") + func badMaxTokensIneligible() { + #expect(!eligible(maxTokens: 0)) + #expect(!eligible(maxTokens: -5)) + } + + @Test("an empty prompt is ineligible") + func emptyPromptIneligible() { + #expect(!eligible(promptTokenCount: 0)) + } + + @Test("the caller can force the request onto the engine path") + func callerOptOutIsIneligible() { + // Tool-bearing requests clear this so they never take the text-only path. + #expect(!eligible(allowFastPath: false)) + } + + @Test("only Gemma-family models are eligible") + func nonGemmaIsIneligible() { + #expect(!eligible(modelId: "mlx-community/Qwen3.5-30B-8bit")) + #expect(!eligible(modelId: "")) + // Case-insensitive family match. + #expect(eligible(modelId: "google/Gemma-4-it")) + } + + @Test("KV quantization disqualifies the fast path (fp16 KV under-reserve)") + func kvQuantIsIneligible() { + #expect(!eligible(kvQuantEnabled: true)) + } + + @Test("a prompt+generation span over the context window defers to the engine") + func overContextIsIneligible() { + // prompt + maxTokens must fit the model context window. + #expect(!eligible(promptTokenCount: 8000, maxTokens: 512, maxContextLength: 8192)) + // Exactly at the limit is fine. + #expect(eligible(promptTokenCount: 8064, maxTokens: 128, maxContextLength: 8192)) + // Unknown context (0) skips the gate — other gates still apply. + #expect(eligible(promptTokenCount: 100000, maxTokens: 4096, maxContextLength: 0)) + } + + @Test("a prefix-cache scope defers to the engine") + func scopedIsIneligible() { + #expect(!eligible(cacheScope: "tenant-abc")) + } + + @Test("any concurrent or queued work disqualifies the exclusive fast path") + func nonExclusiveIsIneligible() { + #expect(!eligible(activeBridgeCount: 1)) + #expect(!eligible(pendingRequestCount: 1)) + } + + @Test("an already-running fast-path task disqualifies a second one") + func fastPathActiveIsIneligible() { + #expect(!eligible(fastPathActive: true)) + } + + @Test("a missing container is ineligible") + func noContainerIneligible() { + #expect(!eligible(hasContainer: false)) + } + + @Test("env flags are off by default in this process") + func envFlagDefaultOff() { + // The CI runner does not set these, so the static gate is false. (If a + // developer exported them, this documents the expectation rather than + // asserting a hard false.) + let env = ProcessInfo.processInfo.environment + let expected = env["DARKBLOOM_B1_GREEDY_FAST_PATH"] == "1" + || env["DARKBLOOM_GEMMA_B1_FAST_PATH"] == "1" + #expect(BatchScheduler.b1GreedyFastPathEnabled() == expected) + } +} + +// MARK: - Live A/B benchmark (opt-in) + +/// One measured generation through the scheduler's tokenized submit path. +private struct FastPathRun { + var text: String = "" + var promptTokens: Int = 0 + var completionTokens: Int = 0 + var tokensPerSecond: Double = 0 + var error: String? + var wallSeconds: Double = 0 +} + +@Suite("B=1 fast-path benchmark", .serialized) +struct B1GreedyFastPathBenchmark { + + /// Submit a pre-tokenized greedy request and collect the full event stream. + private func runTokenized( + scheduler: BatchScheduler, + promptTokens: [Int], + maxTokens: Int + ) async -> FastPathRun { + let start = ContinuousClock.now + let stream = await scheduler.submitTokenized( + promptTokens: promptTokens, + maxTokens: maxTokens, + temperature: 0.0, + requestId: "b1-bench-\(UUID().uuidString.prefix(8))" + ) + var run = FastPathRun() + var chunks: [String] = [] + for await event in stream { + switch event { + case .chunk(let text): + chunks.append(text) + case .info(let prompt, let completion, let tps): + run.promptTokens = prompt + run.completionTokens = completion + run.tokensPerSecond = tps + case .error(let message): + run.error = message + } + } + run.text = chunks.joined() + run.wallSeconds = (ContinuousClock.now - start).asSeconds + return run + } + + /// Median tokens/sec of `iterations` measured runs after one warmup, for a + /// given fast-path mode. + private func measure( + scheduler: BatchScheduler, + promptTokens: [Int], + maxTokens: Int, + fastPath: Bool, + warmups: Int, + iterations: Int + ) async -> (median: FastPathRun, all: [FastPathRun]) { + await scheduler._setForceB1FastPathForTest(fastPath) + for _ in 0 ..< warmups { + _ = await runTokenized( + scheduler: scheduler, promptTokens: promptTokens, maxTokens: maxTokens) + } + var runs: [FastPathRun] = [] + for _ in 0 ..< iterations { + runs.append(await runTokenized( + scheduler: scheduler, promptTokens: promptTokens, maxTokens: maxTokens)) + } + let sorted = runs.sorted { $0.tokensPerSecond < $1.tokensPerSecond } + return (sorted[sorted.count / 2], runs) + } + + @Test( + "fast path vs batched engine decode TPS (Gemma-4)", + .enabled( + if: ProcessInfo.processInfo.environment["DARKBLOOM_LIVE_MLX_TESTS"] != nil + && ProcessInfo.processInfo.environment["DARKBLOOM_LIVE_MLX_GEMMA"] != nil + ) + ) + func fastPathVsEngine() async throws { + if LiveInferenceFixtures.ensureMetallibColocated() == nil { + Issue.record("mlx.metallib not found near test bundle or in MLX_METALLIB_PATH/SOURCE") + return + } + MLX.GPU.set(memoryLimit: 96 * 1024 * 1024 * 1024) + + let modelID = ProcessInfo.processInfo.environment["DARKBLOOM_GEMMA_MODEL"] + ?? "mlx-community/gemma-4-26b-a4b-it-8bit" + guard let modelDir = ModelScanner.resolveLocalPath(modelID: modelID) else { + Issue.record("model '\(modelID)' is not in the local cache") + return + } + + let container = try await VLMModelFactory.shared.loadContainer( + from: modelDir, using: LocalTokenizerLoader()) + + let scheduler = BatchScheduler( + maxConcurrentRequests: 1, + pendingTimeout: .seconds(120), + defaultMaxTokens: 512 + ) + await scheduler.loadModel(container: container, modelId: modelID) + defer { Task { await scheduler.unloadModel() } } + + let prompt = "Write a detailed technical explanation of sparse " + + "mixture-of-experts inference on Apple Silicon." + let promptTokens: [Int] = try await container.perform { ctx in + try ctx.tokenizer.applyChatTemplate( + messages: [["role": "user", "content": prompt]], + tools: nil, + additionalContext: nil) + } + + let maxTokens = 256 + let warmups = 1 + let iterations = 3 + + // Engine (batched) baseline first, then the fast path. Order is fixed so + // both pay an equal share of any thermal drift across the run. + let engine = await measure( + scheduler: scheduler, promptTokens: promptTokens, maxTokens: maxTokens, + fastPath: false, warmups: warmups, iterations: iterations) + let fast = await measure( + scheduler: scheduler, promptTokens: promptTokens, maxTokens: maxTokens, + fastPath: true, warmups: warmups, iterations: iterations) + await scheduler._setForceB1FastPathForTest(nil) + + let e = engine.median + let f = fast.median + let ratio = e.tokensPerSecond > 0 ? f.tokensPerSecond / e.tokensPerSecond : 0 + let commonPrefix = sharedPrefixCount(e.text, f.text) + + print(""" + [b1-fastpath-benchmark] model=\(modelID) prompt_tokens=\(promptTokens.count) max_tokens=\(maxTokens) + [b1-fastpath-benchmark] engine_median_tps=\(fmt(e.tokensPerSecond)) completion=\(e.completionTokens) + [b1-fastpath-benchmark] fast_median_tps=\(fmt(f.tokensPerSecond)) completion=\(f.completionTokens) + [b1-fastpath-benchmark] speedup=\(fmt(ratio))x shared_text_prefix_chars=\(commonPrefix) + """) + if ProcessInfo.processInfo.environment["DARKBLOOM_GEMMA_PRINT_TEXT"] != nil { + print("[b1-fastpath-benchmark] engine_sample=\(e.text.prefix(160).debugDescription)") + print("[b1-fastpath-benchmark] fast_sample=\(f.text.prefix(160).debugDescription)") + } + + // Correctness: both paths must produce real output. + #expect(e.error == nil, "engine path errored: \(e.error ?? "")") + #expect(f.error == nil, "fast path errored: \(f.error ?? "")") + #expect(!e.text.isEmpty, "engine path produced empty text") + #expect(!f.text.isEmpty, "fast path produced empty text") + #expect(e.completionTokens > 0 && f.completionTokens > 0) + // Greedy decode from the same prompt should agree on the opening text + // (the first argmax over the same prefill logits). A long shared prefix + // is strong evidence the fast path is byte-compatible; FP differences in + // the batched vs. single-row kernels can diverge later, so we only + // require a non-trivial shared opening rather than full equality. + #expect(commonPrefix >= 8, "fast/engine greedy text diverged immediately (shared=\(commonPrefix))") + + // Throughput: the fast path must not regress materially. The whole point + // is that it is faster; allow generous noise so a busy CI box can't flake. + #expect( + f.tokensPerSecond >= e.tokensPerSecond * 0.85, + "fast path regressed vs engine (\(fmt(f.tokensPerSecond)) < 0.85 * \(fmt(e.tokensPerSecond)))" + ) + + // Cleanup invariant: no reservations / bridges left behind. + let cap = await scheduler.capacity() + #expect(cap.activeRequests == 0, "left \(cap.activeRequests) active requests") + #expect(cap.pendingRequests == 0, "left \(cap.pendingRequests) pending requests") + let fastTasks = await scheduler._fastPathTaskCountForTest() + #expect(fastTasks == 0, "left \(fastTasks) fast-path tasks tracked") + } + + private func fmt(_ value: Double) -> String { String(format: "%.1f", value) } + + /// Number of leading characters shared by two strings. + private func sharedPrefixCount(_ a: String, _ b: String) -> Int { + let ca = Array(a), cb = Array(b) + var i = 0 + while i < ca.count, i < cb.count, ca[i] == cb[i] { i += 1 } + return i + } +} + +private extension Duration { + var asSeconds: Double { + Double(components.seconds) + Double(components.attoseconds) / 1e18 + } +} diff --git a/provider-swift/Tests/ProviderCoreTests/ContinuousBatchingLiveTests.swift b/provider-swift/Tests/ProviderCoreTests/ContinuousBatchingLiveTests.swift index ccb45a7c..f59fbbc0 100644 --- a/provider-swift/Tests/ProviderCoreTests/ContinuousBatchingLiveTests.swift +++ b/provider-swift/Tests/ProviderCoreTests/ContinuousBatchingLiveTests.swift @@ -87,7 +87,7 @@ struct ContinuousBatchingLiveTests { ) ) func gemma4VLMMixedLengthCoherent() async throws { - try ensureMetallibAvailable() + guard ensureMetallibAvailable() else { return } MLX.GPU.set(memoryLimit: 96 * 1024 * 1024 * 1024) let modelID = ProcessInfo.processInfo.environment["DARKBLOOM_GEMMA_MODEL"] @@ -174,6 +174,173 @@ struct ContinuousBatchingLiveTests { return false } + /// Long-context, mixed-length B=3 correctness for the optimized + /// `BatchRotatingKVCache` decode ring + per-row RoPE offset. Row 2 is a + /// ~1k-token prompt, so prompt + generation crosses the Gemma sliding + /// window (`slidingWindow == 1024`) DURING decode — the regime that + /// exercises the fast path's window slide / ring compaction and, crucially, + /// the per-row `batchOffset` past the window (the scalar `cache.offset` + /// caps at the window and mis-positions every post-window query, which the + /// `gemma4VLMGraphOffsetArray` fix corrects). Each row's batched greedy + /// output is compared against its single-stream (`RotatingKVCache`) + /// reference and must (a) never degenerate into repetition and (b) track + /// the reference at least past the deterministic floor. ≥64 tokens are + /// generated so the long row decodes well past the window edge. + @Test( + "Gemma 4 VLM long-context (~1k) mixed-length B=3 tracks solo + stays coherent", + .enabled(if: + ProcessInfo.processInfo.environment["DARKBLOOM_LIVE_MLX_TESTS"] != nil + && ProcessInfo.processInfo.environment["DARKBLOOM_LIVE_MLX_GEMMA"] != nil + ) + ) + func gemma4VLMLongContextMixedB3() async throws { + guard ensureMetallibAvailable() else { return } + MLX.GPU.set(memoryLimit: 96 * 1024 * 1024 * 1024) + + let modelID = ProcessInfo.processInfo.environment["DARKBLOOM_GEMMA_MODEL"] + ?? "mlx-community/gemma-4-26B-A4B-it-qat-4bit" + guard let modelDir = ModelScanner.resolveLocalPath(modelID: modelID) else { + Issue.record("model '\(modelID)' is not in the local cache") + return + } + let container = try await VLMModelFactory.shared.loadContainer( + from: modelDir, using: LocalTokenizerLoader()) + + // Row 0: short, low-entropy (deterministic floor). Row 1: medium. + // Row 2: ~1k tokens — long enough that prompt + 64 generated tokens + // crosses the 1024 sliding window mid-decode. + let longBody = String( + repeating: "Renewable energy has reshaped the global grid in many ways. ", + count: 130) + let prompts = [ + "Reply with the single word 'ocean'.", + "Briefly, what is photosynthesis?", + "Read the following notes and then summarize them in one paragraph.\n\n" + + longBody, + ] + let encoded: [[Int]] = try await container.perform { ctx in + try prompts.map { + try ctx.tokenizer.applyChatTemplate( + messages: [["role": "user", "content": $0]], + tools: nil, additionalContext: nil) + } + } + let lengths = encoded.map { $0.count } + print("[gemma4-vlm-longctx] prompt token lengths: \(lengths)") + #expect( + lengths[2] >= 900, + Comment(rawValue: "long row only \(lengths[2]) tokens; want ~1k to cross the window")) + + let maxTokens = 80 // ≥ 64; long row decodes past the 1024 window edge + let batched = try await runBatchedEngine( + container: container, modelID: modelID, prompts: encoded, maxTokens: maxTokens) + let single = await singleStreamGreedy( + container: container, prompts: encoded, maxTokens: maxTokens) + + let eos = 106 + for row in 0 ..< prompts.count { + let toks = batched[row] + let text = await container.decode(tokenIds: toks) + print("[gemma4-vlm-longctx] batched row \(row) (\(toks.count) toks): \(text.prefix(160))") + #expect( + !Self.hasDegenerateRepetition(toks), + Comment(rawValue: "batched row \(row) degenerates into repetition: \(toks)")) + + let singleHead = Array(single[row].prefix(while: { $0 != eos })) + let batchedHead = Array(toks.prefix(while: { $0 != eos })) + let match = zip(batchedHead, singleHead).prefix(while: ==).count + print( + "[gemma4-vlm-longctx] row \(row): batched/solo prefix match = \(match) " + + "(batchedHead=\(batchedHead.count), soloHead=\(singleHead.count))") + // Row 0 is low-entropy → strong deterministic agreement. Rows 1/2 + // are higher-entropy MoE continuations where bf16 argmax flips can + // diverge after the floor; the regression we guard is repetition / + // immediate divergence, so require a small positive prefix match. + let required = row == 0 ? 3 : 1 + #expect( + match >= required, + Comment(rawValue: + "row \(row) diverges below floor \(required) (match=\(match), " + + "batched=\(batchedHead.prefix(8)), solo=\(singleHead.prefix(8)))")) + } + } + + /// Provider-stack decode throughput at B=1/2/3 for the optimized + /// `BatchRotatingKVCache`. Drives the same `BatchedEngine` the provider + /// uses (via `runBatchedEngine`) with a ~1k-token prompt per row and times + /// the full prefill+decode, reporting aggregate tok/s per batch width. + /// + /// OLD vs NEW comparison (the optimization is a runtime gate, so no rebuild + /// is needed): + /// ``` + /// # NEW (in-place decode ring, default): + /// DARKBLOOM_LIVE_MLX_TESTS=1 DARKBLOOM_LIVE_MLX_GEMMA=1 DARKBLOOM_BENCH=1 \ + /// swift test --filter gemma4DecodeRingBenchmarkB1B2B3 + /// # OLD (legacy concat+trim path): + /// DARKBLOOM_FAST_BATCH_ROTATING_KV=0 DARKBLOOM_LIVE_MLX_TESTS=1 \ + /// DARKBLOOM_LIVE_MLX_GEMMA=1 DARKBLOOM_BENCH=1 \ + /// swift test --filter gemma4DecodeRingBenchmarkB1B2B3 + /// ``` + /// `DARKBLOOM_BENCH_OUTPUT` overrides the generated-token count (default + /// 256; set 512 to match the production decode benchmark). Diagnostic + /// (prints tok/s); the only assertion is that every row produced output. + @Test( + "BENCHMARK: gemma4DecodeRingBenchmarkB1B2B3 (BatchRotatingKVCache decode TPS)", + .enabled(if: + ProcessInfo.processInfo.environment["DARKBLOOM_LIVE_MLX_TESTS"] != nil + && ProcessInfo.processInfo.environment["DARKBLOOM_LIVE_MLX_GEMMA"] != nil + && ProcessInfo.processInfo.environment["DARKBLOOM_BENCH"] != nil + ) + ) + func gemma4DecodeRingBenchmarkB1B2B3() async throws { + guard ensureMetallibAvailable() else { return } + MLX.GPU.set(memoryLimit: 96 * 1024 * 1024 * 1024) + + let env = ProcessInfo.processInfo.environment + let fastGate = env["DARKBLOOM_FAST_BATCH_ROTATING_KV"].map { + !["0", "false", "no", "off"].contains($0.lowercased()) + } ?? true + let outputTokens = env["DARKBLOOM_BENCH_OUTPUT"].flatMap { Int($0) } ?? 256 + + let modelID = env["DARKBLOOM_GEMMA_MODEL"] + ?? "mlx-community/gemma-4-26B-A4B-it-qat-4bit" + guard let modelDir = ModelScanner.resolveLocalPath(modelID: modelID) else { + Issue.record("model '\(modelID)' is not in the local cache") + return + } + let container = try await VLMModelFactory.shared.loadContainer( + from: modelDir, using: LocalTokenizerLoader()) + + // ~973-token prompt to match the production decode benchmark; with + // outputTokens decode steps this crosses the 1024 sliding window. + let body = String( + repeating: "Renewable energy reshaped the global grid in many ways. ", count: 120) + let encoded: [Int] = try await container.perform { ctx in + try ctx.tokenizer.applyChatTemplate( + messages: [["role": "user", "content": "Summarize:\n\n" + body]], + tools: nil, additionalContext: nil) + } + print( + "[gemma4-bench] fast_path=\(fastGate) prompt_tokens=\(encoded.count) " + + "output_tokens=\(outputTokens) model=\(modelID)") + + for B in [1, 2, 3] { + let prompts = Array(repeating: encoded, count: B) + let t0 = Date() + let toks = try await runBatchedEngine( + container: container, modelID: modelID, prompts: prompts, maxTokens: outputTokens) + let dt = Date().timeIntervalSince(t0) + let generated = toks.reduce(0) { $0 + $1.count } + let aggregateTPS = dt > 0 ? Double(generated) / dt : 0 + for t in toks { + #expect(!t.isEmpty, Comment(rawValue: "B=\(B): a row produced no tokens")) + } + print(String( + format: "[gemma4-bench] fast=%@ B=%d: %d tokens / %.2fs = %.1f tok/s aggregate", + "\(fastGate)", B, generated, dt, aggregateTPS)) + } + } + @Test( "Qwen 3.5 0.8B-MLX-4bit (hybrid SSM+attention), B=2", .enabled(if: ProcessInfo.processInfo.environment["DARKBLOOM_LIVE_MLX_TESTS"] != nil) @@ -197,7 +364,7 @@ struct ContinuousBatchingLiveTests { .enabled(if: ProcessInfo.processInfo.environment["DARKBLOOM_LIVE_MLX_TESTS"] != nil) ) func samePromptDeterministicAcrossBatchPositions() async throws { - try ensureMetallibAvailable() + guard ensureMetallibAvailable() else { return } let modelID = "mlx-community/Qwen3-0.6B-8bit" guard let modelDir = ModelScanner.resolveLocalPath(modelID: modelID) else { @@ -243,7 +410,7 @@ struct ContinuousBatchingLiveTests { maxTokens: Int, wiredMemoryGB: Int? = nil ) async throws { - try ensureMetallibAvailable() + guard ensureMetallibAvailable() else { return } if let wiredMemoryGB { MLX.GPU.set(memoryLimit: wiredMemoryGB * 1024 * 1024 * 1024) } @@ -424,12 +591,21 @@ struct ContinuousBatchingLiveTests { /// Place the matching `mlx.metallib` next to the test runner so the MLX /// C++ runtime's `dladdr` lookup finds it on the first GPU call. - private func ensureMetallibAvailable() throws { + /// + /// Returns `true` when the metallib is colocated (GPU work can proceed) and + /// `false` — after recording an issue — when it is missing, so callers do + /// `guard ensureMetallibAvailable() else { return }` and STOP before the + /// first GPU call. This mirrors `Gemma4DecodeProfileTests`, which returns on + /// a missing metallib instead of pressing on into a hard crash on the first + /// MLX kernel dispatch. + private func ensureMetallibAvailable() -> Bool { if LiveInferenceFixtures.ensureMetallibColocated() == nil { let msg = "mlx.metallib not found near test bundle or in MLX_METALLIB_PATH/SOURCE; " + "run scripts/fetch-metallib.sh debug to install it for local runs" Issue.record(Comment(rawValue: msg)) + return false } + return true } // MARK: - Eviction-and-admission @@ -443,7 +619,7 @@ struct ContinuousBatchingLiveTests { .enabled(if: ProcessInfo.processInfo.environment["DARKBLOOM_LIVE_MLX_TESTS"] != nil) ) func evictionAndAdmissionMatchesSolo() async throws { - try ensureMetallibAvailable() + guard ensureMetallibAvailable() else { return } let modelID = "mlx-community/Qwen3-0.6B-8bit" guard let modelDir = ModelScanner.resolveLocalPath(modelID: modelID) else { @@ -605,7 +781,7 @@ struct ContinuousBatchingLiveTests { ) ) func resourceCountTrajectoryProbeSmall() async throws { - try ensureMetallibAvailable() + guard ensureMetallibAvailable() else { return } // Mimic the big-RAM box: BOTH the cache-size trim (cacheLimit / // max_pool_size_) AND the byte-pressure reclaim (memoryLimit, which drives // gc_limit_ in MetalAllocator::malloc) must be lifted, or the byte path @@ -676,7 +852,7 @@ struct ContinuousBatchingLiveTests { ) ) func resourceCountTrajectoryProbe() async throws { - try ensureMetallibAvailable() + guard ensureMetallibAvailable() else { return } // Mimic the 128 GB box: lift BOTH the cache-size trim (cacheLimit) AND the // byte-pressure reclaim (memoryLimit -> gc_limit_ in diff --git a/provider-swift/Tests/ProviderCoreTests/Gemma4DecodeProfileTests.swift b/provider-swift/Tests/ProviderCoreTests/Gemma4DecodeProfileTests.swift new file mode 100644 index 00000000..7108668c --- /dev/null +++ b/provider-swift/Tests/ProviderCoreTests/Gemma4DecodeProfileTests.swift @@ -0,0 +1,110 @@ +import Foundation +import Testing +import MLX +import MLXLMCommon +import MLXVLM +@testable import ProviderCore + +/// Live Gemma4 26B-A4B decode profile. +/// +/// Gated because it loads the local 26B model. +/// Run with: +/// DARKBLOOM_LIVE_MLX_TESTS=1 DARKBLOOM_LIVE_MLX_GEMMA=1 \ +/// DARKBLOOM_GEMMA_MODEL=mlx-community/gemma-4-26b-a4b-it-8bit \ +/// swift test --filter Gemma4DecodeProfileTests +/// +/// Set `DARKBLOOM_GEMMA_PRINT_TEXT=1` to print a short decoded sample. +@Suite("Gemma4 decode profile", .serialized) +struct Gemma4DecodeProfileTests { + @Test( + "B=1 raw decode TPS", + .enabled(if: + ProcessInfo.processInfo.environment["DARKBLOOM_LIVE_MLX_TESTS"] != nil + && ProcessInfo.processInfo.environment["DARKBLOOM_LIVE_MLX_GEMMA"] != nil + ) + ) + func rawDecodeB1() async throws { + if LiveInferenceFixtures.ensureMetallibColocated() == nil { + Issue.record("mlx.metallib not found near test bundle or in MLX_METALLIB_PATH/SOURCE") + return + } + MLX.GPU.set(memoryLimit: 96 * 1024 * 1024 * 1024) + + let modelID = ProcessInfo.processInfo.environment["DARKBLOOM_GEMMA_MODEL"] + ?? "mlx-community/gemma-4-26b-a4b-it-8bit" + guard let modelDir = ModelScanner.resolveLocalPath(modelID: modelID) else { + Issue.record("model '\(modelID)' is not in the local cache") + return + } + + let container = try await VLMModelFactory.shared.loadContainer( + from: modelDir, using: LocalTokenizerLoader()) + + let prompt = "Write a detailed technical explanation of sparse mixture-of-experts inference on Apple Silicon." + let encoded: [Int] = try await container.perform { ctx in + try ctx.tokenizer.applyChatTemplate( + messages: [["role": "user", "content": prompt]], + tools: nil, + additionalContext: nil) + } + + let tps: Double = await container.perform { ctx in + let cache = ctx.model.newCache(parameters: nil) + let promptArray = MLXArray(encoded.map { Int32($0) }).reshaped([1, encoded.count]) + + var logits = ctx.model.callAsFunction(promptArray, cache: cache) + var nextToken = argMax(logits[0..., -1, 0...], axis: -1) + asyncEval(nextToken) + + let warmups = 8 + for _ in 0 ..< warmups { + logits = ctx.model.callAsFunction(nextToken.reshaped([1, 1]), cache: cache) + let sampled = argMax(logits[0..., -1, 0...], axis: -1) + asyncEval(sampled) + eval(nextToken) + nextToken = sampled + } + + let tokens = 128 + let start = DispatchTime.now().uptimeNanoseconds + for _ in 0 ..< tokens { + logits = ctx.model.callAsFunction(nextToken.reshaped([1, 1]), cache: cache) + let sampled = argMax(logits[0..., -1, 0...], axis: -1) + asyncEval(sampled) + eval(nextToken) + nextToken = sampled + } + eval(nextToken) + + let elapsedMs = Double(DispatchTime.now().uptimeNanoseconds - start) / 1_000_000.0 + let msPerToken = elapsedMs / Double(tokens) + return 1_000.0 / msPerToken + } + + print("[gemma4-decode-profile] model=\(modelID) prompt_tokens=\(encoded.count) raw_b1_tps=\(String(format: "%.1f", tps))") + + if ProcessInfo.processInfo.environment["DARKBLOOM_GEMMA_PRINT_TEXT"] != nil { + let text: String = await container.perform { ctx in + let cache = ctx.model.newCache(parameters: nil) + let promptArray = MLXArray(encoded.map { Int32($0) }).reshaped([1, encoded.count]) + var logits = ctx.model.callAsFunction(promptArray, cache: cache) + var nextToken = argMax(logits[0..., -1, 0...], axis: -1) + asyncEval(nextToken) + + var generated: [Int] = [] + generated.reserveCapacity(64) + for _ in 0 ..< 64 { + eval(nextToken) + let token = nextToken.item(Int.self) + generated.append(token) + logits = ctx.model.callAsFunction(nextToken.reshaped([1, 1]), cache: cache) + let sampled = argMax(logits[0..., -1, 0...], axis: -1) + asyncEval(sampled) + nextToken = sampled + } + return ctx.tokenizer.decode(tokenIds: generated, skipSpecialTokens: true) + } + print("[gemma4-decode-profile] sample=\(text)") + } + } +} diff --git a/provider-swift/Tests/ProviderCoreTests/VLMCapReservationTests.swift b/provider-swift/Tests/ProviderCoreTests/VLMCapReservationTests.swift index b1eb24c0..12c4f24b 100644 --- a/provider-swift/Tests/ProviderCoreTests/VLMCapReservationTests.swift +++ b/provider-swift/Tests/ProviderCoreTests/VLMCapReservationTests.swift @@ -151,6 +151,30 @@ private let gib: UInt64 = 1024 * 1024 * 1024 #expect(tokens > 100) } +@Test func projectedKVTokensReservesEverySampledVideoFrame() { + // A video is NOT a flat per-video allotment: the Gemma4 processor samples up + // to `maxVideoFramesSampled` frames and expands EACH into its own image-like + // soft-token block, so the KV reservation must cover every sampled frame + // (32 × per-frame), not a single frame. The old flat 4096 covered only ~4 + // frames and badly under-reserved a full clip — the bug this guards. + let videoURI = "data:video/mp4;base64,AAAAAA" + let req = OpenAIChatCompletionRequest( + model: "m", + messages: [.init(role: .user, content: .parts([.text("describe"), .videoURL(videoURI)]))], + maxTokens: 25) + // contextLength 0 = "unknown" → no clamp, so the full per-frame span shows. + let tokens = VLMRequestInference.projectedKVTokens( + req, defaultMaxTokens: 1024, contextLength: 0) + // The per-video charge scales with every sampled frame (32 × per-frame). + #expect( + VLMRequestInference.visionTokensPerVideo + == VLMRequestInference.maxVideoFramesSampled * VLMRequestInference.visionTokensPerImage) + // KV span covers every sampled frame's soft tokens + the 25 output tokens… + #expect(tokens >= VLMRequestInference.visionTokensPerVideo + 25) + // …and strictly out-reserves a single-frame charge (the under-reservation bug). + #expect(tokens > VLMRequestInference.visionTokensPerImage + 25) +} + @Test func projectedKVTokensClampsPromptPlusVisionToContext() { // Many images would project a huge prompt+vision span; it must clamp to the // model's context window (the cache can't hold more input tokens than that),