diff --git a/CHANGELOG.md b/CHANGELOG.md index f8635a23..5d0a0239 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ from version 5.0.0 onward. Pre-fork releases (`1.x`–`4.2.0`) were authored by - Real-model tool-calling integration tests for blocking and streaming required tool calls (`ToolCallingIntegrationTest`, Qwen2.5-1.5B-Instruct), wired into CI and `validate-models`. - End-to-end vision input across blocking, typed `ChatRequest`, streaming, and OpenAI-compatible request mapping; real-model tests verify that distinct red and blue images produce the correct semantic answers. - Explicit `setMmprojAuto(boolean)` and `setMmprojOffload(boolean)` controls, including the upstream `--no-mmproj-auto` and `--no-mmproj-offload` flags. +- Per-request KV controls: `InferenceParameters.withSlotId(int)` and `withCacheReuse(int)`. +- Typed cache observability through `Usage.getCachedTokens()`, `Usage.getProcessedPromptTokens()`, `SlotMetrics`, and `ServerMetrics.getSlotMetrics()`. +- Authenticated JSON `GET /metrics` and `GET /slots` endpoints on the embedded server. ### Changed - Unified `CONTRIBUTING.md` and `SECURITY.md` structure with sibling repositories in the project family. @@ -30,6 +33,8 @@ from version 5.0.0 onward. Pre-fork releases (`1.x`–`4.2.0`) were authored by - Preserved decoded image buffers across the JNI chat boundary and submitted media requests through llama.cpp's upstream multimodal task path instead of silently tokenizing them as text-only prompts. - Preserved multipart image content when using the typed `ChatRequest` serializer. - The standalone OpenAI-compatible server now advertises vision only when the loaded model confirms usable vision support. +- `Session` now pins every inference request to its configured slot, so generation and slot save/restore/erase target the same KV state. +- Cached-token usage is preserved through typed Java responses and OpenAI Responses/Anthropic blocking and streaming adapters. ### Added - Reasoning-budget tests (Qwen3-0.6B). diff --git a/README.md b/README.md index a844027d..4b049cad 100644 --- a/README.md +++ b/README.md @@ -473,6 +473,23 @@ a JSON response, matching the HTTP server's contract: Server state is exposed via `getMetrics()`, `eraseSlot(int)`, `saveSlot(int, String)`, `restoreSlot(int, String)`, and `getModelMeta()`. +### Prompt and KV Cache Reuse + +Prompt-prefix reuse is enabled by default in llama.cpp and can be controlled per request with +`InferenceParameters.withCachePrompt(boolean)`. `withCacheReuse(int)` enables non-prefix chunk reuse, +while `withSlotId(int)` pins a request to a specific server slot. `Session` applies its slot id to every +request, so generation and `save`/`restore` operate on the same KV state. + +Typed results expose logical prompt, generated, cached prompt, and evaluated prompt counts through +`Usage`. Per-request timing also remains available through `Timings.getCacheN()`. +`LlamaModel.getMetricsTyped().getSlotMetrics()` reports each slot's logical, processed, cached, +decoded, and remaining token counts. + +The embedded HTTP server exposes the same native JSON at authenticated `GET /metrics`, with the slot +array alone at `GET /slots`. OpenAI responses preserve +`usage.prompt_tokens_details.cached_tokens`; Responses API output uses +`usage.input_tokens_details.cached_tokens`; Anthropic output uses `cache_read_input_tokens`. + ### OpenAI-compatible HTTP server `net.ladenthin.llama.server.OpenAiCompatServer` turns a loaded model into a local @@ -488,6 +505,8 @@ serves: | `POST /v1/rerank` (requires `--reranking`) | `LlamaModel.handleRerank` (reshaped to `results`/`data`) | | `POST /infill` | `LlamaModel.handleInfill` (fill-in-the-middle autocomplete) | | `GET /v1/models` | the configured model id | +| `GET /metrics` | native server and per-slot token/cache counters (JSON) | +| `GET /slots` | native per-slot token/cache counters (JSON array) | | `GET /health` | static `{"status":"ok"}` (unauthenticated) | Chat completions support **streaming via Server-Sent Events** and non-streaming, forwarding diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 087cd91c..e61c35d3 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -205,6 +205,7 @@ static void populate_completion_task(server_task &task, jllama_context *jctx, in } } task.params = server_schema::eval_llama_cmpl_schema(jctx->vocab, jctx->params, n_ctx_slot, logit_bias_eog, data); + configure_task_slot_impl(task, data); } [[nodiscard]] static jint dispatch_streaming_completion(JNIEnv *env, jllama_context *jctx, const json &data, diff --git a/src/main/cpp/jni_helpers.hpp b/src/main/cpp/jni_helpers.hpp index 425c3fc6..d29805dc 100644 --- a/src/main/cpp/jni_helpers.hpp +++ b/src/main/cpp/jni_helpers.hpp @@ -14,7 +14,7 @@ // require_json_field_impl, jint_array_to_tokens_impl // // Layer B — JNI + server orchestration: -// configure_multimodal_task_impl, +// configure_multimodal_task_impl, configure_task_slot_impl, // json_to_jstring_impl, results_to_jstring_impl, // embedding_to_jfloat_array_impl, tokens_to_jint_array_impl // @@ -175,6 +175,12 @@ inline void erase_reader(jllama_context *jctx, int id_task) { return true; } +// Match server_routes::handle_completions_impl(): slot selection is task +// metadata, not part of task_params, so eval_llama_cmpl_schema() does not set it. +inline void configure_task_slot_impl(server_task &task, const json &data) { + task.id_slot = json_value(data, "id_slot", -1); +} + // --------------------------------------------------------------------------- // json_to_jstring_impl // diff --git a/src/main/java/net/ladenthin/llama/Session.java b/src/main/java/net/ladenthin/llama/Session.java index 9efe15e7..c6603468 100644 --- a/src/main/java/net/ladenthin/llama/Session.java +++ b/src/main/java/net/ladenthin/llama/Session.java @@ -185,7 +185,14 @@ public void close() { * @return inference parameters carrying the system message + wire messages */ private InferenceParameters buildParams(@Nullable String systemMessage, List> wireMessages) { - InferenceParameters params = InferenceParameters.empty().withMessages(systemMessage, wireMessages); - return paramsCustomizer == null ? params : paramsCustomizer.apply(params); + InferenceParameters params = InferenceParameters.empty() + .withMessages(systemMessage, wireMessages) + .withCachePrompt(true); + if (paramsCustomizer != null) { + params = paramsCustomizer.apply(params); + } + // Apply last: a Session must never drift away from the slot used by + // save(), restore(), and close(), even if a customizer supplies another id. + return params.withSlotId(slotId); } } diff --git a/src/main/java/net/ladenthin/llama/json/ChatResponseParser.java b/src/main/java/net/ladenthin/llama/json/ChatResponseParser.java index 72d2dd44..8b66e8fa 100644 --- a/src/main/java/net/ladenthin/llama/json/ChatResponseParser.java +++ b/src/main/java/net/ladenthin/llama/json/ChatResponseParser.java @@ -150,9 +150,14 @@ public ChatResponse parseResponse(String json) { JsonNode node = OBJECT_MAPPER.readTree(json); String id = node.path("id").asText(""); List choices = parseChoices(node.path("choices")); + JsonNode usageNode = node.path("usage"); Usage usage = new Usage( - node.path("usage").path("prompt_tokens").asLong(0L), - node.path("usage").path("completion_tokens").asLong(0L)); + usageNode.path("prompt_tokens").asLong(0L), + usageNode.path("completion_tokens").asLong(0L), + usageNode + .path("prompt_tokens_details") + .path("cached_tokens") + .asLong(0L)); Timings timings = Timings.fromJson(node.path("timings")); TimingsLogger.log(timings); return new ChatResponse(id, choices, usage, timings, json); diff --git a/src/main/java/net/ladenthin/llama/json/CompletionResponseParser.java b/src/main/java/net/ladenthin/llama/json/CompletionResponseParser.java index b0ce96b0..dd1686c5 100644 --- a/src/main/java/net/ladenthin/llama/json/CompletionResponseParser.java +++ b/src/main/java/net/ladenthin/llama/json/CompletionResponseParser.java @@ -187,10 +187,11 @@ public CompletionResult parseCompletionResult(String json) { try { JsonNode node = OBJECT_MAPPER.readTree(json); String text = extractContent(node); + Timings timings = Timings.fromJson(node.path("timings")); Usage usage = new Usage( node.path("tokens_evaluated").asLong(0L), - node.path("tokens_predicted").asLong(0L)); - Timings timings = Timings.fromJson(node.path("timings")); + node.path("tokens_predicted").asLong(0L), + Math.max(0, timings.getCacheN())); TimingsLogger.log(timings); List logprobs = parseLogprobs(node); StopReason stopReason = diff --git a/src/main/java/net/ladenthin/llama/parameters/InferenceParameters.java b/src/main/java/net/ladenthin/llama/parameters/InferenceParameters.java index b10a6b87..dd2b8fe4 100644 --- a/src/main/java/net/ladenthin/llama/parameters/InferenceParameters.java +++ b/src/main/java/net/ladenthin/llama/parameters/InferenceParameters.java @@ -58,6 +58,8 @@ public final class InferenceParameters extends JsonParameters { private static final String PARAM_INPUT_PREFIX = "input_prefix"; private static final String PARAM_INPUT_SUFFIX = "input_suffix"; private static final String PARAM_CACHE_PROMPT = "cache_prompt"; + private static final String PARAM_CACHE_REUSE = "n_cache_reuse"; + private static final String PARAM_SLOT_ID = "id_slot"; private static final String PARAM_STREAM_OPTIONS = "stream_options"; private static final String PARAM_RESPONSE_FORMAT = "response_format"; private static final String PARAM_N_PREDICT = "n_predict"; @@ -204,6 +206,36 @@ public InferenceParameters withCachePrompt(boolean cachePrompt) { return withScalar(PARAM_CACHE_PROMPT, cachePrompt); } + /** + * Returns a new request with the minimum reusable KV-cache chunk size replaced. + * A value of {@code 0} disables non-prefix chunk reuse. Ordinary common-prefix + * reuse remains controlled by {@link #withCachePrompt(boolean)}. + * + * @param cacheReuse minimum reusable chunk size, or {@code 0} to disable + * @return a new instance; this instance is unchanged + */ + public InferenceParameters withCacheReuse(int cacheReuse) { + if (cacheReuse < 0) { + throw new IllegalArgumentException("cacheReuse must be non-negative"); + } + return withScalar(PARAM_CACHE_REUSE, cacheReuse); + } + + /** + * Returns a new request pinned to a llama.cpp server slot. Pinning is useful + * for deterministic multi-turn KV reuse and for matching inference with + * {@code saveSlot}/{@code restoreSlot} operations. + * + * @param slotId non-negative slot identifier + * @return a new instance; this instance is unchanged + */ + public InferenceParameters withSlotId(int slotId) { + if (slotId < 0) { + throw new IllegalArgumentException("slotId must be non-negative"); + } + return withScalar(PARAM_SLOT_ID, slotId); + } + /** * Returns a new request with the number of tokens to predict replaced * (default: -1, -1 = infinity, -2 = until context filled). diff --git a/src/main/java/net/ladenthin/llama/parameters/ModelParameters.java b/src/main/java/net/ladenthin/llama/parameters/ModelParameters.java index e605e37a..65278887 100644 --- a/src/main/java/net/ladenthin/llama/parameters/ModelParameters.java +++ b/src/main/java/net/ladenthin/llama/parameters/ModelParameters.java @@ -1398,10 +1398,10 @@ public ModelParameters setKvUnified(boolean kvUnified) { /** * Set the maximum RAM cache size in MiB used to store saved slot KV state. *

- * Requires {@link #setKvUnified(boolean) unified KV} to be enabled. * Set to {@code -1} for no limit, {@code 0} to disable (default: 8192 MiB). - * Together with {@link #setClearIdle} this allows idle slots to be evicted - * from GPU/CPU memory and restored quickly on the next matching request. + * Together with {@link #setClearIdle}, idle slot states are copied into this + * RAM cache and restored on a matching request. Unified KV is required only + * when those idle slots should also be cleared from the active KV buffer. * * @param cacheRamMib maximum cache size in MiB, or {@code -1} for unlimited * @return this builder @@ -1414,14 +1414,13 @@ public ModelParameters setCacheRamMib(int cacheRamMib) { * Enable or disable saving and clearing idle slots when a new task starts. *

* When enabled (the default), idle slots have their KV state saved to the - * RAM cache ({@link #setCacheRamMib}) and are then cleared, freeing GPU/CPU - * memory for the active request. The saved state is transparently restored - * on the next request that shares the same prompt prefix, so cache-hit - * latency is preserved. + * RAM cache ({@link #setCacheRamMib}). With unified KV enabled, the active + * slot state is also cleared, freeing KV-buffer capacity for other requests. + * Without unified KV the RAM-cache copy is still created, but the active + * slot remains allocated. *

- * Requires {@link #setKvUnified(boolean) unified KV} and a non-zero - * {@link #setCacheRamMib RAM cache}. If either dependency is absent the - * server logs a warning and silently disables the feature. + * Requires a non-zero {@link #setCacheRamMib RAM cache}. Unified KV is + * required only for active-buffer eviction. * * @param clearIdle {@code true} to enable idle-slot eviction (default), {@code false} to disable * @return this builder diff --git a/src/main/java/net/ladenthin/llama/server/AnthropicApiSupport.java b/src/main/java/net/ladenthin/llama/server/AnthropicApiSupport.java index c48f8dc7..ea314181 100644 --- a/src/main/java/net/ladenthin/llama/server/AnthropicApiSupport.java +++ b/src/main/java/net/ladenthin/llama/server/AnthropicApiSupport.java @@ -254,7 +254,13 @@ static String toAnthropicResponse(String openAiCompletionJson, String model) { stopReason = anthropicStopReason(choice.path("finish_reason").asText("stop")); JsonNode openAiUsage = completion.path("usage"); if (openAiUsage.isObject()) { - usage.put("input_tokens", openAiUsage.path("prompt_tokens").asInt(0)); + int promptTokens = openAiUsage.path("prompt_tokens").asInt(0); + int cachedTokens = openAiUsage + .path("prompt_tokens_details") + .path("cached_tokens") + .asInt(0); + usage.put("input_tokens", Math.max(0, promptTokens - cachedTokens)); + usage.put("cache_read_input_tokens", cachedTokens); usage.put("output_tokens", openAiUsage.path("completion_tokens").asInt(0)); } } catch (IOException e) { @@ -391,12 +397,20 @@ static String blockStopEvent(int index) { /** {@code message_delta} event carrying the final stop reason. */ static String messageDeltaEvent(String stopReason) { + return messageDeltaEvent(stopReason, 0, 0, 0); + } + + /** Final message delta carrying token usage collected from the trailing OpenAI usage chunk. */ + static String messageDeltaEvent(String stopReason, int inputTokens, int outputTokens, int cachedTokens) { ObjectNode data = OBJECT_MAPPER.createObjectNode(); data.put("type", "message_delta"); ObjectNode delta = data.putObject("delta"); delta.put("stop_reason", stopReason); delta.putNull("stop_sequence"); - data.putObject("usage").put("output_tokens", 0); + ObjectNode usage = data.putObject("usage"); + usage.put("input_tokens", inputTokens); + usage.put("output_tokens", outputTokens); + usage.put("cache_read_input_tokens", cachedTokens); return sseEvent("message_delta", data.toString()); } diff --git a/src/main/java/net/ladenthin/llama/server/AnthropicStreamTranslator.java b/src/main/java/net/ladenthin/llama/server/AnthropicStreamTranslator.java index f5cfc2ff..188365fd 100644 --- a/src/main/java/net/ladenthin/llama/server/AnthropicStreamTranslator.java +++ b/src/main/java/net/ladenthin/llama/server/AnthropicStreamTranslator.java @@ -32,6 +32,9 @@ final class AnthropicStreamTranslator { private int textBlockIndex = -1; private int nextIndex; private String finishReason = "stop"; + private int inputTokens; + private int outputTokens; + private int cachedTokens; AnthropicStreamTranslator(String id, String model) { this.id = id; @@ -60,6 +63,15 @@ String onChunk(String openAiChunkJson) { try { JsonNode chunk = OBJECT_MAPPER.readTree(openAiChunkJson); accumulator.accept(chunk); + JsonNode usage = chunk.path("usage"); + if (usage.isObject()) { + int promptTokens = usage.path("prompt_tokens").asInt(0); + cachedTokens = usage.path("prompt_tokens_details") + .path("cached_tokens") + .asInt(0); + inputTokens = Math.max(0, promptTokens - cachedTokens); + outputTokens = usage.path("completion_tokens").asInt(0); + } JsonNode choice = chunk.path("choices").path(0); JsonNode content = choice.path("delta").path("content"); if (content.isTextual() && !content.asText().isEmpty()) { @@ -102,7 +114,8 @@ String end() { out.append(AnthropicApiSupport.blockStopEvent(index)); } } - out.append(AnthropicApiSupport.messageDeltaEvent(AnthropicApiSupport.anthropicStopReason(finishReason))); + out.append(AnthropicApiSupport.messageDeltaEvent( + AnthropicApiSupport.anthropicStopReason(finishReason), inputTokens, outputTokens, cachedTokens)); out.append(AnthropicApiSupport.messageStopEvent()); return out.toString(); } diff --git a/src/main/java/net/ladenthin/llama/server/LlamaModelBackend.java b/src/main/java/net/ladenthin/llama/server/LlamaModelBackend.java index de289d41..70b74a0c 100644 --- a/src/main/java/net/ladenthin/llama/server/LlamaModelBackend.java +++ b/src/main/java/net/ladenthin/llama/server/LlamaModelBackend.java @@ -40,6 +40,11 @@ final class LlamaModelBackend implements OpenAiBackend { this.mapper = mapper; } + @Override + public String metrics() { + return model.getMetrics(); + } + @Override public String complete(JsonNode request) { return model.chatComplete(mapper.toInferenceParameters(request)); diff --git a/src/main/java/net/ladenthin/llama/server/OpenAiBackend.java b/src/main/java/net/ladenthin/llama/server/OpenAiBackend.java index 5080ff2a..edc5cec5 100644 --- a/src/main/java/net/ladenthin/llama/server/OpenAiBackend.java +++ b/src/main/java/net/ladenthin/llama/server/OpenAiBackend.java @@ -22,6 +22,17 @@ */ interface OpenAiBackend { + /** + * Return llama.cpp server metrics, including per-slot cache counters. + * Test backends may rely on the empty default. + * + * @return metrics JSON + * @throws IOException if metrics cannot be read + */ + default String metrics() throws IOException { + return "{\"slots\":[]}"; + } + /** * Run a non-streaming chat completion ({@code POST /v1/chat/completions}). * diff --git a/src/main/java/net/ladenthin/llama/server/OpenAiCompatServer.java b/src/main/java/net/ladenthin/llama/server/OpenAiCompatServer.java index ef6a5700..03b62f90 100644 --- a/src/main/java/net/ladenthin/llama/server/OpenAiCompatServer.java +++ b/src/main/java/net/ladenthin/llama/server/OpenAiCompatServer.java @@ -44,6 +44,8 @@ *

  • {@code POST /v1/embeddings} — embeddings (requires the model to be loaded in embedding * mode).
  • *
  • {@code GET /v1/models} — advertises the single configured model.
  • + *
  • {@code GET /metrics} — server and per-slot token/cache counters as JSON.
  • + *
  • {@code GET /slots} — the per-slot metrics array as JSON.
  • *
  • {@code GET /health} — liveness probe returning {@code {"status":"ok"}} (no authentication).
  • * * @@ -99,6 +101,12 @@ public final class OpenAiCompatServer implements AutoCloseable { /** The llama.cpp-native server-properties route (context length + modalities). */ public static final String PATH_PROPS = "/props"; + /** llama.cpp server metrics as JSON, including per-slot token/cache counters. */ + public static final String PATH_METRICS = "/metrics"; + + /** llama.cpp slot state array as JSON. */ + public static final String PATH_SLOTS = "/slots"; + /** Ollama-native discovery route (version). */ public static final String PATH_OLLAMA_VERSION = "/api/version"; @@ -166,6 +174,8 @@ public OpenAiCompatServer(LlamaModel model, OpenAiServerConfig config) throws IO register("/", this::handleNotFound); register(PATH_HEALTH, this::handleHealth); register(PATH_PROPS, this::handleProps); + register(PATH_METRICS, this::handleMetrics); + register(PATH_SLOTS, this::handleSlots); // Each route is registered under its canonical path and a bare alias (clients disagree on // whether to include the /v1 prefix), so both forms resolve to the same handler. register(PATH_MODELS, this::handleModels); @@ -573,7 +583,7 @@ private void streamAnthropic(HttpExchange exchange, JsonNode openAiRequest, Stri config.getHeartbeatMillis(), TimeUnit.MILLISECONDS); out.writeStrict(translator.begin()); - backend.stream(openAiRequest, chunkJson -> { + backend.stream(withUsageChunk(openAiRequest), chunkJson -> { String events = translator.onChunk(chunkJson); if (!events.isEmpty()) { out.writeStrict(events); @@ -595,6 +605,15 @@ private void streamAnthropic(HttpExchange exchange, JsonNode openAiRequest, Stri } } + /** Ensure protocol translators receive the native stream's trailing usage chunk. */ + private static JsonNode withUsageChunk(JsonNode request) { + ObjectNode copy = request.deepCopy(); + JsonNode existing = copy.path("stream_options"); + ObjectNode streamOptions = existing.isObject() ? (ObjectNode) existing : copy.putObject("stream_options"); + streamOptions.put("include_usage", true); + return copy; + } + private static String anthropicError(String message) { ObjectNode root = OBJECT_MAPPER.createObjectNode(); root.put("type", "error"); @@ -652,7 +671,7 @@ private void streamResponses(HttpExchange exchange, JsonNode openAiRequest, Stri config.getHeartbeatMillis(), TimeUnit.MILLISECONDS); out.writeStrict(translator.begin()); - backend.stream(openAiRequest, chunkJson -> { + backend.stream(withUsageChunk(openAiRequest), chunkJson -> { String events = translator.onChunk(chunkJson); if (!events.isEmpty()) { out.writeStrict(events); @@ -705,6 +724,37 @@ private void handleHealth(HttpExchange exchange) throws IOException { } } + private void handleMetrics(HttpExchange exchange) throws IOException { + handleMetricsView(exchange, false); + } + + private void handleSlots(HttpExchange exchange) throws IOException { + handleMetricsView(exchange, true); + } + + private void handleMetricsView(HttpExchange exchange, boolean slotsOnly) throws IOException { + try { + if (!"GET".equalsIgnoreCase(exchange.getRequestMethod())) { + sendError(exchange, HTTP_METHOD_NOT_ALLOWED, ERROR_TYPE_REQUEST, "Only GET is supported"); + return; + } + if (!authorized(exchange)) { + sendError(exchange, HTTP_UNAUTHORIZED, ERROR_TYPE_REQUEST, "Missing or invalid API key"); + return; + } + String metrics = backend.metrics(); + if (slotsOnly) { + metrics = OBJECT_MAPPER.readTree(metrics).path("slots").toString(); + } + sendJson(exchange, HTTP_OK, metrics); + } catch (IOException | RuntimeException e) { + LOG.warn("metrics request failed", e); + sendError(exchange, HTTP_SERVER_ERROR, ERROR_TYPE_SERVER, message(e)); + } finally { + exchange.close(); + } + } + private void handleProps(HttpExchange exchange) throws IOException { try { if (!"GET".equalsIgnoreCase(exchange.getRequestMethod())) { diff --git a/src/main/java/net/ladenthin/llama/server/OpenAiRequestMapper.java b/src/main/java/net/ladenthin/llama/server/OpenAiRequestMapper.java index fd12a617..a45d6dd4 100644 --- a/src/main/java/net/ladenthin/llama/server/OpenAiRequestMapper.java +++ b/src/main/java/net/ladenthin/llama/server/OpenAiRequestMapper.java @@ -48,6 +48,20 @@ InferenceParameters toInferenceParameters(JsonNode request) { .withMessagesJson(messages.toString()) .withCachePrompt(true); + // Preserve llama.cpp extensions when advanced clients opt into them. + JsonNode cachePrompt = request.path("cache_prompt"); + if (cachePrompt.isBoolean()) { + params = params.withCachePrompt(cachePrompt.asBoolean()); + } + JsonNode cacheReuse = request.path("n_cache_reuse"); + if (cacheReuse.isIntegralNumber()) { + params = params.withCacheReuse(cacheReuse.asInt()); + } + JsonNode slotId = request.path("id_slot"); + if (slotId.isIntegralNumber()) { + params = params.withSlotId(slotId.asInt()); + } + JsonNode tools = request.path("tools"); if (tools.isArray() && tools.size() > 0) { params = params.withToolsJson(tools.toString()).withUseChatTemplate(true); diff --git a/src/main/java/net/ladenthin/llama/server/ResponsesApiSupport.java b/src/main/java/net/ladenthin/llama/server/ResponsesApiSupport.java index 65abf7a2..5dd6a4e0 100644 --- a/src/main/java/net/ladenthin/llama/server/ResponsesApiSupport.java +++ b/src/main/java/net/ladenthin/llama/server/ResponsesApiSupport.java @@ -178,6 +178,8 @@ static String toResponsesResponse(String openAiCompletionJson, String model, Str usage.put("input_tokens", 0); usage.put("output_tokens", 0); usage.put("total_tokens", 0); + ObjectNode inputTokenDetails = usage.putObject("input_tokens_details"); + inputTokenDetails.put("cached_tokens", 0); try { JsonNode completion = OBJECT_MAPPER.readTree(openAiCompletionJson); JsonNode message = completion.path("choices").path(0).path("message"); @@ -205,6 +207,12 @@ static String toResponsesResponse(String openAiCompletionJson, String model, Str usage.put("input_tokens", in); usage.put("output_tokens", out); usage.put("total_tokens", in + out); + inputTokenDetails.put( + "cached_tokens", + openAiUsage + .path("prompt_tokens_details") + .path("cached_tokens") + .asInt(0)); } } catch (IOException e) { // Defensive: an unexpected body still yields a valid, empty completed response. diff --git a/src/main/java/net/ladenthin/llama/server/ResponsesStreamTranslator.java b/src/main/java/net/ladenthin/llama/server/ResponsesStreamTranslator.java index 17d928f7..eb50e8f0 100644 --- a/src/main/java/net/ladenthin/llama/server/ResponsesStreamTranslator.java +++ b/src/main/java/net/ladenthin/llama/server/ResponsesStreamTranslator.java @@ -39,6 +39,9 @@ final class ResponsesStreamTranslator { private boolean messageOpen; private int nextOutputIndex; private int messageOutputIndex = -1; + private int inputTokens; + private int outputTokens; + private int cachedTokens; ResponsesStreamTranslator(String model, String responseId) { this.model = model; @@ -70,6 +73,7 @@ String onChunk(String openAiChunkJson) { try { JsonNode chunk = OBJECT_MAPPER.readTree(openAiChunkJson); accumulator.accept(chunk); + captureUsage(chunk.path("usage")); JsonNode content = chunk.path("choices").path(0).path("delta").path("content"); if (content.isTextual() && !content.asText().isEmpty()) { if (!messageOpen) { @@ -142,11 +146,25 @@ String end() { ObjectNode completed = ResponsesApiSupport.dataObject(); ObjectNode response = ResponsesApiSupport.newResponseShell(model, responseId, "completed"); response.set("output", output); + ObjectNode usage = response.putObject("usage"); + usage.put("input_tokens", inputTokens); + usage.put("output_tokens", outputTokens); + usage.put("total_tokens", inputTokens + outputTokens); + usage.putObject("input_tokens_details").put("cached_tokens", cachedTokens); completed.set("response", response); out.append(ResponsesApiSupport.sseEvent("response.completed", sequence++, completed)); return out.toString(); } + private void captureUsage(JsonNode usage) { + if (!usage.isObject()) { + return; + } + inputTokens = usage.path("prompt_tokens").asInt(0); + outputTokens = usage.path("completion_tokens").asInt(0); + cachedTokens = usage.path("prompt_tokens_details").path("cached_tokens").asInt(0); + } + private ObjectNode messageItemShell() { ObjectNode item = OBJECT_MAPPER.createObjectNode(); item.put("type", "message"); diff --git a/src/main/java/net/ladenthin/llama/value/ServerMetrics.java b/src/main/java/net/ladenthin/llama/value/ServerMetrics.java index e07afb6e..b8870e09 100644 --- a/src/main/java/net/ladenthin/llama/value/ServerMetrics.java +++ b/src/main/java/net/ladenthin/llama/value/ServerMetrics.java @@ -5,6 +5,9 @@ package net.ladenthin.llama.value; import com.fasterxml.jackson.databind.JsonNode; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import lombok.EqualsAndHashCode; /** @@ -103,9 +106,9 @@ public int getTokensMax() { } /** - * Cumulative server-wide token usage since startup. Prompt tokens come from - * {@code n_prompt_tokens_processed_total} and completion tokens from - * {@code n_tokens_predicted_total}. + * Cumulative server-wide compute counters since startup. The prompt value is + * the number actually evaluated by the model, excluding cache hits; upstream + * does not currently expose a cumulative logical or cached prompt total. * * @return cumulative {@link Usage} across all completions since server start */ @@ -115,6 +118,24 @@ public Usage getCumulativeUsage() { node.path("n_tokens_predicted_total").asLong(0L)); } + /** + * Returns the cumulative number of prompt tokens actually evaluated since startup. + * + * @return processed prompt-token total, excluding cache hits + */ + public long getCumulativeProcessedPromptTokens() { + return node.path("n_prompt_tokens_processed_total").asLong(0L); + } + + /** + * Returns the cumulative number of generated tokens since startup. + * + * @return generated-token total + */ + public long getCumulativeGeneratedTokens() { + return node.path("n_tokens_predicted_total").asLong(0L); + } + /** * Usage counters from the most recent measurement window (current bucket) — * {@code n_prompt_tokens_processed} and {@code n_tokens_predicted}. @@ -127,6 +148,15 @@ public Usage getWindowUsage() { node.path("n_tokens_predicted").asLong(0L)); } + /** + * Returns prompt tokens actually evaluated in the current metrics window. + * + * @return processed prompt-token count, excluding cache hits + */ + public long getWindowProcessedPromptTokens() { + return node.path("n_prompt_tokens_processed").asLong(0L); + } + /** * Cumulative throughput derived from the totals fields. Returns {@code 0.0} for * any rate where the corresponding ms total is zero. @@ -152,6 +182,23 @@ public JsonNode getSlots() { return node.path("slots"); } + /** + * Typed slot metrics, including per-slot processed and cached prompt counts. + * + * @return immutable slot list, empty when the native response has no slots array + */ + public List getSlotMetrics() { + JsonNode slots = node.path("slots"); + if (!slots.isArray()) { + return Collections.emptyList(); + } + List result = new ArrayList<>(); + for (JsonNode slot : slots) { + result.add(new SlotMetrics(slot)); + } + return Collections.unmodifiableList(result); + } + /** * Raw passthrough escape hatch. * @return underlying JSON for direct access to fields not yet exposed by typed getters diff --git a/src/main/java/net/ladenthin/llama/value/SlotMetrics.java b/src/main/java/net/ladenthin/llama/value/SlotMetrics.java new file mode 100644 index 00000000..4ebfc8e4 --- /dev/null +++ b/src/main/java/net/ladenthin/llama/value/SlotMetrics.java @@ -0,0 +1,101 @@ +// SPDX-FileCopyrightText: 2026 Bernard Ladenthin +// +// SPDX-License-Identifier: MIT + +package net.ladenthin.llama.value; + +import com.fasterxml.jackson.databind.JsonNode; +import lombok.EqualsAndHashCode; + +/** Typed view of one entry in llama.cpp's server-metrics {@code slots} array. */ +@EqualsAndHashCode +public final class SlotMetrics { + + private final JsonNode node; + + /** + * Wrap a raw slot metrics object. + * + * @param node slot JSON emitted by llama.cpp + */ + public SlotMetrics(JsonNode node) { + this.node = node; + } + + /** + * Returns the zero-based server slot identifier. + * @return slot identifier + */ + public int getId() { + return node.path("id").asInt(-1); + } + + /** + * Returns the context capacity assigned to this slot. + * @return context capacity + */ + public int getContextSize() { + return node.path("n_ctx").asInt(0); + } + + /** + * Reports whether this slot is currently processing a task. + * @return {@code true} while processing + */ + public boolean isProcessing() { + return node.path("is_processing").asBoolean(false); + } + + /** + * Returns the logical prompt-token count for the current or most recent task. + * @return logical prompt-token count + */ + public long getPromptTokens() { + return node.path("n_prompt_tokens").asLong(0L); + } + + /** + * Returns prompt tokens evaluated by the model for the current or most recent task. + * @return evaluated prompt-token count + */ + public long getProcessedPromptTokens() { + return node.path("n_prompt_tokens_processed").asLong(0L); + } + + /** + * Returns prompt tokens reused from KV cache for the current or most recent task. + * @return cached prompt-token count + */ + public long getCachedPromptTokens() { + return node.path("n_prompt_tokens_cache").asLong(0L); + } + + /** + * Returns tokens decoded for the current or most recent task. + * @return decoded-token count + */ + public long getDecodedTokens() { + return node.path("next_token").path("n_decoded").asLong(0L); + } + + /** + * Returns tokens remaining under the current generation limit. + * @return remaining-token count + */ + public long getRemainingTokens() { + return node.path("next_token").path("n_remain").asLong(0L); + } + + /** + * Returns raw slot JSON for fields not represented by typed accessors. + * @return raw slot JSON + */ + public JsonNode asJson() { + return node; + } + + @Override + public String toString() { + return node.toString(); + } +} diff --git a/src/main/java/net/ladenthin/llama/value/Usage.java b/src/main/java/net/ladenthin/llama/value/Usage.java index 7921634b..ce0b9a74 100644 --- a/src/main/java/net/ladenthin/llama/value/Usage.java +++ b/src/main/java/net/ladenthin/llama/value/Usage.java @@ -10,11 +10,11 @@ /** * Token-usage counters, modeled after the OpenAI / Llama Stack {@code usage} block. *

    - * Used by {@link ServerMetrics} to expose cumulative server-wide token totals and - * (in a future {@code ChatResponse}) per-completion counts. + * Used by {@link ChatResponse}, {@link CompletionResult}, and {@link ServerMetrics} + * to expose per-request and cumulative token counts. *

    * - *

    Value equality / {@code toString} are generated by Lombok over the two stored + *

    Value equality / {@code toString} are generated by Lombok over the stored * counters. The derived {@link #getTotalTokens()} sum is included in {@code toString} * via {@link ToString.Include @ToString.Include} so the rendered output retains the * convenience field that the handwritten version exposed.

    @@ -25,6 +25,7 @@ public final class Usage { private final long promptTokens; private final long completionTokens; + private final long cachedTokens; /** * Construct a usage record. @@ -33,8 +34,20 @@ public final class Usage { * @param completionTokens number of completion tokens */ public Usage(long promptTokens, long completionTokens) { + this(promptTokens, completionTokens, 0L); + } + + /** + * Construct a usage record including the reused subset of prompt tokens. + * + * @param promptTokens logical prompt token count, including cached tokens + * @param completionTokens number of completion tokens + * @param cachedTokens prompt tokens served from KV cache without re-evaluation + */ + public Usage(long promptTokens, long completionTokens, long cachedTokens) { this.promptTokens = promptTokens; this.completionTokens = completionTokens; + this.cachedTokens = cachedTokens; } /** @@ -53,6 +66,22 @@ public long getCompletionTokens() { return completionTokens; } + /** + * Prompt tokens reused from the KV cache. + * @return cached subset of {@link #getPromptTokens()} + */ + public long getCachedTokens() { + return cachedTokens; + } + + /** + * Prompt tokens that required model evaluation. + * @return prompt tokens minus the cached subset, clamped to zero for malformed upstream data + */ + public long getProcessedPromptTokens() { + return Math.max(0L, promptTokens - cachedTokens); + } + /** * Convenience sum of the prompt and completion counts. * @return sum of prompt and completion tokens diff --git a/src/test/cpp/test_jni_helpers.cpp b/src/test/cpp/test_jni_helpers.cpp index 211323bc..cafc924d 100644 --- a/src/test/cpp/test_jni_helpers.cpp +++ b/src/test/cpp/test_jni_helpers.cpp @@ -14,7 +14,8 @@ // get_jllama_context_impl, require_json_field_impl, jint_array_to_tokens_impl // // Layer B tests (need upstream server headers + mock JNIEnv): -// configure_multimodal_task_impl, json_to_jstring_impl, results_to_jstring_impl, +// configure_multimodal_task_impl, configure_task_slot_impl, +// json_to_jstring_impl, results_to_jstring_impl, // embedding_to_jfloat_array_impl, tokens_to_jint_array_impl // // JNIEnv is mocked via a zero-filled JNINativeInterface_ table with only the @@ -617,3 +618,15 @@ TEST(ConfigureMultimodalTask, NonStringPromptThrows) { EXPECT_THROW((void)configure_multimodal_task_impl(task, true, {{"prompt", json::array({1, 2})}}, {{0x01}}), std::invalid_argument); } + +TEST(ConfigureTaskSlot, MissingIdUsesAutomaticSelection) { + server_task task(SERVER_TASK_TYPE_COMPLETION); + configure_task_slot_impl(task, json::object()); + EXPECT_EQ(task.id_slot, -1); +} + +TEST(ConfigureTaskSlot, ExplicitIdPinsTask) { + server_task task(SERVER_TASK_TYPE_COMPLETION); + configure_task_slot_impl(task, {{"id_slot", 3}}); + EXPECT_EQ(task.id_slot, 3); +} diff --git a/src/test/java/net/ladenthin/llama/json/ChatResponseParserTest.java b/src/test/java/net/ladenthin/llama/json/ChatResponseParserTest.java index 0a7a875c..147b05b9 100644 --- a/src/test/java/net/ladenthin/llama/json/ChatResponseParserTest.java +++ b/src/test/java/net/ladenthin/llama/json/ChatResponseParserTest.java @@ -227,7 +227,8 @@ public void testParseResponse_fullResponse() { String json = "{\"id\":\"chatcmpl-abc\",\"choices\":[{\"index\":0," + "\"message\":{\"role\":\"assistant\",\"content\":\"Hi there\"}," + "\"finish_reason\":\"stop\"}]," - + "\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3}}"; + + "\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3," + + "\"prompt_tokens_details\":{\"cached_tokens\":5}}}"; ChatResponse r = parser.parseResponse(json); assertEquals("chatcmpl-abc", r.getId()); @@ -239,6 +240,8 @@ public void testParseResponse_fullResponse() { assertEquals("stop", c.getFinishReason()); assertEquals(7L, r.getUsage().getPromptTokens()); assertEquals(3L, r.getUsage().getCompletionTokens()); + assertEquals(5L, r.getUsage().getCachedTokens()); + assertEquals(2L, r.getUsage().getProcessedPromptTokens()); assertEquals(json, r.getRawJson()); } diff --git a/src/test/java/net/ladenthin/llama/json/CompletionResponseParserTest.java b/src/test/java/net/ladenthin/llama/json/CompletionResponseParserTest.java index 1d7e7149..fef3f2f9 100644 --- a/src/test/java/net/ladenthin/llama/json/CompletionResponseParserTest.java +++ b/src/test/java/net/ladenthin/llama/json/CompletionResponseParserTest.java @@ -295,6 +295,7 @@ public void testParseLogprobs_emptyArray_returnsEmptyList() throws Exception { public void testParseCompletionResult_fullResult() { String json = "{\"content\":\"final answer\"," + "\"tokens_evaluated\":11,\"tokens_predicted\":4," + + "\"timings\":{\"cache_n\":8}," + "\"stop_type\":\"eos\"," + "\"completion_probabilities\":[{\"token\":\"final\",\"id\":1,\"prob\":0.7}]}"; CompletionResult r = parser.parseCompletionResult(json); @@ -302,6 +303,8 @@ public void testParseCompletionResult_fullResult() { assertEquals("final answer", r.getText()); assertEquals(11L, r.getUsage().getPromptTokens()); assertEquals(4L, r.getUsage().getCompletionTokens()); + assertEquals(8L, r.getUsage().getCachedTokens()); + assertEquals(3L, r.getUsage().getProcessedPromptTokens()); assertEquals(StopReason.EOS, r.getStopReason()); assertEquals(1, r.getLogprobs().size()); assertEquals("final", r.getLogprobs().get(0).getToken()); diff --git a/src/test/java/net/ladenthin/llama/parameters/InferenceParametersTest.java b/src/test/java/net/ladenthin/llama/parameters/InferenceParametersTest.java index df0d4cef..dbcf8257 100644 --- a/src/test/java/net/ladenthin/llama/parameters/InferenceParametersTest.java +++ b/src/test/java/net/ladenthin/llama/parameters/InferenceParametersTest.java @@ -73,6 +73,20 @@ public void testSetNPredict() { assertThat(params.parameters.get("n_predict"), is("42")); } + @Test + public void testSetCacheReuse() { + InferenceParameters params = InferenceParameters.empty().withCacheReuse(256); + assertThat(params.parameters.get("n_cache_reuse"), is("256")); + assertThrows(IllegalArgumentException.class, () -> params.withCacheReuse(-1)); + } + + @Test + public void testSetSlotId() { + InferenceParameters params = InferenceParameters.empty().withSlotId(2); + assertThat(params.parameters.get("id_slot"), is("2")); + assertThrows(IllegalArgumentException.class, () -> params.withSlotId(-1)); + } + @Test public void testSetParallelToolCalls() { InferenceParameters params = new InferenceParameters("").withParallelToolCalls(false); diff --git a/src/test/java/net/ladenthin/llama/parameters/JsonEndpointParametersTest.java b/src/test/java/net/ladenthin/llama/parameters/JsonEndpointParametersTest.java index 3972d36e..c311011f 100644 --- a/src/test/java/net/ladenthin/llama/parameters/JsonEndpointParametersTest.java +++ b/src/test/java/net/ladenthin/llama/parameters/JsonEndpointParametersTest.java @@ -8,6 +8,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; @@ -15,6 +16,7 @@ import net.ladenthin.llama.ClaudeGenerated; import net.ladenthin.llama.LlamaModel; import net.ladenthin.llama.TestConstants; +import net.ladenthin.llama.value.CompletionResult; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; @@ -60,6 +62,7 @@ public static void setup() { model = new LlamaModel(new ModelParameters() .setCtxSize(256) .setModel(TestConstants.MODEL_PATH) + .setParallel(2) .setGpuLayers(gpuLayers) .setFit(false)); } @@ -245,11 +248,24 @@ public void testNDiscardAccepted() { @Test public void testIdSlotSelection() { - String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + ",\"id_slot\":0}"; + String json = "{\"prompt\":\"" + PROMPT + "\",\"n_predict\":" + N_PREDICT + DETERMINISTIC + ",\"id_slot\":1}"; String result = model.handleCompletions(json); assertThat(result, is(notNullValue())); assertThat(result, containsString("\"content\"")); - assertThat("Response should contain 'id_slot' field", result, containsString("\"id_slot\"")); + assertThat("Response must identify the requested slot", result, containsString("\"id_slot\":1")); + } + + @Test + public void testTypedSlotSelection() { + InferenceParameters params = InferenceParameters.of(PROMPT) + .withSlotId(1) + .withNPredict(N_PREDICT) + .withTemperature(0.0f) + .withSeed(42); + model.completeWithStats(params); // warm the selected slot's prompt cache + CompletionResult result = model.completeWithStats(params); + assertThat(result.getRawJson(), containsString("\"id_slot\":1")); + assertThat("repeated prompt must expose a cache hit", result.getUsage().getCachedTokens(), greaterThan(0L)); } // ------------------------------------------------------------------------- diff --git a/src/test/java/net/ladenthin/llama/server/AnthropicApiSupportTest.java b/src/test/java/net/ladenthin/llama/server/AnthropicApiSupportTest.java index 29f79ed3..dd76e468 100644 --- a/src/test/java/net/ladenthin/llama/server/AnthropicApiSupportTest.java +++ b/src/test/java/net/ladenthin/llama/server/AnthropicApiSupportTest.java @@ -125,7 +125,8 @@ public void responseEmitsTextAndToolUseBlocksAndStopReason() throws IOException String openAi = "{\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"hi\"," + "\"tool_calls\":[{\"id\":\"c1\",\"type\":\"function\",\"function\":{\"name\":\"f\"," + "\"arguments\":\"{\\\"a\\\":1}\"}}]},\"finish_reason\":\"tool_calls\"}]," - + "\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2}}"; + + "\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2," + + "\"prompt_tokens_details\":{\"cached_tokens\":3}}}"; JsonNode out = read(AnthropicApiSupport.toAnthropicResponse(openAi, "m")); assertThat(out.path("type").asText(), is("message")); assertThat(out.path("role").asText(), is("assistant")); @@ -137,7 +138,8 @@ public void responseEmitsTextAndToolUseBlocksAndStopReason() throws IOException assertThat(toolUse.path("input").path("a").asInt(), is(1)); // finish_reason "tool_calls" -> stop_reason "tool_use". assertThat(out.path("stop_reason").asText(), is("tool_use")); - assertThat(out.path("usage").path("input_tokens").asInt(), is(5)); + assertThat(out.path("usage").path("input_tokens").asInt(), is(2)); + assertThat(out.path("usage").path("cache_read_input_tokens").asInt(), is(3)); assertThat(out.path("usage").path("output_tokens").asInt(), is(2)); } diff --git a/src/test/java/net/ladenthin/llama/server/AnthropicStreamTranslatorTest.java b/src/test/java/net/ladenthin/llama/server/AnthropicStreamTranslatorTest.java index e77f271d..28846467 100644 --- a/src/test/java/net/ladenthin/llama/server/AnthropicStreamTranslatorTest.java +++ b/src/test/java/net/ladenthin/llama/server/AnthropicStreamTranslatorTest.java @@ -39,11 +39,16 @@ public void firstTextDeltaOpensBlockThenSubsequentDeltasAppend() { public void endClosesTextBlockAndEmitsStopReasonAndMessageStop() { AnthropicStreamTranslator translator = new AnthropicStreamTranslator("msg_1", "m"); translator.onChunk("{\"choices\":[{\"delta\":{\"content\":\"hi\"},\"finish_reason\":\"stop\"}]}"); + translator.onChunk("{\"choices\":[],\"usage\":{\"prompt_tokens\":12,\"completion_tokens\":3," + + "\"prompt_tokens_details\":{\"cached_tokens\":8}}}"); String end = translator.end(); assertThat(end, containsString("event: content_block_stop")); assertThat(end, containsString("event: message_delta")); assertThat(end, containsString("\"stop_reason\":\"end_turn\"")); assertThat(end, containsString("event: message_stop")); + assertThat(end, containsString("\"input_tokens\":4")); + assertThat(end, containsString("\"output_tokens\":3")); + assertThat(end, containsString("\"cache_read_input_tokens\":8")); } @Test diff --git a/src/test/java/net/ladenthin/llama/server/OpenAiCompatServerHttpTest.java b/src/test/java/net/ladenthin/llama/server/OpenAiCompatServerHttpTest.java index e74652ff..3a18a12f 100644 --- a/src/test/java/net/ladenthin/llama/server/OpenAiCompatServerHttpTest.java +++ b/src/test/java/net/ladenthin/llama/server/OpenAiCompatServerHttpTest.java @@ -88,6 +88,18 @@ public void embeddingsRouteReturnsEmbeddingList() throws IOException { } } + @Test + public void metricsAndSlotsExposeCacheCounters() throws IOException { + try (OpenAiCompatServer server = new OpenAiCompatServer(new FakeBackend(), config()).start()) { + Response metrics = get(server.getPort(), "/metrics", ""); + assertThat(metrics.code, is(200)); + assertThat(metrics.body, containsString("n_prompt_tokens_cache")); + Response slots = get(server.getPort(), "/slots", ""); + assertThat(slots.code, is(200)); + assertThat(slots.body, containsString("\"id\":0")); + } + } + @Test public void infillRouteReturnsContent() throws IOException { try (OpenAiCompatServer server = new OpenAiCompatServer(new FakeBackend(), config()).start()) { @@ -177,6 +189,7 @@ public void anthropicMessagesNonStreamingReturnsMessage() throws IOException { assertThat(response.code, is(200)); assertThat(response.body, containsString("\"type\":\"message\"")); assertThat(response.body, containsString("hello")); // FakeBackend.complete text + assertThat(response.body, containsString("\"cache_read_input_tokens\":8")); } } @@ -190,6 +203,7 @@ public void anthropicMessagesStreamingEmitsEventSequence() throws IOException { assertThat(response.body, containsString("event: message_start")); assertThat(response.body, containsString("event: content_block_delta")); assertThat(response.body, containsString("event: message_stop")); + assertThat(response.body, containsString("\"cache_read_input_tokens\":8")); } } @@ -202,6 +216,7 @@ public void responsesNonStreamingReturnsResponseObject() throws IOException { assertThat(response.body, containsString("\"object\":\"response\"")); assertThat(response.body, containsString("output_text")); assertThat(response.body, containsString("hello")); + assertThat(response.body, containsString("\"cached_tokens\":8")); } } @@ -214,6 +229,7 @@ public void responsesStreamingEmitsEventSequence() throws IOException { assertThat(response.body, containsString("event: response.created")); assertThat(response.body, containsString("event: response.output_text.delta")); assertThat(response.body, containsString("event: response.completed")); + assertThat(response.body, containsString("\"cached_tokens\":8")); } } @@ -376,12 +392,38 @@ public void healthEndpointIsUnauthenticated() throws IOException { } } + @Test + public void metricsAndSlotsRequireApiKeyWhenConfigured() throws IOException { + OpenAiServerConfig cfg = OpenAiServerConfig.builder() + .host("127.0.0.1") + .port(0) + .apiKey("secret") + .build(); + try (OpenAiCompatServer server = new OpenAiCompatServer(new FakeBackend(), cfg).start()) { + int port = server.getPort(); + // /metrics and /slots expose slot state and token counters, so they must be gated. + assertThat(get(port, "/metrics", "").code, is(401)); + assertThat(get(port, "/metrics", "Bearer wrong").code, is(401)); + assertThat(get(port, "/metrics", "Bearer secret").code, is(200)); + assertThat(get(port, "/slots", "").code, is(401)); + assertThat(get(port, "/slots", "Bearer wrong").code, is(401)); + assertThat(get(port, "/slots", "Bearer secret").code, is(200)); + } + } + /** Deterministic backend that returns canned OpenAI shapes for every operation. */ static final class FakeBackend implements OpenAiBackend { + @Override + public String metrics() { + return "{\"idle\":1,\"slots\":[{\"id\":0,\"n_prompt_tokens_cache\":8}]}"; + } + @Override public String complete(JsonNode request) { return "{\"object\":\"chat.completion\",\"choices\":[{\"index\":0," - + "\"message\":{\"role\":\"assistant\",\"content\":\"hello\"}}]}"; + + "\"message\":{\"role\":\"assistant\",\"content\":\"hello\"}}]," + + "\"usage\":{\"prompt_tokens\":12,\"completion_tokens\":3," + + "\"prompt_tokens_details\":{\"cached_tokens\":8}}}"; } @Override @@ -389,6 +431,11 @@ public void stream(JsonNode request, ChunkSink sink) throws IOException { sink.accept("{\"object\":\"chat.completion.chunk\",\"choices\":[{\"delta\":{\"content\":\"he\"}}]}"); sink.accept("{\"object\":\"chat.completion.chunk\"," + "\"choices\":[{\"delta\":{\"content\":\"llo\"},\"finish_reason\":\"stop\"}]}"); + if (request.path("stream_options").path("include_usage").asBoolean(false)) { + sink.accept("{\"object\":\"chat.completion.chunk\",\"choices\":[]," + + "\"usage\":{\"prompt_tokens\":12,\"completion_tokens\":3," + + "\"prompt_tokens_details\":{\"cached_tokens\":8}}}"); + } } @Override diff --git a/src/test/java/net/ladenthin/llama/server/OpenAiRequestMapperTest.java b/src/test/java/net/ladenthin/llama/server/OpenAiRequestMapperTest.java index ee931a7e..f0506e63 100644 --- a/src/test/java/net/ladenthin/llama/server/OpenAiRequestMapperTest.java +++ b/src/test/java/net/ladenthin/llama/server/OpenAiRequestMapperTest.java @@ -154,6 +154,15 @@ public void cachePromptDefaultedTrue() throws IOException { assertThat(out.path("cache_prompt").asBoolean(), is(true)); } + @Test + public void cacheAndSlotExtensionsForwarded() throws IOException { + JsonNode out = mapAndSerialize("{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]," + + "\"cache_prompt\":false,\"n_cache_reuse\":128,\"id_slot\":2}"); + assertThat(out.path("cache_prompt").asBoolean(), is(false)); + assertThat(out.path("n_cache_reuse").asInt(), is(128)); + assertThat(out.path("id_slot").asInt(), is(2)); + } + @Test public void unknownFieldsIgnored() throws IOException { JsonNode out = mapAndSerialize( diff --git a/src/test/java/net/ladenthin/llama/server/ResponsesApiSupportTest.java b/src/test/java/net/ladenthin/llama/server/ResponsesApiSupportTest.java index 213b711c..75e82bd6 100644 --- a/src/test/java/net/ladenthin/llama/server/ResponsesApiSupportTest.java +++ b/src/test/java/net/ladenthin/llama/server/ResponsesApiSupportTest.java @@ -96,7 +96,8 @@ public void requestMapsAssistantMessageItemAndSkipsNonFunctionTools() throws IOE @Test public void responseWrapsOutputMessageWithOutputTextAndUsage() throws IOException { String openAi = "{\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"hello\"}," - + "\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":4,\"completion_tokens\":1}}"; + + "\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":4,\"completion_tokens\":1," + + "\"prompt_tokens_details\":{\"cached_tokens\":3}}}"; JsonNode out = read(ResponsesApiSupport.toResponsesResponse(openAi, "m", "resp_1")); assertThat(out.path("object").asText(), is("response")); assertThat(out.path("status").asText(), is("completed")); @@ -106,6 +107,12 @@ public void responseWrapsOutputMessageWithOutputTextAndUsage() throws IOExceptio assertThat(messageItem.path("content").get(0).path("text").asText(), is("hello")); assertThat(out.path("usage").path("input_tokens").asInt(), is(4)); assertThat(out.path("usage").path("total_tokens").asInt(), is(5)); + assertThat( + out.path("usage") + .path("input_tokens_details") + .path("cached_tokens") + .asInt(), + is(3)); } @Test diff --git a/src/test/java/net/ladenthin/llama/server/ResponsesStreamTranslatorTest.java b/src/test/java/net/ladenthin/llama/server/ResponsesStreamTranslatorTest.java index d4288370..b104f56f 100644 --- a/src/test/java/net/ladenthin/llama/server/ResponsesStreamTranslatorTest.java +++ b/src/test/java/net/ladenthin/llama/server/ResponsesStreamTranslatorTest.java @@ -42,12 +42,17 @@ public void firstTextDeltaOpensItemAndPartThenStreamsDelta() { public void endEmitsDoneEventsAndCompleted() { ResponsesStreamTranslator translator = new ResponsesStreamTranslator("m", "resp_1"); translator.onChunk("{\"choices\":[{\"delta\":{\"content\":\"hi\"},\"finish_reason\":\"stop\"}]}"); + translator.onChunk("{\"choices\":[],\"usage\":{\"prompt_tokens\":12,\"completion_tokens\":3," + + "\"prompt_tokens_details\":{\"cached_tokens\":8}}}"); String end = translator.end(); assertThat(end, containsString("event: response.output_text.done")); assertThat(end, containsString("event: response.content_part.done")); assertThat(end, containsString("event: response.output_item.done")); assertThat(end, containsString("event: response.completed")); assertThat(end, containsString("\"text\":\"hi\"")); + assertThat(end, containsString("\"input_tokens\":12")); + assertThat(end, containsString("\"output_tokens\":3")); + assertThat(end, containsString("\"cached_tokens\":8")); } @Test diff --git a/src/test/java/net/ladenthin/llama/value/ServerMetricsTest.java b/src/test/java/net/ladenthin/llama/value/ServerMetricsTest.java index 6d5989af..10670e34 100644 --- a/src/test/java/net/ladenthin/llama/value/ServerMetricsTest.java +++ b/src/test/java/net/ladenthin/llama/value/ServerMetricsTest.java @@ -28,7 +28,10 @@ private ServerMetrics parse(String json) throws Exception { + "\"n_prompt_tokens_processed\":10,\"t_prompt_processing\":5," + "\"n_tokens_predicted\":20,\"t_tokens_generation\":8," + "\"n_decode_total\":300,\"n_busy_slots_total\":4,\"n_tokens_max\":4096," - + "\"slots\":[{\"id\":0},{\"id\":1}]}"; + + "\"slots\":[{\"id\":0,\"n_ctx\":4096,\"is_processing\":true," + + "\"n_prompt_tokens\":100,\"n_prompt_tokens_processed\":20," + + "\"n_prompt_tokens_cache\":80,\"next_token\":{\"n_decoded\":7,\"n_remain\":9}}," + + "{\"id\":1}]}"; @Test public void slotCountsAndTimestamp() throws Exception { @@ -54,6 +57,8 @@ public void cumulativeUsage() throws Exception { assertEquals(100L, u.getPromptTokens()); assertEquals(200L, u.getCompletionTokens()); assertEquals(300L, u.getTotalTokens()); + assertEquals(100L, m.getCumulativeProcessedPromptTokens()); + assertEquals(200L, m.getCumulativeGeneratedTokens()); } @Test @@ -62,6 +67,7 @@ public void windowUsage() throws Exception { Usage u = m.getWindowUsage(); assertEquals(10L, u.getPromptTokens()); assertEquals(20L, u.getCompletionTokens()); + assertEquals(10L, m.getWindowProcessedPromptTokens()); } @Test @@ -89,6 +95,23 @@ public void slotsArrayExposed() throws Exception { assertEquals(2, m.getSlots().size()); } + @Test + public void typedSlotMetricsExposeCacheCounts() throws Exception { + ServerMetrics m = parse(SAMPLE); + assertEquals(2, m.getSlotMetrics().size()); + SlotMetrics slot = m.getSlotMetrics().get(0); + assertEquals(0, slot.getId()); + assertEquals(4096, slot.getContextSize()); + assertTrue(slot.isProcessing()); + assertEquals(100L, slot.getPromptTokens()); + assertEquals(20L, slot.getProcessedPromptTokens()); + assertEquals(80L, slot.getCachedPromptTokens()); + assertEquals(7L, slot.getDecodedTokens()); + assertEquals(9L, slot.getRemainingTokens()); + assertEquals(0, slot.asJson().path("id").asInt()); + assertTrue(slot.toString().contains("n_prompt_tokens_cache")); + } + @Test public void missingFieldsDefaultToZero() throws Exception { ServerMetrics m = parse("{}"); diff --git a/src/test/java/net/ladenthin/llama/value/UsageTest.java b/src/test/java/net/ladenthin/llama/value/UsageTest.java index fd430809..8785e595 100644 --- a/src/test/java/net/ladenthin/llama/value/UsageTest.java +++ b/src/test/java/net/ladenthin/llama/value/UsageTest.java @@ -27,6 +27,14 @@ public void zeroIsZero() { assertEquals(0, u.getTotalTokens()); } + @Test + public void cachedTokensExposeProcessedSubset() { + Usage u = new Usage(10, 7, 6); + assertEquals(6, u.getCachedTokens()); + assertEquals(4, u.getProcessedPromptTokens()); + assertEquals(17, u.getTotalTokens()); + } + @Test public void equalsAndHashCode() { assertEquals(new Usage(3, 4), new Usage(3, 4));