diff --git a/.gitignore b/.gitignore index 03591943..1e77e82a 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,4 @@ playwright-report/ *.sln *.sw? -# Other .env -*.db diff --git a/backend/omni/.gitignore b/backend/omni/.gitignore new file mode 100644 index 00000000..ddf686d8 --- /dev/null +++ b/backend/omni/.gitignore @@ -0,0 +1,4 @@ +config*.yaml +!config.sample.yaml +*.db +.env diff --git a/backend/omni/CHANGELOG.md b/backend/omni/CHANGELOG.md index b57be4e8..49a425a3 100644 --- a/backend/omni/CHANGELOG.md +++ b/backend/omni/CHANGELOG.md @@ -10,10 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Config can include other config files +- OpenAPI tool registry now supports `tool_servers` and auto-registers all `operationId` endpoints from each OpenAPI spec as tools ### Chnaged - Config to be used can be defined via CONFIG_PATH environment variable +- Refactored tools contract to use request-aware registry methods: `get_tools(request)` and `run_tool(request, params)` +- Tool definitions now use OpenAI Responses API `FunctionToolParam` shape directly +- `GET /api/tools` now returns the raw tool-definition list (no `{ "tools": ... }` envelope) ## [0.0.1] - 2026-02-12 diff --git a/backend/omni/config-simple-dev.yaml b/backend/omni/config.sample.yaml similarity index 72% rename from backend/omni/config-simple-dev.yaml rename to backend/omni/config.sample.yaml index 33d7f57c..72a583bc 100644 --- a/backend/omni/config-simple-dev.yaml +++ b/backend/omni/config.sample.yaml @@ -2,6 +2,7 @@ includes: - path: ./src/modai/default_config.yaml modules: + # Deactivate authentication completely for development purposes. auth_oidc: collision_strategy: drop session: @@ -11,8 +12,3 @@ modules: user_id: "dev-user" email: "dev@example.com" name: "Dev User" - openapi_tool_registry: - config: - tools: - - url: http://localhost:8001/roll - method: POST diff --git a/backend/omni/docs/architecture/tools.md b/backend/omni/docs/architecture/tools.md index 40523699..48831c8e 100644 --- a/backend/omni/docs/architecture/tools.md +++ b/backend/omni/docs/architecture/tools.md @@ -1,12 +1,13 @@ # Tools Architecture ## 1. Overview -- **Architecture Style**: Microservice-based tool system with a generic tool abstraction and a web layer that serves tools in OpenAI format +- **Architecture Style**: Microservice-based tool system with OpenAI-formatted tool definitions - **Design Principles**: - - Tools are LLM-agnostic — a tool has a definition (name, description, parameters) and a run capability; neither is tied to any LLM API - - OpenAPI as one registry implementation — the OpenAPI registry fetches specs from microservices and creates tool instances that handle HTTP invocation internally - - Registry encapsulates invocation — callers do not need to know about URLs, methods, or HTTP; they just run a tool with parameters - - Web layer transforms definitions — the Tools Web Module converts tool definitions to OpenAI function-calling format + - Tool definitions are OpenAI-compatible at the registry boundary (`FunctionToolParam`) + - OpenAPI as one registry implementation — the OpenAPI registry fetches specs from microservices and executes operations over HTTP + - Registry encapsulates invocation — callers do not need to know about URLs, methods, or HTTP; they call `run_tool(request, params)` + - `ToolRegistryModule` is a plain interface (non-web) + - Central aggregation — `ToolsRouterModule` exposes `GET /api/tools`, returns tools without renaming, and dispatches `run_tool` by tool name - Extensible — new registry implementations can plug in any tool backend (HTTP, gRPC, in-process functions, etc.) without changing callers - **Quality Attributes**: Decoupled, language-agnostic, independently deployable, discoverable @@ -118,151 +119,59 @@ The registry will build a tool definition that includes `user_id` and `order_id` ```mermaid flowchart TD - FE[Frontend] -->|GET /api/tools| TW[Tools Web Module] - TW -->|get_tools| TR[Tool Registry Module] - TR -->|GET /openapi.json| TS1[Tool Service A] - TR -->|GET /openapi.json| TS2[Tool Service B] + FE[Frontend] -->|GET /api/tools| CTR[ToolsRouterModule] + CTR -->|aggregate get_tools| TR1[OpenAPIToolRegistry] + TR1 -->|GET /openapi.json| TS1[Tool Service A] + TR1 -->|GET /openapi.json| TS2[Tool Service B] FE -->|POST /api/responses with tool names| CR[Chat Router] CR --> CA[Chat Agent Module] - CA -->|get_tool_by_name| TR - CA -->|tool.run params | TS1 - CA -->|tool.run params | TS2 + CA -->|run_tool request+params| CTR + CTR -->|route by tool name| TR1 + TR1 -->|HTTP invoke| TS1 + TR1 -->|HTTP invoke| TS2 ``` **Flow**: 1. Frontend calls `GET /api/tools` to discover all available tools -2. Tools Web Module asks the Tool Registry for all tools and converts their definitions to OpenAI format +2. Tool Registry returns tools as-is (already OpenAI format) 3. User selects which tools to enable for a chat session 4. Frontend sends `POST /api/responses` with tool names (as received from `GET /api/tools`) -5. When the LLM emits a `tool_call`, the Chat Agent looks up the tool by name in the registry -6. The Chat Agent runs the tool with the LLM-supplied parameters — the tool directly invokes the microservice; no registry involvement at invocation time +5. When the LLM emits a `tool_call`, the Chat Agent calls `run_tool` in the central router with the tool name +6. The central router resolves the registry by matching the tool name and delegates to the selected registry +7. The target registry resolves and invokes the operation and returns the result to the Chat Agent -## 4. Module Architecture +## 4. API Endpoints -### 4.1 Core Abstractions +- `GET /api/tools` — List all available tools across registries (tool names are returned unchanged) -**Tool Definition** — a value object with three fields: -- `name` — unique identifier derived from the OpenAPI `operationId` -- `description` — human-readable text describing what the tool does -- `parameters` — a fully-resolved JSON Schema (all `$ref` pointers inlined) describing the input - -A tool definition contains enough information to construct an LLM tool call but is not tied to any specific LLM API. - -**Tool** — pairs a definition with execution capability: -- Exposes its `definition` (read-only) -- Provides a `run(params)` operation that executes the tool with the given parameters and returns the result - -#### Reserved `_`-prefixed keys in `params` - -Callers may inject caller-supplied metadata into the `params` dict using keys prefixed with `_`. These keys are **never** forwarded to the tool microservice's JSON body — tool implementations must extract and consume them before sending the request. - -Currently defined reserved keys: - -| Key | Type | Description | -|---|---|---| -| `_bearer_token` | `str \| None` | Forwarded as `Authorization: Bearer ` HTTP header | - -This convention keeps the `Tool.run` interface stable while allowing callers to pass through transport-level concerns (auth, tracing, etc.) without requiring interface changes. - -### 4.2 Tool Registry Module (Plain Module) - -**Purpose**: Aggregates tools from all configured sources and provides lookup by name. - -**Responsibilities**: -- Return all available tools via `get_tools` -- Look up a tool by name via `get_tool_by_name` -- Handle unavailable tool services gracefully (skip with warning, don't fail) - -**No module dependencies**: The registry does not depend on other modAI modules. - -### 4.3 OpenAPI Tool Registry (concrete implementation) - -**Purpose**: Concrete registry implementation that harvests OpenAPI specs from configured HTTP microservices. - -**How it works**: -- On each call to `get_tools`, fetches `/openapi.json` from each configured service -- Extracts the tool definition from the spec: - - `operationId` → name - - `summary`/`description` → description - - Request body schema → parameters (all `$ref` resolved inline) - - Path parameters (`in: path`) from the `parameters` array are merged into the schema's `properties` and `required` lists so the LLM is told to supply them -- Each resulting tool's `run` operation: - 1. Resolves `{param_name}` placeholders in the configured URL by substituting values from the supplied `params` dict - 2. Sends the remaining parameters as the JSON request body - 3. Makes an HTTP call to the resolved URL using the configured method - -**Configuration** — each tool entry specifies: -- `url`: The full trigger endpoint URL of the tool microservice -- `method`: The HTTP method to use when invoking the tool (e.g. POST, PUT, GET) - -### 4.4 Tools Web Module (Web Module) - -**Purpose**: Exposes `GET /api/tools` endpoint. Transforms tool definitions into OpenAI function-calling format. - -**Dependencies**: Tool Registry Module - -**Responsibilities**: -- Expose `GET /api/tools` endpoint -- Call the Tool Registry to get all available tools -- Convert each tool definition to OpenAI function-calling format -- Return the transformed tool definitions to the frontend - -### 4.5 Chat Agent Module - -The Chat Agent Module receives a tool registry dependency. When the LLM emits a `tool_call`: -1. Extract the function name from the tool call -2. Look up the tool by name in the registry -3. Run the tool with the LLM-supplied parameters — no HTTP knowledge needed in the chat module -4. Return the result to the LLM - -## 5. API Endpoints - -- `GET /api/tools` — List all available tools in OpenAI function-calling format - -### 5.1 List Available Tools +### 4.1 List Available Tools **Endpoint**: `GET /api/tools` -**Purpose**: Returns all available tools in OpenAI function-calling format. +**Purpose**: Returns all available tools in OpenAI function-calling format aggregated from all configured registries. -**Tool Definition → OpenAI Transformation**: -- `name` → `function.name` -- `description` → `function.description` -- `parameters` → `function.parameters` (already resolved, no `$ref`) +The endpoint returns tool definitions in OpenAI function-tool format directly as a JSON list. **Response Format (200 OK)**: ```json -{ - "tools": [ - { - "type": "function", - "function": { - "name": "calculate", - "description": "Evaluate a math expression", - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "Math expression to evaluate" - } - }, - "required": ["expression"] +[ + { + "type": "function", + "name": "calculate", + "description": "Evaluate a math expression", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Math expression to evaluate" } - } - } - ] -} + }, + "required": ["expression"] + }, + "strict": true + } +] ``` If a tool service is unreachable, it is omitted from the response and a warning is logged. - -## 6. Configuration - -The tool registry is configured with a list of tool microservice endpoints. Each entry has: -- `url`: The full trigger endpoint URL of the tool microservice -- `method`: The HTTP method used to invoke the tool (e.g. PUT, POST, GET) - -The registry derives the base URL from `url` (strips the path) and appends `/openapi.json` to fetch the spec. - -See `config.yaml` and `default_config.yaml` for concrete configuration examples. diff --git a/backend/omni/justfile b/backend/omni/justfile index fb98d70f..0528b8e3 100644 --- a/backend/omni/justfile +++ b/backend/omni/justfile @@ -21,6 +21,6 @@ check: uv run ruff check src # Fix code style and linting issues -check-write: +format: uv run ruff format src uv run ruff check --fix src diff --git a/backend/omni/src/modai/default_config.yaml b/backend/omni/src/modai/default_config.yaml index 81fbaca6..33910d31 100644 --- a/backend/omni/src/modai/default_config.yaml +++ b/backend/omni/src/modai/default_config.yaml @@ -79,15 +79,13 @@ modules: module_dependencies: http_client: "http_client" config: - tools: [] + tool_servers: [] # Example: - # tools: - # - url: http://calculator-service:8000/calculate - # method: POST - # - url: http://web-search-service:8000/search - # method: PUT + # tool_servers: + # - url: http://calculator-service:8000/openapi.json + # - url: http://web-search-service:8000/openapi.json - tool_registry: + predefined_tool_registry: class: modai.modules.tools.tool_registry_predefined_vars.PredefinedVariablesToolRegistryModule module_dependencies: delegate_registry: "openapi_tool_registry" @@ -98,10 +96,10 @@ modules: # variable_mappings: # X-Session-Id: session_id - tools_web: - class: modai.modules.tools.tools_web_module.OpenAIToolsWebModule + tool_registry: + class: modai.modules.tools.tool_router.ToolsRouterModule module_dependencies: - tool_registry: tool_registry + predefined: "predefined_tool_registry" full_reset: class: modai.modules.reset.reset_web_module.ResetWebModule diff --git a/backend/omni/src/modai/modules/chat/__tests__/test_chat_llm_modules.py b/backend/omni/src/modai/modules/chat/__tests__/test_chat_llm_modules.py index 01fd4632..38a4eb8b 100644 --- a/backend/omni/src/modai/modules/chat/__tests__/test_chat_llm_modules.py +++ b/backend/omni/src/modai/modules/chat/__tests__/test_chat_llm_modules.py @@ -41,7 +41,6 @@ ModelProviderResponse, ModelProvidersListResponse, ) -from modai.modules.tools.module import Tool, ToolDefinition working_dir = Path.cwd() load_dotenv(find_dotenv(str(working_dir / ".env"))) @@ -94,6 +93,10 @@ # Both module classes, llmock backend only (no OpenAI key required). _LLMOCK_ONLY_PARAMS = [_AGENTIC_LLMOCK, _NON_AGENTIC_LLMOCK] +# Agentic module with llmock only — for tool-call execution tests that require +# deterministic LLM behaviour (llmock ToolCallStrategy always calls the tool). +_AGENTIC_LLMOCK_ONLY_PARAMS = [_AGENTIC_LLMOCK] + # --------------------------------------------------------------------------- # Module factory # --------------------------------------------------------------------------- @@ -180,6 +183,14 @@ def llmock_only_factory( return _build_module_factory(request.param, llmock_base_url) +@pytest.fixture(params=_AGENTIC_LLMOCK_ONLY_PARAMS) +def agentic_llmock_factory( + request: pytest.FixtureRequest, llmock_base_url: str +) -> ModuleFactory: + """Agentic module with llmock only — tool call behaviour is deterministic.""" + return _build_module_factory(request.param, llmock_base_url) + + @pytest.fixture( params=[ pytest.param(StrandsAgentChatModule, id="agentic"), @@ -568,35 +579,22 @@ class TestAgenticLoop: @pytest.mark.asyncio async def test_tool_is_executed_during_non_streaming_loop( - self, agentic_factory: ModuleFactory + self, agentic_llmock_factory: ModuleFactory ): captured_calls: list[dict] = [] - class _CapturingTool(Tool): - @property - def definition(self) -> ToolDefinition: - return ToolDefinition( - name="calculate", - description="Evaluate a math expression", - parameters={ - "type": "object", - "properties": {"expression": {"type": "string"}}, - "required": ["expression"], - }, - ) - - async def run(self, params: dict[str, Any]) -> Any: - captured_calls.append(dict(params)) - return "42" + async def _run_tool(request: Any, params: dict[str, Any]) -> str: + captured_calls.append(dict(params)) + return "42" registry = Mock() - registry.get_tool_by_name = AsyncMock(return_value=_CapturingTool()) + registry.run_tool = AsyncMock(side_effect=_run_tool) - module = agentic_factory.create(tool_registry=registry) + module = agentic_llmock_factory.create(tool_registry=registry) result = await module.generate_response( _make_request(), { - "model": agentic_factory.model, + "model": agentic_llmock_factory.model, "input": "call tool 'calculate' with '{\"expression\": \"6*7\"}'", "tools": [_AGENTIC_CALCULATE_TOOL], }, @@ -607,35 +605,22 @@ async def run(self, params: dict[str, Any]) -> Any: @pytest.mark.asyncio async def test_tool_is_executed_during_streaming_loop( - self, agentic_factory: ModuleFactory + self, agentic_llmock_factory: ModuleFactory ): captured_calls: list[dict] = [] - class _CapturingTool(Tool): - @property - def definition(self) -> ToolDefinition: - return ToolDefinition( - name="calculate", - description="Evaluate a math expression", - parameters={ - "type": "object", - "properties": {"expression": {"type": "string"}}, - "required": ["expression"], - }, - ) - - async def run(self, params: dict[str, Any]) -> Any: - captured_calls.append(dict(params)) - return "42" + async def _run_tool(request: Any, params: dict[str, Any]) -> str: + captured_calls.append(dict(params)) + return "42" registry = Mock() - registry.get_tool_by_name = AsyncMock(return_value=_CapturingTool()) + registry.run_tool = AsyncMock(side_effect=_run_tool) - module = agentic_factory.create(tool_registry=registry) + module = agentic_llmock_factory.create(tool_registry=registry) gen = await module.generate_response( _make_request(), { - "model": agentic_factory.model, + "model": agentic_llmock_factory.model, "input": "call tool 'calculate' with '{\"expression\": \"6*7\"}'", "tools": [_AGENTIC_CALCULATE_TOOL], "stream": True, @@ -814,7 +799,9 @@ async def test_unknown_tool_is_silently_skipped( agent continues to a final text response. """ registry = Mock() - registry.get_tool_by_name = AsyncMock(return_value=None) + registry.run_tool = AsyncMock( + side_effect=ValueError("Tool 'nonexistent_tool' not found") + ) module = agentic_factory.create(tool_registry=registry) body = { "model": agentic_factory.model, @@ -833,20 +820,8 @@ async def test_tool_run_error_is_handled_gracefully( ): """A tool whose run() raises an error does not crash the agent.""" - class _FailingTool(Tool): - @property - def definition(self) -> ToolDefinition: - return ToolDefinition( - name="broken_tool", - description="Broken", - parameters={"type": "object", "properties": {}}, - ) - - async def run(self, params: dict[str, Any]) -> Any: - raise RuntimeError("tool exploded") - registry = Mock() - registry.get_tool_by_name = AsyncMock(return_value=_FailingTool()) + registry.run_tool = AsyncMock(side_effect=RuntimeError("tool exploded")) module = agentic_factory.create(tool_registry=registry) body = { "model": agentic_factory.model, @@ -870,9 +845,7 @@ async def test_tool_registry_error_handled_gracefully( result so the agent can continue to a final text response. """ registry = Mock() - registry.get_tool_by_name = AsyncMock( - side_effect=RuntimeError("Registry unavailable") - ) + registry.run_tool = AsyncMock(side_effect=RuntimeError("Registry unavailable")) module = agentic_factory.create(tool_registry=registry) body = { "model": agentic_factory.model, @@ -890,19 +863,10 @@ async def test_tool_invocation_http_error_handled_gracefully( self, agentic_factory: ModuleFactory ): """When a tool URL is unreachable the agent receives a tool error and completes.""" - definition = ToolDefinition( - name="calculate", - description="Evaluate a math expression", - parameters={ - "type": "object", - "properties": {"expression": {"type": "string"}}, - }, - ) - tool = _make_tool( - definition, run_url="http://localhost:1/calculate", run_method="POST" - ) registry = Mock() - registry.get_tool_by_name = AsyncMock(return_value=tool) + registry.run_tool = AsyncMock( + side_effect=httpx_lib.ConnectError("Connection refused") + ) module = agentic_factory.create(tool_registry=registry) body = { "model": agentic_factory.model, @@ -918,19 +882,18 @@ async def test_tool_invocation_success_request_sent_to_tool( ): """The tool HTTP endpoint receives the call forwarded by the agent.""" httpserver.expect_oneshot_request("/calculate").respond_with_json({"result": 4}) - definition = ToolDefinition( - name="calculate", - description="Evaluate a math expression", - parameters={ - "type": "object", - "properties": {"expression": {"type": "string"}}, - }, - ) - tool = _make_tool( - definition, run_url=httpserver.url_for("/calculate"), run_method="POST" - ) + + async def _run_tool_http(request: Any, params: dict[str, Any]) -> str: + async with httpx_lib.AsyncClient(timeout=30.0) as client: + resp = await client.post( + httpserver.url_for("/calculate"), + json=params.get("arguments", {}), + ) + resp.raise_for_status() + return resp.text + registry = Mock() - registry.get_tool_by_name = AsyncMock(return_value=tool) + registry.run_tool = AsyncMock(side_effect=_run_tool_http) module = agentic_factory.create(tool_registry=registry) body = { "model": agentic_factory.model, @@ -947,19 +910,15 @@ async def test_partial_tools_resolved_when_some_missing( ): """All client tool specs are registered with Strands; missing registry tools return an error result at execution time without preventing other tools.""" - calc_definition = ToolDefinition( - name="calculate", - description="Evaluate a math expression", - parameters={ - "type": "object", - "properties": {"expression": {"type": "string"}}, - }, - ) - calc_tool = _make_tool(calc_definition) + + async def _run_tool_partial(request: Any, params: dict[str, Any]) -> Any: + name = params.get("name") + if name == "calculate": + return "42" + raise ValueError(f"Tool '{name}' not found") + registry = Mock() - registry.get_tool_by_name = AsyncMock( - side_effect=lambda name, **_: calc_tool if name == "calculate" else None - ) + registry.run_tool = AsyncMock(side_effect=_run_tool_partial) module = agentic_factory.create(tool_registry=registry) body = { "model": agentic_factory.model, @@ -999,20 +958,8 @@ def clear_history(self, llmock_base_url: str) -> None: async def test_client_description_forwarded_to_llm(self, llmock_base_url: str): """Tool description from client spec reaches LLM instead of registry's.""" - class _RegistryTool(Tool): - @property - def definition(self) -> ToolDefinition: - return ToolDefinition( - name="calculate", - description="REGISTRY description — must NOT reach LLM", - parameters={"type": "object", "properties": {}}, - ) - - async def run(self, params: dict[str, Any]) -> Any: - return "42" - registry = Mock() - registry.get_tool_by_name = AsyncMock(return_value=_RegistryTool()) + registry.run_tool = AsyncMock(return_value="42") provider = _make_provider(base_url=llmock_base_url, api_key=LLMOCK_API_KEY) module = StrandsAgentChatModule( @@ -1072,20 +1019,8 @@ async def run(self, params: dict[str, Any]) -> Any: async def test_client_parameters_forwarded_to_llm(self, llmock_base_url: str): """Tool parameters schema from client spec reaches LLM instead of registry's.""" - class _RegistryTool(Tool): - @property - def definition(self) -> ToolDefinition: - return ToolDefinition( - name="calculate", - description="Registry tool", - parameters={"type": "object", "properties": {}}, - ) - - async def run(self, params: dict[str, Any]) -> Any: - return "42" - registry = Mock() - registry.get_tool_by_name = AsyncMock(return_value=_RegistryTool()) + registry.run_tool = AsyncMock(return_value="42") provider = _make_provider(base_url=llmock_base_url, api_key=LLMOCK_API_KEY) module = StrandsAgentChatModule( @@ -1152,9 +1087,8 @@ async def test_client_only_name_sends_empty_description_to_llm( LLM receives a tool with empty description — the registry is never consulted for spec information.""" - # Registry that finds no tool (returns None) — must not affect the spec. registry = Mock() - registry.get_tool_by_name = AsyncMock(return_value=None) + registry.run_tool = AsyncMock(return_value="42") provider = _make_provider(base_url=llmock_base_url, api_key=LLMOCK_API_KEY) module = StrandsAgentChatModule( @@ -1301,37 +1235,6 @@ def _real_model() -> str: return f"myopenai/{model}" -def _make_tool( - definition: ToolDefinition, run_url: str = "", run_method: str = "POST" -) -> Tool: - """Create a Tool stub for testing. - - If run_url is provided the tool makes a real HTTP call when run() is called; - otherwise run() returns an empty string. - """ - url = run_url - method = run_method - - class _TestTool(Tool): - @property - def definition(self) -> ToolDefinition: - return definition - - async def run(self, params: dict[str, Any]) -> Any: - if url: - import httpx - - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.request( - method=method.upper(), url=url, json=params - ) - response.raise_for_status() - return response.text - return "" - - return _TestTool() - - def _wait_for_health(base_url: str, timeout: float = 30.0) -> None: """Poll the llmock health endpoint until it responds.""" deadline = time.time() + timeout diff --git a/backend/omni/src/modai/modules/chat/openai_agent_chat.py b/backend/omni/src/modai/modules/chat/openai_agent_chat.py index be5ceb46..df148e39 100644 --- a/backend/omni/src/modai/modules/chat/openai_agent_chat.py +++ b/backend/omni/src/modai/modules/chat/openai_agent_chat.py @@ -68,11 +68,10 @@ async def generate_response( ) -> OpenAIResponse | AsyncGenerator[OpenAIResponseStreamEvent, None]: provider_name, actual_model = _parse_model(body_json.get("model", "")) provider = await self._resolve_provider(request, provider_name) - additional_tool_properties = _extract_additional_tool_properties(request) tools = await _resolve_request_tools( body_json, self.tool_registry, - additional_tool_properties=additional_tool_properties, + request=request, ) agent = _create_agent(provider, actual_model, body_json, tools) user_message = _extract_last_user_message(body_json) @@ -112,23 +111,6 @@ def _parse_model(model: str) -> tuple[str, str]: return parts[0], parts[1] -def _extract_additional_tool_properties(request: Request) -> dict[str, Any]: - """Extract caller-supplied metadata from the request to inject into tool calls. - - Returns a dict of ``_``-prefixed keys that are merged into every tool - invocation's ``params`` dict. Tool implementations consume these reserved - keys (e.g. for HTTP headers) without forwarding them to the payload. - - Currently extracted properties: - - ``_bearer_token``: raw token from the ``Authorization: Bearer`` header. - """ - properties: dict[str, Any] = {} - auth_header = request.headers.get("Authorization", "") - if auth_header.startswith("Bearer "): - properties["_bearer_token"] = auth_header[len("Bearer ") :] - return properties - - def _create_agent( provider: ModelProviderResponse, model_id: str, @@ -235,7 +217,7 @@ def _extract_tools(body_json: dict[str, Any]) -> list[dict[str, Any]]: async def _resolve_request_tools( body_json: dict[str, Any], tool_registry: ToolRegistryModule | None, - additional_tool_properties: dict[str, Any] | None = None, + request: Request, ) -> list[PythonAgentTool]: """Resolve requested tools from the request body into Strands agent tools. @@ -243,10 +225,6 @@ async def _resolve_request_tools( is looked up in the registry and wrapped as a ``PythonAgentTool`` that invokes the tool microservice over HTTP. - ``additional_tool_properties`` is a dict of ``_``-prefixed keys extracted - from the request (see ``_extract_additional_tool_properties``) that are - merged into every tool invocation's params dict. - Returns an empty list when no registry is configured or no tools are requested. """ @@ -261,7 +239,7 @@ async def _resolve_request_tools( _create_strands_tool( client_spec=client_spec, tool_registry=tool_registry, - additional_tool_properties=additional_tool_properties, + request=request, ) for client_spec in tool_specs ] @@ -270,7 +248,7 @@ async def _resolve_request_tools( def _create_strands_tool( client_spec: dict[str, Any], tool_registry: ToolRegistryModule, - additional_tool_properties: dict[str, Any] | None = None, + request: Request, ) -> PythonAgentTool: """Wrap a client-provided tool spec as a Strands ``PythonAgentTool``. @@ -279,10 +257,6 @@ def _create_strands_tool( verbatim. Execution is handled lazily: when Strands invokes the tool, the registry is queried by name and ``tool.run`` is called. - ``additional_tool_properties`` (a dict of ``_``-prefixed keys) is merged - into every invocation's params dict so that tool implementations can pick - up transport-level concerns (auth, tracing, etc.) without the interface - carrying extra args. """ name: str = client_spec.get("name", "") @@ -294,23 +268,14 @@ def _create_strands_tool( }, } - async def _handler(tool_use: ToolUse, **kwargs: Any) -> ToolResult: + async def _run_tool_handler(tool_use: ToolUse, **kwargs: Any) -> ToolResult: """Invoke the tool and wrap the result for Strands.""" - tool = await tool_registry.get_tool_by_name( - name, predefined_params=additional_tool_properties - ) - if tool is None: - logger.warning("Tool '%s' not found in registry at execution time", name) - return { - "toolUseId": tool_use["toolUseId"], - "status": "error", - "content": [{"text": f"Tool '{name}' is not available"}], - } - params: dict[str, Any] = dict(tool_use["input"]) - if additional_tool_properties: - params.update(additional_tool_properties) + params: dict[str, Any] = { + "name": name, + "arguments": dict(tool_use["input"]), + } try: - result = await tool.run(params) + result = await tool_registry.run_tool(request, params) return { "toolUseId": tool_use["toolUseId"], "status": "success", @@ -327,7 +292,7 @@ async def _handler(tool_use: ToolUse, **kwargs: Any) -> ToolResult: return PythonAgentTool( tool_name=name, tool_spec=tool_spec, - tool_func=_handler, + tool_func=_run_tool_handler, ) diff --git a/backend/omni/src/modai/modules/tools/README.md b/backend/omni/src/modai/modules/tools/README.md new file mode 100644 index 00000000..c3997d67 --- /dev/null +++ b/backend/omni/src/modai/modules/tools/README.md @@ -0,0 +1,71 @@ +# Tools Module + +## Interface + +The abstract interface is defined in `module.py`. + +- Module types: + - `ToolRegistryModule` (plain module): resolves available tools and executes tool calls. +- Public contract for callers: + - `get_tools(request: Request) -> list[ToolDefinition]` + - `run_tool(request: Request, params: dict[str, Any]) -> Any` + - `ToolDefinition` uses OpenAI Responses API format (`FunctionToolParam`). + +## OpenAPIToolRegistryModule + +Purpose: discovers tools from OpenAPI servers and invokes each operation over HTTP. + +Class used in config: +- `modai.modules.tools.tool_registry_openapi.OpenAPIToolRegistryModule` + +```yaml +modules: + openapi_tool_registry: + class: modai.modules.tools.tool_registry_openapi.OpenAPIToolRegistryModule + module_dependencies: + http_client: "http_client" + config: + tool_servers: + - url: http://localhost:8001/openapi.json +``` + +Supported config keys: +- `tool_servers` (required list): list of objects with `url` (OpenAPI spec URL). + +Module dependencies: +- `http_client` + +## ToolsRouterModule + +Purpose: aggregate multiple `ToolRegistryModule` implementations behind one public endpoint and dispatch runtime tool calls. + +Class used in config: +- `modai.modules.tools.tool_router.ToolsRouterModule` + +```yaml +modules: + openapi_tool_registry: + class: modai.modules.tools.tool_registry_openapi.OpenAPIToolRegistryModule + module_dependencies: + http_client: "http_client" + config: + tool_servers: + - url: http://localhost:8001/openapi.json + + tool_registry: + class: modai.modules.tools.tool_router.ToolsRouterModule + module_dependencies: + openapi: "openapi_tool_registry" +``` + +Supported config keys: +- none + +Module dependencies: +- one or more `ToolRegistryModule` implementations (named freely in `module_dependencies`) + +Behavior: +- Exposes `GET /api/tools`. +- Returns tool names as provided by underlying registries (no prefixing). +- On `run_tool`, routes to the registry that provides the requested tool name. +- If multiple registries provide the same tool name, invocation fails with an ambiguity error. diff --git a/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry_openapi.py b/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry_openapi.py index ba0fa290..74315919 100644 --- a/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry_openapi.py +++ b/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry_openapi.py @@ -4,21 +4,19 @@ import httpx import pytest +from fastapi import Request from modai.module import ModuleDependencies from modai.modules.http_client.module import HttpClientModule -from modai.modules.tools.module import Tool, ToolDefinition from modai.modules.tools.tool_registry_openapi import ( OpenAPIToolRegistryModule, - _build_tool_definition, + _build_operation_specs, _derive_base_url, _fetch_openapi_spec, ) class _StubHttpClientFactory(HttpClientModule): - """Test factory that yields clients in sequence; reuses the last one when exhausted.""" - def __init__(self, *clients: httpx.AsyncClient): super().__init__(ModuleDependencies(), {}) self._clients = list(clients) @@ -34,8 +32,23 @@ async def _ctx(): return _ctx() +def _request( + path: str = "/api/tools", headers: dict[str, str] | None = None +) -> Request: + raw_headers = [ + (name.lower().encode("latin-1"), value.encode("latin-1")) + for name, value in (headers or {}).items() + ] + scope: dict[str, Any] = { + "type": "http", + "method": "GET", + "path": path, + "headers": raw_headers, + } + return Request(scope) + + def _mock_response(spec: dict | None = None, text: str = "") -> MagicMock: - """Build a minimal mock httpx response.""" resp = MagicMock() resp.raise_for_status = MagicMock() if spec is not None: @@ -46,14 +59,12 @@ def _mock_response(spec: dict | None = None, text: str = "") -> MagicMock: SAMPLE_OPENAPI_SPEC = { "openapi": "3.1.0", - "info": {"title": "Calculator Tool", "version": "1.0.0"}, "paths": { "/calculate": { "post": { "summary": "Evaluate a math expression", "operationId": "calculate", "requestBody": { - "required": True, "content": { "application/json": { "schema": { @@ -62,7 +73,7 @@ def _mock_response(spec: dict | None = None, text: str = "") -> MagicMock: "required": ["expression"], } } - }, + } }, } } @@ -71,7 +82,6 @@ def _mock_response(spec: dict | None = None, text: str = "") -> MagicMock: PATH_PARAMS_SPEC = { "openapi": "3.1.0", - "info": {"title": "User Tool", "version": "1.0.0"}, "paths": { "/users/{user_id}/orders/{order_id}": { "get": { @@ -100,7 +110,6 @@ def _mock_response(spec: dict | None = None, text: str = "") -> MagicMock: HEADER_PARAMS_SPEC = { "openapi": "3.1.0", - "info": {"title": "Session Tool", "version": "1.0.0"}, "paths": { "/data": { "get": { @@ -129,7 +138,6 @@ def _mock_response(spec: dict | None = None, text: str = "") -> MagicMock: HEADER_AND_BODY_SPEC = { "openapi": "3.1.0", - "info": {"title": "Submit Tool", "version": "1.0.0"}, "paths": { "/submit": { "post": { @@ -163,7 +171,6 @@ def _mock_response(spec: dict | None = None, text: str = "") -> MagicMock: PATH_PARAMS_WITH_BODY_SPEC = { "openapi": "3.1.0", - "info": {"title": "Update Tool", "version": "1.0.0"}, "paths": { "/items/{item_id}": { "put": { @@ -197,7 +204,6 @@ def _mock_response(spec: dict | None = None, text: str = "") -> MagicMock: DICE_ROLLER_SPEC = { "openapi": "3.1.0", - "info": {"title": "Dice Roller Tool", "version": "1.0.0"}, "paths": { "/roll": { "post": { @@ -236,72 +242,59 @@ def _mock_response(spec: dict | None = None, text: str = "") -> MagicMock: } -class TestBuildToolDefinition: - def test_openapi_with_inline_schema(self): - definition, header_names = _build_tool_definition(SAMPLE_OPENAPI_SPEC) - assert definition == ToolDefinition( - name="calculate", - description="Evaluate a math expression", - parameters={ +class TestBuildOperationSpecs: + def test_builds_openai_function_tool_definition(self): + specs = _build_operation_specs(SAMPLE_OPENAPI_SPEC) + assert len(specs) == 1 + assert specs[0].definition == { + "type": "function", + "name": "calculate", + "description": "Evaluate a math expression", + "parameters": { "type": "object", "properties": {"expression": {"type": "string"}}, "required": ["expression"], }, - ) - assert header_names == frozenset() - - def test_openapi_with_ref_schema(self): - definition, header_names = _build_tool_definition(DICE_ROLLER_SPEC) - assert definition == ToolDefinition( - name="roll_dice", - description="Roll dice and return the results", - parameters={ - "type": "object", - "properties": { - "count": { - "type": "integer", - "default": 1, - "description": "Number of dice to roll", - }, - "sides": { - "type": "integer", - "default": 6, - "description": "Number of sides per die", - }, - }, - }, - ) - assert header_names == frozenset() - - def test_path_parameters_only(self): - definition, header_names = _build_tool_definition(PATH_PARAMS_SPEC) - assert definition is not None - assert definition.name == "get_user_order" - assert definition.description == "Get a specific user order" - params = definition.parameters - assert params["type"] == "object" + "strict": True, + } + + def test_ref_schema_is_resolved(self): + specs = _build_operation_specs(DICE_ROLLER_SPEC) + assert len(specs) == 1 + params = specs[0].definition["parameters"] + assert "count" in params["properties"] + assert "sides" in params["properties"] + assert params["properties"]["count"]["type"] == "integer" + assert specs[0].header_param_names == frozenset() + + def test_path_parameters_become_properties(self): + specs = _build_operation_specs(PATH_PARAMS_SPEC) + assert len(specs) == 1 + definition = specs[0].definition + assert definition["name"] == "get_user_order" + params = definition["parameters"] assert "user_id" in params["properties"] assert "order_id" in params["properties"] assert params["properties"]["user_id"]["type"] == "string" assert params["properties"]["user_id"]["description"] == "The user's ID" assert params["properties"]["order_id"]["type"] == "integer" assert set(params["required"]) == {"user_id", "order_id"} - assert header_names == frozenset() + assert specs[0].header_param_names == frozenset() def test_path_parameters_merged_with_request_body(self): - definition, header_names = _build_tool_definition(PATH_PARAMS_WITH_BODY_SPEC) - assert definition is not None - params = definition.parameters + specs = _build_operation_specs(PATH_PARAMS_WITH_BODY_SPEC) + assert len(specs) == 1 + params = specs[0].definition["parameters"] assert "item_id" in params["properties"] assert "name" in params["properties"] assert "item_id" in params["required"] assert "name" in params["required"] - assert header_names == frozenset() + assert specs[0].header_param_names == frozenset() def test_header_parameters_in_definition(self): - definition, header_names = _build_tool_definition(HEADER_PARAMS_SPEC) - assert definition is not None - params = definition.parameters + specs = _build_operation_specs(HEADER_PARAMS_SPEC) + assert len(specs) == 1 + params = specs[0].definition["parameters"] assert "X-Session-Id" in params["properties"] assert "X-Tenant" in params["properties"] assert params["properties"]["X-Session-Id"]["type"] == "string" @@ -311,43 +304,66 @@ def test_header_parameters_in_definition(self): ) assert "X-Session-Id" in params["required"] assert "X-Tenant" not in params.get("required", []) - assert header_names == {"X-Session-Id", "X-Tenant"} + assert specs[0].header_param_names == frozenset({"X-Session-Id", "X-Tenant"}) def test_header_parameters_merged_with_request_body(self): - definition, header_names = _build_tool_definition(HEADER_AND_BODY_SPEC) - assert definition is not None - params = definition.parameters + specs = _build_operation_specs(HEADER_AND_BODY_SPEC) + assert len(specs) == 1 + params = specs[0].definition["parameters"] assert "X-Request-Id" in params["properties"] assert "payload" in params["properties"] assert "X-Request-Id" in params["required"] assert "payload" in params["required"] - assert header_names == {"X-Request-Id"} + assert specs[0].header_param_names == frozenset({"X-Request-Id"}) - def test_no_operation_id_returns_none(self): + def test_no_operation_id_skips_operation(self): spec = {"paths": {"/run": {"post": {"summary": "no id"}}}} - definition, header_names = _build_tool_definition(spec) - assert definition is None - assert header_names == frozenset() + assert _build_operation_specs(spec) == [] + +class TestGetTools: def _make_module( - self, tools: list[dict], factory=None + self, + factory: HttpClientModule, + tool_servers: list[dict] | None = None, ) -> OpenAPIToolRegistryModule: - if factory is None: - # Provide a factory that yields a no-op async client by default - factory = _StubHttpClientFactory(AsyncMock()) deps = ModuleDependencies({"http_client": factory}) - return OpenAPIToolRegistryModule(deps, {"tools": tools}) + return OpenAPIToolRegistryModule( + deps, + { + "tool_servers": tool_servers + or [{"url": "http://calc:8000/openapi.json"}] + }, + ) + + @pytest.mark.asyncio + async def test_get_tools_returns_openai_function_tools(self): + spec_client = AsyncMock() + spec_client.request = AsyncMock( + return_value=_mock_response(spec=SAMPLE_OPENAPI_SPEC) + ) + module = self._make_module(_StubHttpClientFactory(spec_client)) + + result = await module.get_tools(_request()) + + assert len(result) == 1 + assert result[0]["type"] == "function" + assert result[0]["name"] == "calculate" + assert result[0]["strict"] is True @pytest.mark.asyncio async def test_get_tools_empty_config(self): - module = self._make_module([]) - result = await module.get_tools() + deps = ModuleDependencies({"http_client": _StubHttpClientFactory(AsyncMock())}) + module = OpenAPIToolRegistryModule(deps, {}) + + result = await module.get_tools(_request()) + assert result == [] @pytest.mark.asyncio async def test_get_tools_returns_tools_from_all_services(self): search_spec = { - **SAMPLE_OPENAPI_SPEC, + "openapi": "3.1.0", "paths": { "/search": { "put": { @@ -376,40 +392,19 @@ async def mock_request(method, url, **kwargs): mock_client = AsyncMock() mock_client.request = mock_request module = self._make_module( - [ - {"url": "http://calc:8000/calculate", "method": "POST"}, - {"url": "http://search:8000/search", "method": "PUT"}, + _StubHttpClientFactory(mock_client), + tool_servers=[ + {"url": "http://calc:8000/openapi.json"}, + {"url": "http://search:8000/openapi.json"}, ], - factory=_StubHttpClientFactory(mock_client), ) - result = await module.get_tools() + result = await module.get_tools(_request()) assert len(result) == 2 - assert isinstance(result[0], Tool) - assert isinstance(result[1], Tool) - names = {tool.definition.name for tool in result} + names = {tool["name"] for tool in result} assert names == {"calculate", "web_search"} - @pytest.mark.asyncio - async def test_tool_definition_extracted_from_spec(self): - mock_client = AsyncMock() - mock_client.request = AsyncMock( - return_value=_mock_response(spec=SAMPLE_OPENAPI_SPEC) - ) - module = self._make_module( - [{"url": "http://calc:8000/calculate", "method": "POST"}], - factory=_StubHttpClientFactory(mock_client), - ) - - result = await module.get_tools() - - assert len(result) == 1 - definition = result[0].definition - assert definition.name == "calculate" - assert definition.description == "Evaluate a math expression" - assert "expression" in definition.parameters["properties"] - @pytest.mark.asyncio async def test_get_tools_skips_unavailable_service(self): async def mock_request(method, url, **kwargs): @@ -420,82 +415,94 @@ async def mock_request(method, url, **kwargs): mock_client = AsyncMock() mock_client.request = mock_request module = self._make_module( - [ - {"url": "http://good:8000/run", "method": "POST"}, - {"url": "http://bad:8000/run", "method": "POST"}, + _StubHttpClientFactory(mock_client), + tool_servers=[ + {"url": "http://good:8000/openapi.json"}, + {"url": "http://bad:8000/openapi.json"}, ], - factory=_StubHttpClientFactory(mock_client), ) - result = await module.get_tools() + result = await module.get_tools(_request()) assert len(result) == 1 - assert result[0].definition.name == "calculate" + assert result[0]["name"] == "calculate" @pytest.mark.asyncio async def test_get_tools_skips_spec_without_operation_id(self): no_op_spec = {"paths": {"/run": {"post": {"summary": "No operationId"}}}} mock_client = AsyncMock() mock_client.request = AsyncMock(return_value=_mock_response(spec=no_op_spec)) - module = self._make_module( - [{"url": "http://tool:8000/run", "method": "POST"}], - factory=_StubHttpClientFactory(mock_client), - ) + module = self._make_module(_StubHttpClientFactory(mock_client)) - result = await module.get_tools() + result = await module.get_tools(_request()) assert result == [] + @pytest.mark.asyncio + async def test_get_tools_returns_all_operations_from_single_server(self): + multi_op_spec = { + "openapi": "3.1.0", + "paths": { + "/calculate": { + "post": {"summary": "Evaluate", "operationId": "calculate"} + }, + "/search": {"put": {"summary": "Search", "operationId": "web_search"}}, + }, + } + mock_client = AsyncMock() + mock_client.request = AsyncMock(return_value=_mock_response(spec=multi_op_spec)) + module = self._make_module(_StubHttpClientFactory(mock_client)) + + result = await module.get_tools(_request()) + + names = {tool["name"] for tool in result} + assert names == {"calculate", "web_search"} + def test_has_no_router(self): - module = self._make_module([]) + deps = ModuleDependencies({"http_client": _StubHttpClientFactory(AsyncMock())}) + module = OpenAPIToolRegistryModule(deps, {}) assert not hasattr(module, "router") - def test_stores_tool_services_from_config(self): - tools = [ - {"url": "http://a:8000/run", "method": "POST"}, - {"url": "http://b:9000/exec", "method": "PUT"}, + def test_stores_tool_servers_from_config(self): + tool_servers = [ + {"url": "http://a:8000/openapi.json"}, + {"url": "http://b:9000/openapi.json"}, ] - module = self._make_module(tools) - assert module.tool_services == tools + deps = ModuleDependencies({"http_client": _StubHttpClientFactory(AsyncMock())}) + module = OpenAPIToolRegistryModule(deps, {"tool_servers": tool_servers}) + assert module.tool_servers == tool_servers - def test_defaults_to_empty_tools_list(self): - deps = ModuleDependencies() + def test_defaults_to_empty_tool_servers_list(self): + deps = ModuleDependencies({"http_client": _StubHttpClientFactory(AsyncMock())}) module = OpenAPIToolRegistryModule(deps, {}) - assert module.tool_services == [] + assert module.tool_servers == [] -class TestToolRun: - """Tool.run invokes the tool microservice over HTTP.""" - +class TestRunTool: def _make_module( - self, tools: list[dict], factory=None + self, spec: dict[str, Any], run_client: AsyncMock ) -> OpenAPIToolRegistryModule: - if factory is None: - factory = _StubHttpClientFactory(AsyncMock()) - deps = ModuleDependencies({"http_client": factory}) - return OpenAPIToolRegistryModule(deps, {"tools": tools}) - - @pytest.mark.asyncio - async def test_run_makes_http_request_to_tool_endpoint(self): spec_client = AsyncMock() - spec_client.request = AsyncMock( - return_value=_mock_response(spec=SAMPLE_OPENAPI_SPEC) + spec_client.request = AsyncMock(return_value=_mock_response(spec=spec)) + deps = ModuleDependencies( + {"http_client": _StubHttpClientFactory(spec_client, run_client)} + ) + return OpenAPIToolRegistryModule( + deps, {"tool_servers": [{"url": "http://calc:8000/openapi.json"}]} ) - run_response = _mock_response(text='{"result": 42}') + @pytest.mark.asyncio + async def test_run_tool_executes_operation_by_name(self): run_client = AsyncMock() - run_client.request = AsyncMock(return_value=run_response) - - # factory yields spec_client on first new() call, run_client on second - module = self._make_module( - [{"url": "http://calc:8000/calculate", "method": "POST"}], - factory=_StubHttpClientFactory(spec_client, run_client), + run_client.request = AsyncMock( + return_value=_mock_response(text='{"result": 42}') ) + module = self._make_module(SAMPLE_OPENAPI_SPEC, run_client) - tools = await module.get_tools() - assert len(tools) == 1 - - result = await tools[0].run({"expression": "6*7"}) + result = await module.run_tool( + _request("/api/responses"), + {"name": "calculate", "arguments": {"expression": "6*7"}}, + ) run_client.request.assert_called_once_with( method="POST", @@ -506,190 +513,128 @@ async def test_run_makes_http_request_to_tool_endpoint(self): assert result == '{"result": 42}' @pytest.mark.asyncio - async def test_run_forwards_bearer_token_as_authorization_header(self): - """When _bearer_token is in params it becomes Authorization: Bearer .""" - spec_client = AsyncMock() - spec_client.request = AsyncMock( - return_value=_mock_response(spec=SAMPLE_OPENAPI_SPEC) - ) - - run_response = _mock_response(text='{"result": 42}') + async def test_run_tool_forwards_bearer_and_header_params(self): + spec = { + "openapi": "3.1.0", + "paths": { + "/submit": { + "post": { + "operationId": "submit", + "parameters": [ + { + "name": "X-Request-Id", + "in": "header", + "required": True, + "schema": {"type": "string"}, + } + ], + } + } + }, + } run_client = AsyncMock() - run_client.request = AsyncMock(return_value=run_response) - - module = self._make_module( - [{"url": "http://calc:8000/calculate", "method": "POST"}], - factory=_StubHttpClientFactory(spec_client, run_client), + run_client.request = AsyncMock(return_value=_mock_response(text='{"ok": true}')) + module = self._make_module(spec, run_client) + + await module.run_tool( + _request(), + { + "name": "submit", + "arguments": { + "payload": "hello", + "X-Request-Id": "req-1", + "_bearer_token": "secret", + }, + }, ) - tools = await module.get_tools() - await tools[0].run({"expression": "2+2", "_bearer_token": "secret"}) - run_client.request.assert_called_once_with( method="POST", - url="http://calc:8000/calculate", - json={"expression": "2+2"}, - headers={"Authorization": "Bearer secret"}, + url="http://calc:8000/submit", + json={"payload": "hello"}, + headers={"Authorization": "Bearer secret", "X-Request-Id": "req-1"}, ) @pytest.mark.asyncio - async def test_run_substitutes_path_parameters_into_url(self): - """Path parameters are substituted into the URL template, not sent in the body.""" - spec_client = AsyncMock() - spec_client.request = AsyncMock( - return_value=_mock_response(spec=PATH_PARAMS_SPEC) - ) - - run_response = _mock_response(text='{"order": "details"}') + async def test_run_tool_raises_for_unknown_name(self): run_client = AsyncMock() - run_client.request = AsyncMock(return_value=run_response) + run_client.request = AsyncMock(return_value=_mock_response(text='{"ok": true}')) + module = self._make_module(SAMPLE_OPENAPI_SPEC, run_client) - module = self._make_module( - [ - { - "url": "http://users:8000/users/{user_id}/orders/{order_id}", - "method": "GET", - } - ], - factory=_StubHttpClientFactory(spec_client, run_client), - ) + with pytest.raises(ValueError, match="not found"): + await module.run_tool(_request(), {"name": "unknown", "arguments": {}}) - tools = await module.get_tools() - assert len(tools) == 1 + @pytest.mark.asyncio + async def test_run_tool_substitutes_path_parameters_into_url(self): + run_client = AsyncMock() + run_client.request = AsyncMock( + return_value=_mock_response(text='{"order": "details"}') + ) + module = self._make_module(PATH_PARAMS_SPEC, run_client) - result = await tools[0].run({"user_id": "alice", "order_id": 42}) + result = await module.run_tool( + _request(), + { + "name": "get_user_order", + "arguments": {"user_id": "alice", "order_id": 42}, + }, + ) run_client.request.assert_called_once_with( method="GET", - url="http://users:8000/users/alice/orders/42", + url="http://calc:8000/users/alice/orders/42", json={}, headers={}, ) assert result == '{"order": "details"}' @pytest.mark.asyncio - async def test_run_substitutes_path_parameters_leaving_body_params(self): - """Path params are substituted into URL; remaining params go in the request body.""" - spec_client = AsyncMock() - spec_client.request = AsyncMock( - return_value=_mock_response(spec=PATH_PARAMS_WITH_BODY_SPEC) - ) - - run_response = _mock_response(text='{"updated": true}') + async def test_run_tool_substitutes_path_params_leaving_body(self): run_client = AsyncMock() - run_client.request = AsyncMock(return_value=run_response) - - module = self._make_module( - [ - { - "url": "http://items:8000/items/{item_id}", - "method": "PUT", - } - ], - factory=_StubHttpClientFactory(spec_client, run_client), + run_client.request = AsyncMock( + return_value=_mock_response(text='{"updated": true}') ) + module = self._make_module(PATH_PARAMS_WITH_BODY_SPEC, run_client) - tools = await module.get_tools() - assert len(tools) == 1 - - await tools[0].run({"item_id": 7, "name": "Widget"}) + await module.run_tool( + _request(), + {"name": "update_item", "arguments": {"item_id": 7, "name": "Widget"}}, + ) run_client.request.assert_called_once_with( method="PUT", - url="http://items:8000/items/7", + url="http://calc:8000/items/7", json={"name": "Widget"}, headers={}, ) - @pytest.mark.asyncio - async def test_run_forwards_header_parameters_as_http_headers(self): - """Header parameters declared in the spec are forwarded as HTTP headers, not in the body.""" - spec_client = AsyncMock() - spec_client.request = AsyncMock( - return_value=_mock_response(spec=HEADER_AND_BODY_SPEC) - ) - - run_response = _mock_response(text='{"ok": true}') - run_client = AsyncMock() - run_client.request = AsyncMock(return_value=run_response) - - module = self._make_module( - [{"url": "http://submit:8000/submit", "method": "POST"}], - factory=_StubHttpClientFactory(spec_client, run_client), - ) - - tools = await module.get_tools() - assert len(tools) == 1 - - await tools[0].run({"payload": "hello", "X-Request-Id": "req-abc"}) - - run_client.request.assert_called_once_with( - method="POST", - url="http://submit:8000/submit", - json={"payload": "hello"}, - headers={"X-Request-Id": "req-abc"}, - ) - - @pytest.mark.asyncio - async def test_run_combines_bearer_token_and_header_parameters(self): - """Both _bearer_token and header params end up in the headers dict.""" - spec_client = AsyncMock() - spec_client.request = AsyncMock( - return_value=_mock_response(spec=HEADER_AND_BODY_SPEC) - ) - - run_response = _mock_response(text='{"ok": true}') - run_client = AsyncMock() - run_client.request = AsyncMock(return_value=run_response) - - module = self._make_module( - [{"url": "http://submit:8000/submit", "method": "POST"}], - factory=_StubHttpClientFactory(spec_client, run_client), - ) - - tools = await module.get_tools() - await tools[0].run( - {"payload": "hello", "X-Request-Id": "req-abc", "_bearer_token": "tok"} - ) - - run_client.request.assert_called_once_with( - method="POST", - url="http://submit:8000/submit", - json={"payload": "hello"}, - headers={"Authorization": "Bearer tok", "X-Request-Id": "req-abc"}, - ) - -class TestFetchOpenapiSpec: +class TestHelpers: @pytest.mark.asyncio - async def test_success(self): - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.raise_for_status = lambda: None - mock_response.json.return_value = SAMPLE_OPENAPI_SPEC - + async def test_fetch_openapi_spec_appends_default_path(self): client = AsyncMock() - client.request = AsyncMock(return_value=mock_response) + client.request = AsyncMock( + return_value=_mock_response(spec=SAMPLE_OPENAPI_SPEC) + ) result = await _fetch_openapi_spec(client, "http://tool:8000") + assert result == SAMPLE_OPENAPI_SPEC client.request.assert_called_once_with("GET", "http://tool:8000/openapi.json") @pytest.mark.asyncio - async def test_strips_trailing_slash(self): - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.raise_for_status = lambda: None - mock_response.json.return_value = SAMPLE_OPENAPI_SPEC - + async def test_fetch_openapi_spec_strips_trailing_slash(self): client = AsyncMock() - client.request = AsyncMock(return_value=mock_response) + client.request = AsyncMock( + return_value=_mock_response(spec=SAMPLE_OPENAPI_SPEC) + ) await _fetch_openapi_spec(client, "http://tool:8000/") + client.request.assert_called_once_with("GET", "http://tool:8000/openapi.json") @pytest.mark.asyncio - async def test_http_error_returns_none(self): + async def test_fetch_openapi_spec_http_error_returns_none(self): mock_response = MagicMock() mock_response.status_code = 500 mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( @@ -697,93 +642,48 @@ async def test_http_error_returns_none(self): request=httpx.Request("GET", "http://tool:8000/openapi.json"), response=mock_response, ) - client = AsyncMock() client.request = AsyncMock(return_value=mock_response) result = await _fetch_openapi_spec(client, "http://tool:8000") + assert result is None @pytest.mark.asyncio - async def test_connection_error_returns_none(self): + async def test_fetch_openapi_spec_connection_error_returns_none(self): client = AsyncMock() client.request = AsyncMock(side_effect=httpx.ConnectError("Connection refused")) result = await _fetch_openapi_spec(client, "http://tool:8000") + assert result is None @pytest.mark.asyncio - async def test_unexpected_error_returns_none(self): + async def test_fetch_openapi_spec_unexpected_error_returns_none(self): client = AsyncMock() client.request = AsyncMock(side_effect=RuntimeError("something went wrong")) result = await _fetch_openapi_spec(client, "http://tool:8000") - assert result is None - - -class TestDeriveBaseUrl: - def test_strips_path(self): - assert _derive_base_url("http://calc:8000/calculate") == "http://calc:8000" - - def test_strips_nested_path(self): - assert _derive_base_url("http://host:9000/api/v1/run") == "http://host:9000" - - def test_no_path(self): - assert _derive_base_url("http://tool:8000") == "http://tool:8000" - - def test_trailing_slash(self): - assert _derive_base_url("http://tool:8000/") == "http://tool:8000" + assert result is None -class TestGetToolByName: - def _make_module( - self, tools: list[dict], factory=None - ) -> OpenAPIToolRegistryModule: - if factory is None: - factory = _StubHttpClientFactory(AsyncMock()) - deps = ModuleDependencies({"http_client": factory}) - return OpenAPIToolRegistryModule(deps, {"tools": tools}) - - def _make_spec_factory(self, spec_map: dict[str, dict]): - """Build an HttpClientFactory whose client dispatches by URL key.""" - - async def mock_request(method, url, **kwargs): - for key, spec in spec_map.items(): - if key in url: - return _mock_response(spec=spec) - raise httpx.ConnectError("No mock for " + url) - - mock_client = AsyncMock() - mock_client.request = mock_request - return _StubHttpClientFactory(mock_client) - - @pytest.mark.asyncio - async def test_finds_tool_by_name(self): - module = self._make_module( - [{"url": "http://calc:8000/calculate", "method": "POST"}], - factory=self._make_spec_factory({"calc": SAMPLE_OPENAPI_SPEC}), + def test_derive_base_url(self): + assert ( + _derive_base_url("http://host:9000/api/openapi.json") + == "http://host:9000/api" ) - result = await module.get_tool_by_name("calculate") - - assert result is not None - assert isinstance(result, Tool) - assert result.definition.name == "calculate" - assert result.definition.description == "Evaluate a math expression" + def test_derive_base_url_strips_openapi_json(self): + assert _derive_base_url("http://calc:8000/openapi.json") == "http://calc:8000" - @pytest.mark.asyncio - async def test_returns_none_for_unknown_name(self): - module = self._make_module( - [{"url": "http://calc:8000/calculate", "method": "POST"}], - factory=self._make_spec_factory({"calc": SAMPLE_OPENAPI_SPEC}), + def test_derive_base_url_keeps_nested_base_path(self): + assert ( + _derive_base_url("http://host:9000/api/v1/openapi.json") + == "http://host:9000/api/v1" ) - result = await module.get_tool_by_name("nonexistent") - - assert result is None + def test_derive_base_url_accepts_bare_base_url(self): + assert _derive_base_url("http://tool:8000") == "http://tool:8000" - @pytest.mark.asyncio - async def test_returns_none_for_empty_registry(self): - module = self._make_module([]) - result = await module.get_tool_by_name("calculate") - assert result is None + def test_derive_base_url_trailing_slash(self): + assert _derive_base_url("http://tool:8000/") == "http://tool:8000" diff --git a/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry_predefined_vars.py b/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry_predefined_vars.py index 7c16dc79..9da83cc5 100644 --- a/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry_predefined_vars.py +++ b/backend/omni/src/modai/modules/tools/__tests__/test_tool_registry_predefined_vars.py @@ -1,465 +1,160 @@ from typing import Any -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import ANY, AsyncMock, MagicMock import pytest +from fastapi import Request from modai.module import ModuleDependencies -from modai.modules.tools.module import Tool, ToolDefinition, ToolRegistryModule +from modai.modules.tools.module import ToolRegistryModule from modai.modules.tools.tool_registry_predefined_vars import ( PredefinedVariablesToolRegistryModule, + _extract_predefined_params, ) -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _make_tool(definition: ToolDefinition) -> Tool: - class _StubTool(Tool): - @property - def definition(self) -> ToolDefinition: - return definition - - async def run(self, params: dict[str, Any]) -> Any: - return params - - return _StubTool() - - -def _make_capturing_tool(definition: ToolDefinition) -> tuple[Tool, list[dict]]: - """Return a tool and a list that receives each run() params call.""" - calls: list[dict] = [] - - class _CapturingTool(Tool): - @property - def definition(self) -> ToolDefinition: - return definition - - async def run(self, params: dict[str, Any]) -> Any: - calls.append(dict(params)) - return "ok" - - return _CapturingTool(), calls - - -def _stub_registry(*tools: Tool) -> ToolRegistryModule: - """Build a mock ToolRegistryModule that returns the given tools.""" - registry = MagicMock(spec=ToolRegistryModule) - registry.get_tools = AsyncMock(return_value=list(tools)) - - async def _get_by_name(name: str, predefined_params=None) -> Tool | None: - return next((t for t in tools if t.definition.name == name), None) - - registry.get_tool_by_name = _get_by_name - return registry +def _request(headers: dict[str, str] | None = None) -> Request: + raw_headers = [ + (name.lower().encode("latin-1"), value.encode("latin-1")) + for name, value in (headers or {}).items() + ] + scope: dict[str, Any] = { + "type": "http", + "method": "GET", + "path": "/api/tools", + "headers": raw_headers, + } + return Request(scope) + + +FULL_TOOL = { + "type": "function", + "name": "get_user_order", + "description": "Retrieve an order", + "parameters": { + "type": "object", + "properties": { + "user_id": {"type": "string"}, + "session_id": {"type": "string"}, + "X-Session-Id": {"type": "string"}, + }, + "required": ["user_id", "session_id", "X-Session-Id"], + }, + "strict": True, +} def _make_module( - inner: ToolRegistryModule, + inner_get_tools: AsyncMock, + inner_run_tool: AsyncMock, variable_mappings: dict[str, str] | None = None, ) -> PredefinedVariablesToolRegistryModule: + inner = MagicMock(spec=ToolRegistryModule) + inner.get_tools = inner_get_tools + inner.run_tool = inner_run_tool + deps = ModuleDependencies({"delegate_registry": inner}) - config: dict = {} + config: dict[str, Any] = {} if variable_mappings: config["variable_mappings"] = variable_mappings return PredefinedVariablesToolRegistryModule(deps, config) -FULL_DEFINITION = ToolDefinition( - name="get_user_order", - description="Retrieve an order", - parameters={ - "type": "object", - "properties": { - "user_id": {"type": "string", "description": "The user's ID"}, - "order_id": {"type": "integer", "description": "The order's ID"}, - "session_id": {"type": "string", "description": "Active session"}, - }, - "required": ["user_id", "order_id", "session_id"], - }, -) - -SIMPLE_DEFINITION = ToolDefinition( - name="calculate", - description="Evaluate a math expression", - parameters={ - "type": "object", - "properties": {"expression": {"type": "string"}}, - "required": ["expression"], - }, -) - -HEADER_DEFINITION = ToolDefinition( - name="fetch_data", - description="Fetch session data", - parameters={ - "type": "object", - "properties": { - "X-Session-Id": {"type": "string", "description": "Active session"}, - "filter": {"type": "string", "description": "Optional filter"}, - }, - "required": ["X-Session-Id"], - }, -) - - -# --------------------------------------------------------------------------- -# get_tools: definition filtering -# --------------------------------------------------------------------------- - - -class TestGetToolsDefinitionFiltering: - @pytest.mark.asyncio - async def test_no_predefined_params_returns_full_definition(self): - tool = _make_tool(FULL_DEFINITION) - module = _make_module(_stub_registry(tool)) - - result = await module.get_tools() - - assert len(result) == 1 - assert result[0].definition == FULL_DEFINITION - - @pytest.mark.asyncio - async def test_predefined_param_stripped_from_properties(self): - tool = _make_tool(FULL_DEFINITION) - module = _make_module(_stub_registry(tool)) - - result = await module.get_tools(predefined_params={"_session_id": "abc"}) - - assert len(result) == 1 - params = result[0].definition.parameters - assert "session_id" not in params["properties"] - assert "user_id" in params["properties"] - assert "order_id" in params["properties"] - - @pytest.mark.asyncio - async def test_predefined_param_stripped_from_required(self): - tool = _make_tool(FULL_DEFINITION) - module = _make_module(_stub_registry(tool)) - - result = await module.get_tools(predefined_params={"_session_id": "abc"}) - - required = result[0].definition.parameters["required"] - assert "session_id" not in required - assert "user_id" in required - assert "order_id" in required - - @pytest.mark.asyncio - async def test_multiple_predefined_params_stripped(self): - tool = _make_tool(FULL_DEFINITION) - module = _make_module(_stub_registry(tool)) - - result = await module.get_tools( - predefined_params={"_session_id": "s1", "_user_id": "u1"} - ) - - params = result[0].definition.parameters - assert "session_id" not in params["properties"] - assert "user_id" not in params["properties"] - assert "order_id" in params["properties"] - - @pytest.mark.asyncio - async def test_predefined_param_not_in_schema_leaves_definition_unchanged(self): - tool = _make_tool(SIMPLE_DEFINITION) - module = _make_module(_stub_registry(tool)) - - result = await module.get_tools(predefined_params={"_session_id": "abc"}) - - # session_id doesn't exist in the schema → tool is returned as-is (no wrapper) - assert result[0].definition == SIMPLE_DEFINITION - - @pytest.mark.asyncio - async def test_non_prefixed_predefined_key_is_ignored(self): - """Keys without a leading _ are not treated as predefined variables.""" - tool = _make_tool(FULL_DEFINITION) - module = _make_module(_stub_registry(tool)) - - result = await module.get_tools(predefined_params={"session_id": "abc"}) - - params = result[0].definition.parameters - assert "session_id" in params["properties"] - - @pytest.mark.asyncio - async def test_empty_predefined_params_returns_full_definition(self): - tool = _make_tool(FULL_DEFINITION) - module = _make_module(_stub_registry(tool)) - - result = await module.get_tools(predefined_params={}) - - assert result[0].definition == FULL_DEFINITION - - @pytest.mark.asyncio - async def test_multiple_tools_each_filtered(self): - tool_a = _make_tool(FULL_DEFINITION) - tool_b = _make_tool(SIMPLE_DEFINITION) - module = _make_module(_stub_registry(tool_a, tool_b)) - - result = await module.get_tools(predefined_params={"_session_id": "s1"}) - - assert len(result) == 2 - # tool_a had session_id → stripped - assert "session_id" not in result[0].definition.parameters["properties"] - # tool_b had no session_id → unchanged - assert result[1].definition == SIMPLE_DEFINITION - - -# --------------------------------------------------------------------------- -# get_tool_by_name -# --------------------------------------------------------------------------- - - -class TestGetToolByName: - @pytest.mark.asyncio - async def test_returns_filtered_tool_when_found(self): - tool = _make_tool(FULL_DEFINITION) - module = _make_module(_stub_registry(tool)) - - result = await module.get_tool_by_name( - "get_user_order", predefined_params={"_session_id": "s1"} +class TestExtractPredefinedParams: + def test_extracts_header_values_with_underscore_prefix(self): + params = _extract_predefined_params( + _request({"Authorization": "Bearer abc", "X-Session-Id": "sid-1"}) ) + assert params == { + "_authorization": "Bearer abc", + "_x_session_id": "sid-1", + } - assert result is not None - assert "session_id" not in result.definition.parameters["properties"] +class TestGetTools: @pytest.mark.asyncio - async def test_returns_none_when_not_found(self): - module = _make_module(_stub_registry()) - - result = await module.get_tool_by_name( - "nonexistent", predefined_params={"_session_id": "s1"} + async def test_hides_schema_properties_that_are_predefined(self): + module = _make_module( + inner_get_tools=AsyncMock(return_value=[FULL_TOOL]), + inner_run_tool=AsyncMock(), ) - assert result is None - - @pytest.mark.asyncio - async def test_no_predefined_params_returns_full_definition(self): - tool = _make_tool(FULL_DEFINITION) - module = _make_module(_stub_registry(tool)) - - result = await module.get_tool_by_name("get_user_order") - - assert result is not None - assert result.definition == FULL_DEFINITION - - -# --------------------------------------------------------------------------- -# run() — predefined variable translation -# --------------------------------------------------------------------------- + result = await module.get_tools(_request({"Session-Id": "sid-1"})) + assert len(result) == 1 + props = result[0]["parameters"]["properties"] + assert "session_id" not in props + assert "user_id" in props -class TestRunTranslation: @pytest.mark.asyncio - async def test_predefined_key_translated_to_unprefixed_before_inner_run(self): - inner_tool, calls = _make_capturing_tool(FULL_DEFINITION) - module = _make_module(_stub_registry(inner_tool)) - - wrapped = await module.get_tool_by_name( - "get_user_order", predefined_params={"_session_id": "session-xyz"} + async def test_hides_mapped_property_from_definition(self): + module = _make_module( + inner_get_tools=AsyncMock(return_value=[FULL_TOOL]), + inner_run_tool=AsyncMock(), + variable_mappings={"X-Session-Id": "x_session_id"}, ) - assert wrapped is not None - await wrapped.run( - {"user_id": "alice", "order_id": 7, "_session_id": "session-xyz"} - ) + result = await module.get_tools(_request({"X-Session-Id": "sid-1"})) - assert len(calls) == 1 - assert calls[0]["session_id"] == "session-xyz" - assert "_session_id" not in calls[0] + props = result[0]["parameters"]["properties"] + assert "X-Session-Id" not in props - @pytest.mark.asyncio - async def test_non_predefined_params_passed_through_unchanged(self): - inner_tool, calls = _make_capturing_tool(FULL_DEFINITION) - module = _make_module(_stub_registry(inner_tool)) - - wrapped = await module.get_tool_by_name( - "get_user_order", predefined_params={"_session_id": "s1"} - ) - assert wrapped is not None - - await wrapped.run({"user_id": "alice", "order_id": 7, "_session_id": "s1"}) - - assert calls[0]["user_id"] == "alice" - assert calls[0]["order_id"] == 7 +class TestRunTool: @pytest.mark.asyncio - async def test_bearer_token_not_in_schema_stays_prefixed(self): - """_bearer_token is a reserved key not found in the schema — it must - remain as _bearer_token so the inner tool (e.g. _OpenAPITool) can - handle it for the Authorization header.""" - inner_tool, calls = _make_capturing_tool(FULL_DEFINITION) - module = _make_module(_stub_registry(inner_tool)) - - # _bearer_token is NOT in the schema, so it is NOT a hidden property - wrapped = await module.get_tool_by_name( - "get_user_order", - predefined_params={"_session_id": "s1", "_bearer_token": "tok"}, + async def test_injects_predefined_values_into_arguments(self): + inner_run_tool = AsyncMock(return_value="ok") + module = _make_module( + inner_get_tools=AsyncMock(return_value=[FULL_TOOL]), + inner_run_tool=inner_run_tool, ) - assert wrapped is not None - await wrapped.run( + await module.run_tool( + _request({"Session-Id": "sid-1"}), { - "user_id": "alice", - "order_id": 7, - "_session_id": "s1", - "_bearer_token": "tok", - } - ) - - # _bearer_token was not in schema so it is not translated - assert "_bearer_token" in calls[0] - assert "bearer_token" not in calls[0] - - @pytest.mark.asyncio - async def test_tool_not_requiring_wrapping_is_returned_directly(self): - """When predefined params have no overlap with the schema, the original - tool object is returned without a wrapper.""" - tool = _make_tool(SIMPLE_DEFINITION) - module = _make_module(_stub_registry(tool)) - - result = await module.get_tool_by_name( - "calculate", predefined_params={"_session_id": "s1"} - ) - - # Same object — no wrapper was needed - assert result is tool - - -# --------------------------------------------------------------------------- -# variable_mappings config -# --------------------------------------------------------------------------- - - -class TestVariableMappings: - @pytest.mark.asyncio - async def test_mapped_tool_param_hidden_from_definition(self): - """X-Session-Id is stripped when _session_id is predefined and mapping is configured.""" - tool = _make_tool(HEADER_DEFINITION) - module = _make_module( - _stub_registry(tool), - variable_mappings={"X-Session-Id": "session_id"}, + "name": "get_user_order", + "arguments": {"user_id": "alice"}, + }, ) - result = await module.get_tools(predefined_params={"_session_id": "sess-abc"}) - - assert len(result) == 1 - params = result[0].definition.parameters - assert "X-Session-Id" not in params["properties"] - assert "filter" in params["properties"] - assert "X-Session-Id" not in params.get("required", []) - - @pytest.mark.asyncio - async def test_mapped_param_not_hidden_when_predefined_var_absent(self): - """If _session_id is not in predefined_params, X-Session-Id stays visible.""" - tool = _make_tool(HEADER_DEFINITION) - module = _make_module( - _stub_registry(tool), - variable_mappings={"X-Session-Id": "session_id"}, + inner_run_tool.assert_awaited_once_with( + ANY, + { + "name": "get_user_order", + "arguments": {"user_id": "alice", "session_id": "sid-1"}, + }, ) - result = await module.get_tools(predefined_params={}) - - assert result[0].definition == HEADER_DEFINITION - @pytest.mark.asyncio - async def test_run_translates_predefined_key_to_mapped_tool_param(self): - """_session_id is translated to X-Session-Id (not session_id) per the mapping.""" - inner_tool, calls = _make_capturing_tool(HEADER_DEFINITION) + async def test_mapped_injection_uses_configured_tool_param_name(self): + inner_run_tool = AsyncMock(return_value="ok") module = _make_module( - _stub_registry(inner_tool), - variable_mappings={"X-Session-Id": "session_id"}, - ) - - wrapped = await module.get_tool_by_name( - "fetch_data", predefined_params={"_session_id": "sess-xyz"} + inner_get_tools=AsyncMock(return_value=[FULL_TOOL]), + inner_run_tool=inner_run_tool, + variable_mappings={"X-Session-Id": "x_session_id"}, ) - assert wrapped is not None - await wrapped.run({"filter": "recent", "_session_id": "sess-xyz"}) - - assert len(calls) == 1 - assert calls[0]["X-Session-Id"] == "sess-xyz" - assert "session_id" not in calls[0] - assert "_session_id" not in calls[0] - - @pytest.mark.asyncio - async def test_direct_and_configured_mappings_coexist(self): - """A direct-mapped param (session_id) and a configured mapping (X-Session-Id) - for different predefined vars can both be active at the same time.""" - definition = ToolDefinition( - name="multi_param_tool", - description="Tool with both direct and mapped params", - parameters={ - "type": "object", - "properties": { - "session_id": {"type": "string"}, - "X-Tenant": {"type": "string"}, - "query": {"type": "string"}, - }, - "required": ["session_id", "X-Tenant", "query"], + await module.run_tool( + _request({"X-Session-Id": "sid-1"}), + { + "name": "get_user_order", + "arguments": {"user_id": "alice"}, }, ) - inner_tool, calls = _make_capturing_tool(definition) - module = _make_module( - _stub_registry(inner_tool), - variable_mappings={"X-Tenant": "tenant_id"}, - ) - wrapped = await module.get_tool_by_name( - "multi_param_tool", - predefined_params={"_session_id": "s1", "_tenant_id": "acme"}, + inner_run_tool.assert_awaited_once_with( + ANY, + { + "name": "get_user_order", + "arguments": {"user_id": "alice", "X-Session-Id": "sid-1"}, + }, ) - assert wrapped is not None - - # Both session_id and X-Tenant should be hidden from the definition - params = wrapped.definition.parameters - assert "session_id" not in params["properties"] - assert "X-Tenant" not in params["properties"] - assert "query" in params["properties"] - - await wrapped.run({"query": "hello", "_session_id": "s1", "_tenant_id": "acme"}) - - assert calls[0]["session_id"] == "s1" # direct mapping - assert calls[0]["X-Tenant"] == "acme" # configured mapping - assert "_session_id" not in calls[0] - assert "_tenant_id" not in calls[0] @pytest.mark.asyncio - async def test_configured_mapping_overrides_direct_for_same_var(self): - """When a mapping routes _session_id to X-Session-Id, the default - session_id → _session_id direct mapping must NOT also be applied.""" - definition = ToolDefinition( - name="override_tool", - description="Test override", - parameters={ - "type": "object", - "properties": { - "session_id": {"type": "string"}, - "X-Session-Id": {"type": "string"}, - }, - "required": ["session_id", "X-Session-Id"], - }, - ) - inner_tool, calls = _make_capturing_tool(definition) - # Map _session_id to X-Session-Id only — session_id in schema remains unaffected + async def test_unknown_tool_raises(self): module = _make_module( - _stub_registry(inner_tool), - variable_mappings={"X-Session-Id": "session_id"}, + inner_get_tools=AsyncMock(return_value=[]), + inner_run_tool=AsyncMock(), ) - wrapped = await module.get_tool_by_name( - "override_tool", - predefined_params={"_session_id": "s1"}, - ) - assert wrapped is not None - - # Only X-Session-Id should be hidden; session_id (different schema prop) stays - params = wrapped.definition.parameters - assert "X-Session-Id" not in params["properties"] - assert "session_id" in params["properties"] - - await wrapped.run({"session_id": "manual", "_session_id": "s1"}) - - assert calls[0]["X-Session-Id"] == "s1" - assert calls[0]["session_id"] == "manual" - assert "_session_id" not in calls[0] + with pytest.raises(ValueError, match="not found"): + await module.run_tool(_request(), {"name": "missing", "arguments": {}}) diff --git a/backend/omni/src/modai/modules/tools/__tests__/test_tool_router.py b/backend/omni/src/modai/modules/tools/__tests__/test_tool_router.py new file mode 100644 index 00000000..a2840cbc --- /dev/null +++ b/backend/omni/src/modai/modules/tools/__tests__/test_tool_router.py @@ -0,0 +1,162 @@ +from typing import Any +from unittest.mock import AsyncMock + +import pytest +from fastapi import Request + +from modai.module import ModuleDependencies +from modai.modules.tools.module import ToolRegistryModule +from modai.modules.tools.tool_router import ToolsRouterModule + + +class _RegistryStub(ToolRegistryModule): + def __init__( + self, + get_tools_result: list[dict[str, Any]], + run_tool_result: Any, + ): + super().__init__(ModuleDependencies(), {}) + self._get_tools_result = get_tools_result + self._run_tool_mock = AsyncMock(return_value=run_tool_result) + + async def get_tools(self, request: Request) -> list[dict[str, Any]]: + del request + return self._get_tools_result + + async def run_tool(self, request: Request, params: dict[str, Any]) -> Any: + return await self._run_tool_mock(request, params) + + +def _request() -> Request: + scope: dict[str, Any] = { + "type": "http", + "method": "GET", + "path": "/api/tools", + "headers": [], + } + return Request(scope) + + +@pytest.mark.asyncio +async def test_aggregates_tool_names_without_prefix(): + registry_a = _RegistryStub( + get_tools_result=[ + { + "type": "function", + "name": "calculate", + "parameters": {"type": "object", "properties": {}}, + "strict": True, + } + ], + run_tool_result="a", + ) + registry_b = _RegistryStub( + get_tools_result=[ + { + "type": "function", + "name": "search", + "parameters": {"type": "object", "properties": {}}, + "strict": True, + } + ], + run_tool_result="b", + ) + + module = ToolsRouterModule( + ModuleDependencies( + { + "registry_a": registry_a, + "registry_b": registry_b, + } + ), + {}, + ) + + tools = await module.get_tools(_request()) + + assert [tool["name"] for tool in tools] == ["calculate", "search"] + + +@pytest.mark.asyncio +async def test_dispatches_to_registry_that_exposes_tool_name(): + registry = _RegistryStub( + get_tools_result=[ + { + "type": "function", + "name": "calculate", + "parameters": {"type": "object", "properties": {}}, + "strict": True, + } + ], + run_tool_result='{"ok": true}', + ) + module = ToolsRouterModule( + ModuleDependencies({"openapi": registry}), + {}, + ) + + request = _request() + result = await module.run_tool( + request, + { + "name": "calculate", + "arguments": {"expression": "1+1"}, + }, + ) + + assert result == '{"ok": true}' + registry._run_tool_mock.assert_awaited_once_with( + request, + { + "name": "calculate", + "arguments": {"expression": "1+1"}, + }, + ) + + +@pytest.mark.asyncio +async def test_raises_when_tool_name_is_not_found(): + registry = _RegistryStub( + get_tools_result=[], + run_tool_result="ok", + ) + module = ToolsRouterModule( + ModuleDependencies({"registry": registry}), + {}, + ) + + with pytest.raises(ValueError, match="not found"): + await module.run_tool(_request(), {"name": "calculate", "arguments": {}}) + + +@pytest.mark.asyncio +async def test_raises_when_tool_name_is_ambiguous(): + registry_a = _RegistryStub( + get_tools_result=[ + { + "type": "function", + "name": "calculate", + "parameters": {"type": "object", "properties": {}}, + "strict": True, + } + ], + run_tool_result="a", + ) + registry_b = _RegistryStub( + get_tools_result=[ + { + "type": "function", + "name": "calculate", + "parameters": {"type": "object", "properties": {}}, + "strict": True, + } + ], + run_tool_result="b", + ) + module = ToolsRouterModule( + ModuleDependencies({"a": registry_a, "b": registry_b}), + {}, + ) + + with pytest.raises(ValueError, match="multiple registries"): + await module.run_tool(_request(), {"name": "calculate", "arguments": {}}) diff --git a/backend/omni/src/modai/modules/tools/__tests__/test_tools_web_module.py b/backend/omni/src/modai/modules/tools/__tests__/test_tools_web_module.py deleted file mode 100644 index 26aeb896..00000000 --- a/backend/omni/src/modai/modules/tools/__tests__/test_tools_web_module.py +++ /dev/null @@ -1,182 +0,0 @@ -from typing import Any -from unittest.mock import AsyncMock - -import pytest - -from modai.module import ModuleDependencies -from modai.modules.tools.module import Tool, ToolDefinition -from modai.modules.tools.tools_web_module import ( - OpenAIToolsWebModule, - _to_openai_format, -) - - -SAMPLE_OPENAPI_SPEC = { - "openapi": "3.1.0", - "info": {"title": "Calculator Tool", "version": "1.0.0"}, - "paths": { - "/calculate": { - "post": { - "summary": "Evaluate a math expression", - "operationId": "calculate", - "requestBody": { - "required": True, - "content": { - "application/json": { - "schema": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "Math expression to evaluate", - } - }, - "required": ["expression"], - } - } - }, - }, - } - } - }, -} - -SAMPLE_DEFINITION = ToolDefinition( - name="calculate", - description="Evaluate a math expression", - parameters={ - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "Math expression to evaluate", - } - }, - "required": ["expression"], - }, -) - - -def _make_tool(definition: ToolDefinition) -> Tool: - """Create a minimal Tool stub for testing.""" - - class _StubTool(Tool): - @property - def definition(self) -> ToolDefinition: - return definition - - async def run(self, params: dict[str, Any]) -> Any: - return "" - - return _StubTool() - - -class TestToOpenAIFormat: - def test_formats_valid_definition(self): - result = _to_openai_format(SAMPLE_DEFINITION) - assert result == { - "type": "function", - "function": { - "name": "calculate", - "description": "Evaluate a math expression", - "parameters": SAMPLE_DEFINITION.parameters, - "strict": True, - }, - } - - def test_uses_provided_description(self): - definition = ToolDefinition( - name="run_task", - description="Runs something", - parameters={"type": "object", "properties": {}}, - ) - result = _to_openai_format(definition) - assert result["function"]["description"] == "Runs something" - - def test_empty_description_is_preserved(self): - definition = ToolDefinition( - name="run_task", - description="", - parameters={"type": "object", "properties": {}}, - ) - result = _to_openai_format(definition) - assert result["function"]["description"] == "" - - def test_parameters_are_passed_through(self): - custom_params = {"type": "object", "properties": {"x": {"type": "integer"}}} - definition = ToolDefinition( - name="calc", description="desc", parameters=custom_params - ) - result = _to_openai_format(definition) - assert result["function"]["parameters"] == custom_params - - def test_strict_is_always_true(self): - result = _to_openai_format(SAMPLE_DEFINITION) - assert result["function"]["strict"] is True - - -class TestToolsWebModule: - def _make_module(self, registry_tools: list[Tool]) -> OpenAIToolsWebModule: - mock_registry = AsyncMock() - mock_registry.get_tools = AsyncMock(return_value=registry_tools) - deps = ModuleDependencies(modules={"tool_registry": mock_registry}) - return OpenAIToolsWebModule(deps, {}) - - def test_has_router_with_tools_endpoint(self): - module = self._make_module([]) - assert hasattr(module, "router") - routes = [r.path for r in module.router.routes] - assert "/api/tools" in routes - - @pytest.mark.asyncio - async def test_returns_empty_tools_when_registry_empty(self): - module = self._make_module([]) - result = await module.get_tools() - assert result == {"tools": []} - - @pytest.mark.asyncio - async def test_transforms_registry_tools_to_openai_format(self): - definition = ToolDefinition( - name="calculate", - description="Evaluate a math expression", - parameters={ - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "Math expression to evaluate", - } - }, - "required": ["expression"], - }, - ) - module = self._make_module([_make_tool(definition)]) - result = await module.get_tools() - - assert len(result["tools"]) == 1 - tool = result["tools"][0] - assert tool["type"] == "function" - assert tool["function"]["name"] == "calculate" - assert tool["function"]["description"] == "Evaluate a math expression" - assert "expression" in tool["function"]["parameters"]["properties"] - - @pytest.mark.asyncio - async def test_multiple_tools_returned(self): - search_def = ToolDefinition( - name="web_search", - description="Search the web", - parameters={ - "type": "object", - "properties": {"query": {"type": "string"}}, - "required": ["query"], - }, - ) - module = self._make_module( - [_make_tool(SAMPLE_DEFINITION), _make_tool(search_def)] - ) - result = await module.get_tools() - - assert len(result["tools"]) == 2 - names = [t["function"]["name"] for t in result["tools"]] - assert "calculate" in names - assert "web_search" in names diff --git a/backend/omni/src/modai/modules/tools/module.py b/backend/omni/src/modai/modules/tools/module.py index 204de908..31b270a9 100644 --- a/backend/omni/src/modai/modules/tools/module.py +++ b/backend/omni/src/modai/modules/tools/module.py @@ -1,137 +1,49 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Any -from fastapi import APIRouter +from fastapi import Request +from openai.types.responses.function_tool_param import FunctionToolParam -from modai.module import ModaiModule, ModuleDependencies +from modai.module import ModaiModule -@dataclass(frozen=True) -class ToolDefinition: - """LLM-agnostic description of a tool. - - Contains enough information to construct LLM tool calls but is not tied - to any specific LLM API format. Parameters are fully resolved (no $ref) - so they can be passed directly to any LLM. - """ - - name: str - description: str - parameters: dict[str, Any] - - -class Tool(ABC): - """A tool with its LLM-agnostic definition and run capability. - - Implementations provide both the definition (used by LLMs to understand - and invoke the tool) and the ability to execute the tool with parameters - returned by the LLM. - """ - - @property - @abstractmethod - def definition(self) -> ToolDefinition: - """The tool's LLM-agnostic definition (name, description, parameters).""" - pass - - @abstractmethod - async def run(self, params: dict[str, Any]) -> Any: - """Execute the tool with the given parameters. - - Args: - params: Parameters to pass to the tool, typically the arguments - returned by an LLM tool call. Callers may inject - additional transport-level properties using ``_``-prefixed - keys (e.g. ``_bearer_token``). These reserved keys must - be extracted and consumed by the implementation before - building the request payload — they are never forwarded - to the tool microservice as part of the JSON body. - - Returns: - The tool's result (implementation-specific). - """ - pass +ToolDefinition = FunctionToolParam class ToolRegistryModule(ModaiModule, ABC): """ Module Declaration for: Tool Registry (Plain Module) - Aggregates tools from all configured sources and provides lookup by name. + Abstract contract for tool discovery and execution. - Configuration: - tools: list of dicts, each with: - - "url": the full trigger endpoint URL of the tool microservice - - "method": the HTTP method to invoke the tool (e.g. PUT, POST, GET) + This interface defines runtime tool discovery via ``get_tools`` and + tool execution via ``run_tool``. - Example config: - tools: - - url: http://calculator-service:8000/calculate - method: POST - - url: http://web-search-service:8000/search - method: PUT + Concrete implementations decide how tools are sourced and invoked + (for example via OpenAPI specs, static definitions, or other backends). """ - def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): - super().__init__(dependencies, config) - @abstractmethod - async def get_tools( - self, predefined_params: dict[str, Any] | None = None - ) -> list[Tool]: + async def get_tools(self, request: Request) -> list[ToolDefinition]: """ Returns all configured tools. - Each Tool provides its definition and run capability. + Each entry is returned in OpenAI ``function`` tool format. Unavailable tool services are omitted with a warning logged. Args: - predefined_params: Optional dict of ``_``-prefixed keys whose - values are already known by the caller (e.g. - ``{"_session_id": "abc", "_bearer_token": "xyz"}``). - Implementations may use these to strip the corresponding - properties from tool definitions so the LLM is not asked to - supply values that are already available. + request: FastAPI request context. """ pass @abstractmethod - async def get_tool_by_name( - self, name: str, predefined_params: dict[str, Any] | None = None - ) -> Tool | None: + async def run_tool(self, request: Request, params: dict[str, Any]) -> Any: """ - Look up a tool by its name. - - Returns the matching Tool if found, or None if not found. + Execute a configured tool. Args: - name: The tool's unique name (derived from OpenAPI ``operationId``). - predefined_params: Same semantics as in :meth:`get_tools`. - """ - pass - - -class ToolsWebModule(ModaiModule, ABC): - """ - Module Declaration for: Tools Web Module (Web Module) - - Exposes GET /api/tools. Retrieves tools from the Tool Registry and returns - their definitions in a format suitable for the consumer. - """ - - def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): - super().__init__(dependencies, config) - self.router = APIRouter() - self.router.add_api_route("/api/tools", self.get_tools, methods=["GET"]) - - @abstractmethod - async def get_tools(self) -> dict[str, Any]: - """ - Returns all available tool definitions in a consumer-specific format. - - The response must contain a "tools" key with a list of tool definitions. - The exact structure of each tool definition is determined by the - implementation. + request: FastAPI request context. + params: Invocation payload. Implementations should expect at least + ``name`` and ``arguments`` keys. """ pass diff --git a/backend/omni/src/modai/modules/tools/tool_registry_openapi.py b/backend/omni/src/modai/modules/tools/tool_registry_openapi.py index ecad84d9..cb40cc0e 100644 --- a/backend/omni/src/modai/modules/tools/tool_registry_openapi.py +++ b/backend/omni/src/modai/modules/tools/tool_registry_openapi.py @@ -1,13 +1,15 @@ import logging import re +from dataclasses import dataclass from typing import Any from urllib.parse import urlparse import httpx +from fastapi import Request from modai.module import ModuleDependencies from modai.modules.http_client.module import HttpClientModule -from modai.modules.tools.module import Tool, ToolDefinition, ToolRegistryModule +from modai.modules.tools.module import ToolDefinition, ToolRegistryModule logger = logging.getLogger(__name__) @@ -15,178 +17,169 @@ TOOL_HTTP_TIMEOUT_SECONDS = 30.0 -class _OpenAPITool(Tool): - """Tool backed by an OpenAPI microservice endpoint. - - Holds the tool's pre-built definition and invokes the microservice - over HTTP when ``run`` is called. - """ - - def __init__( - self, - url: str, - method: str, - definition: ToolDefinition, - http_client_factory: HttpClientModule, - header_param_names: frozenset[str] = frozenset(), - ) -> None: - self._url = url - self._method = method - self._definition_val = definition - self._http_client_factory = http_client_factory - self._header_param_names = header_param_names - - @property - def definition(self) -> ToolDefinition: - return self._definition_val - - async def run(self, params: dict[str, Any]) -> Any: - """Invoke the tool microservice over HTTP with the given parameters. - - Extracts reserved metadata keys from ``params`` before sending the - request. Currently recognised keys: - - * ``_bearer_token`` — forwarded as the ``Authorization: Bearer`` - header; never included in the JSON request body. - - Path parameters present in the URL template (e.g. ``{user_id}``) are - substituted directly into the URL and removed from the request body. - - Header parameters declared in the OpenAPI spec (``in: header``) are - forwarded as-is HTTP headers and removed from the request body. - """ - body = dict(params) - bearer_token = body.pop("_bearer_token", None) - headers: dict[str, str] = {} - if bearer_token: - headers["Authorization"] = f"Bearer {bearer_token}" - - for header_name in self._header_param_names: - if header_name in body: - headers[header_name] = str(body.pop(header_name)) - - url = self._url - for name in re.findall(r"\{(\w+)\}", url): - if name in body: - url = url.replace(f"{{{name}}}", str(body.pop(name))) - - async with self._http_client_factory.new( - timeout=TOOL_HTTP_TIMEOUT_SECONDS - ) as client: - response = await client.request( - method=self._method.upper(), - url=url, - json=body, - headers=headers, - ) - response.raise_for_status() - return response.text - - class OpenAPIToolRegistryModule(ToolRegistryModule): - """ - Tool Registry that fetches OpenAPI specs from configured microservices - and creates Tool instances with HTTP-based run capability. - - Each Tool's definition (name, description, parameters) is extracted from - the service's OpenAPI spec. The run method makes an HTTP request to the - configured trigger endpoint. - - Configuration: - tools: list of dicts, each with: - - "url": the full trigger endpoint URL of the tool microservice - - "method": the HTTP method to invoke the tool (e.g. PUT, POST, GET) - The registry derives the base URL from "url" and appends - "/openapi.json" to fetch the spec. - - Module Dependencies: - http_client: an HttpClientModule used for all outbound HTTP requests. - - Example config: - tools: - - url: http://calculator-service:8000/calculate - method: POST - - url: http://web-search-service:8000/search - method: PUT + """Tool Registry that fetches OpenAPI specs from configured microservices. + + It exposes tools in OpenAI ``function`` tool format and executes tool calls + over HTTP based on the operation metadata discovered from OpenAPI specs. """ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): super().__init__(dependencies, config) - self.tool_services: list[dict[str, str]] = config.get("tools", []) - self._http_client: HttpClientModule = dependencies.get_module("http_client") # type: ignore[assignment] + self.tool_servers: list[dict[str, str]] = config.get("tool_servers", []) + self._http_client: HttpClientModule = dependencies.get_module("http_client") + + async def get_tools(self, request: Request) -> list[ToolDefinition]: + del request + operation_specs = await self._collect_operation_specs() + return [spec.definition for spec in operation_specs] + + async def run_tool(self, request: Request, params: dict[str, Any]) -> Any: + del request + name = params.get("name") + if not isinstance(name, str) or not name: + raise ValueError("Tool invocation requires a non-empty 'name'") + + raw_arguments = params.get("arguments") + if raw_arguments is None: + raw_arguments = {k: v for k, v in params.items() if k != "name"} + if not isinstance(raw_arguments, dict): + raise ValueError("Tool invocation 'arguments' must be an object") + + operation_spec = await self._find_operation_spec_by_name(name) + if operation_spec is None: + raise ValueError(f"Tool '{name}' not found") + + return await _run_operation( + http_client_factory=self._http_client, + operation_spec=operation_spec, + params=raw_arguments, + ) - async def get_tools( - self, - predefined_params: dict[str, Any] | None = None, # noqa: ARG002 - ) -> list[Tool]: - tools: list[Tool] = [] + async def _find_operation_spec_by_name(self, name: str) -> "_OperationSpec | None": + operation_specs = await self._collect_operation_specs() + return next( + (spec for spec in operation_specs if spec.definition["name"] == name), None + ) + + async def _collect_operation_specs(self) -> list["_OperationSpec"]: + operation_specs: list[_OperationSpec] = [] async with self._http_client.new(timeout=HTTP_TIMEOUT_SECONDS) as client: - for service in self.tool_services: - url = service["url"] - method = service["method"] - base_url = _derive_base_url(url) - spec = await _fetch_openapi_spec(client, base_url) + for server in self.tool_servers: + openapi_url = server["url"] + spec = await _fetch_openapi_spec(client, openapi_url) if spec is None: continue - definition, header_param_names = _build_tool_definition(spec) - if definition is None: + + service_base_url = _derive_base_url(openapi_url) + built_specs = _build_operation_specs(spec) + if not built_specs: logger.warning( "No valid operation with operationId found in spec from %s, skipping", - base_url, + openapi_url, ) continue - tools.append( - _OpenAPITool( - url=url, - method=method, - definition=definition, - http_client_factory=self._http_client, - header_param_names=header_param_names, + + for built_spec in built_specs: + operation_specs.append( + _OperationSpec( + url=_build_operation_url(service_base_url, built_spec.path), + method=built_spec.method, + definition=built_spec.definition, + header_param_names=built_spec.header_param_names, + ) ) - ) - return tools - - async def get_tool_by_name( - self, - name: str, - predefined_params: dict[str, Any] | None = None, # noqa: ARG002 - ) -> Tool | None: - tools = await self.get_tools() - for tool in tools: - if tool.definition.name == name: - return tool - return None + return operation_specs -def _build_tool_definition( - spec: dict[str, Any], -) -> tuple[ToolDefinition, frozenset[str]] | tuple[None, frozenset[str]]: - """Build a ToolDefinition from an OpenAPI spec. +@dataclass(frozen=True) +class _BuiltOperationSpec: + path: str + method: str + definition: ToolDefinition + header_param_names: frozenset[str] - Extracts name (operationId), description (summary/description), and - parameters (request body schema with $ref fully resolved, path params and - header params merged in). - Returns a ``(ToolDefinition, header_param_names)`` tuple so callers know - which parameter names map to HTTP headers at invocation time. Returns - ``(None, frozenset())`` if no operation with an operationId is found. - """ +@dataclass(frozen=True) +class _OperationSpec: + url: str + method: str + definition: ToolDefinition + header_param_names: frozenset[str] + + +def _build_operation_specs(spec: dict[str, Any]) -> list[_BuiltOperationSpec]: paths = spec.get("paths", {}) - for _path, methods in paths.items(): - for _method, operation in methods.items(): + operation_specs: list[_BuiltOperationSpec] = [] + for path, methods in paths.items(): + if not isinstance(methods, dict): + continue + + shared_parameters = methods.get("parameters", []) + for method, operation in methods.items(): + if method == "parameters": + continue if not isinstance(operation, dict) or "operationId" not in operation: continue - name = operation["operationId"] - description = operation.get("summary") or operation.get("description", "") - parameters = _extract_parameters(operation, spec) - header_params = _extract_header_parameters(operation, spec) - header_param_names = frozenset(p["name"] for p in header_params) - return ToolDefinition( - name=name, description=description, parameters=parameters - ), header_param_names - return None, frozenset() + + merged_operation = { + **operation, + "parameters": [*shared_parameters, *operation.get("parameters", [])], + } + operation_specs.append( + _BuiltOperationSpec( + path=path, + method=method, + definition={ + "type": "function", + "name": operation["operationId"], + "description": operation.get("summary") + or operation.get("description", ""), + "parameters": _extract_parameters(merged_operation, spec), + "strict": True, + }, + header_param_names=frozenset( + p["name"] + for p in _extract_header_parameters(merged_operation, spec) + ), + ) + ) + + return operation_specs + + +async def _run_operation( + http_client_factory: HttpClientModule, + operation_spec: _OperationSpec, + params: dict[str, Any], +) -> Any: + body = dict(params) + bearer_token = body.pop("_bearer_token", None) + + headers: dict[str, str] = {} + if bearer_token: + headers["Authorization"] = f"Bearer {bearer_token}" + + for header_name in operation_spec.header_param_names: + if header_name in body: + headers[header_name] = str(body.pop(header_name)) + + resolved_url = operation_spec.url + for name in re.findall(r"\{(\w+)\}", resolved_url): + if name in body: + resolved_url = resolved_url.replace(f"{{{name}}}", str(body.pop(name))) + + async with http_client_factory.new(timeout=TOOL_HTTP_TIMEOUT_SECONDS) as client: + response = await client.request( + method=operation_spec.method.upper(), + url=resolved_url, + json=body, + headers=headers, + ) + response.raise_for_status() + return response.text def _extract_parameters( @@ -237,7 +230,6 @@ def _extract_parameters( def _extract_path_parameters( operation: dict[str, Any], spec: dict[str, Any] ) -> list[dict[str, Any]]: - """Return resolved path parameters (``in: path``) from an OpenAPI operation.""" return [ _resolve_refs(param, spec) for param in operation.get("parameters", []) @@ -248,7 +240,6 @@ def _extract_path_parameters( def _extract_header_parameters( operation: dict[str, Any], spec: dict[str, Any] ) -> list[dict[str, Any]]: - """Return resolved header parameters (``in: header``) from an OpenAPI operation.""" return [ _resolve_refs(param, spec) for param in operation.get("parameters", []) @@ -257,7 +248,6 @@ def _extract_header_parameters( def _resolve_refs(node: Any, spec: dict[str, Any]) -> Any: - """Recursively resolve all $ref pointers in a JSON Schema against the OpenAPI spec.""" if isinstance(node, dict): if "$ref" in node: resolved = _follow_ref(node["$ref"], spec) @@ -269,7 +259,6 @@ def _resolve_refs(node: Any, spec: dict[str, Any]) -> Any: def _follow_ref(ref: str, spec: dict[str, Any]) -> dict[str, Any]: - """Follow a JSON Pointer reference like '#/components/schemas/Foo'.""" if not ref.startswith("#/"): logger.warning("Unsupported $ref format: %s", ref) return {} @@ -284,33 +273,53 @@ def _follow_ref(ref: str, spec: dict[str, Any]) -> dict[str, Any]: return current if isinstance(current, dict) else {} -def _derive_base_url(trigger_url: str) -> str: - """Derive the service base URL from a full trigger endpoint URL. +def _derive_base_url(openapi_url: str) -> str: + """Derive service base URL from an OpenAPI spec URL. - E.g. 'http://calc:8000/calculate' -> 'http://calc:8000' + E.g. 'http://calc:8000/openapi.json' -> 'http://calc:8000' + 'http://calc:8000/api/openapi.json' -> 'http://calc:8000/api' """ - parsed = urlparse(trigger_url) - return f"{parsed.scheme}://{parsed.netloc}" + parsed = urlparse(_normalize_openapi_url(openapi_url)) + base_path = parsed.path.rsplit("/", maxsplit=1)[0] + origin = f"{parsed.scheme}://{parsed.netloc}" + return f"{origin}{base_path}" if base_path else origin + + +def _build_operation_url(service_base_url: str, operation_path: str) -> str: + if operation_path.startswith("/"): + return f"{service_base_url.rstrip('/')}{operation_path}" + return f"{service_base_url.rstrip('/')}/{operation_path}" + + +def _normalize_openapi_url(openapi_url: str) -> str: + normalized = openapi_url.rstrip("/") + if normalized.endswith(".json"): + return normalized + return f"{normalized}/openapi.json" async def _fetch_openapi_spec( - client: httpx.AsyncClient, base_url: str + client: httpx.AsyncClient, openapi_url: str ) -> dict[str, Any] | None: - openapi_url = f"{base_url.rstrip('/')}/openapi.json" + normalized_openapi_url = _normalize_openapi_url(openapi_url) try: - response = await client.request("GET", openapi_url) + response = await client.request("GET", normalized_openapi_url) response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: logger.warning( - "Tool service %s returned HTTP %s", base_url, e.response.status_code + "Tool service %s returned HTTP %s", + normalized_openapi_url, + e.response.status_code, ) return None except httpx.RequestError as e: - logger.warning("Failed to reach tool service %s: %s", base_url, e) + logger.warning("Failed to reach tool service %s: %s", normalized_openapi_url, e) return None except Exception: logger.warning( - "Unexpected error fetching spec from %s", base_url, exc_info=True + "Unexpected error fetching spec from %s", + normalized_openapi_url, + exc_info=True, ) return None diff --git a/backend/omni/src/modai/modules/tools/tool_registry_predefined_vars.py b/backend/omni/src/modai/modules/tools/tool_registry_predefined_vars.py index 2eeb367d..851979dd 100644 --- a/backend/omni/src/modai/modules/tools/tool_registry_predefined_vars.py +++ b/backend/omni/src/modai/modules/tools/tool_registry_predefined_vars.py @@ -21,45 +21,20 @@ supplied via the ``delegate_registry`` module dependency. """ -import logging from typing import Any -from modai.module import ModuleDependencies -from modai.modules.tools.module import Tool, ToolDefinition, ToolRegistryModule - -logger = logging.getLogger(__name__) - +from fastapi import Request -class _PredefinedVariablesTool(Tool): - """Wraps an inner Tool, hiding known variables from its public definition. +from modai.module import ModuleDependencies +from modai.modules.tools.module import ToolDefinition, ToolRegistryModule - ``translations`` is a mapping of ``tool_param_name → prefixed_predefined_key`` - (e.g. ``{"X-Session-Id": "_session_id"}``). In :meth:`run` each prefixed - key is popped from ``params`` and re-injected under its tool parameter name - before the call is forwarded to the inner tool. - """ - def __init__( - self, - inner: Tool, - translations: dict[str, str], - filtered_definition: ToolDefinition, - ) -> None: - self._inner = inner - self._translations = translations - self._filtered_definition = filtered_definition - - @property - def definition(self) -> ToolDefinition: - return self._filtered_definition - - async def run(self, params: dict[str, Any]) -> Any: - """Forward to inner tool, substituting predefined values into their tool param names.""" - translated = dict(params) - for tool_param, prefixed_key in self._translations.items(): - if prefixed_key in translated: - translated[tool_param] = translated.pop(prefixed_key) - return await self._inner.run(translated) +def _extract_predefined_params(request: Request) -> dict[str, Any]: + """Extract predefined tool params from all request headers.""" + return { + f"_{header_name.lower().replace('-', '_')}": value + for header_name, value in request.headers.items() + } class PredefinedVariablesToolRegistryModule(ToolRegistryModule): @@ -108,52 +83,58 @@ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): ) # type: ignore[assignment] self._variable_mappings: dict[str, str] = config.get("variable_mappings", {}) - async def get_tools( - self, predefined_params: dict[str, Any] | None = None - ) -> list[Tool]: - tools = await self._inner_registry.get_tools() + async def get_tools(self, request: Request) -> list[ToolDefinition]: + tools = await self._inner_registry.get_tools(request) + predefined_params = _extract_predefined_params(request) return [ - _wrap_tool(tool, predefined_params, self._variable_mappings) - for tool in tools + _filter_tool_definition( + definition, predefined_params, self._variable_mappings + ) + for definition in tools ] - async def get_tool_by_name( - self, name: str, predefined_params: dict[str, Any] | None = None - ) -> Tool | None: - tool = await self._inner_registry.get_tool_by_name(name) - if tool is None: - return None - return _wrap_tool(tool, predefined_params, self._variable_mappings) - - -# --------------------------------------------------------------------------- -# Pure helper functions -# --------------------------------------------------------------------------- + async def run_tool(self, request: Request, params: dict[str, Any]) -> Any: + predefined_params = _extract_predefined_params(request) + name = params.get("name") + if not isinstance(name, str) or not name: + raise ValueError("Tool invocation requires a non-empty 'name'") + + tool_definitions = await self._inner_registry.get_tools(request) + definition = next( + (tool for tool in tool_definitions if tool["name"] == name), None + ) + if definition is None: + raise ValueError(f"Tool '{name}' not found") + + translated_params = dict(params) + raw_arguments = translated_params.get("arguments", {}) + if not isinstance(raw_arguments, dict): + raise ValueError("Tool invocation 'arguments' must be an object") + + translated_arguments = dict(raw_arguments) + translations = _build_translations( + definition=definition, + predefined_params=predefined_params, + variable_mappings=self._variable_mappings, + ) + for tool_param, prefixed_key in translations.items(): + predefined_value = predefined_params.get(prefixed_key) + if predefined_value is not None and tool_param not in translated_arguments: + translated_arguments[tool_param] = predefined_value + + translated_params["arguments"] = translated_arguments + return await self._inner_registry.run_tool(request, translated_params) def _build_translations( definition: ToolDefinition, - predefined_params: dict[str, Any] | None, + predefined_params: dict[str, Any], variable_mappings: dict[str, str], ) -> dict[str, str]: - """Build a ``tool_param → prefixed_predefined_key`` map for *definition*. - - Only includes entries where: - - the prefixed predefined key is present in ``predefined_params``, AND - - the target tool parameter exists in the definition's schema properties. - - Direct mappings (``_session_id`` → ``session_id``) are derived - automatically from ``predefined_params``. ``variable_mappings`` entries - (``X-Session-Id: session_id``) override the direct mapping for the same - predefined variable so the value is routed to the correct tool parameter. - """ - if not predefined_params: - return {} - - schema_properties = set(definition.parameters.get("properties", {}).keys()) + schema = definition.get("parameters") or {} + schema_properties = set(schema.get("properties", {}).keys()) translations: dict[str, str] = {} - # Direct: _session_id → session_id (when session_id is in the schema) for prefixed_key in predefined_params: if not prefixed_key.startswith("_"): continue @@ -161,56 +142,49 @@ def _build_translations( if var_name in schema_properties: translations[var_name] = prefixed_key - # Configured: X-Session-Id ← _session_id (overrides the direct mapping) for tool_param, var_name in variable_mappings.items(): prefixed_key = f"_{var_name}" if prefixed_key not in predefined_params: continue if tool_param not in schema_properties: continue - # Remove default direct mapping for var_name if it was added above translations.pop(var_name, None) translations[tool_param] = prefixed_key return translations -def _wrap_tool( - tool: Tool, - predefined_params: dict[str, Any] | None, +def _filter_tool_definition( + definition: ToolDefinition, + predefined_params: dict[str, Any], variable_mappings: dict[str, str], -) -> Tool: - """Return a filtered wrapper around *tool*, or *tool* itself if nothing to hide.""" - translations = _build_translations( - tool.definition, predefined_params, variable_mappings - ) - if not translations: - return tool - hidden = set(translations.keys()) - filtered_definition = _filter_definition(tool.definition, hidden) - return _PredefinedVariablesTool( - inner=tool, - translations=translations, - filtered_definition=filtered_definition, - ) - - -def _filter_definition( - definition: ToolDefinition, hidden_properties: set[str] ) -> ToolDefinition: - """Return a new :class:`ToolDefinition` with *hidden_properties* removed.""" - params = definition.parameters + translations = _build_translations(definition, predefined_params, variable_mappings) + if not translations: + return definition + + hidden_properties = set(translations.keys()) + parameters = definition.get("parameters") or {} new_properties = { k: v - for k, v in params.get("properties", {}).items() + for k, v in parameters.get("properties", {}).items() if k not in hidden_properties } - new_required = [r for r in params.get("required", []) if r not in hidden_properties] - new_params: dict[str, Any] = {**params, "properties": new_properties} - if "required" in params: - new_params["required"] = new_required - return ToolDefinition( - name=definition.name, - description=definition.description, - parameters=new_params, - ) + new_required = [ + r for r in parameters.get("required", []) if r not in hidden_properties + ] + + new_parameters: dict[str, Any] = {**parameters, "properties": new_properties} + if "required" in parameters: + new_parameters["required"] = new_required + + filtered: ToolDefinition = { + "type": "function", + "name": definition["name"], + "parameters": new_parameters, + "strict": definition.get("strict", True), + } + if "description" in definition: + filtered["description"] = definition.get("description") + + return filtered diff --git a/backend/omni/src/modai/modules/tools/tool_router.py b/backend/omni/src/modai/modules/tools/tool_router.py new file mode 100644 index 00000000..c6165bd2 --- /dev/null +++ b/backend/omni/src/modai/modules/tools/tool_router.py @@ -0,0 +1,68 @@ +import logging +from typing import Any + +from fastapi import APIRouter, Request + +from modai.module import ModuleDependencies +from modai.modules.tools.module import ToolDefinition, ToolRegistryModule + +logger = logging.getLogger(__name__) + + +class ToolsRouterModule(ToolRegistryModule): + """Central tools router that aggregates multiple tool registries. + + Exposes ``GET /api/tools`` and returns the union of all tool definitions + from configured registries without renaming tool names. + + Runtime invocations are dispatched by finding the registry that provides + the requested tool name. + """ + + def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): + super().__init__(dependencies, config) + + self._registries: dict[str, ToolRegistryModule] = {} + for dependency_name, module in dependencies.modules.items(): + if isinstance(module, ToolRegistryModule): + self._registries[dependency_name] = module + + self.router = APIRouter() + self.router.add_api_route("/api/tools", self.get_tools, methods=["GET"]) + + async def get_tools(self, request: Request) -> list[ToolDefinition]: + aggregated_tools: list[ToolDefinition] = [] + for registry_dependency_key, registry in self._registries.items(): + try: + tools = await registry.get_tools(request) + except Exception as exc: + logger.warning( + "Failed to load tools from registry '%s': %s", + registry_dependency_key, + exc, + ) + continue + + aggregated_tools.extend(tools) + + return aggregated_tools + + async def run_tool(self, request: Request, params: dict[str, Any]) -> Any: + name = params.get("name") + if not isinstance(name, str) or not name: + raise ValueError("Tool invocation requires a non-empty 'name'") + + matching_registries: list[ToolRegistryModule] = [] + for registry in self._registries.values(): + tools = await registry.get_tools(request) + if any(tool.get("name") == name for tool in tools): + matching_registries.append(registry) + + if not matching_registries: + raise ValueError(f"Tool '{name}' not found") + if len(matching_registries) > 1: + raise ValueError( + f"Tool '{name}' is provided by multiple registries; tool names must be unique" + ) + + return await matching_registries[0].run_tool(request, dict(params)) diff --git a/backend/omni/src/modai/modules/tools/tools_web_module.py b/backend/omni/src/modai/modules/tools/tools_web_module.py deleted file mode 100644 index 1b8bee20..00000000 --- a/backend/omni/src/modai/modules/tools/tools_web_module.py +++ /dev/null @@ -1,54 +0,0 @@ -import logging -from typing import Any - -from modai.module import ModuleDependencies -from modai.modules.tools.module import ( - ToolDefinition, - ToolRegistryModule, - ToolsWebModule, -) - -logger = logging.getLogger(__name__) - - -class OpenAIToolsWebModule(ToolsWebModule): - """ - ToolsWebModule implementation that returns tools in OpenAI - function-calling format. - - Converts each tool's ToolDefinition into the format expected by - the OpenAI Chat Completions API: - { - "type": "function", - "function": { - "name": "", - "description": "", - "parameters": { }, - "strict": true - } - } - """ - - def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]): - super().__init__(dependencies, config) - self.tool_registry: ToolRegistryModule = dependencies.get_module( - "tool_registry" - ) - - async def get_tools(self) -> dict[str, Any]: - tools = await self.tool_registry.get_tools() - openai_tools = [_to_openai_format(tool.definition) for tool in tools] - return {"tools": openai_tools} - - -def _to_openai_format(definition: ToolDefinition) -> dict[str, Any]: - """Convert a ToolDefinition to OpenAI function-calling format.""" - return { - "type": "function", - "function": { - "name": definition.name, - "description": definition.description, - "parameters": definition.parameters, - "strict": True, - }, - } diff --git a/backend/tools/dice-roller/justfile b/backend/tools/dice-roller/justfile index b41018a0..59a87821 100644 --- a/backend/tools/dice-roller/justfile +++ b/backend/tools/dice-roller/justfile @@ -17,6 +17,6 @@ check: uv run ruff check . # Fix code style and linting issues -check-write: +format: uv run ruff format . uv run ruff check --fix . diff --git a/e2e_tests/tests_omni_full/backend-config-e2e.yaml b/e2e_tests/tests_omni_full/backend-config-e2e.yaml index cd5eef5f..619b95bc 100644 --- a/e2e_tests/tests_omni_full/backend-config-e2e.yaml +++ b/e2e_tests/tests_omni_full/backend-config-e2e.yaml @@ -4,6 +4,5 @@ includes: modules: openapi_tool_registry: config: - tools: - - url: http://localhost:8001/roll - method: POST + tool_servers: + - url: http://localhost:8001/openapi.json diff --git a/e2e_tests/tests_omni_full/justfile b/e2e_tests/tests_omni_full/justfile index bf7499f8..b1c5199b 100644 --- a/e2e_tests/tests_omni_full/justfile +++ b/e2e_tests/tests_omni_full/justfile @@ -21,5 +21,5 @@ check: pnpm check # Fix code style and linting issues -check-write: +format: pnpm check:write diff --git a/e2e_tests/tests_omni_full/src/chat.spec.ts b/e2e_tests/tests_omni_full/src/chat.spec.ts index 14b5ee7b..fd30d070 100644 --- a/e2e_tests/tests_omni_full/src/chat.spec.ts +++ b/e2e_tests/tests_omni_full/src/chat.spec.ts @@ -55,7 +55,7 @@ test.describe("Chat", () => { const chatPage = new ChatPage(page); await chatPage.navigateTo(); await chatPage.selectFirstModel(); - await chatPage.enableTool("roll_dice"); + await chatPage.enableTool("Roll Dice"); // llmock trigger: "call tool '' with ''" causes it to return // a tool_call response. The backend Strands agent then calls the diff --git a/e2e_tests/tests_omni_light/justfile b/e2e_tests/tests_omni_light/justfile index 7df9fffe..92bbd0d4 100644 --- a/e2e_tests/tests_omni_light/justfile +++ b/e2e_tests/tests_omni_light/justfile @@ -21,5 +21,5 @@ check: pnpm check # Fix code style and linting issues -check-write: +format: pnpm check:write diff --git a/frontend/omni/.gitignore b/frontend/omni/.gitignore index 8227b370..3f0cf390 100644 --- a/frontend/omni/.gitignore +++ b/frontend/omni/.gitignore @@ -2,5 +2,7 @@ node_modules/ dist/ src/modules/external-* -# Frontend generated -frontend/omni/public/modules.json +public/modules.json +public/modules*.json +!public/modules_browser_only.json +!public/modules_with_backend.json diff --git a/frontend/omni/CHANGELOG.md b/frontend/omni/CHANGELOG.md index b76bb868..c575ce9c 100644 --- a/frontend/omni/CHANGELOG.md +++ b/frontend/omni/CHANGELOG.md @@ -12,12 +12,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - i18n support via `i18next` with browser language detection (supports English fallback and German). Translations are module-scoped: each module with user-facing strings has its own `locales/de.json`. New `src/modules/i18n/` module provides the `getT(namespace)` helper used by all components. - Tools can now be toggled directly from the chat input panel via a wrench-icon popover — no separate Tools page needed. - External frontend modules under `src/modules/external-*` can now carry their own npm dependencies through pnpm workspace package discovery, without requiring edits to the root `package.json`. +- Manifest `includes` support: the root `modules.json` can now declare an `includes` array to compose the module list from multiple JSON files. Included files are merged left-to-right; the root manifest always wins. Per-module `collisionStrategy` (`merge` | `replace` | `drop`) controls collision behaviour, mirroring the backend YAML config loader. ### Changed - Switch frontend to Svelte +- Updated tools fetching and chat tool handling to support `/api/tools` responses as a raw list of function tools (instead of requiring a `{ tools: [...] }` envelope) - ChatInputPanel redesigned: textarea, model selector, tools selector, and send button are now visually inside a single input box (ai-sdk.dev/examples/chatbot style). - Removed the Tools sidebar navigation item and `/tools` route from all module configurations. +- Made the chat tools selector popover scrollable with a viewport-constrained height so long tool lists remain fully accessible. +- Simplified the chat tools selector list to show only tool names; tool descriptions are now shown as tooltips on hover/focus. +- Tool identifiers are now displayed as human-readable names in the chat tools selector (for example, `some_tool` -> `Some Tool`). ## [0.0.1] - 2026-02-12 diff --git a/frontend/omni/docs/architecture/core.md b/frontend/omni/docs/architecture/core.md index b39afb57..4b8a8395 100644 --- a/frontend/omni/docs/architecture/core.md +++ b/frontend/omni/docs/architecture/core.md @@ -115,6 +115,52 @@ To activate a module, add it to `modules*.json`. > **Auto-discovery**: The module registry automatically discovers all `.svelte` and `.svelte.ts` files under `src/modules/**` via Vite's `import.meta.glob`. No manual TypeScript registry entry is required. +#### `modules*.json` — top-level structure and includes + +A manifest file has the following top-level shape: + +```json +{ + "version": "1.0.0", + "includes": [{ "path": "base.json" }], + "modules": [...] +} +``` + +- **version**: manifest format version (currently `"1.0.0"`) +- **includes** _(optional, root manifest only)_: list of other manifest files to merge in before this file's own modules. Each entry is an object with a `path` field. Relative paths are resolved relative to the current manifest. Nested includes (an included file itself containing `includes`) are not supported and throw an error. +- **modules**: array of module entries (see below) + +**Load order** — mirrors the backend YAML config loader: +1. Included files are fetched and applied left-to-right; later includes win on collision. +2. The root manifest's own `modules` are applied last and always win. + +**`collisionStrategy`** on a module entry controls what happens when that module's `id` already exists from an earlier include: +- `"merge"` *(default)* — deep-merges `config`; incoming wins on shared keys. `dependencies` is shallow-merged; incoming wins on key collision. +- `"replace"` — incoming entry fully replaces the existing one. +- `"drop"` — removes the existing entry; does not add the incoming one either. + +`collisionStrategy` is stripped from the resolved manifest before it is used. + +**Example** — composing a deployment-specific manifest from a base: + +```json +{ + "version": "1.0.0", + "includes": [{ "path": "modules_with_backend.json" }], + "modules": [ + { + "id": "fetch-service", + "path": "@/modules/fetch-service/sessionFetchService/create", + "collisionStrategy": "replace", + "dependencies": { + "module:sessionService": "session-service" + } + } + ] +} +``` + #### `modules*.json` — regular module entry ```json diff --git a/frontend/omni/justfile b/frontend/omni/justfile index cb690c4e..be98d9e7 100644 --- a/frontend/omni/justfile +++ b/frontend/omni/justfile @@ -24,5 +24,5 @@ check: pnpm check # Fix code style and linting issues -check-write: +format: pnpm check:write diff --git a/frontend/omni/public/modules_with_backend.json b/frontend/omni/public/modules_with_backend.json index 9c6c317a..3c157934 100644 --- a/frontend/omni/public/modules_with_backend.json +++ b/frontend/omni/public/modules_with_backend.json @@ -60,6 +60,7 @@ "type": "FetchService", "path": "@/modules/fetch-service/sessionFetchService/create", "dependencies": { + "module:fetchService": "plain-fetch-service", "module:sessionService": "session-service", "module:noSessionAction": "no-session-action" } diff --git a/frontend/omni/src/core/module-system/ModulesProvider.svelte b/frontend/omni/src/core/module-system/ModulesProvider.svelte index 3445eaea..9cec1796 100644 --- a/frontend/omni/src/core/module-system/ModulesProvider.svelte +++ b/frontend/omni/src/core/module-system/ModulesProvider.svelte @@ -4,7 +4,7 @@ import { setContext, untrack } from "svelte"; import { ComponentResolver } from "./componentResolver"; import { MODULES_KEY, type Modules } from "./index"; import { resolveManifestDependencies } from "./manifestDependencyResolver"; -import { fetchManifestJson } from "./manifestJson"; +import { resolveManifest } from "./manifestJson"; import { ActiveModulesImpl } from "./module"; interface Props { @@ -28,7 +28,7 @@ const modules: Modules = { setContext(MODULES_KEY, modules); // untrack: manifest path is intentionally captured once at mount time -const ready = fetchManifestJson(untrack(() => manifestPath)).then( +const ready = resolveManifest(untrack(() => manifestPath)).then( async (json) => { const activeEntries = resolveManifestDependencies(json.modules, []); const componentResolver = diff --git a/frontend/omni/src/core/module-system/manifestJson.test.ts b/frontend/omni/src/core/module-system/manifestJson.test.ts new file mode 100644 index 00000000..c75bdf3f --- /dev/null +++ b/frontend/omni/src/core/module-system/manifestJson.test.ts @@ -0,0 +1,381 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import type { ManifestEntry, ManifestJson } from "./manifestJson"; +import { resolveManifest } from "./manifestJson"; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function manifest( + modules: ManifestEntry[], + extras: Partial = {}, +): ManifestJson { + return { version: "1.0.0", modules, ...extras }; +} + +function entry(id: string, extras: Partial = {}): ManifestEntry { + return { id, path: `@/modules/${id}`, ...extras }; +} + +function mockFetch(responses: Record) { + vi.spyOn(globalThis, "fetch").mockImplementation(async (url) => { + const key = String(url); + if (key in responses) { + return { + ok: true, + json: () => Promise.resolve(responses[key]), + } as Response; + } + return { ok: false, statusText: "Not Found" } as Response; + }); +} + +afterEach(() => { + vi.restoreAllMocks(); +}); + +// --------------------------------------------------------------------------- +// No includes — passes straight through +// --------------------------------------------------------------------------- + +describe("resolveManifest — no includes", () => { + it("returns the manifest unchanged when there are no includes", async () => { + mockFetch({ + "/modules.json": manifest([entry("chat"), entry("auth")]), + }); + + const result = await resolveManifest("/modules.json"); + + expect(result.version).toBe("1.0.0"); + expect(result.modules.map((m) => m.id)).toEqual(["chat", "auth"]); + }); + + it("returns an empty module list when root has no modules and no includes", async () => { + mockFetch({ "/modules.json": manifest([]) }); + + const result = await resolveManifest("/modules.json"); + + expect(result.modules).toEqual([]); + }); +}); + +// --------------------------------------------------------------------------- +// Basic include merging +// --------------------------------------------------------------------------- + +describe("resolveManifest — includes", () => { + it("adds modules from an included file that do not collide", async () => { + mockFetch({ + "/modules.json": manifest([entry("chat")], { + includes: [{ path: "/base.json" }], + }), + "/base.json": manifest([entry("auth")]), + }); + + const result = await resolveManifest("/modules.json"); + + expect(result.modules.map((m) => m.id)).toContain("chat"); + expect(result.modules.map((m) => m.id)).toContain("auth"); + }); + + it("root modules win over included modules (applied last)", async () => { + mockFetch({ + "/modules.json": manifest( + [entry("chat", { path: "@/modules/chat-root" })], + { includes: [{ path: "/base.json" }] }, + ), + "/base.json": manifest([ + entry("chat", { path: "@/modules/chat-base" }), + ]), + }); + + const result = await resolveManifest("/modules.json"); + + const chatModule = result.modules.find((m) => m.id === "chat"); + expect(chatModule?.path).toBe("@/modules/chat-root"); + }); + + it("strips collisionStrategy from all entries in the result", async () => { + mockFetch({ + "/modules.json": manifest( + [entry("chat", { collisionStrategy: "replace" })], + { includes: [{ path: "/base.json" }] }, + ), + "/base.json": manifest([entry("other")]), + }); + + const result = await resolveManifest("/modules.json"); + + for (const mod of result.modules) { + expect(mod.collisionStrategy).toBeUndefined(); + } + }); +}); + +// --------------------------------------------------------------------------- +// Collision strategies +// --------------------------------------------------------------------------- + +describe("resolveManifest — collisionStrategy", () => { + it("merge (default): incoming wins on shared fields, base-only keys survive", async () => { + mockFetch({ + "/modules.json": manifest( + [ + entry("svc", { + config: { rootKey: 1 }, + }), + ], + { includes: [{ path: "/base.json" }] }, + ), + "/base.json": manifest([ + entry("svc", { + config: { baseKey: 2, rootKey: 99 }, + }), + ]), + }); + + const result = await resolveManifest("/modules.json"); + const svc = result.modules.find((m) => m.id === "svc"); + + expect(svc?.config).toEqual({ rootKey: 1, baseKey: 2 }); + }); + + it("merge: deeply merges nested config objects", async () => { + mockFetch({ + "/modules.json": manifest( + [entry("svc", { config: { nested: { a: 1, c: 3 } } })], + { includes: [{ path: "/base.json" }] }, + ), + "/base.json": manifest([ + entry("svc", { config: { nested: { a: 99, b: 2 } } }), + ]), + }); + + const result = await resolveManifest("/modules.json"); + const svc = result.modules.find((m) => m.id === "svc"); + + expect(svc?.config).toEqual({ nested: { a: 1, b: 2, c: 3 } }); + }); + + it("merge: merges dependencies records, incoming wins on key collision", async () => { + mockFetch({ + "/modules.json": manifest( + [ + entry("chat", { + dependencies: { + "module:svc": "chat-service-v2", + "module:router": "router", + }, + }), + ], + { includes: [{ path: "/base.json" }] }, + ), + "/base.json": manifest([ + entry("chat", { + dependencies: { + "module:svc": "chat-service-v1", + "module:sidebar": "sidebar", + }, + }), + ]), + }); + + const result = await resolveManifest("/modules.json"); + const chat = result.modules.find((m) => m.id === "chat"); + + expect(chat?.dependencies).toEqual({ + "module:svc": "chat-service-v2", + "module:router": "router", + "module:sidebar": "sidebar", + }); + }); + + it("replace: incoming entry fully replaces the included one", async () => { + mockFetch({ + "/modules.json": manifest( + [ + entry("svc", { + path: "@/modules/svc-root", + config: { key: "root" }, + collisionStrategy: "replace", + }), + ], + { includes: [{ path: "/base.json" }] }, + ), + "/base.json": manifest([ + entry("svc", { + path: "@/modules/svc-base", + config: { key: "base", extraKey: "kept_by_base" }, + }), + ]), + }); + + const result = await resolveManifest("/modules.json"); + const svc = result.modules.find((m) => m.id === "svc"); + + expect(svc?.path).toBe("@/modules/svc-root"); + expect(svc?.config).toEqual({ key: "root" }); + }); + + it("drop: removes the existing entry and does not add the incoming one", async () => { + mockFetch({ + "/modules.json": manifest( + [entry("legacy", { collisionStrategy: "drop" })], + { includes: [{ path: "/base.json" }] }, + ), + "/base.json": manifest([entry("legacy"), entry("other")]), + }); + + const result = await resolveManifest("/modules.json"); + const ids = result.modules.map((m) => m.id); + + expect(ids).not.toContain("legacy"); + expect(ids).toContain("other"); + }); + + it("drop on a new id (not from includes) has no effect (nothing to drop)", async () => { + mockFetch({ + "/modules.json": manifest( + [entry("ghost", { collisionStrategy: "drop" })], + { includes: [{ path: "/base.json" }] }, + ), + "/base.json": manifest([entry("other")]), + }); + + const result = await resolveManifest("/modules.json"); + const ids = result.modules.map((m) => m.id); + + expect(ids).not.toContain("ghost"); + expect(ids).toContain("other"); + }); +}); + +// --------------------------------------------------------------------------- +// Multiple includes — load order +// --------------------------------------------------------------------------- + +describe("resolveManifest — multiple includes", () => { + it("later include wins over earlier on collision", async () => { + mockFetch({ + "/modules.json": manifest([], { + includes: [{ path: "/first.json" }, { path: "/second.json" }], + }), + "/first.json": manifest([ + entry("svc", { path: "@/modules/svc-first" }), + ]), + "/second.json": manifest([ + entry("svc", { path: "@/modules/svc-second" }), + ]), + }); + + const result = await resolveManifest("/modules.json"); + const svc = result.modules.find((m) => m.id === "svc"); + + expect(svc?.path).toBe("@/modules/svc-second"); + }); + + it("root still wins over all includes regardless of order", async () => { + mockFetch({ + "/modules.json": manifest( + [entry("svc", { path: "@/modules/svc-root" })], + { + includes: [ + { path: "/first.json" }, + { path: "/second.json" }, + ], + }, + ), + "/first.json": manifest([ + entry("svc", { path: "@/modules/svc-first" }), + ]), + "/second.json": manifest([ + entry("svc", { path: "@/modules/svc-second" }), + ]), + }); + + const result = await resolveManifest("/modules.json"); + const svc = result.modules.find((m) => m.id === "svc"); + + expect(svc?.path).toBe("@/modules/svc-root"); + }); +}); + +// --------------------------------------------------------------------------- +// Path resolution +// --------------------------------------------------------------------------- + +describe("resolveManifest — include path resolution", () => { + it("resolves absolute include paths as-is", async () => { + mockFetch({ + "/manifests/root.json": manifest([entry("root-mod")], { + includes: [{ path: "/shared/base.json" }], + }), + "/shared/base.json": manifest([entry("base-mod")]), + }); + + const result = await resolveManifest("/manifests/root.json"); + const ids = result.modules.map((m) => m.id); + + expect(ids).toContain("base-mod"); + expect(ids).toContain("root-mod"); + }); + + it("resolves relative include paths relative to the manifest directory", async () => { + mockFetch({ + "/manifests/root.json": manifest([entry("root-mod")], { + includes: [{ path: "base.json" }], + }), + "/manifests/base.json": manifest([entry("base-mod")]), + }); + + const result = await resolveManifest("/manifests/root.json"); + const ids = result.modules.map((m) => m.id); + + expect(ids).toContain("base-mod"); + expect(ids).toContain("root-mod"); + }); +}); + +// --------------------------------------------------------------------------- +// Error cases +// --------------------------------------------------------------------------- + +describe("resolveManifest — error cases", () => { + it("throws when fetch fails", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue({ + ok: false, + statusText: "Internal Server Error", + } as Response); + + await expect(resolveManifest("/modules.json")).rejects.toThrow( + "Failed to fetch manifest", + ); + }); + + it("throws when an included file itself contains includes (nested includes)", async () => { + mockFetch({ + "/modules.json": manifest([], { + includes: [{ path: "/child.json" }], + }), + "/child.json": manifest([], { + includes: [{ path: "/grandchild.json" }], + }), + }); + + await expect(resolveManifest("/modules.json")).rejects.toThrow( + "Nested includes are not supported", + ); + }); + + it("throws when an included file fetch fails", async () => { + mockFetch({ + "/modules.json": manifest([], { + includes: [{ path: "/missing.json" }], + }), + }); + + await expect(resolveManifest("/modules.json")).rejects.toThrow( + "Failed to fetch manifest", + ); + }); +}); diff --git a/frontend/omni/src/core/module-system/manifestJson.ts b/frontend/omni/src/core/module-system/manifestJson.ts index f8f509cc..71912fc7 100644 --- a/frontend/omni/src/core/module-system/manifestJson.ts +++ b/frontend/omni/src/core/module-system/manifestJson.ts @@ -1,5 +1,13 @@ +export type CollisionStrategy = "merge" | "replace" | "drop"; + +export interface IncludeEntry { + path: string; +} + export interface ManifestJson { version: string; + /** Only allowed in the root manifest. Nested includes are not supported. */ + includes?: IncludeEntry[]; modules: ManifestEntry[]; } @@ -8,6 +16,58 @@ export interface ManifestEntry { path: string; dependencies?: Record; config?: Record; + /** + * Controls how this entry behaves when its `id` already exists from an + * included file. Only meaningful on root-manifest entries (or later + * includes that overwrite earlier ones). + * + * - `merge` (default): deep-merge config; incoming wins on shared keys. + * - `replace`: incoming entry fully replaces the existing one. + * - `drop`: remove the existing entry; do not add this entry either. + * + * Stripped from the final resolved manifest. + */ + collisionStrategy?: CollisionStrategy; +} + +/** + * Fetch and resolve a manifest, expanding any `includes` into a single merged + * module list. Only the root manifest may contain `includes`; nested includes + * throw an error. + * + * Load order (mirrors the backend YAML config loader): + * 1. Includes are applied left-to-right; later includes win on collision. + * 2. Root modules are applied last and always win. + */ +export async function resolveManifest(path: string): Promise { + const root = await fetchManifestJson(path); + const includes = root.includes ?? []; + + if (includes.length === 0) { + return { version: root.version, modules: root.modules }; + } + + let accumulated = new Map(); + + for (const include of includes) { + const includePath = resolveIncludePath(path, include.path); + const included = await fetchManifestJson(includePath); + if (included.includes && included.includes.length > 0) { + throw new Error( + `Nested includes are not supported. '${includePath}' contains an 'includes' key. ` + + "Only the root manifest may use 'includes'.", + ); + } + accumulated = applyModules(accumulated, included.modules); + } + + accumulated = applyModules(accumulated, root.modules); + + const resolvedModules: ManifestEntry[] = Array.from( + accumulated.values(), + ).map(({ collisionStrategy: _, ...rest }) => rest as ManifestEntry); + + return { version: root.version, modules: resolvedModules }; } export async function fetchManifestJson(path: string): Promise { @@ -17,3 +77,94 @@ export async function fetchManifestJson(path: string): Promise { } return response.json() as Promise; } + +function resolveIncludePath(manifestPath: string, includePath: string): string { + if (includePath.startsWith("/")) return includePath; + const lastSlash = manifestPath.lastIndexOf("/"); + const base = lastSlash >= 0 ? manifestPath.substring(0, lastSlash + 1) : ""; + return base + includePath; +} + +function applyModules( + base: Map, + incoming: ManifestEntry[], +): Map { + const result = new Map(base); + + for (const entry of incoming) { + const strategy = entry.collisionStrategy ?? "merge"; + + if (strategy === "drop") { + result.delete(entry.id); + continue; + } + + if (!result.has(entry.id)) { + result.set(entry.id, entry); + continue; + } + + if (strategy === "replace") { + result.set(entry.id, entry); + } else { + const existing = result.get(entry.id); + if (existing) { + result.set(entry.id, deepMergeEntries(existing, entry)); + } + } + } + + return result; +} + +function deepMergeEntries( + base: ManifestEntry, + incoming: ManifestEntry, +): ManifestEntry { + const merged: ManifestEntry = { ...base, ...incoming }; + + if ( + base.dependencies !== undefined || + incoming.dependencies !== undefined + ) { + merged.dependencies = { + ...(base.dependencies ?? {}), + ...(incoming.dependencies ?? {}), + }; + } + + if (base.config !== undefined || incoming.config !== undefined) { + merged.config = deepMergeRecords( + base.config ?? {}, + incoming.config ?? {}, + ); + } + + return merged; +} + +function deepMergeRecords( + base: Record, + incoming: Record, +): Record { + const result: Record = { ...base }; + for (const [key, val] of Object.entries(incoming)) { + if ( + val !== null && + typeof val === "object" && + !Array.isArray(val) && + key in result && + result[key] !== null && + typeof result[key] === "object" && + !Array.isArray(result[key]) + ) { + result[key] = deepMergeRecords( + result[key] as Record, + val as Record, + ); + } else { + result[key] = val; + } + } + return result; +} diff --git a/frontend/omni/src/modules/chat-service/openai.svelte.test.ts b/frontend/omni/src/modules/chat-service/openai.svelte.test.ts index 2bac9201..b5300d6a 100644 --- a/frontend/omni/src/modules/chat-service/openai.svelte.test.ts +++ b/frontend/omni/src/modules/chat-service/openai.svelte.test.ts @@ -72,13 +72,11 @@ describe("OpenAIChatService", () => { [ { type: "function", - function: { - name: "calculate", - description: "A calculator", - parameters: { - type: "object", - properties: { x: { type: "number" } }, - }, + name: "calculate", + description: "A calculator", + parameters: { + type: "object", + properties: { x: { type: "number" } }, }, }, ], diff --git a/frontend/omni/src/modules/chat-service/openai.svelte.ts b/frontend/omni/src/modules/chat-service/openai.svelte.ts index a1470e57..eb6c502c 100644 --- a/frontend/omni/src/modules/chat-service/openai.svelte.ts +++ b/frontend/omni/src/modules/chat-service/openai.svelte.ts @@ -49,11 +49,11 @@ function trimTrailingSlash(url: string): string { function convertToAiSdkTools(openAiTools: OpenAIFunctionTool[]) { return Object.fromEntries( openAiTools.map((t) => [ - t.function.name, + t.name, tool({ - description: t.function.description, + description: t.description, inputSchema: jsonSchema( - (t.function.parameters ?? { + (t.parameters ?? { type: "object", properties: {}, }) as object, diff --git a/frontend/omni/src/modules/chat/ChatComponent.svelte b/frontend/omni/src/modules/chat/ChatComponent.svelte index 82f74ecb..b9f07963 100644 --- a/frontend/omni/src/modules/chat/ChatComponent.svelte +++ b/frontend/omni/src/modules/chat/ChatComponent.svelte @@ -113,7 +113,7 @@ async function handleSend(text: string) { try { chatStatus = "streaming"; const selectedTools = availableTools.filter((tool) => - selectedToolNames.includes(tool.function.name), + selectedToolNames.includes(tool.name), ); for await (const textPart of chatService.streamChat( selectedModelData, diff --git a/frontend/omni/src/modules/chat/ChatToolsSelector.svelte b/frontend/omni/src/modules/chat/ChatToolsSelector.svelte index b6914d57..28590407 100644 --- a/frontend/omni/src/modules/chat/ChatToolsSelector.svelte +++ b/frontend/omni/src/modules/chat/ChatToolsSelector.svelte @@ -2,8 +2,10 @@ import { Check, Wrench } from "lucide-svelte"; import { getT } from "@/modules/i18n/index.svelte.js"; import type { OpenAIFunctionTool } from "@/modules/tools-service/index.svelte.js"; +import { sortToolsByName } from "@/modules/tools-service/toolName"; import { Button } from "$lib/shadcnui/components/ui/button/index.js"; import * as Popover from "$lib/shadcnui/components/ui/popover/index.js"; +import * as Tooltip from "$lib/shadcnui/components/ui/tooltip/index.js"; const t = getT("chat"); @@ -18,10 +20,19 @@ let { } = $props(); let open = $state(false); +const sortedTools = $derived(sortToolsByName(availableTools)); function isSelected(name: string): boolean { return selectedToolNames.includes(name); } + +function formatToolName(name: string): string { + return name + .replace(/[_-]+/g, " ") + .trim() + .replace(/\s+/g, " ") + .replace(/\b\w/g, (match) => match.toUpperCase()); +} @@ -31,7 +42,7 @@ function isSelected(name: string): boolean { variant="ghost" size="sm" class="text-muted-foreground h-auto gap-1.5 px-2 py-1 text-xs" - aria-label={t("selectTools", { defaultValue: "Select tools" })} + aria-label={t("selectTools", { defaultValue: "Select tools" }) as string} {...props} > @@ -41,28 +52,43 @@ function isSelected(name: string): boolean { {/snippet} - +
{t("tools", { defaultValue: "Tools" })}
- {#each availableTools as tool} - + {#each sortedTools as tool} + {@const displayName = formatToolName(tool.name)} + {#if tool.description} + + + {#snippet child({ props })} + + {/snippet} + + + {tool.description} + + + {:else} + + {/if} {/each}
diff --git a/frontend/omni/src/modules/fetch-service/fetchService.test.ts b/frontend/omni/src/modules/fetch-service/fetchService.test.ts index ab0764b3..13de52f2 100644 --- a/frontend/omni/src/modules/fetch-service/fetchService.test.ts +++ b/frontend/omni/src/modules/fetch-service/fetchService.test.ts @@ -4,6 +4,7 @@ import type { NoSessionAction, SessionService, } from "@/modules/session-service/index.svelte"; +import type { FetchService } from "./index.svelte"; import { create as createPureFetchService } from "./pureFetchService.svelte"; import { create as createSessionFetchService } from "./sessionFetchService.svelte"; @@ -37,6 +38,9 @@ describe("PureFetchService", () => { }); describe("SessionFetchService", () => { + const mockFetchService: FetchService = { + fetch: vi.fn(), + }; const mockSessionService: SessionService = { refresh: vi.fn(), isSessionActive: vi.fn(), @@ -48,6 +52,7 @@ describe("SessionFetchService", () => { function makeDeps(): ModuleDependencies { return { getOne: vi.fn().mockImplementation((name: string) => { + if (name === "fetchService") return mockFetchService; if (name === "sessionService") return mockSessionService; if (name === "noSessionAction") return mockNoSessionAction; throw new Error(`Unknown dep "${name}"`); @@ -64,19 +69,17 @@ describe("SessionFetchService", () => { } beforeEach(() => { - vi.stubGlobal("fetch", vi.fn()); vi.mocked(mockSessionService.refresh).mockResolvedValue(undefined); vi.mocked(mockNoSessionAction.execute).mockImplementation(() => {}); }); afterEach(() => { - vi.unstubAllGlobals(); vi.clearAllMocks(); }); it("returns the response without touching session service on success", async () => { const mockResponse = new Response(null, { status: 200 }); - vi.mocked(fetch).mockResolvedValue(mockResponse); + vi.mocked(mockFetchService.fetch).mockResolvedValue(mockResponse); const response = await makeService().fetch("/api/data"); @@ -85,7 +88,9 @@ describe("SessionFetchService", () => { }); it("does not interact with session service on non-401 error responses", async () => { - vi.mocked(fetch).mockResolvedValue(new Response(null, { status: 500 })); + vi.mocked(mockFetchService.fetch).mockResolvedValue( + new Response(null, { status: 500 }), + ); await makeService().fetch("/api/data"); @@ -93,7 +98,9 @@ describe("SessionFetchService", () => { }); it("calls session service refresh on 401", async () => { - vi.mocked(fetch).mockResolvedValue(new Response(null, { status: 401 })); + vi.mocked(mockFetchService.fetch).mockResolvedValue( + new Response(null, { status: 401 }), + ); await makeService().fetch("/api/data"); @@ -101,7 +108,9 @@ describe("SessionFetchService", () => { }); it("executes no-session action when session is inactive after 401", async () => { - vi.mocked(fetch).mockResolvedValue(new Response(null, { status: 401 })); + vi.mocked(mockFetchService.fetch).mockResolvedValue( + new Response(null, { status: 401 }), + ); await makeService(false).fetch("/api/data"); @@ -109,30 +118,54 @@ describe("SessionFetchService", () => { }); it("does not execute no-session action when session is still active after 401", async () => { - vi.mocked(fetch).mockResolvedValue(new Response(null, { status: 401 })); + vi.mocked(mockFetchService.fetch).mockResolvedValue( + new Response(null, { status: 401 }), + ); await makeService(true).fetch("/api/data"); expect(mockNoSessionAction.execute).not.toHaveBeenCalled(); }); + it("delegates fetch call to injected fetch service", async () => { + vi.mocked(mockFetchService.fetch).mockResolvedValue( + new Response(null, { status: 200 }), + ); + const init: RequestInit = { method: "POST", body: "data" }; + + await makeService().fetch("/api/data", init); + + expect(mockFetchService.fetch).toHaveBeenCalledWith( + "/api/data", + expect.objectContaining({ + credentials: "include", + method: "POST", + body: "data", + }), + ); + }); + it("includes credentials by default", async () => { - vi.mocked(fetch).mockResolvedValue(new Response(null, { status: 200 })); + vi.mocked(mockFetchService.fetch).mockResolvedValue( + new Response(null, { status: 200 }), + ); await makeService().fetch("/api/data"); - expect(fetch).toHaveBeenCalledWith( + expect(mockFetchService.fetch).toHaveBeenCalledWith( "/api/data", expect.objectContaining({ credentials: "include" }), ); }); it("allows caller to override credentials", async () => { - vi.mocked(fetch).mockResolvedValue(new Response(null, { status: 200 })); + vi.mocked(mockFetchService.fetch).mockResolvedValue( + new Response(null, { status: 200 }), + ); await makeService().fetch("/api/data", { credentials: "omit" }); - expect(fetch).toHaveBeenCalledWith( + expect(mockFetchService.fetch).toHaveBeenCalledWith( "/api/data", expect.objectContaining({ credentials: "omit" }), ); diff --git a/frontend/omni/src/modules/fetch-service/sessionFetchService.svelte.ts b/frontend/omni/src/modules/fetch-service/sessionFetchService.svelte.ts index 8c5d5c5b..a64278bc 100644 --- a/frontend/omni/src/modules/fetch-service/sessionFetchService.svelte.ts +++ b/frontend/omni/src/modules/fetch-service/sessionFetchService.svelte.ts @@ -6,6 +6,7 @@ import type { import type { FetchService } from "./index.svelte.js"; export function create(deps: ModuleDependencies): FetchService { + const fetchService = deps.getOne("fetchService"); const sessionService = deps.getOne("sessionService"); const noSessionAction = deps.getOne("noSessionAction"); return { @@ -13,7 +14,7 @@ export function create(deps: ModuleDependencies): FetchService { input: RequestInfo | URL, init?: RequestInit, ): Promise { - const response = await fetch(input, { + const response = await fetchService.fetch(input, { credentials: "include", ...init, }); diff --git a/frontend/omni/src/modules/tools-service/index.svelte.test.ts b/frontend/omni/src/modules/tools-service/index.svelte.test.ts index 396f9d2a..f919a6b0 100644 --- a/frontend/omni/src/modules/tools-service/index.svelte.test.ts +++ b/frontend/omni/src/modules/tools-service/index.svelte.test.ts @@ -18,20 +18,18 @@ describe("ModaiBackendToolsService", () => { const tools = [ { type: "function" as const, - function: { - name: "calculate", - description: "Evaluate a math expression", - parameters: { - type: "object", - properties: { - expression: { type: "string" }, - }, + name: "calculate", + description: "Evaluate a math expression", + parameters: { + type: "object", + properties: { + expression: { type: "string" }, }, }, }, ]; vi.mocked(fetchService.fetch).mockResolvedValue( - new Response(JSON.stringify({ tools }), { + new Response(JSON.stringify(tools), { status: 200, headers: { "Content-Type": "application/json" }, }), diff --git a/frontend/omni/src/modules/tools-service/index.svelte.ts b/frontend/omni/src/modules/tools-service/index.svelte.ts index 8c5d5177..33d0de8c 100644 --- a/frontend/omni/src/modules/tools-service/index.svelte.ts +++ b/frontend/omni/src/modules/tools-service/index.svelte.ts @@ -1,11 +1,9 @@ export type OpenAIFunctionTool = { type: "function"; - function: { - name: string; - description?: string; - parameters?: Record; - strict?: boolean; - }; + name: string; + description?: string; + parameters?: Record; + strict?: boolean; }; /** diff --git a/frontend/omni/src/modules/tools-service/modaiBackendToolsService.svelte.ts b/frontend/omni/src/modules/tools-service/modaiBackendToolsService.svelte.ts index d34dff86..82de619a 100644 --- a/frontend/omni/src/modules/tools-service/modaiBackendToolsService.svelte.ts +++ b/frontend/omni/src/modules/tools-service/modaiBackendToolsService.svelte.ts @@ -4,10 +4,6 @@ import type { OpenAIFunctionTool, ToolsService } from "./index.svelte.js"; const API_BASE = "/api/tools"; -type BackendToolsResponse = { - tools: OpenAIFunctionTool[]; -}; - export class ModaiBackendToolsService implements ToolsService { readonly #fetchService: FetchService; @@ -21,8 +17,7 @@ export class ModaiBackendToolsService implements ToolsService { return []; } - const data = (await response.json()) as BackendToolsResponse; - return data.tools ?? []; + return (await response.json()) as OpenAIFunctionTool[]; } } diff --git a/frontend/omni/src/modules/tools-service/toolName.test.ts b/frontend/omni/src/modules/tools-service/toolName.test.ts new file mode 100644 index 00000000..14fc04ab --- /dev/null +++ b/frontend/omni/src/modules/tools-service/toolName.test.ts @@ -0,0 +1,30 @@ +import { describe, expect, it } from "vitest"; +import type { OpenAIFunctionTool } from "./index.svelte.js"; +import { sortToolsByName } from "./toolName"; + +describe("sortToolsByName", () => { + it("sorts tools by plain tool name", () => { + const tools: OpenAIFunctionTool[] = [ + { + type: "function", + name: "zoom", + }, + { + type: "function", + name: "calculate", + }, + { + type: "function", + name: "plain", + }, + ]; + + const sorted = sortToolsByName(tools); + + expect(sorted.map((tool) => tool.name)).toEqual([ + "calculate", + "plain", + "zoom", + ]); + }); +}); diff --git a/frontend/omni/src/modules/tools-service/toolName.ts b/frontend/omni/src/modules/tools-service/toolName.ts new file mode 100644 index 00000000..5ab112c4 --- /dev/null +++ b/frontend/omni/src/modules/tools-service/toolName.ts @@ -0,0 +1,11 @@ +import type { OpenAIFunctionTool } from "./index.svelte.js"; + +export function sortToolsByName( + tools: OpenAIFunctionTool[], +): OpenAIFunctionTool[] { + return [...tools].sort((left, right) => + left.name.localeCompare(right.name, "und", { + sensitivity: "base", + }), + ); +}