diff --git a/.gitignore b/.gitignore
index 7f5be4bf18..34dcc23d79 100644
--- a/.gitignore
+++ b/.gitignore
@@ -42,4 +42,11 @@ cmd/thv-operator/.task/checksum/crdref-gen
# Test coverage
coverage*
-crd-helm-wrapper
\ No newline at end of file
+crd-helm-wrapper
+cmd/vmcp/__debug_bin*
+
+# Demo files
+examples/operator/virtual-mcps/vmcp_optimizer.yaml
+scripts/k8s_vmcp_optimizer_demo.sh
+examples/ingress/mcp-servers-ingress.yaml
+/vmcp
diff --git a/.golangci.yml b/.golangci.yml
index 62c3611473..ff2b3d54e9 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -139,6 +139,7 @@ linters:
- third_party$
- builtin$
- examples$
+ - scripts$
formatters:
enable:
- gci
@@ -155,3 +156,4 @@ formatters:
- third_party$
- builtin$
- examples$
+ - scripts$
diff --git a/cmd/thv-operator/controllers/mcpserver_controller.go b/cmd/thv-operator/controllers/mcpserver_controller.go
index 2313d85ace..f1221c6ecc 100644
--- a/cmd/thv-operator/controllers/mcpserver_controller.go
+++ b/cmd/thv-operator/controllers/mcpserver_controller.go
@@ -1250,12 +1250,13 @@ func (r *MCPServerReconciler) deploymentForMCPServer(
Spec: corev1.PodSpec{
ServiceAccountName: ctrlutil.ProxyRunnerServiceAccountName(m.Name),
Containers: []corev1.Container{{
- Image: getToolhiveRunnerImage(),
- Name: "toolhive",
- Args: args,
- Env: env,
- VolumeMounts: volumeMounts,
- Resources: resources,
+ Image: getToolhiveRunnerImage(),
+ Name: "toolhive",
+ ImagePullPolicy: getImagePullPolicyForToolhiveRunner(),
+ Args: args,
+ Env: env,
+ VolumeMounts: volumeMounts,
+ Resources: resources,
Ports: []corev1.ContainerPort{{
ContainerPort: m.GetProxyPort(),
Name: "http",
@@ -1813,6 +1814,19 @@ func getToolhiveRunnerImage() string {
return image
}
+// getImagePullPolicyForToolhiveRunner returns the appropriate imagePullPolicy for the toolhive runner container.
+// If the image is a local image (starts with "kind.local/" or "localhost/"), use Never.
+// Otherwise, use IfNotPresent to allow pulling when needed but avoid unnecessary pulls.
+func getImagePullPolicyForToolhiveRunner() corev1.PullPolicy {
+ image := getToolhiveRunnerImage()
+ // Check if it's a local image that should use Never
+ if strings.HasPrefix(image, "kind.local/") || strings.HasPrefix(image, "localhost/") {
+ return corev1.PullNever
+ }
+ // For other images, use IfNotPresent to allow pulling when needed
+ return corev1.PullIfNotPresent
+}
+
// handleExternalAuthConfig validates and tracks the hash of the referenced MCPExternalAuthConfig.
// It updates the MCPServer status when the external auth configuration changes.
func (r *MCPServerReconciler) handleExternalAuthConfig(ctx context.Context, m *mcpv1alpha1.MCPServer) error {
diff --git a/cmd/thv-operator/pkg/optimizer/INTEGRATION.md b/cmd/thv-operator/pkg/optimizer/INTEGRATION.md
new file mode 100644
index 0000000000..a231a0dabb
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/INTEGRATION.md
@@ -0,0 +1,134 @@
+# Integrating Optimizer with vMCP
+
+## Overview
+
+The optimizer package ingests MCP server and tool metadata into a searchable database with semantic embeddings. This enables intelligent tool discovery and token optimization for LLM consumption.
+
+## Integration Approach
+
+**Event-Driven Ingestion**: The optimizer integrates directly with vMCP's startup process. When vMCP starts and loads its configured servers, it calls the optimizer to ingest each server's metadata and tools.
+
+❌ **NOT** a separate polling service discovering backends
+✅ **IS** called directly by vMCP during server initialization
+
+## How It Is Integrated
+
+The optimizer is already integrated into vMCP and works automatically when enabled via configuration. Here's how the integration works:
+
+### Initialization
+
+When vMCP starts with optimizer enabled in the configuration, it:
+
+1. Initializes the optimizer database (chromem-go + SQLite FTS5)
+2. Configures the embedding backend (placeholder, Ollama, or vLLM)
+3. Sets up the ingestion service
+
+### Automatic Ingestion
+
+The optimizer integrates with vMCP's `OnRegisterSession` hook, which is called whenever:
+
+- vMCP starts and loads configured MCP servers
+- A new MCP server is dynamically added
+- A session reconnects or refreshes
+
+When this hook is triggered, the optimizer:
+
+1. Retrieves the server's metadata and tools via MCP protocol
+2. Generates embeddings for searchable content
+3. Stores the data in both the vector database (chromem-go) and FTS5 database
+4. Makes the tools immediately available for semantic search
+
+### Exposed Tools
+
+When the optimizer is enabled, vMCP automatically exposes these tools to LLM clients:
+
+- `optim.find_tool`: Semantic search for tools across all registered servers
+- `optim.call_tool`: Dynamic tool invocation after discovery
+
+### Implementation Location
+
+The integration code is located in:
+- `cmd/vmcp/optimizer.go`: Optimizer initialization and configuration
+- `pkg/vmcp/optimizer/optimizer.go`: Session registration hook implementation
+- `cmd/thv-operator/pkg/optimizer/ingestion/service.go`: Core ingestion service
+
+## Configuration
+
+Add optimizer configuration to vMCP's config:
+
+```yaml
+# vMCP config
+optimizer:
+ enabled: true
+ db_path: /data/optimizer.db
+ embedding:
+ backend: vllm # or "ollama" for local dev, "placeholder" for testing
+ url: http://vllm-service:8000
+ model: sentence-transformers/all-MiniLM-L6-v2
+ dimension: 384
+```
+
+## Error Handling
+
+**Important**: Optimizer failures should NOT break vMCP functionality:
+
+- ✅ Log warnings if optimizer fails
+- ✅ Continue server startup even if ingestion fails
+- ✅ Run ingestion in goroutines to avoid blocking
+- ❌ Don't fail server startup if optimizer is unavailable
+
+## Benefits
+
+1. **Automatic**: Servers are indexed as they're added to vMCP
+2. **Up-to-date**: Database reflects current vMCP state
+3. **No polling**: Event-driven, efficient
+4. **Semantic search**: Enables intelligent tool discovery
+5. **Token optimization**: Tracks token usage for LLM efficiency
+
+## Testing
+
+```go
+func TestOptimizerIntegration(t *testing.T) {
+ // Initialize optimizer
+ optimizerSvc, err := ingestion.NewService(&ingestion.Config{
+ DBConfig: &db.Config{Path: "/tmp/test-optimizer.db"},
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ Dimension: 384,
+ },
+ })
+ require.NoError(t, err)
+ defer optimizerSvc.Close()
+
+ // Simulate vMCP starting a server
+ ctx := context.Background()
+ tools := []mcp.Tool{
+ {Name: "get_weather", Description: "Get current weather"},
+ {Name: "get_forecast", Description: "Get weather forecast"},
+ }
+
+ err = optimizerSvc.IngestServer(
+ ctx,
+ "weather-001",
+ "weather-service",
+ "http://weather.local",
+ models.TransportSSE,
+ ptr("Weather information service"),
+ tools,
+ )
+ require.NoError(t, err)
+
+ // Verify ingestion
+ server, err := optimizerSvc.GetServer(ctx, "weather-001")
+ require.NoError(t, err)
+ assert.Equal(t, "weather-service", server.Name)
+}
+```
+
+## See Also
+
+- [Optimizer Package README](./README.md) - Package overview and API
+
diff --git a/cmd/thv-operator/pkg/optimizer/README.md b/cmd/thv-operator/pkg/optimizer/README.md
new file mode 100644
index 0000000000..7db703b711
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/README.md
@@ -0,0 +1,339 @@
+# Optimizer Package
+
+The optimizer package provides semantic tool discovery and ingestion for MCP servers in ToolHive's vMCP. It enables intelligent, context-aware tool selection to reduce token usage and improve LLM performance.
+
+## Features
+
+- **Pure Go**: No CGO dependencies - uses [chromem-go](https://github.com/philippgille/chromem-go) for vector search and `modernc.org/sqlite` for FTS5
+- **Hybrid Search**: Combines semantic search (chromem-go) with BM25 full-text search (SQLite FTS5)
+- **In-Memory by Default**: Fast ephemeral database with optional persistence
+- **Pluggable Embeddings**: Supports vLLM, Ollama, and placeholder backends
+- **Event-Driven**: Integrates with vMCP's `OnRegisterSession` hook for automatic ingestion
+- **Semantic + Keyword Search**: Configurable ratio between semantic and BM25 search
+- **Token Counting**: Tracks token usage for LLM consumption metrics
+
+## Architecture
+
+```
+cmd/thv-operator/pkg/optimizer/
+├── models/ # Domain models (Server, Tool, etc.)
+├── db/ # Hybrid database layer (chromem-go + SQLite FTS5)
+│ ├── db.go # Database coordinator
+│ ├── fts.go # SQLite FTS5 for BM25 search (pure Go)
+│ ├── hybrid.go # Hybrid search combining semantic + BM25
+│ ├── backend_server.go # Server operations
+│ └── backend_tool.go # Tool operations
+├── embeddings/ # Embedding backends (vLLM, Ollama, placeholder)
+├── ingestion/ # Event-driven ingestion service
+└── tokens/ # Token counting for LLM metrics
+```
+
+## Embedding Backends
+
+The optimizer supports multiple embedding backends:
+
+| Backend | Use Case | Performance | Setup |
+|---------|----------|-------------|-------|
+| **vLLM** | **Production/Kubernetes (recommended)** | Excellent (GPU) | Deploy vLLM service |
+| Ollama | Local development, CPU-only | Good | `ollama serve` |
+| Placeholder | Testing, CI/CD | Fast (hash-based) | Zero setup |
+
+**For production Kubernetes deployments, vLLM is recommended** due to its high-throughput performance, GPU efficiency (PagedAttention), and scalability for multi-user environments.
+
+## Hybrid Search
+
+The optimizer **always uses hybrid search** combining:
+
+1. **Semantic Search** (chromem-go): Understands meaning and context via embeddings
+2. **BM25 Full-Text Search** (SQLite FTS5): Keyword matching with Porter stemming
+
+This dual approach ensures the best of both worlds: semantic understanding for intent-based queries and keyword precision for technical terms and acronyms.
+
+### Configuration
+
+```yaml
+optimizer:
+ enabled: true
+ embeddingBackend: placeholder
+ embeddingDimension: 384
+ # persistPath: /data/optimizer # Optional: for persistence
+ # ftsDBPath: /data/optimizer-fts.db # Optional: defaults to :memory: or {persistPath}/fts.db
+ hybridSearchRatio: 70 # 70% semantic, 30% BM25 (default, 0-100 percentage)
+```
+
+| Ratio | Semantic | BM25 | Best For |
+|-------|----------|------|----------|
+| 1.0 | 100% | 0% | Pure semantic (intent-heavy queries) |
+| 0.7 | 70% | 30% | **Default**: Balanced hybrid |
+| 0.5 | 50% | 50% | Equal weight |
+| 0.0 | 0% | 100% | Pure keyword (exact term matching) |
+
+### How It Works
+
+1. **Parallel Execution**: Semantic and BM25 searches run concurrently
+2. **Result Merging**: Combines results and removes duplicates
+3. **Ranking**: Sorts by similarity/relevance score
+4. **Limit Enforcement**: Returns top N results
+
+### Example Queries
+
+| Query | Semantic Match | BM25 Match | Winner |
+|-------|----------------|------------|--------|
+| "What's the weather?" | ✅ `get_current_weather` | ✅ `weather_forecast` | Both (deduped) |
+| "SQL database query" | ❌ (no embeddings) | ✅ `execute_sql` | BM25 |
+| "Make it rain outside" | ✅ `weather_control` | ❌ (no keyword) | Semantic |
+
+## Quick Start
+
+### vMCP Integration (Recommended)
+
+The optimizer is designed to work as part of vMCP, not standalone:
+
+```yaml
+# examples/vmcp-config-optimizer.yaml
+optimizer:
+ enabled: true
+ embeddingBackend: placeholder # or "ollama", "openai-compatible"
+ embeddingDimension: 384
+ # persistPath: /data/optimizer # Optional: for chromem-go persistence
+ # ftsDBPath: /data/fts.db # Optional: auto-defaults to :memory: or {persistPath}/fts.db
+ # hybridSearchRatio: 70 # Optional: 70% semantic, 30% BM25 (default, 0-100 percentage)
+```
+
+Start vMCP with optimizer:
+
+```bash
+thv vmcp serve --config examples/vmcp-config-optimizer.yaml
+```
+
+When optimizer is enabled, vMCP exposes:
+- `optim.find_tool`: Semantic search for tools
+- `optim.call_tool`: Dynamic tool invocation
+
+### Programmatic Usage
+
+```go
+import (
+ "context"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/db"
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings"
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/ingestion"
+)
+
+func main() {
+ ctx := context.Background()
+
+ // Initialize database (in-memory)
+ database, err := db.NewDB(&db.Config{
+ PersistPath: "", // Empty = in-memory only
+ })
+ if err != nil {
+ panic(err)
+ }
+
+ // Initialize embedding manager with Ollama (default)
+ embeddingMgr, err := embeddings.NewManager(&embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ })
+ if err != nil {
+ panic(err)
+ }
+
+ // Create ingestion service
+ svc, err := ingestion.NewService(&ingestion.Config{
+ DBConfig: &db.Config{PersistPath: ""},
+ EmbeddingConfig: embeddingMgr.Config(),
+ })
+ if err != nil {
+ panic(err)
+ }
+ defer svc.Close()
+
+ // Ingest a server (called by vMCP on session registration)
+ err = svc.IngestServer(ctx, "server-id", "MyServer", nil, []mcp.Tool{...})
+ if err != nil {
+ panic(err)
+ }
+}
+```
+
+### Production Deployment with vLLM (Kubernetes)
+
+```yaml
+optimizer:
+ enabled: true
+ embeddingBackend: openai-compatible
+ embeddingURL: http://vllm-service:8000/v1
+ embeddingModel: BAAI/bge-small-en-v1.5
+ embeddingDimension: 768
+ persistPath: /data/optimizer # Persistent storage for faster restarts
+```
+
+Deploy vLLM alongside vMCP:
+
+```yaml
+apiVersion: apps/v1
+kind: Deployment
+metadata:
+ name: vllm-embeddings
+spec:
+ template:
+ spec:
+ containers:
+ - name: vllm
+ image: vllm/vllm-openai:latest
+ args:
+ - --model
+ - BAAI/bge-small-en-v1.5
+ - --port
+ - "8000"
+ resources:
+ limits:
+ nvidia.com/gpu: 1
+```
+
+### Local Development with Ollama
+
+```bash
+# Start Ollama
+ollama serve
+
+# Pull an embedding model
+ollama pull all-minilm
+```
+
+Configure vMCP:
+
+```yaml
+optimizer:
+ enabled: true
+ embeddingBackend: ollama
+ embeddingURL: http://localhost:11434
+ embeddingModel: all-minilm
+ embeddingDimension: 384
+```
+
+## Configuration
+
+### Database
+
+- **Storage**: chromem-go (pure Go, no CGO)
+- **Default**: In-memory (ephemeral)
+- **Persistence**: Optional via `persistPath`
+- **Format**: Binary (gob encoding)
+
+### Embedding Models
+
+Common embedding dimensions:
+- **384**: all-MiniLM-L6-v2, nomic-embed-text (default)
+- **768**: BAAI/bge-small-en-v1.5
+- **1536**: OpenAI text-embedding-3-small
+
+### Performance
+
+From chromem-go benchmarks (mid-range 2020 Intel laptop):
+- **1,000 tools**: ~0.5ms query time
+- **5,000 tools**: ~2.2ms query time
+- **25,000 tools**: ~9.9ms query time
+- **100,000 tools**: ~39.6ms query time
+
+Perfect for typical vMCP deployments (hundreds to thousands of tools).
+
+## Testing
+
+Run the unit tests:
+
+```bash
+# Test all packages
+go test ./cmd/thv-operator/pkg/optimizer/...
+
+# Test with coverage
+go test -cover ./cmd/thv-operator/pkg/optimizer/...
+
+# Test specific package
+go test ./cmd/thv-operator/pkg/optimizer/models
+```
+
+## Inspecting the Database
+
+The optimizer uses a hybrid database (chromem-go + SQLite FTS5). Here's how to inspect each:
+
+### Inspecting SQLite FTS5 (Easiest)
+
+The FTS5 database is standard SQLite and can be opened with any SQLite tool:
+
+```bash
+# Use sqlite3 CLI
+sqlite3 /tmp/vmcp-optimizer-fts.db
+
+# Count documents
+SELECT COUNT(*) FROM backend_servers_fts;
+SELECT COUNT(*) FROM backend_tools_fts;
+
+# View tool names and descriptions
+SELECT tool_name, tool_description FROM backend_tools_fts LIMIT 10;
+
+# Full-text search with BM25 ranking
+SELECT tool_name, rank
+FROM backend_tool_fts_index
+WHERE backend_tool_fts_index MATCH 'github repository'
+ORDER BY rank
+LIMIT 5;
+
+# Join servers and tools
+SELECT s.name, t.tool_name, t.tool_description
+FROM backend_tools_fts t
+JOIN backend_servers_fts s ON t.mcpserver_id = s.id
+LIMIT 10;
+```
+
+**VSCode Extension**: Install `alexcvzz.vscode-sqlite` to view `.db` files directly in VSCode.
+
+### Inspecting chromem-go (Vector Database)
+
+chromem-go uses `.gob` binary files. Use the provided inspection scripts:
+
+```bash
+# Quick summary (shows collection sizes and first few documents)
+go run scripts/inspect-chromem-raw.go /tmp/vmcp-optimizer-debug.db
+
+# View specific tool with full metadata and embeddings
+go run scripts/view-chromem-tool.go /tmp/vmcp-optimizer-debug.db get_file_contents
+
+# View all documents (warning: lots of output)
+go run scripts/view-chromem-tool.go /tmp/vmcp-optimizer-debug.db
+
+# Search by content
+go run scripts/view-chromem-tool.go /tmp/vmcp-optimizer-debug.db "search"
+```
+
+### chromem-go Schema
+
+Each document in chromem-go contains:
+
+```go
+Document {
+ ID: string // "github" or UUID for tools
+ Content: string // "tool_name. description..."
+ Embedding: []float32 // 384-dimensional vector
+ Metadata: map[string]string // {"type": "backend_tool", "server_id": "github", "data": "...JSON..."}
+}
+```
+
+**Collections**:
+- `backend_servers`: Server metadata (3 documents in typical setup)
+- `backend_tools`: Tool metadata and embeddings (40+ documents)
+
+## Known Limitations
+
+1. **Scale**: Optimized for <100,000 tools (more than sufficient for typical vMCP deployments)
+2. **Approximate Search**: chromem-go uses exhaustive search (not HNSW), but this is fine for our scale
+3. **Persistence Format**: Binary gob format (not human-readable)
+
+## License
+
+This package is part of ToolHive and follows the same license.
diff --git a/cmd/thv-operator/pkg/optimizer/db/backend_server.go b/cmd/thv-operator/pkg/optimizer/db/backend_server.go
new file mode 100644
index 0000000000..296969f07d
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/db/backend_server.go
@@ -0,0 +1,243 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+// Package db provides chromem-go based database operations for the optimizer.
+package db
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "time"
+
+ "github.com/philippgille/chromem-go"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models"
+ "github.com/stacklok/toolhive/pkg/logger"
+)
+
+// BackendServerOps provides operations for backend servers in chromem-go
+type BackendServerOps struct {
+ db *DB
+ embeddingFunc chromem.EmbeddingFunc
+}
+
+// NewBackendServerOps creates a new BackendServerOps instance
+func NewBackendServerOps(db *DB, embeddingFunc chromem.EmbeddingFunc) *BackendServerOps {
+ return &BackendServerOps{
+ db: db,
+ embeddingFunc: embeddingFunc,
+ }
+}
+
+// Create adds a new backend server to the collection
+func (ops *BackendServerOps) Create(ctx context.Context, server *models.BackendServer) error {
+ collection, err := ops.db.GetOrCreateCollection(ctx, BackendServerCollection, ops.embeddingFunc)
+ if err != nil {
+ return fmt.Errorf("failed to get backend server collection: %w", err)
+ }
+
+ // Prepare content for embedding (name + description)
+ content := server.Name
+ if server.Description != nil && *server.Description != "" {
+ content += ". " + *server.Description
+ }
+
+ // Serialize metadata
+ metadata, err := serializeServerMetadata(server)
+ if err != nil {
+ return fmt.Errorf("failed to serialize server metadata: %w", err)
+ }
+
+ // Create document
+ doc := chromem.Document{
+ ID: server.ID,
+ Content: content,
+ Metadata: metadata,
+ }
+
+ // If embedding is provided, use it
+ if len(server.ServerEmbedding) > 0 {
+ doc.Embedding = server.ServerEmbedding
+ }
+
+ // Add document to chromem-go collection
+ err = collection.AddDocument(ctx, doc)
+ if err != nil {
+ return fmt.Errorf("failed to add server document to chromem-go: %w", err)
+ }
+
+ // Also add to FTS5 database if available (for keyword filtering)
+ // Use background context to avoid cancellation issues - FTS5 is supplementary
+ if ftsDB := ops.db.GetFTSDB(); ftsDB != nil {
+ // Use background context with timeout for FTS operations
+ // This ensures FTS operations complete even if the original context is canceled
+ ftsCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+ if err := ftsDB.UpsertServer(ftsCtx, server); err != nil {
+ // Log but don't fail - FTS5 is supplementary
+ logger.Warnf("Failed to upsert server to FTS5: %v", err)
+ }
+ }
+
+ logger.Debugf("Created backend server: %s (chromem-go + FTS5)", server.ID)
+ return nil
+}
+
+// Get retrieves a backend server by ID
+func (ops *BackendServerOps) Get(ctx context.Context, serverID string) (*models.BackendServer, error) {
+ collection, err := ops.db.GetCollection(BackendServerCollection, ops.embeddingFunc)
+ if err != nil {
+ return nil, fmt.Errorf("backend server collection not found: %w", err)
+ }
+
+ // Query by ID with exact match
+ results, err := collection.Query(ctx, serverID, 1, nil, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to query server: %w", err)
+ }
+
+ if len(results) == 0 {
+ return nil, fmt.Errorf("server not found: %s", serverID)
+ }
+
+ // Deserialize from metadata
+ server, err := deserializeServerMetadata(results[0].Metadata)
+ if err != nil {
+ return nil, fmt.Errorf("failed to deserialize server: %w", err)
+ }
+
+ return server, nil
+}
+
+// Update updates an existing backend server
+func (ops *BackendServerOps) Update(ctx context.Context, server *models.BackendServer) error {
+ // chromem-go doesn't have an update operation, so we delete and re-create
+ err := ops.Delete(ctx, server.ID)
+ if err != nil {
+ // If server doesn't exist, that's fine
+ logger.Debugf("Server %s not found for update, will create new", server.ID)
+ }
+
+ return ops.Create(ctx, server)
+}
+
+// Delete removes a backend server
+func (ops *BackendServerOps) Delete(ctx context.Context, serverID string) error {
+ collection, err := ops.db.GetCollection(BackendServerCollection, ops.embeddingFunc)
+ if err != nil {
+ // Collection doesn't exist, nothing to delete
+ return nil
+ }
+
+ err = collection.Delete(ctx, nil, nil, serverID)
+ if err != nil {
+ return fmt.Errorf("failed to delete server from chromem-go: %w", err)
+ }
+
+ // Also delete from FTS5 database if available
+ if ftsDB := ops.db.GetFTSDB(); ftsDB != nil {
+ if err := ftsDB.DeleteServer(ctx, serverID); err != nil {
+ // Log but don't fail
+ logger.Warnf("Failed to delete server from FTS5: %v", err)
+ }
+ }
+
+ logger.Debugf("Deleted backend server: %s (chromem-go + FTS5)", serverID)
+ return nil
+}
+
+// List returns all backend servers
+func (ops *BackendServerOps) List(ctx context.Context) ([]*models.BackendServer, error) {
+ collection, err := ops.db.GetCollection(BackendServerCollection, ops.embeddingFunc)
+ if err != nil {
+ // Collection doesn't exist yet, return empty list
+ return []*models.BackendServer{}, nil
+ }
+
+ // Get count to determine nResults
+ count := collection.Count()
+ if count == 0 {
+ return []*models.BackendServer{}, nil
+ }
+
+ // Query with a generic term to get all servers
+ // Using "server" as a generic query that should match all servers
+ results, err := collection.Query(ctx, "server", count, nil, nil)
+ if err != nil {
+ return []*models.BackendServer{}, nil
+ }
+
+ servers := make([]*models.BackendServer, 0, len(results))
+ for _, result := range results {
+ server, err := deserializeServerMetadata(result.Metadata)
+ if err != nil {
+ logger.Warnf("Failed to deserialize server: %v", err)
+ continue
+ }
+ servers = append(servers, server)
+ }
+
+ return servers, nil
+}
+
+// Search performs semantic search for backend servers
+func (ops *BackendServerOps) Search(ctx context.Context, query string, limit int) ([]*models.BackendServer, error) {
+ collection, err := ops.db.GetCollection(BackendServerCollection, ops.embeddingFunc)
+ if err != nil {
+ return []*models.BackendServer{}, nil
+ }
+
+ // Get collection count and adjust limit if necessary
+ count := collection.Count()
+ if count == 0 {
+ return []*models.BackendServer{}, nil
+ }
+ if limit > count {
+ limit = count
+ }
+
+ results, err := collection.Query(ctx, query, limit, nil, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to search servers: %w", err)
+ }
+
+ servers := make([]*models.BackendServer, 0, len(results))
+ for _, result := range results {
+ server, err := deserializeServerMetadata(result.Metadata)
+ if err != nil {
+ logger.Warnf("Failed to deserialize server: %v", err)
+ continue
+ }
+ servers = append(servers, server)
+ }
+
+ return servers, nil
+}
+
+// Helper functions for metadata serialization
+
+func serializeServerMetadata(server *models.BackendServer) (map[string]string, error) {
+ data, err := json.Marshal(server)
+ if err != nil {
+ return nil, err
+ }
+ return map[string]string{
+ "data": string(data),
+ "type": "backend_server",
+ }, nil
+}
+
+func deserializeServerMetadata(metadata map[string]string) (*models.BackendServer, error) {
+ data, ok := metadata["data"]
+ if !ok {
+ return nil, fmt.Errorf("missing data field in metadata")
+ }
+
+ var server models.BackendServer
+ if err := json.Unmarshal([]byte(data), &server); err != nil {
+ return nil, err
+ }
+
+ return &server, nil
+}
diff --git a/cmd/thv-operator/pkg/optimizer/db/backend_server_test.go b/cmd/thv-operator/pkg/optimizer/db/backend_server_test.go
new file mode 100644
index 0000000000..9cc9a8aa43
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/db/backend_server_test.go
@@ -0,0 +1,427 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package db
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models"
+)
+
+// TestBackendServerOps_Create tests creating a backend server
+func TestBackendServerOps_Create(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendServerOps(db, embeddingFunc)
+
+ description := "A test MCP server"
+ server := &models.BackendServer{
+ ID: "server-1",
+ Name: "Test Server",
+ Description: &description,
+ Group: "default",
+ }
+
+ err := ops.Create(ctx, server)
+ require.NoError(t, err)
+
+ // Verify server was created by retrieving it
+ retrieved, err := ops.Get(ctx, "server-1")
+ require.NoError(t, err)
+ assert.Equal(t, "Test Server", retrieved.Name)
+ assert.Equal(t, "server-1", retrieved.ID)
+ assert.Equal(t, description, *retrieved.Description)
+}
+
+// TestBackendServerOps_CreateWithEmbedding tests creating server with precomputed embedding
+func TestBackendServerOps_CreateWithEmbedding(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendServerOps(db, embeddingFunc)
+
+ description := "Server with embedding"
+ embedding := make([]float32, 384)
+ for i := range embedding {
+ embedding[i] = 0.5
+ }
+
+ server := &models.BackendServer{
+ ID: "server-2",
+ Name: "Embedded Server",
+ Description: &description,
+ Group: "default",
+ ServerEmbedding: embedding,
+ }
+
+ err := ops.Create(ctx, server)
+ require.NoError(t, err)
+
+ // Verify server was created
+ retrieved, err := ops.Get(ctx, "server-2")
+ require.NoError(t, err)
+ assert.Equal(t, "Embedded Server", retrieved.Name)
+}
+
+// TestBackendServerOps_Get tests retrieving a backend server
+func TestBackendServerOps_Get(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendServerOps(db, embeddingFunc)
+
+ // Create a server first
+ description := "GitHub MCP server"
+ server := &models.BackendServer{
+ ID: "github-server",
+ Name: "GitHub",
+ Description: &description,
+ Group: "development",
+ }
+
+ err := ops.Create(ctx, server)
+ require.NoError(t, err)
+
+ // Test Get
+ retrieved, err := ops.Get(ctx, "github-server")
+ require.NoError(t, err)
+ assert.Equal(t, "github-server", retrieved.ID)
+ assert.Equal(t, "GitHub", retrieved.Name)
+ assert.Equal(t, "development", retrieved.Group)
+}
+
+// TestBackendServerOps_Get_NotFound tests retrieving non-existent server
+func TestBackendServerOps_Get_NotFound(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendServerOps(db, embeddingFunc)
+
+ // Try to get a non-existent server
+ _, err := ops.Get(ctx, "non-existent")
+ assert.Error(t, err)
+ // Error message could be "server not found" or "collection not found" depending on state
+ assert.True(t, err != nil, "Should return an error for non-existent server")
+}
+
+// TestBackendServerOps_Update tests updating a backend server
+func TestBackendServerOps_Update(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendServerOps(db, embeddingFunc)
+
+ // Create initial server
+ description := "Original description"
+ server := &models.BackendServer{
+ ID: "server-1",
+ Name: "Original Name",
+ Description: &description,
+ Group: "default",
+ }
+
+ err := ops.Create(ctx, server)
+ require.NoError(t, err)
+
+ // Update the server
+ updatedDescription := "Updated description"
+ server.Name = "Updated Name"
+ server.Description = &updatedDescription
+
+ err = ops.Update(ctx, server)
+ require.NoError(t, err)
+
+ // Verify update
+ retrieved, err := ops.Get(ctx, "server-1")
+ require.NoError(t, err)
+ assert.Equal(t, "Updated Name", retrieved.Name)
+ assert.Equal(t, "Updated description", *retrieved.Description)
+}
+
+// TestBackendServerOps_Update_NonExistent tests updating non-existent server
+func TestBackendServerOps_Update_NonExistent(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendServerOps(db, embeddingFunc)
+
+ // Try to update non-existent server (should create it)
+ description := "New server"
+ server := &models.BackendServer{
+ ID: "new-server",
+ Name: "New Server",
+ Description: &description,
+ Group: "default",
+ }
+
+ err := ops.Update(ctx, server)
+ require.NoError(t, err)
+
+ // Verify server was created
+ retrieved, err := ops.Get(ctx, "new-server")
+ require.NoError(t, err)
+ assert.Equal(t, "New Server", retrieved.Name)
+}
+
+// TestBackendServerOps_Delete tests deleting a backend server
+func TestBackendServerOps_Delete(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendServerOps(db, embeddingFunc)
+
+ // Create a server
+ description := "Server to delete"
+ server := &models.BackendServer{
+ ID: "delete-me",
+ Name: "Delete Me",
+ Description: &description,
+ Group: "default",
+ }
+
+ err := ops.Create(ctx, server)
+ require.NoError(t, err)
+
+ // Delete the server
+ err = ops.Delete(ctx, "delete-me")
+ require.NoError(t, err)
+
+ // Verify deletion
+ _, err = ops.Get(ctx, "delete-me")
+ assert.Error(t, err, "Should not find deleted server")
+}
+
+// TestBackendServerOps_Delete_NonExistent tests deleting non-existent server
+func TestBackendServerOps_Delete_NonExistent(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendServerOps(db, embeddingFunc)
+
+ // Try to delete a non-existent server - should not error
+ err := ops.Delete(ctx, "non-existent")
+ assert.NoError(t, err)
+}
+
+// TestBackendServerOps_List tests listing all servers
+func TestBackendServerOps_List(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendServerOps(db, embeddingFunc)
+
+ // Create multiple servers
+ desc1 := "Server 1"
+ server1 := &models.BackendServer{
+ ID: "server-1",
+ Name: "Server 1",
+ Description: &desc1,
+ Group: "group-a",
+ }
+
+ desc2 := "Server 2"
+ server2 := &models.BackendServer{
+ ID: "server-2",
+ Name: "Server 2",
+ Description: &desc2,
+ Group: "group-b",
+ }
+
+ desc3 := "Server 3"
+ server3 := &models.BackendServer{
+ ID: "server-3",
+ Name: "Server 3",
+ Description: &desc3,
+ Group: "group-a",
+ }
+
+ err := ops.Create(ctx, server1)
+ require.NoError(t, err)
+ err = ops.Create(ctx, server2)
+ require.NoError(t, err)
+ err = ops.Create(ctx, server3)
+ require.NoError(t, err)
+
+ // List all servers
+ servers, err := ops.List(ctx)
+ require.NoError(t, err)
+ assert.Len(t, servers, 3, "Should have 3 servers")
+
+ // Verify server names
+ serverNames := make(map[string]bool)
+ for _, server := range servers {
+ serverNames[server.Name] = true
+ }
+ assert.True(t, serverNames["Server 1"])
+ assert.True(t, serverNames["Server 2"])
+ assert.True(t, serverNames["Server 3"])
+}
+
+// TestBackendServerOps_List_Empty tests listing servers on empty database
+func TestBackendServerOps_List_Empty(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendServerOps(db, embeddingFunc)
+
+ // List empty database
+ servers, err := ops.List(ctx)
+ require.NoError(t, err)
+ assert.Empty(t, servers, "Should return empty list for empty database")
+}
+
+// TestBackendServerOps_Search tests semantic search for servers
+func TestBackendServerOps_Search(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendServerOps(db, embeddingFunc)
+
+ // Create test servers
+ desc1 := "GitHub integration server"
+ server1 := &models.BackendServer{
+ ID: "github",
+ Name: "GitHub Server",
+ Description: &desc1,
+ Group: "vcs",
+ }
+
+ desc2 := "Slack messaging server"
+ server2 := &models.BackendServer{
+ ID: "slack",
+ Name: "Slack Server",
+ Description: &desc2,
+ Group: "messaging",
+ }
+
+ err := ops.Create(ctx, server1)
+ require.NoError(t, err)
+ err = ops.Create(ctx, server2)
+ require.NoError(t, err)
+
+ // Search for servers
+ results, err := ops.Search(ctx, "integration", 5)
+ require.NoError(t, err)
+ assert.NotEmpty(t, results, "Should find servers")
+}
+
+// TestBackendServerOps_Search_Empty tests search on empty database
+func TestBackendServerOps_Search_Empty(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendServerOps(db, embeddingFunc)
+
+ // Search empty database
+ results, err := ops.Search(ctx, "anything", 5)
+ require.NoError(t, err)
+ assert.Empty(t, results, "Should return empty results for empty database")
+}
+
+// TestBackendServerOps_MetadataSerialization tests metadata serialization/deserialization
+func TestBackendServerOps_MetadataSerialization(t *testing.T) {
+ t.Parallel()
+
+ description := "Test server"
+ server := &models.BackendServer{
+ ID: "server-1",
+ Name: "Test Server",
+ Description: &description,
+ Group: "default",
+ }
+
+ // Test serialization
+ metadata, err := serializeServerMetadata(server)
+ require.NoError(t, err)
+ assert.Contains(t, metadata, "data")
+ assert.Equal(t, "backend_server", metadata["type"])
+
+ // Test deserialization
+ deserializedServer, err := deserializeServerMetadata(metadata)
+ require.NoError(t, err)
+ assert.Equal(t, server.ID, deserializedServer.ID)
+ assert.Equal(t, server.Name, deserializedServer.Name)
+ assert.Equal(t, server.Group, deserializedServer.Group)
+}
+
+// TestBackendServerOps_MetadataDeserialization_MissingData tests error handling
+func TestBackendServerOps_MetadataDeserialization_MissingData(t *testing.T) {
+ t.Parallel()
+
+ // Test with missing data field
+ metadata := map[string]string{
+ "type": "backend_server",
+ }
+
+ _, err := deserializeServerMetadata(metadata)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "missing data field")
+}
+
+// TestBackendServerOps_MetadataDeserialization_InvalidJSON tests invalid JSON handling
+func TestBackendServerOps_MetadataDeserialization_InvalidJSON(t *testing.T) {
+ t.Parallel()
+
+ // Test with invalid JSON
+ metadata := map[string]string{
+ "data": "invalid json {",
+ "type": "backend_server",
+ }
+
+ _, err := deserializeServerMetadata(metadata)
+ assert.Error(t, err)
+}
diff --git a/cmd/thv-operator/pkg/optimizer/db/backend_server_test_coverage.go b/cmd/thv-operator/pkg/optimizer/db/backend_server_test_coverage.go
new file mode 100644
index 0000000000..055b6a3353
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/db/backend_server_test_coverage.go
@@ -0,0 +1,97 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package db
+
+import (
+ "context"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models"
+)
+
+// TestBackendServerOps_Create_FTS tests FTS integration in Create
+func TestBackendServerOps_Create_FTS(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ config := &Config{
+ PersistPath: filepath.Join(tmpDir, "test-db"),
+ FTSDBPath: filepath.Join(tmpDir, "fts.db"),
+ }
+
+ db, err := NewDB(config)
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
+ return []float32{0.1, 0.2, 0.3}, nil
+ }
+
+ ops := NewBackendServerOps(db, embeddingFunc)
+
+ server := &models.BackendServer{
+ ID: "server-1",
+ Name: "Test Server",
+ Description: stringPtr("A test server"),
+ Group: "default",
+ CreatedAt: time.Now(),
+ LastUpdated: time.Now(),
+ }
+
+ // Create should also update FTS
+ err = ops.Create(ctx, server)
+ require.NoError(t, err)
+
+ // Verify FTS was updated by checking FTS DB directly
+ ftsDB := db.GetFTSDB()
+ require.NotNil(t, ftsDB)
+
+ // FTS should have the server
+ // We can't easily query FTS directly, but we can verify it doesn't error
+}
+
+// TestBackendServerOps_Delete_FTS tests FTS integration in Delete
+func TestBackendServerOps_Delete_FTS(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ config := &Config{
+ PersistPath: filepath.Join(tmpDir, "test-db"),
+ FTSDBPath: filepath.Join(tmpDir, "fts.db"),
+ }
+
+ db, err := NewDB(config)
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
+ return []float32{0.1, 0.2, 0.3}, nil
+ }
+
+ ops := NewBackendServerOps(db, embeddingFunc)
+
+ desc := "A test server"
+ server := &models.BackendServer{
+ ID: "server-1",
+ Name: "Test Server",
+ Description: &desc,
+ Group: "default",
+ CreatedAt: time.Now(),
+ LastUpdated: time.Now(),
+ }
+
+ // Create server
+ err = ops.Create(ctx, server)
+ require.NoError(t, err)
+
+ // Delete should also delete from FTS
+ err = ops.Delete(ctx, server.ID)
+ require.NoError(t, err)
+}
diff --git a/cmd/thv-operator/pkg/optimizer/db/backend_tool.go b/cmd/thv-operator/pkg/optimizer/db/backend_tool.go
new file mode 100644
index 0000000000..3dfa860f1a
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/db/backend_tool.go
@@ -0,0 +1,319 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package db
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "time"
+
+ "github.com/philippgille/chromem-go"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models"
+ "github.com/stacklok/toolhive/pkg/logger"
+)
+
+// BackendToolOps provides operations for backend tools in chromem-go
+type BackendToolOps struct {
+ db *DB
+ embeddingFunc chromem.EmbeddingFunc
+}
+
+// NewBackendToolOps creates a new BackendToolOps instance
+func NewBackendToolOps(db *DB, embeddingFunc chromem.EmbeddingFunc) *BackendToolOps {
+ return &BackendToolOps{
+ db: db,
+ embeddingFunc: embeddingFunc,
+ }
+}
+
+// Create adds a new backend tool to the collection
+func (ops *BackendToolOps) Create(ctx context.Context, tool *models.BackendTool, serverName string) error {
+ collection, err := ops.db.GetOrCreateCollection(ctx, BackendToolCollection, ops.embeddingFunc)
+ if err != nil {
+ return fmt.Errorf("failed to get backend tool collection: %w", err)
+ }
+
+ // Prepare content for embedding (name + description + input schema summary)
+ content := tool.ToolName
+ if tool.Description != nil && *tool.Description != "" {
+ content += ". " + *tool.Description
+ }
+
+ // Serialize metadata
+ metadata, err := serializeToolMetadata(tool)
+ if err != nil {
+ return fmt.Errorf("failed to serialize tool metadata: %w", err)
+ }
+
+ // Create document
+ doc := chromem.Document{
+ ID: tool.ID,
+ Content: content,
+ Metadata: metadata,
+ }
+
+ // If embedding is provided, use it
+ if len(tool.ToolEmbedding) > 0 {
+ doc.Embedding = tool.ToolEmbedding
+ }
+
+ // Add document to chromem-go collection
+ err = collection.AddDocument(ctx, doc)
+ if err != nil {
+ return fmt.Errorf("failed to add tool document to chromem-go: %w", err)
+ }
+
+ // Also add to FTS5 database if available (for BM25 search)
+ // Use background context to avoid cancellation issues - FTS5 is supplementary
+ if ops.db.fts != nil {
+ // Use background context with timeout for FTS operations
+ // This ensures FTS operations complete even if the original context is canceled
+ ftsCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+ if err := ops.db.fts.UpsertToolMeta(ftsCtx, tool, serverName); err != nil {
+ // Log but don't fail - FTS5 is supplementary
+ logger.Warnf("Failed to upsert tool to FTS5: %v", err)
+ }
+ }
+
+ logger.Debugf("Created backend tool: %s (chromem-go + FTS5)", tool.ID)
+ return nil
+}
+
+// Get retrieves a backend tool by ID
+func (ops *BackendToolOps) Get(ctx context.Context, toolID string) (*models.BackendTool, error) {
+ collection, err := ops.db.GetCollection(BackendToolCollection, ops.embeddingFunc)
+ if err != nil {
+ return nil, fmt.Errorf("backend tool collection not found: %w", err)
+ }
+
+ // Query by ID with exact match
+ results, err := collection.Query(ctx, toolID, 1, nil, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to query tool: %w", err)
+ }
+
+ if len(results) == 0 {
+ return nil, fmt.Errorf("tool not found: %s", toolID)
+ }
+
+ // Deserialize from metadata
+ tool, err := deserializeToolMetadata(results[0].Metadata)
+ if err != nil {
+ return nil, fmt.Errorf("failed to deserialize tool: %w", err)
+ }
+
+ return tool, nil
+}
+
+// Update updates an existing backend tool in chromem-go
+// Note: This only updates chromem-go, not FTS5. Use Create to update both.
+func (ops *BackendToolOps) Update(ctx context.Context, tool *models.BackendTool) error {
+ collection, err := ops.db.GetOrCreateCollection(ctx, BackendToolCollection, ops.embeddingFunc)
+ if err != nil {
+ return fmt.Errorf("failed to get backend tool collection: %w", err)
+ }
+
+ // Prepare content for embedding
+ content := tool.ToolName
+ if tool.Description != nil && *tool.Description != "" {
+ content += ". " + *tool.Description
+ }
+
+ // Serialize metadata
+ metadata, err := serializeToolMetadata(tool)
+ if err != nil {
+ return fmt.Errorf("failed to serialize tool metadata: %w", err)
+ }
+
+ // Delete existing document
+ _ = collection.Delete(ctx, nil, nil, tool.ID) // Ignore error if doesn't exist
+
+ // Create updated document
+ doc := chromem.Document{
+ ID: tool.ID,
+ Content: content,
+ Metadata: metadata,
+ }
+
+ if len(tool.ToolEmbedding) > 0 {
+ doc.Embedding = tool.ToolEmbedding
+ }
+
+ err = collection.AddDocument(ctx, doc)
+ if err != nil {
+ return fmt.Errorf("failed to update tool document: %w", err)
+ }
+
+ logger.Debugf("Updated backend tool: %s", tool.ID)
+ return nil
+}
+
+// Delete removes a backend tool
+func (ops *BackendToolOps) Delete(ctx context.Context, toolID string) error {
+ collection, err := ops.db.GetCollection(BackendToolCollection, ops.embeddingFunc)
+ if err != nil {
+ // Collection doesn't exist, nothing to delete
+ return nil
+ }
+
+ err = collection.Delete(ctx, nil, nil, toolID)
+ if err != nil {
+ return fmt.Errorf("failed to delete tool: %w", err)
+ }
+
+ logger.Debugf("Deleted backend tool: %s", toolID)
+ return nil
+}
+
+// DeleteByServer removes all tools for a given server from both chromem-go and FTS5
+func (ops *BackendToolOps) DeleteByServer(ctx context.Context, serverID string) error {
+ collection, err := ops.db.GetCollection(BackendToolCollection, ops.embeddingFunc)
+ if err != nil {
+ // Collection doesn't exist, nothing to delete in chromem-go
+ logger.Debug("Backend tool collection not found, skipping chromem-go deletion")
+ } else {
+ // Query all tools for this server
+ tools, err := ops.ListByServer(ctx, serverID)
+ if err != nil {
+ return fmt.Errorf("failed to list tools for server: %w", err)
+ }
+
+ // Delete each tool from chromem-go
+ for _, tool := range tools {
+ if err := collection.Delete(ctx, nil, nil, tool.ID); err != nil {
+ logger.Warnf("Failed to delete tool %s from chromem-go: %v", tool.ID, err)
+ }
+ }
+
+ logger.Debugf("Deleted %d tools from chromem-go for server: %s", len(tools), serverID)
+ }
+
+ // Also delete from FTS5 database if available
+ if ops.db.fts != nil {
+ if err := ops.db.fts.DeleteToolsByServer(ctx, serverID); err != nil {
+ logger.Warnf("Failed to delete tools from FTS5 for server %s: %v", serverID, err)
+ } else {
+ logger.Debugf("Deleted tools from FTS5 for server: %s", serverID)
+ }
+ }
+
+ return nil
+}
+
+// ListByServer returns all tools for a given server
+func (ops *BackendToolOps) ListByServer(ctx context.Context, serverID string) ([]*models.BackendTool, error) {
+ collection, err := ops.db.GetCollection(BackendToolCollection, ops.embeddingFunc)
+ if err != nil {
+ // Collection doesn't exist yet, return empty list
+ return []*models.BackendTool{}, nil
+ }
+
+ // Get count to determine nResults
+ count := collection.Count()
+ if count == 0 {
+ return []*models.BackendTool{}, nil
+ }
+
+ // Query with a generic term and metadata filter
+ // Using "tool" as a generic query that should match all tools
+ results, err := collection.Query(ctx, "tool", count, map[string]string{"server_id": serverID}, nil)
+ if err != nil {
+ // If no tools match, return empty list
+ return []*models.BackendTool{}, nil
+ }
+
+ tools := make([]*models.BackendTool, 0, len(results))
+ for _, result := range results {
+ tool, err := deserializeToolMetadata(result.Metadata)
+ if err != nil {
+ logger.Warnf("Failed to deserialize tool: %v", err)
+ continue
+ }
+ tools = append(tools, tool)
+ }
+
+ return tools, nil
+}
+
+// Search performs semantic search for backend tools
+func (ops *BackendToolOps) Search(
+ ctx context.Context,
+ query string,
+ limit int,
+ serverID *string,
+) ([]*models.BackendToolWithMetadata, error) {
+ collection, err := ops.db.GetCollection(BackendToolCollection, ops.embeddingFunc)
+ if err != nil {
+ return []*models.BackendToolWithMetadata{}, nil
+ }
+
+ // Get collection count and adjust limit if necessary
+ count := collection.Count()
+ if count == 0 {
+ return []*models.BackendToolWithMetadata{}, nil
+ }
+ if limit > count {
+ limit = count
+ }
+
+ // Build metadata filter if server ID is provided
+ var metadataFilter map[string]string
+ if serverID != nil {
+ metadataFilter = map[string]string{"server_id": *serverID}
+ }
+
+ results, err := collection.Query(ctx, query, limit, metadataFilter, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to search tools: %w", err)
+ }
+
+ tools := make([]*models.BackendToolWithMetadata, 0, len(results))
+ for _, result := range results {
+ tool, err := deserializeToolMetadata(result.Metadata)
+ if err != nil {
+ logger.Warnf("Failed to deserialize tool: %v", err)
+ continue
+ }
+
+ // Add similarity score
+ toolWithMeta := &models.BackendToolWithMetadata{
+ BackendTool: *tool,
+ Similarity: result.Similarity,
+ }
+ tools = append(tools, toolWithMeta)
+ }
+
+ return tools, nil
+}
+
+// Helper functions for metadata serialization
+
+func serializeToolMetadata(tool *models.BackendTool) (map[string]string, error) {
+ data, err := json.Marshal(tool)
+ if err != nil {
+ return nil, err
+ }
+ return map[string]string{
+ "data": string(data),
+ "type": "backend_tool",
+ "server_id": tool.MCPServerID,
+ }, nil
+}
+
+func deserializeToolMetadata(metadata map[string]string) (*models.BackendTool, error) {
+ data, ok := metadata["data"]
+ if !ok {
+ return nil, fmt.Errorf("missing data field in metadata")
+ }
+
+ var tool models.BackendTool
+ if err := json.Unmarshal([]byte(data), &tool); err != nil {
+ return nil, err
+ }
+
+ return &tool, nil
+}
diff --git a/cmd/thv-operator/pkg/optimizer/db/backend_tool_test.go b/cmd/thv-operator/pkg/optimizer/db/backend_tool_test.go
new file mode 100644
index 0000000000..4f9a58b01e
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/db/backend_tool_test.go
@@ -0,0 +1,590 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package db
+
+import (
+ "context"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings"
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models"
+)
+
+// createTestDB creates a test database
+func createTestDB(t *testing.T) *DB {
+ t.Helper()
+ tmpDir := t.TempDir()
+
+ config := &Config{
+ PersistPath: filepath.Join(tmpDir, "test-db"),
+ }
+
+ db, err := NewDB(config)
+ require.NoError(t, err)
+
+ return db
+}
+
+// createTestEmbeddingFunc creates a test embedding function using Ollama embeddings
+func createTestEmbeddingFunc(t *testing.T) func(ctx context.Context, text string) ([]float32, error) {
+ t.Helper()
+
+ // Try to use Ollama if available, otherwise skip test
+ config := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ manager, err := embeddings.NewManager(config)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err)
+ return nil
+ }
+ t.Cleanup(func() { _ = manager.Close() })
+
+ return func(_ context.Context, text string) ([]float32, error) {
+ results, err := manager.GenerateEmbedding([]string{text})
+ if err != nil {
+ return nil, err
+ }
+ if len(results) == 0 {
+ return nil, assert.AnError
+ }
+ return results[0], nil
+ }
+}
+
+// TestBackendToolOps_Create tests creating a backend tool
+func TestBackendToolOps_Create(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendToolOps(db, embeddingFunc)
+
+ description := "Get current weather information"
+ tool := &models.BackendTool{
+ ID: "tool-1",
+ MCPServerID: "server-1",
+ ToolName: "get_weather",
+ Description: &description,
+ InputSchema: []byte(`{"type":"object","properties":{"location":{"type":"string"}}}`),
+ TokenCount: 100,
+ }
+
+ err := ops.Create(ctx, tool, "Test Server")
+ require.NoError(t, err)
+
+ // Verify tool was created by retrieving it
+ retrieved, err := ops.Get(ctx, "tool-1")
+ require.NoError(t, err)
+ assert.Equal(t, "get_weather", retrieved.ToolName)
+ assert.Equal(t, "server-1", retrieved.MCPServerID)
+ assert.Equal(t, description, *retrieved.Description)
+}
+
+// TestBackendToolOps_CreateWithPrecomputedEmbedding tests creating tool with existing embedding
+func TestBackendToolOps_CreateWithPrecomputedEmbedding(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendToolOps(db, embeddingFunc)
+
+ description := "Search the web"
+ // Generate a precomputed embedding
+ precomputedEmbedding := make([]float32, 384)
+ for i := range precomputedEmbedding {
+ precomputedEmbedding[i] = 0.1
+ }
+
+ tool := &models.BackendTool{
+ ID: "tool-2",
+ MCPServerID: "server-1",
+ ToolName: "search_web",
+ Description: &description,
+ InputSchema: []byte(`{}`),
+ ToolEmbedding: precomputedEmbedding,
+ TokenCount: 50,
+ }
+
+ err := ops.Create(ctx, tool, "Test Server")
+ require.NoError(t, err)
+
+ // Verify tool was created
+ retrieved, err := ops.Get(ctx, "tool-2")
+ require.NoError(t, err)
+ assert.Equal(t, "search_web", retrieved.ToolName)
+}
+
+// TestBackendToolOps_Get tests retrieving a backend tool
+func TestBackendToolOps_Get(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendToolOps(db, embeddingFunc)
+
+ // Create a tool first
+ description := "Send an email"
+ tool := &models.BackendTool{
+ ID: "tool-3",
+ MCPServerID: "server-1",
+ ToolName: "send_email",
+ Description: &description,
+ InputSchema: []byte(`{}`),
+ TokenCount: 75,
+ }
+
+ err := ops.Create(ctx, tool, "Test Server")
+ require.NoError(t, err)
+
+ // Test Get
+ retrieved, err := ops.Get(ctx, "tool-3")
+ require.NoError(t, err)
+ assert.Equal(t, "tool-3", retrieved.ID)
+ assert.Equal(t, "send_email", retrieved.ToolName)
+}
+
+// TestBackendToolOps_Get_NotFound tests retrieving non-existent tool
+func TestBackendToolOps_Get_NotFound(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendToolOps(db, embeddingFunc)
+
+ // Try to get a non-existent tool
+ _, err := ops.Get(ctx, "non-existent")
+ assert.Error(t, err)
+}
+
+// TestBackendToolOps_Update tests updating a backend tool
+func TestBackendToolOps_Update(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendToolOps(db, embeddingFunc)
+
+ // Create initial tool
+ description := "Original description"
+ tool := &models.BackendTool{
+ ID: "tool-4",
+ MCPServerID: "server-1",
+ ToolName: "test_tool",
+ Description: &description,
+ InputSchema: []byte(`{}`),
+ TokenCount: 50,
+ }
+
+ err := ops.Create(ctx, tool, "Test Server")
+ require.NoError(t, err)
+
+ // Update the tool
+ const updatedDescription = "Updated description"
+ updatedDescriptionCopy := updatedDescription
+ tool.Description = &updatedDescriptionCopy
+ tool.TokenCount = 75
+
+ err = ops.Update(ctx, tool)
+ require.NoError(t, err)
+
+ // Verify update
+ retrieved, err := ops.Get(ctx, "tool-4")
+ require.NoError(t, err)
+ assert.Equal(t, "Updated description", *retrieved.Description)
+}
+
+// TestBackendToolOps_Delete tests deleting a backend tool
+func TestBackendToolOps_Delete(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendToolOps(db, embeddingFunc)
+
+ // Create a tool
+ description := "Tool to delete"
+ tool := &models.BackendTool{
+ ID: "tool-5",
+ MCPServerID: "server-1",
+ ToolName: "delete_me",
+ Description: &description,
+ InputSchema: []byte(`{}`),
+ TokenCount: 25,
+ }
+
+ err := ops.Create(ctx, tool, "Test Server")
+ require.NoError(t, err)
+
+ // Delete the tool
+ err = ops.Delete(ctx, "tool-5")
+ require.NoError(t, err)
+
+ // Verify deletion
+ _, err = ops.Get(ctx, "tool-5")
+ assert.Error(t, err, "Should not find deleted tool")
+}
+
+// TestBackendToolOps_Delete_NonExistent tests deleting non-existent tool
+func TestBackendToolOps_Delete_NonExistent(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendToolOps(db, embeddingFunc)
+
+ // Try to delete a non-existent tool - should not error
+ err := ops.Delete(ctx, "non-existent")
+ // Delete may or may not error depending on implementation
+ // Just ensure it doesn't panic
+ _ = err
+}
+
+// TestBackendToolOps_ListByServer tests listing tools for a server
+func TestBackendToolOps_ListByServer(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendToolOps(db, embeddingFunc)
+
+ // Create multiple tools for different servers
+ desc1 := "Tool 1"
+ tool1 := &models.BackendTool{
+ ID: "tool-1",
+ MCPServerID: "server-1",
+ ToolName: "tool_1",
+ Description: &desc1,
+ InputSchema: []byte(`{}`),
+ TokenCount: 10,
+ }
+
+ desc2 := "Tool 2"
+ tool2 := &models.BackendTool{
+ ID: "tool-2",
+ MCPServerID: "server-1",
+ ToolName: "tool_2",
+ Description: &desc2,
+ InputSchema: []byte(`{}`),
+ TokenCount: 20,
+ }
+
+ desc3 := "Tool 3"
+ tool3 := &models.BackendTool{
+ ID: "tool-3",
+ MCPServerID: "server-2",
+ ToolName: "tool_3",
+ Description: &desc3,
+ InputSchema: []byte(`{}`),
+ TokenCount: 30,
+ }
+
+ err := ops.Create(ctx, tool1, "Server 1")
+ require.NoError(t, err)
+ err = ops.Create(ctx, tool2, "Server 1")
+ require.NoError(t, err)
+ err = ops.Create(ctx, tool3, "Server 2")
+ require.NoError(t, err)
+
+ // List tools for server-1
+ tools, err := ops.ListByServer(ctx, "server-1")
+ require.NoError(t, err)
+ assert.Len(t, tools, 2, "Should have 2 tools for server-1")
+
+ // Verify tool names
+ toolNames := make(map[string]bool)
+ for _, tool := range tools {
+ toolNames[tool.ToolName] = true
+ }
+ assert.True(t, toolNames["tool_1"])
+ assert.True(t, toolNames["tool_2"])
+
+ // List tools for server-2
+ tools, err = ops.ListByServer(ctx, "server-2")
+ require.NoError(t, err)
+ assert.Len(t, tools, 1, "Should have 1 tool for server-2")
+ assert.Equal(t, "tool_3", tools[0].ToolName)
+}
+
+// TestBackendToolOps_ListByServer_Empty tests listing tools for server with no tools
+func TestBackendToolOps_ListByServer_Empty(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendToolOps(db, embeddingFunc)
+
+ // List tools for non-existent server
+ tools, err := ops.ListByServer(ctx, "non-existent-server")
+ require.NoError(t, err)
+ assert.Empty(t, tools, "Should return empty list for server with no tools")
+}
+
+// TestBackendToolOps_DeleteByServer tests deleting all tools for a server
+func TestBackendToolOps_DeleteByServer(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendToolOps(db, embeddingFunc)
+
+ // Create tools for two servers
+ desc1 := "Tool 1"
+ tool1 := &models.BackendTool{
+ ID: "tool-1",
+ MCPServerID: "server-1",
+ ToolName: "tool_1",
+ Description: &desc1,
+ InputSchema: []byte(`{}`),
+ TokenCount: 10,
+ }
+
+ desc2 := "Tool 2"
+ tool2 := &models.BackendTool{
+ ID: "tool-2",
+ MCPServerID: "server-1",
+ ToolName: "tool_2",
+ Description: &desc2,
+ InputSchema: []byte(`{}`),
+ TokenCount: 20,
+ }
+
+ desc3 := "Tool 3"
+ tool3 := &models.BackendTool{
+ ID: "tool-3",
+ MCPServerID: "server-2",
+ ToolName: "tool_3",
+ Description: &desc3,
+ InputSchema: []byte(`{}`),
+ TokenCount: 30,
+ }
+
+ err := ops.Create(ctx, tool1, "Server 1")
+ require.NoError(t, err)
+ err = ops.Create(ctx, tool2, "Server 1")
+ require.NoError(t, err)
+ err = ops.Create(ctx, tool3, "Server 2")
+ require.NoError(t, err)
+
+ // Delete all tools for server-1
+ err = ops.DeleteByServer(ctx, "server-1")
+ require.NoError(t, err)
+
+ // Verify server-1 tools are deleted
+ tools, err := ops.ListByServer(ctx, "server-1")
+ require.NoError(t, err)
+ assert.Empty(t, tools, "All server-1 tools should be deleted")
+
+ // Verify server-2 tools are still present
+ tools, err = ops.ListByServer(ctx, "server-2")
+ require.NoError(t, err)
+ assert.Len(t, tools, 1, "Server-2 tools should remain")
+}
+
+// TestBackendToolOps_Search tests semantic search for tools
+func TestBackendToolOps_Search(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendToolOps(db, embeddingFunc)
+
+ // Create test tools
+ desc1 := "Get current weather conditions"
+ tool1 := &models.BackendTool{
+ ID: "tool-1",
+ MCPServerID: "server-1",
+ ToolName: "get_weather",
+ Description: &desc1,
+ InputSchema: []byte(`{}`),
+ TokenCount: 50,
+ }
+
+ desc2 := "Send email message"
+ tool2 := &models.BackendTool{
+ ID: "tool-2",
+ MCPServerID: "server-1",
+ ToolName: "send_email",
+ Description: &desc2,
+ InputSchema: []byte(`{}`),
+ TokenCount: 40,
+ }
+
+ err := ops.Create(ctx, tool1, "Server 1")
+ require.NoError(t, err)
+ err = ops.Create(ctx, tool2, "Server 1")
+ require.NoError(t, err)
+
+ // Search for tools
+ results, err := ops.Search(ctx, "weather information", 5, nil)
+ require.NoError(t, err)
+ assert.NotEmpty(t, results, "Should find tools")
+
+ // Weather tool should be most similar to weather query
+ assert.NotEmpty(t, results, "Should find at least one tool")
+ if len(results) > 0 {
+ assert.Equal(t, "get_weather", results[0].ToolName,
+ "Weather tool should be most similar to weather query")
+ }
+}
+
+// TestBackendToolOps_Search_WithServerFilter tests search with server ID filter
+func TestBackendToolOps_Search_WithServerFilter(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendToolOps(db, embeddingFunc)
+
+ // Create tools for different servers
+ desc1 := "Weather tool"
+ tool1 := &models.BackendTool{
+ ID: "tool-1",
+ MCPServerID: "server-1",
+ ToolName: "get_weather",
+ Description: &desc1,
+ InputSchema: []byte(`{}`),
+ TokenCount: 50,
+ }
+
+ desc2 := "Email tool"
+ tool2 := &models.BackendTool{
+ ID: "tool-2",
+ MCPServerID: "server-2",
+ ToolName: "send_email",
+ Description: &desc2,
+ InputSchema: []byte(`{}`),
+ TokenCount: 40,
+ }
+
+ err := ops.Create(ctx, tool1, "Server 1")
+ require.NoError(t, err)
+ err = ops.Create(ctx, tool2, "Server 2")
+ require.NoError(t, err)
+
+ // Search with server filter
+ serverID := "server-1"
+ results, err := ops.Search(ctx, "tool", 5, &serverID)
+ require.NoError(t, err)
+ assert.Len(t, results, 1, "Should only return tools from server-1")
+ assert.Equal(t, "server-1", results[0].MCPServerID)
+}
+
+// TestBackendToolOps_Search_Empty tests search on empty database
+func TestBackendToolOps_Search_Empty(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ db := createTestDB(t)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := createTestEmbeddingFunc(t)
+ ops := NewBackendToolOps(db, embeddingFunc)
+
+ // Search empty database
+ results, err := ops.Search(ctx, "anything", 5, nil)
+ require.NoError(t, err)
+ assert.Empty(t, results, "Should return empty results for empty database")
+}
+
+// TestBackendToolOps_MetadataSerialization tests metadata serialization/deserialization
+func TestBackendToolOps_MetadataSerialization(t *testing.T) {
+ t.Parallel()
+
+ description := "Test tool"
+ tool := &models.BackendTool{
+ ID: "tool-1",
+ MCPServerID: "server-1",
+ ToolName: "test_tool",
+ Description: &description,
+ InputSchema: []byte(`{"type":"object"}`),
+ TokenCount: 100,
+ }
+
+ // Test serialization
+ metadata, err := serializeToolMetadata(tool)
+ require.NoError(t, err)
+ assert.Contains(t, metadata, "data")
+ assert.Equal(t, "backend_tool", metadata["type"])
+ assert.Equal(t, "server-1", metadata["server_id"])
+
+ // Test deserialization
+ deserializedTool, err := deserializeToolMetadata(metadata)
+ require.NoError(t, err)
+ assert.Equal(t, tool.ID, deserializedTool.ID)
+ assert.Equal(t, tool.ToolName, deserializedTool.ToolName)
+ assert.Equal(t, tool.MCPServerID, deserializedTool.MCPServerID)
+}
+
+// TestBackendToolOps_MetadataDeserialization_MissingData tests error handling
+func TestBackendToolOps_MetadataDeserialization_MissingData(t *testing.T) {
+ t.Parallel()
+
+ // Test with missing data field
+ metadata := map[string]string{
+ "type": "backend_tool",
+ }
+
+ _, err := deserializeToolMetadata(metadata)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "missing data field")
+}
+
+// TestBackendToolOps_MetadataDeserialization_InvalidJSON tests invalid JSON handling
+func TestBackendToolOps_MetadataDeserialization_InvalidJSON(t *testing.T) {
+ t.Parallel()
+
+ // Test with invalid JSON
+ metadata := map[string]string{
+ "data": "invalid json {",
+ "type": "backend_tool",
+ }
+
+ _, err := deserializeToolMetadata(metadata)
+ assert.Error(t, err)
+}
diff --git a/cmd/thv-operator/pkg/optimizer/db/backend_tool_test_coverage.go b/cmd/thv-operator/pkg/optimizer/db/backend_tool_test_coverage.go
new file mode 100644
index 0000000000..1e3c7b7e84
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/db/backend_tool_test_coverage.go
@@ -0,0 +1,99 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package db
+
+import (
+ "context"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models"
+)
+
+// TestBackendToolOps_Create_FTS tests FTS integration in Create
+func TestBackendToolOps_Create_FTS(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ config := &Config{
+ PersistPath: filepath.Join(tmpDir, "test-db"),
+ FTSDBPath: filepath.Join(tmpDir, "fts.db"),
+ }
+
+ db, err := NewDB(config)
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
+ return []float32{0.1, 0.2, 0.3}, nil
+ }
+
+ ops := NewBackendToolOps(db, embeddingFunc)
+
+ desc := "A test tool"
+ tool := &models.BackendTool{
+ ID: "tool-1",
+ MCPServerID: "server-1",
+ ToolName: "test_tool",
+ Description: &desc,
+ InputSchema: []byte(`{"type": "object"}`),
+ TokenCount: 10,
+ CreatedAt: time.Now(),
+ LastUpdated: time.Now(),
+ }
+
+ // Create should also update FTS
+ err = ops.Create(ctx, tool, "TestServer")
+ require.NoError(t, err)
+
+ // Verify FTS was updated
+ ftsDB := db.GetFTSDB()
+ require.NotNil(t, ftsDB)
+}
+
+// TestBackendToolOps_DeleteByServer_FTS tests FTS integration in DeleteByServer
+func TestBackendToolOps_DeleteByServer_FTS(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ config := &Config{
+ PersistPath: filepath.Join(tmpDir, "test-db"),
+ FTSDBPath: filepath.Join(tmpDir, "fts.db"),
+ }
+
+ db, err := NewDB(config)
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
+ return []float32{0.1, 0.2, 0.3}, nil
+ }
+
+ ops := NewBackendToolOps(db, embeddingFunc)
+
+ desc := "A test tool"
+ tool := &models.BackendTool{
+ ID: "tool-1",
+ MCPServerID: "server-1",
+ ToolName: "test_tool",
+ Description: &desc,
+ InputSchema: []byte(`{"type": "object"}`),
+ TokenCount: 10,
+ CreatedAt: time.Now(),
+ LastUpdated: time.Now(),
+ }
+
+ // Create tool
+ err = ops.Create(ctx, tool, "TestServer")
+ require.NoError(t, err)
+
+ // DeleteByServer should also delete from FTS
+ err = ops.DeleteByServer(ctx, "server-1")
+ require.NoError(t, err)
+}
diff --git a/cmd/thv-operator/pkg/optimizer/db/db.go b/cmd/thv-operator/pkg/optimizer/db/db.go
new file mode 100644
index 0000000000..1e850309ed
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/db/db.go
@@ -0,0 +1,215 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package db
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "strings"
+ "sync"
+
+ "github.com/philippgille/chromem-go"
+
+ "github.com/stacklok/toolhive/pkg/logger"
+)
+
+// Config holds database configuration
+//
+// The optimizer database is designed to be ephemeral - it's rebuilt from scratch
+// on each startup by ingesting MCP backends. Persistence is optional and primarily
+// useful for development/debugging to avoid re-generating embeddings.
+type Config struct {
+ // PersistPath is the optional path for chromem-go persistence.
+ // If empty, chromem-go will be in-memory only (recommended for production).
+ PersistPath string
+
+ // FTSDBPath is the path for SQLite FTS5 database for BM25 search.
+ // If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set.
+ // FTS5 is always enabled for hybrid search.
+ FTSDBPath string
+}
+
+// DB represents the hybrid database (chromem-go + SQLite FTS5) for optimizer data
+type DB struct {
+ config *Config
+ chromem *chromem.DB // Vector/semantic search
+ fts *FTSDatabase // BM25 full-text search (optional)
+ mu sync.RWMutex
+}
+
+// Collection names
+//
+// Terminology: We use "backend_servers" and "backend_tools" to be explicit about
+// tracking MCP server metadata. While vMCP uses "Backend" for the workload concept,
+// the optimizer focuses on the MCP server component for semantic search and tool discovery.
+// This naming convention provides clarity across the database layer.
+const (
+ BackendServerCollection = "backend_servers"
+ BackendToolCollection = "backend_tools"
+)
+
+// NewDB creates a new chromem-go database with FTS5 for hybrid search
+func NewDB(config *Config) (*DB, error) {
+ var chromemDB *chromem.DB
+ var err error
+
+ if config.PersistPath != "" {
+ logger.Infof("Creating chromem-go database with persistence at: %s", config.PersistPath)
+ chromemDB, err = chromem.NewPersistentDB(config.PersistPath, false)
+ if err != nil {
+ // Check if error is due to corrupted database (missing collection metadata)
+ if strings.Contains(err.Error(), "collection metadata file not found") {
+ logger.Warnf("Database appears corrupted, attempting to remove and recreate: %v", err)
+ // Try to remove corrupted database directory
+ // Use RemoveAll which should handle directories recursively
+ // If it fails, we'll try to create with a new path or fall back to in-memory
+ if removeErr := os.RemoveAll(config.PersistPath); removeErr != nil {
+ logger.Warnf("Failed to remove corrupted database directory (may be in use): %v. Will try to recreate anyway.", removeErr)
+ // Try to rename the corrupted directory and create a new one
+ backupPath := config.PersistPath + ".corrupted"
+ if renameErr := os.Rename(config.PersistPath, backupPath); renameErr != nil {
+ logger.Warnf("Failed to rename corrupted database: %v. Attempting to create database anyway.", renameErr)
+ // Continue and let chromem-go handle it - it might work if the corruption is partial
+ } else {
+ logger.Infof("Renamed corrupted database to: %s", backupPath)
+ }
+ }
+ // Retry creating the database
+ chromemDB, err = chromem.NewPersistentDB(config.PersistPath, false)
+ if err != nil {
+ // If still failing, return the error but suggest manual cleanup
+ return nil, fmt.Errorf(
+ "failed to create persistent database after cleanup attempt. Please manually remove %s and try again: %w",
+ config.PersistPath, err)
+ }
+ logger.Info("Successfully recreated database after cleanup")
+ } else {
+ return nil, fmt.Errorf("failed to create persistent database: %w", err)
+ }
+ }
+ } else {
+ logger.Info("Creating in-memory chromem-go database")
+ chromemDB = chromem.NewDB()
+ }
+
+ db := &DB{
+ config: config,
+ chromem: chromemDB,
+ }
+
+ // Set default FTS5 path if not provided
+ ftsPath := config.FTSDBPath
+ if ftsPath == "" {
+ if config.PersistPath != "" {
+ // Persistent mode: store FTS5 alongside chromem-go
+ ftsPath = config.PersistPath + "/fts.db"
+ } else {
+ // In-memory mode: use SQLite in-memory database
+ ftsPath = ":memory:"
+ }
+ }
+
+ // Initialize FTS5 database for BM25 text search (always enabled)
+ logger.Infof("Initializing FTS5 database for hybrid search at: %s", ftsPath)
+ ftsDB, err := NewFTSDatabase(&FTSConfig{DBPath: ftsPath})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create FTS5 database: %w", err)
+ }
+ db.fts = ftsDB
+ logger.Info("Hybrid search enabled (chromem-go + FTS5)")
+
+ logger.Info("Optimizer database initialized successfully")
+ return db, nil
+}
+
+// GetOrCreateCollection gets an existing collection or creates a new one
+func (db *DB) GetOrCreateCollection(
+ _ context.Context,
+ name string,
+ embeddingFunc chromem.EmbeddingFunc,
+) (*chromem.Collection, error) {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+
+ // Try to get existing collection first
+ collection := db.chromem.GetCollection(name, embeddingFunc)
+ if collection != nil {
+ return collection, nil
+ }
+
+ // Create new collection if it doesn't exist
+ collection, err := db.chromem.CreateCollection(name, nil, embeddingFunc)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create collection %s: %w", name, err)
+ }
+
+ logger.Debugf("Created new collection: %s", name)
+ return collection, nil
+}
+
+// GetCollection gets an existing collection
+func (db *DB) GetCollection(name string, embeddingFunc chromem.EmbeddingFunc) (*chromem.Collection, error) {
+ db.mu.RLock()
+ defer db.mu.RUnlock()
+
+ collection := db.chromem.GetCollection(name, embeddingFunc)
+ if collection == nil {
+ return nil, fmt.Errorf("collection not found: %s", name)
+ }
+ return collection, nil
+}
+
+// DeleteCollection deletes a collection
+func (db *DB) DeleteCollection(name string) {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+
+ //nolint:errcheck,gosec // DeleteCollection in chromem-go doesn't return an error
+ db.chromem.DeleteCollection(name)
+ logger.Debugf("Deleted collection: %s", name)
+}
+
+// Close closes both databases
+func (db *DB) Close() error {
+ logger.Info("Closing optimizer databases")
+ // chromem-go doesn't need explicit close, but FTS5 does
+ if db.fts != nil {
+ if err := db.fts.Close(); err != nil {
+ return fmt.Errorf("failed to close FTS database: %w", err)
+ }
+ }
+ return nil
+}
+
+// GetChromemDB returns the underlying chromem.DB instance
+func (db *DB) GetChromemDB() *chromem.DB {
+ return db.chromem
+}
+
+// GetFTSDB returns the FTS database (may be nil if FTS is disabled)
+func (db *DB) GetFTSDB() *FTSDatabase {
+ return db.fts
+}
+
+// Reset clears all collections and FTS tables (useful for testing and startup)
+func (db *DB) Reset() {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+
+ //nolint:errcheck,gosec // DeleteCollection in chromem-go doesn't return an error
+ db.chromem.DeleteCollection(BackendServerCollection)
+ //nolint:errcheck,gosec // DeleteCollection in chromem-go doesn't return an error
+ db.chromem.DeleteCollection(BackendToolCollection)
+
+ // Clear FTS5 tables if available
+ if db.fts != nil {
+ //nolint:errcheck // Best effort cleanup
+ _, _ = db.fts.db.Exec("DELETE FROM backend_tools_fts")
+ //nolint:errcheck // Best effort cleanup
+ _, _ = db.fts.db.Exec("DELETE FROM backend_servers_fts")
+ }
+
+ logger.Debug("Reset all collections and FTS tables")
+}
diff --git a/cmd/thv-operator/pkg/optimizer/db/db_test.go b/cmd/thv-operator/pkg/optimizer/db/db_test.go
new file mode 100644
index 0000000000..4eb98daaeb
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/db/db_test.go
@@ -0,0 +1,305 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package db
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// TestNewDB_CorruptedDatabase tests database recovery from corruption
+func TestNewDB_CorruptedDatabase(t *testing.T) {
+ t.Parallel()
+ tmpDir := t.TempDir()
+ dbPath := filepath.Join(tmpDir, "corrupted-db")
+
+ // Create a directory that looks like a corrupted database
+ err := os.MkdirAll(dbPath, 0755)
+ require.NoError(t, err)
+
+ // Create a file that might cause issues
+ err = os.WriteFile(filepath.Join(dbPath, "some-file"), []byte("corrupted"), 0644)
+ require.NoError(t, err)
+
+ config := &Config{
+ PersistPath: dbPath,
+ }
+
+ // Should recover from corruption
+ db, err := NewDB(config)
+ require.NoError(t, err)
+ require.NotNil(t, db)
+ defer func() { _ = db.Close() }()
+}
+
+// TestNewDB_CorruptedDatabase_RecoveryFailure tests when recovery fails
+func TestNewDB_CorruptedDatabase_RecoveryFailure(t *testing.T) {
+ t.Parallel()
+ tmpDir := t.TempDir()
+ dbPath := filepath.Join(tmpDir, "corrupted-db")
+
+ // Create a directory that looks like a corrupted database
+ err := os.MkdirAll(dbPath, 0755)
+ require.NoError(t, err)
+
+ // Create a file that might cause issues
+ err = os.WriteFile(filepath.Join(dbPath, "some-file"), []byte("corrupted"), 0644)
+ require.NoError(t, err)
+
+ // Make directory read-only to simulate recovery failure
+ // Note: This might not work on all systems, so we'll test the error path differently
+ // Instead, we'll test with an invalid path that can't be created
+ config := &Config{
+ PersistPath: "/invalid/path/that/does/not/exist",
+ }
+
+ _, err = NewDB(config)
+ // Should return error for invalid path
+ assert.Error(t, err)
+}
+
+// TestDB_GetOrCreateCollection tests collection creation and retrieval
+func TestDB_GetOrCreateCollection(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ config := &Config{
+ PersistPath: "", // In-memory
+ }
+
+ db, err := NewDB(config)
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ // Create a simple embedding function
+ embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
+ return []float32{0.1, 0.2, 0.3}, nil
+ }
+
+ // Get or create collection
+ collection, err := db.GetOrCreateCollection(ctx, "test-collection", embeddingFunc)
+ require.NoError(t, err)
+ require.NotNil(t, collection)
+
+ // Get existing collection
+ collection2, err := db.GetOrCreateCollection(ctx, "test-collection", embeddingFunc)
+ require.NoError(t, err)
+ require.NotNil(t, collection2)
+ assert.Equal(t, collection, collection2)
+}
+
+// TestDB_GetCollection tests collection retrieval
+func TestDB_GetCollection(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ config := &Config{
+ PersistPath: "", // In-memory
+ }
+
+ db, err := NewDB(config)
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
+ return []float32{0.1, 0.2, 0.3}, nil
+ }
+
+ // Get non-existent collection should fail
+ _, err = db.GetCollection("non-existent", embeddingFunc)
+ assert.Error(t, err)
+
+ // Create collection first
+ _, err = db.GetOrCreateCollection(ctx, "test-collection", embeddingFunc)
+ require.NoError(t, err)
+
+ // Now get it
+ collection, err := db.GetCollection("test-collection", embeddingFunc)
+ require.NoError(t, err)
+ require.NotNil(t, collection)
+}
+
+// TestDB_DeleteCollection tests collection deletion
+func TestDB_DeleteCollection(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ config := &Config{
+ PersistPath: "", // In-memory
+ }
+
+ db, err := NewDB(config)
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
+ return []float32{0.1, 0.2, 0.3}, nil
+ }
+
+ // Create collection
+ _, err = db.GetOrCreateCollection(ctx, "test-collection", embeddingFunc)
+ require.NoError(t, err)
+
+ // Delete collection
+ db.DeleteCollection("test-collection")
+
+ // Verify it's deleted
+ _, err = db.GetCollection("test-collection", embeddingFunc)
+ assert.Error(t, err)
+}
+
+// TestDB_Reset tests database reset
+func TestDB_Reset(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ config := &Config{
+ PersistPath: "", // In-memory
+ }
+
+ db, err := NewDB(config)
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
+ return []float32{0.1, 0.2, 0.3}, nil
+ }
+
+ // Create collections
+ _, err = db.GetOrCreateCollection(ctx, BackendServerCollection, embeddingFunc)
+ require.NoError(t, err)
+
+ _, err = db.GetOrCreateCollection(ctx, BackendToolCollection, embeddingFunc)
+ require.NoError(t, err)
+
+ // Reset database
+ db.Reset()
+
+ // Verify collections are deleted
+ _, err = db.GetCollection(BackendServerCollection, embeddingFunc)
+ assert.Error(t, err)
+
+ _, err = db.GetCollection(BackendToolCollection, embeddingFunc)
+ assert.Error(t, err)
+}
+
+// TestDB_GetChromemDB tests chromem DB accessor
+func TestDB_GetChromemDB(t *testing.T) {
+ t.Parallel()
+
+ config := &Config{
+ PersistPath: "", // In-memory
+ }
+
+ db, err := NewDB(config)
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ chromemDB := db.GetChromemDB()
+ require.NotNil(t, chromemDB)
+}
+
+// TestDB_GetFTSDB tests FTS DB accessor
+func TestDB_GetFTSDB(t *testing.T) {
+ t.Parallel()
+
+ config := &Config{
+ PersistPath: "", // In-memory
+ }
+
+ db, err := NewDB(config)
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ ftsDB := db.GetFTSDB()
+ require.NotNil(t, ftsDB)
+}
+
+// TestDB_Close tests database closing
+func TestDB_Close(t *testing.T) {
+ t.Parallel()
+
+ config := &Config{
+ PersistPath: "", // In-memory
+ }
+
+ db, err := NewDB(config)
+ require.NoError(t, err)
+
+ err = db.Close()
+ require.NoError(t, err)
+
+ // Multiple closes should be safe
+ err = db.Close()
+ require.NoError(t, err)
+}
+
+// TestNewDB_FTSDBPath tests FTS database path configuration
+func TestNewDB_FTSDBPath(t *testing.T) {
+ t.Parallel()
+ tmpDir := t.TempDir()
+
+ tests := []struct {
+ name string
+ config *Config
+ wantErr bool
+ }{
+ {
+ name: "in-memory FTS with persistent chromem",
+ config: &Config{
+ PersistPath: filepath.Join(tmpDir, "db"),
+ FTSDBPath: ":memory:",
+ },
+ wantErr: false,
+ },
+ {
+ name: "persistent FTS with persistent chromem",
+ config: &Config{
+ PersistPath: filepath.Join(tmpDir, "db2"),
+ FTSDBPath: filepath.Join(tmpDir, "fts.db"),
+ },
+ wantErr: false,
+ },
+ {
+ name: "default FTS path with persistent chromem",
+ config: &Config{
+ PersistPath: filepath.Join(tmpDir, "db3"),
+ // FTSDBPath not set, should default to {PersistPath}/fts.db
+ },
+ wantErr: false,
+ },
+ {
+ name: "in-memory FTS with in-memory chromem",
+ config: &Config{
+ PersistPath: "",
+ FTSDBPath: ":memory:",
+ },
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ db, err := NewDB(tt.config)
+ if tt.wantErr {
+ assert.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ require.NotNil(t, db)
+ defer func() { _ = db.Close() }()
+
+ // Verify FTS DB is accessible
+ ftsDB := db.GetFTSDB()
+ require.NotNil(t, ftsDB)
+ }
+ })
+ }
+}
diff --git a/cmd/thv-operator/pkg/optimizer/db/fts.go b/cmd/thv-operator/pkg/optimizer/db/fts.go
new file mode 100644
index 0000000000..2f444cfae0
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/db/fts.go
@@ -0,0 +1,360 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package db
+
+import (
+ "context"
+ "database/sql"
+ _ "embed"
+ "fmt"
+ "strings"
+ "sync"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models"
+ "github.com/stacklok/toolhive/pkg/logger"
+)
+
+//go:embed schema_fts.sql
+var schemaFTS string
+
+// FTSConfig holds FTS5 database configuration
+type FTSConfig struct {
+ // DBPath is the path to the SQLite database file
+ // If empty, uses ":memory:" for in-memory database
+ DBPath string
+}
+
+// FTSDatabase handles FTS5 (BM25) search operations
+type FTSDatabase struct {
+ config *FTSConfig
+ db *sql.DB
+ mu sync.RWMutex
+}
+
+// NewFTSDatabase creates a new FTS5 database for BM25 search
+func NewFTSDatabase(config *FTSConfig) (*FTSDatabase, error) {
+ dbPath := config.DBPath
+ if dbPath == "" {
+ dbPath = ":memory:"
+ }
+
+ // Open with modernc.org/sqlite (pure Go)
+ sqlDB, err := sql.Open("sqlite", dbPath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to open FTS database: %w", err)
+ }
+
+ // Set pragmas for performance
+ pragmas := []string{
+ "PRAGMA journal_mode=WAL",
+ "PRAGMA synchronous=NORMAL",
+ "PRAGMA foreign_keys=ON",
+ "PRAGMA busy_timeout=5000",
+ }
+
+ for _, pragma := range pragmas {
+ if _, err := sqlDB.Exec(pragma); err != nil {
+ _ = sqlDB.Close()
+ return nil, fmt.Errorf("failed to set pragma: %w", err)
+ }
+ }
+
+ ftsDB := &FTSDatabase{
+ config: config,
+ db: sqlDB,
+ }
+
+ // Initialize schema
+ if err := ftsDB.initializeSchema(); err != nil {
+ _ = sqlDB.Close()
+ return nil, fmt.Errorf("failed to initialize FTS schema: %w", err)
+ }
+
+ logger.Infof("FTS5 database initialized successfully at: %s", dbPath)
+ return ftsDB, nil
+}
+
+// initializeSchema creates the FTS5 tables and triggers
+//
+// Note: We execute the schema directly rather than using a migration framework
+// because the FTS database is ephemeral (destroyed on shutdown, recreated on startup).
+// Migrations are only needed when you need to preserve data across schema changes.
+func (fts *FTSDatabase) initializeSchema() error {
+ fts.mu.Lock()
+ defer fts.mu.Unlock()
+
+ _, err := fts.db.Exec(schemaFTS)
+ if err != nil {
+ return fmt.Errorf("failed to execute schema: %w", err)
+ }
+
+ logger.Debug("FTS5 schema initialized")
+ return nil
+}
+
+// UpsertServer inserts or updates a server in the FTS database
+func (fts *FTSDatabase) UpsertServer(
+ ctx context.Context,
+ server *models.BackendServer,
+) error {
+ fts.mu.Lock()
+ defer fts.mu.Unlock()
+
+ query := `
+ INSERT INTO backend_servers_fts (id, name, description, server_group, last_updated, created_at)
+ VALUES (?, ?, ?, ?, ?, ?)
+ ON CONFLICT(id) DO UPDATE SET
+ name = excluded.name,
+ description = excluded.description,
+ server_group = excluded.server_group,
+ last_updated = excluded.last_updated
+ `
+
+ _, err := fts.db.ExecContext(
+ ctx,
+ query,
+ server.ID,
+ server.Name,
+ server.Description,
+ server.Group,
+ server.LastUpdated,
+ server.CreatedAt,
+ )
+
+ if err != nil {
+ return fmt.Errorf("failed to upsert server in FTS: %w", err)
+ }
+
+ logger.Debugf("Upserted server in FTS: %s", server.ID)
+ return nil
+}
+
+// UpsertToolMeta inserts or updates a tool in the FTS database
+func (fts *FTSDatabase) UpsertToolMeta(
+ ctx context.Context,
+ tool *models.BackendTool,
+ _ string, // serverName - unused, keeping for interface compatibility
+) error {
+ fts.mu.Lock()
+ defer fts.mu.Unlock()
+
+ // Convert input schema to JSON string
+ var schemaStr *string
+ if len(tool.InputSchema) > 0 {
+ str := string(tool.InputSchema)
+ schemaStr = &str
+ }
+
+ query := `
+ INSERT INTO backend_tools_fts (
+ id, mcpserver_id, tool_name, tool_description,
+ input_schema, token_count, last_updated, created_at
+ )
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
+ ON CONFLICT(id) DO UPDATE SET
+ mcpserver_id = excluded.mcpserver_id,
+ tool_name = excluded.tool_name,
+ tool_description = excluded.tool_description,
+ input_schema = excluded.input_schema,
+ token_count = excluded.token_count,
+ last_updated = excluded.last_updated
+ `
+
+ _, err := fts.db.ExecContext(
+ ctx,
+ query,
+ tool.ID,
+ tool.MCPServerID,
+ tool.ToolName,
+ tool.Description,
+ schemaStr,
+ tool.TokenCount,
+ tool.LastUpdated,
+ tool.CreatedAt,
+ )
+
+ if err != nil {
+ return fmt.Errorf("failed to upsert tool in FTS: %w", err)
+ }
+
+ logger.Debugf("Upserted tool in FTS: %s", tool.ToolName)
+ return nil
+}
+
+// DeleteServer removes a server and its tools from FTS database
+func (fts *FTSDatabase) DeleteServer(ctx context.Context, serverID string) error {
+ fts.mu.Lock()
+ defer fts.mu.Unlock()
+
+ // Foreign key cascade will delete related tools
+ _, err := fts.db.ExecContext(ctx, "DELETE FROM backend_servers_fts WHERE id = ?", serverID)
+ if err != nil {
+ return fmt.Errorf("failed to delete server from FTS: %w", err)
+ }
+
+ logger.Debugf("Deleted server from FTS: %s", serverID)
+ return nil
+}
+
+// DeleteToolsByServer removes all tools for a server from FTS database
+func (fts *FTSDatabase) DeleteToolsByServer(ctx context.Context, serverID string) error {
+ fts.mu.Lock()
+ defer fts.mu.Unlock()
+
+ result, err := fts.db.ExecContext(ctx, "DELETE FROM backend_tools_fts WHERE mcpserver_id = ?", serverID)
+ if err != nil {
+ return fmt.Errorf("failed to delete tools from FTS: %w", err)
+ }
+
+ count, _ := result.RowsAffected()
+ logger.Debugf("Deleted %d tools from FTS for server: %s", count, serverID)
+ return nil
+}
+
+// DeleteTool removes a tool from FTS database
+func (fts *FTSDatabase) DeleteTool(ctx context.Context, toolID string) error {
+ fts.mu.Lock()
+ defer fts.mu.Unlock()
+
+ _, err := fts.db.ExecContext(ctx, "DELETE FROM backend_tools_fts WHERE id = ?", toolID)
+ if err != nil {
+ return fmt.Errorf("failed to delete tool from FTS: %w", err)
+ }
+
+ logger.Debugf("Deleted tool from FTS: %s", toolID)
+ return nil
+}
+
+// SearchBM25 performs BM25 full-text search on tools
+func (fts *FTSDatabase) SearchBM25(
+ ctx context.Context,
+ query string,
+ limit int,
+ serverID *string,
+) ([]*models.BackendToolWithMetadata, error) {
+ fts.mu.RLock()
+ defer fts.mu.RUnlock()
+
+ // Sanitize FTS5 query
+ sanitizedQuery := sanitizeFTS5Query(query)
+ if sanitizedQuery == "" {
+ return []*models.BackendToolWithMetadata{}, nil
+ }
+
+ // Build query with optional server filter
+ sqlQuery := `
+ SELECT
+ t.id,
+ t.mcpserver_id,
+ t.tool_name,
+ t.tool_description,
+ t.input_schema,
+ t.token_count,
+ t.last_updated,
+ t.created_at,
+ fts.rank
+ FROM backend_tool_fts_index fts
+ JOIN backend_tools_fts t ON fts.tool_id = t.id
+ WHERE backend_tool_fts_index MATCH ?
+ `
+
+ args := []interface{}{sanitizedQuery}
+
+ if serverID != nil {
+ sqlQuery += " AND t.mcpserver_id = ?"
+ args = append(args, *serverID)
+ }
+
+ sqlQuery += " ORDER BY rank LIMIT ?"
+ args = append(args, limit)
+
+ rows, err := fts.db.QueryContext(ctx, sqlQuery, args...)
+ if err != nil {
+ return nil, fmt.Errorf("failed to search tools: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ var results []*models.BackendToolWithMetadata
+ for rows.Next() {
+ var tool models.BackendTool
+ var schemaStr sql.NullString
+ var rank float32
+
+ err := rows.Scan(
+ &tool.ID,
+ &tool.MCPServerID,
+ &tool.ToolName,
+ &tool.Description,
+ &schemaStr,
+ &tool.TokenCount,
+ &tool.LastUpdated,
+ &tool.CreatedAt,
+ &rank,
+ )
+ if err != nil {
+ logger.Warnf("Failed to scan tool row: %v", err)
+ continue
+ }
+
+ if schemaStr.Valid {
+ tool.InputSchema = []byte(schemaStr.String)
+ }
+
+ // Convert BM25 rank to similarity score (higher is better)
+ // FTS5 rank is negative, so we negate and normalize
+ similarity := float32(1.0 / (1.0 - float64(rank)))
+
+ results = append(results, &models.BackendToolWithMetadata{
+ BackendTool: tool,
+ Similarity: similarity,
+ })
+ }
+
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("error iterating tool rows: %w", err)
+ }
+
+ logger.Debugf("BM25 search found %d tools for query: %s", len(results), query)
+ return results, nil
+}
+
+// GetTotalToolTokens returns the sum of token_count across all tools
+func (fts *FTSDatabase) GetTotalToolTokens(ctx context.Context) (int, error) {
+ fts.mu.RLock()
+ defer fts.mu.RUnlock()
+
+ var totalTokens int
+ query := "SELECT COALESCE(SUM(token_count), 0) FROM backend_tools_fts"
+
+ err := fts.db.QueryRowContext(ctx, query).Scan(&totalTokens)
+ if err != nil {
+ return 0, fmt.Errorf("failed to get total tool tokens: %w", err)
+ }
+
+ return totalTokens, nil
+}
+
+// Close closes the FTS database connection
+func (fts *FTSDatabase) Close() error {
+ return fts.db.Close()
+}
+
+// sanitizeFTS5Query escapes special characters in FTS5 queries
+// FTS5 uses: " * ( ) AND OR NOT
+func sanitizeFTS5Query(query string) string {
+ // Remove or escape special FTS5 characters
+ replacer := strings.NewReplacer(
+ `"`, `""`, // Escape quotes
+ `*`, ` `, // Remove wildcards
+ `(`, ` `, // Remove parentheses
+ `)`, ` `,
+ )
+
+ sanitized := replacer.Replace(query)
+
+ // Remove multiple spaces
+ sanitized = strings.Join(strings.Fields(sanitized), " ")
+
+ return strings.TrimSpace(sanitized)
+}
diff --git a/cmd/thv-operator/pkg/optimizer/db/fts_test_coverage.go b/cmd/thv-operator/pkg/optimizer/db/fts_test_coverage.go
new file mode 100644
index 0000000000..b4b1911b93
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/db/fts_test_coverage.go
@@ -0,0 +1,162 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package db
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models"
+)
+
+// stringPtr returns a pointer to the given string
+func stringPtr(s string) *string {
+ return &s
+}
+
+// TestFTSDatabase_GetTotalToolTokens tests token counting
+func TestFTSDatabase_GetTotalToolTokens(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ config := &FTSConfig{
+ DBPath: ":memory:",
+ }
+
+ ftsDB, err := NewFTSDatabase(config)
+ require.NoError(t, err)
+ defer func() { _ = ftsDB.Close() }()
+
+ // Initially should be 0
+ totalTokens, err := ftsDB.GetTotalToolTokens(ctx)
+ require.NoError(t, err)
+ assert.Equal(t, 0, totalTokens)
+
+ // Add a tool
+ tool := &models.BackendTool{
+ ID: "tool-1",
+ MCPServerID: "server-1",
+ ToolName: "test_tool",
+ Description: stringPtr("Test tool"),
+ TokenCount: 100,
+ CreatedAt: time.Now(),
+ LastUpdated: time.Now(),
+ }
+
+ err = ftsDB.UpsertToolMeta(ctx, tool, "TestServer")
+ require.NoError(t, err)
+
+ // Should now have tokens
+ totalTokens, err = ftsDB.GetTotalToolTokens(ctx)
+ require.NoError(t, err)
+ assert.Equal(t, 100, totalTokens)
+
+ // Add another tool
+ tool2 := &models.BackendTool{
+ ID: "tool-2",
+ MCPServerID: "server-1",
+ ToolName: "test_tool2",
+ Description: stringPtr("Test tool 2"),
+ TokenCount: 50,
+ CreatedAt: time.Now(),
+ LastUpdated: time.Now(),
+ }
+
+ err = ftsDB.UpsertToolMeta(ctx, tool2, "TestServer")
+ require.NoError(t, err)
+
+ // Should sum tokens
+ totalTokens, err = ftsDB.GetTotalToolTokens(ctx)
+ require.NoError(t, err)
+ assert.Equal(t, 150, totalTokens)
+}
+
+// TestSanitizeFTS5Query tests query sanitization
+func TestSanitizeFTS5Query(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ expected string
+ }{
+ {
+ name: "remove quotes",
+ input: `"test query"`,
+ expected: "test query",
+ },
+ {
+ name: "remove wildcards",
+ input: "test*query",
+ expected: "test query",
+ },
+ {
+ name: "remove parentheses",
+ input: "test(query)",
+ expected: "test query",
+ },
+ {
+ name: "remove multiple spaces",
+ input: "test query",
+ expected: "test query",
+ },
+ {
+ name: "trim whitespace",
+ input: " test query ",
+ expected: "test query",
+ },
+ {
+ name: "empty string",
+ input: "",
+ expected: "",
+ },
+ {
+ name: "only special characters",
+ input: `"*()`,
+ expected: "",
+ },
+ {
+ name: "mixed special characters",
+ input: `test"query*with(special)chars`,
+ expected: "test query with special chars",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result := sanitizeFTS5Query(tt.input)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+// TestFTSDatabase_SearchBM25_EmptyQuery tests empty query handling
+func TestFTSDatabase_SearchBM25_EmptyQuery(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ config := &FTSConfig{
+ DBPath: ":memory:",
+ }
+
+ ftsDB, err := NewFTSDatabase(config)
+ require.NoError(t, err)
+ defer func() { _ = ftsDB.Close() }()
+
+ // Empty query should return empty results
+ results, err := ftsDB.SearchBM25(ctx, "", 10, nil)
+ require.NoError(t, err)
+ assert.Empty(t, results)
+
+ // Query with only special characters should return empty results
+ results, err = ftsDB.SearchBM25(ctx, `"*()`, 10, nil)
+ require.NoError(t, err)
+ assert.Empty(t, results)
+}
diff --git a/cmd/thv-operator/pkg/optimizer/db/hybrid.go b/cmd/thv-operator/pkg/optimizer/db/hybrid.go
new file mode 100644
index 0000000000..27df70d696
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/db/hybrid.go
@@ -0,0 +1,172 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package db
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models"
+ "github.com/stacklok/toolhive/pkg/logger"
+)
+
+// HybridSearchConfig configures hybrid search behavior
+type HybridSearchConfig struct {
+ // SemanticRatio controls the mix of semantic vs BM25 results (0-100, representing percentage)
+ // Default: 70 (70% semantic, 30% BM25)
+ SemanticRatio int
+
+ // Limit is the total number of results to return
+ Limit int
+
+ // ServerID optionally filters results to a specific server
+ ServerID *string
+}
+
+// DefaultHybridConfig returns sensible defaults for hybrid search
+func DefaultHybridConfig() *HybridSearchConfig {
+ return &HybridSearchConfig{
+ SemanticRatio: 70,
+ Limit: 10,
+ }
+}
+
+// SearchHybrid performs hybrid search combining semantic (chromem-go) and BM25 (FTS5) results
+// This matches the Python mcp-optimizer's hybrid search implementation
+func (ops *BackendToolOps) SearchHybrid(
+ ctx context.Context,
+ queryText string,
+ config *HybridSearchConfig,
+) ([]*models.BackendToolWithMetadata, error) {
+ if config == nil {
+ config = DefaultHybridConfig()
+ }
+
+ // Calculate limits for each search method
+ // Convert percentage to ratio (0-100 -> 0.0-1.0)
+ semanticRatioFloat := float64(config.SemanticRatio) / 100.0
+ semanticLimit := max(1, int(float64(config.Limit)*semanticRatioFloat))
+ bm25Limit := max(1, config.Limit-semanticLimit)
+
+ logger.Debugf(
+ "Hybrid search: semantic_limit=%d, bm25_limit=%d, ratio=%d%%",
+ semanticLimit, bm25Limit, config.SemanticRatio,
+ )
+
+ // Execute both searches in parallel
+ type searchResult struct {
+ results []*models.BackendToolWithMetadata
+ err error
+ }
+
+ semanticCh := make(chan searchResult, 1)
+ bm25Ch := make(chan searchResult, 1)
+
+ // Semantic search
+ go func() {
+ results, err := ops.Search(ctx, queryText, semanticLimit, config.ServerID)
+ semanticCh <- searchResult{results, err}
+ }()
+
+ // BM25 search
+ go func() {
+ results, err := ops.db.fts.SearchBM25(ctx, queryText, bm25Limit, config.ServerID)
+ bm25Ch <- searchResult{results, err}
+ }()
+
+ // Collect results
+ var semanticResults, bm25Results []*models.BackendToolWithMetadata
+ var errs []error
+
+ // Wait for semantic results
+ semanticRes := <-semanticCh
+ if semanticRes.err != nil {
+ logger.Warnf("Semantic search failed: %v", semanticRes.err)
+ errs = append(errs, semanticRes.err)
+ } else {
+ semanticResults = semanticRes.results
+ }
+
+ // Wait for BM25 results
+ bm25Res := <-bm25Ch
+ if bm25Res.err != nil {
+ logger.Warnf("BM25 search failed: %v", bm25Res.err)
+ errs = append(errs, bm25Res.err)
+ } else {
+ bm25Results = bm25Res.results
+ }
+
+ // If both failed, return error
+ if len(errs) == 2 {
+ return nil, fmt.Errorf("both search methods failed: semantic=%v, bm25=%v", errs[0], errs[1])
+ }
+
+ // Combine and deduplicate results
+ combined := combineAndDeduplicateResults(semanticResults, bm25Results, config.Limit)
+
+ logger.Infof(
+ "Hybrid search completed: semantic=%d, bm25=%d, combined=%d (requested=%d)",
+ len(semanticResults), len(bm25Results), len(combined), config.Limit,
+ )
+
+ return combined, nil
+}
+
+// combineAndDeduplicateResults merges semantic and BM25 results, removing duplicates
+// Keeps the result with the higher similarity score for duplicates
+func combineAndDeduplicateResults(
+ semantic, bm25 []*models.BackendToolWithMetadata,
+ limit int,
+) []*models.BackendToolWithMetadata {
+ // Use a map to deduplicate by tool ID
+ seen := make(map[string]*models.BackendToolWithMetadata)
+
+ // Add semantic results first (they typically have higher quality)
+ for _, result := range semantic {
+ seen[result.ID] = result
+ }
+
+ // Add BM25 results, only if not seen or if similarity is higher
+ for _, result := range bm25 {
+ if existing, exists := seen[result.ID]; exists {
+ // Keep the one with higher similarity
+ if result.Similarity > existing.Similarity {
+ seen[result.ID] = result
+ }
+ } else {
+ seen[result.ID] = result
+ }
+ }
+
+ // Convert map to slice
+ combined := make([]*models.BackendToolWithMetadata, 0, len(seen))
+ for _, result := range seen {
+ combined = append(combined, result)
+ }
+
+ // Sort by similarity (descending) and limit
+ sortedResults := sortBySimilarity(combined)
+ if len(sortedResults) > limit {
+ sortedResults = sortedResults[:limit]
+ }
+
+ return sortedResults
+}
+
+// sortBySimilarity sorts results by similarity score in descending order
+func sortBySimilarity(results []*models.BackendToolWithMetadata) []*models.BackendToolWithMetadata {
+ // Simple bubble sort (fine for small result sets)
+ sorted := make([]*models.BackendToolWithMetadata, len(results))
+ copy(sorted, results)
+
+ for i := 0; i < len(sorted); i++ {
+ for j := i + 1; j < len(sorted); j++ {
+ if sorted[j].Similarity > sorted[i].Similarity {
+ sorted[i], sorted[j] = sorted[j], sorted[i]
+ }
+ }
+ }
+
+ return sorted
+}
diff --git a/cmd/thv-operator/pkg/optimizer/db/schema_fts.sql b/cmd/thv-operator/pkg/optimizer/db/schema_fts.sql
new file mode 100644
index 0000000000..101dbea7d7
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/db/schema_fts.sql
@@ -0,0 +1,120 @@
+-- FTS5 schema for BM25 full-text search
+-- Complements chromem-go (which handles vector/semantic search)
+--
+-- This schema only contains:
+-- 1. Metadata tables for tool/server information
+-- 2. FTS5 virtual tables for BM25 keyword search
+--
+-- Note: chromem-go handles embeddings separately in memory/persistent storage
+
+-- Backend servers metadata (for FTS queries and joining)
+CREATE TABLE IF NOT EXISTS backend_servers_fts (
+ id TEXT PRIMARY KEY,
+ name TEXT NOT NULL,
+ description TEXT,
+ server_group TEXT NOT NULL DEFAULT 'default',
+ last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+
+CREATE INDEX IF NOT EXISTS idx_backend_servers_fts_group ON backend_servers_fts(server_group);
+
+-- Backend tools metadata (for FTS queries and joining)
+CREATE TABLE IF NOT EXISTS backend_tools_fts (
+ id TEXT PRIMARY KEY,
+ mcpserver_id TEXT NOT NULL,
+ tool_name TEXT NOT NULL,
+ tool_description TEXT,
+ input_schema TEXT, -- JSON string
+ token_count INTEGER NOT NULL DEFAULT 0,
+ last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (mcpserver_id) REFERENCES backend_servers_fts(id) ON DELETE CASCADE
+);
+
+CREATE INDEX IF NOT EXISTS idx_backend_tools_fts_server ON backend_tools_fts(mcpserver_id);
+CREATE INDEX IF NOT EXISTS idx_backend_tools_fts_name ON backend_tools_fts(tool_name);
+
+-- FTS5 virtual table for backend tools
+-- Uses Porter stemming for better keyword matching
+-- Indexes: server name, tool name, and tool description
+CREATE VIRTUAL TABLE IF NOT EXISTS backend_tool_fts_index
+USING fts5(
+ tool_id UNINDEXED,
+ mcp_server_name,
+ tool_name,
+ tool_description,
+ tokenize='porter',
+ content='backend_tools_fts',
+ content_rowid='rowid'
+);
+
+-- Triggers to keep FTS5 index in sync with backend_tools_fts table
+CREATE TRIGGER IF NOT EXISTS backend_tools_fts_ai AFTER INSERT ON backend_tools_fts BEGIN
+ INSERT INTO backend_tool_fts_index(
+ rowid,
+ tool_id,
+ mcp_server_name,
+ tool_name,
+ tool_description
+ )
+ SELECT
+ rowid,
+ new.id,
+ (SELECT name FROM backend_servers_fts WHERE id = new.mcpserver_id),
+ new.tool_name,
+ COALESCE(new.tool_description, '')
+ FROM backend_tools_fts
+ WHERE id = new.id;
+END;
+
+CREATE TRIGGER IF NOT EXISTS backend_tools_fts_ad AFTER DELETE ON backend_tools_fts BEGIN
+ INSERT INTO backend_tool_fts_index(
+ backend_tool_fts_index,
+ rowid,
+ tool_id,
+ mcp_server_name,
+ tool_name,
+ tool_description
+ ) VALUES (
+ 'delete',
+ old.rowid,
+ old.id,
+ NULL,
+ NULL,
+ NULL
+ );
+END;
+
+CREATE TRIGGER IF NOT EXISTS backend_tools_fts_au AFTER UPDATE ON backend_tools_fts BEGIN
+ INSERT INTO backend_tool_fts_index(
+ backend_tool_fts_index,
+ rowid,
+ tool_id,
+ mcp_server_name,
+ tool_name,
+ tool_description
+ ) VALUES (
+ 'delete',
+ old.rowid,
+ old.id,
+ NULL,
+ NULL,
+ NULL
+ );
+ INSERT INTO backend_tool_fts_index(
+ rowid,
+ tool_id,
+ mcp_server_name,
+ tool_name,
+ tool_description
+ )
+ SELECT
+ rowid,
+ new.id,
+ (SELECT name FROM backend_servers_fts WHERE id = new.mcpserver_id),
+ new.tool_name,
+ COALESCE(new.tool_description, '')
+ FROM backend_tools_fts
+ WHERE id = new.id;
+END;
diff --git a/cmd/thv-operator/pkg/optimizer/db/sqlite_fts.go b/cmd/thv-operator/pkg/optimizer/db/sqlite_fts.go
new file mode 100644
index 0000000000..23ae5bcdfb
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/db/sqlite_fts.go
@@ -0,0 +1,11 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+// Package db provides database operations for the optimizer.
+// This file handles FTS5 (Full-Text Search) using modernc.org/sqlite (pure Go).
+package db
+
+import (
+ // Pure Go SQLite driver with FTS5 support
+ _ "modernc.org/sqlite"
+)
diff --git a/cmd/thv-operator/pkg/optimizer/doc.go b/cmd/thv-operator/pkg/optimizer/doc.go
new file mode 100644
index 0000000000..c59b7556a1
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/doc.go
@@ -0,0 +1,88 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+// Package optimizer provides semantic tool discovery and ingestion for MCP servers.
+//
+// The optimizer package implements an ingestion service that discovers MCP backends
+// from ToolHive, generates semantic embeddings for tools using ONNX Runtime, and stores
+// them in a SQLite database with vector search capabilities.
+//
+// # Architecture
+//
+// The optimizer follows a similar architecture to mcp-optimizer (Python) but adapted
+// for Go idioms and patterns:
+//
+// pkg/optimizer/
+// ├── doc.go // Package documentation
+// ├── models/ // Database models and types
+// │ ├── models.go // Core domain models (Server, Tool, etc.)
+// │ └── transport.go // Transport and status enums
+// ├── db/ // Database layer
+// │ ├── db.go // Database connection and config
+// │ ├── fts.go // FTS5 database for BM25 search
+// │ ├── schema_fts.sql // Embedded FTS5 schema (executed directly)
+// │ ├── hybrid.go // Hybrid search (semantic + BM25)
+// │ ├── backend_server.go // Backend server operations
+// │ └── backend_tool.go // Backend tool operations
+// ├── embeddings/ // Embedding generation
+// │ ├── manager.go // Embedding manager with ONNX Runtime
+// │ └── cache.go // Optional embedding cache
+// ├── mcpclient/ // MCP client for tool discovery
+// │ └── client.go // MCP client wrapper
+// ├── ingestion/ // Core ingestion service
+// │ ├── service.go // Ingestion service implementation
+// │ └── errors.go // Custom errors
+// └── tokens/ // Token counting (for LLM consumption)
+// └── counter.go // Token counter using tiktoken-go
+//
+// # Core Concepts
+//
+// **Ingestion**: Discovers MCP backends from ToolHive (via Docker or Kubernetes),
+// connects to each backend to list tools, generates embeddings, and stores in database.
+//
+// **Embeddings**: Uses ONNX Runtime to generate semantic embeddings for tools and servers.
+// Embeddings enable semantic search to find relevant tools based on natural language queries.
+//
+// **Database**: Hybrid approach using chromem-go for vector search and SQLite FTS5 for
+// keyword search. The database is ephemeral (in-memory by default, optional persistence)
+// and schema is initialized directly on startup without migrations.
+//
+// **Terminology**: Uses "BackendServer" and "BackendTool" to explicitly refer to MCP server
+// metadata, distinguishing from vMCP's broader "Backend" concept which represents workloads.
+//
+// **Token Counting**: Tracks token counts for tools to measure LLM consumption and
+// calculate token savings from semantic filtering.
+//
+// # Usage
+//
+// The optimizer is integrated into vMCP as native tools:
+//
+// 1. **vMCP Integration**: The optimizer runs as part of vMCP, exposing
+// optim.find_tool and optim.call_tool to clients.
+//
+// 2. **Event-Driven Ingestion**: Tools are ingested when vMCP sessions
+// are registered, not via polling.
+//
+// Example vMCP integration (see pkg/vmcp/optimizer):
+//
+// import (
+// "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/ingestion"
+// "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings"
+// )
+//
+// // Create embedding manager
+// embMgr, err := embeddings.NewManager(embeddings.Config{
+// BackendType: "ollama", // or "openai-compatible" or "vllm"
+// BaseURL: "http://localhost:11434",
+// Model: "all-minilm",
+// Dimension: 384,
+// })
+//
+// // Create ingestion service
+// svc, err := ingestion.NewService(ctx, ingestion.Config{
+// DBConfig: dbConfig,
+// }, embMgr)
+//
+// // Ingest a server (called by vMCP's OnRegisterSession hook)
+// err = svc.IngestServer(ctx, "weather-service", tools, target)
+package optimizer
diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/cache.go b/cmd/thv-operator/pkg/optimizer/embeddings/cache.go
new file mode 100644
index 0000000000..68f6bbe74b
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/embeddings/cache.go
@@ -0,0 +1,104 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+// Package embeddings provides caching for embedding vectors.
+package embeddings
+
+import (
+ "container/list"
+ "sync"
+)
+
+// cache implements an LRU cache for embeddings
+type cache struct {
+ maxSize int
+ mu sync.RWMutex
+ items map[string]*list.Element
+ lru *list.List
+ hits int64
+ misses int64
+}
+
+type cacheEntry struct {
+ key string
+ value []float32
+}
+
+// newCache creates a new LRU cache
+func newCache(maxSize int) *cache {
+ return &cache{
+ maxSize: maxSize,
+ items: make(map[string]*list.Element),
+ lru: list.New(),
+ }
+}
+
+// Get retrieves an embedding from the cache
+func (c *cache) Get(key string) []float32 {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ elem, ok := c.items[key]
+ if !ok {
+ c.misses++
+ return nil
+ }
+
+ c.hits++
+ c.lru.MoveToFront(elem)
+ return elem.Value.(*cacheEntry).value
+}
+
+// Put stores an embedding in the cache
+func (c *cache) Put(key string, value []float32) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ // Check if key already exists
+ if elem, ok := c.items[key]; ok {
+ c.lru.MoveToFront(elem)
+ elem.Value.(*cacheEntry).value = value
+ return
+ }
+
+ // Add new entry
+ entry := &cacheEntry{
+ key: key,
+ value: value,
+ }
+ elem := c.lru.PushFront(entry)
+ c.items[key] = elem
+
+ // Evict if necessary
+ if c.lru.Len() > c.maxSize {
+ c.evict()
+ }
+}
+
+// evict removes the least recently used item
+func (c *cache) evict() {
+ elem := c.lru.Back()
+ if elem != nil {
+ c.lru.Remove(elem)
+ entry := elem.Value.(*cacheEntry)
+ delete(c.items, entry.key)
+ }
+}
+
+// Size returns the current cache size
+func (c *cache) Size() int {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ return c.lru.Len()
+}
+
+// Clear clears the cache
+func (c *cache) Clear() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.items = make(map[string]*list.Element)
+ c.lru = list.New()
+ c.hits = 0
+ c.misses = 0
+}
diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/cache_test.go b/cmd/thv-operator/pkg/optimizer/embeddings/cache_test.go
new file mode 100644
index 0000000000..9b16346056
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/embeddings/cache_test.go
@@ -0,0 +1,172 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package embeddings
+
+import (
+ "testing"
+)
+
+func TestCache_GetPut(t *testing.T) {
+ t.Parallel()
+ c := newCache(2)
+
+ // Test cache miss
+ result := c.Get("key1")
+ if result != nil {
+ t.Error("Expected cache miss for non-existent key")
+ }
+ if c.misses != 1 {
+ t.Errorf("Expected 1 miss, got %d", c.misses)
+ }
+
+ // Test cache put and hit
+ embedding := []float32{1.0, 2.0, 3.0}
+ c.Put("key1", embedding)
+
+ result = c.Get("key1")
+ if result == nil {
+ t.Fatal("Expected cache hit for existing key")
+ }
+ if c.hits != 1 {
+ t.Errorf("Expected 1 hit, got %d", c.hits)
+ }
+
+ // Verify embedding values
+ if len(result) != len(embedding) {
+ t.Errorf("Embedding length mismatch: got %d, want %d", len(result), len(embedding))
+ }
+ for i := range embedding {
+ if result[i] != embedding[i] {
+ t.Errorf("Embedding value mismatch at index %d: got %f, want %f", i, result[i], embedding[i])
+ }
+ }
+}
+
+func TestCache_LRUEviction(t *testing.T) {
+ t.Parallel()
+ c := newCache(2)
+
+ // Add two items (fills cache)
+ c.Put("key1", []float32{1.0})
+ c.Put("key2", []float32{2.0})
+
+ if c.Size() != 2 {
+ t.Errorf("Expected cache size 2, got %d", c.Size())
+ }
+
+ // Add third item (should evict key1)
+ c.Put("key3", []float32{3.0})
+
+ if c.Size() != 2 {
+ t.Errorf("Expected cache size 2 after eviction, got %d", c.Size())
+ }
+
+ // key1 should be evicted (oldest)
+ if result := c.Get("key1"); result != nil {
+ t.Error("key1 should have been evicted")
+ }
+
+ // key2 and key3 should still exist
+ if result := c.Get("key2"); result == nil {
+ t.Error("key2 should still exist")
+ }
+ if result := c.Get("key3"); result == nil {
+ t.Error("key3 should still exist")
+ }
+}
+
+func TestCache_MoveToFrontOnAccess(t *testing.T) {
+ t.Parallel()
+ c := newCache(2)
+
+ // Add two items
+ c.Put("key1", []float32{1.0})
+ c.Put("key2", []float32{2.0})
+
+ // Access key1 (moves it to front)
+ c.Get("key1")
+
+ // Add third item (should evict key2, not key1)
+ c.Put("key3", []float32{3.0})
+
+ // key1 should still exist (was accessed recently)
+ if result := c.Get("key1"); result == nil {
+ t.Error("key1 should still exist (was accessed recently)")
+ }
+
+ // key2 should be evicted (was oldest)
+ if result := c.Get("key2"); result != nil {
+ t.Error("key2 should have been evicted")
+ }
+
+ // key3 should exist
+ if result := c.Get("key3"); result == nil {
+ t.Error("key3 should exist")
+ }
+}
+
+func TestCache_UpdateExistingKey(t *testing.T) {
+ t.Parallel()
+ c := newCache(2)
+
+ // Add initial value
+ c.Put("key1", []float32{1.0})
+
+ // Update with new value
+ newEmbedding := []float32{2.0, 3.0}
+ c.Put("key1", newEmbedding)
+
+ // Should get updated value
+ result := c.Get("key1")
+ if result == nil {
+ t.Fatal("Expected cache hit for existing key")
+ }
+
+ if len(result) != len(newEmbedding) {
+ t.Errorf("Embedding length mismatch: got %d, want %d", len(result), len(newEmbedding))
+ }
+
+ // Cache size should still be 1
+ if c.Size() != 1 {
+ t.Errorf("Expected cache size 1, got %d", c.Size())
+ }
+}
+
+func TestCache_Clear(t *testing.T) {
+ t.Parallel()
+ c := newCache(10)
+
+ // Add some items
+ c.Put("key1", []float32{1.0})
+ c.Put("key2", []float32{2.0})
+ c.Put("key3", []float32{3.0})
+
+ // Access some items to generate stats
+ c.Get("key1")
+ c.Get("missing")
+
+ if c.Size() != 3 {
+ t.Errorf("Expected cache size 3, got %d", c.Size())
+ }
+
+ // Clear cache
+ c.Clear()
+
+ if c.Size() != 0 {
+ t.Errorf("Expected cache size 0 after clear, got %d", c.Size())
+ }
+
+ // Stats should be reset
+ if c.hits != 0 {
+ t.Errorf("Expected 0 hits after clear, got %d", c.hits)
+ }
+ if c.misses != 0 {
+ t.Errorf("Expected 0 misses after clear, got %d", c.misses)
+ }
+
+ // Items should be gone
+ if result := c.Get("key1"); result != nil {
+ t.Error("key1 should be gone after clear")
+ }
+}
diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/manager.go b/cmd/thv-operator/pkg/optimizer/embeddings/manager.go
new file mode 100644
index 0000000000..4f29729e3b
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/embeddings/manager.go
@@ -0,0 +1,219 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package embeddings
+
+import (
+ "fmt"
+ "strings"
+ "sync"
+
+ "github.com/stacklok/toolhive/pkg/logger"
+)
+
+const (
+ // DefaultModelAllMiniLM is the default Ollama model name
+ DefaultModelAllMiniLM = "all-minilm"
+ // BackendTypeOllama is the Ollama backend type
+ BackendTypeOllama = "ollama"
+)
+
+// Config holds configuration for the embedding manager
+type Config struct {
+ // BackendType specifies which backend to use:
+ // - "ollama": Ollama native API (default)
+ // - "vllm": vLLM OpenAI-compatible API
+ // - "unified": Generic OpenAI-compatible API (works with both)
+ // - "openai": OpenAI-compatible API
+ BackendType string
+
+ // BaseURL is the base URL for the embedding service
+ // - Ollama: http://127.0.0.1:11434 (or http://localhost:11434, will be normalized to 127.0.0.1)
+ // - vLLM: http://localhost:8000
+ BaseURL string
+
+ // Model is the model name to use
+ // - Ollama: "all-minilm" (default), "nomic-embed-text"
+ // - vLLM: "sentence-transformers/all-MiniLM-L6-v2", "intfloat/e5-mistral-7b-instruct"
+ Model string
+
+ // Dimension is the embedding dimension (default 384 for all-MiniLM-L6-v2)
+ Dimension int
+
+ // EnableCache enables caching of embeddings
+ EnableCache bool
+
+ // MaxCacheSize is the maximum number of embeddings to cache (default 1000)
+ MaxCacheSize int
+}
+
+// Backend interface for different embedding implementations
+type Backend interface {
+ Embed(text string) ([]float32, error)
+ EmbedBatch(texts []string) ([][]float32, error)
+ Dimension() int
+ Close() error
+}
+
+// Manager manages embedding generation using pluggable backends
+// Default backend is all-MiniLM-L6-v2 (same model as codegate)
+type Manager struct {
+ config *Config
+ backend Backend
+ cache *cache
+ mu sync.RWMutex
+}
+
+// NewManager creates a new embedding manager
+func NewManager(config *Config) (*Manager, error) {
+ if config.Dimension == 0 {
+ config.Dimension = 384 // Default dimension for all-MiniLM-L6-v2
+ }
+
+ if config.MaxCacheSize == 0 {
+ config.MaxCacheSize = 1000
+ }
+
+ // Default to Ollama
+ if config.BackendType == "" {
+ config.BackendType = BackendTypeOllama
+ }
+
+ // Initialize backend based on configuration
+ var backend Backend
+ var err error
+
+ switch config.BackendType {
+ case BackendTypeOllama:
+ // Use Ollama native API (requires ollama serve)
+ baseURL := config.BaseURL
+ if baseURL == "" {
+ baseURL = "http://127.0.0.1:11434"
+ } else {
+ // Normalize localhost to 127.0.0.1 to avoid IPv6 resolution issues
+ baseURL = strings.ReplaceAll(baseURL, "localhost", "127.0.0.1")
+ }
+ model := config.Model
+ if model == "" {
+ model = DefaultModelAllMiniLM // Default: all-MiniLM-L6-v2
+ }
+ // Update dimension if not set and using default model
+ if config.Dimension == 0 && model == DefaultModelAllMiniLM {
+ config.Dimension = 384
+ }
+ backend, err = NewOllamaBackend(baseURL, model)
+ if err != nil {
+ return nil, fmt.Errorf(
+ "failed to initialize Ollama backend: %w (ensure 'ollama serve' is running and 'ollama pull %s' has been executed)",
+ err, DefaultModelAllMiniLM)
+ }
+
+ case "vllm", "unified", "openai":
+ // Use OpenAI-compatible API
+ // vLLM is recommended for production Kubernetes deployments (GPU-accelerated, high-throughput)
+ // Also supports: Ollama v1 API, OpenAI, or any OpenAI-compatible service
+ if config.BaseURL == "" {
+ return nil, fmt.Errorf("BaseURL is required for %s backend", config.BackendType)
+ }
+ if config.Model == "" {
+ return nil, fmt.Errorf("model is required for %s backend", config.BackendType)
+ }
+ backend, err = NewOpenAICompatibleBackend(config.BaseURL, config.Model, config.Dimension)
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize %s backend: %w", config.BackendType, err)
+ }
+
+ default:
+ return nil, fmt.Errorf("unknown backend type: %s (supported: ollama, vllm, unified, openai)", config.BackendType)
+ }
+
+ m := &Manager{
+ config: config,
+ backend: backend,
+ }
+
+ if config.EnableCache {
+ m.cache = newCache(config.MaxCacheSize)
+ }
+
+ return m, nil
+}
+
+// GenerateEmbedding generates embeddings for the given texts
+// Returns a 2D slice where each row is an embedding for the corresponding text
+// Uses all-MiniLM-L6-v2 by default (same model as codegate)
+func (m *Manager) GenerateEmbedding(texts []string) ([][]float32, error) {
+ if len(texts) == 0 {
+ return nil, fmt.Errorf("no texts provided")
+ }
+
+ // Check cache for single text requests
+ if len(texts) == 1 && m.config.EnableCache && m.cache != nil {
+ if cached := m.cache.Get(texts[0]); cached != nil {
+ logger.Debugf("Cache hit for embedding")
+ return [][]float32{cached}, nil
+ }
+ }
+
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ // Use backend to generate embeddings
+ embeddings, err := m.backend.EmbedBatch(texts)
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate embeddings: %w", err)
+ }
+
+ // Cache single embeddings
+ if len(texts) == 1 && m.config.EnableCache && m.cache != nil {
+ m.cache.Put(texts[0], embeddings[0])
+ }
+
+ logger.Debugf("Generated %d embeddings (dimension: %d)", len(embeddings), m.backend.Dimension())
+ return embeddings, nil
+}
+
+// GetCacheStats returns cache statistics
+func (m *Manager) GetCacheStats() map[string]interface{} {
+ if !m.config.EnableCache || m.cache == nil {
+ return map[string]interface{}{
+ "enabled": false,
+ }
+ }
+
+ return map[string]interface{}{
+ "enabled": true,
+ "hits": m.cache.hits,
+ "misses": m.cache.misses,
+ "size": m.cache.Size(),
+ "maxsize": m.config.MaxCacheSize,
+ }
+}
+
+// ClearCache clears the embedding cache
+func (m *Manager) ClearCache() {
+ if m.config.EnableCache && m.cache != nil {
+ m.cache.Clear()
+ logger.Info("Embedding cache cleared")
+ }
+}
+
+// Close releases resources
+func (m *Manager) Close() error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.backend != nil {
+ return m.backend.Close()
+ }
+
+ return nil
+}
+
+// Dimension returns the embedding dimension
+func (m *Manager) Dimension() int {
+ if m.backend != nil {
+ return m.backend.Dimension()
+ }
+ return m.config.Dimension
+}
diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/manager_test_coverage.go b/cmd/thv-operator/pkg/optimizer/embeddings/manager_test_coverage.go
new file mode 100644
index 0000000000..529d65ec4c
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/embeddings/manager_test_coverage.go
@@ -0,0 +1,158 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package embeddings
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// TestManager_GetCacheStats tests cache statistics
+func TestManager_GetCacheStats(t *testing.T) {
+ t.Parallel()
+
+ config := &Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ EnableCache: true,
+ MaxCacheSize: 100,
+ }
+
+ manager, err := NewManager(config)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ defer func() { _ = manager.Close() }()
+
+ stats := manager.GetCacheStats()
+ require.NotNil(t, stats)
+ assert.True(t, stats["enabled"].(bool))
+ assert.Contains(t, stats, "hits")
+ assert.Contains(t, stats, "misses")
+ assert.Contains(t, stats, "size")
+ assert.Contains(t, stats, "maxsize")
+}
+
+// TestManager_GetCacheStats_Disabled tests cache statistics when cache is disabled
+func TestManager_GetCacheStats_Disabled(t *testing.T) {
+ t.Parallel()
+
+ config := &Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ EnableCache: false,
+ }
+
+ manager, err := NewManager(config)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ defer func() { _ = manager.Close() }()
+
+ stats := manager.GetCacheStats()
+ require.NotNil(t, stats)
+ assert.False(t, stats["enabled"].(bool))
+}
+
+// TestManager_ClearCache tests cache clearing
+func TestManager_ClearCache(t *testing.T) {
+ t.Parallel()
+
+ config := &Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ EnableCache: true,
+ MaxCacheSize: 100,
+ }
+
+ manager, err := NewManager(config)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ defer func() { _ = manager.Close() }()
+
+ // Clear cache should not panic
+ manager.ClearCache()
+
+ // Multiple clears should be safe
+ manager.ClearCache()
+}
+
+// TestManager_ClearCache_Disabled tests cache clearing when cache is disabled
+func TestManager_ClearCache_Disabled(t *testing.T) {
+ t.Parallel()
+
+ config := &Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ EnableCache: false,
+ }
+
+ manager, err := NewManager(config)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ defer func() { _ = manager.Close() }()
+
+ // Clear cache should not panic even when disabled
+ manager.ClearCache()
+}
+
+// TestManager_Dimension tests dimension accessor
+func TestManager_Dimension(t *testing.T) {
+ t.Parallel()
+
+ config := &Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ manager, err := NewManager(config)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ defer func() { _ = manager.Close() }()
+
+ dimension := manager.Dimension()
+ assert.Equal(t, 384, dimension)
+}
+
+// TestManager_Dimension_Default tests default dimension
+func TestManager_Dimension_Default(t *testing.T) {
+ t.Parallel()
+
+ config := &Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ // Dimension not set, should default to 384
+ }
+
+ manager, err := NewManager(config)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ defer func() { _ = manager.Close() }()
+
+ dimension := manager.Dimension()
+ assert.Equal(t, 384, dimension)
+}
diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/ollama.go b/cmd/thv-operator/pkg/optimizer/embeddings/ollama.go
new file mode 100644
index 0000000000..6cb6e1f8a2
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/embeddings/ollama.go
@@ -0,0 +1,148 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package embeddings
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+
+ "github.com/stacklok/toolhive/pkg/logger"
+)
+
+// OllamaBackend implements the Backend interface using Ollama
+// This provides local embeddings without remote API calls
+// Ollama must be running locally (ollama serve)
+type OllamaBackend struct {
+ baseURL string
+ model string
+ dimension int
+ client *http.Client
+}
+
+type ollamaEmbedRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt"`
+}
+
+type ollamaEmbedResponse struct {
+ Embedding []float64 `json:"embedding"`
+}
+
+// normalizeLocalhostURL converts localhost to 127.0.0.1 to avoid IPv6 resolution issues
+func normalizeLocalhostURL(url string) string {
+ // Replace localhost with 127.0.0.1 to ensure IPv4 connection
+ // This prevents connection refused errors when Ollama only listens on IPv4
+ return strings.ReplaceAll(url, "localhost", "127.0.0.1")
+}
+
+// NewOllamaBackend creates a new Ollama backend
+// Requires Ollama to be running locally: ollama serve
+// Default model: all-minilm (all-MiniLM-L6-v2, 384 dimensions)
+func NewOllamaBackend(baseURL, model string) (*OllamaBackend, error) {
+ if baseURL == "" {
+ baseURL = "http://127.0.0.1:11434"
+ } else {
+ // Normalize localhost to 127.0.0.1 to avoid IPv6 resolution issues
+ baseURL = normalizeLocalhostURL(baseURL)
+ }
+ if model == "" {
+ model = "all-minilm" // Default embedding model (all-MiniLM-L6-v2)
+ }
+
+ logger.Infof("Initializing Ollama backend (model: %s, url: %s)", model, baseURL)
+
+ // Determine dimension based on model
+ dimension := 384 // Default for all-minilm
+ if model == "nomic-embed-text" {
+ dimension = 768
+ }
+
+ backend := &OllamaBackend{
+ baseURL: baseURL,
+ model: model,
+ dimension: dimension,
+ client: &http.Client{},
+ }
+
+ // Test connection
+ resp, err := backend.client.Get(baseURL)
+ if err != nil {
+ return nil, fmt.Errorf("failed to connect to Ollama at %s: %w (is 'ollama serve' running?)", baseURL, err)
+ }
+ _ = resp.Body.Close()
+
+ logger.Info("Successfully connected to Ollama")
+ return backend, nil
+}
+
+// Embed generates an embedding for a single text
+func (o *OllamaBackend) Embed(text string) ([]float32, error) {
+ reqBody := ollamaEmbedRequest{
+ Model: o.model,
+ Prompt: text,
+ }
+
+ jsonData, err := json.Marshal(reqBody)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %w", err)
+ }
+
+ resp, err := o.client.Post(
+ o.baseURL+"/api/embeddings",
+ "application/json",
+ bytes.NewBuffer(jsonData),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("failed to call Ollama API: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return nil, fmt.Errorf("ollama API returned status %d: %s", resp.StatusCode, string(body))
+ }
+
+ var embedResp ollamaEmbedResponse
+ if err := json.NewDecoder(resp.Body).Decode(&embedResp); err != nil {
+ return nil, fmt.Errorf("failed to decode response: %w", err)
+ }
+
+ // Convert []float64 to []float32
+ embedding := make([]float32, len(embedResp.Embedding))
+ for i, v := range embedResp.Embedding {
+ embedding[i] = float32(v)
+ }
+
+ return embedding, nil
+}
+
+// EmbedBatch generates embeddings for multiple texts
+func (o *OllamaBackend) EmbedBatch(texts []string) ([][]float32, error) {
+ embeddings := make([][]float32, len(texts))
+
+ for i, text := range texts {
+ emb, err := o.Embed(text)
+ if err != nil {
+ return nil, fmt.Errorf("failed to embed text %d: %w", i, err)
+ }
+ embeddings[i] = emb
+ }
+
+ return embeddings, nil
+}
+
+// Dimension returns the embedding dimension
+func (o *OllamaBackend) Dimension() int {
+ return o.dimension
+}
+
+// Close releases any resources
+func (*OllamaBackend) Close() error {
+ // HTTP client doesn't need explicit cleanup
+ return nil
+}
diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/ollama_test.go b/cmd/thv-operator/pkg/optimizer/embeddings/ollama_test.go
new file mode 100644
index 0000000000..16d7793e85
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/embeddings/ollama_test.go
@@ -0,0 +1,69 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package embeddings
+
+import (
+ "testing"
+)
+
+func TestOllamaBackend_ConnectionFailure(t *testing.T) {
+ t.Parallel()
+ // This test verifies that Ollama backend handles connection failures gracefully
+
+ // Test that NewOllamaBackend handles connection failure gracefully
+ _, err := NewOllamaBackend("http://localhost:99999", "all-minilm")
+ if err == nil {
+ t.Error("Expected error when connecting to invalid Ollama URL")
+ }
+}
+
+func TestManagerWithOllama(t *testing.T) {
+ t.Parallel()
+ // Test that Manager works with Ollama when available
+ config := &Config{
+ BackendType: BackendTypeOllama,
+ BaseURL: "http://localhost:11434",
+ Model: DefaultModelAllMiniLM,
+ Dimension: 768,
+ EnableCache: true,
+ MaxCacheSize: 100,
+ }
+
+ manager, err := NewManager(config)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err)
+ return
+ }
+ defer manager.Close()
+
+ // Test single embedding
+ embeddings, err := manager.GenerateEmbedding([]string{"test text"})
+ if err != nil {
+ // Model might not be pulled - skip gracefully
+ t.Skipf("Skipping test: Failed to generate embedding. Error: %v. Run 'ollama pull nomic-embed-text'", err)
+ return
+ }
+
+ if len(embeddings) != 1 {
+ t.Errorf("Expected 1 embedding, got %d", len(embeddings))
+ }
+
+ // Ollama all-minilm uses 384 dimensions
+ if len(embeddings[0]) != 384 {
+ t.Errorf("Expected dimension 384, got %d", len(embeddings[0]))
+ }
+
+ // Test batch embeddings
+ texts := []string{"text 1", "text 2", "text 3"}
+ embeddings, err = manager.GenerateEmbedding(texts)
+ if err != nil {
+ // Model might not be pulled - skip gracefully
+ t.Skipf("Skipping test: Failed to generate batch embeddings. Error: %v. Run 'ollama pull nomic-embed-text'", err)
+ return
+ }
+
+ if len(embeddings) != 3 {
+ t.Errorf("Expected 3 embeddings, got %d", len(embeddings))
+ }
+}
diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/openai_compatible.go b/cmd/thv-operator/pkg/optimizer/embeddings/openai_compatible.go
new file mode 100644
index 0000000000..c98adba54a
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/embeddings/openai_compatible.go
@@ -0,0 +1,152 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package embeddings
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+
+ "github.com/stacklok/toolhive/pkg/logger"
+)
+
+// OpenAICompatibleBackend implements the Backend interface for OpenAI-compatible APIs.
+//
+// Supported Services:
+// - vLLM: Recommended for production Kubernetes deployments
+// - High-throughput GPU-accelerated inference
+// - PagedAttention for efficient GPU memory utilization
+// - Superior scalability for multi-user environments
+// - Ollama: Good for local development (via /v1/embeddings endpoint)
+// - OpenAI: For cloud-based embeddings
+// - Any OpenAI-compatible embedding service
+//
+// For production deployments, vLLM is strongly recommended due to its performance
+// characteristics and Kubernetes-native design.
+type OpenAICompatibleBackend struct {
+ baseURL string
+ model string
+ dimension int
+ client *http.Client
+}
+
+type openaiEmbedRequest struct {
+ Model string `json:"model"`
+ Input string `json:"input"` // OpenAI standard uses "input"
+}
+
+type openaiEmbedResponse struct {
+ Object string `json:"object"`
+ Data []struct {
+ Object string `json:"object"`
+ Embedding []float32 `json:"embedding"`
+ Index int `json:"index"`
+ } `json:"data"`
+ Model string `json:"model"`
+}
+
+// NewOpenAICompatibleBackend creates a new OpenAI-compatible backend.
+//
+// Examples:
+// - vLLM: NewOpenAICompatibleBackend("http://vllm-service:8000", "sentence-transformers/all-MiniLM-L6-v2", 384)
+// - Ollama: NewOpenAICompatibleBackend("http://localhost:11434", "nomic-embed-text", 768)
+// - OpenAI: NewOpenAICompatibleBackend("https://api.openai.com", "text-embedding-3-small", 1536)
+func NewOpenAICompatibleBackend(baseURL, model string, dimension int) (*OpenAICompatibleBackend, error) {
+ if baseURL == "" {
+ return nil, fmt.Errorf("baseURL is required for OpenAI-compatible backend")
+ }
+ if model == "" {
+ return nil, fmt.Errorf("model is required for OpenAI-compatible backend")
+ }
+ if dimension == 0 {
+ dimension = 384 // Default dimension
+ }
+
+ logger.Infof("Initializing OpenAI-compatible backend (model: %s, url: %s)", model, baseURL)
+
+ backend := &OpenAICompatibleBackend{
+ baseURL: baseURL,
+ model: model,
+ dimension: dimension,
+ client: &http.Client{},
+ }
+
+ // Test connection
+ resp, err := backend.client.Get(baseURL)
+ if err != nil {
+ return nil, fmt.Errorf("failed to connect to %s: %w", baseURL, err)
+ }
+ _ = resp.Body.Close()
+
+ logger.Info("Successfully connected to OpenAI-compatible service")
+ return backend, nil
+}
+
+// Embed generates an embedding for a single text using OpenAI-compatible API
+func (o *OpenAICompatibleBackend) Embed(text string) ([]float32, error) {
+ reqBody := openaiEmbedRequest{
+ Model: o.model,
+ Input: text,
+ }
+
+ jsonData, err := json.Marshal(reqBody)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %w", err)
+ }
+
+ // Use standard OpenAI v1 endpoint
+ resp, err := o.client.Post(
+ o.baseURL+"/v1/embeddings",
+ "application/json",
+ bytes.NewBuffer(jsonData),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("failed to call embeddings API: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
+ }
+
+ var embedResp openaiEmbedResponse
+ if err := json.NewDecoder(resp.Body).Decode(&embedResp); err != nil {
+ return nil, fmt.Errorf("failed to decode response: %w", err)
+ }
+
+ if len(embedResp.Data) == 0 {
+ return nil, fmt.Errorf("no embeddings in response")
+ }
+
+ return embedResp.Data[0].Embedding, nil
+}
+
+// EmbedBatch generates embeddings for multiple texts
+func (o *OpenAICompatibleBackend) EmbedBatch(texts []string) ([][]float32, error) {
+ embeddings := make([][]float32, len(texts))
+
+ for i, text := range texts {
+ emb, err := o.Embed(text)
+ if err != nil {
+ return nil, fmt.Errorf("failed to embed text %d: %w", i, err)
+ }
+ embeddings[i] = emb
+ }
+
+ return embeddings, nil
+}
+
+// Dimension returns the embedding dimension
+func (o *OpenAICompatibleBackend) Dimension() int {
+ return o.dimension
+}
+
+// Close releases any resources
+func (*OpenAICompatibleBackend) Close() error {
+ // HTTP client doesn't need explicit cleanup
+ return nil
+}
diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/openai_compatible_test.go b/cmd/thv-operator/pkg/optimizer/embeddings/openai_compatible_test.go
new file mode 100644
index 0000000000..f9a686e953
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/embeddings/openai_compatible_test.go
@@ -0,0 +1,226 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package embeddings
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+const testEmbeddingsEndpoint = "/v1/embeddings"
+
+func TestOpenAICompatibleBackend(t *testing.T) {
+ t.Parallel()
+ // Create a test server that mimics OpenAI-compatible API
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == testEmbeddingsEndpoint {
+ var req openaiEmbedRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ t.Fatalf("Failed to decode request: %v", err)
+ }
+
+ // Return a mock embedding response
+ resp := openaiEmbedResponse{
+ Object: "list",
+ Data: []struct {
+ Object string `json:"object"`
+ Embedding []float32 `json:"embedding"`
+ Index int `json:"index"`
+ }{
+ {
+ Object: "embedding",
+ Embedding: make([]float32, 384),
+ Index: 0,
+ },
+ },
+ Model: req.Model,
+ }
+
+ // Fill with test data
+ for i := range resp.Data[0].Embedding {
+ resp.Data[0].Embedding[i] = float32(i) / 384.0
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(resp)
+ return
+ }
+
+ // Health check endpoint
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ // Test backend creation
+ backend, err := NewOpenAICompatibleBackend(server.URL, "test-model", 384)
+ if err != nil {
+ t.Fatalf("Failed to create backend: %v", err)
+ }
+ defer backend.Close()
+
+ // Test embedding generation
+ embedding, err := backend.Embed("test text")
+ if err != nil {
+ t.Fatalf("Failed to generate embedding: %v", err)
+ }
+
+ if len(embedding) != 384 {
+ t.Errorf("Expected embedding dimension 384, got %d", len(embedding))
+ }
+
+ // Test batch embedding
+ texts := []string{"text1", "text2", "text3"}
+ embeddings, err := backend.EmbedBatch(texts)
+ if err != nil {
+ t.Fatalf("Failed to generate batch embeddings: %v", err)
+ }
+
+ if len(embeddings) != len(texts) {
+ t.Errorf("Expected %d embeddings, got %d", len(texts), len(embeddings))
+ }
+}
+
+func TestOpenAICompatibleBackendErrors(t *testing.T) {
+ t.Parallel()
+ // Test missing baseURL
+ _, err := NewOpenAICompatibleBackend("", "model", 384)
+ if err == nil {
+ t.Error("Expected error for missing baseURL")
+ }
+
+ // Test missing model
+ _, err = NewOpenAICompatibleBackend("http://localhost:8000", "", 384)
+ if err == nil {
+ t.Error("Expected error for missing model")
+ }
+}
+
+func TestManagerWithVLLM(t *testing.T) {
+ t.Parallel()
+ // Create a test server
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == testEmbeddingsEndpoint {
+ resp := openaiEmbedResponse{
+ Object: "list",
+ Data: []struct {
+ Object string `json:"object"`
+ Embedding []float32 `json:"embedding"`
+ Index int `json:"index"`
+ }{
+ {
+ Object: "embedding",
+ Embedding: make([]float32, 384),
+ Index: 0,
+ },
+ },
+ }
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(resp)
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ // Test manager with vLLM backend
+ config := &Config{
+ BackendType: "vllm",
+ BaseURL: server.URL,
+ Model: "sentence-transformers/all-MiniLM-L6-v2",
+ Dimension: 384,
+ EnableCache: true,
+ }
+
+ manager, err := NewManager(config)
+ if err != nil {
+ t.Fatalf("Failed to create manager: %v", err)
+ }
+ defer manager.Close()
+
+ // Test embedding generation
+ embeddings, err := manager.GenerateEmbedding([]string{"test"})
+ if err != nil {
+ t.Fatalf("Failed to generate embeddings: %v", err)
+ }
+
+ if len(embeddings) != 1 {
+ t.Errorf("Expected 1 embedding, got %d", len(embeddings))
+ }
+ if len(embeddings[0]) != 384 {
+ t.Errorf("Expected dimension 384, got %d", len(embeddings[0]))
+ }
+}
+
+func TestManagerWithUnified(t *testing.T) {
+ t.Parallel()
+ // Create a test server
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == testEmbeddingsEndpoint {
+ resp := openaiEmbedResponse{
+ Object: "list",
+ Data: []struct {
+ Object string `json:"object"`
+ Embedding []float32 `json:"embedding"`
+ Index int `json:"index"`
+ }{
+ {
+ Object: "embedding",
+ Embedding: make([]float32, 768),
+ Index: 0,
+ },
+ },
+ }
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(resp)
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ // Test manager with unified backend
+ config := &Config{
+ BackendType: "unified",
+ BaseURL: server.URL,
+ Model: "nomic-embed-text",
+ Dimension: 768,
+ EnableCache: false,
+ }
+
+ manager, err := NewManager(config)
+ if err != nil {
+ t.Fatalf("Failed to create manager: %v", err)
+ }
+ defer manager.Close()
+
+ // Test embedding generation
+ embeddings, err := manager.GenerateEmbedding([]string{"test"})
+ if err != nil {
+ t.Fatalf("Failed to generate embeddings: %v", err)
+ }
+
+ if len(embeddings) != 1 {
+ t.Errorf("Expected 1 embedding, got %d", len(embeddings))
+ }
+}
+
+func TestManagerFallbackBehavior(t *testing.T) {
+ t.Parallel()
+ // Test that invalid vLLM backend fails gracefully during initialization
+ // (No fallback behavior is currently implemented)
+ config := &Config{
+ BackendType: "vllm",
+ BaseURL: "http://invalid-host-that-does-not-exist:9999",
+ Model: "test-model",
+ Dimension: 384,
+ }
+
+ _, err := NewManager(config)
+ if err == nil {
+ t.Error("Expected error when creating manager with invalid backend URL")
+ }
+ // Test passes if error is returned (no fallback behavior)
+}
diff --git a/cmd/thv-operator/pkg/optimizer/ingestion/errors.go b/cmd/thv-operator/pkg/optimizer/ingestion/errors.go
new file mode 100644
index 0000000000..93e8eab31c
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/ingestion/errors.go
@@ -0,0 +1,24 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+// Package ingestion provides services for ingesting MCP tools into the database.
+package ingestion
+
+import "errors"
+
+var (
+ // ErrIngestionFailed is returned when ingestion fails
+ ErrIngestionFailed = errors.New("ingestion failed")
+
+ // ErrBackendRetrievalFailed is returned when backend retrieval fails
+ ErrBackendRetrievalFailed = errors.New("backend retrieval failed")
+
+ // ErrToolHiveUnavailable is returned when ToolHive is unavailable
+ ErrToolHiveUnavailable = errors.New("ToolHive unavailable")
+
+ // ErrBackendStatusNil is returned when backend status is nil
+ ErrBackendStatusNil = errors.New("backend status cannot be nil")
+
+ // ErrInvalidRuntimeMode is returned for invalid runtime mode
+ ErrInvalidRuntimeMode = errors.New("invalid runtime mode: must be 'docker' or 'k8s'")
+)
diff --git a/cmd/thv-operator/pkg/optimizer/ingestion/service.go b/cmd/thv-operator/pkg/optimizer/ingestion/service.go
new file mode 100644
index 0000000000..0b78423e12
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/ingestion/service.go
@@ -0,0 +1,346 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package ingestion
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/mark3labs/mcp-go/mcp"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/codes"
+ "go.opentelemetry.io/otel/trace"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/db"
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings"
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models"
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/tokens"
+ "github.com/stacklok/toolhive/pkg/logger"
+)
+
+// Config holds configuration for the ingestion service
+type Config struct {
+ // Database configuration
+ DBConfig *db.Config
+
+ // Embedding configuration
+ EmbeddingConfig *embeddings.Config
+
+ // MCP timeout in seconds
+ MCPTimeout int
+
+ // Workloads to skip during ingestion
+ SkippedWorkloads []string
+
+ // Runtime mode: "docker" or "k8s"
+ RuntimeMode string
+
+ // Kubernetes configuration (used when RuntimeMode is "k8s")
+ K8sAPIServerURL string
+ K8sNamespace string
+ K8sAllNamespaces bool
+}
+
+// Service handles ingestion of MCP backends and their tools
+type Service struct {
+ config *Config
+ database *db.DB
+ embeddingManager *embeddings.Manager
+ tokenCounter *tokens.Counter
+ backendServerOps *db.BackendServerOps
+ backendToolOps *db.BackendToolOps
+ tracer trace.Tracer
+
+ // Embedding time tracking
+ embeddingTimeMu sync.Mutex
+ totalEmbeddingTime time.Duration
+}
+
+// NewService creates a new ingestion service
+func NewService(config *Config) (*Service, error) {
+ // Set defaults
+ if config.MCPTimeout == 0 {
+ config.MCPTimeout = 30
+ }
+ if len(config.SkippedWorkloads) == 0 {
+ config.SkippedWorkloads = []string{"inspector", "mcp-optimizer"}
+ }
+
+ // Initialize database
+ database, err := db.NewDB(config.DBConfig)
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize database: %w", err)
+ }
+
+ // Clear database on startup to ensure fresh embeddings
+ // This is important when the embedding model changes or for consistency
+ database.Reset()
+ logger.Info("Cleared optimizer database on startup")
+
+ // Initialize embedding manager
+ embeddingManager, err := embeddings.NewManager(config.EmbeddingConfig)
+ if err != nil {
+ _ = database.Close()
+ return nil, fmt.Errorf("failed to initialize embedding manager: %w", err)
+ }
+
+ // Initialize token counter
+ tokenCounter := tokens.NewCounter()
+
+ // Initialize tracer
+ tracer := otel.Tracer("github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/ingestion")
+
+ svc := &Service{
+ config: config,
+ database: database,
+ embeddingManager: embeddingManager,
+ tokenCounter: tokenCounter,
+ tracer: tracer,
+ totalEmbeddingTime: 0,
+ }
+
+ // Create chromem-go embeddingFunc from our embedding manager with tracing
+ embeddingFunc := func(ctx context.Context, text string) ([]float32, error) {
+ // Create a span for embedding calculation
+ _, span := svc.tracer.Start(ctx, "optimizer.ingestion.calculate_embedding",
+ trace.WithAttributes(
+ attribute.String("operation", "embedding_calculation"),
+ ))
+ defer span.End()
+
+ start := time.Now()
+
+ // Our manager takes a slice, so wrap the single text
+ embeddingsResult, err := embeddingManager.GenerateEmbedding([]string{text})
+ if err != nil {
+ span.RecordError(err)
+ span.SetStatus(codes.Error, err.Error())
+ return nil, err
+ }
+ if len(embeddingsResult) == 0 {
+ err := fmt.Errorf("no embeddings generated")
+ span.RecordError(err)
+ span.SetStatus(codes.Error, err.Error())
+ return nil, err
+ }
+
+ // Track embedding time
+ duration := time.Since(start)
+ svc.embeddingTimeMu.Lock()
+ svc.totalEmbeddingTime += duration
+ svc.embeddingTimeMu.Unlock()
+
+ span.SetAttributes(
+ attribute.Int64("embedding.duration_ms", duration.Milliseconds()),
+ )
+
+ return embeddingsResult[0], nil
+ }
+
+ svc.backendServerOps = db.NewBackendServerOps(database, embeddingFunc)
+ svc.backendToolOps = db.NewBackendToolOps(database, embeddingFunc)
+
+ logger.Info("Ingestion service initialized for event-driven ingestion (chromem-go)")
+ return svc, nil
+}
+
+// IngestServer ingests a single MCP server and its tools into the optimizer database.
+// This is called by vMCP during session registration for each backend server.
+//
+// Parameters:
+// - serverID: Unique identifier for the backend server
+// - serverName: Human-readable server name
+// - description: Optional server description
+// - tools: List of tools available from this server
+//
+// This method will:
+// 1. Create or update the backend server record (simplified metadata only)
+// 2. Generate embeddings for server and tools
+// 3. Count tokens for each tool
+// 4. Store everything in the database for semantic search
+//
+// Note: URL, transport, status are NOT stored - vMCP manages backend lifecycle
+func (s *Service) IngestServer(
+ ctx context.Context,
+ serverID string,
+ serverName string,
+ description *string,
+ tools []mcp.Tool,
+) error {
+ // Create a span for the entire ingestion operation
+ ctx, span := s.tracer.Start(ctx, "optimizer.ingestion.ingest_server",
+ trace.WithAttributes(
+ attribute.String("server.id", serverID),
+ attribute.String("server.name", serverName),
+ attribute.Int("tools.count", len(tools)),
+ ))
+ defer span.End()
+
+ start := time.Now()
+ logger.Infof("Ingesting server: %s (%d tools) [serverID=%s]", serverName, len(tools), serverID)
+
+ // Create backend server record (simplified - vMCP manages lifecycle)
+ // chromem-go will generate embeddings automatically from the content
+ backendServer := &models.BackendServer{
+ ID: serverID,
+ Name: serverName,
+ Description: description,
+ Group: "default", // TODO: Pass group from vMCP if needed
+ CreatedAt: time.Now(),
+ LastUpdated: time.Now(),
+ }
+
+ // Create or update server (chromem-go handles embeddings)
+ if err := s.backendServerOps.Update(ctx, backendServer); err != nil {
+ span.RecordError(err)
+ span.SetStatus(codes.Error, err.Error())
+ return fmt.Errorf("failed to create/update server %s: %w", serverName, err)
+ }
+ logger.Debugf("Created/updated server: %s", serverName)
+
+ // Sync tools for this server
+ toolCount, err := s.syncBackendTools(ctx, serverID, serverName, tools)
+ if err != nil {
+ span.RecordError(err)
+ span.SetStatus(codes.Error, err.Error())
+ return fmt.Errorf("failed to sync tools for %s: %w", serverName, err)
+ }
+
+ duration := time.Since(start)
+ span.SetAttributes(
+ attribute.Int64("ingestion.duration_ms", duration.Milliseconds()),
+ attribute.Int("tools.ingested", toolCount),
+ )
+
+ logger.Infow("Successfully ingested server",
+ "server_name", serverName,
+ "server_id", serverID,
+ "tools_count", toolCount,
+ "duration_ms", duration.Milliseconds())
+ return nil
+}
+
+// syncBackendTools synchronizes tools for a backend server
+func (s *Service) syncBackendTools(ctx context.Context, serverID string, serverName string, tools []mcp.Tool) (int, error) {
+ // Create a span for tool synchronization
+ ctx, span := s.tracer.Start(ctx, "optimizer.ingestion.sync_backend_tools",
+ trace.WithAttributes(
+ attribute.String("server.id", serverID),
+ attribute.String("server.name", serverName),
+ attribute.Int("tools.count", len(tools)),
+ ))
+ defer span.End()
+
+ logger.Debugf("syncBackendTools: server=%s, serverID=%s, tool_count=%d", serverName, serverID, len(tools))
+
+ // Delete existing tools
+ if err := s.backendToolOps.DeleteByServer(ctx, serverID); err != nil {
+ span.RecordError(err)
+ span.SetStatus(codes.Error, err.Error())
+ return 0, fmt.Errorf("failed to delete existing tools: %w", err)
+ }
+
+ if len(tools) == 0 {
+ return 0, nil
+ }
+
+ // Create tool records (chromem-go will generate embeddings automatically)
+ for _, tool := range tools {
+ // Extract description for embedding
+ description := tool.Description
+
+ // Convert InputSchema to JSON
+ schemaJSON, err := json.Marshal(tool.InputSchema)
+ if err != nil {
+ span.RecordError(err)
+ span.SetStatus(codes.Error, err.Error())
+ return 0, fmt.Errorf("failed to marshal input schema for tool %s: %w", tool.Name, err)
+ }
+
+ backendTool := &models.BackendTool{
+ ID: uuid.New().String(),
+ MCPServerID: serverID,
+ ToolName: tool.Name,
+ Description: &description,
+ InputSchema: schemaJSON,
+ TokenCount: s.tokenCounter.CountToolTokens(tool),
+ CreatedAt: time.Now(),
+ LastUpdated: time.Now(),
+ }
+
+ if err := s.backendToolOps.Create(ctx, backendTool, serverName); err != nil {
+ span.RecordError(err)
+ span.SetStatus(codes.Error, err.Error())
+ return 0, fmt.Errorf("failed to create tool %s: %w", tool.Name, err)
+ }
+ }
+
+ logger.Infof("Synced %d tools for server %s", len(tools), serverName)
+ return len(tools), nil
+}
+
+// GetEmbeddingManager returns the embedding manager for this service
+func (s *Service) GetEmbeddingManager() *embeddings.Manager {
+ return s.embeddingManager
+}
+
+// GetBackendToolOps returns the backend tool operations for search and retrieval
+func (s *Service) GetBackendToolOps() *db.BackendToolOps {
+ return s.backendToolOps
+}
+
+// GetTotalToolTokens returns the total token count across all tools in the database
+func (s *Service) GetTotalToolTokens(ctx context.Context) int {
+ // Use FTS database to efficiently count all tool tokens
+ if s.database.GetFTSDB() != nil {
+ totalTokens, err := s.database.GetFTSDB().GetTotalToolTokens(ctx)
+ if err != nil {
+ logger.Warnw("Failed to get total tool tokens from FTS", "error", err)
+ return 0
+ }
+ return totalTokens
+ }
+
+ // Fallback: query all tools (less efficient but works)
+ logger.Warn("FTS database not available, using fallback for token counting")
+ return 0
+}
+
+// GetTotalEmbeddingTime returns the total time spent calculating embeddings
+func (s *Service) GetTotalEmbeddingTime() time.Duration {
+ s.embeddingTimeMu.Lock()
+ defer s.embeddingTimeMu.Unlock()
+ return s.totalEmbeddingTime
+}
+
+// ResetEmbeddingTime resets the total embedding time counter
+func (s *Service) ResetEmbeddingTime() {
+ s.embeddingTimeMu.Lock()
+ defer s.embeddingTimeMu.Unlock()
+ s.totalEmbeddingTime = 0
+}
+
+// Close releases resources
+func (s *Service) Close() error {
+ var errs []error
+
+ if err := s.embeddingManager.Close(); err != nil {
+ errs = append(errs, fmt.Errorf("failed to close embedding manager: %w", err))
+ }
+
+ if err := s.database.Close(); err != nil {
+ errs = append(errs, fmt.Errorf("failed to close database: %w", err))
+ }
+
+ if len(errs) > 0 {
+ return fmt.Errorf("errors closing service: %v", errs)
+ }
+
+ return nil
+}
diff --git a/cmd/thv-operator/pkg/optimizer/ingestion/service_test.go b/cmd/thv-operator/pkg/optimizer/ingestion/service_test.go
new file mode 100644
index 0000000000..0475737071
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/ingestion/service_test.go
@@ -0,0 +1,253 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package ingestion
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/stretchr/testify/require"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/db"
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings"
+)
+
+// TestServiceCreationAndIngestion demonstrates the complete chromem-go workflow:
+// 1. Create in-memory database
+// 2. Initialize ingestion service
+// 3. Ingest server and tools
+// 4. Query the database
+func TestServiceCreationAndIngestion(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ // Create temporary directory for persistence (optional)
+ tmpDir := t.TempDir()
+
+ // Try to use Ollama if available, otherwise skip test
+ // Check for the actual model we'll use: nomic-embed-text
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "nomic-embed-text",
+ Dimension: 768,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available or model not found. Error: %v. Run 'ollama serve && ollama pull nomic-embed-text'", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ // Initialize service with Ollama embeddings
+ config := &Config{
+ DBConfig: &db.Config{
+ PersistPath: filepath.Join(tmpDir, "test-db"),
+ },
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "nomic-embed-text",
+ Dimension: 768,
+ },
+ }
+
+ svc, err := NewService(config)
+ if err != nil {
+ t.Skipf("Skipping test: Failed to create service. Error: %v. Run 'ollama serve && ollama pull nomic-embed-text'", err)
+ return
+ }
+ defer func() { _ = svc.Close() }()
+
+ // Create test tools
+ tools := []mcp.Tool{
+ {
+ Name: "get_weather",
+ Description: "Get the current weather for a location",
+ },
+ {
+ Name: "search_web",
+ Description: "Search the web for information",
+ },
+ }
+
+ // Ingest server with tools
+ serverName := "test-server"
+ serverID := "test-server-id"
+ description := "A test MCP server"
+
+ err = svc.IngestServer(ctx, serverID, serverName, &description, tools)
+ if err != nil {
+ // Check if error is due to missing model
+ errStr := err.Error()
+ if strings.Contains(errStr, "model") || strings.Contains(errStr, "not found") || strings.Contains(errStr, "404") {
+ t.Skipf("Skipping test: Model not available. Error: %v. Run 'ollama serve && ollama pull nomic-embed-text'", err)
+ return
+ }
+ require.NoError(t, err)
+ }
+
+ // Query tools
+ allTools, err := svc.backendToolOps.ListByServer(ctx, serverID)
+ require.NoError(t, err)
+ require.Len(t, allTools, 2, "Expected 2 tools to be ingested")
+
+ // Verify tool names
+ toolNames := make(map[string]bool)
+ for _, tool := range allTools {
+ toolNames[tool.ToolName] = true
+ }
+ require.True(t, toolNames["get_weather"], "get_weather tool should be present")
+ require.True(t, toolNames["search_web"], "search_web tool should be present")
+
+ // Search for similar tools
+ results, err := svc.backendToolOps.Search(ctx, "weather information", 5, &serverID)
+ require.NoError(t, err)
+ require.NotEmpty(t, results, "Should find at least one similar tool")
+
+ require.NotEmpty(t, results, "Should return at least one result")
+
+ // Weather tool should be most similar to weather query
+ require.Equal(t, "get_weather", results[0].ToolName,
+ "Weather tool should be most similar to weather query")
+ toolNamesFound := make(map[string]bool)
+ for _, result := range results {
+ toolNamesFound[result.ToolName] = true
+ }
+ require.True(t, toolNamesFound["get_weather"], "get_weather should be in results")
+ require.True(t, toolNamesFound["search_web"], "search_web should be in results")
+}
+
+// TestService_EmbeddingTimeTracking tests that embedding time is tracked correctly
+func TestService_EmbeddingTimeTracking(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Try to use Ollama if available, otherwise skip test
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ // Initialize service
+ config := &Config{
+ DBConfig: &db.Config{
+ PersistPath: filepath.Join(tmpDir, "test-db"),
+ },
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ svc, err := NewService(config)
+ require.NoError(t, err)
+ defer func() { _ = svc.Close() }()
+
+ // Initially, embedding time should be 0
+ initialTime := svc.GetTotalEmbeddingTime()
+ require.Equal(t, time.Duration(0), initialTime, "Initial embedding time should be 0")
+
+ // Create test tools
+ tools := []mcp.Tool{
+ {
+ Name: "test_tool_1",
+ Description: "First test tool for embedding",
+ },
+ {
+ Name: "test_tool_2",
+ Description: "Second test tool for embedding",
+ },
+ }
+
+ // Reset embedding time before ingestion
+ svc.ResetEmbeddingTime()
+
+ // Ingest server with tools (this will generate embeddings)
+ err = svc.IngestServer(ctx, "test-server-id", "TestServer", nil, tools)
+ require.NoError(t, err)
+
+ // After ingestion, embedding time should be greater than 0
+ totalEmbeddingTime := svc.GetTotalEmbeddingTime()
+ require.Greater(t, totalEmbeddingTime, time.Duration(0),
+ "Total embedding time should be greater than 0 after ingestion")
+
+ // Reset and verify it's back to 0
+ svc.ResetEmbeddingTime()
+ resetTime := svc.GetTotalEmbeddingTime()
+ require.Equal(t, time.Duration(0), resetTime, "Embedding time should be 0 after reset")
+}
+
+// TestServiceWithOllama demonstrates using real embeddings (requires Ollama running)
+// This test can be enabled locally to verify Ollama integration
+func TestServiceWithOllama(t *testing.T) {
+ t.Parallel()
+
+ // Skip if not explicitly enabled or Ollama is not available
+ if os.Getenv("TEST_OLLAMA") != "true" {
+ t.Skip("Skipping Ollama integration test (set TEST_OLLAMA=true to enable)")
+ }
+
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Initialize service with Ollama embeddings
+ config := &Config{
+ DBConfig: &db.Config{
+ PersistPath: filepath.Join(tmpDir, "ollama-db"),
+ },
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "nomic-embed-text",
+ Dimension: 384,
+ },
+ }
+
+ svc, err := NewService(config)
+ require.NoError(t, err)
+ defer func() { _ = svc.Close() }()
+
+ // Create test tools
+ tools := []mcp.Tool{
+ {
+ Name: "get_weather",
+ Description: "Get current weather conditions for any location worldwide",
+ },
+ {
+ Name: "send_email",
+ Description: "Send an email message to a recipient",
+ },
+ }
+
+ // Ingest server
+ err = svc.IngestServer(ctx, "server-1", "TestServer", nil, tools)
+ require.NoError(t, err)
+
+ // Search for weather-related tools
+ results, err := svc.backendToolOps.Search(ctx, "What's the temperature outside?", 5, nil)
+ require.NoError(t, err)
+ require.NotEmpty(t, results)
+
+ require.Equal(t, "get_weather", results[0].ToolName,
+ "Weather tool should be most similar to weather query")
+}
diff --git a/cmd/thv-operator/pkg/optimizer/ingestion/service_test_coverage.go b/cmd/thv-operator/pkg/optimizer/ingestion/service_test_coverage.go
new file mode 100644
index 0000000000..a068eab687
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/ingestion/service_test_coverage.go
@@ -0,0 +1,285 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package ingestion
+
+import (
+ "context"
+ "path/filepath"
+ "testing"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/db"
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings"
+)
+
+// TestService_GetTotalToolTokens tests token counting
+func TestService_GetTotalToolTokens(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ config := &Config{
+ DBConfig: &db.Config{
+ PersistPath: filepath.Join(tmpDir, "test-db"),
+ },
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ svc, err := NewService(config)
+ require.NoError(t, err)
+ defer func() { _ = svc.Close() }()
+
+ // Ingest some tools
+ tools := []mcp.Tool{
+ {
+ Name: "tool1",
+ Description: "Tool 1",
+ },
+ {
+ Name: "tool2",
+ Description: "Tool 2",
+ },
+ }
+
+ err = svc.IngestServer(ctx, "server-1", "TestServer", nil, tools)
+ require.NoError(t, err)
+
+ // Get total tokens
+ totalTokens := svc.GetTotalToolTokens(ctx)
+ assert.GreaterOrEqual(t, totalTokens, 0, "Total tokens should be non-negative")
+}
+
+// TestService_GetTotalToolTokens_NoFTS tests token counting without FTS
+func TestService_GetTotalToolTokens_NoFTS(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ config := &Config{
+ DBConfig: &db.Config{
+ PersistPath: "", // In-memory
+ FTSDBPath: "", // Will default to :memory:
+ },
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ svc, err := NewService(config)
+ require.NoError(t, err)
+ defer func() { _ = svc.Close() }()
+
+ // Get total tokens (should use FTS if available, fallback otherwise)
+ totalTokens := svc.GetTotalToolTokens(ctx)
+ assert.GreaterOrEqual(t, totalTokens, 0, "Total tokens should be non-negative")
+}
+
+// TestService_GetBackendToolOps tests backend tool ops accessor
+func TestService_GetBackendToolOps(t *testing.T) {
+ t.Parallel()
+ tmpDir := t.TempDir()
+
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ config := &Config{
+ DBConfig: &db.Config{
+ PersistPath: filepath.Join(tmpDir, "test-db"),
+ },
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ svc, err := NewService(config)
+ require.NoError(t, err)
+ defer func() { _ = svc.Close() }()
+
+ toolOps := svc.GetBackendToolOps()
+ require.NotNil(t, toolOps)
+}
+
+// TestService_GetEmbeddingManager tests embedding manager accessor
+func TestService_GetEmbeddingManager(t *testing.T) {
+ t.Parallel()
+ tmpDir := t.TempDir()
+
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ config := &Config{
+ DBConfig: &db.Config{
+ PersistPath: filepath.Join(tmpDir, "test-db"),
+ },
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ svc, err := NewService(config)
+ require.NoError(t, err)
+ defer func() { _ = svc.Close() }()
+
+ manager := svc.GetEmbeddingManager()
+ require.NotNil(t, manager)
+}
+
+// TestService_IngestServer_ErrorHandling tests error handling during ingestion
+func TestService_IngestServer_ErrorHandling(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ config := &Config{
+ DBConfig: &db.Config{
+ PersistPath: filepath.Join(tmpDir, "test-db"),
+ },
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ svc, err := NewService(config)
+ require.NoError(t, err)
+ defer func() { _ = svc.Close() }()
+
+ // Test with empty tools list
+ err = svc.IngestServer(ctx, "server-1", "TestServer", nil, []mcp.Tool{})
+ require.NoError(t, err, "Should handle empty tools list gracefully")
+
+ // Test with nil description
+ err = svc.IngestServer(ctx, "server-2", "TestServer2", nil, []mcp.Tool{
+ {
+ Name: "tool1",
+ Description: "Tool 1",
+ },
+ })
+ require.NoError(t, err, "Should handle nil description gracefully")
+}
+
+// TestService_Close_ErrorHandling tests error handling during close
+func TestService_Close_ErrorHandling(t *testing.T) {
+ t.Parallel()
+ tmpDir := t.TempDir()
+
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ config := &Config{
+ DBConfig: &db.Config{
+ PersistPath: filepath.Join(tmpDir, "test-db"),
+ },
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ svc, err := NewService(config)
+ require.NoError(t, err)
+
+ // Close should succeed
+ err = svc.Close()
+ require.NoError(t, err)
+
+ // Multiple closes should be safe
+ err = svc.Close()
+ require.NoError(t, err)
+}
diff --git a/cmd/thv-operator/pkg/optimizer/models/errors.go b/cmd/thv-operator/pkg/optimizer/models/errors.go
new file mode 100644
index 0000000000..c5b10eebe6
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/models/errors.go
@@ -0,0 +1,19 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+// Package models defines domain models for the optimizer.
+// It includes structures for MCP servers, tools, and related metadata.
+package models
+
+import "errors"
+
+var (
+ // ErrRemoteServerMissingURL is returned when a remote server doesn't have a URL
+ ErrRemoteServerMissingURL = errors.New("remote servers must have URL")
+
+ // ErrContainerServerMissingPackage is returned when a container server doesn't have a package
+ ErrContainerServerMissingPackage = errors.New("container servers must have package")
+
+ // ErrInvalidTokenMetrics is returned when token metrics are inconsistent
+ ErrInvalidTokenMetrics = errors.New("invalid token metrics: calculated values don't match")
+)
diff --git a/cmd/thv-operator/pkg/optimizer/models/models.go b/cmd/thv-operator/pkg/optimizer/models/models.go
new file mode 100644
index 0000000000..6c810fbe04
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/models/models.go
@@ -0,0 +1,176 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package models
+
+import (
+ "encoding/json"
+ "time"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// BaseMCPServer represents the common fields for MCP servers.
+type BaseMCPServer struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Remote bool `json:"remote"`
+ Transport TransportType `json:"transport"`
+ Description *string `json:"description,omitempty"`
+ ServerEmbedding []float32 `json:"-"` // Excluded from JSON, stored as BLOB
+ Group string `json:"group"`
+ LastUpdated time.Time `json:"last_updated"`
+ CreatedAt time.Time `json:"created_at"`
+}
+
+// RegistryServer represents an MCP server from the registry catalog.
+type RegistryServer struct {
+ BaseMCPServer
+ URL *string `json:"url,omitempty"` // For remote servers
+ Package *string `json:"package,omitempty"` // For container servers
+}
+
+// Validate checks if the registry server has valid data.
+// Remote servers must have URL, container servers must have package.
+func (r *RegistryServer) Validate() error {
+ if r.Remote && r.URL == nil {
+ return ErrRemoteServerMissingURL
+ }
+ if !r.Remote && r.Package == nil {
+ return ErrContainerServerMissingPackage
+ }
+ return nil
+}
+
+// BackendServer represents a running MCP server backend.
+// Simplified: Only stores metadata needed for tool organization and search results.
+// vMCP manages backend lifecycle (URL, status, transport, etc.)
+type BackendServer struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Description *string `json:"description,omitempty"`
+ Group string `json:"group"`
+ ServerEmbedding []float32 `json:"-"` // Excluded from JSON, stored as BLOB
+ LastUpdated time.Time `json:"last_updated"`
+ CreatedAt time.Time `json:"created_at"`
+}
+
+// BaseTool represents the common fields for tools.
+type BaseTool struct {
+ ID string `json:"id"`
+ MCPServerID string `json:"mcpserver_id"`
+ Details mcp.Tool `json:"details"`
+ DetailsEmbedding []float32 `json:"-"` // Excluded from JSON, stored as BLOB
+ LastUpdated time.Time `json:"last_updated"`
+ CreatedAt time.Time `json:"created_at"`
+}
+
+// RegistryTool represents a tool from a registry MCP server.
+type RegistryTool struct {
+ BaseTool
+}
+
+// BackendTool represents a tool from a backend MCP server.
+// With chromem-go, embeddings are managed by the database.
+type BackendTool struct {
+ ID string `json:"id"`
+ MCPServerID string `json:"mcpserver_id"`
+ ToolName string `json:"tool_name"`
+ Description *string `json:"description,omitempty"`
+ InputSchema json.RawMessage `json:"input_schema,omitempty"`
+ ToolEmbedding []float32 `json:"-"` // Managed by chromem-go
+ TokenCount int `json:"token_count"`
+ LastUpdated time.Time `json:"last_updated"`
+ CreatedAt time.Time `json:"created_at"`
+}
+
+// ToolDetailsToJSON converts mcp.Tool to JSON for storage in the database.
+func ToolDetailsToJSON(tool mcp.Tool) (string, error) {
+ data, err := json.Marshal(tool)
+ if err != nil {
+ return "", err
+ }
+ return string(data), nil
+}
+
+// ToolDetailsFromJSON converts JSON to mcp.Tool
+func ToolDetailsFromJSON(data string) (*mcp.Tool, error) {
+ var tool mcp.Tool
+ err := json.Unmarshal([]byte(data), &tool)
+ if err != nil {
+ return nil, err
+ }
+ return &tool, nil
+}
+
+// BackendToolWithMetadata represents a backend tool with similarity score.
+type BackendToolWithMetadata struct {
+ BackendTool
+ Similarity float32 `json:"similarity"` // Cosine similarity from chromem-go (0-1, higher is better)
+}
+
+// RegistryToolWithMetadata represents a registry tool with server information and similarity distance.
+type RegistryToolWithMetadata struct {
+ ServerName string `json:"server_name"`
+ ServerDescription *string `json:"server_description,omitempty"`
+ Distance float64 `json:"distance"` // Cosine distance from query embedding
+ Tool RegistryTool `json:"tool"`
+}
+
+// BackendWithRegistry represents a backend server with its resolved registry relationship.
+type BackendWithRegistry struct {
+ Backend BackendServer `json:"backend"`
+ Registry *RegistryServer `json:"registry,omitempty"` // NULL if autonomous
+}
+
+// EffectiveDescription returns the description (inherited from registry or own).
+func (b *BackendWithRegistry) EffectiveDescription() *string {
+ if b.Registry != nil {
+ return b.Registry.Description
+ }
+ return b.Backend.Description
+}
+
+// EffectiveEmbedding returns the embedding (inherited from registry or own).
+func (b *BackendWithRegistry) EffectiveEmbedding() []float32 {
+ if b.Registry != nil {
+ return b.Registry.ServerEmbedding
+ }
+ return b.Backend.ServerEmbedding
+}
+
+// ServerNameForTools returns the server name to use as context for tool embeddings.
+func (b *BackendWithRegistry) ServerNameForTools() string {
+ if b.Registry != nil {
+ return b.Registry.Name
+ }
+ return b.Backend.Name
+}
+
+// TokenMetrics represents token efficiency metrics for tool filtering.
+type TokenMetrics struct {
+ BaselineTokens int `json:"baseline_tokens"` // Total tokens for all running server tools
+ ReturnedTokens int `json:"returned_tokens"` // Total tokens for returned/filtered tools
+ TokensSaved int `json:"tokens_saved"` // Number of tokens saved by filtering
+ SavingsPercentage float64 `json:"savings_percentage"` // Percentage of tokens saved (0-100)
+}
+
+// Validate checks if the token metrics are consistent.
+func (t *TokenMetrics) Validate() error {
+ if t.TokensSaved != t.BaselineTokens-t.ReturnedTokens {
+ return ErrInvalidTokenMetrics
+ }
+
+ var expectedPct float64
+ if t.BaselineTokens > 0 {
+ expectedPct = (float64(t.TokensSaved) / float64(t.BaselineTokens)) * 100
+ // Allow small floating point differences (0.01%)
+ if expectedPct-t.SavingsPercentage > 0.01 || t.SavingsPercentage-expectedPct > 0.01 {
+ return ErrInvalidTokenMetrics
+ }
+ } else if t.SavingsPercentage != 0.0 {
+ return ErrInvalidTokenMetrics
+ }
+
+ return nil
+}
diff --git a/cmd/thv-operator/pkg/optimizer/models/models_test.go b/cmd/thv-operator/pkg/optimizer/models/models_test.go
new file mode 100644
index 0000000000..af06e90bf4
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/models/models_test.go
@@ -0,0 +1,273 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package models
+
+import (
+ "testing"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+func TestRegistryServer_Validate(t *testing.T) {
+ t.Parallel()
+ url := "http://example.com/mcp"
+ pkg := "github.com/example/mcp-server"
+
+ tests := []struct {
+ name string
+ server *RegistryServer
+ wantErr bool
+ }{
+ {
+ name: "Remote server with URL is valid",
+ server: &RegistryServer{
+ BaseMCPServer: BaseMCPServer{
+ Remote: true,
+ },
+ URL: &url,
+ },
+ wantErr: false,
+ },
+ {
+ name: "Container server with package is valid",
+ server: &RegistryServer{
+ BaseMCPServer: BaseMCPServer{
+ Remote: false,
+ },
+ Package: &pkg,
+ },
+ wantErr: false,
+ },
+ {
+ name: "Remote server without URL is invalid",
+ server: &RegistryServer{
+ BaseMCPServer: BaseMCPServer{
+ Remote: true,
+ },
+ },
+ wantErr: true,
+ },
+ {
+ name: "Container server without package is invalid",
+ server: &RegistryServer{
+ BaseMCPServer: BaseMCPServer{
+ Remote: false,
+ },
+ },
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ err := tt.server.Validate()
+ if (err != nil) != tt.wantErr {
+ t.Errorf("RegistryServer.Validate() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestToolDetailsToJSON(t *testing.T) {
+ t.Parallel()
+ tool := mcp.Tool{
+ Name: "test_tool",
+ Description: "A test tool",
+ }
+
+ json, err := ToolDetailsToJSON(tool)
+ if err != nil {
+ t.Fatalf("ToolDetailsToJSON() error = %v", err)
+ }
+
+ if json == "" {
+ t.Error("ToolDetailsToJSON() returned empty string")
+ }
+
+ // Try to parse it back
+ parsed, err := ToolDetailsFromJSON(json)
+ if err != nil {
+ t.Fatalf("ToolDetailsFromJSON() error = %v", err)
+ }
+
+ if parsed.Name != tool.Name {
+ t.Errorf("Tool name mismatch: got %v, want %v", parsed.Name, tool.Name)
+ }
+
+ if parsed.Description != tool.Description {
+ t.Errorf("Tool description mismatch: got %v, want %v", parsed.Description, tool.Description)
+ }
+}
+
+func TestTokenMetrics_Validate(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ metrics *TokenMetrics
+ wantErr bool
+ }{
+ {
+ name: "Valid metrics with savings",
+ metrics: &TokenMetrics{
+ BaselineTokens: 1000,
+ ReturnedTokens: 600,
+ TokensSaved: 400,
+ SavingsPercentage: 40.0,
+ },
+ wantErr: false,
+ },
+ {
+ name: "Valid metrics with no savings",
+ metrics: &TokenMetrics{
+ BaselineTokens: 1000,
+ ReturnedTokens: 1000,
+ TokensSaved: 0,
+ SavingsPercentage: 0.0,
+ },
+ wantErr: false,
+ },
+ {
+ name: "Invalid: tokens saved doesn't match",
+ metrics: &TokenMetrics{
+ BaselineTokens: 1000,
+ ReturnedTokens: 600,
+ TokensSaved: 500, // Should be 400
+ SavingsPercentage: 40.0,
+ },
+ wantErr: true,
+ },
+ {
+ name: "Invalid: savings percentage doesn't match",
+ metrics: &TokenMetrics{
+ BaselineTokens: 1000,
+ ReturnedTokens: 600,
+ TokensSaved: 400,
+ SavingsPercentage: 50.0, // Should be 40.0
+ },
+ wantErr: true,
+ },
+ {
+ name: "Invalid: non-zero percentage with zero baseline",
+ metrics: &TokenMetrics{
+ BaselineTokens: 0,
+ ReturnedTokens: 0,
+ TokensSaved: 0,
+ SavingsPercentage: 10.0, // Should be 0
+ },
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ err := tt.metrics.Validate()
+ if (err != nil) != tt.wantErr {
+ t.Errorf("TokenMetrics.Validate() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestBackendWithRegistry_EffectiveDescription(t *testing.T) {
+ t.Parallel()
+ registryDesc := "Registry description"
+ backendDesc := "Backend description"
+
+ tests := []struct {
+ name string
+ w *BackendWithRegistry
+ want *string
+ }{
+ {
+ name: "Uses registry description when available",
+ w: &BackendWithRegistry{
+ Backend: BackendServer{
+ Description: &backendDesc,
+ },
+ Registry: &RegistryServer{
+ BaseMCPServer: BaseMCPServer{
+ Description: ®istryDesc,
+ },
+ },
+ },
+ want: ®istryDesc,
+ },
+ {
+ name: "Uses backend description when no registry",
+ w: &BackendWithRegistry{
+ Backend: BackendServer{
+ Description: &backendDesc,
+ },
+ Registry: nil,
+ },
+ want: &backendDesc,
+ },
+ {
+ name: "Returns nil when no description",
+ w: &BackendWithRegistry{
+ Backend: BackendServer{},
+ Registry: nil,
+ },
+ want: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := tt.w.EffectiveDescription()
+ if (got == nil) != (tt.want == nil) {
+ t.Errorf("BackendWithRegistry.EffectiveDescription() = %v, want %v", got, tt.want)
+ }
+ if got != nil && tt.want != nil && *got != *tt.want {
+ t.Errorf("BackendWithRegistry.EffectiveDescription() = %v, want %v", *got, *tt.want)
+ }
+ })
+ }
+}
+
+func TestBackendWithRegistry_ServerNameForTools(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ w *BackendWithRegistry
+ want string
+ }{
+ {
+ name: "Uses registry name when available",
+ w: &BackendWithRegistry{
+ Backend: BackendServer{
+ Name: "backend-name",
+ },
+ Registry: &RegistryServer{
+ BaseMCPServer: BaseMCPServer{
+ Name: "registry-name",
+ },
+ },
+ },
+ want: "registry-name",
+ },
+ {
+ name: "Uses backend name when no registry",
+ w: &BackendWithRegistry{
+ Backend: BackendServer{
+ Name: "backend-name",
+ },
+ Registry: nil,
+ },
+ want: "backend-name",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := tt.w.ServerNameForTools(); got != tt.want {
+ t.Errorf("BackendWithRegistry.ServerNameForTools() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/cmd/thv-operator/pkg/optimizer/models/transport.go b/cmd/thv-operator/pkg/optimizer/models/transport.go
new file mode 100644
index 0000000000..8764b7fd48
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/models/transport.go
@@ -0,0 +1,114 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package models
+
+import (
+ "database/sql/driver"
+ "fmt"
+)
+
+// TransportType represents the transport protocol used by an MCP server.
+// Maps 1:1 to ToolHive transport modes.
+type TransportType string
+
+const (
+ // TransportSSE represents Server-Sent Events transport
+ TransportSSE TransportType = "sse"
+ // TransportStreamable represents Streamable HTTP transport
+ TransportStreamable TransportType = "streamable-http"
+)
+
+// Valid returns true if the transport type is valid
+func (t TransportType) Valid() bool {
+ switch t {
+ case TransportSSE, TransportStreamable:
+ return true
+ default:
+ return false
+ }
+}
+
+// String returns the string representation
+func (t TransportType) String() string {
+ return string(t)
+}
+
+// Value implements the driver.Valuer interface for database storage
+func (t TransportType) Value() (driver.Value, error) {
+ if !t.Valid() {
+ return nil, fmt.Errorf("invalid transport type: %s", t)
+ }
+ return string(t), nil
+}
+
+// Scan implements the sql.Scanner interface for database retrieval
+func (t *TransportType) Scan(value interface{}) error {
+ if value == nil {
+ return fmt.Errorf("transport type cannot be nil")
+ }
+
+ str, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("transport type must be a string, got %T", value)
+ }
+
+ *t = TransportType(str)
+ if !t.Valid() {
+ return fmt.Errorf("invalid transport type from database: %s", str)
+ }
+
+ return nil
+}
+
+// MCPStatus represents the status of an MCP server backend.
+type MCPStatus string
+
+const (
+ // StatusRunning indicates the backend is running
+ StatusRunning MCPStatus = "running"
+ // StatusStopped indicates the backend is stopped
+ StatusStopped MCPStatus = "stopped"
+)
+
+// Valid returns true if the status is valid
+func (s MCPStatus) Valid() bool {
+ switch s {
+ case StatusRunning, StatusStopped:
+ return true
+ default:
+ return false
+ }
+}
+
+// String returns the string representation
+func (s MCPStatus) String() string {
+ return string(s)
+}
+
+// Value implements the driver.Valuer interface for database storage
+func (s MCPStatus) Value() (driver.Value, error) {
+ if !s.Valid() {
+ return nil, fmt.Errorf("invalid MCP status: %s", s)
+ }
+ return string(s), nil
+}
+
+// Scan implements the sql.Scanner interface for database retrieval
+func (s *MCPStatus) Scan(value interface{}) error {
+ if value == nil {
+ return fmt.Errorf("MCP status cannot be nil")
+ }
+
+ str, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("MCP status must be a string, got %T", value)
+ }
+
+ *s = MCPStatus(str)
+ if !s.Valid() {
+ return fmt.Errorf("invalid MCP status from database: %s", str)
+ }
+
+ return nil
+}
diff --git a/cmd/thv-operator/pkg/optimizer/models/transport_test.go b/cmd/thv-operator/pkg/optimizer/models/transport_test.go
new file mode 100644
index 0000000000..156062c595
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/models/transport_test.go
@@ -0,0 +1,276 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package models
+
+import (
+ "testing"
+)
+
+func TestTransportType_Valid(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ transport TransportType
+ want bool
+ }{
+ {
+ name: "SSE transport is valid",
+ transport: TransportSSE,
+ want: true,
+ },
+ {
+ name: "Streamable transport is valid",
+ transport: TransportStreamable,
+ want: true,
+ },
+ {
+ name: "Invalid transport is not valid",
+ transport: TransportType("invalid"),
+ want: false,
+ },
+ {
+ name: "Empty transport is not valid",
+ transport: TransportType(""),
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := tt.transport.Valid(); got != tt.want {
+ t.Errorf("TransportType.Valid() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestTransportType_Value(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ transport TransportType
+ wantValue string
+ wantErr bool
+ }{
+ {
+ name: "SSE transport value",
+ transport: TransportSSE,
+ wantValue: "sse",
+ wantErr: false,
+ },
+ {
+ name: "Streamable transport value",
+ transport: TransportStreamable,
+ wantValue: "streamable-http",
+ wantErr: false,
+ },
+ {
+ name: "Invalid transport returns error",
+ transport: TransportType("invalid"),
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got, err := tt.transport.Value()
+ if (err != nil) != tt.wantErr {
+ t.Errorf("TransportType.Value() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if !tt.wantErr && got != tt.wantValue {
+ t.Errorf("TransportType.Value() = %v, want %v", got, tt.wantValue)
+ }
+ })
+ }
+}
+
+func TestTransportType_Scan(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ value interface{}
+ want TransportType
+ wantErr bool
+ }{
+ {
+ name: "Scan SSE transport",
+ value: "sse",
+ want: TransportSSE,
+ wantErr: false,
+ },
+ {
+ name: "Scan streamable transport",
+ value: "streamable-http",
+ want: TransportStreamable,
+ wantErr: false,
+ },
+ {
+ name: "Scan invalid transport returns error",
+ value: "invalid",
+ wantErr: true,
+ },
+ {
+ name: "Scan nil returns error",
+ value: nil,
+ wantErr: true,
+ },
+ {
+ name: "Scan non-string returns error",
+ value: 123,
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ var transport TransportType
+ err := transport.Scan(tt.value)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("TransportType.Scan() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if !tt.wantErr && transport != tt.want {
+ t.Errorf("TransportType.Scan() = %v, want %v", transport, tt.want)
+ }
+ })
+ }
+}
+
+func TestMCPStatus_Valid(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ status MCPStatus
+ want bool
+ }{
+ {
+ name: "Running status is valid",
+ status: StatusRunning,
+ want: true,
+ },
+ {
+ name: "Stopped status is valid",
+ status: StatusStopped,
+ want: true,
+ },
+ {
+ name: "Invalid status is not valid",
+ status: MCPStatus("invalid"),
+ want: false,
+ },
+ {
+ name: "Empty status is not valid",
+ status: MCPStatus(""),
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := tt.status.Valid(); got != tt.want {
+ t.Errorf("MCPStatus.Valid() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestMCPStatus_Value(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ status MCPStatus
+ wantValue string
+ wantErr bool
+ }{
+ {
+ name: "Running status value",
+ status: StatusRunning,
+ wantValue: "running",
+ wantErr: false,
+ },
+ {
+ name: "Stopped status value",
+ status: StatusStopped,
+ wantValue: "stopped",
+ wantErr: false,
+ },
+ {
+ name: "Invalid status returns error",
+ status: MCPStatus("invalid"),
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got, err := tt.status.Value()
+ if (err != nil) != tt.wantErr {
+ t.Errorf("MCPStatus.Value() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if !tt.wantErr && got != tt.wantValue {
+ t.Errorf("MCPStatus.Value() = %v, want %v", got, tt.wantValue)
+ }
+ })
+ }
+}
+
+func TestMCPStatus_Scan(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ value interface{}
+ want MCPStatus
+ wantErr bool
+ }{
+ {
+ name: "Scan running status",
+ value: "running",
+ want: StatusRunning,
+ wantErr: false,
+ },
+ {
+ name: "Scan stopped status",
+ value: "stopped",
+ want: StatusStopped,
+ wantErr: false,
+ },
+ {
+ name: "Scan invalid status returns error",
+ value: "invalid",
+ wantErr: true,
+ },
+ {
+ name: "Scan nil returns error",
+ value: nil,
+ wantErr: true,
+ },
+ {
+ name: "Scan non-string returns error",
+ value: 123,
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ var status MCPStatus
+ err := status.Scan(tt.value)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("MCPStatus.Scan() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if !tt.wantErr && status != tt.want {
+ t.Errorf("MCPStatus.Scan() = %v, want %v", status, tt.want)
+ }
+ })
+ }
+}
diff --git a/cmd/thv-operator/pkg/optimizer/tokens/counter.go b/cmd/thv-operator/pkg/optimizer/tokens/counter.go
new file mode 100644
index 0000000000..11ed33c118
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/tokens/counter.go
@@ -0,0 +1,68 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+// Package tokens provides token counting utilities for LLM cost estimation.
+// It estimates token counts for MCP tools and their metadata.
+package tokens
+
+import (
+ "encoding/json"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// Counter counts tokens for LLM consumption
+// This provides estimates of token usage for tools
+type Counter struct {
+ // Simple heuristic: ~4 characters per token for English text
+ charsPerToken float64
+}
+
+// NewCounter creates a new token counter
+func NewCounter() *Counter {
+ return &Counter{
+ charsPerToken: 4.0, // GPT-style tokenization approximation
+ }
+}
+
+// CountToolTokens estimates the number of tokens for a tool
+func (c *Counter) CountToolTokens(tool mcp.Tool) int {
+ // Convert tool to JSON representation (as it would be sent to LLM)
+ toolJSON, err := json.Marshal(tool)
+ if err != nil {
+ // Fallback to simple estimation
+ return c.estimateFromTool(tool)
+ }
+
+ // Estimate tokens from JSON length
+ return int(float64(len(toolJSON)) / c.charsPerToken)
+}
+
+// estimateFromTool provides a fallback estimation from tool fields
+func (c *Counter) estimateFromTool(tool mcp.Tool) int {
+ totalChars := len(tool.Name)
+
+ if tool.Description != "" {
+ totalChars += len(tool.Description)
+ }
+
+ // Estimate input schema size
+ schemaJSON, _ := json.Marshal(tool.InputSchema)
+ totalChars += len(schemaJSON)
+
+ return int(float64(totalChars) / c.charsPerToken)
+}
+
+// CountToolsTokens calculates total tokens for multiple tools
+func (c *Counter) CountToolsTokens(tools []mcp.Tool) int {
+ total := 0
+ for _, tool := range tools {
+ total += c.CountToolTokens(tool)
+ }
+ return total
+}
+
+// EstimateText estimates tokens for arbitrary text
+func (c *Counter) EstimateText(text string) int {
+ return int(float64(len(text)) / c.charsPerToken)
+}
diff --git a/cmd/thv-operator/pkg/optimizer/tokens/counter_test.go b/cmd/thv-operator/pkg/optimizer/tokens/counter_test.go
new file mode 100644
index 0000000000..082ee385a1
--- /dev/null
+++ b/cmd/thv-operator/pkg/optimizer/tokens/counter_test.go
@@ -0,0 +1,146 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package tokens
+
+import (
+ "testing"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+func TestCountToolTokens(t *testing.T) {
+ t.Parallel()
+ counter := NewCounter()
+
+ tool := mcp.Tool{
+ Name: "test_tool",
+ Description: "A test tool for counting tokens",
+ }
+
+ tokens := counter.CountToolTokens(tool)
+
+ // Should return a positive number
+ if tokens <= 0 {
+ t.Errorf("Expected positive token count, got %d", tokens)
+ }
+
+ // Rough estimate: tool should have at least a few tokens
+ if tokens < 5 {
+ t.Errorf("Expected at least 5 tokens for a tool with name and description, got %d", tokens)
+ }
+}
+
+func TestCountToolTokens_MinimalTool(t *testing.T) {
+ t.Parallel()
+ counter := NewCounter()
+
+ // Minimal tool with just a name
+ tool := mcp.Tool{
+ Name: "minimal",
+ }
+
+ tokens := counter.CountToolTokens(tool)
+
+ // Should return a positive number even for minimal tool
+ if tokens <= 0 {
+ t.Errorf("Expected positive token count for minimal tool, got %d", tokens)
+ }
+}
+
+func TestCountToolTokens_NoDescription(t *testing.T) {
+ t.Parallel()
+ counter := NewCounter()
+
+ tool := mcp.Tool{
+ Name: "test_tool",
+ }
+
+ tokens := counter.CountToolTokens(tool)
+
+ // Should still return a positive number
+ if tokens <= 0 {
+ t.Errorf("Expected positive token count for tool without description, got %d", tokens)
+ }
+}
+
+func TestCountToolsTokens(t *testing.T) {
+ t.Parallel()
+ counter := NewCounter()
+
+ tools := []mcp.Tool{
+ {
+ Name: "tool1",
+ Description: "First tool",
+ },
+ {
+ Name: "tool2",
+ Description: "Second tool with longer description",
+ },
+ }
+
+ totalTokens := counter.CountToolsTokens(tools)
+
+ // Should be greater than individual tools
+ tokens1 := counter.CountToolTokens(tools[0])
+ tokens2 := counter.CountToolTokens(tools[1])
+
+ expectedTotal := tokens1 + tokens2
+ if totalTokens != expectedTotal {
+ t.Errorf("Expected total tokens %d, got %d", expectedTotal, totalTokens)
+ }
+}
+
+func TestCountToolsTokens_EmptyList(t *testing.T) {
+ t.Parallel()
+ counter := NewCounter()
+
+ tokens := counter.CountToolsTokens([]mcp.Tool{})
+
+ // Should return 0 for empty list
+ if tokens != 0 {
+ t.Errorf("Expected 0 tokens for empty list, got %d", tokens)
+ }
+}
+
+func TestEstimateText(t *testing.T) {
+ t.Parallel()
+ counter := NewCounter()
+
+ tests := []struct {
+ name string
+ text string
+ want int
+ }{
+ {
+ name: "Empty text",
+ text: "",
+ want: 0,
+ },
+ {
+ name: "Short text",
+ text: "Hello",
+ want: 1, // 5 chars / 4 chars per token ≈ 1
+ },
+ {
+ name: "Medium text",
+ text: "This is a test message",
+ want: 5, // 22 chars / 4 chars per token ≈ 5
+ },
+ {
+ name: "Long text",
+ text: "This is a much longer test message that should have more tokens because it contains significantly more characters",
+ want: 28, // 112 chars / 4 chars per token = 28
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := counter.EstimateText(tt.text)
+ if got != tt.want {
+ t.Errorf("EstimateText() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/cmd/thv-operator/pkg/vmcpconfig/converter.go b/cmd/thv-operator/pkg/vmcpconfig/converter.go
index d5e283f87b..47264f422e 100644
--- a/cmd/thv-operator/pkg/vmcpconfig/converter.go
+++ b/cmd/thv-operator/pkg/vmcpconfig/converter.go
@@ -135,6 +135,17 @@ func (c *Converter) Convert(
// are handled by kubebuilder annotations in pkg/telemetry/config.go and applied by the API server.
config.Telemetry = spectoconfig.NormalizeTelemetryConfig(vmcp.Spec.Config.Telemetry, vmcp.Name)
+ // Convert audit config
+ c.convertAuditConfig(config, vmcp)
+
+ // Apply operational defaults (fills missing values)
+ config.EnsureOperationalDefaults()
+
+ return config, nil
+}
+
+// convertAuditConfig converts audit configuration from CRD to vmcp config.
+func (*Converter) convertAuditConfig(config *vmcpconfig.Config, vmcp *mcpv1alpha1.VirtualMCPServer) {
if vmcp.Spec.Config.Audit != nil && vmcp.Spec.Config.Audit.Enabled {
config.Audit = vmcp.Spec.Config.Audit
}
@@ -142,11 +153,6 @@ func (c *Converter) Convert(
if config.Audit != nil && config.Audit.Component == "" {
config.Audit.Component = vmcp.Name
}
-
- // Apply operational defaults (fills missing values)
- config.EnsureOperationalDefaults()
-
- return config, nil
}
// convertIncomingAuth converts IncomingAuthConfig from CRD to vmcp config.
diff --git a/cmd/vmcp/README.md b/cmd/vmcp/README.md
index 30ac862ca2..10a60bef2a 100644
--- a/cmd/vmcp/README.md
+++ b/cmd/vmcp/README.md
@@ -6,7 +6,7 @@ The Virtual MCP Server (vmcp) is a standalone binary that aggregates multiple MC
## Features
-### Implemented (Phase 1)
+### Implemented
- ✅ **Group-Based Backend Management**: Automatic workload discovery from ToolHive groups
- ✅ **Tool Aggregation**: Combines tools from multiple MCP servers with conflict resolution (prefix, priority, manual)
- ✅ **Resource & Prompt Aggregation**: Unified access to resources and prompts from all backends
@@ -15,12 +15,14 @@ The Virtual MCP Server (vmcp) is a standalone binary that aggregates multiple MC
- ✅ **Health Endpoints**: `/health` and `/ping` for service monitoring
- ✅ **Configuration Validation**: `vmcp validate` command for config verification
- ✅ **Observability**: OpenTelemetry metrics and traces for backend operations and workflow executions
+- ✅ **Composite Tools**: Multi-step workflows with elicitation support
### In Progress
- 🚧 **Incoming Authentication** (Issue #165): OIDC, local, anonymous authentication
- 🚧 **Outgoing Authentication** (Issue #160): RFC 8693 token exchange for backend API access
- 🚧 **Token Caching**: Memory and Redis cache providers
- 🚧 **Health Monitoring** (Issue #166): Circuit breakers, backend health checks
+- 🚧 **Optimizer** Support the MCP optimizer in vMCP for context optimization on large toolsets.
### Future (Phase 2+)
- 📋 **Authorization**: Cedar policy-based access control
diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go
index f9c0aa8a70..317c0e1ad8 100644
--- a/cmd/vmcp/app/commands.go
+++ b/cmd/vmcp/app/commands.go
@@ -27,7 +27,7 @@ import (
"github.com/stacklok/toolhive/pkg/vmcp/discovery"
"github.com/stacklok/toolhive/pkg/vmcp/health"
"github.com/stacklok/toolhive/pkg/vmcp/k8s"
- "github.com/stacklok/toolhive/pkg/vmcp/optimizer"
+ vmcpoptimizer "github.com/stacklok/toolhive/pkg/vmcp/optimizer"
vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router"
vmcpserver "github.com/stacklok/toolhive/pkg/vmcp/server"
vmcpstatus "github.com/stacklok/toolhive/pkg/vmcp/status"
@@ -436,9 +436,28 @@ func runServe(cmd *cobra.Command, _ []string) error {
StatusReporter: statusReporter,
}
- if cfg.Optimizer != nil {
- // TODO: update this with the real optimizer.
- serverCfg.OptimizerFactory = optimizer.NewDummyOptimizer
+ // Configure optimizer if enabled in YAML config
+ if cfg.Optimizer != nil && cfg.Optimizer.Enabled {
+ logger.Info("🔬 Optimizer enabled via configuration (chromem-go)")
+ optimizerCfg := vmcpoptimizer.ConfigFromVMCPConfig(cfg.Optimizer)
+ serverCfg.OptimizerConfig = optimizerCfg
+ persistInfo := "in-memory"
+ if cfg.Optimizer.PersistPath != "" {
+ persistInfo = cfg.Optimizer.PersistPath
+ }
+ // FTS5 is always enabled with configurable semantic/BM25 ratio
+ ratio := 70 // Default (70%)
+ if cfg.Optimizer.HybridSearchRatio != nil {
+ ratio = *cfg.Optimizer.HybridSearchRatio
+ }
+ searchMode := fmt.Sprintf("hybrid (%d%% semantic, %d%% BM25)",
+ ratio,
+ 100-ratio)
+ logger.Infof("Optimizer configured: backend=%s, dimension=%d, persistence=%s, search=%s",
+ cfg.Optimizer.EmbeddingBackend,
+ cfg.Optimizer.EmbeddingDimension,
+ persistInfo,
+ searchMode)
}
// Convert composite tool configurations to workflow definitions
diff --git a/codecov.yaml b/codecov.yaml
index 1a8032e484..410f9ae7ee 100644
--- a/codecov.yaml
+++ b/codecov.yaml
@@ -13,6 +13,8 @@ coverage:
- "**/mocks/**/*"
- "**/mock_*.go"
- "**/zz_generated.deepcopy.go"
+ - "**/*_test.go"
+ - "**/*_test_coverage.go"
status:
project:
default:
diff --git a/deploy/charts/operator-crds/Chart.yaml b/deploy/charts/operator-crds/Chart.yaml
index 01865a110d..1b14897d71 100644
--- a/deploy/charts/operator-crds/Chart.yaml
+++ b/deploy/charts/operator-crds/Chart.yaml
@@ -2,5 +2,5 @@ apiVersion: v2
name: toolhive-operator-crds
description: A Helm chart for installing the ToolHive Operator CRDs into Kubernetes.
type: application
-version: 0.0.101
+version: 0.0.102
appVersion: "0.0.1"
diff --git a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml
index 318099bce9..0f153da6a3 100644
--- a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml
+++ b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml
@@ -677,17 +677,74 @@ spec:
optimizer:
description: |-
Optimizer configures the MCP optimizer for context optimization on large toolsets.
- When enabled, vMCP exposes only find_tool and call_tool operations to clients
+ When enabled, vMCP exposes optim_find_tool and optim_call_tool operations to clients
instead of all backend tools directly. This reduces token usage by allowing
LLMs to discover relevant tools on demand rather than receiving all tool definitions.
properties:
- embeddingService:
+ embeddingBackend:
description: |-
- EmbeddingService is the name of a Kubernetes Service that provides the embedding service
- for semantic tool discovery. The service must implement the optimizer embedding API.
+ EmbeddingBackend specifies the embedding provider: "ollama", "openai-compatible", or "placeholder".
+ - "ollama": Uses local Ollama HTTP API for embeddings
+ - "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.)
+ - "placeholder": Uses deterministic hash-based embeddings (for testing/development)
+ enum:
+ - ollama
+ - openai-compatible
+ - placeholder
+ type: string
+ embeddingDimension:
+ description: |-
+ EmbeddingDimension is the dimension of the embedding vectors.
+ Common values:
+ - 384: all-MiniLM-L6-v2, nomic-embed-text
+ - 768: BAAI/bge-small-en-v1.5
+ - 1536: OpenAI text-embedding-3-small
+ minimum: 1
+ type: integer
+ embeddingModel:
+ description: |-
+ EmbeddingModel is the model name to use for embeddings.
+ Required when EmbeddingBackend is "ollama" or "openai-compatible".
+ Examples:
+ - Ollama: "nomic-embed-text", "all-minilm"
+ - vLLM: "BAAI/bge-small-en-v1.5"
+ - OpenAI: "text-embedding-3-small"
+ type: string
+ embeddingURL:
+ description: |-
+ EmbeddingURL is the base URL for the embedding service (Ollama or OpenAI-compatible API).
+ Required when EmbeddingBackend is "ollama" or "openai-compatible".
+ Examples:
+ - Ollama: "http://localhost:11434"
+ - vLLM: "http://vllm-service:8000/v1"
+ - OpenAI: "https://api.openai.com/v1"
+ type: string
+ enabled:
+ description: |-
+ Enabled determines whether the optimizer is active.
+ When true, vMCP exposes optim_find_tool and optim_call_tool instead of all backend tools.
+ type: boolean
+ ftsDBPath:
+ description: |-
+ FTSDBPath is the path to the SQLite FTS5 database for BM25 text search.
+ If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set.
+ Hybrid search (semantic + BM25) is always enabled.
+ type: string
+ hybridSearchRatio:
+ description: |-
+ HybridSearchRatio controls the mix of semantic vs BM25 results in hybrid search.
+ Value range: 0 (all BM25) to 100 (all semantic), representing percentage.
+ Default: 70 (70% semantic, 30% BM25)
+ Only used when FTSDBPath is set.
+ maximum: 100
+ minimum: 0
+ type: integer
+ persistPath:
+ description: |-
+ PersistPath is the optional filesystem path for persisting the chromem-go database.
+ If empty, the database will be in-memory only (ephemeral).
+ When set, tool metadata and embeddings are persisted to disk for faster restarts.
type: string
- required:
- - embeddingService
type: object
outgoingAuth:
description: |-
diff --git a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml
index 7ebc65a9ab..48a26c7069 100644
--- a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml
+++ b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml
@@ -680,17 +680,74 @@ spec:
optimizer:
description: |-
Optimizer configures the MCP optimizer for context optimization on large toolsets.
- When enabled, vMCP exposes only find_tool and call_tool operations to clients
+ When enabled, vMCP exposes optim_find_tool and optim_call_tool operations to clients
instead of all backend tools directly. This reduces token usage by allowing
LLMs to discover relevant tools on demand rather than receiving all tool definitions.
properties:
- embeddingService:
+ embeddingBackend:
description: |-
- EmbeddingService is the name of a Kubernetes Service that provides the embedding service
- for semantic tool discovery. The service must implement the optimizer embedding API.
+ EmbeddingBackend specifies the embedding provider: "ollama", "openai-compatible", or "placeholder".
+ - "ollama": Uses local Ollama HTTP API for embeddings
+ - "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.)
+ - "placeholder": Uses deterministic hash-based embeddings (for testing/development)
+ enum:
+ - ollama
+ - openai-compatible
+ - placeholder
+ type: string
+ embeddingDimension:
+ description: |-
+ EmbeddingDimension is the dimension of the embedding vectors.
+ Common values:
+ - 384: all-MiniLM-L6-v2, nomic-embed-text
+ - 768: BAAI/bge-small-en-v1.5
+ - 1536: OpenAI text-embedding-3-small
+ minimum: 1
+ type: integer
+ embeddingModel:
+ description: |-
+ EmbeddingModel is the model name to use for embeddings.
+ Required when EmbeddingBackend is "ollama" or "openai-compatible".
+ Examples:
+ - Ollama: "nomic-embed-text", "all-minilm"
+ - vLLM: "BAAI/bge-small-en-v1.5"
+ - OpenAI: "text-embedding-3-small"
+ type: string
+ embeddingURL:
+ description: |-
+ EmbeddingURL is the base URL for the embedding service (Ollama or OpenAI-compatible API).
+ Required when EmbeddingBackend is "ollama" or "openai-compatible".
+ Examples:
+ - Ollama: "http://localhost:11434"
+ - vLLM: "http://vllm-service:8000/v1"
+ - OpenAI: "https://api.openai.com/v1"
+ type: string
+ enabled:
+ description: |-
+ Enabled determines whether the optimizer is active.
+ When true, vMCP exposes optim_find_tool and optim_call_tool instead of all backend tools.
+ type: boolean
+ ftsDBPath:
+ description: |-
+ FTSDBPath is the path to the SQLite FTS5 database for BM25 text search.
+ If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set.
+ Hybrid search (semantic + BM25) is always enabled.
+ type: string
+ hybridSearchRatio:
+ description: |-
+ HybridSearchRatio controls the mix of semantic vs BM25 results in hybrid search.
+ Value range: 0 (all BM25) to 100 (all semantic), representing percentage.
+ Default: 70 (70% semantic, 30% BM25)
+ Only used when FTSDBPath is set.
+ maximum: 100
+ minimum: 0
+ type: integer
+ persistPath:
+ description: |-
+ PersistPath is the optional filesystem path for persisting the chromem-go database.
+ If empty, the database will be in-memory only (ephemeral).
+ When set, tool metadata and embeddings are persisted to disk for faster restarts.
type: string
- required:
- - embeddingService
type: object
outgoingAuth:
description: |-
diff --git a/docs/operator/crd-api.md b/docs/operator/crd-api.md
index 3d075ce09b..020e665859 100644
--- a/docs/operator/crd-api.md
+++ b/docs/operator/crd-api.md
@@ -245,7 +245,7 @@ _Appears in:_
| `metadata` _object (keys:string, values:string)_ | Refer to Kubernetes API documentation for fields of `metadata`. | | |
| `telemetry` _[pkg.telemetry.Config](#pkgtelemetryconfig)_ | Telemetry configures OpenTelemetry-based observability for the Virtual MCP server
including distributed tracing, OTLP metrics export, and Prometheus metrics endpoint. | | |
| `audit` _[pkg.audit.Config](#pkgauditconfig)_ | Audit configures audit logging for the Virtual MCP server.
When present, audit logs include MCP protocol operations.
See audit.Config for available configuration options. | | |
-| `optimizer` _[vmcp.config.OptimizerConfig](#vmcpconfigoptimizerconfig)_ | Optimizer configures the MCP optimizer for context optimization on large toolsets.
When enabled, vMCP exposes only find_tool and call_tool operations to clients
instead of all backend tools directly. This reduces token usage by allowing
LLMs to discover relevant tools on demand rather than receiving all tool definitions. | | |
+| `optimizer` _[vmcp.config.OptimizerConfig](#vmcpconfigoptimizerconfig)_ | Optimizer configures the MCP optimizer for context optimization on large toolsets.
When enabled, vMCP exposes optim_find_tool and optim_call_tool operations to clients
instead of all backend tools directly. This reduces token usage by allowing
LLMs to discover relevant tools on demand rather than receiving all tool definitions. | | |
#### vmcp.config.ConflictResolutionConfig
@@ -377,9 +377,9 @@ _Appears in:_
-OptimizerConfig configures the MCP optimizer.
-When enabled, vMCP exposes only find_tool and call_tool operations to clients
-instead of all backend tools directly.
+OptimizerConfig configures the MCP optimizer for semantic tool discovery.
+The optimizer reduces token usage by allowing LLMs to discover relevant tools
+on demand rather than receiving all tool definitions upfront.
@@ -388,7 +388,14 @@ _Appears in:_
| Field | Description | Default | Validation |
| --- | --- | --- | --- |
-| `embeddingService` _string_ | EmbeddingService is the name of a Kubernetes Service that provides the embedding service
for semantic tool discovery. The service must implement the optimizer embedding API. | | Required: \{\}
|
+| `enabled` _boolean_ | Enabled determines whether the optimizer is active.
When true, vMCP exposes optim_find_tool and optim_call_tool instead of all backend tools. | | |
+| `embeddingBackend` _string_ | EmbeddingBackend specifies the embedding provider: "ollama", "openai-compatible", or "placeholder".
- "ollama": Uses local Ollama HTTP API for embeddings
- "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.)
- "placeholder": Uses deterministic hash-based embeddings (for testing/development) | | Enum: [ollama openai-compatible placeholder]
|
+| `embeddingURL` _string_ | EmbeddingURL is the base URL for the embedding service (Ollama or OpenAI-compatible API).
Required when EmbeddingBackend is "ollama" or "openai-compatible".
Examples:
- Ollama: "http://localhost:11434"
- vLLM: "http://vllm-service:8000/v1"
- OpenAI: "https://api.openai.com/v1" | | |
+| `embeddingModel` _string_ | EmbeddingModel is the model name to use for embeddings.
Required when EmbeddingBackend is "ollama" or "openai-compatible".
Examples:
- Ollama: "nomic-embed-text", "all-minilm"
- vLLM: "BAAI/bge-small-en-v1.5"
- OpenAI: "text-embedding-3-small" | | |
+| `embeddingDimension` _integer_ | EmbeddingDimension is the dimension of the embedding vectors.
Common values:
- 384: all-MiniLM-L6-v2, nomic-embed-text
- 768: BAAI/bge-small-en-v1.5
- 1536: OpenAI text-embedding-3-small | | Minimum: 1
|
+| `persistPath` _string_ | PersistPath is the optional filesystem path for persisting the chromem-go database.
If empty, the database will be in-memory only (ephemeral).
When set, tool metadata and embeddings are persisted to disk for faster restarts. | | |
+| `ftsDBPath` _string_ | FTSDBPath is the path to the SQLite FTS5 database for BM25 text search.
If empty, defaults to ":memory:" for in-memory FTS5, or "\{PersistPath\}/fts.db" if PersistPath is set.
Hybrid search (semantic + BM25) is always enabled. | | |
+| `hybridSearchRatio` _integer_ | HybridSearchRatio controls the mix of semantic vs BM25 results in hybrid search.
Value range: 0 (all BM25) to 100 (all semantic), representing percentage.
Default: 70 (70% semantic, 30% BM25)
Only used when FTSDBPath is set. | | Maximum: 100
Minimum: 0
|
#### vmcp.config.OutgoingAuthConfig
diff --git a/examples/vmcp-config-optimizer.yaml b/examples/vmcp-config-optimizer.yaml
new file mode 100644
index 0000000000..547c60e5f6
--- /dev/null
+++ b/examples/vmcp-config-optimizer.yaml
@@ -0,0 +1,126 @@
+# vMCP Configuration with Optimizer Enabled
+# This configuration enables the optimizer for semantic tool discovery
+
+name: "vmcp-debug"
+
+# Reference to ToolHive group containing MCP servers
+groupRef: "default"
+
+# Client authentication (anonymous for local development)
+incomingAuth:
+ type: anonymous
+
+# Backend authentication (unauthenticated for local development)
+outgoingAuth:
+ source: inline
+ default:
+ type: unauthenticated
+
+# Tool aggregation settings
+aggregation:
+ conflictResolution: prefix
+ conflictResolutionConfig:
+ prefixFormat: "{workload}_"
+
+# Operational settings
+operational:
+ timeouts:
+ default: 30s
+ failureHandling:
+ healthCheckInterval: 30s
+ unhealthyThreshold: 3
+ partialFailureMode: fail
+
+# =============================================================================
+# OPTIMIZER CONFIGURATION
+# =============================================================================
+# When enabled, vMCP exposes optim.find_tool and optim.call_tool instead of
+# all backend tools directly. This reduces token usage by allowing LLMs to
+# discover relevant tools on demand via semantic search.
+#
+# The optimizer ingests tools from all backends in the group, generates
+# embeddings, and provides semantic search capabilities.
+
+optimizer:
+ # Enable the optimizer
+ enabled: true
+
+ # Embedding backend: "ollama" (default), "openai-compatible", or "vllm"
+ # - "ollama": Uses local Ollama HTTP API for embeddings (default, requires 'ollama serve')
+ # - "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.)
+ # - "vllm": Alias for OpenAI-compatible API
+ embeddingBackend: ollama
+
+ # Embedding dimension (common values: 384, 768, 1536)
+ # 384 is standard for all-MiniLM-L6-v2 and nomic-embed-text
+ embeddingDimension: 384
+
+ # Optional: Path for persisting the chromem-go database
+ # If omitted, the database will be in-memory only (ephemeral)
+ persistPath: /tmp/vmcp-optimizer-debug.db
+
+ # Optional: Path for the SQLite FTS5 database (for hybrid search)
+ # Default: ":memory:" (in-memory) or "{persistPath}/fts.db" if persistPath is set
+ # Hybrid search (semantic + BM25) is ALWAYS enabled
+ ftsDBPath: /tmp/vmcp-optimizer-fts.db # Uncomment to customize location
+
+ # Optional: Hybrid search ratio (0-100, representing percentage)
+ # Default: 70 (70% semantic, 30% BM25)
+ # hybridSearchRatio: 70
+
+ # =============================================================================
+ # PRODUCTION CONFIGURATIONS (Commented Examples)
+ # =============================================================================
+
+ # Option 1: Local Ollama (good for development/testing)
+ # embeddingBackend: ollama
+ # embeddingURL: http://localhost:11434
+ # embeddingModel: all-minilm # Default model (all-MiniLM-L6-v2)
+ # embeddingDimension: 384
+
+ # Option 2: vLLM (recommended for production with GPU acceleration)
+ # embeddingBackend: openai-compatible
+ # embeddingURL: http://vllm-service:8000/v1
+ # embeddingModel: BAAI/bge-small-en-v1.5
+ # embeddingDimension: 768
+
+ # Option 3: OpenAI API (cloud-based)
+ # embeddingBackend: openai-compatible
+ # embeddingURL: https://api.openai.com/v1
+ # embeddingModel: text-embedding-3-small
+ # embeddingDimension: 1536
+ # (requires OPENAI_API_KEY environment variable)
+
+ # Option 4: Kubernetes in-cluster service (K8s deployments)
+ # embeddingURL: http://embedding-service-name.namespace.svc.cluster.local:port
+ # Use the full service DNS name with port for in-cluster services
+
+# =============================================================================
+# TELEMETRY CONFIGURATION (for Jaeger tracing)
+# =============================================================================
+# Configure OpenTelemetry to send traces to Jaeger
+telemetry:
+ endpoint: "localhost:4318" # OTLP HTTP endpoint (Jaeger collector) - no http:// prefix needed with insecure: true
+ serviceName: "vmcp-optimizer"
+ serviceVersion: "1.0.0" # Optional: service version
+ tracingEnabled: true
+ metricsEnabled: false # Set to true if you want metrics too
+ samplingRate: "1.0" # 100% sampling for development (use lower in production)
+ insecure: true # Use HTTP instead of HTTPS
+
+# =============================================================================
+# USAGE
+# =============================================================================
+# 1. Start MCP backends in the group:
+# thv run weather --group default
+# thv run github --group default
+#
+# 2. Start vMCP with optimizer:
+# thv vmcp serve --config examples/vmcp-config-optimizer.yaml
+#
+# 3. Connect MCP client to vMCP
+#
+# 4. Available tools from vMCP:
+# - optim.find_tool: Search for tools by semantic query
+# - optim.call_tool: Execute a tool by name
+# - (backend tools are NOT directly exposed when optimizer is enabled)
diff --git a/go.mod b/go.mod
index 0dca0108db..e041a8ca0a 100644
--- a/go.mod
+++ b/go.mod
@@ -29,6 +29,7 @@ require (
github.com/onsi/ginkgo/v2 v2.27.5
github.com/onsi/gomega v1.39.0
github.com/ory/fosite v0.49.0
+ github.com/philippgille/chromem-go v0.7.0
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/prometheus/client_golang v1.23.2
github.com/sigstore/protobuf-specs v0.5.0
@@ -59,6 +60,7 @@ require (
k8s.io/api v0.35.0
k8s.io/apimachinery v0.35.0
k8s.io/utils v0.0.0-20260108192941-914a6e750570
+ modernc.org/sqlite v1.44.0
sigs.k8s.io/controller-runtime v0.22.4
sigs.k8s.io/yaml v1.6.0
)
@@ -174,6 +176,7 @@ require (
github.com/muesli/termenv v0.16.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect
+ github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/oklog/ulid v1.3.1 // indirect
github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 // indirect
github.com/olekukonko/errors v1.1.0 // indirect
@@ -188,6 +191,7 @@ require (
github.com/prometheus/common v0.67.4 // indirect
github.com/prometheus/otlptranslator v1.0.0 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
+ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect
@@ -251,6 +255,9 @@ require (
k8s.io/apiextensions-apiserver v0.34.1 // indirect
k8s.io/klog/v2 v2.130.1 // indirect
k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 // indirect
+ modernc.org/libc v1.67.4 // indirect
+ modernc.org/mathutil v1.7.1 // indirect
+ modernc.org/memory v1.11.0 // indirect
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect
sigs.k8s.io/randfill v1.0.0 // indirect
sigs.k8s.io/structured-merge-diff/v6 v6.3.0 // indirect
@@ -286,7 +293,7 @@ require (
go.opentelemetry.io/otel/metric v1.39.0
go.opentelemetry.io/otel/trace v1.39.0
golang.org/x/crypto v0.47.0
- golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect
+ golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/sys v0.40.0
k8s.io/client-go v0.35.0
)
diff --git a/go.sum b/go.sum
index 2b66579e95..78126e3d7a 100644
--- a/go.sum
+++ b/go.sum
@@ -602,6 +602,8 @@ github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A=
github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM=
+github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
+github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/nyaruka/phonenumbers v1.1.6 h1:DcueYq7QrOArAprAYNoQfDgp0KetO4LqtnBtQC6Wyes=
github.com/nyaruka/phonenumbers v1.1.6/go.mod h1:yShPJHDSH3aTKzCbXyVxNpbl2kA+F+Ne5Pun/MvFRos=
github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4=
@@ -640,6 +642,8 @@ github.com/ory/x v0.0.665 h1:61vv0ObCDSX1vOQYbxBeqDiv4YiPmMT91lYxDaaKX08=
github.com/ory/x v0.0.665/go.mod h1:7SCTki3N0De3ZpqlxhxU/94ZrOCfNEnXwVtd0xVt+L8=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
+github.com/philippgille/chromem-go v0.7.0 h1:4jfvfyKymjKNfGxBUhHUcj1kp7B17NL/I1P+vGh1RvY=
+github.com/philippgille/chromem-go v0.7.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo=
github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4=
github.com/pjbgf/sha1cd v0.3.2/go.mod h1:zQWigSxVmsHEZow5qaLtPYxpcKMMQpa09ixqBxuCS6A=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
@@ -661,6 +665,8 @@ github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEo
github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM=
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
@@ -909,8 +915,8 @@ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
-golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
-golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
+golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
+golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/exp/event v0.0.0-20251219203646-944ab1f22d93 h1:Fee8ke0jLfLhU4ywDLs7IYmhJ8MrSP0iZE3p39EKKSc=
golang.org/x/exp/event v0.0.0-20251219203646-944ab1f22d93/go.mod h1:HgAgrKXB9WF2wFZJBGBnRVkmsC8n+v2ja/8VR0H3QkY=
golang.org/x/exp/jsonrpc2 v0.0.0-20260112195511-716be5621a96 h1:cN9X2vSBmT3Ruw2UlbJNLJh0iBqTmtSB0dRfh5aumiY=
@@ -1086,6 +1092,34 @@ k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 h1:Y3gxNAuB0OBLImH611+UDZ
k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912/go.mod h1:kdmbQkyfwUagLfXIad1y2TdrjPFWp2Q89B3qkRwf/pQ=
k8s.io/utils v0.0.0-20260108192941-914a6e750570 h1:JT4W8lsdrGENg9W+YwwdLJxklIuKWdRm+BC+xt33FOY=
k8s.io/utils v0.0.0-20260108192941-914a6e750570/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk=
+modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
+modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
+modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
+modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
+modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
+modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
+modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
+modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
+modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
+modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
+modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
+modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
+modernc.org/libc v1.67.4 h1:zZGmCMUVPORtKv95c2ReQN5VDjvkoRm9GWPTEPuvlWg=
+modernc.org/libc v1.67.4/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
+modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
+modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
+modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
+modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
+modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
+modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
+modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
+modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
+modernc.org/sqlite v1.44.0 h1:YjCKJnzZde2mLVy0cMKTSL4PxCmbIguOq9lGp8ZvGOc=
+modernc.org/sqlite v1.44.0/go.mod h1:2Dq41ir5/qri7QJJJKNZcP4UF7TsX/KNeykYgPDtGhE=
+modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
+modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
+modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
+modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
sigs.k8s.io/controller-runtime v0.22.4 h1:GEjV7KV3TY8e+tJ2LCTxUTanW4z/FmNB7l327UfMq9A=
sigs.k8s.io/controller-runtime v0.22.4/go.mod h1:+QX1XUpTXN4mLoblf4tqr5CQcyHPAki2HLXqQMY6vh8=
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 h1:IpInykpT6ceI+QxKBbEflcR5EXP7sU1kvOlxwZh5txg=
diff --git a/pkg/runner/config_builder_test.go b/pkg/runner/config_builder_test.go
index 735c9ccc45..0e4556937d 100644
--- a/pkg/runner/config_builder_test.go
+++ b/pkg/runner/config_builder_test.go
@@ -1076,8 +1076,8 @@ func TestRunConfigBuilder_WithRegistryProxyPort(t *testing.T) {
ProxyPort: 8976,
TargetPort: 8976,
},
- cliProxyPort: 9000,
- expectedProxyPort: 9000,
+ cliProxyPort: 9999,
+ expectedProxyPort: 9999,
},
{
name: "random port when neither CLI nor registry specified",
diff --git a/pkg/vmcp/aggregator/default_aggregator.go b/pkg/vmcp/aggregator/default_aggregator.go
index 95be734af0..3cf2846fcc 100644
--- a/pkg/vmcp/aggregator/default_aggregator.go
+++ b/pkg/vmcp/aggregator/default_aggregator.go
@@ -8,6 +8,10 @@ import (
"fmt"
"sync"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/codes"
+ "go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
"github.com/stacklok/toolhive/pkg/logger"
@@ -21,6 +25,7 @@ type defaultAggregator struct {
backendClient vmcp.BackendClient
conflictResolver ConflictResolver
toolConfigMap map[string]*config.WorkloadToolConfig // Maps backend ID to tool config
+ tracer trace.Tracer
}
// NewDefaultAggregator creates a new default aggregator implementation.
@@ -43,12 +48,20 @@ func NewDefaultAggregator(
backendClient: backendClient,
conflictResolver: conflictResolver,
toolConfigMap: toolConfigMap,
+ tracer: otel.Tracer("github.com/stacklok/toolhive/pkg/vmcp/aggregator"),
}
}
// QueryCapabilities queries a single backend for its MCP capabilities.
// Returns the raw capabilities (tools, resources, prompts) from the backend.
func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp.Backend) (*BackendCapabilities, error) {
+ ctx, span := a.tracer.Start(ctx, "aggregator.QueryCapabilities",
+ trace.WithAttributes(
+ attribute.String("backend.id", backend.ID),
+ ),
+ )
+ defer span.End()
+
logger.Debugf("Querying capabilities from backend %s", backend.ID)
// Create a BackendTarget from the Backend
@@ -58,6 +71,8 @@ func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp.
// Query capabilities using the backend client
capabilities, err := a.backendClient.ListCapabilities(ctx, target)
if err != nil {
+ span.RecordError(err)
+ span.SetStatus(codes.Error, err.Error())
return nil, fmt.Errorf("%w: %s: %w", ErrBackendQueryFailed, backend.ID, err)
}
@@ -74,6 +89,12 @@ func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp.
SupportsSampling: capabilities.SupportsSampling,
}
+ span.SetAttributes(
+ attribute.Int("tools.count", len(result.Tools)),
+ attribute.Int("resources.count", len(result.Resources)),
+ attribute.Int("prompts.count", len(result.Prompts)),
+ )
+
logger.Debugf("Backend %s: %d tools (after filtering/overrides), %d resources, %d prompts",
backend.ID, len(result.Tools), len(result.Resources), len(result.Prompts))
@@ -86,6 +107,13 @@ func (a *defaultAggregator) QueryAllCapabilities(
ctx context.Context,
backends []vmcp.Backend,
) (map[string]*BackendCapabilities, error) {
+ ctx, span := a.tracer.Start(ctx, "aggregator.QueryAllCapabilities",
+ trace.WithAttributes(
+ attribute.Int("backends.count", len(backends)),
+ ),
+ )
+ defer span.End()
+
logger.Infof("Querying capabilities from %d backends", len(backends))
// Use errgroup for parallel queries with context cancellation
@@ -118,13 +146,22 @@ func (a *defaultAggregator) QueryAllCapabilities(
// Wait for all queries to complete
if err := g.Wait(); err != nil {
+ span.RecordError(err)
+ span.SetStatus(codes.Error, err.Error())
return nil, fmt.Errorf("capability queries failed: %w", err)
}
if len(capabilities) == 0 {
- return nil, fmt.Errorf("no backends returned capabilities")
+ err := fmt.Errorf("no backends returned capabilities")
+ span.RecordError(err)
+ span.SetStatus(codes.Error, err.Error())
+ return nil, err
}
+ span.SetAttributes(
+ attribute.Int("successful.backends", len(capabilities)),
+ )
+
logger.Infof("Successfully queried %d/%d backends", len(capabilities), len(backends))
return capabilities, nil
}
@@ -135,6 +172,13 @@ func (a *defaultAggregator) ResolveConflicts(
ctx context.Context,
capabilities map[string]*BackendCapabilities,
) (*ResolvedCapabilities, error) {
+ ctx, span := a.tracer.Start(ctx, "aggregator.ResolveConflicts",
+ trace.WithAttributes(
+ attribute.Int("backends.count", len(capabilities)),
+ ),
+ )
+ defer span.End()
+
logger.Debugf("Resolving conflicts across %d backends", len(capabilities))
// Group tools by backend for conflict resolution
@@ -150,6 +194,8 @@ func (a *defaultAggregator) ResolveConflicts(
if a.conflictResolver != nil {
resolvedTools, err = a.conflictResolver.ResolveToolConflicts(ctx, toolsByBackend)
if err != nil {
+ span.RecordError(err)
+ span.SetStatus(codes.Error, err.Error())
return nil, fmt.Errorf("conflict resolution failed: %w", err)
}
} else {
@@ -191,6 +237,12 @@ func (a *defaultAggregator) ResolveConflicts(
resolved.SupportsSampling = resolved.SupportsSampling || caps.SupportsSampling
}
+ span.SetAttributes(
+ attribute.Int("resolved.tools", len(resolved.Tools)),
+ attribute.Int("resolved.resources", len(resolved.Resources)),
+ attribute.Int("resolved.prompts", len(resolved.Prompts)),
+ )
+
logger.Debugf("Resolved %d unique tools, %d resources, %d prompts",
len(resolved.Tools), len(resolved.Resources), len(resolved.Prompts))
@@ -199,11 +251,20 @@ func (a *defaultAggregator) ResolveConflicts(
// MergeCapabilities creates the final unified capability view and routing table.
// Uses the backend registry to populate full BackendTarget information for routing.
-func (*defaultAggregator) MergeCapabilities(
+func (a *defaultAggregator) MergeCapabilities(
ctx context.Context,
resolved *ResolvedCapabilities,
registry vmcp.BackendRegistry,
) (*AggregatedCapabilities, error) {
+ ctx, span := a.tracer.Start(ctx, "aggregator.MergeCapabilities",
+ trace.WithAttributes(
+ attribute.Int("resolved.tools", len(resolved.Tools)),
+ attribute.Int("resolved.resources", len(resolved.Resources)),
+ attribute.Int("resolved.prompts", len(resolved.Prompts)),
+ ),
+ )
+ defer span.End()
+
logger.Debugf("Merging capabilities into final view")
// Create routing table
@@ -304,6 +365,13 @@ func (*defaultAggregator) MergeCapabilities(
},
}
+ span.SetAttributes(
+ attribute.Int("aggregated.tools", aggregated.Metadata.ToolCount),
+ attribute.Int("aggregated.resources", aggregated.Metadata.ResourceCount),
+ attribute.Int("aggregated.prompts", aggregated.Metadata.PromptCount),
+ attribute.String("conflict.strategy", string(aggregated.Metadata.ConflictStrategy)),
+ )
+
logger.Infof("Merged capabilities: %d tools, %d resources, %d prompts",
aggregated.Metadata.ToolCount, aggregated.Metadata.ResourceCount, aggregated.Metadata.PromptCount)
@@ -316,6 +384,13 @@ func (*defaultAggregator) MergeCapabilities(
// 3. Resolve conflicts
// 4. Merge into final view with full backend information
func (a *defaultAggregator) AggregateCapabilities(ctx context.Context, backends []vmcp.Backend) (*AggregatedCapabilities, error) {
+ ctx, span := a.tracer.Start(ctx, "aggregator.AggregateCapabilities",
+ trace.WithAttributes(
+ attribute.Int("backends.count", len(backends)),
+ ),
+ )
+ defer span.End()
+
logger.Infof("Starting capability aggregation for %d backends", len(backends))
// Step 1: Create registry from discovered backends
@@ -325,24 +400,38 @@ func (a *defaultAggregator) AggregateCapabilities(ctx context.Context, backends
// Step 2: Query all backends
capabilities, err := a.QueryAllCapabilities(ctx, backends)
if err != nil {
+ span.RecordError(err)
+ span.SetStatus(codes.Error, err.Error())
return nil, fmt.Errorf("failed to query backends: %w", err)
}
// Step 3: Resolve conflicts
resolved, err := a.ResolveConflicts(ctx, capabilities)
if err != nil {
+ span.RecordError(err)
+ span.SetStatus(codes.Error, err.Error())
return nil, fmt.Errorf("failed to resolve conflicts: %w", err)
}
// Step 4: Merge into final view with full backend information
aggregated, err := a.MergeCapabilities(ctx, resolved, registry)
if err != nil {
+ span.RecordError(err)
+ span.SetStatus(codes.Error, err.Error())
return nil, fmt.Errorf("failed to merge capabilities: %w", err)
}
// Update metadata with backend count
aggregated.Metadata.BackendCount = len(backends)
+ span.SetAttributes(
+ attribute.Int("aggregated.backends", aggregated.Metadata.BackendCount),
+ attribute.Int("aggregated.tools", aggregated.Metadata.ToolCount),
+ attribute.Int("aggregated.resources", aggregated.Metadata.ResourceCount),
+ attribute.Int("aggregated.prompts", aggregated.Metadata.PromptCount),
+ attribute.String("conflict.strategy", string(aggregated.Metadata.ConflictStrategy)),
+ )
+
logger.Infof("Capability aggregation complete: %d backends, %d tools, %d resources, %d prompts",
aggregated.Metadata.BackendCount, aggregated.Metadata.ToolCount,
aggregated.Metadata.ResourceCount, aggregated.Metadata.PromptCount)
diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go
index 9e53ff994e..0634376de6 100644
--- a/pkg/vmcp/client/client.go
+++ b/pkg/vmcp/client/client.go
@@ -15,6 +15,7 @@ import (
"io"
"net"
"net/http"
+ "time"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
@@ -126,8 +127,6 @@ func (a *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
return nil, fmt.Errorf("authentication failed for backend %s: %w", a.target.WorkloadID, err)
}
- logger.Debugf("Applied authentication strategy %q to backend %s", a.authStrategy.Name(), a.target.WorkloadID)
-
return a.base.RoundTrip(reqClone)
}
@@ -203,8 +202,10 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm
})
// Create HTTP client with configured transport chain
+ // Set timeouts to prevent long-lived connections that require continuous listening
httpClient := &http.Client{
Transport: sizeLimitedTransport,
+ Timeout: 30 * time.Second, // Prevent hanging on connections
}
var c *client.Client
@@ -213,8 +214,7 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm
case "streamable-http", "streamable":
c, err = client.NewStreamableHttpClient(
target.BaseURL,
- transport.WithHTTPTimeout(0),
- transport.WithContinuousListening(),
+ transport.WithHTTPTimeout(30*time.Second), // Set timeout instead of 0
transport.WithHTTPBasicClient(httpClient),
)
if err != nil {
diff --git a/pkg/vmcp/config/config.go b/pkg/vmcp/config/config.go
index aa9583cce0..f477c01232 100644
--- a/pkg/vmcp/config/config.go
+++ b/pkg/vmcp/config/config.go
@@ -151,7 +151,7 @@ type Config struct {
Audit *audit.Config `json:"audit,omitempty" yaml:"audit,omitempty"`
// Optimizer configures the MCP optimizer for context optimization on large toolsets.
- // When enabled, vMCP exposes only find_tool and call_tool operations to clients
+ // When enabled, vMCP exposes optim_find_tool and optim_call_tool operations to clients
// instead of all backend tools directly. This reduces token usage by allowing
// LLMs to discover relevant tools on demand rather than receiving all tool definitions.
// +optional
@@ -696,16 +696,72 @@ type OutputProperty struct {
Default thvjson.Any `json:"default,omitempty" yaml:"default,omitempty"`
}
-// OptimizerConfig configures the MCP optimizer.
-// When enabled, vMCP exposes only find_tool and call_tool operations to clients
-// instead of all backend tools directly.
+// OptimizerConfig configures the MCP optimizer for semantic tool discovery.
+// The optimizer reduces token usage by allowing LLMs to discover relevant tools
+// on demand rather than receiving all tool definitions upfront.
// +kubebuilder:object:generate=true
// +gendoc
type OptimizerConfig struct {
- // EmbeddingService is the name of a Kubernetes Service that provides the embedding service
- // for semantic tool discovery. The service must implement the optimizer embedding API.
- // +kubebuilder:validation:Required
- EmbeddingService string `json:"embeddingService" yaml:"embeddingService"`
+ // Enabled determines whether the optimizer is active.
+ // When true, vMCP exposes optim_find_tool and optim_call_tool instead of all backend tools.
+ // +optional
+ Enabled bool `json:"enabled" yaml:"enabled"`
+
+ // EmbeddingBackend specifies the embedding provider: "ollama", "openai-compatible", or "placeholder".
+ // - "ollama": Uses local Ollama HTTP API for embeddings
+ // - "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.)
+ // - "placeholder": Uses deterministic hash-based embeddings (for testing/development)
+ // +kubebuilder:validation:Enum=ollama;openai-compatible;placeholder
+ // +optional
+ EmbeddingBackend string `json:"embeddingBackend,omitempty" yaml:"embeddingBackend,omitempty"`
+
+ // EmbeddingURL is the base URL for the embedding service (Ollama or OpenAI-compatible API).
+ // Required when EmbeddingBackend is "ollama" or "openai-compatible".
+ // Examples:
+ // - Ollama: "http://localhost:11434"
+ // - vLLM: "http://vllm-service:8000/v1"
+ // - OpenAI: "https://api.openai.com/v1"
+ // +optional
+ EmbeddingURL string `json:"embeddingURL,omitempty" yaml:"embeddingURL,omitempty"`
+
+ // EmbeddingModel is the model name to use for embeddings.
+ // Required when EmbeddingBackend is "ollama" or "openai-compatible".
+ // Examples:
+ // - Ollama: "nomic-embed-text", "all-minilm"
+ // - vLLM: "BAAI/bge-small-en-v1.5"
+ // - OpenAI: "text-embedding-3-small"
+ // +optional
+ EmbeddingModel string `json:"embeddingModel,omitempty" yaml:"embeddingModel,omitempty"`
+
+ // EmbeddingDimension is the dimension of the embedding vectors.
+ // Common values:
+ // - 384: all-MiniLM-L6-v2, nomic-embed-text
+ // - 768: BAAI/bge-small-en-v1.5
+ // - 1536: OpenAI text-embedding-3-small
+ // +kubebuilder:validation:Minimum=1
+ // +optional
+ EmbeddingDimension int `json:"embeddingDimension,omitempty" yaml:"embeddingDimension,omitempty"`
+
+ // PersistPath is the optional filesystem path for persisting the chromem-go database.
+ // If empty, the database will be in-memory only (ephemeral).
+ // When set, tool metadata and embeddings are persisted to disk for faster restarts.
+ // +optional
+ PersistPath string `json:"persistPath,omitempty" yaml:"persistPath,omitempty"`
+
+ // FTSDBPath is the path to the SQLite FTS5 database for BM25 text search.
+ // If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set.
+ // Hybrid search (semantic + BM25) is always enabled.
+ // +optional
+ FTSDBPath string `json:"ftsDBPath,omitempty" yaml:"ftsDBPath,omitempty"`
+
+ // HybridSearchRatio controls the mix of semantic vs BM25 results in hybrid search.
+ // Value range: 0 (all BM25) to 100 (all semantic), representing percentage.
+ // Default: 70 (70% semantic, 30% BM25)
+ // Only used when FTSDBPath is set.
+ // +optional
+ // +kubebuilder:validation:Minimum=0
+ // +kubebuilder:validation:Maximum=100
+ HybridSearchRatio *int `json:"hybridSearchRatio,omitempty" yaml:"hybridSearchRatio,omitempty"`
}
// Validator validates configuration.
diff --git a/pkg/vmcp/discovery/manager.go b/pkg/vmcp/discovery/manager.go
index 0845118ee1..6dfa659512 100644
--- a/pkg/vmcp/discovery/manager.go
+++ b/pkg/vmcp/discovery/manager.go
@@ -18,6 +18,8 @@ import (
"sync"
"time"
+ "golang.org/x/sync/singleflight"
+
"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/vmcp"
@@ -68,6 +70,9 @@ type DefaultManager struct {
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
+ // singleFlight ensures only one aggregation happens per cache key at a time
+ // This prevents concurrent requests from all triggering aggregation
+ singleFlight singleflight.Group
}
// NewManager creates a new discovery manager with the given aggregator.
@@ -131,6 +136,9 @@ func NewManagerWithRegistry(agg aggregator.Aggregator, registry vmcp.DynamicRegi
//
// The context must contain an authenticated user identity (set by auth middleware).
// Returns ErrNoIdentity if user identity is not found in context.
+//
+// This method uses singleflight to ensure that concurrent requests for the same
+// cache key only trigger one aggregation, preventing duplicate work.
func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) {
// Validate user identity is present (set by auth middleware)
// This ensures discovery happens with proper user authentication context
@@ -142,7 +150,7 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend)
// Generate cache key from user identity and backend set
cacheKey := m.generateCacheKey(identity.Subject, backends)
- // Check cache first
+ // Check cache first (with read lock)
if caps := m.getCachedCapabilities(cacheKey); caps != nil {
logger.Debugf("Cache hit for user %s (key: %s)", identity.Subject, cacheKey)
return caps, nil
@@ -150,16 +158,33 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend)
logger.Debugf("Cache miss - performing capability discovery for user: %s", identity.Subject)
- // Cache miss - perform aggregation
- caps, err := m.aggregator.AggregateCapabilities(ctx, backends)
+ // Use singleflight to ensure only one aggregation happens per cache key
+ // Even if multiple requests come in concurrently, they'll all wait for the same result
+ result, err, _ := m.singleFlight.Do(cacheKey, func() (interface{}, error) {
+ // Double-check cache after acquiring singleflight lock
+ // Another goroutine might have populated it while we were waiting
+ if caps := m.getCachedCapabilities(cacheKey); caps != nil {
+ logger.Debugf("Cache populated while waiting - returning cached result for user %s", identity.Subject)
+ return caps, nil
+ }
+
+ // Perform aggregation
+ caps, err := m.aggregator.AggregateCapabilities(ctx, backends)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrDiscoveryFailed, err)
+ }
+
+ // Cache the result (skips caching if at capacity and key doesn't exist)
+ m.cacheCapabilities(cacheKey, caps)
+
+ return caps, nil
+ })
+
if err != nil {
- return nil, fmt.Errorf("%w: %w", ErrDiscoveryFailed, err)
+ return nil, err
}
- // Cache the result (skips caching if at capacity and key doesn't exist)
- m.cacheCapabilities(cacheKey, caps)
-
- return caps, nil
+ return result.(*aggregator.AggregatedCapabilities), nil
}
// Stop gracefully stops the manager and cleans up resources.
diff --git a/pkg/vmcp/discovery/manager_test_coverage.go b/pkg/vmcp/discovery/manager_test_coverage.go
new file mode 100644
index 0000000000..3826fc2849
--- /dev/null
+++ b/pkg/vmcp/discovery/manager_test_coverage.go
@@ -0,0 +1,176 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package discovery
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "go.uber.org/mock/gomock"
+
+ "github.com/stacklok/toolhive/pkg/auth"
+ "github.com/stacklok/toolhive/pkg/vmcp"
+ "github.com/stacklok/toolhive/pkg/vmcp/aggregator"
+ aggmocks "github.com/stacklok/toolhive/pkg/vmcp/aggregator/mocks"
+ vmcpmocks "github.com/stacklok/toolhive/pkg/vmcp/mocks"
+)
+
+// TestDefaultManager_CacheVersionMismatch tests cache invalidation on version mismatch
+func TestDefaultManager_CacheVersionMismatch(t *testing.T) {
+ t.Parallel()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockAggregator := aggmocks.NewMockAggregator(ctrl)
+ mockRegistry := vmcpmocks.NewMockDynamicRegistry(ctrl)
+
+ // First call - version 1
+ mockRegistry.EXPECT().Version().Return(uint64(1)).Times(2)
+ mockAggregator.EXPECT().
+ AggregateCapabilities(gomock.Any(), gomock.Any()).
+ Return(&aggregator.AggregatedCapabilities{}, nil).
+ Times(1)
+
+ manager, err := NewManagerWithRegistry(mockAggregator, mockRegistry)
+ require.NoError(t, err)
+ defer manager.Stop()
+
+ ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{
+ Subject: "user-1",
+ })
+
+ backends := []vmcp.Backend{
+ {ID: "backend-1", Name: "Backend 1"},
+ }
+
+ // First discovery - should cache
+ caps1, err := manager.Discover(ctx, backends)
+ require.NoError(t, err)
+ require.NotNil(t, caps1)
+
+ // Second discovery with same version - should use cache
+ mockRegistry.EXPECT().Version().Return(uint64(1)).Times(1)
+ caps2, err := manager.Discover(ctx, backends)
+ require.NoError(t, err)
+ require.NotNil(t, caps2)
+
+ // Third discovery with different version - should invalidate cache
+ mockRegistry.EXPECT().Version().Return(uint64(2)).Times(2)
+ mockAggregator.EXPECT().
+ AggregateCapabilities(gomock.Any(), gomock.Any()).
+ Return(&aggregator.AggregatedCapabilities{}, nil).
+ Times(1)
+
+ caps3, err := manager.Discover(ctx, backends)
+ require.NoError(t, err)
+ require.NotNil(t, caps3)
+}
+
+// TestDefaultManager_CacheAtCapacity tests cache eviction when at capacity
+func TestDefaultManager_CacheAtCapacity(t *testing.T) {
+ t.Parallel()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockAggregator := aggmocks.NewMockAggregator(ctrl)
+
+ // Create many different cache keys to fill cache
+ mockAggregator.EXPECT().
+ AggregateCapabilities(gomock.Any(), gomock.Any()).
+ Return(&aggregator.AggregatedCapabilities{}, nil).
+ Times(maxCacheSize + 1) // One more than capacity
+
+ manager, err := NewManager(mockAggregator)
+ require.NoError(t, err)
+ defer manager.Stop()
+
+ // Fill cache to capacity
+ for i := 0; i < maxCacheSize; i++ {
+ ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{
+ Subject: "user-" + string(rune(i)),
+ })
+
+ backends := []vmcp.Backend{
+ {ID: "backend-" + string(rune(i)), Name: "Backend"},
+ }
+
+ _, err := manager.Discover(ctx, backends)
+ require.NoError(t, err)
+ }
+
+ // Next discovery should not cache (at capacity)
+ ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{
+ Subject: "user-new",
+ })
+
+ backends := []vmcp.Backend{
+ {ID: "backend-new", Name: "Backend"},
+ }
+
+ _, err = manager.Discover(ctx, backends)
+ require.NoError(t, err)
+}
+
+// TestDefaultManager_CacheAtCapacity_ExistingKey tests cache update when at capacity but key exists
+func TestDefaultManager_CacheAtCapacity_ExistingKey(t *testing.T) {
+ t.Parallel()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockAggregator := aggmocks.NewMockAggregator(ctrl)
+
+ // First call
+ mockAggregator.EXPECT().
+ AggregateCapabilities(gomock.Any(), gomock.Any()).
+ Return(&aggregator.AggregatedCapabilities{}, nil).
+ Times(1)
+
+ manager, err := NewManager(mockAggregator)
+ require.NoError(t, err)
+ defer manager.Stop()
+
+ ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{
+ Subject: "user-1",
+ })
+
+ backends := []vmcp.Backend{
+ {ID: "backend-1", Name: "Backend 1"},
+ }
+
+ // First discovery
+ _, err = manager.Discover(ctx, backends)
+ require.NoError(t, err)
+
+ // Fill cache to capacity with other keys
+ for i := 0; i < maxCacheSize-1; i++ {
+ ctxOther := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{
+ Subject: "user-" + string(rune(i+2)),
+ })
+
+ backendsOther := []vmcp.Backend{
+ {ID: "backend-" + string(rune(i+2)), Name: "Backend"},
+ }
+
+ mockAggregator.EXPECT().
+ AggregateCapabilities(gomock.Any(), gomock.Any()).
+ Return(&aggregator.AggregatedCapabilities{}, nil).
+ Times(1)
+
+ _, err := manager.Discover(ctxOther, backendsOther)
+ require.NoError(t, err)
+ }
+
+ // Update existing key should work even at capacity
+ mockAggregator.EXPECT().
+ AggregateCapabilities(gomock.Any(), gomock.Any()).
+ Return(&aggregator.AggregatedCapabilities{}, nil).
+ Times(1)
+
+ _, err = manager.Discover(ctx, backends)
+ require.NoError(t, err)
+}
diff --git a/pkg/vmcp/discovery/middleware_test.go b/pkg/vmcp/discovery/middleware_test.go
index d1b36a870c..3c8cd8e9ca 100644
--- a/pkg/vmcp/discovery/middleware_test.go
+++ b/pkg/vmcp/discovery/middleware_test.go
@@ -348,8 +348,19 @@ func TestMiddleware_CapabilitiesInContext(t *testing.T) {
},
}
+ // Use Do to capture and verify backends separately, since order may vary
mockMgr.EXPECT().
- Discover(gomock.Any(), unorderedBackendsMatcher{backends}).
+ Discover(gomock.Any(), gomock.Any()).
+ Do(func(_ context.Context, actualBackends []vmcp.Backend) {
+ // Verify that we got the expected backends regardless of order
+ assert.Len(t, actualBackends, 2)
+ backendIDs := make(map[string]bool)
+ for _, b := range actualBackends {
+ backendIDs[b.ID] = true
+ }
+ assert.True(t, backendIDs["backend1"], "backend1 should be present")
+ assert.True(t, backendIDs["backend2"], "backend2 should be present")
+ }).
Return(expectedCaps, nil)
// Create handler that inspects context in detail
diff --git a/pkg/vmcp/health/checker.go b/pkg/vmcp/health/checker.go
index d593aa0401..bf6f5c329c 100644
--- a/pkg/vmcp/health/checker.go
+++ b/pkg/vmcp/health/checker.go
@@ -11,6 +11,8 @@ import (
"context"
"errors"
"fmt"
+ "net/url"
+ "strings"
"time"
"github.com/stacklok/toolhive/pkg/logger"
@@ -29,6 +31,10 @@ type healthChecker struct {
// If a health check succeeds but takes longer than this duration, the backend is marked degraded.
// Zero means disabled (backends will never be marked degraded based on response time alone).
degradedThreshold time.Duration
+
+ // selfURL is the server's own URL. If a health check targets this URL, it's short-circuited.
+ // This prevents the server from trying to health check itself.
+ selfURL string
}
// NewHealthChecker creates a new health checker that uses BackendClient.ListCapabilities
@@ -39,13 +45,20 @@ type healthChecker struct {
// - client: BackendClient for communicating with backend MCP servers
// - timeout: Maximum duration for health check operations (0 = no timeout)
// - degradedThreshold: Response time threshold for marking backend as degraded (0 = disabled)
+// - selfURL: Optional server's own URL. If provided, health checks targeting this URL are short-circuited.
//
// Returns a new HealthChecker implementation.
-func NewHealthChecker(client vmcp.BackendClient, timeout time.Duration, degradedThreshold time.Duration) vmcp.HealthChecker {
+func NewHealthChecker(
+ client vmcp.BackendClient,
+ timeout time.Duration,
+ degradedThreshold time.Duration,
+ selfURL string,
+) vmcp.HealthChecker {
return &healthChecker{
client: client,
timeout: timeout,
degradedThreshold: degradedThreshold,
+ selfURL: selfURL,
}
}
@@ -62,16 +75,28 @@ func NewHealthChecker(client vmcp.BackendClient, timeout time.Duration, degraded
// The error return is informational and provides context about what failed.
// The BackendHealthStatus return indicates the categorized health state.
func (h *healthChecker) CheckHealth(ctx context.Context, target *vmcp.BackendTarget) (vmcp.BackendHealthStatus, error) {
- // Apply timeout if configured
- checkCtx := ctx
+ // Mark context as health check to bypass authentication logging
+ // Health checks verify backend availability and should not require user credentials
+ healthCheckCtx := WithHealthCheckMarker(ctx)
+
+ // Apply timeout if configured (after adding health check marker)
+ checkCtx := healthCheckCtx
var cancel context.CancelFunc
if h.timeout > 0 {
- checkCtx, cancel = context.WithTimeout(ctx, h.timeout)
+ checkCtx, cancel = context.WithTimeout(healthCheckCtx, h.timeout)
defer cancel()
}
logger.Debugf("Performing health check for backend %s (%s)", target.WorkloadName, target.BaseURL)
+ // Short-circuit health check if targeting ourselves
+ // This prevents the server from trying to health check itself, which would work
+ // but is wasteful and can cause connection issues during startup
+ if h.selfURL != "" && h.isSelfCheck(target.BaseURL) {
+ logger.Debugf("Skipping health check for backend %s - this is the server itself", target.WorkloadName)
+ return vmcp.BackendHealthy, nil
+ }
+
// Track response time for degraded detection
startTime := time.Now()
@@ -137,3 +162,62 @@ func categorizeError(err error) vmcp.BackendHealthStatus {
// Default to unhealthy for unknown errors
return vmcp.BackendUnhealthy
}
+
+// isSelfCheck checks if a backend URL matches the server's own URL.
+// URLs are normalized before comparison to handle variations like:
+// - http://127.0.0.1:PORT vs http://localhost:PORT
+// - http://HOST:PORT vs http://HOST:PORT/
+func (h *healthChecker) isSelfCheck(backendURL string) bool {
+ if h.selfURL == "" || backendURL == "" {
+ return false
+ }
+
+ // Normalize both URLs for comparison
+ backendNormalized, err := NormalizeURLForComparison(backendURL)
+ if err != nil {
+ return false
+ }
+
+ selfNormalized, err := NormalizeURLForComparison(h.selfURL)
+ if err != nil {
+ return false
+ }
+
+ return backendNormalized == selfNormalized
+}
+
+// NormalizeURLForComparison normalizes a URL for comparison by:
+// - Parsing and reconstructing the URL
+// - Converting localhost/127.0.0.1 to a canonical form
+// - Comparing only scheme://host:port (ignoring path, query, fragment)
+// - Lowercasing scheme and host
+// Exported for testing purposes
+func NormalizeURLForComparison(rawURL string) (string, error) {
+ u, err := url.Parse(rawURL)
+ if err != nil {
+ return "", err
+ }
+ // Validate that we have a scheme and host (basic URL validation)
+ if u.Scheme == "" || u.Host == "" {
+ return "", fmt.Errorf("invalid URL: missing scheme or host")
+ }
+
+ // Normalize host: convert localhost to 127.0.0.1 for consistency
+ host := strings.ToLower(u.Hostname())
+ if host == "localhost" {
+ host = "127.0.0.1"
+ }
+
+ // Reconstruct URL with normalized components (scheme://host:port only)
+ // We ignore path, query, and fragment for comparison
+ normalized := &url.URL{
+ Scheme: strings.ToLower(u.Scheme),
+ }
+ if u.Port() != "" {
+ normalized.Host = host + ":" + u.Port()
+ } else {
+ normalized.Host = host
+ }
+
+ return normalized.String(), nil
+}
diff --git a/pkg/vmcp/health/checker_selfcheck_test.go b/pkg/vmcp/health/checker_selfcheck_test.go
new file mode 100644
index 0000000000..ff963d8d35
--- /dev/null
+++ b/pkg/vmcp/health/checker_selfcheck_test.go
@@ -0,0 +1,504 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package health
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.uber.org/mock/gomock"
+
+ "github.com/stacklok/toolhive/pkg/vmcp"
+ "github.com/stacklok/toolhive/pkg/vmcp/mocks"
+)
+
+// TestHealthChecker_CheckHealth_SelfCheck tests self-check detection
+func TestHealthChecker_CheckHealth_SelfCheck(t *testing.T) {
+ t.Parallel()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockClient := mocks.NewMockBackendClient(ctrl)
+ // Should not call ListCapabilities for self-check
+ mockClient.EXPECT().
+ ListCapabilities(gomock.Any(), gomock.Any()).
+ Times(0)
+
+ checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080")
+ target := &vmcp.BackendTarget{
+ WorkloadID: "backend-1",
+ WorkloadName: "test-backend",
+ BaseURL: "http://127.0.0.1:8080", // Same as selfURL
+ }
+
+ status, err := checker.CheckHealth(context.Background(), target)
+ assert.NoError(t, err)
+ assert.Equal(t, vmcp.BackendHealthy, status)
+}
+
+// TestHealthChecker_CheckHealth_SelfCheck_Localhost tests localhost normalization
+func TestHealthChecker_CheckHealth_SelfCheck_Localhost(t *testing.T) {
+ t.Parallel()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockClient := mocks.NewMockBackendClient(ctrl)
+ mockClient.EXPECT().
+ ListCapabilities(gomock.Any(), gomock.Any()).
+ Times(0)
+
+ checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://localhost:8080")
+ target := &vmcp.BackendTarget{
+ WorkloadID: "backend-1",
+ WorkloadName: "test-backend",
+ BaseURL: "http://127.0.0.1:8080", // localhost should match 127.0.0.1
+ }
+
+ status, err := checker.CheckHealth(context.Background(), target)
+ assert.NoError(t, err)
+ assert.Equal(t, vmcp.BackendHealthy, status)
+}
+
+// TestHealthChecker_CheckHealth_SelfCheck_Reverse tests reverse localhost normalization
+func TestHealthChecker_CheckHealth_SelfCheck_Reverse(t *testing.T) {
+ t.Parallel()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockClient := mocks.NewMockBackendClient(ctrl)
+ mockClient.EXPECT().
+ ListCapabilities(gomock.Any(), gomock.Any()).
+ Times(0)
+
+ checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080")
+ target := &vmcp.BackendTarget{
+ WorkloadID: "backend-1",
+ WorkloadName: "test-backend",
+ BaseURL: "http://localhost:8080", // 127.0.0.1 should match localhost
+ }
+
+ status, err := checker.CheckHealth(context.Background(), target)
+ assert.NoError(t, err)
+ assert.Equal(t, vmcp.BackendHealthy, status)
+}
+
+// TestHealthChecker_CheckHealth_SelfCheck_DifferentPort tests different ports don't match
+func TestHealthChecker_CheckHealth_SelfCheck_DifferentPort(t *testing.T) {
+ t.Parallel()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockClient := mocks.NewMockBackendClient(ctrl)
+ mockClient.EXPECT().
+ ListCapabilities(gomock.Any(), gomock.Any()).
+ Return(&vmcp.CapabilityList{}, nil).
+ Times(1)
+
+ checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080")
+ target := &vmcp.BackendTarget{
+ WorkloadID: "backend-1",
+ WorkloadName: "test-backend",
+ BaseURL: "http://127.0.0.1:8081", // Different port
+ }
+
+ status, err := checker.CheckHealth(context.Background(), target)
+ assert.NoError(t, err)
+ assert.Equal(t, vmcp.BackendHealthy, status)
+}
+
+// TestHealthChecker_CheckHealth_SelfCheck_EmptyURL tests empty URLs
+func TestHealthChecker_CheckHealth_SelfCheck_EmptyURL(t *testing.T) {
+ t.Parallel()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockClient := mocks.NewMockBackendClient(ctrl)
+ mockClient.EXPECT().
+ ListCapabilities(gomock.Any(), gomock.Any()).
+ Return(&vmcp.CapabilityList{}, nil).
+ Times(1)
+
+ checker := NewHealthChecker(mockClient, 5*time.Second, 0, "")
+ target := &vmcp.BackendTarget{
+ WorkloadID: "backend-1",
+ WorkloadName: "test-backend",
+ BaseURL: "http://127.0.0.1:8080",
+ }
+
+ status, err := checker.CheckHealth(context.Background(), target)
+ assert.NoError(t, err)
+ assert.Equal(t, vmcp.BackendHealthy, status)
+}
+
+// TestHealthChecker_CheckHealth_SelfCheck_InvalidURL tests invalid URLs
+func TestHealthChecker_CheckHealth_SelfCheck_InvalidURL(t *testing.T) {
+ t.Parallel()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockClient := mocks.NewMockBackendClient(ctrl)
+ mockClient.EXPECT().
+ ListCapabilities(gomock.Any(), gomock.Any()).
+ Return(&vmcp.CapabilityList{}, nil).
+ Times(1)
+
+ checker := NewHealthChecker(mockClient, 5*time.Second, 0, "not-a-valid-url")
+ target := &vmcp.BackendTarget{
+ WorkloadID: "backend-1",
+ WorkloadName: "test-backend",
+ BaseURL: "http://127.0.0.1:8080",
+ }
+
+ status, err := checker.CheckHealth(context.Background(), target)
+ assert.NoError(t, err)
+ assert.Equal(t, vmcp.BackendHealthy, status)
+}
+
+// TestHealthChecker_CheckHealth_SelfCheck_WithPath tests URLs with paths are normalized
+func TestHealthChecker_CheckHealth_SelfCheck_WithPath(t *testing.T) {
+ t.Parallel()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockClient := mocks.NewMockBackendClient(ctrl)
+ mockClient.EXPECT().
+ ListCapabilities(gomock.Any(), gomock.Any()).
+ Times(0)
+
+ checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080")
+ target := &vmcp.BackendTarget{
+ WorkloadID: "backend-1",
+ WorkloadName: "test-backend",
+ BaseURL: "http://127.0.0.1:8080/mcp", // Path should be ignored
+ }
+
+ status, err := checker.CheckHealth(context.Background(), target)
+ assert.NoError(t, err)
+ assert.Equal(t, vmcp.BackendHealthy, status)
+}
+
+// TestHealthChecker_CheckHealth_DegradedThreshold tests degraded threshold detection
+func TestHealthChecker_CheckHealth_DegradedThreshold(t *testing.T) {
+ t.Parallel()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockClient := mocks.NewMockBackendClient(ctrl)
+ mockClient.EXPECT().
+ ListCapabilities(gomock.Any(), gomock.Any()).
+ DoAndReturn(func(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) {
+ // Simulate slow response
+ time.Sleep(150 * time.Millisecond)
+ return &vmcp.CapabilityList{}, nil
+ }).
+ Times(1)
+
+ // Set degraded threshold to 100ms
+ checker := NewHealthChecker(mockClient, 5*time.Second, 100*time.Millisecond, "")
+ target := &vmcp.BackendTarget{
+ WorkloadID: "backend-1",
+ WorkloadName: "test-backend",
+ BaseURL: "http://localhost:8080",
+ }
+
+ status, err := checker.CheckHealth(context.Background(), target)
+ assert.NoError(t, err)
+ assert.Equal(t, vmcp.BackendDegraded, status, "Should mark as degraded when response time exceeds threshold")
+}
+
+// TestHealthChecker_CheckHealth_DegradedThreshold_Disabled tests disabled degraded threshold
+func TestHealthChecker_CheckHealth_DegradedThreshold_Disabled(t *testing.T) {
+ t.Parallel()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockClient := mocks.NewMockBackendClient(ctrl)
+ mockClient.EXPECT().
+ ListCapabilities(gomock.Any(), gomock.Any()).
+ DoAndReturn(func(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) {
+ // Simulate slow response
+ time.Sleep(150 * time.Millisecond)
+ return &vmcp.CapabilityList{}, nil
+ }).
+ Times(1)
+
+ // Set degraded threshold to 0 (disabled)
+ checker := NewHealthChecker(mockClient, 5*time.Second, 0, "")
+ target := &vmcp.BackendTarget{
+ WorkloadID: "backend-1",
+ WorkloadName: "test-backend",
+ BaseURL: "http://localhost:8080",
+ }
+
+ status, err := checker.CheckHealth(context.Background(), target)
+ assert.NoError(t, err)
+ assert.Equal(t, vmcp.BackendHealthy, status, "Should not mark as degraded when threshold is disabled")
+}
+
+// TestHealthChecker_CheckHealth_DegradedThreshold_FastResponse tests fast response doesn't trigger degraded
+func TestHealthChecker_CheckHealth_DegradedThreshold_FastResponse(t *testing.T) {
+ t.Parallel()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockClient := mocks.NewMockBackendClient(ctrl)
+ mockClient.EXPECT().
+ ListCapabilities(gomock.Any(), gomock.Any()).
+ Return(&vmcp.CapabilityList{}, nil).
+ Times(1)
+
+ // Set degraded threshold to 100ms
+ checker := NewHealthChecker(mockClient, 5*time.Second, 100*time.Millisecond, "")
+ target := &vmcp.BackendTarget{
+ WorkloadID: "backend-1",
+ WorkloadName: "test-backend",
+ BaseURL: "http://localhost:8080",
+ }
+
+ status, err := checker.CheckHealth(context.Background(), target)
+ assert.NoError(t, err)
+ assert.Equal(t, vmcp.BackendHealthy, status, "Should not mark as degraded when response is fast")
+}
+
+// TestCategorizeError_SentinelErrors tests sentinel error categorization
+func TestCategorizeError_SentinelErrors(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ err error
+ expectedStatus vmcp.BackendHealthStatus
+ }{
+ {
+ name: "ErrAuthenticationFailed",
+ err: vmcp.ErrAuthenticationFailed,
+ expectedStatus: vmcp.BackendUnauthenticated,
+ },
+ {
+ name: "ErrAuthorizationFailed",
+ err: vmcp.ErrAuthorizationFailed,
+ expectedStatus: vmcp.BackendUnauthenticated,
+ },
+ {
+ name: "ErrTimeout",
+ err: vmcp.ErrTimeout,
+ expectedStatus: vmcp.BackendUnhealthy,
+ },
+ {
+ name: "ErrCancelled",
+ err: vmcp.ErrCancelled,
+ expectedStatus: vmcp.BackendUnhealthy,
+ },
+ {
+ name: "ErrBackendUnavailable",
+ err: vmcp.ErrBackendUnavailable,
+ expectedStatus: vmcp.BackendUnhealthy,
+ },
+ {
+ name: "wrapped ErrAuthenticationFailed",
+ err: errors.New("wrapped: " + vmcp.ErrAuthenticationFailed.Error()),
+ expectedStatus: vmcp.BackendUnauthenticated,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ status := categorizeError(tt.err)
+ assert.Equal(t, tt.expectedStatus, status)
+ })
+ }
+}
+
+// TestNormalizeURLForComparison tests URL normalization
+func TestNormalizeURLForComparison(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ expected string
+ wantErr bool
+ }{
+ {
+ name: "localhost normalized to 127.0.0.1",
+ input: "http://localhost:8080",
+ expected: "http://127.0.0.1:8080",
+ wantErr: false,
+ },
+ {
+ name: "127.0.0.1 stays as is",
+ input: "http://127.0.0.1:8080",
+ expected: "http://127.0.0.1:8080",
+ wantErr: false,
+ },
+ {
+ name: "path is ignored",
+ input: "http://127.0.0.1:8080/mcp",
+ expected: "http://127.0.0.1:8080",
+ wantErr: false,
+ },
+ {
+ name: "query is ignored",
+ input: "http://127.0.0.1:8080?param=value",
+ expected: "http://127.0.0.1:8080",
+ wantErr: false,
+ },
+ {
+ name: "fragment is ignored",
+ input: "http://127.0.0.1:8080#fragment",
+ expected: "http://127.0.0.1:8080",
+ wantErr: false,
+ },
+ {
+ name: "scheme is lowercased",
+ input: "HTTP://127.0.0.1:8080",
+ expected: "http://127.0.0.1:8080",
+ wantErr: false,
+ },
+ {
+ name: "host is lowercased",
+ input: "http://EXAMPLE.COM:8080",
+ expected: "http://example.com:8080",
+ wantErr: false,
+ },
+ {
+ name: "no port",
+ input: "http://127.0.0.1",
+ expected: "http://127.0.0.1",
+ wantErr: false,
+ },
+ {
+ name: "invalid URL",
+ input: "not-a-url",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result, err := NormalizeURLForComparison(tt.input)
+ if tt.wantErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+// TestIsSelfCheck_EdgeCases tests edge cases for self-check detection
+func TestIsSelfCheck_EdgeCases(t *testing.T) {
+ t.Parallel()
+
+ ctrl := gomock.NewController(t)
+ t.Cleanup(func() { ctrl.Finish() })
+
+ mockClient := mocks.NewMockBackendClient(ctrl)
+
+ tests := []struct {
+ name string
+ selfURL string
+ backendURL string
+ expected bool
+ }{
+ {
+ name: "both empty",
+ selfURL: "",
+ backendURL: "",
+ expected: false,
+ },
+ {
+ name: "selfURL empty",
+ selfURL: "",
+ backendURL: "http://127.0.0.1:8080",
+ expected: false,
+ },
+ {
+ name: "backendURL empty",
+ selfURL: "http://127.0.0.1:8080",
+ backendURL: "",
+ expected: false,
+ },
+ {
+ name: "localhost matches 127.0.0.1",
+ selfURL: "http://localhost:8080",
+ backendURL: "http://127.0.0.1:8080",
+ expected: true,
+ },
+ {
+ name: "127.0.0.1 matches localhost",
+ selfURL: "http://127.0.0.1:8080",
+ backendURL: "http://localhost:8080",
+ expected: true,
+ },
+ {
+ name: "different ports",
+ selfURL: "http://127.0.0.1:8080",
+ backendURL: "http://127.0.0.1:8081",
+ expected: false,
+ },
+ {
+ name: "different hosts",
+ selfURL: "http://127.0.0.1:8080",
+ backendURL: "http://192.168.1.1:8080",
+ expected: false,
+ },
+ {
+ name: "path ignored",
+ selfURL: "http://127.0.0.1:8080",
+ backendURL: "http://127.0.0.1:8080/mcp",
+ expected: true,
+ },
+ {
+ name: "query ignored",
+ selfURL: "http://127.0.0.1:8080",
+ backendURL: "http://127.0.0.1:8080?param=value",
+ expected: true,
+ },
+ {
+ name: "invalid selfURL",
+ selfURL: "not-a-url",
+ backendURL: "http://127.0.0.1:8080",
+ expected: false,
+ },
+ {
+ name: "invalid backendURL",
+ selfURL: "http://127.0.0.1:8080",
+ backendURL: "not-a-url",
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ checker := NewHealthChecker(mockClient, 5*time.Second, 0, tt.selfURL)
+ hc, ok := checker.(*healthChecker)
+ require.True(t, ok)
+
+ result := hc.isSelfCheck(tt.backendURL)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
diff --git a/pkg/vmcp/health/checker_test.go b/pkg/vmcp/health/checker_test.go
index 39f7258d82..63c3c986b6 100644
--- a/pkg/vmcp/health/checker_test.go
+++ b/pkg/vmcp/health/checker_test.go
@@ -44,7 +44,7 @@ func TestNewHealthChecker(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
- checker := NewHealthChecker(mockClient, tt.timeout, 0)
+ checker := NewHealthChecker(mockClient, tt.timeout, 0, "")
require.NotNil(t, checker)
// Type assert to access internals for verification
@@ -68,7 +68,7 @@ func TestHealthChecker_CheckHealth_Success(t *testing.T) {
Return(&vmcp.CapabilityList{}, nil).
Times(1)
- checker := NewHealthChecker(mockClient, 5*time.Second, 0)
+ checker := NewHealthChecker(mockClient, 5*time.Second, 0, "")
target := &vmcp.BackendTarget{
WorkloadID: "backend-1",
WorkloadName: "test-backend",
@@ -95,7 +95,7 @@ func TestHealthChecker_CheckHealth_ContextCancellation(t *testing.T) {
}).
Times(1)
- checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0)
+ checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0, "")
target := &vmcp.BackendTarget{
WorkloadID: "backend-1",
WorkloadName: "test-backend",
@@ -123,7 +123,7 @@ func TestHealthChecker_CheckHealth_NoTimeout(t *testing.T) {
Times(1)
// Create checker with no timeout
- checker := NewHealthChecker(mockClient, 0, 0)
+ checker := NewHealthChecker(mockClient, 0, 0, "")
target := &vmcp.BackendTarget{
WorkloadID: "backend-1",
WorkloadName: "test-backend",
@@ -213,7 +213,7 @@ func TestHealthChecker_CheckHealth_ErrorCategorization(t *testing.T) {
Return(nil, tt.err).
Times(1)
- checker := NewHealthChecker(mockClient, 5*time.Second, 0)
+ checker := NewHealthChecker(mockClient, 5*time.Second, 0, "")
target := &vmcp.BackendTarget{
WorkloadID: "backend-1",
WorkloadName: "test-backend",
@@ -430,7 +430,7 @@ func TestHealthChecker_CheckHealth_Timeout(t *testing.T) {
}).
Times(1)
- checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0)
+ checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0, "")
target := &vmcp.BackendTarget{
WorkloadID: "backend-1",
WorkloadName: "test-backend",
@@ -467,7 +467,7 @@ func TestHealthChecker_CheckHealth_MultipleBackends(t *testing.T) {
}).
Times(4)
- checker := NewHealthChecker(mockClient, 5*time.Second, 0)
+ checker := NewHealthChecker(mockClient, 5*time.Second, 0, "")
// Test healthy backend
status, err := checker.CheckHealth(context.Background(), &vmcp.BackendTarget{
diff --git a/pkg/vmcp/health/monitor.go b/pkg/vmcp/health/monitor.go
index 50f00d788d..62aea9b735 100644
--- a/pkg/vmcp/health/monitor.go
+++ b/pkg/vmcp/health/monitor.go
@@ -108,12 +108,14 @@ func DefaultConfig() MonitorConfig {
// - client: BackendClient for communicating with backend MCP servers
// - backends: List of backends to monitor
// - config: Configuration for health monitoring
+// - selfURL: Optional server's own URL. If provided, health checks targeting this URL are short-circuited.
//
// Returns (monitor, error). Error is returned if configuration is invalid.
func NewMonitor(
client vmcp.BackendClient,
backends []vmcp.Backend,
config MonitorConfig,
+ selfURL string,
) (*Monitor, error) {
// Validate configuration
if config.CheckInterval <= 0 {
@@ -123,8 +125,8 @@ func NewMonitor(
return nil, fmt.Errorf("unhealthy threshold must be >= 1, got %d", config.UnhealthyThreshold)
}
- // Create health checker with degraded threshold
- checker := NewHealthChecker(client, config.Timeout, config.DegradedThreshold)
+ // Create health checker with degraded threshold and self URL
+ checker := NewHealthChecker(client, config.Timeout, config.DegradedThreshold, selfURL)
// Create status tracker
statusTracker := newStatusTracker(config.UnhealthyThreshold)
diff --git a/pkg/vmcp/health/monitor_test.go b/pkg/vmcp/health/monitor_test.go
index bb177017e7..8d2de11bdd 100644
--- a/pkg/vmcp/health/monitor_test.go
+++ b/pkg/vmcp/health/monitor_test.go
@@ -66,7 +66,7 @@ func TestNewMonitor_Validation(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
- monitor, err := NewMonitor(mockClient, backends, tt.config)
+ monitor, err := NewMonitor(mockClient, backends, tt.config, "")
if tt.expectError {
assert.Error(t, err)
assert.Nil(t, monitor)
@@ -101,7 +101,7 @@ func TestMonitor_StartStop(t *testing.T) {
Return(&vmcp.CapabilityList{}, nil).
AnyTimes()
- monitor, err := NewMonitor(mockClient, backends, config)
+ monitor, err := NewMonitor(mockClient, backends, config, "")
require.NoError(t, err)
// Start monitor
@@ -178,7 +178,7 @@ func TestMonitor_StartErrors(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
- monitor, err := NewMonitor(mockClient, backends, config)
+ monitor, err := NewMonitor(mockClient, backends, config, "")
require.NoError(t, err)
err = tt.setupFunc(monitor)
@@ -208,7 +208,7 @@ func TestMonitor_StopWithoutStart(t *testing.T) {
Timeout: 50 * time.Millisecond,
}
- monitor, err := NewMonitor(mockClient, backends, config)
+ monitor, err := NewMonitor(mockClient, backends, config, "")
require.NoError(t, err)
// Try to stop without starting
@@ -239,7 +239,7 @@ func TestMonitor_PeriodicHealthChecks(t *testing.T) {
Return(nil, errors.New("backend unavailable")).
MinTimes(2)
- monitor, err := NewMonitor(mockClient, backends, config)
+ monitor, err := NewMonitor(mockClient, backends, config, "")
require.NoError(t, err)
ctx := context.Background()
@@ -289,7 +289,7 @@ func TestMonitor_GetHealthSummary(t *testing.T) {
}).
AnyTimes()
- monitor, err := NewMonitor(mockClient, backends, config)
+ monitor, err := NewMonitor(mockClient, backends, config, "")
require.NoError(t, err)
ctx := context.Background()
@@ -333,7 +333,7 @@ func TestMonitor_GetBackendStatus(t *testing.T) {
Return(&vmcp.CapabilityList{}, nil).
AnyTimes()
- monitor, err := NewMonitor(mockClient, backends, config)
+ monitor, err := NewMonitor(mockClient, backends, config, "")
require.NoError(t, err)
ctx := context.Background()
@@ -382,7 +382,7 @@ func TestMonitor_GetBackendState(t *testing.T) {
Return(&vmcp.CapabilityList{}, nil).
AnyTimes()
- monitor, err := NewMonitor(mockClient, backends, config)
+ monitor, err := NewMonitor(mockClient, backends, config, "")
require.NoError(t, err)
ctx := context.Background()
@@ -433,7 +433,7 @@ func TestMonitor_GetAllBackendStates(t *testing.T) {
Return(&vmcp.CapabilityList{}, nil).
AnyTimes()
- monitor, err := NewMonitor(mockClient, backends, config)
+ monitor, err := NewMonitor(mockClient, backends, config, "")
require.NoError(t, err)
ctx := context.Background()
@@ -477,7 +477,7 @@ func TestMonitor_ContextCancellation(t *testing.T) {
Return(&vmcp.CapabilityList{}, nil).
AnyTimes()
- monitor, err := NewMonitor(mockClient, backends, config)
+ monitor, err := NewMonitor(mockClient, backends, config, "")
require.NoError(t, err)
// Start with cancellable context
diff --git a/pkg/vmcp/optimizer/config.go b/pkg/vmcp/optimizer/config.go
new file mode 100644
index 0000000000..62aef2669c
--- /dev/null
+++ b/pkg/vmcp/optimizer/config.go
@@ -0,0 +1,42 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package optimizer
+
+import (
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings"
+ "github.com/stacklok/toolhive/pkg/vmcp/config"
+)
+
+// ConfigFromVMCPConfig converts a vmcp/config.OptimizerConfig to optimizer.Config.
+// This helper function bridges the gap between the shared config package and
+// the optimizer package's internal configuration structure.
+func ConfigFromVMCPConfig(cfg *config.OptimizerConfig) *Config {
+ if cfg == nil {
+ return nil
+ }
+
+ optimizerCfg := &Config{
+ Enabled: cfg.Enabled,
+ PersistPath: cfg.PersistPath,
+ FTSDBPath: cfg.FTSDBPath,
+ HybridSearchRatio: 70, // Default
+ }
+
+ // Handle HybridSearchRatio (pointer in config, value in optimizer.Config)
+ if cfg.HybridSearchRatio != nil {
+ optimizerCfg.HybridSearchRatio = *cfg.HybridSearchRatio
+ }
+
+ // Convert embedding config
+ if cfg.EmbeddingBackend != "" || cfg.EmbeddingURL != "" || cfg.EmbeddingModel != "" || cfg.EmbeddingDimension > 0 {
+ optimizerCfg.EmbeddingConfig = &embeddings.Config{
+ BackendType: cfg.EmbeddingBackend,
+ BaseURL: cfg.EmbeddingURL,
+ Model: cfg.EmbeddingModel,
+ Dimension: cfg.EmbeddingDimension,
+ }
+ }
+
+ return optimizerCfg
+}
diff --git a/pkg/vmcp/optimizer/dummy_optimizer.go b/pkg/vmcp/optimizer/dummy_optimizer.go
deleted file mode 100644
index 00c9be9eae..0000000000
--- a/pkg/vmcp/optimizer/dummy_optimizer.go
+++ /dev/null
@@ -1,119 +0,0 @@
-// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
-// SPDX-License-Identifier: Apache-2.0
-
-package optimizer
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "strings"
-
- "github.com/mark3labs/mcp-go/mcp"
- "github.com/mark3labs/mcp-go/server"
-)
-
-// DummyOptimizer implements the Optimizer interface using exact string matching.
-//
-// This implementation is intended for testing and development. It performs
-// case-insensitive substring matching on tool names and descriptions.
-//
-// For production use, see the EmbeddingOptimizer which uses semantic similarity.
-type DummyOptimizer struct {
- // tools contains all available tools indexed by name.
- tools map[string]server.ServerTool
-}
-
-// NewDummyOptimizer creates a new DummyOptimizer with the given tools.
-//
-// The tools slice should contain all backend tools (as ServerTool with handlers).
-func NewDummyOptimizer(tools []server.ServerTool) Optimizer {
- toolMap := make(map[string]server.ServerTool, len(tools))
- for _, tool := range tools {
- toolMap[tool.Tool.Name] = tool
- }
-
- return DummyOptimizer{
- tools: toolMap,
- }
-}
-
-// FindTool searches for tools using exact substring matching.
-//
-// The search is case-insensitive and matches against:
-// - Tool name (substring match)
-// - Tool description (substring match)
-//
-// Returns all matching tools with a score of 1.0 (exact match semantics).
-// TokenMetrics are returned as zero values (not implemented in dummy).
-func (d DummyOptimizer) FindTool(_ context.Context, input FindToolInput) (*FindToolOutput, error) {
- if input.ToolDescription == "" {
- return nil, fmt.Errorf("tool_description is required")
- }
-
- searchTerm := strings.ToLower(input.ToolDescription)
-
- var matches []ToolMatch
- for _, tool := range d.tools {
- nameLower := strings.ToLower(tool.Tool.Name)
- descLower := strings.ToLower(tool.Tool.Description)
-
- // Check if search term matches name or description
- if strings.Contains(nameLower, searchTerm) || strings.Contains(descLower, searchTerm) {
- schema, err := getToolSchema(tool.Tool)
- if err != nil {
- return nil, err
- }
- matches = append(matches, ToolMatch{
- Name: tool.Tool.Name,
- Description: tool.Tool.Description,
- InputSchema: schema,
- Score: 1.0, // Exact match semantics
- })
- }
- }
-
- return &FindToolOutput{
- Tools: matches,
- TokenMetrics: TokenMetrics{}, // Zero values for dummy
- }, nil
-}
-
-// CallTool invokes a tool by name using its registered handler.
-//
-// The tool is looked up by exact name match. If found, the handler
-// is invoked directly with the given parameters.
-func (d DummyOptimizer) CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error) {
- if input.ToolName == "" {
- return nil, fmt.Errorf("tool_name is required")
- }
-
- // Verify the tool exists
- tool, exists := d.tools[input.ToolName]
- if !exists {
- return mcp.NewToolResultError(fmt.Sprintf("tool not found: %s", input.ToolName)), nil
- }
-
- // Build the MCP request
- request := mcp.CallToolRequest{}
- request.Params.Name = input.ToolName
- request.Params.Arguments = input.Parameters
-
- // Call the tool handler directly
- return tool.Handler(ctx, request)
-}
-
-// getToolSchema returns the input schema for a tool.
-// Prefers RawInputSchema if set, otherwise marshals InputSchema.
-func getToolSchema(tool mcp.Tool) (json.RawMessage, error) {
- if len(tool.RawInputSchema) > 0 {
- return tool.RawInputSchema, nil
- }
-
- // Fall back to InputSchema
- data, err := json.Marshal(tool.InputSchema)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal input schema for tool %s: %w", tool.Name, err)
- }
- return data, nil
-}
diff --git a/pkg/vmcp/optimizer/dummy_optimizer_test.go b/pkg/vmcp/optimizer/dummy_optimizer_test.go
deleted file mode 100644
index 2113a5a4c1..0000000000
--- a/pkg/vmcp/optimizer/dummy_optimizer_test.go
+++ /dev/null
@@ -1,191 +0,0 @@
-// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
-// SPDX-License-Identifier: Apache-2.0
-
-package optimizer
-
-import (
- "context"
- "testing"
-
- "github.com/mark3labs/mcp-go/mcp"
- "github.com/mark3labs/mcp-go/server"
- "github.com/stretchr/testify/require"
-)
-
-func TestDummyOptimizer_FindTool(t *testing.T) {
- t.Parallel()
-
- tools := []server.ServerTool{
- {
- Tool: mcp.Tool{
- Name: "fetch_url",
- Description: "Fetch content from a URL",
- },
- },
- {
- Tool: mcp.Tool{
- Name: "read_file",
- Description: "Read a file from the filesystem",
- },
- },
- {
- Tool: mcp.Tool{
- Name: "write_file",
- Description: "Write content to a file",
- },
- },
- }
-
- opt := NewDummyOptimizer(tools)
-
- tests := []struct {
- name string
- input FindToolInput
- expectedNames []string
- expectedError bool
- errorContains string
- }{
- {
- name: "find by exact name",
- input: FindToolInput{
- ToolDescription: "fetch_url",
- },
- expectedNames: []string{"fetch_url"},
- },
- {
- name: "find by description substring",
- input: FindToolInput{
- ToolDescription: "file",
- },
- expectedNames: []string{"read_file", "write_file"},
- },
- {
- name: "case insensitive search",
- input: FindToolInput{
- ToolDescription: "FETCH",
- },
- expectedNames: []string{"fetch_url"},
- },
- {
- name: "no matches",
- input: FindToolInput{
- ToolDescription: "nonexistent",
- },
- expectedNames: []string{},
- },
- {
- name: "empty description",
- input: FindToolInput{},
- expectedError: true,
- errorContains: "tool_description is required",
- },
- }
-
- for _, tc := range tests {
- t.Run(tc.name, func(t *testing.T) {
- t.Parallel()
-
- result, err := opt.FindTool(context.Background(), tc.input)
-
- if tc.expectedError {
- require.Error(t, err)
- require.Contains(t, err.Error(), tc.errorContains)
- return
- }
-
- require.NoError(t, err)
- require.NotNil(t, result)
-
- // Extract names from results
- var names []string
- for _, match := range result.Tools {
- names = append(names, match.Name)
- }
-
- require.ElementsMatch(t, tc.expectedNames, names)
- })
- }
-}
-
-func TestDummyOptimizer_CallTool(t *testing.T) {
- t.Parallel()
-
- tools := []server.ServerTool{
- {
- Tool: mcp.Tool{
- Name: "test_tool",
- Description: "A test tool",
- },
- Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- args, _ := req.Params.Arguments.(map[string]any)
- input := args["input"].(string)
- return mcp.NewToolResultText("Hello, " + input + "!"), nil
- },
- },
- }
-
- opt := NewDummyOptimizer(tools)
-
- tests := []struct {
- name string
- input CallToolInput
- expectedText string
- expectedError bool
- isToolError bool
- errorContains string
- }{
- {
- name: "successful tool call",
- input: CallToolInput{
- ToolName: "test_tool",
- Parameters: map[string]any{"input": "World"},
- },
- expectedText: "Hello, World!",
- },
- {
- name: "tool not found",
- input: CallToolInput{
- ToolName: "nonexistent",
- Parameters: map[string]any{},
- },
- isToolError: true,
- expectedText: "tool not found: nonexistent",
- },
- {
- name: "empty tool name",
- input: CallToolInput{
- Parameters: map[string]any{},
- },
- expectedError: true,
- errorContains: "tool_name is required",
- },
- }
-
- for _, tc := range tests {
- t.Run(tc.name, func(t *testing.T) {
- t.Parallel()
-
- result, err := opt.CallTool(context.Background(), tc.input)
-
- if tc.expectedError {
- require.Error(t, err)
- require.Contains(t, err.Error(), tc.errorContains)
- return
- }
-
- require.NoError(t, err)
- require.NotNil(t, result)
-
- if tc.isToolError {
- require.True(t, result.IsError)
- }
-
- if tc.expectedText != "" {
- require.Len(t, result.Content, 1)
- textContent, ok := result.Content[0].(mcp.TextContent)
- require.True(t, ok)
- require.Equal(t, tc.expectedText, textContent.Text)
- }
- })
- }
-}
diff --git a/pkg/vmcp/optimizer/find_tool_semantic_search_test.go b/pkg/vmcp/optimizer/find_tool_semantic_search_test.go
new file mode 100644
index 0000000000..3868bfd54d
--- /dev/null
+++ b/pkg/vmcp/optimizer/find_tool_semantic_search_test.go
@@ -0,0 +1,693 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package optimizer
+
+import (
+ "context"
+ "encoding/json"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings"
+ transportsession "github.com/stacklok/toolhive/pkg/transport/session"
+ "github.com/stacklok/toolhive/pkg/vmcp"
+ "github.com/stacklok/toolhive/pkg/vmcp/aggregator"
+ "github.com/stacklok/toolhive/pkg/vmcp/discovery"
+ vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session"
+)
+
+const (
+ testBackendOllama = "ollama"
+ testBackendOpenAI = "openai"
+)
+
+// verifyEmbeddingBackendWorking verifies that the embedding backend is actually working by attempting to generate an embedding
+// This ensures the service is not just reachable but actually functional
+func verifyEmbeddingBackendWorking(t *testing.T, manager *embeddings.Manager, backendType string) {
+ t.Helper()
+ _, err := manager.GenerateEmbedding([]string{"test"})
+ if err != nil {
+ if backendType == testBackendOllama {
+ t.Skipf("Skipping test: Ollama is reachable but embedding generation failed. Error: %v. Ensure 'ollama pull %s' has been executed", err, embeddings.DefaultModelAllMiniLM)
+ } else {
+ t.Skipf("Skipping test: Embedding backend is reachable but embedding generation failed. Error: %v", err)
+ }
+ }
+}
+
+// TestFindTool_SemanticSearch tests semantic search capabilities
+// These tests verify that find_tool can find tools based on semantic meaning,
+// not just exact keyword matches
+func TestFindTool_SemanticSearch(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Try to use Ollama if available, otherwise skip test
+ embeddingBackend := testBackendOllama
+ embeddingConfig := &embeddings.Config{
+ BackendType: embeddingBackend,
+ BaseURL: "http://localhost:11434",
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384, // all-MiniLM-L6-v2 dimension
+ }
+
+ // Test if Ollama is available
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ // Try OpenAI-compatible (might be vLLM or Ollama v1 API)
+ embeddingConfig.BackendType = testBackendOpenAI
+ embeddingConfig.BaseURL = "http://localhost:11434"
+ embeddingConfig.Model = embeddings.DefaultModelAllMiniLM
+ embeddingConfig.Dimension = 768
+ embeddingManager, err = embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping semantic search test: No embedding backend available (Ollama or OpenAI-compatible). Error: %v", err)
+ return
+ }
+ embeddingBackend = testBackendOpenAI
+ }
+ t.Cleanup(func() { _ = embeddingManager.Close() })
+
+ // Verify embedding backend is actually working, not just reachable
+ verifyEmbeddingBackendWorking(t, embeddingManager, embeddingBackend)
+
+ // Setup optimizer integration with high semantic ratio to favor semantic search
+ mcpServer := server.NewMCPServer("test-server", "1.0")
+ mockClient := &mockBackendClient{}
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: embeddingBackend,
+ BaseURL: embeddingConfig.BaseURL,
+ Model: embeddingConfig.Model,
+ Dimension: embeddingConfig.Dimension,
+ },
+ HybridSearchRatio: 90, // 90% semantic, 10% BM25 to test semantic search
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ require.NotNil(t, integration)
+ t.Cleanup(func() { _ = integration.Close() })
+
+ // Create tools with diverse descriptions to test semantic understanding
+ tools := []vmcp.Tool{
+ {
+ Name: "github_pull_request_read",
+ Description: "Get information on a specific pull request in GitHub repository.",
+ BackendID: "github",
+ },
+ {
+ Name: "github_list_pull_requests",
+ Description: "List pull requests in a GitHub repository.",
+ BackendID: "github",
+ },
+ {
+ Name: "github_create_pull_request",
+ Description: "Create a new pull request in a GitHub repository.",
+ BackendID: "github",
+ },
+ {
+ Name: "github_merge_pull_request",
+ Description: "Merge a pull request in a GitHub repository.",
+ BackendID: "github",
+ },
+ {
+ Name: "github_issue_read",
+ Description: "Get information about a specific issue in a GitHub repository.",
+ BackendID: "github",
+ },
+ {
+ Name: "github_list_issues",
+ Description: "List issues in a GitHub repository.",
+ BackendID: "github",
+ },
+ {
+ Name: "github_create_repository",
+ Description: "Create a new GitHub repository in your account or specified organization",
+ BackendID: "github",
+ },
+ {
+ Name: "github_get_commit",
+ Description: "Get details for a commit from a GitHub repository",
+ BackendID: "github",
+ },
+ {
+ Name: "github_get_branch",
+ Description: "Get information about a branch in a GitHub repository",
+ BackendID: "github",
+ },
+ {
+ Name: "fetch_fetch",
+ Description: "Fetches a URL from the internet and optionally extracts its contents as markdown.",
+ BackendID: "fetch",
+ },
+ }
+
+ capabilities := &aggregator.AggregatedCapabilities{
+ Tools: tools,
+ RoutingTable: &vmcp.RoutingTable{
+ Tools: make(map[string]*vmcp.BackendTarget),
+ Resources: map[string]*vmcp.BackendTarget{},
+ Prompts: map[string]*vmcp.BackendTarget{},
+ },
+ }
+
+ for _, tool := range tools {
+ capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{
+ WorkloadID: tool.BackendID,
+ WorkloadName: tool.BackendID,
+ }
+ }
+
+ session := &mockSession{sessionID: "test-session"}
+ err = integration.OnRegisterSession(ctx, session, capabilities)
+ require.NoError(t, err)
+
+ // Manually ingest tools for testing (OnRegisterSession skips ingestion)
+ mcpTools := make([]mcp.Tool, len(tools))
+ for i, tool := range tools {
+ mcpTools[i] = mcp.Tool{
+ Name: tool.Name,
+ Description: tool.Description,
+ }
+ }
+ err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools)
+ require.NoError(t, err)
+
+ ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities)
+
+ // Test cases for semantic search - queries that mean the same thing but use different words
+ testCases := []struct {
+ name string
+ query string
+ keywords string
+ expectedTools []string // Tools that should be found semantically
+ description string
+ }{
+ {
+ name: "semantic_pr_synonyms",
+ query: "view code review request",
+ keywords: "",
+ expectedTools: []string{"github_pull_request_read", "github_list_pull_requests"},
+ description: "Should find PR tools using semantic synonyms (code review = pull request)",
+ },
+ {
+ name: "semantic_merge_synonyms",
+ query: "combine code changes",
+ keywords: "",
+ expectedTools: []string{"github_merge_pull_request"},
+ description: "Should find merge tool using semantic meaning (combine = merge)",
+ },
+ {
+ name: "semantic_create_synonyms",
+ query: "make a new code review",
+ keywords: "",
+ expectedTools: []string{"github_create_pull_request", "github_list_pull_requests", "github_pull_request_read"},
+ description: "Should find PR-related tools using semantic meaning (make = create, code review = PR)",
+ },
+ {
+ name: "semantic_issue_synonyms",
+ query: "show bug reports",
+ keywords: "",
+ expectedTools: []string{"github_issue_read", "github_list_issues"},
+ description: "Should find issue tools using semantic synonyms (bug report = issue)",
+ },
+ {
+ name: "semantic_repository_synonyms",
+ query: "start a new project",
+ keywords: "",
+ expectedTools: []string{"github_create_repository"},
+ description: "Should find repository tool using semantic meaning (project = repository)",
+ },
+ {
+ name: "semantic_commit_synonyms",
+ query: "get change details",
+ keywords: "",
+ expectedTools: []string{"github_get_commit"},
+ description: "Should find commit tool using semantic meaning (change = commit)",
+ },
+ {
+ name: "semantic_fetch_synonyms",
+ query: "download web page content",
+ keywords: "",
+ expectedTools: []string{"fetch_fetch"},
+ description: "Should find fetch tool using semantic synonyms (download = fetch)",
+ },
+ {
+ name: "semantic_branch_synonyms",
+ query: "get branch information",
+ keywords: "",
+ expectedTools: []string{"github_get_branch"},
+ description: "Should find branch tool using semantic meaning",
+ },
+ {
+ name: "semantic_related_concepts",
+ query: "code collaboration features",
+ keywords: "",
+ expectedTools: []string{"github_pull_request_read", "github_create_pull_request", "github_issue_read"},
+ description: "Should find collaboration-related tools (PRs and issues are collaboration features)",
+ },
+ {
+ name: "semantic_intent_based",
+ query: "I want to see what code changes were made",
+ keywords: "",
+ expectedTools: []string{"github_get_commit", "github_pull_request_read"},
+ description: "Should find tools based on user intent (seeing code changes = commits/PRs)",
+ },
+ }
+
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "tool_description": tc.query,
+ "tool_keywords": tc.keywords,
+ "limit": 10,
+ },
+ },
+ }
+
+ handler := integration.CreateFindToolHandler()
+ result, err := handler(ctxWithCaps, request)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.False(t, result.IsError, "Tool call should not return error for query: %s", tc.query)
+
+ // Parse the result
+ require.NotEmpty(t, result.Content, "Result should have content")
+ textContent, okText := mcp.AsTextContent(result.Content[0])
+ require.True(t, okText, "Result should be text content")
+
+ var response map[string]any
+ err = json.Unmarshal([]byte(textContent.Text), &response)
+ require.NoError(t, err, "Result should be valid JSON")
+
+ toolsArray, okArray := response["tools"].([]interface{})
+ require.True(t, okArray, "Response should have tools array")
+ require.NotEmpty(t, toolsArray, "Should return at least one result for semantic query: %s", tc.query)
+
+ // Extract tool names from results
+ foundTools := make([]string, 0, len(toolsArray))
+ for _, toolInterface := range toolsArray {
+ toolMap, okMap := toolInterface.(map[string]interface{})
+ require.True(t, okMap, "Tool should be a map")
+ toolName, okName := toolMap["name"].(string)
+ require.True(t, okName, "Tool should have name")
+ foundTools = append(foundTools, toolName)
+
+ // Verify similarity score exists and is reasonable
+ similarity, okScore := toolMap["similarity_score"].(float64)
+ require.True(t, okScore, "Tool should have similarity_score")
+ assert.Greater(t, similarity, 0.0, "Similarity score should be positive")
+ }
+
+ // Check that at least one expected tool is found
+ foundCount := 0
+ for _, expectedTool := range tc.expectedTools {
+ for _, foundTool := range foundTools {
+ if foundTool == expectedTool {
+ foundCount++
+ break
+ }
+ }
+ }
+
+ assert.GreaterOrEqual(t, foundCount, 1,
+ "Semantic query '%s' should find at least one expected tool from %v. Found tools: %v (found %d/%d)",
+ tc.query, tc.expectedTools, foundTools, foundCount, len(tc.expectedTools))
+
+ // Log results for debugging
+ if foundCount < len(tc.expectedTools) {
+ t.Logf("Semantic query '%s': Found %d/%d expected tools. Found: %v, Expected: %v",
+ tc.query, foundCount, len(tc.expectedTools), foundTools, tc.expectedTools)
+ }
+
+ // Verify token metrics exist
+ tokenMetrics, okMetrics := response["token_metrics"].(map[string]interface{})
+ require.True(t, okMetrics, "Response should have token_metrics")
+ assert.Contains(t, tokenMetrics, "baseline_tokens")
+ assert.Contains(t, tokenMetrics, "returned_tokens")
+ })
+ }
+}
+
+// TestFindTool_SemanticVsKeyword tests that semantic search finds different results than keyword search
+func TestFindTool_SemanticVsKeyword(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Try to use Ollama if available
+ embeddingBackend := "ollama"
+ embeddingConfig := &embeddings.Config{
+ BackendType: embeddingBackend,
+ BaseURL: "http://localhost:11434",
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ // Try OpenAI-compatible
+ embeddingConfig.BackendType = testBackendOpenAI
+ embeddingManager, err = embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: No embedding backend available. Error: %v", err)
+ return
+ }
+ embeddingBackend = testBackendOpenAI
+ }
+
+ // Verify embedding backend is actually working, not just reachable
+ verifyEmbeddingBackendWorking(t, embeddingManager, embeddingBackend)
+ _ = embeddingManager.Close()
+
+ mcpServer := server.NewMCPServer("test-server", "1.0")
+ mockClient := &mockBackendClient{}
+
+ // Test with high semantic ratio
+ configSemantic := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db-semantic"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: embeddingBackend,
+ BaseURL: embeddingConfig.BaseURL,
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ },
+ HybridSearchRatio: 90, // 90% semantic
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integrationSemantic, err := NewIntegration(ctx, configSemantic, mcpServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integrationSemantic.Close() }()
+
+ // Test with low semantic ratio (high BM25)
+ configKeyword := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db-keyword"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: embeddingBackend,
+ BaseURL: embeddingConfig.BaseURL,
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ },
+ HybridSearchRatio: 10, // 10% semantic, 90% BM25
+ }
+
+ integrationKeyword, err := NewIntegration(ctx, configKeyword, mcpServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integrationKeyword.Close() }()
+
+ tools := []vmcp.Tool{
+ {
+ Name: "github_pull_request_read",
+ Description: "Get information on a specific pull request in GitHub repository.",
+ BackendID: "github",
+ },
+ {
+ Name: "github_create_repository",
+ Description: "Create a new GitHub repository in your account or specified organization",
+ BackendID: "github",
+ },
+ }
+
+ capabilities := &aggregator.AggregatedCapabilities{
+ Tools: tools,
+ RoutingTable: &vmcp.RoutingTable{
+ Tools: make(map[string]*vmcp.BackendTarget),
+ Resources: map[string]*vmcp.BackendTarget{},
+ Prompts: map[string]*vmcp.BackendTarget{},
+ },
+ }
+
+ for _, tool := range tools {
+ capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{
+ WorkloadID: tool.BackendID,
+ WorkloadName: tool.BackendID,
+ }
+ }
+
+ session := &mockSession{sessionID: "test-session"}
+ ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities)
+
+ // Register both integrations
+ err = integrationSemantic.OnRegisterSession(ctx, session, capabilities)
+ require.NoError(t, err)
+
+ err = integrationKeyword.OnRegisterSession(ctx, session, capabilities)
+ require.NoError(t, err)
+
+ // Manually ingest tools for testing (OnRegisterSession skips ingestion)
+ mcpTools := make([]mcp.Tool, len(tools))
+ for i, tool := range tools {
+ mcpTools[i] = mcp.Tool{
+ Name: tool.Name,
+ Description: tool.Description,
+ }
+ }
+ err = integrationSemantic.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools)
+ require.NoError(t, err)
+ err = integrationKeyword.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools)
+ require.NoError(t, err)
+
+ // Query that has semantic meaning but no exact keyword match
+ query := "view code review"
+
+ // Test semantic search
+ requestSemantic := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "tool_description": query,
+ "tool_keywords": "",
+ "limit": 10,
+ },
+ },
+ }
+
+ handlerSemantic := integrationSemantic.CreateFindToolHandler()
+ resultSemantic, err := handlerSemantic(ctxWithCaps, requestSemantic)
+ require.NoError(t, err)
+ require.False(t, resultSemantic.IsError)
+
+ // Test keyword search
+ requestKeyword := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "tool_description": query,
+ "tool_keywords": "",
+ "limit": 10,
+ },
+ },
+ }
+
+ handlerKeyword := integrationKeyword.CreateFindToolHandler()
+ resultKeyword, err := handlerKeyword(ctxWithCaps, requestKeyword)
+ require.NoError(t, err)
+ require.False(t, resultKeyword.IsError)
+
+ // Parse both results
+ textSemantic, _ := mcp.AsTextContent(resultSemantic.Content[0])
+ var responseSemantic map[string]any
+ json.Unmarshal([]byte(textSemantic.Text), &responseSemantic)
+
+ textKeyword, _ := mcp.AsTextContent(resultKeyword.Content[0])
+ var responseKeyword map[string]any
+ json.Unmarshal([]byte(textKeyword.Text), &responseKeyword)
+
+ toolsSemantic, _ := responseSemantic["tools"].([]interface{})
+ toolsKeyword, _ := responseKeyword["tools"].([]interface{})
+
+ // Both should find results (semantic should find PR tools, keyword might not)
+ assert.NotEmpty(t, toolsSemantic, "Semantic search should find results")
+ assert.NotEmpty(t, toolsKeyword, "Keyword search should find results")
+
+ // Semantic search should find pull request tools even without exact keyword match
+ foundPRSemantic := false
+ for _, toolInterface := range toolsSemantic {
+ toolMap, _ := toolInterface.(map[string]interface{})
+ toolName, _ := toolMap["name"].(string)
+ if toolName == "github_pull_request_read" {
+ foundPRSemantic = true
+ break
+ }
+ }
+
+ t.Logf("Semantic search (90%% semantic): Found %d tools", len(toolsSemantic))
+ t.Logf("Keyword search (10%% semantic): Found %d tools", len(toolsKeyword))
+ t.Logf("Semantic search found PR tool: %v", foundPRSemantic)
+
+ // Semantic search should be able to find semantically related tools
+ // even when keywords don't match exactly
+ assert.True(t, foundPRSemantic,
+ "Semantic search should find 'github_pull_request_read' for query 'view code review' even without exact keyword match")
+}
+
+// TestFindTool_SemanticSimilarityScores tests that similarity scores are meaningful
+func TestFindTool_SemanticSimilarityScores(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Try to use Ollama if available
+ embeddingBackend := "ollama"
+ embeddingConfig := &embeddings.Config{
+ BackendType: embeddingBackend,
+ BaseURL: "http://localhost:11434",
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ // Try OpenAI-compatible
+ embeddingConfig.BackendType = testBackendOpenAI
+ embeddingManager, err = embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: No embedding backend available. Error: %v", err)
+ return
+ }
+ embeddingBackend = testBackendOpenAI
+ }
+
+ // Verify embedding backend is actually working, not just reachable
+ verifyEmbeddingBackendWorking(t, embeddingManager, embeddingBackend)
+ _ = embeddingManager.Close()
+
+ mcpServer := server.NewMCPServer("test-server", "1.0")
+ mockClient := &mockBackendClient{}
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: embeddingBackend,
+ BaseURL: embeddingConfig.BaseURL,
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ },
+ HybridSearchRatio: 90, // High semantic ratio
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ tools := []vmcp.Tool{
+ {
+ Name: "github_pull_request_read",
+ Description: "Get information on a specific pull request in GitHub repository.",
+ BackendID: "github",
+ },
+ {
+ Name: "github_create_repository",
+ Description: "Create a new GitHub repository in your account or specified organization",
+ BackendID: "github",
+ },
+ {
+ Name: "fetch_fetch",
+ Description: "Fetches a URL from the internet and optionally extracts its contents as markdown.",
+ BackendID: "fetch",
+ },
+ }
+
+ capabilities := &aggregator.AggregatedCapabilities{
+ Tools: tools,
+ RoutingTable: &vmcp.RoutingTable{
+ Tools: make(map[string]*vmcp.BackendTarget),
+ Resources: map[string]*vmcp.BackendTarget{},
+ Prompts: map[string]*vmcp.BackendTarget{},
+ },
+ }
+
+ for _, tool := range tools {
+ capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{
+ WorkloadID: tool.BackendID,
+ WorkloadName: tool.BackendID,
+ }
+ }
+
+ session := &mockSession{sessionID: "test-session"}
+ err = integration.OnRegisterSession(ctx, session, capabilities)
+ require.NoError(t, err)
+
+ // Manually ingest tools for testing (OnRegisterSession skips ingestion)
+ mcpTools := make([]mcp.Tool, len(tools))
+ for i, tool := range tools {
+ mcpTools[i] = mcp.Tool{
+ Name: tool.Name,
+ Description: tool.Description,
+ }
+ }
+ err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools)
+ require.NoError(t, err)
+
+ ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities)
+
+ // Query for pull request
+ query := "view pull request"
+
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "tool_description": query,
+ "tool_keywords": "",
+ "limit": 10,
+ },
+ },
+ }
+
+ handler := integration.CreateFindToolHandler()
+ result, err := handler(ctxWithCaps, request)
+ require.NoError(t, err)
+ require.False(t, result.IsError)
+
+ textContent, _ := mcp.AsTextContent(result.Content[0])
+ var response map[string]any
+ json.Unmarshal([]byte(textContent.Text), &response)
+
+ toolsArray, _ := response["tools"].([]interface{})
+ require.NotEmpty(t, toolsArray)
+
+ // Check that results are sorted by similarity (highest first)
+ var similarities []float64
+ for _, toolInterface := range toolsArray {
+ toolMap, _ := toolInterface.(map[string]interface{})
+ similarity, _ := toolMap["similarity_score"].(float64)
+ similarities = append(similarities, similarity)
+ }
+
+ // Verify results are sorted by similarity (descending)
+ for i := 1; i < len(similarities); i++ {
+ assert.GreaterOrEqual(t, similarities[i-1], similarities[i],
+ "Results should be sorted by similarity score (descending). Scores: %v", similarities)
+ }
+
+ // The most relevant tool (pull request) should have a higher similarity than unrelated tools
+ if len(similarities) > 1 {
+ // First result should have highest similarity
+ assert.Greater(t, similarities[0], 0.0, "Top result should have positive similarity")
+ }
+}
diff --git a/pkg/vmcp/optimizer/find_tool_string_matching_test.go b/pkg/vmcp/optimizer/find_tool_string_matching_test.go
new file mode 100644
index 0000000000..6166de6164
--- /dev/null
+++ b/pkg/vmcp/optimizer/find_tool_string_matching_test.go
@@ -0,0 +1,699 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package optimizer
+
+import (
+ "context"
+ "encoding/json"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings"
+ transportsession "github.com/stacklok/toolhive/pkg/transport/session"
+ "github.com/stacklok/toolhive/pkg/vmcp"
+ "github.com/stacklok/toolhive/pkg/vmcp/aggregator"
+ "github.com/stacklok/toolhive/pkg/vmcp/discovery"
+ vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session"
+)
+
+// verifyOllamaWorking verifies that Ollama is actually working by attempting to generate an embedding
+// This ensures the service is not just reachable but actually functional
+func verifyOllamaWorking(t *testing.T, manager *embeddings.Manager) {
+ t.Helper()
+ _, err := manager.GenerateEmbedding([]string{"test"})
+ if err != nil {
+ t.Skipf("Skipping test: Ollama is reachable but embedding generation failed. Error: %v. Ensure 'ollama pull %s' has been executed", err, embeddings.DefaultModelAllMiniLM)
+ }
+}
+
+// getRealToolData returns test data based on actual MCP server tools
+// These are real tool descriptions from GitHub and other MCP servers
+func getRealToolData() []vmcp.Tool {
+ return []vmcp.Tool{
+ {
+ Name: "github_pull_request_read",
+ Description: "Get information on a specific pull request in GitHub repository.",
+ BackendID: "github",
+ },
+ {
+ Name: "github_list_pull_requests",
+ Description: "List pull requests in a GitHub repository. If the user specifies an author, then DO NOT use this tool and use the search_pull_requests tool instead.",
+ BackendID: "github",
+ },
+ {
+ Name: "github_search_pull_requests",
+ Description: "Search for pull requests in GitHub repositories using issues search syntax already scoped to is:pr",
+ BackendID: "github",
+ },
+ {
+ Name: "github_create_pull_request",
+ Description: "Create a new pull request in a GitHub repository.",
+ BackendID: "github",
+ },
+ {
+ Name: "github_merge_pull_request",
+ Description: "Merge a pull request in a GitHub repository.",
+ BackendID: "github",
+ },
+ {
+ Name: "github_pull_request_review_write",
+ Description: "Create and/or submit, delete review of a pull request.",
+ BackendID: "github",
+ },
+ {
+ Name: "github_issue_read",
+ Description: "Get information about a specific issue in a GitHub repository.",
+ BackendID: "github",
+ },
+ {
+ Name: "github_list_issues",
+ Description: "List issues in a GitHub repository. For pagination, use the 'endCursor' from the previous response's 'pageInfo' in the 'after' parameter.",
+ BackendID: "github",
+ },
+ {
+ Name: "github_create_repository",
+ Description: "Create a new GitHub repository in your account or specified organization",
+ BackendID: "github",
+ },
+ {
+ Name: "github_get_commit",
+ Description: "Get details for a commit from a GitHub repository",
+ BackendID: "github",
+ },
+ {
+ Name: "fetch_fetch",
+ Description: "Fetches a URL from the internet and optionally extracts its contents as markdown.",
+ BackendID: "fetch",
+ },
+ }
+}
+
+// TestFindTool_StringMatching tests that find_tool can match strings correctly
+func TestFindTool_StringMatching(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Setup optimizer integration
+ mcpServer := server.NewMCPServer("test-server", "1.0")
+ mockClient := &mockBackendClient{}
+
+ // Try to use Ollama if available, otherwise skip test
+ embeddingConfig := &embeddings.Config{
+ BackendType: embeddings.BackendTypeOllama,
+ BaseURL: "http://localhost:11434",
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM)
+ return
+ }
+ t.Cleanup(func() { _ = embeddingManager.Close() })
+
+ // Verify Ollama is actually working, not just reachable
+ verifyOllamaWorking(t, embeddingManager)
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: embeddings.BackendTypeOllama,
+ BaseURL: "http://localhost:11434",
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ },
+ HybridSearchRatio: 50, // 50% semantic, 50% BM25 for better string matching
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ require.NotNil(t, integration)
+ t.Cleanup(func() { _ = integration.Close() })
+
+ // Get real tool data
+ tools := getRealToolData()
+
+ // Create capabilities with real tools
+ capabilities := &aggregator.AggregatedCapabilities{
+ Tools: tools,
+ RoutingTable: &vmcp.RoutingTable{
+ Tools: make(map[string]*vmcp.BackendTarget),
+ Resources: map[string]*vmcp.BackendTarget{},
+ Prompts: map[string]*vmcp.BackendTarget{},
+ },
+ }
+
+ // Build routing table
+ for _, tool := range tools {
+ capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{
+ WorkloadID: tool.BackendID,
+ WorkloadName: tool.BackendID,
+ }
+ }
+
+ // Register session and generate embeddings
+ session := &mockSession{sessionID: "test-session"}
+ err = integration.OnRegisterSession(ctx, session, capabilities)
+ require.NoError(t, err)
+
+ // Manually ingest tools for testing (OnRegisterSession skips ingestion)
+ mcpTools := make([]mcp.Tool, len(tools))
+ for i, tool := range tools {
+ mcpTools[i] = mcp.Tool{
+ Name: tool.Name,
+ Description: tool.Description,
+ }
+ }
+ err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools)
+ require.NoError(t, err)
+
+ // Create context with capabilities
+ ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities)
+
+ // Test cases: query -> expected tool names that should be found
+ testCases := []struct {
+ name string
+ query string
+ keywords string
+ expectedTools []string // Tools that should definitely be in results
+ minResults int // Minimum number of results expected
+ description string
+ }{
+ {
+ name: "exact_pull_request_match",
+ query: "pull request",
+ keywords: "pull request",
+ expectedTools: []string{"github_pull_request_read", "github_list_pull_requests", "github_create_pull_request"},
+ minResults: 3,
+ description: "Should find tools with exact 'pull request' string match",
+ },
+ {
+ name: "pull_request_in_name",
+ query: "pull request",
+ keywords: "pull_request",
+ expectedTools: []string{"github_pull_request_read", "github_list_pull_requests"},
+ minResults: 2,
+ description: "Should match tools with 'pull_request' in name",
+ },
+ {
+ name: "list_pull_requests",
+ query: "list pull requests",
+ keywords: "list pull requests",
+ expectedTools: []string{"github_list_pull_requests"},
+ minResults: 1,
+ description: "Should find list pull requests tool",
+ },
+ {
+ name: "read_pull_request",
+ query: "read pull request",
+ keywords: "read pull request",
+ expectedTools: []string{"github_pull_request_read"},
+ minResults: 1,
+ description: "Should find read pull request tool",
+ },
+ {
+ name: "create_pull_request",
+ query: "create pull request",
+ keywords: "create pull request",
+ expectedTools: []string{"github_create_pull_request"},
+ minResults: 1,
+ description: "Should find create pull request tool",
+ },
+ {
+ name: "merge_pull_request",
+ query: "merge pull request",
+ keywords: "merge pull request",
+ expectedTools: []string{"github_merge_pull_request"},
+ minResults: 1,
+ description: "Should find merge pull request tool",
+ },
+ {
+ name: "search_pull_requests",
+ query: "search pull requests",
+ keywords: "search pull requests",
+ expectedTools: []string{"github_search_pull_requests"},
+ minResults: 1,
+ description: "Should find search pull requests tool",
+ },
+ {
+ name: "issue_tools",
+ query: "issue",
+ keywords: "issue",
+ expectedTools: []string{"github_issue_read", "github_list_issues"},
+ minResults: 2,
+ description: "Should find issue-related tools",
+ },
+ {
+ name: "repository_tool",
+ query: "create repository",
+ keywords: "create repository",
+ expectedTools: []string{"github_create_repository"},
+ minResults: 1,
+ description: "Should find create repository tool",
+ },
+ {
+ name: "commit_tool",
+ query: "get commit",
+ keywords: "commit",
+ expectedTools: []string{"github_get_commit"},
+ minResults: 1,
+ description: "Should find get commit tool",
+ },
+ {
+ name: "fetch_tool",
+ query: "fetch URL",
+ keywords: "fetch",
+ expectedTools: []string{"fetch_fetch"},
+ minResults: 1,
+ description: "Should find fetch tool",
+ },
+ }
+
+ for _, tc := range testCases {
+ tc := tc // capture loop variable
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Create the tool call request
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "tool_description": tc.query,
+ "tool_keywords": tc.keywords,
+ "limit": 20,
+ },
+ },
+ }
+
+ // Call the handler
+ handler := integration.CreateFindToolHandler()
+ result, err := handler(ctxWithCaps, request)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.False(t, result.IsError, "Tool call should not return error")
+
+ // Parse the result
+ require.NotEmpty(t, result.Content, "Result should have content")
+ textContent, ok := mcp.AsTextContent(result.Content[0])
+ require.True(t, ok, "Result should be text content")
+
+ // Parse JSON response
+ var response map[string]any
+ err = json.Unmarshal([]byte(textContent.Text), &response)
+ require.NoError(t, err, "Result should be valid JSON")
+
+ // Check tools array exists
+ toolsArray, ok := response["tools"].([]interface{})
+ require.True(t, ok, "Response should have tools array")
+ require.GreaterOrEqual(t, len(toolsArray), tc.minResults,
+ "Should return at least %d results for query: %s", tc.minResults, tc.query)
+
+ // Extract tool names from results
+ foundTools := make([]string, 0, len(toolsArray))
+ for _, toolInterface := range toolsArray {
+ toolMap, okMap := toolInterface.(map[string]interface{})
+ require.True(t, okMap, "Tool should be a map")
+ toolName, okName := toolMap["name"].(string)
+ require.True(t, okName, "Tool should have name")
+ foundTools = append(foundTools, toolName)
+ }
+
+ // Check that at least some expected tools are found
+ // String matching may not be perfect, so we check that at least one expected tool is found
+ foundCount := 0
+ for _, expectedTool := range tc.expectedTools {
+ for _, foundTool := range foundTools {
+ if foundTool == expectedTool {
+ foundCount++
+ break
+ }
+ }
+ }
+
+ // We should find at least one expected tool, or at least 50% of expected tools
+ minExpected := 1
+ if len(tc.expectedTools) > 1 {
+ half := len(tc.expectedTools) / 2
+ if half > minExpected {
+ minExpected = half
+ }
+ }
+
+ assert.GreaterOrEqual(t, foundCount, minExpected,
+ "Query '%s' should find at least %d of expected tools %v. Found tools: %v (found %d/%d)",
+ tc.query, minExpected, tc.expectedTools, foundTools, foundCount, len(tc.expectedTools))
+
+ // Log which expected tools were found for debugging
+ if foundCount < len(tc.expectedTools) {
+ t.Logf("Query '%s': Found %d/%d expected tools. Found: %v, Expected: %v",
+ tc.query, foundCount, len(tc.expectedTools), foundTools, tc.expectedTools)
+ }
+
+ // Verify token metrics exist
+ tokenMetrics, ok := response["token_metrics"].(map[string]interface{})
+ require.True(t, ok, "Response should have token_metrics")
+ assert.Contains(t, tokenMetrics, "baseline_tokens")
+ assert.Contains(t, tokenMetrics, "returned_tokens")
+ assert.Contains(t, tokenMetrics, "tokens_saved")
+ assert.Contains(t, tokenMetrics, "savings_percentage")
+ })
+ }
+}
+
+// TestFindTool_ExactStringMatch tests that exact string matches work correctly
+func TestFindTool_ExactStringMatch(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Setup optimizer integration with higher BM25 ratio for better string matching
+ mcpServer := server.NewMCPServer("test-server", "1.0")
+ mockClient := &mockBackendClient{}
+
+ // Try to use Ollama if available, otherwise skip test
+ embeddingConfig := &embeddings.Config{
+ BackendType: embeddings.BackendTypeOllama,
+ BaseURL: "http://localhost:11434",
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM)
+ return
+ }
+ t.Cleanup(func() { _ = embeddingManager.Close() })
+
+ // Verify Ollama is actually working, not just reachable
+ verifyOllamaWorking(t, embeddingManager)
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: embeddings.BackendTypeOllama,
+ BaseURL: "http://localhost:11434",
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ },
+ HybridSearchRatio: 30, // 30% semantic, 70% BM25 for better exact string matching
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ require.NotNil(t, integration)
+ t.Cleanup(func() { _ = integration.Close() })
+
+ // Create tools with specific strings to match
+ tools := []vmcp.Tool{
+ {
+ Name: "test_pull_request_tool",
+ Description: "This tool handles pull requests in GitHub",
+ BackendID: "test",
+ },
+ {
+ Name: "test_issue_tool",
+ Description: "This tool handles issues in GitHub",
+ BackendID: "test",
+ },
+ {
+ Name: "test_repository_tool",
+ Description: "This tool creates repositories",
+ BackendID: "test",
+ },
+ }
+
+ capabilities := &aggregator.AggregatedCapabilities{
+ Tools: tools,
+ RoutingTable: &vmcp.RoutingTable{
+ Tools: make(map[string]*vmcp.BackendTarget),
+ Resources: map[string]*vmcp.BackendTarget{},
+ Prompts: map[string]*vmcp.BackendTarget{},
+ },
+ }
+
+ for _, tool := range tools {
+ capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{
+ WorkloadID: tool.BackendID,
+ WorkloadName: tool.BackendID,
+ }
+ }
+
+ session := &mockSession{sessionID: "test-session"}
+ err = integration.OnRegisterSession(ctx, session, capabilities)
+ require.NoError(t, err)
+
+ // Manually ingest tools for testing (OnRegisterSession skips ingestion)
+ mcpTools := make([]mcp.Tool, len(tools))
+ for i, tool := range tools {
+ mcpTools[i] = mcp.Tool{
+ Name: tool.Name,
+ Description: tool.Description,
+ }
+ }
+ err = integration.IngestToolsForTesting(ctx, "test", "test", nil, mcpTools)
+ require.NoError(t, err)
+
+ ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities)
+
+ // Test exact string matching
+ testCases := []struct {
+ name string
+ query string
+ keywords string
+ expectedTool string
+ description string
+ }{
+ {
+ name: "exact_pull_request_string",
+ query: "pull request",
+ keywords: "pull request",
+ expectedTool: "test_pull_request_tool",
+ description: "Should match exact 'pull request' string",
+ },
+ {
+ name: "exact_issue_string",
+ query: "issue",
+ keywords: "issue",
+ expectedTool: "test_issue_tool",
+ description: "Should match exact 'issue' string",
+ },
+ {
+ name: "exact_repository_string",
+ query: "repository",
+ keywords: "repository",
+ expectedTool: "test_repository_tool",
+ description: "Should match exact 'repository' string",
+ },
+ }
+
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "tool_description": tc.query,
+ "tool_keywords": tc.keywords,
+ "limit": 10,
+ },
+ },
+ }
+
+ handler := integration.CreateFindToolHandler()
+ result, err := handler(ctxWithCaps, request)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.False(t, result.IsError)
+
+ textContent, okText := mcp.AsTextContent(result.Content[0])
+ require.True(t, okText)
+
+ var response map[string]any
+ err = json.Unmarshal([]byte(textContent.Text), &response)
+ require.NoError(t, err)
+
+ toolsArray, okArray := response["tools"].([]interface{})
+ require.True(t, okArray)
+ require.NotEmpty(t, toolsArray, "Should find at least one tool for query: %s", tc.query)
+
+ // Check that the expected tool is in the results
+ found := false
+ for _, toolInterface := range toolsArray {
+ toolMap, okMap := toolInterface.(map[string]interface{})
+ require.True(t, okMap)
+ toolName, okName := toolMap["name"].(string)
+ require.True(t, okName)
+ if toolName == tc.expectedTool {
+ found = true
+ break
+ }
+ }
+
+ assert.True(t, found,
+ "Expected tool '%s' not found in results for query '%s'. This indicates string matching is not working correctly.",
+ tc.expectedTool, tc.query)
+ })
+ }
+}
+
+// TestFindTool_CaseInsensitive tests case-insensitive string matching
+func TestFindTool_CaseInsensitive(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ mcpServer := server.NewMCPServer("test-server", "1.0")
+ mockClient := &mockBackendClient{}
+
+ // Try to use Ollama if available, otherwise skip test
+ embeddingConfig := &embeddings.Config{
+ BackendType: embeddings.BackendTypeOllama,
+ BaseURL: "http://localhost:11434",
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM)
+ return
+ }
+ t.Cleanup(func() { _ = embeddingManager.Close() })
+
+ // Verify Ollama is actually working, not just reachable
+ verifyOllamaWorking(t, embeddingManager)
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: embeddings.BackendTypeOllama,
+ BaseURL: "http://localhost:11434",
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ },
+ HybridSearchRatio: 30, // Favor BM25 for string matching
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ require.NotNil(t, integration)
+ t.Cleanup(func() { _ = integration.Close() })
+
+ tools := []vmcp.Tool{
+ {
+ Name: "github_pull_request_read",
+ Description: "Get information on a specific pull request in GitHub repository.",
+ BackendID: "github",
+ },
+ }
+
+ capabilities := &aggregator.AggregatedCapabilities{
+ Tools: tools,
+ RoutingTable: &vmcp.RoutingTable{
+ Tools: map[string]*vmcp.BackendTarget{
+ "github_pull_request_read": {
+ WorkloadID: "github",
+ WorkloadName: "github",
+ },
+ },
+ Resources: map[string]*vmcp.BackendTarget{},
+ Prompts: map[string]*vmcp.BackendTarget{},
+ },
+ }
+
+ session := &mockSession{sessionID: "test-session"}
+ err = integration.OnRegisterSession(ctx, session, capabilities)
+ require.NoError(t, err)
+
+ // Manually ingest tools for testing (OnRegisterSession skips ingestion)
+ mcpTools := make([]mcp.Tool, len(tools))
+ for i, tool := range tools {
+ mcpTools[i] = mcp.Tool{
+ Name: tool.Name,
+ Description: tool.Description,
+ }
+ }
+ err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools)
+ require.NoError(t, err)
+
+ ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities)
+
+ // Test different case variations
+ queries := []string{
+ "PULL REQUEST",
+ "Pull Request",
+ "pull request",
+ "PuLl ReQuEsT",
+ }
+
+ for _, query := range queries {
+ query := query
+ t.Run("case_"+strings.ToLower(query), func(t *testing.T) {
+ t.Parallel()
+
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "tool_description": query,
+ "tool_keywords": strings.ToLower(query),
+ "limit": 10,
+ },
+ },
+ }
+
+ handler := integration.CreateFindToolHandler()
+ result, err := handler(ctxWithCaps, request)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.False(t, result.IsError)
+
+ textContent, okText := mcp.AsTextContent(result.Content[0])
+ require.True(t, okText)
+
+ var response map[string]any
+ err = json.Unmarshal([]byte(textContent.Text), &response)
+ require.NoError(t, err)
+
+ toolsArray, okArray := response["tools"].([]interface{})
+ require.True(t, okArray)
+
+ // Should find the pull request tool regardless of case
+ found := false
+ for _, toolInterface := range toolsArray {
+ toolMap, okMap := toolInterface.(map[string]interface{})
+ require.True(t, okMap)
+ toolName, okName := toolMap["name"].(string)
+ require.True(t, okName)
+ if toolName == "github_pull_request_read" {
+ found = true
+ break
+ }
+ }
+
+ assert.True(t, found,
+ "Should find pull request tool with case-insensitive query: %s", query)
+ })
+ }
+}
diff --git a/pkg/vmcp/optimizer/integration.go b/pkg/vmcp/optimizer/integration.go
new file mode 100644
index 0000000000..01d2f74291
--- /dev/null
+++ b/pkg/vmcp/optimizer/integration.go
@@ -0,0 +1,42 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package optimizer
+
+import (
+ "context"
+
+ "github.com/mark3labs/mcp-go/server"
+
+ "github.com/stacklok/toolhive/pkg/vmcp"
+ "github.com/stacklok/toolhive/pkg/vmcp/aggregator"
+ "github.com/stacklok/toolhive/pkg/vmcp/server/adapter"
+)
+
+// Integration is the interface for optimizer functionality in vMCP.
+// This interface encapsulates all optimizer logic, keeping server.go clean.
+type Integration interface {
+ // Initialize performs all optimizer initialization:
+ // - Registers optimizer tools globally with the MCP server
+ // - Ingests initial backends from the registry
+ // This should be called once during server startup, after the MCP server is created.
+ Initialize(ctx context.Context, mcpServer *server.MCPServer, backendRegistry vmcp.BackendRegistry) error
+
+ // HandleSessionRegistration handles session registration for optimizer mode.
+ // Returns true if optimizer mode is enabled and handled the registration,
+ // false if optimizer is disabled and normal registration should proceed.
+ // The resourceConverter function converts vmcp.Resource to server.ServerResource.
+ HandleSessionRegistration(
+ ctx context.Context,
+ sessionID string,
+ caps *aggregator.AggregatedCapabilities,
+ mcpServer *server.MCPServer,
+ resourceConverter func([]vmcp.Resource) []server.ServerResource,
+ ) (bool, error)
+
+ // Close cleans up optimizer resources
+ Close() error
+
+ // OptimizerHandlerProvider is embedded to provide tool handlers
+ adapter.OptimizerHandlerProvider
+}
diff --git a/pkg/vmcp/optimizer/optimizer.go b/pkg/vmcp/optimizer/optimizer.go
index fea0425bb5..d3640419ec 100644
--- a/pkg/vmcp/optimizer/optimizer.go
+++ b/pkg/vmcp/optimizer/optimizer.go
@@ -1,91 +1,889 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0
-// Package optimizer provides the Optimizer interface for intelligent tool discovery
-// and invocation in the Virtual MCP Server.
+// Package optimizer provides vMCP integration for semantic tool discovery.
//
-// When the optimizer is enabled, vMCP exposes only two tools to clients:
-// - find_tool: Semantic search over available tools
-// - call_tool: Dynamic invocation of any backend tool
+// This package implements the RFC-0022 optimizer integration, exposing:
+// - optim_find_tool: Semantic/keyword-based tool discovery
+// - optim_call_tool: Dynamic tool invocation across backends
//
-// This reduces token usage by avoiding the need to send all tool definitions
-// to the LLM, instead allowing it to discover relevant tools on demand.
+// Architecture:
+// - Embeddings are generated during session initialization (OnRegisterSession hook)
+// - Tools are exposed as standard MCP tools callable via tools/call
+// - Integrates with vMCP's two-boundary authentication model
+// - Uses existing router for backend tool invocation
package optimizer
import (
"context"
"encoding/json"
+ "fmt"
+ "sync"
+ "time"
"github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/codes"
+ "go.opentelemetry.io/otel/metric"
+ "go.opentelemetry.io/otel/trace"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/db"
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings"
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/ingestion"
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models"
+ "github.com/stacklok/toolhive/pkg/logger"
+ transportsession "github.com/stacklok/toolhive/pkg/transport/session"
+ "github.com/stacklok/toolhive/pkg/vmcp"
+ "github.com/stacklok/toolhive/pkg/vmcp/aggregator"
+ "github.com/stacklok/toolhive/pkg/vmcp/discovery"
+ "github.com/stacklok/toolhive/pkg/vmcp/server/adapter"
)
-// Optimizer defines the interface for intelligent tool discovery and invocation.
+// Config holds optimizer configuration for vMCP integration.
+type Config struct {
+ // Enabled controls whether optimizer tools are available
+ Enabled bool
+
+ // PersistPath is the optional path for chromem-go database persistence (empty = in-memory)
+ PersistPath string
+
+ // FTSDBPath is the path to SQLite FTS5 database for BM25 search
+ // (empty = auto-default: ":memory:" or "{PersistPath}/fts.db")
+ FTSDBPath string
+
+ // HybridSearchRatio controls semantic vs BM25 mix (0-100 percentage, default: 70)
+ HybridSearchRatio int
+
+ // EmbeddingConfig configures the embedding backend (vLLM, Ollama, placeholder)
+ EmbeddingConfig *embeddings.Config
+}
+
+// OptimizerIntegration manages optimizer functionality within vMCP.
//
-// Implementations may use various strategies for tool matching:
-// - DummyOptimizer: Exact string matching (for testing)
-// - EmbeddingOptimizer: Semantic similarity via embeddings (production)
-type Optimizer interface {
- // FindTool searches for tools matching the given description and keywords.
- // Returns matching tools ranked by relevance score.
- FindTool(ctx context.Context, input FindToolInput) (*FindToolOutput, error)
+//nolint:revive // Name is intentional for clarity in external packages
+type OptimizerIntegration struct {
+ config *Config
+ ingestionService *ingestion.Service
+ mcpServer *server.MCPServer // For registering tools
+ backendClient vmcp.BackendClient // For querying backends at startup
+ sessionManager *transportsession.Manager
+ processedSessions sync.Map // Track sessions that have already been processed
+ tracer trace.Tracer
+}
+
+// NewIntegration creates a new optimizer integration.
+func NewIntegration(
+ _ context.Context,
+ cfg *Config,
+ mcpServer *server.MCPServer,
+ backendClient vmcp.BackendClient,
+ sessionManager *transportsession.Manager,
+) (*OptimizerIntegration, error) {
+ if cfg == nil || !cfg.Enabled {
+ return nil, nil // Optimizer disabled
+ }
+
+ // Initialize ingestion service with embedding backend
+ ingestionCfg := &ingestion.Config{
+ DBConfig: &db.Config{
+ PersistPath: cfg.PersistPath,
+ FTSDBPath: cfg.FTSDBPath,
+ },
+ EmbeddingConfig: cfg.EmbeddingConfig,
+ }
+
+ svc, err := ingestion.NewService(ingestionCfg)
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize optimizer service: %w", err)
+ }
- // CallTool invokes a tool by name with the given parameters.
- // Returns the tool's result or an error if the tool is not found or execution fails.
- // Returns the MCP CallToolResult directly from the underlying tool handler.
- CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error)
+ return &OptimizerIntegration{
+ config: cfg,
+ ingestionService: svc,
+ mcpServer: mcpServer,
+ backendClient: backendClient,
+ sessionManager: sessionManager,
+ tracer: otel.Tracer("github.com/stacklok/toolhive/pkg/vmcp/optimizer"),
+ }, nil
}
-// FindToolInput contains the parameters for finding tools.
-type FindToolInput struct {
- // ToolDescription is a natural language description of the tool to find.
- ToolDescription string `json:"tool_description" description:"Natural language description of the tool to find"`
+// Ensure OptimizerIntegration implements Integration interface at compile time.
+var _ Integration = (*OptimizerIntegration)(nil)
+
+// HandleSessionRegistration handles session registration for optimizer mode.
+// Returns true if optimizer mode is enabled and handled the registration,
+// false if optimizer is disabled and normal registration should proceed.
+//
+// When optimizer is enabled:
+// 1. Registers optimizer tools (find_tool, call_tool) for the session
+// 2. Injects resources (but not backend tools or composite tools)
+// 3. Backend tools are accessible via find_tool and call_tool
+func (o *OptimizerIntegration) HandleSessionRegistration(
+ _ context.Context,
+ sessionID string,
+ caps *aggregator.AggregatedCapabilities,
+ mcpServer *server.MCPServer,
+ resourceConverter func([]vmcp.Resource) []server.ServerResource,
+) (bool, error) {
+ if o == nil {
+ return false, nil // Optimizer not enabled, use normal registration
+ }
+
+ logger.Debugw("HandleSessionRegistration called for optimizer mode", "session_id", sessionID)
- // ToolKeywords is an optional list of keywords to narrow the search.
- ToolKeywords []string `json:"tool_keywords,omitempty" description:"Optional keywords to narrow search"`
+ // Register optimizer tools for this session
+ // Tools are already registered globally, but we need to add them to the session
+ // when using WithToolCapabilities(false)
+ optimizerTools, err := adapter.CreateOptimizerTools(o)
+ if err != nil {
+ return false, fmt.Errorf("failed to create optimizer tools: %w", err)
+ }
+
+ // Add optimizer tools to session
+ if err := mcpServer.AddSessionTools(sessionID, optimizerTools...); err != nil {
+ return false, fmt.Errorf("failed to add optimizer tools to session: %w", err)
+ }
+
+ logger.Debugw("Optimizer tools registered for session", "session_id", sessionID)
+
+ // Inject resources (but not backend tools or composite tools)
+ // Backend tools will be accessible via find_tool and call_tool
+ if len(caps.Resources) > 0 {
+ sdkResources := resourceConverter(caps.Resources)
+ if err := mcpServer.AddSessionResources(sessionID, sdkResources...); err != nil {
+ return false, fmt.Errorf("failed to add session resources: %w", err)
+ }
+ logger.Debugw("Added session resources (optimizer mode)",
+ "session_id", sessionID,
+ "count", len(sdkResources))
+ }
+
+ logger.Infow("Optimizer mode: backend tools not exposed directly",
+ "session_id", sessionID,
+ "backend_tool_count", len(caps.Tools),
+ "resource_count", len(caps.Resources))
+
+ return true, nil // Optimizer handled the registration
}
-// FindToolOutput contains the results of a tool search.
-type FindToolOutput struct {
- // Tools contains the matching tools, ranked by relevance.
- Tools []ToolMatch `json:"tools"`
+// OnRegisterSession is a legacy method kept for test compatibility.
+// It does nothing since ingestion is now handled by Initialize().
+// This method is deprecated and will be removed in a future version.
+// Tests should be updated to use HandleSessionRegistration instead.
+func (o *OptimizerIntegration) OnRegisterSession(
+ _ context.Context,
+ session server.ClientSession,
+ _ *aggregator.AggregatedCapabilities,
+) error {
+ if o == nil {
+ return nil // Optimizer not enabled
+ }
+
+ sessionID := session.SessionID()
+
+ logger.Debugw("OnRegisterSession called (legacy method, no-op)", "session_id", sessionID)
- // TokenMetrics provides information about token savings from using the optimizer.
- TokenMetrics TokenMetrics `json:"token_metrics"`
+ // Check if this session has already been processed
+ if _, alreadyProcessed := o.processedSessions.LoadOrStore(sessionID, true); alreadyProcessed {
+ logger.Debugw("Session already processed, skipping duplicate ingestion",
+ "session_id", sessionID)
+ return nil
+ }
+
+ // Skip ingestion in OnRegisterSession - IngestInitialBackends already handles ingestion at startup
+ // This prevents duplicate ingestion when sessions are registered
+ // The optimizer database is populated once at startup, not per-session
+ logger.Infow("Skipping ingestion in OnRegisterSession (handled by Initialize at startup)",
+ "session_id", sessionID)
+
+ return nil
}
-// ToolMatch represents a tool that matched the search criteria.
-type ToolMatch struct {
- // Name is the unique identifier of the tool.
- Name string `json:"name"`
+// Initialize performs all optimizer initialization:
+// - Registers optimizer tools globally with the MCP server
+// - Ingests initial backends from the registry
+//
+// This should be called once during server startup, after the MCP server is created.
+func (o *OptimizerIntegration) Initialize(
+ ctx context.Context,
+ mcpServer *server.MCPServer,
+ backendRegistry vmcp.BackendRegistry,
+) error {
+ if o == nil {
+ return nil // Optimizer not enabled
+ }
- // Description is the human-readable description of the tool.
- Description string `json:"description"`
+ // Register optimizer tools globally (available to all sessions immediately)
+ optimizerTools, err := adapter.CreateOptimizerTools(o)
+ if err != nil {
+ return fmt.Errorf("failed to create optimizer tools: %w", err)
+ }
+ for _, tool := range optimizerTools {
+ mcpServer.AddTool(tool.Tool, tool.Handler)
+ }
+ logger.Info("Optimizer tools registered globally")
- // InputSchema is the JSON schema for the tool's input parameters.
- // Uses json.RawMessage to preserve the original schema format.
- InputSchema json.RawMessage `json:"input_schema"`
+ // Ingest discovered backends into optimizer database
+ initialBackends := backendRegistry.List(ctx)
+ if err := o.IngestInitialBackends(ctx, initialBackends); err != nil {
+ logger.Warnf("Failed to ingest initial backends into optimizer: %v", err)
+ // Don't fail initialization - optimizer can still work with incremental ingestion
+ }
- // Score indicates how well this tool matches the search criteria (0.0-1.0).
- Score float64 `json:"score"`
+ return nil
}
-// TokenMetrics provides information about token usage optimization.
-type TokenMetrics struct {
- // BaselineTokens is the estimated tokens if all tools were sent.
- BaselineTokens int `json:"baseline_tokens"`
+// RegisterTools adds optimizer tools to the session.
+// Even though tools are registered globally via RegisterGlobalTools(),
+// with WithToolCapabilities(false), we also need to register them per-session
+// to ensure they appear in list_tools responses.
+// This should be called after OnRegisterSession completes.
+func (o *OptimizerIntegration) RegisterTools(_ context.Context, session server.ClientSession) error {
+ if o == nil {
+ return nil // Optimizer not enabled
+ }
+
+ sessionID := session.SessionID()
- // ReturnedTokens is the actual tokens for the returned tools.
- ReturnedTokens int `json:"returned_tokens"`
+ // Define optimizer tools with handlers (same as global registration)
+ optimizerTools := []server.ServerTool{
+ {
+ Tool: mcp.Tool{
+ Name: "optim_find_tool",
+ Description: "Semantic search across all backend tools using natural language description and optional keywords",
+ InputSchema: mcp.ToolInputSchema{
+ Type: "object",
+ Properties: map[string]any{
+ "tool_description": map[string]any{
+ "type": "string",
+ "description": "Natural language description of the tool you're looking for",
+ },
+ "tool_keywords": map[string]any{
+ "type": "string",
+ "description": "Optional space-separated keywords for keyword-based search",
+ },
+ "limit": map[string]any{
+ "type": "integer",
+ "description": "Maximum number of tools to return (default: 10)",
+ "default": 10,
+ },
+ },
+ Required: []string{"tool_description"},
+ },
+ },
+ Handler: o.createFindToolHandler(),
+ },
+ {
+ Tool: mcp.Tool{
+ Name: "optim_call_tool",
+ Description: "Dynamically invoke any tool on any backend using the backend_id from find_tool",
+ InputSchema: mcp.ToolInputSchema{
+ Type: "object",
+ Properties: map[string]any{
+ "backend_id": map[string]any{
+ "type": "string",
+ "description": "Backend ID from find_tool results",
+ },
+ "tool_name": map[string]any{
+ "type": "string",
+ "description": "Tool name to invoke",
+ },
+ "parameters": map[string]any{
+ "type": "object",
+ "description": "Parameters to pass to the tool",
+ },
+ },
+ Required: []string{"backend_id", "tool_name", "parameters"},
+ },
+ },
+ Handler: o.CreateCallToolHandler(),
+ },
+ }
+
+ // Add tools to session (required when WithToolCapabilities(false))
+ if err := o.mcpServer.AddSessionTools(sessionID, optimizerTools...); err != nil {
+ return fmt.Errorf("failed to add optimizer tools to session: %w", err)
+ }
+
+ logger.Debugw("Optimizer tools registered for session", "session_id", sessionID)
+ return nil
+}
- // SavingsPercent is the percentage of tokens saved.
- SavingsPercent float64 `json:"savings_percent"`
+// GetOptimizerToolDefinitions returns the tool definitions for optimizer tools
+// without handlers. This is useful for adding tools to capabilities before session registration.
+func (o *OptimizerIntegration) GetOptimizerToolDefinitions() []mcp.Tool {
+ if o == nil {
+ return nil
+ }
+ return []mcp.Tool{
+ {
+ Name: "optim_find_tool",
+ Description: "Semantic search across all backend tools using natural language description and optional keywords",
+ InputSchema: mcp.ToolInputSchema{
+ Type: "object",
+ Properties: map[string]any{
+ "tool_description": map[string]any{
+ "type": "string",
+ "description": "Natural language description of the tool you're looking for",
+ },
+ "tool_keywords": map[string]any{
+ "type": "string",
+ "description": "Optional space-separated keywords for keyword-based search",
+ },
+ "limit": map[string]any{
+ "type": "integer",
+ "description": "Maximum number of tools to return (default: 10)",
+ "default": 10,
+ },
+ },
+ Required: []string{"tool_description"},
+ },
+ },
+ {
+ Name: "optim_call_tool",
+ Description: "Dynamically invoke any tool on any backend using the backend_id from find_tool",
+ InputSchema: mcp.ToolInputSchema{
+ Type: "object",
+ Properties: map[string]any{
+ "backend_id": map[string]any{
+ "type": "string",
+ "description": "Backend ID from find_tool results",
+ },
+ "tool_name": map[string]any{
+ "type": "string",
+ "description": "Tool name to invoke",
+ },
+ "parameters": map[string]any{
+ "type": "object",
+ "description": "Parameters to pass to the tool",
+ },
+ },
+ Required: []string{"backend_id", "tool_name", "parameters"},
+ },
+ },
+ }
}
-// CallToolInput contains the parameters for calling a tool.
-type CallToolInput struct {
- // ToolName is the name of the tool to invoke.
- ToolName string `json:"tool_name" description:"Name of the tool to call"`
+// CreateFindToolHandler creates the handler for optim_find_tool
+// Exported for testing purposes
+func (o *OptimizerIntegration) CreateFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ return o.createFindToolHandler()
+}
+
+// extractFindToolParams extracts and validates parameters from the find_tool request
+func extractFindToolParams(args map[string]any) (toolDescription, toolKeywords string, limit int, err *mcp.CallToolResult) {
+ // Extract tool_description (required)
+ toolDescription, ok := args["tool_description"].(string)
+ if !ok || toolDescription == "" {
+ return "", "", 0, mcp.NewToolResultError("tool_description is required and must be a non-empty string")
+ }
+
+ // Extract tool_keywords (optional)
+ toolKeywords, _ = args["tool_keywords"].(string)
+
+ // Extract limit (optional, default: 10)
+ limit = 10
+ if limitVal, ok := args["limit"]; ok {
+ if limitFloat, ok := limitVal.(float64); ok {
+ limit = int(limitFloat)
+ }
+ }
+
+ return toolDescription, toolKeywords, limit, nil
+}
+
+// resolveToolName looks up the resolved name for a tool in the routing table.
+// Returns the resolved name if found, otherwise returns the original name.
+//
+// The routing table maps resolved names (after conflict resolution) to BackendTarget.
+// Each BackendTarget contains:
+// - WorkloadID: the backend ID
+// - OriginalCapabilityName: the original tool name (empty if not renamed)
+//
+// We need to find the resolved name by matching backend ID and original name.
+func resolveToolName(routingTable *vmcp.RoutingTable, backendID string, originalName string) string {
+ if routingTable == nil || routingTable.Tools == nil {
+ return originalName
+ }
+
+ // Search through routing table to find the resolved name
+ // Match by backend ID and original capability name
+ for resolvedName, target := range routingTable.Tools {
+ // Case 1: Tool was renamed (OriginalCapabilityName is set)
+ // Match by backend ID and original name
+ if target.WorkloadID == backendID && target.OriginalCapabilityName == originalName {
+ logger.Debugw("Resolved tool name (renamed)",
+ "backend_id", backendID,
+ "original_name", originalName,
+ "resolved_name", resolvedName)
+ return resolvedName
+ }
+
+ // Case 2: Tool was not renamed (OriginalCapabilityName is empty)
+ // Match by backend ID and resolved name equals original name
+ if target.WorkloadID == backendID && target.OriginalCapabilityName == "" && resolvedName == originalName {
+ logger.Debugw("Resolved tool name (not renamed)",
+ "backend_id", backendID,
+ "original_name", originalName,
+ "resolved_name", resolvedName)
+ return resolvedName
+ }
+ }
+
+ // If not found, return original name (fallback for tools not in routing table)
+ // This can happen if:
+ // - Tool was just ingested but routing table hasn't been updated yet
+ // - Tool belongs to a backend that's not currently registered
+ logger.Debugw("Tool name not found in routing table, using original name",
+ "backend_id", backendID,
+ "original_name", originalName)
+ return originalName
+}
+
+// convertSearchResultsToResponse converts database search results to the response format.
+// It resolves tool names using the routing table to ensure returned names match routing table keys.
+func convertSearchResultsToResponse(
+ results []*models.BackendToolWithMetadata,
+ routingTable *vmcp.RoutingTable,
+) ([]map[string]any, int) {
+ responseTools := make([]map[string]any, 0, len(results))
+ totalReturnedTokens := 0
+
+ for _, result := range results {
+ // Unmarshal InputSchema
+ var inputSchema map[string]any
+ if len(result.InputSchema) > 0 {
+ if err := json.Unmarshal(result.InputSchema, &inputSchema); err != nil {
+ logger.Warnw("Failed to unmarshal input schema",
+ "tool_id", result.ID,
+ "tool_name", result.ToolName,
+ "error", err)
+ inputSchema = map[string]any{} // Use empty schema on error
+ }
+ }
+
+ // Handle nil description
+ description := ""
+ if result.Description != nil {
+ description = *result.Description
+ }
+
+ // Resolve tool name using routing table to ensure it matches routing table keys
+ resolvedName := resolveToolName(routingTable, result.MCPServerID, result.ToolName)
+
+ tool := map[string]any{
+ "name": resolvedName,
+ "description": description,
+ "input_schema": inputSchema,
+ "backend_id": result.MCPServerID,
+ "similarity_score": result.Similarity,
+ "token_count": result.TokenCount,
+ }
+ responseTools = append(responseTools, tool)
+ totalReturnedTokens += result.TokenCount
+ }
+
+ return responseTools, totalReturnedTokens
+}
+
+// createFindToolHandler creates the handler for optim_find_tool
+func (o *OptimizerIntegration) createFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ logger.Debugw("optim_find_tool called", "request", request)
+
+ // Extract parameters from request arguments
+ args, ok := request.Params.Arguments.(map[string]any)
+ if !ok {
+ return mcp.NewToolResultError("invalid arguments: expected object"), nil
+ }
+
+ // Extract and validate parameters
+ toolDescription, toolKeywords, limit, err := extractFindToolParams(args)
+ if err != nil {
+ return err, nil
+ }
+
+ // Perform hybrid search using database operations
+ if o.ingestionService == nil {
+ return mcp.NewToolResultError("backend tool operations not initialized"), nil
+ }
+ backendToolOps := o.ingestionService.GetBackendToolOps()
+ if backendToolOps == nil {
+ return mcp.NewToolResultError("backend tool operations not initialized"), nil
+ }
+
+ // Configure hybrid search
+ hybridConfig := &db.HybridSearchConfig{
+ SemanticRatio: o.config.HybridSearchRatio,
+ Limit: limit,
+ ServerID: nil, // Search across all servers
+ }
+
+ // Execute hybrid search
+ queryText := toolDescription
+ if toolKeywords != "" {
+ queryText = toolDescription + " " + toolKeywords
+ }
+ results, err2 := backendToolOps.SearchHybrid(ctx, queryText, hybridConfig)
+ if err2 != nil {
+ logger.Errorw("Hybrid search failed",
+ "error", err2,
+ "tool_description", toolDescription,
+ "tool_keywords", toolKeywords,
+ "query_text", queryText)
+ return mcp.NewToolResultError(fmt.Sprintf("search failed: %v", err2)), nil
+ }
+
+ // Get routing table from context to resolve tool names
+ var routingTable *vmcp.RoutingTable
+ if capabilities, ok := discovery.DiscoveredCapabilitiesFromContext(ctx); ok && capabilities != nil {
+ routingTable = capabilities.RoutingTable
+ }
+
+ // Convert results to response format, resolving tool names to match routing table
+ responseTools, totalReturnedTokens := convertSearchResultsToResponse(results, routingTable)
+
+ // Calculate token metrics
+ baselineTokens := o.ingestionService.GetTotalToolTokens(ctx)
+ tokensSaved := baselineTokens - totalReturnedTokens
+ savingsPercentage := 0.0
+ if baselineTokens > 0 {
+ savingsPercentage = (float64(tokensSaved) / float64(baselineTokens)) * 100.0
+ }
+
+ tokenMetrics := map[string]any{
+ "baseline_tokens": baselineTokens,
+ "returned_tokens": totalReturnedTokens,
+ "tokens_saved": tokensSaved,
+ "savings_percentage": savingsPercentage,
+ }
+
+ // Record OpenTelemetry metrics for token savings
+ o.recordTokenMetrics(ctx, baselineTokens, totalReturnedTokens, tokensSaved, savingsPercentage)
+
+ // Build response
+ response := map[string]any{
+ "tools": responseTools,
+ "token_metrics": tokenMetrics,
+ }
+
+ // Marshal to JSON for the result
+ responseJSON, err3 := json.Marshal(response)
+ if err3 != nil {
+ logger.Errorw("Failed to marshal response", "error", err3)
+ return mcp.NewToolResultError(fmt.Sprintf("failed to marshal response: %v", err3)), nil
+ }
+
+ logger.Infow("optim_find_tool completed",
+ "query", toolDescription,
+ "results_count", len(responseTools),
+ "tokens_saved", tokensSaved,
+ "savings_percentage", fmt.Sprintf("%.2f%%", savingsPercentage))
+
+ return mcp.NewToolResultText(string(responseJSON)), nil
+ }
+}
+
+// recordTokenMetrics records OpenTelemetry metrics for token savings
+func (*OptimizerIntegration) recordTokenMetrics(
+ ctx context.Context,
+ baselineTokens int,
+ returnedTokens int,
+ tokensSaved int,
+ savingsPercentage float64,
+) {
+ // Get meter from global OpenTelemetry provider
+ meter := otel.Meter("github.com/stacklok/toolhive/pkg/vmcp/optimizer")
+
+ // Create metrics if they don't exist (they'll be cached by the meter)
+ baselineCounter, err := meter.Int64Counter(
+ "toolhive_vmcp_optimizer_baseline_tokens",
+ metric.WithDescription("Total tokens for all tools in the optimizer database (baseline)"),
+ )
+ if err != nil {
+ logger.Debugw("Failed to create baseline_tokens counter", "error", err)
+ return
+ }
+
+ returnedCounter, err := meter.Int64Counter(
+ "toolhive_vmcp_optimizer_returned_tokens",
+ metric.WithDescription("Total tokens for tools returned by optim_find_tool"),
+ )
+ if err != nil {
+ logger.Debugw("Failed to create returned_tokens counter", "error", err)
+ return
+ }
+
+ savedCounter, err := meter.Int64Counter(
+ "toolhive_vmcp_optimizer_tokens_saved",
+ metric.WithDescription("Number of tokens saved by filtering tools with optim_find_tool"),
+ )
+ if err != nil {
+ logger.Debugw("Failed to create tokens_saved counter", "error", err)
+ return
+ }
+
+ savingsGauge, err := meter.Float64Gauge(
+ "toolhive_vmcp_optimizer_savings_percentage",
+ metric.WithDescription("Percentage of tokens saved by filtering tools (0-100)"),
+ metric.WithUnit("%"),
+ )
+ if err != nil {
+ logger.Debugw("Failed to create savings_percentage gauge", "error", err)
+ return
+ }
+
+ // Record metrics with attributes
+ attrs := metric.WithAttributes(
+ attribute.String("operation", "find_tool"),
+ )
+
+ baselineCounter.Add(ctx, int64(baselineTokens), attrs)
+ returnedCounter.Add(ctx, int64(returnedTokens), attrs)
+ savedCounter.Add(ctx, int64(tokensSaved), attrs)
+ savingsGauge.Record(ctx, savingsPercentage, attrs)
+
+ logger.Debugw("Token metrics recorded",
+ "baseline_tokens", baselineTokens,
+ "returned_tokens", returnedTokens,
+ "tokens_saved", tokensSaved,
+ "savings_percentage", savingsPercentage)
+}
+
+// CreateCallToolHandler creates the handler for optim_call_tool
+// Exported for testing purposes
+func (o *OptimizerIntegration) CreateCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ return o.createCallToolHandler()
+}
+
+// createCallToolHandler creates the handler for optim_call_tool
+func (o *OptimizerIntegration) createCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ logger.Debugw("optim_call_tool called", "request", request)
+
+ // Extract parameters from request arguments
+ args, ok := request.Params.Arguments.(map[string]any)
+ if !ok {
+ return mcp.NewToolResultError("invalid arguments: expected object"), nil
+ }
+
+ // Extract backend_id (required)
+ backendID, ok := args["backend_id"].(string)
+ if !ok || backendID == "" {
+ return mcp.NewToolResultError("backend_id is required and must be a non-empty string"), nil
+ }
+
+ // Extract tool_name (required)
+ toolName, ok := args["tool_name"].(string)
+ if !ok || toolName == "" {
+ return mcp.NewToolResultError("tool_name is required and must be a non-empty string"), nil
+ }
+
+ // Extract parameters (required)
+ parameters, ok := args["parameters"].(map[string]any)
+ if !ok {
+ return mcp.NewToolResultError("parameters is required and must be an object"), nil
+ }
+
+ // Get routing table from context via discovered capabilities
+ capabilities, ok := discovery.DiscoveredCapabilitiesFromContext(ctx)
+ if !ok || capabilities == nil {
+ return mcp.NewToolResultError("routing information not available in context"), nil
+ }
+
+ if capabilities.RoutingTable == nil || capabilities.RoutingTable.Tools == nil {
+ return mcp.NewToolResultError("routing table not initialized"), nil
+ }
+
+ // Find the tool in the routing table
+ target, exists := capabilities.RoutingTable.Tools[toolName]
+ if !exists {
+ return mcp.NewToolResultError(fmt.Sprintf("tool not found in routing table: %s", toolName)), nil
+ }
+
+ // Verify the tool belongs to the specified backend
+ if target.WorkloadID != backendID {
+ return mcp.NewToolResultError(fmt.Sprintf(
+ "tool %s belongs to backend %s, not %s",
+ toolName,
+ target.WorkloadID,
+ backendID,
+ )), nil
+ }
+
+ // Get the backend capability name (handles renamed tools)
+ backendToolName := target.GetBackendCapabilityName(toolName)
+
+ logger.Infow("Calling tool via optimizer",
+ "backend_id", backendID,
+ "tool_name", toolName,
+ "backend_tool_name", backendToolName,
+ "workload_name", target.WorkloadName)
+
+ // Call the tool on the backend using the backend client
+ result, err := o.backendClient.CallTool(ctx, target, backendToolName, parameters)
+ if err != nil {
+ logger.Errorw("Tool call failed",
+ "error", err,
+ "backend_id", backendID,
+ "tool_name", toolName,
+ "backend_tool_name", backendToolName)
+ return mcp.NewToolResultError(fmt.Sprintf("tool call failed: %v", err)), nil
+ }
+
+ // Convert result to JSON
+ resultJSON, err := json.Marshal(result)
+ if err != nil {
+ logger.Errorw("Failed to marshal tool result", "error", err)
+ return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil
+ }
+
+ logger.Infow("optim_call_tool completed successfully",
+ "backend_id", backendID,
+ "tool_name", toolName)
+
+ return mcp.NewToolResultText(string(resultJSON)), nil
+ }
+}
+
+// IngestInitialBackends ingests all discovered backends and their tools at startup.
+// This should be called after backends are discovered during server initialization.
+func (o *OptimizerIntegration) IngestInitialBackends(ctx context.Context, backends []vmcp.Backend) error {
+ if o == nil || o.ingestionService == nil {
+ // Optimizer disabled - log that embedding time is 0
+ logger.Infow("Optimizer disabled, embedding time: 0ms")
+ return nil
+ }
+
+ // Reset embedding time before starting ingestion
+ o.ingestionService.ResetEmbeddingTime()
+
+ // Create a span for the entire ingestion process
+ ctx, span := o.tracer.Start(ctx, "optimizer.ingestion.ingest_initial_backends",
+ trace.WithAttributes(
+ attribute.Int("backends.count", len(backends)),
+ ))
+ defer span.End()
+
+ start := time.Now()
+ logger.Infof("Ingesting %d discovered backends into optimizer", len(backends))
+
+ ingestedCount := 0
+ totalToolsIngested := 0
+ for _, backend := range backends {
+ // Create a span for each backend ingestion
+ backendCtx, backendSpan := o.tracer.Start(ctx, "optimizer.ingestion.ingest_backend",
+ trace.WithAttributes(
+ attribute.String("backend.id", backend.ID),
+ attribute.String("backend.name", backend.Name),
+ ))
+ defer backendSpan.End()
+
+ // Convert Backend to BackendTarget for client API
+ target := vmcp.BackendToTarget(&backend)
+ if target == nil {
+ logger.Warnf("Failed to convert backend %s to target", backend.Name)
+ backendSpan.RecordError(fmt.Errorf("failed to convert backend to target"))
+ backendSpan.SetStatus(codes.Error, "conversion failed")
+ continue
+ }
+
+ // Query backend capabilities to get its tools
+ capabilities, err := o.backendClient.ListCapabilities(backendCtx, target)
+ if err != nil {
+ logger.Warnf("Failed to query capabilities for backend %s: %v", backend.Name, err)
+ backendSpan.RecordError(err)
+ backendSpan.SetStatus(codes.Error, err.Error())
+ continue // Skip this backend but continue with others
+ }
+
+ // Extract tools from capabilities
+ // Note: For ingestion, we only need name and description (for generating embeddings)
+ // InputSchema is not used by the ingestion service
+ var tools []mcp.Tool
+ for _, tool := range capabilities.Tools {
+ tools = append(tools, mcp.Tool{
+ Name: tool.Name,
+ Description: tool.Description,
+ // InputSchema not needed for embedding generation
+ })
+ }
+
+ // Get description from metadata (may be empty)
+ var description *string
+ if backend.Metadata != nil {
+ if desc := backend.Metadata["description"]; desc != "" {
+ description = &desc
+ }
+ }
+
+ backendSpan.SetAttributes(
+ attribute.Int("tools.count", len(tools)),
+ )
+
+ // Ingest this backend's tools (IngestServer will create its own spans)
+ if err := o.ingestionService.IngestServer(
+ backendCtx,
+ backend.ID,
+ backend.Name,
+ description,
+ tools,
+ ); err != nil {
+ logger.Warnf("Failed to ingest backend %s: %v", backend.Name, err)
+ backendSpan.RecordError(err)
+ backendSpan.SetStatus(codes.Error, err.Error())
+ continue // Log but don't fail startup
+ }
+ ingestedCount++
+ totalToolsIngested += len(tools)
+ backendSpan.SetAttributes(
+ attribute.Int("tools.ingested", len(tools)),
+ )
+ backendSpan.SetStatus(codes.Ok, "backend ingested successfully")
+ }
+
+ // Get total embedding time
+ totalEmbeddingTime := o.ingestionService.GetTotalEmbeddingTime()
+ totalDuration := time.Since(start)
+
+ span.SetAttributes(
+ attribute.Int64("ingestion.duration_ms", totalDuration.Milliseconds()),
+ attribute.Int64("embedding.duration_ms", totalEmbeddingTime.Milliseconds()),
+ attribute.Int("backends.ingested", ingestedCount),
+ attribute.Int("tools.ingested", totalToolsIngested),
+ )
+
+ logger.Infow("Initial backend ingestion completed",
+ "servers_ingested", ingestedCount,
+ "tools_ingested", totalToolsIngested,
+ "total_duration_ms", totalDuration.Milliseconds(),
+ "total_embedding_time_ms", totalEmbeddingTime.Milliseconds(),
+ "embedding_time_percentage", fmt.Sprintf("%.2f%%", float64(totalEmbeddingTime)/float64(totalDuration)*100))
+
+ return nil
+}
+
+// Close cleans up optimizer resources.
+func (o *OptimizerIntegration) Close() error {
+ if o == nil || o.ingestionService == nil {
+ return nil
+ }
+ return o.ingestionService.Close()
+}
- // Parameters are the arguments to pass to the tool.
- Parameters map[string]any `json:"parameters" description:"Parameters to pass to the tool"`
+// IngestToolsForTesting manually ingests tools for testing purposes.
+// This is a test helper that bypasses the normal ingestion flow.
+func (o *OptimizerIntegration) IngestToolsForTesting(
+ ctx context.Context,
+ serverID string,
+ serverName string,
+ description *string,
+ tools []mcp.Tool,
+) error {
+ if o == nil || o.ingestionService == nil {
+ return fmt.Errorf("optimizer integration not initialized")
+ }
+ return o.ingestionService.IngestServer(ctx, serverID, serverName, description, tools)
}
diff --git a/pkg/vmcp/optimizer/optimizer_handlers_test.go b/pkg/vmcp/optimizer/optimizer_handlers_test.go
new file mode 100644
index 0000000000..6adee847ee
--- /dev/null
+++ b/pkg/vmcp/optimizer/optimizer_handlers_test.go
@@ -0,0 +1,1029 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package optimizer
+
+import (
+ "context"
+ "encoding/json"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings"
+ transportsession "github.com/stacklok/toolhive/pkg/transport/session"
+ "github.com/stacklok/toolhive/pkg/vmcp"
+ "github.com/stacklok/toolhive/pkg/vmcp/aggregator"
+ "github.com/stacklok/toolhive/pkg/vmcp/discovery"
+ vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session"
+)
+
+// mockMCPServerWithSession implements AddSessionTools for testing
+type mockMCPServerWithSession struct {
+ *server.MCPServer
+ toolsAdded map[string][]server.ServerTool
+}
+
+func newMockMCPServerWithSession() *mockMCPServerWithSession {
+ return &mockMCPServerWithSession{
+ MCPServer: server.NewMCPServer("test-server", "1.0"),
+ toolsAdded: make(map[string][]server.ServerTool),
+ }
+}
+
+func (m *mockMCPServerWithSession) AddSessionTools(sessionID string, tools ...server.ServerTool) error {
+ m.toolsAdded[sessionID] = tools
+ return nil
+}
+
+// mockBackendClientWithCallTool implements CallTool for testing
+type mockBackendClientWithCallTool struct {
+ callToolResult map[string]any
+ callToolError error
+}
+
+func (*mockBackendClientWithCallTool) ListCapabilities(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) {
+ return &vmcp.CapabilityList{}, nil
+}
+
+func (m *mockBackendClientWithCallTool) CallTool(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (map[string]any, error) {
+ if m.callToolError != nil {
+ return nil, m.callToolError
+ }
+ return m.callToolResult, nil
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockBackendClientWithCallTool) GetPrompt(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (string, error) {
+ return "", nil
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockBackendClientWithCallTool) ReadResource(_ context.Context, _ *vmcp.BackendTarget, _ string) ([]byte, error) {
+ return nil, nil
+}
+
+// TestCreateFindToolHandler_InvalidArguments tests error handling for invalid arguments
+func TestCreateFindToolHandler_InvalidArguments(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Setup optimizer integration
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ mcpServer := newMockMCPServerWithSession()
+ mockClient := &mockBackendClient{}
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ handler := integration.CreateFindToolHandler()
+
+ // Test with invalid arguments type
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: "not a map",
+ },
+ }
+
+ result, err := handler(ctx, request)
+ require.NoError(t, err)
+ require.True(t, result.IsError, "Should return error for invalid arguments")
+
+ // Test with missing tool_description
+ request = mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "limit": 10,
+ },
+ },
+ }
+
+ result, err = handler(ctx, request)
+ require.NoError(t, err)
+ require.True(t, result.IsError, "Should return error for missing tool_description")
+
+ // Test with empty tool_description
+ request = mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "tool_description": "",
+ },
+ },
+ }
+
+ result, err = handler(ctx, request)
+ require.NoError(t, err)
+ require.True(t, result.IsError, "Should return error for empty tool_description")
+
+ // Test with non-string tool_description
+ request = mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "tool_description": 123,
+ },
+ },
+ }
+
+ result, err = handler(ctx, request)
+ require.NoError(t, err)
+ require.True(t, result.IsError, "Should return error for non-string tool_description")
+}
+
+// TestCreateFindToolHandler_WithKeywords tests find_tool with keywords
+func TestCreateFindToolHandler_WithKeywords(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ mcpServer := newMockMCPServerWithSession()
+ mockClient := &mockBackendClient{}
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ // Ingest a tool for testing
+ tools := []mcp.Tool{
+ {
+ Name: "test_tool",
+ Description: "A test tool for searching",
+ },
+ }
+
+ err = integration.IngestToolsForTesting(ctx, "server-1", "TestServer", nil, tools)
+ require.NoError(t, err)
+
+ handler := integration.CreateFindToolHandler()
+
+ // Test with keywords
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "tool_description": "search tool",
+ "tool_keywords": "test search",
+ "limit": 10,
+ },
+ },
+ }
+
+ result, err := handler(ctx, request)
+ require.NoError(t, err)
+ require.False(t, result.IsError, "Should not return error")
+
+ // Verify response structure
+ textContent, ok := mcp.AsTextContent(result.Content[0])
+ require.True(t, ok)
+
+ var response map[string]any
+ err = json.Unmarshal([]byte(textContent.Text), &response)
+ require.NoError(t, err)
+
+ _, ok = response["tools"]
+ require.True(t, ok, "Response should have tools")
+
+ _, ok = response["token_metrics"]
+ require.True(t, ok, "Response should have token_metrics")
+}
+
+// TestCreateFindToolHandler_Limit tests limit parameter handling
+func TestCreateFindToolHandler_Limit(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ mcpServer := newMockMCPServerWithSession()
+ mockClient := &mockBackendClient{}
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ handler := integration.CreateFindToolHandler()
+
+ // Test with custom limit
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "tool_description": "test",
+ "limit": 5,
+ },
+ },
+ }
+
+ result, err := handler(ctx, request)
+ require.NoError(t, err)
+ require.False(t, result.IsError)
+
+ // Test with float64 limit (from JSON)
+ request = mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "tool_description": "test",
+ "limit": float64(3),
+ },
+ },
+ }
+
+ result, err = handler(ctx, request)
+ require.NoError(t, err)
+ require.False(t, result.IsError)
+}
+
+// TestCreateFindToolHandler_BackendToolOpsNil tests error when backend tool ops is nil
+func TestCreateFindToolHandler_BackendToolOpsNil(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ // Create integration with nil ingestion service to trigger error path
+ integration := &OptimizerIntegration{
+ config: &Config{Enabled: true},
+ ingestionService: nil, // This will cause GetBackendToolOps to return nil
+ }
+
+ handler := integration.CreateFindToolHandler()
+
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "tool_description": "test",
+ },
+ },
+ }
+
+ result, err := handler(ctx, request)
+ require.NoError(t, err)
+ require.True(t, result.IsError, "Should return error when backend tool ops is nil")
+}
+
+// TestCreateCallToolHandler_InvalidArguments tests error handling for invalid arguments
+func TestCreateCallToolHandler_InvalidArguments(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Check Ollama availability first
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ mcpServer := newMockMCPServerWithSession()
+ mockClient := &mockBackendClientWithCallTool{}
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ handler := integration.CreateCallToolHandler()
+
+ // Test with invalid arguments type
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_call_tool",
+ Arguments: "not a map",
+ },
+ }
+
+ result, err := handler(ctx, request)
+ require.NoError(t, err)
+ require.True(t, result.IsError, "Should return error for invalid arguments")
+
+ // Test with missing backend_id
+ request = mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_call_tool",
+ Arguments: map[string]any{
+ "tool_name": "test_tool",
+ "parameters": map[string]any{},
+ },
+ },
+ }
+
+ result, err = handler(ctx, request)
+ require.NoError(t, err)
+ require.True(t, result.IsError, "Should return error for missing backend_id")
+
+ // Test with empty backend_id
+ request = mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_call_tool",
+ Arguments: map[string]any{
+ "backend_id": "",
+ "tool_name": "test_tool",
+ "parameters": map[string]any{},
+ },
+ },
+ }
+
+ result, err = handler(ctx, request)
+ require.NoError(t, err)
+ require.True(t, result.IsError, "Should return error for empty backend_id")
+
+ // Test with missing tool_name
+ request = mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_call_tool",
+ Arguments: map[string]any{
+ "backend_id": "backend-1",
+ "parameters": map[string]any{},
+ },
+ },
+ }
+
+ result, err = handler(ctx, request)
+ require.NoError(t, err)
+ require.True(t, result.IsError, "Should return error for missing tool_name")
+
+ // Test with missing parameters
+ request = mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_call_tool",
+ Arguments: map[string]any{
+ "backend_id": "backend-1",
+ "tool_name": "test_tool",
+ },
+ },
+ }
+
+ result, err = handler(ctx, request)
+ require.NoError(t, err)
+ require.True(t, result.IsError, "Should return error for missing parameters")
+
+ // Test with invalid parameters type
+ request = mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_call_tool",
+ Arguments: map[string]any{
+ "backend_id": "backend-1",
+ "tool_name": "test_tool",
+ "parameters": "not a map",
+ },
+ },
+ }
+
+ result, err = handler(ctx, request)
+ require.NoError(t, err)
+ require.True(t, result.IsError, "Should return error for invalid parameters type")
+}
+
+// TestCreateCallToolHandler_NoRoutingTable tests error when routing table is missing
+func TestCreateCallToolHandler_NoRoutingTable(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Check Ollama availability first
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ mcpServer := newMockMCPServerWithSession()
+ mockClient := &mockBackendClientWithCallTool{}
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ handler := integration.CreateCallToolHandler()
+
+ // Test without routing table in context
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_call_tool",
+ Arguments: map[string]any{
+ "backend_id": "backend-1",
+ "tool_name": "test_tool",
+ "parameters": map[string]any{},
+ },
+ },
+ }
+
+ result, err := handler(ctx, request)
+ require.NoError(t, err)
+ require.True(t, result.IsError, "Should return error when routing table is missing")
+}
+
+// TestCreateCallToolHandler_ToolNotFound tests error when tool is not found
+func TestCreateCallToolHandler_ToolNotFound(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Check Ollama availability first
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ mcpServer := newMockMCPServerWithSession()
+ mockClient := &mockBackendClientWithCallTool{}
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ handler := integration.CreateCallToolHandler()
+
+ // Create context with routing table but tool not found
+ capabilities := &aggregator.AggregatedCapabilities{
+ RoutingTable: &vmcp.RoutingTable{
+ Tools: make(map[string]*vmcp.BackendTarget),
+ Resources: map[string]*vmcp.BackendTarget{},
+ Prompts: map[string]*vmcp.BackendTarget{},
+ },
+ }
+
+ ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities)
+
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_call_tool",
+ Arguments: map[string]any{
+ "backend_id": "backend-1",
+ "tool_name": "nonexistent_tool",
+ "parameters": map[string]any{},
+ },
+ },
+ }
+
+ result, err := handler(ctxWithCaps, request)
+ require.NoError(t, err)
+ require.True(t, result.IsError, "Should return error when tool is not found")
+}
+
+// TestCreateCallToolHandler_BackendMismatch tests error when backend doesn't match
+func TestCreateCallToolHandler_BackendMismatch(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Check Ollama availability first
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ mcpServer := newMockMCPServerWithSession()
+ mockClient := &mockBackendClientWithCallTool{}
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ handler := integration.CreateCallToolHandler()
+
+ // Create context with routing table where tool belongs to different backend
+ capabilities := &aggregator.AggregatedCapabilities{
+ RoutingTable: &vmcp.RoutingTable{
+ Tools: map[string]*vmcp.BackendTarget{
+ "test_tool": {
+ WorkloadID: "backend-2", // Different backend
+ WorkloadName: "Backend 2",
+ },
+ },
+ Resources: map[string]*vmcp.BackendTarget{},
+ Prompts: map[string]*vmcp.BackendTarget{},
+ },
+ }
+
+ ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities)
+
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_call_tool",
+ Arguments: map[string]any{
+ "backend_id": "backend-1", // Requesting backend-1
+ "tool_name": "test_tool", // But tool belongs to backend-2
+ "parameters": map[string]any{},
+ },
+ },
+ }
+
+ result, err := handler(ctxWithCaps, request)
+ require.NoError(t, err)
+ require.True(t, result.IsError, "Should return error when backend doesn't match")
+}
+
+// TestCreateCallToolHandler_Success tests successful tool call
+func TestCreateCallToolHandler_Success(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Check Ollama availability first
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ mcpServer := newMockMCPServerWithSession()
+ mockClient := &mockBackendClientWithCallTool{
+ callToolResult: map[string]any{
+ "result": "success",
+ },
+ }
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ handler := integration.CreateCallToolHandler()
+
+ // Create context with routing table
+ target := &vmcp.BackendTarget{
+ WorkloadID: "backend-1",
+ WorkloadName: "Backend 1",
+ BaseURL: "http://localhost:8000",
+ }
+
+ capabilities := &aggregator.AggregatedCapabilities{
+ RoutingTable: &vmcp.RoutingTable{
+ Tools: map[string]*vmcp.BackendTarget{
+ "test_tool": target,
+ },
+ Resources: map[string]*vmcp.BackendTarget{},
+ Prompts: map[string]*vmcp.BackendTarget{},
+ },
+ }
+
+ ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities)
+
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_call_tool",
+ Arguments: map[string]any{
+ "backend_id": "backend-1",
+ "tool_name": "test_tool",
+ "parameters": map[string]any{
+ "param1": "value1",
+ },
+ },
+ },
+ }
+
+ result, err := handler(ctxWithCaps, request)
+ require.NoError(t, err)
+ require.False(t, result.IsError, "Should not return error")
+
+ // Verify response
+ textContent, ok := mcp.AsTextContent(result.Content[0])
+ require.True(t, ok)
+
+ var response map[string]any
+ err = json.Unmarshal([]byte(textContent.Text), &response)
+ require.NoError(t, err)
+ assert.Equal(t, "success", response["result"])
+}
+
+// TestCreateCallToolHandler_CallToolError tests error handling when CallTool fails
+func TestCreateCallToolHandler_CallToolError(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Check Ollama availability first
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ mcpServer := newMockMCPServerWithSession()
+ mockClient := &mockBackendClientWithCallTool{
+ callToolError: assert.AnError,
+ }
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ handler := integration.CreateCallToolHandler()
+
+ target := &vmcp.BackendTarget{
+ WorkloadID: "backend-1",
+ WorkloadName: "Backend 1",
+ BaseURL: "http://localhost:8000",
+ }
+
+ capabilities := &aggregator.AggregatedCapabilities{
+ RoutingTable: &vmcp.RoutingTable{
+ Tools: map[string]*vmcp.BackendTarget{
+ "test_tool": target,
+ },
+ Resources: map[string]*vmcp.BackendTarget{},
+ Prompts: map[string]*vmcp.BackendTarget{},
+ },
+ }
+
+ ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities)
+
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_call_tool",
+ Arguments: map[string]any{
+ "backend_id": "backend-1",
+ "tool_name": "test_tool",
+ "parameters": map[string]any{},
+ },
+ },
+ }
+
+ result, err := handler(ctxWithCaps, request)
+ require.NoError(t, err)
+ require.True(t, result.IsError, "Should return error when CallTool fails")
+}
+
+// TestCreateFindToolHandler_InputSchemaUnmarshalError tests error handling for invalid input schema
+func TestCreateFindToolHandler_InputSchemaUnmarshalError(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ mcpServer := newMockMCPServerWithSession()
+ mockClient := &mockBackendClient{}
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ handler := integration.CreateFindToolHandler()
+
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "tool_description": "test",
+ },
+ },
+ }
+
+ // The handler should handle invalid input schema gracefully
+ result, err := handler(ctx, request)
+ require.NoError(t, err)
+ // Should not error even if some tools have invalid schemas
+ require.False(t, result.IsError)
+}
+
+// TestOnRegisterSession_DuplicateSession tests duplicate session handling
+func TestOnRegisterSession_DuplicateSession(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Check Ollama availability first
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ mcpServer := newMockMCPServerWithSession()
+ mockClient := &mockBackendClient{}
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ session := &mockSession{sessionID: "test-session"}
+ capabilities := &aggregator.AggregatedCapabilities{}
+
+ // First call
+ err = integration.OnRegisterSession(ctx, session, capabilities)
+ require.NoError(t, err)
+
+ // Second call with same session ID (should be skipped)
+ err = integration.OnRegisterSession(ctx, session, capabilities)
+ require.NoError(t, err, "Should handle duplicate session gracefully")
+}
+
+// TestIngestInitialBackends_ErrorHandling tests error handling during ingestion
+func TestIngestInitialBackends_ErrorHandling(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Check Ollama availability first
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ mcpServer := newMockMCPServerWithSession()
+ mockClient := &mockBackendClient{
+ err: assert.AnError, // Simulate error when listing capabilities
+ }
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ backends := []vmcp.Backend{
+ {
+ ID: "backend-1",
+ Name: "Backend 1",
+ BaseURL: "http://localhost:8000",
+ TransportType: "sse",
+ },
+ }
+
+ // Should not fail even if backend query fails
+ err = integration.IngestInitialBackends(ctx, backends)
+ require.NoError(t, err, "Should handle backend query errors gracefully")
+}
+
+// TestIngestInitialBackends_NilIntegration tests nil integration handling
+func TestIngestInitialBackends_NilIntegration(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ var integration *OptimizerIntegration = nil
+ backends := []vmcp.Backend{}
+
+ err := integration.IngestInitialBackends(ctx, backends)
+ require.NoError(t, err, "Should handle nil integration gracefully")
+}
diff --git a/pkg/vmcp/optimizer/optimizer_integration_test.go b/pkg/vmcp/optimizer/optimizer_integration_test.go
new file mode 100644
index 0000000000..bb3ecf9583
--- /dev/null
+++ b/pkg/vmcp/optimizer/optimizer_integration_test.go
@@ -0,0 +1,439 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package optimizer
+
+import (
+ "context"
+ "encoding/json"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
+ "github.com/stretchr/testify/require"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings"
+ transportsession "github.com/stacklok/toolhive/pkg/transport/session"
+ "github.com/stacklok/toolhive/pkg/vmcp"
+ "github.com/stacklok/toolhive/pkg/vmcp/aggregator"
+ vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session"
+)
+
+// mockBackendClient implements vmcp.BackendClient for integration testing
+type mockIntegrationBackendClient struct {
+ backends map[string]*vmcp.CapabilityList
+}
+
+func newMockIntegrationBackendClient() *mockIntegrationBackendClient {
+ return &mockIntegrationBackendClient{
+ backends: make(map[string]*vmcp.CapabilityList),
+ }
+}
+
+func (m *mockIntegrationBackendClient) addBackend(backendID string, caps *vmcp.CapabilityList) {
+ m.backends[backendID] = caps
+}
+
+func (m *mockIntegrationBackendClient) ListCapabilities(_ context.Context, target *vmcp.BackendTarget) (*vmcp.CapabilityList, error) {
+ if caps, exists := m.backends[target.WorkloadID]; exists {
+ return caps, nil
+ }
+ return &vmcp.CapabilityList{}, nil
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockIntegrationBackendClient) CallTool(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (map[string]any, error) {
+ return nil, nil
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockIntegrationBackendClient) GetPrompt(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (string, error) {
+ return "", nil
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockIntegrationBackendClient) ReadResource(_ context.Context, _ *vmcp.BackendTarget, _ string) ([]byte, error) {
+ return nil, nil
+}
+
+// mockIntegrationSession implements server.ClientSession for testing
+type mockIntegrationSession struct {
+ sessionID string
+}
+
+func (m *mockIntegrationSession) SessionID() string {
+ return m.sessionID
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockIntegrationSession) Send(_ interface{}) error {
+ return nil
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockIntegrationSession) Close() error {
+ return nil
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockIntegrationSession) Initialize() {
+ // No-op for testing
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockIntegrationSession) Initialized() bool {
+ return true
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockIntegrationSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
+ // Return a dummy channel for testing
+ ch := make(chan mcp.JSONRPCNotification, 1)
+ return ch
+}
+
+// TestOptimizerIntegration_WithVMCP tests the complete integration with vMCP
+func TestOptimizerIntegration_WithVMCP(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Create MCP server
+ mcpServer := server.NewMCPServer("vmcp-test", "1.0")
+
+ // Create mock backend client
+ mockClient := newMockIntegrationBackendClient()
+ mockClient.addBackend("github", &vmcp.CapabilityList{
+ Tools: []vmcp.Tool{
+ {
+ Name: "create_issue",
+ Description: "Create a GitHub issue",
+ },
+ },
+ })
+
+ // Try to use Ollama if available, otherwise skip test
+ embeddingConfig := &embeddings.Config{
+ BackendType: embeddings.BackendTypeOllama,
+ BaseURL: "http://localhost:11434",
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM)
+ return
+ }
+ t.Cleanup(func() { _ = embeddingManager.Close() })
+
+ // Configure optimizer
+ optimizerConfig := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: embeddings.BackendTypeOllama,
+ BaseURL: "http://localhost:11434",
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ },
+ }
+
+ // Create optimizer integration
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ // Ingest backends
+ backends := []vmcp.Backend{
+ {
+ ID: "github",
+ Name: "GitHub",
+ BaseURL: "http://localhost:8000",
+ TransportType: "sse",
+ },
+ }
+
+ err = integration.IngestInitialBackends(ctx, backends)
+ require.NoError(t, err)
+
+ // Simulate session registration
+ session := &mockIntegrationSession{sessionID: "test-session"}
+ capabilities := &aggregator.AggregatedCapabilities{
+ Tools: []vmcp.Tool{
+ {
+ Name: "create_issue",
+ Description: "Create a GitHub issue",
+ BackendID: "github",
+ },
+ },
+ RoutingTable: &vmcp.RoutingTable{
+ Tools: map[string]*vmcp.BackendTarget{
+ "create_issue": {
+ WorkloadID: "github",
+ WorkloadName: "GitHub",
+ },
+ },
+ Resources: map[string]*vmcp.BackendTarget{},
+ Prompts: map[string]*vmcp.BackendTarget{},
+ },
+ }
+
+ err = integration.OnRegisterSession(ctx, session, capabilities)
+ require.NoError(t, err)
+
+ // Note: We don't test RegisterTools here because it requires the session
+ // to be properly registered with the MCP server, which is beyond the scope
+ // of this integration test. The RegisterTools method is tested separately
+ // in unit tests where we can properly mock the MCP server behavior.
+}
+
+// TestOptimizerIntegration_EmbeddingTimeTracking tests that embedding time is tracked and logged
+func TestOptimizerIntegration_EmbeddingTimeTracking(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Create MCP server
+ mcpServer := server.NewMCPServer("vmcp-test", "1.0")
+
+ // Create mock backend client
+ mockClient := newMockIntegrationBackendClient()
+ mockClient.addBackend("github", &vmcp.CapabilityList{
+ Tools: []vmcp.Tool{
+ {
+ Name: "create_issue",
+ Description: "Create a GitHub issue",
+ },
+ {
+ Name: "get_repo",
+ Description: "Get repository information",
+ },
+ },
+ })
+
+ // Try to use Ollama if available, otherwise skip test
+ embeddingConfig := &embeddings.Config{
+ BackendType: embeddings.BackendTypeOllama,
+ BaseURL: "http://localhost:11434",
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM)
+ return
+ }
+ t.Cleanup(func() { _ = embeddingManager.Close() })
+
+ // Configure optimizer
+ optimizerConfig := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: embeddings.BackendTypeOllama,
+ BaseURL: "http://localhost:11434",
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ },
+ }
+
+ // Create optimizer integration
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ // Verify embedding time starts at 0
+ embeddingTime := integration.ingestionService.GetTotalEmbeddingTime()
+ require.Equal(t, time.Duration(0), embeddingTime, "Initial embedding time should be 0")
+
+ // Ingest backends
+ backends := []vmcp.Backend{
+ {
+ ID: "github",
+ Name: "GitHub",
+ BaseURL: "http://localhost:8000",
+ TransportType: "sse",
+ },
+ }
+
+ err = integration.IngestInitialBackends(ctx, backends)
+ require.NoError(t, err)
+
+ // After ingestion, embedding time should be tracked
+ // Note: The actual time depends on Ollama performance, but it should be > 0
+ finalEmbeddingTime := integration.ingestionService.GetTotalEmbeddingTime()
+ require.Greater(t, finalEmbeddingTime, time.Duration(0),
+ "Embedding time should be tracked after ingestion")
+}
+
+// TestOptimizerIntegration_DisabledEmbeddingTime tests that embedding time is 0 when optimizer is disabled
+func TestOptimizerIntegration_DisabledEmbeddingTime(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ // Create optimizer integration with disabled optimizer
+ optimizerConfig := &Config{
+ Enabled: false,
+ }
+
+ mcpServer := server.NewMCPServer("vmcp-test", "1.0")
+ mockClient := newMockIntegrationBackendClient()
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+
+ integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ require.Nil(t, integration, "Integration should be nil when optimizer is disabled")
+
+ // Try to ingest backends - should return nil without error
+ backends := []vmcp.Backend{
+ {
+ ID: "github",
+ Name: "GitHub",
+ BaseURL: "http://localhost:8000",
+ TransportType: "sse",
+ },
+ }
+
+ // This should handle nil integration gracefully
+ var nilIntegration *OptimizerIntegration
+ err = nilIntegration.IngestInitialBackends(ctx, backends)
+ require.NoError(t, err, "Should handle nil integration gracefully")
+}
+
+// TestOptimizerIntegration_TokenMetrics tests that token metrics are calculated and returned in optim_find_tool
+func TestOptimizerIntegration_TokenMetrics(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Create MCP server
+ mcpServer := server.NewMCPServer("vmcp-test", "1.0")
+
+ // Create mock backend client with multiple tools
+ mockClient := newMockIntegrationBackendClient()
+ mockClient.addBackend("github", &vmcp.CapabilityList{
+ Tools: []vmcp.Tool{
+ {
+ Name: "create_issue",
+ Description: "Create a GitHub issue",
+ },
+ {
+ Name: "get_pull_request",
+ Description: "Get a pull request from GitHub",
+ },
+ {
+ Name: "list_repositories",
+ Description: "List repositories from GitHub",
+ },
+ },
+ })
+
+ // Try to use Ollama if available, otherwise skip test
+ embeddingConfig := &embeddings.Config{
+ BackendType: embeddings.BackendTypeOllama,
+ BaseURL: "http://localhost:11434",
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM)
+ return
+ }
+ t.Cleanup(func() { _ = embeddingManager.Close() })
+
+ // Configure optimizer
+ optimizerConfig := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: embeddings.BackendTypeOllama,
+ BaseURL: "http://localhost:11434",
+ Model: embeddings.DefaultModelAllMiniLM,
+ Dimension: 384,
+ },
+ }
+
+ // Create optimizer integration
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ // Ingest backends
+ backends := []vmcp.Backend{
+ {
+ ID: "github",
+ Name: "GitHub",
+ BaseURL: "http://localhost:8000",
+ TransportType: "sse",
+ },
+ }
+
+ err = integration.IngestInitialBackends(ctx, backends)
+ require.NoError(t, err)
+
+ // Get the find_tool handler
+ handler := integration.CreateFindToolHandler()
+ require.NotNil(t, handler)
+
+ // Call optim_find_tool
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: "optim_find_tool",
+ Arguments: map[string]any{
+ "tool_description": "create issue",
+ "limit": 5,
+ },
+ },
+ }
+
+ result, err := handler(ctx, request)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+
+ // Verify result contains token_metrics
+ require.NotNil(t, result.Content)
+ require.Len(t, result.Content, 1)
+ textResult, ok := result.Content[0].(mcp.TextContent)
+ require.True(t, ok, "Result should be TextContent")
+
+ // Parse JSON response
+ var response map[string]any
+ err = json.Unmarshal([]byte(textResult.Text), &response)
+ require.NoError(t, err)
+
+ // Verify token_metrics exist
+ tokenMetrics, ok := response["token_metrics"].(map[string]any)
+ require.True(t, ok, "Response should contain token_metrics")
+
+ // Verify token metrics fields
+ baselineTokens, ok := tokenMetrics["baseline_tokens"].(float64)
+ require.True(t, ok, "token_metrics should contain baseline_tokens")
+ require.Greater(t, baselineTokens, float64(0), "baseline_tokens should be greater than 0")
+
+ returnedTokens, ok := tokenMetrics["returned_tokens"].(float64)
+ require.True(t, ok, "token_metrics should contain returned_tokens")
+ require.GreaterOrEqual(t, returnedTokens, float64(0), "returned_tokens should be >= 0")
+
+ tokensSaved, ok := tokenMetrics["tokens_saved"].(float64)
+ require.True(t, ok, "token_metrics should contain tokens_saved")
+ require.GreaterOrEqual(t, tokensSaved, float64(0), "tokens_saved should be >= 0")
+
+ savingsPercentage, ok := tokenMetrics["savings_percentage"].(float64)
+ require.True(t, ok, "token_metrics should contain savings_percentage")
+ require.GreaterOrEqual(t, savingsPercentage, float64(0), "savings_percentage should be >= 0")
+ require.LessOrEqual(t, savingsPercentage, float64(100), "savings_percentage should be <= 100")
+
+ // Verify tools are returned
+ tools, ok := response["tools"].([]any)
+ require.True(t, ok, "Response should contain tools")
+ require.Greater(t, len(tools), 0, "Should return at least one tool")
+}
diff --git a/pkg/vmcp/optimizer/optimizer_unit_test.go b/pkg/vmcp/optimizer/optimizer_unit_test.go
new file mode 100644
index 0000000000..c764d54aeb
--- /dev/null
+++ b/pkg/vmcp/optimizer/optimizer_unit_test.go
@@ -0,0 +1,338 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package optimizer
+
+import (
+ "context"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings"
+ transportsession "github.com/stacklok/toolhive/pkg/transport/session"
+ "github.com/stacklok/toolhive/pkg/vmcp"
+ "github.com/stacklok/toolhive/pkg/vmcp/aggregator"
+ vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session"
+)
+
+// mockBackendClient implements vmcp.BackendClient for testing
+type mockBackendClient struct {
+ capabilities *vmcp.CapabilityList
+ err error
+}
+
+func (m *mockBackendClient) ListCapabilities(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) {
+ if m.err != nil {
+ return nil, m.err
+ }
+ return m.capabilities, nil
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockBackendClient) CallTool(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (map[string]any, error) {
+ return nil, nil
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockBackendClient) GetPrompt(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (string, error) {
+ return "", nil
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockBackendClient) ReadResource(_ context.Context, _ *vmcp.BackendTarget, _ string) ([]byte, error) {
+ return nil, nil
+}
+
+// mockSession implements server.ClientSession for testing
+type mockSession struct {
+ sessionID string
+}
+
+func (m *mockSession) SessionID() string {
+ return m.sessionID
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockSession) Send(_ interface{}) error {
+ return nil
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockSession) Close() error {
+ return nil
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockSession) Initialize() {
+ // No-op for testing
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockSession) Initialized() bool {
+ return true
+}
+
+//nolint:revive // Receiver unused in mock implementation
+func (m *mockSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
+ // Return a dummy channel for testing
+ ch := make(chan mcp.JSONRPCNotification, 1)
+ return ch
+}
+
+// TestNewIntegration_Disabled tests that nil is returned when optimizer is disabled
+func TestNewIntegration_Disabled(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ // Test with nil config
+ integration, err := NewIntegration(ctx, nil, nil, nil, nil)
+ require.NoError(t, err)
+ assert.Nil(t, integration, "Should return nil when config is nil")
+
+ // Test with disabled config
+ config := &Config{Enabled: false}
+ integration, err = NewIntegration(ctx, config, nil, nil, nil)
+ require.NoError(t, err)
+ assert.Nil(t, integration, "Should return nil when optimizer is disabled")
+}
+
+// TestNewIntegration_Enabled tests successful creation
+func TestNewIntegration_Enabled(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ // Try to use Ollama if available, otherwise skip test
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ mcpServer := server.NewMCPServer("test-server", "1.0")
+ mockClient := &mockBackendClient{}
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "nomic-embed-text",
+ Dimension: 768,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ require.NotNil(t, integration)
+ defer func() { _ = integration.Close() }()
+}
+
+// TestOnRegisterSession tests session registration
+func TestOnRegisterSession(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ mcpServer := server.NewMCPServer("test-server", "1.0")
+ mockClient := &mockBackendClient{}
+
+ // Try to use Ollama if available, otherwise skip test
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "nomic-embed-text",
+ Dimension: 768,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ session := &mockSession{sessionID: "test-session"}
+ capabilities := &aggregator.AggregatedCapabilities{
+ Tools: []vmcp.Tool{
+ {
+ Name: "test_tool",
+ Description: "A test tool",
+ BackendID: "backend-1",
+ },
+ },
+ RoutingTable: &vmcp.RoutingTable{
+ Tools: map[string]*vmcp.BackendTarget{
+ "test_tool": {
+ WorkloadID: "backend-1",
+ WorkloadName: "Test Backend",
+ },
+ },
+ Resources: map[string]*vmcp.BackendTarget{},
+ Prompts: map[string]*vmcp.BackendTarget{},
+ },
+ }
+
+ err = integration.OnRegisterSession(ctx, session, capabilities)
+ assert.NoError(t, err)
+}
+
+// TestOnRegisterSession_NilIntegration tests nil integration handling
+func TestOnRegisterSession_NilIntegration(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ var integration *OptimizerIntegration = nil
+ session := &mockSession{sessionID: "test-session"}
+ capabilities := &aggregator.AggregatedCapabilities{}
+
+ err := integration.OnRegisterSession(ctx, session, capabilities)
+ assert.NoError(t, err)
+}
+
+// TestRegisterTools tests tool registration behavior
+func TestRegisterTools(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ mcpServer := server.NewMCPServer("test-server", "1.0")
+ mockClient := &mockBackendClient{}
+
+ // Try to use Ollama if available, otherwise skip test
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "nomic-embed-text",
+ Dimension: 768,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+ defer func() { _ = integration.Close() }()
+
+ session := &mockSession{sessionID: "test-session"}
+ // RegisterTools will fail with "session not found" because the mock session
+ // is not actually registered with the MCP server. This is expected behavior.
+ // We're just testing that the method executes without panicking.
+ _ = integration.RegisterTools(ctx, session)
+}
+
+// TestRegisterTools_NilIntegration tests nil integration handling
+func TestRegisterTools_NilIntegration(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ var integration *OptimizerIntegration = nil
+ session := &mockSession{sessionID: "test-session"}
+
+ err := integration.RegisterTools(ctx, session)
+ assert.NoError(t, err)
+}
+
+// TestClose tests cleanup
+func TestClose(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+
+ mcpServer := server.NewMCPServer("test-server", "1.0")
+ mockClient := &mockBackendClient{}
+
+ // Try to use Ollama if available, otherwise skip test
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ config := &Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "nomic-embed-text",
+ Dimension: 768,
+ },
+ }
+
+ sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory())
+ integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr)
+ require.NoError(t, err)
+
+ err = integration.Close()
+ assert.NoError(t, err)
+
+ // Multiple closes should be safe
+ err = integration.Close()
+ assert.NoError(t, err)
+}
+
+// TestClose_NilIntegration tests nil integration close
+func TestClose_NilIntegration(t *testing.T) {
+ t.Parallel()
+
+ var integration *OptimizerIntegration = nil
+ err := integration.Close()
+ assert.NoError(t, err)
+}
diff --git a/pkg/vmcp/schema/reflect_test.go b/pkg/vmcp/schema/reflect_test.go
index 55d9491019..2e0da8ed28 100644
--- a/pkg/vmcp/schema/reflect_test.go
+++ b/pkg/vmcp/schema/reflect_test.go
@@ -8,10 +8,85 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
-
- "github.com/stacklok/toolhive/pkg/vmcp/optimizer"
)
+// FindToolInput represents the input schema for optim_find_tool
+// This matches the schema defined in pkg/vmcp/optimizer/optimizer.go
+type FindToolInput struct {
+ ToolDescription string `json:"tool_description" description:"Natural language description of the tool you're looking for"`
+ ToolKeywords string `json:"tool_keywords,omitempty" description:"Optional space-separated keywords for keyword-based search"`
+ Limit int `json:"limit,omitempty" description:"Maximum number of tools to return (default: 10)"`
+}
+
+// CallToolInput represents the input schema for optim_call_tool
+// This matches the schema defined in pkg/vmcp/optimizer/optimizer.go
+type CallToolInput struct {
+ BackendID string `json:"backend_id" description:"Backend ID from find_tool results"`
+ ToolName string `json:"tool_name" description:"Tool name to invoke"`
+ Parameters map[string]any `json:"parameters" description:"Parameters to pass to the tool"`
+}
+
+func TestGenerateSchema_AllTypes(t *testing.T) {
+ t.Parallel()
+
+ type TestStruct struct {
+ StringField string `json:"string_field,omitempty"`
+ IntField int `json:"int_field"`
+ FloatField float64 `json:"float_field,omitempty"`
+ BoolField bool `json:"bool_field"`
+ OptionalStr string `json:"optional_str,omitempty"`
+ SliceField []int `json:"slice_field"`
+ MapField map[string]string `json:"map_field"`
+ StructField struct {
+ RequiredField string `json:"field"`
+ OptionalField string `json:"optional_field,omitempty"`
+ } `json:"struct_field"`
+ PointerField *int `json:"pointer_field"`
+ }
+
+ expected := map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "string_field": map[string]any{"type": "string"},
+ "int_field": map[string]any{"type": "integer"},
+ "float_field": map[string]any{"type": "number"},
+ "bool_field": map[string]any{"type": "boolean"},
+ "optional_str": map[string]any{"type": "string"},
+ "slice_field": map[string]any{
+ "type": "array",
+ "items": map[string]any{"type": "integer"},
+ },
+ "map_field": map[string]any{"type": "object"},
+ "struct_field": map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "field": map[string]any{"type": "string"},
+ "optional_field": map[string]any{"type": "string"},
+ },
+ "required": []string{"field"},
+ },
+ "pointer_field": map[string]any{
+ "type": "integer",
+ },
+ },
+ "required": []string{
+ "int_field",
+ "bool_field",
+ "map_field",
+ "struct_field",
+ "pointer_field",
+ "slice_field",
+ },
+ }
+
+ actual, err := GenerateSchema[TestStruct]()
+ require.NoError(t, err)
+
+ require.Equal(t, expected["type"], actual["type"])
+ require.Equal(t, expected["properties"], actual["properties"])
+ require.ElementsMatch(t, expected["required"], actual["required"])
+}
+
func TestGenerateSchema_FindToolInput(t *testing.T) {
t.Parallel()
@@ -20,18 +95,21 @@ func TestGenerateSchema_FindToolInput(t *testing.T) {
"properties": map[string]any{
"tool_description": map[string]any{
"type": "string",
- "description": "Natural language description of the tool to find",
+ "description": "Natural language description of the tool you're looking for",
},
"tool_keywords": map[string]any{
- "type": "array",
- "items": map[string]any{"type": "string"},
- "description": "Optional keywords to narrow search",
+ "type": "string",
+ "description": "Optional space-separated keywords for keyword-based search",
+ },
+ "limit": map[string]any{
+ "type": "integer",
+ "description": "Maximum number of tools to return (default: 10)",
},
},
"required": []string{"tool_description"},
}
- actual, err := GenerateSchema[optimizer.FindToolInput]()
+ actual, err := GenerateSchema[FindToolInput]()
require.NoError(t, err)
require.Equal(t, expected, actual)
@@ -43,19 +121,23 @@ func TestGenerateSchema_CallToolInput(t *testing.T) {
expected := map[string]any{
"type": "object",
"properties": map[string]any{
+ "backend_id": map[string]any{
+ "type": "string",
+ "description": "Backend ID from find_tool results",
+ },
"tool_name": map[string]any{
"type": "string",
- "description": "Name of the tool to call",
+ "description": "Tool name to invoke",
},
"parameters": map[string]any{
"type": "object",
"description": "Parameters to pass to the tool",
},
},
- "required": []string{"tool_name", "parameters"},
+ "required": []string{"backend_id", "tool_name", "parameters"},
}
- actual, err := GenerateSchema[optimizer.CallToolInput]()
+ actual, err := GenerateSchema[CallToolInput]()
require.NoError(t, err)
require.Equal(t, expected, actual)
@@ -66,15 +148,17 @@ func TestTranslate_FindToolInput(t *testing.T) {
input := map[string]any{
"tool_description": "find a tool to read files",
- "tool_keywords": []any{"file", "read"},
+ "tool_keywords": "file read",
+ "limit": 5,
}
- result, err := Translate[optimizer.FindToolInput](input)
+ result, err := Translate[FindToolInput](input)
require.NoError(t, err)
- require.Equal(t, optimizer.FindToolInput{
+ require.Equal(t, FindToolInput{
ToolDescription: "find a tool to read files",
- ToolKeywords: []string{"file", "read"},
+ ToolKeywords: "file read",
+ Limit: 5,
}, result)
}
@@ -82,16 +166,18 @@ func TestTranslate_CallToolInput(t *testing.T) {
t.Parallel()
input := map[string]any{
- "tool_name": "read_file",
+ "backend_id": "backend-123",
+ "tool_name": "read_file",
"parameters": map[string]any{
"path": "/etc/hosts",
},
}
- result, err := Translate[optimizer.CallToolInput](input)
+ result, err := Translate[CallToolInput](input)
require.NoError(t, err)
- require.Equal(t, optimizer.CallToolInput{
+ require.Equal(t, CallToolInput{
+ BackendID: "backend-123",
ToolName: "read_file",
Parameters: map[string]any{"path": "/etc/hosts"},
}, result)
@@ -104,12 +190,13 @@ func TestTranslate_PartialInput(t *testing.T) {
"tool_description": "find a file reader",
}
- result, err := Translate[optimizer.FindToolInput](input)
+ result, err := Translate[FindToolInput](input)
require.NoError(t, err)
- require.Equal(t, optimizer.FindToolInput{
+ require.Equal(t, FindToolInput{
ToolDescription: "find a file reader",
- ToolKeywords: nil,
+ ToolKeywords: "",
+ Limit: 0,
}, result)
}
@@ -118,68 +205,7 @@ func TestTranslate_InvalidInput(t *testing.T) {
input := make(chan int)
- _, err := Translate[optimizer.FindToolInput](input)
+ _, err := Translate[FindToolInput](input)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to marshal input")
}
-
-func TestGenerateSchema_AllTypes(t *testing.T) {
- t.Parallel()
-
- type TestStruct struct {
- StringField string `json:"string_field,omitempty"`
- IntField int `json:"int_field"`
- FloatField float64 `json:"float_field,omitempty"`
- BoolField bool `json:"bool_field"`
- OptionalStr string `json:"optional_str,omitempty"`
- SliceField []int `json:"slice_field"`
- MapField map[string]string `json:"map_field"`
- StructField struct {
- RequiredField string `json:"field"`
- OptionalField string `json:"optional_field,omitempty"`
- } `json:"struct_field"`
- PointerField *int `json:"pointer_field"`
- }
-
- expected := map[string]any{
- "type": "object",
- "properties": map[string]any{
- "string_field": map[string]any{"type": "string"},
- "int_field": map[string]any{"type": "integer"},
- "float_field": map[string]any{"type": "number"},
- "bool_field": map[string]any{"type": "boolean"},
- "optional_str": map[string]any{"type": "string"},
- "slice_field": map[string]any{
- "type": "array",
- "items": map[string]any{"type": "integer"},
- },
- "map_field": map[string]any{"type": "object"},
- "struct_field": map[string]any{
- "type": "object",
- "properties": map[string]any{
- "field": map[string]any{"type": "string"},
- "optional_field": map[string]any{"type": "string"},
- },
- "required": []string{"field"},
- },
- "pointer_field": map[string]any{
- "type": "integer",
- },
- },
- "required": []string{
- "int_field",
- "bool_field",
- "map_field",
- "struct_field",
- "pointer_field",
- "slice_field",
- },
- }
-
- actual, err := GenerateSchema[TestStruct]()
- require.NoError(t, err)
-
- require.Equal(t, expected["type"], actual["type"])
- require.Equal(t, expected["properties"], actual["properties"])
- require.ElementsMatch(t, expected["required"], actual["required"])
-}
diff --git a/pkg/vmcp/server/adapter/capability_adapter.go b/pkg/vmcp/server/adapter/capability_adapter.go
index 875ecbd9b0..e3b488dacc 100644
--- a/pkg/vmcp/server/adapter/capability_adapter.go
+++ b/pkg/vmcp/server/adapter/capability_adapter.go
@@ -4,6 +4,7 @@
package adapter
import (
+ "context"
"encoding/json"
"fmt"
@@ -14,6 +15,17 @@ import (
"github.com/stacklok/toolhive/pkg/vmcp"
)
+// OptimizerHandlerProvider provides handlers for optimizer tools.
+// This interface allows the adapter to create optimizer tools without
+// depending on the optimizer package implementation.
+type OptimizerHandlerProvider interface {
+ // CreateFindToolHandler returns the handler for optim_find_tool
+ CreateFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error)
+
+ // CreateCallToolHandler returns the handler for optim_call_tool
+ CreateCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error)
+}
+
// CapabilityAdapter converts aggregator domain models to SDK types.
//
// This is the Anti-Corruption Layer between:
@@ -208,3 +220,15 @@ func (a *CapabilityAdapter) ToCompositeToolSDKTools(
return sdkTools, nil
}
+
+// CreateOptimizerTools creates SDK tools for optimizer mode.
+//
+// When optimizer is enabled, only optim_find_tool and optim_call_tool are exposed
+// to clients instead of all backend tools. This method delegates to the standalone
+// CreateOptimizerTools function in optimizer_adapter.go for consistency.
+//
+// This keeps optimizer tool creation consistent with other tool types (backend,
+// composite) by going through the adapter layer.
+func (*CapabilityAdapter) CreateOptimizerTools(provider OptimizerHandlerProvider) ([]server.ServerTool, error) {
+ return CreateOptimizerTools(provider)
+}
diff --git a/pkg/vmcp/server/adapter/optimizer_adapter.go b/pkg/vmcp/server/adapter/optimizer_adapter.go
index 07a6f4cb72..d38d2fa514 100644
--- a/pkg/vmcp/server/adapter/optimizer_adapter.go
+++ b/pkg/vmcp/server/adapter/optimizer_adapter.go
@@ -4,15 +4,11 @@
package adapter
import (
- "context"
"encoding/json"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
-
- "github.com/stacklok/toolhive/pkg/vmcp/optimizer"
- "github.com/stacklok/toolhive/pkg/vmcp/schema"
)
// OptimizerToolNames defines the tool names exposed when optimizer is enabled.
@@ -24,80 +20,88 @@ const (
// Pre-generated schemas for optimizer tools.
// Generated at package init time so any schema errors panic at startup.
var (
- findToolInputSchema = mustGenerateSchema[optimizer.FindToolInput]()
- callToolInputSchema = mustGenerateSchema[optimizer.CallToolInput]()
+ findToolInputSchema = mustMarshalSchema(findToolSchema)
+ callToolInputSchema = mustMarshalSchema(callToolSchema)
+)
+
+// Tool schemas defined once to eliminate duplication.
+var (
+ findToolSchema = mcp.ToolInputSchema{
+ Type: "object",
+ Properties: map[string]any{
+ "tool_description": map[string]any{
+ "type": "string",
+ "description": "Natural language description of the tool you're looking for",
+ },
+ "tool_keywords": map[string]any{
+ "type": "string",
+ "description": "Optional space-separated keywords for keyword-based search",
+ },
+ "limit": map[string]any{
+ "type": "integer",
+ "description": "Maximum number of tools to return (default: 10)",
+ "default": 10,
+ },
+ },
+ Required: []string{"tool_description"},
+ }
+
+ callToolSchema = mcp.ToolInputSchema{
+ Type: "object",
+ Properties: map[string]any{
+ "backend_id": map[string]any{
+ "type": "string",
+ "description": "Backend ID from find_tool results",
+ },
+ "tool_name": map[string]any{
+ "type": "string",
+ "description": "Tool name to invoke",
+ },
+ "parameters": map[string]any{
+ "type": "object",
+ "description": "Parameters to pass to the tool",
+ },
+ },
+ Required: []string{"backend_id", "tool_name", "parameters"},
+ }
)
// CreateOptimizerTools creates the SDK tools for optimizer mode.
// When optimizer is enabled, only these two tools are exposed to clients
// instead of all backend tools.
-func CreateOptimizerTools(opt optimizer.Optimizer) []server.ServerTool {
+//
+// This function uses the OptimizerHandlerProvider interface to get handlers,
+// allowing it to work with OptimizerIntegration without direct dependency.
+func CreateOptimizerTools(provider OptimizerHandlerProvider) ([]server.ServerTool, error) {
+ if provider == nil {
+ return nil, fmt.Errorf("optimizer handler provider cannot be nil")
+ }
+
return []server.ServerTool{
{
Tool: mcp.Tool{
Name: FindToolName,
- Description: "Search for tools by description. Returns matching tools ranked by relevance.",
+ Description: "Semantic search across all backend tools using natural language description and optional keywords",
RawInputSchema: findToolInputSchema,
},
- Handler: createFindToolHandler(opt),
+ Handler: provider.CreateFindToolHandler(),
},
{
Tool: mcp.Tool{
Name: CallToolName,
- Description: "Call a tool by name with the given parameters.",
+ Description: "Dynamically invoke any tool on any backend using the backend_id from find_tool",
RawInputSchema: callToolInputSchema,
},
- Handler: createCallToolHandler(opt),
+ Handler: provider.CreateCallToolHandler(),
},
- }
-}
-
-// createFindToolHandler creates a handler for the find_tool optimizer operation.
-func createFindToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- input, err := schema.Translate[optimizer.FindToolInput](request.Params.Arguments)
- if err != nil {
- return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil
- }
-
- output, err := opt.FindTool(ctx, input)
- if err != nil {
- return mcp.NewToolResultError(fmt.Sprintf("find_tool failed: %v", err)), nil
- }
-
- return mcp.NewToolResultStructuredOnly(output), nil
- }
-}
-
-// createCallToolHandler creates a handler for the call_tool optimizer operation.
-func createCallToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
- input, err := schema.Translate[optimizer.CallToolInput](request.Params.Arguments)
- if err != nil {
- return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil
- }
-
- result, err := opt.CallTool(ctx, input)
- if err != nil {
- // Exposing the error to the MCP client is important if you want it to correct its behavior.
- // Without information on the failure, the model is pretty much hopeless in figuring out the problem.
- return mcp.NewToolResultError(fmt.Sprintf("call_tool failed: %v", err)), nil
- }
-
- return result, nil
- }
+ }, nil
}
// mustMarshalSchema marshals a schema to JSON, panicking on error.
// This is safe because schemas are generated from known types at startup.
// This should NOT be called by runtime code.
-func mustGenerateSchema[T any]() json.RawMessage {
- s, err := schema.GenerateSchema[T]()
- if err != nil {
- panic(fmt.Sprintf("failed to generate schema: %v", err))
- }
-
- data, err := json.Marshal(s)
+func mustMarshalSchema(schema mcp.ToolInputSchema) json.RawMessage {
+ data, err := json.Marshal(schema)
if err != nil {
panic(fmt.Sprintf("failed to marshal schema: %v", err))
}
diff --git a/pkg/vmcp/server/adapter/optimizer_adapter_test.go b/pkg/vmcp/server/adapter/optimizer_adapter_test.go
index b5ad7e066a..4272a978c4 100644
--- a/pkg/vmcp/server/adapter/optimizer_adapter_test.go
+++ b/pkg/vmcp/server/adapter/optimizer_adapter_test.go
@@ -9,65 +9,76 @@ import (
"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/require"
-
- "github.com/stacklok/toolhive/pkg/vmcp/optimizer"
)
-// mockOptimizer implements optimizer.Optimizer for testing.
-type mockOptimizer struct {
- findToolFunc func(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error)
- callToolFunc func(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error)
+// mockOptimizerHandlerProvider implements OptimizerHandlerProvider for testing.
+type mockOptimizerHandlerProvider struct {
+ findToolHandler func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error)
+ callToolHandler func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error)
}
-func (m *mockOptimizer) FindTool(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) {
- if m.findToolFunc != nil {
- return m.findToolFunc(ctx, input)
+func (m *mockOptimizerHandlerProvider) CreateFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ if m.findToolHandler != nil {
+ return m.findToolHandler
+ }
+ return func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ return mcp.NewToolResultText("ok"), nil
}
- return &optimizer.FindToolOutput{}, nil
}
-func (m *mockOptimizer) CallTool(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) {
- if m.callToolFunc != nil {
- return m.callToolFunc(ctx, input)
+func (m *mockOptimizerHandlerProvider) CreateCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ if m.callToolHandler != nil {
+ return m.callToolHandler
+ }
+ return func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ return mcp.NewToolResultText("ok"), nil
}
- return mcp.NewToolResultText("ok"), nil
}
func TestCreateOptimizerTools(t *testing.T) {
t.Parallel()
- opt := &mockOptimizer{}
- tools := CreateOptimizerTools(opt)
+ provider := &mockOptimizerHandlerProvider{}
+ tools, err := CreateOptimizerTools(provider)
+ require.NoError(t, err)
require.Len(t, tools, 2)
require.Equal(t, FindToolName, tools[0].Tool.Name)
require.Equal(t, CallToolName, tools[1].Tool.Name)
}
+func TestCreateOptimizerTools_NilProvider(t *testing.T) {
+ t.Parallel()
+
+ tools, err := CreateOptimizerTools(nil)
+
+ require.Error(t, err)
+ require.Nil(t, tools)
+ require.Contains(t, err.Error(), "cannot be nil")
+}
+
func TestFindToolHandler(t *testing.T) {
t.Parallel()
- opt := &mockOptimizer{
- findToolFunc: func(_ context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) {
- require.Equal(t, "read files", input.ToolDescription)
- return &optimizer.FindToolOutput{
- Tools: []optimizer.ToolMatch{
- {
- Name: "read_file",
- Description: "Read a file",
- Score: 1.0,
- },
- },
- }, nil
+ provider := &mockOptimizerHandlerProvider{
+ findToolHandler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ args, ok := req.Params.Arguments.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "read files", args["tool_description"])
+ return mcp.NewToolResultText("found tools"), nil
},
}
- tools := CreateOptimizerTools(opt)
+ tools, err := CreateOptimizerTools(provider)
+ require.NoError(t, err)
handler := tools[0].Handler
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]any{
- "tool_description": "read files",
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Arguments: map[string]any{
+ "tool_description": "read files",
+ },
+ },
}
result, err := handler(context.Background(), request)
@@ -80,22 +91,29 @@ func TestFindToolHandler(t *testing.T) {
func TestCallToolHandler(t *testing.T) {
t.Parallel()
- opt := &mockOptimizer{
- callToolFunc: func(_ context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) {
- require.Equal(t, "read_file", input.ToolName)
- require.Equal(t, "/etc/hosts", input.Parameters["path"])
+ provider := &mockOptimizerHandlerProvider{
+ callToolHandler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+ args, ok := req.Params.Arguments.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "read_file", args["tool_name"])
+ params := args["parameters"].(map[string]any)
+ require.Equal(t, "/etc/hosts", params["path"])
return mcp.NewToolResultText("file contents here"), nil
},
}
- tools := CreateOptimizerTools(opt)
+ tools, err := CreateOptimizerTools(provider)
+ require.NoError(t, err)
handler := tools[1].Handler
- request := mcp.CallToolRequest{}
- request.Params.Arguments = map[string]any{
- "tool_name": "read_file",
- "parameters": map[string]any{
- "path": "/etc/hosts",
+ request := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Arguments: map[string]any{
+ "tool_name": "read_file",
+ "parameters": map[string]any{
+ "path": "/etc/hosts",
+ },
+ },
},
}
diff --git a/pkg/vmcp/server/optimizer_test.go b/pkg/vmcp/server/optimizer_test.go
new file mode 100644
index 0000000000..56cfeff396
--- /dev/null
+++ b/pkg/vmcp/server/optimizer_test.go
@@ -0,0 +1,362 @@
+// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package server
+
+import (
+ "context"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.uber.org/mock/gomock"
+
+ "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings"
+ "github.com/stacklok/toolhive/pkg/vmcp"
+ "github.com/stacklok/toolhive/pkg/vmcp/aggregator"
+ discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks"
+ "github.com/stacklok/toolhive/pkg/vmcp/mocks"
+ "github.com/stacklok/toolhive/pkg/vmcp/optimizer"
+ "github.com/stacklok/toolhive/pkg/vmcp/router"
+)
+
+// TestNew_OptimizerEnabled tests server creation with optimizer enabled
+func TestNew_OptimizerEnabled(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockBackendClient := mocks.NewMockBackendClient(ctrl)
+ mockBackendClient.EXPECT().
+ ListCapabilities(gomock.Any(), gomock.Any()).
+ Return(&vmcp.CapabilityList{}, nil).
+ AnyTimes()
+
+ mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
+ mockDiscoveryMgr.EXPECT().
+ Discover(gomock.Any(), gomock.Any()).
+ Return(&aggregator.AggregatedCapabilities{}, nil).
+ AnyTimes()
+ mockDiscoveryMgr.EXPECT().Stop().AnyTimes()
+
+ tmpDir := t.TempDir()
+
+ // Try to use Ollama if available
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ cfg := &Config{
+ Name: "test-server",
+ Version: "1.0.0",
+ Host: "127.0.0.1",
+ Port: 0,
+ SessionTTL: 5 * time.Minute,
+ OptimizerConfig: &optimizer.Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ HybridSearchRatio: 70,
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ },
+ }
+
+ rt := router.NewDefaultRouter()
+ backends := []vmcp.Backend{
+ {
+ ID: "backend-1",
+ Name: "Backend 1",
+ BaseURL: "http://localhost:8000",
+ TransportType: "sse",
+ },
+ }
+
+ srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil)
+ require.NoError(t, err)
+ require.NotNil(t, srv)
+ defer func() { _ = srv.Stop(context.Background()) }()
+
+ // Verify optimizer integration was created
+ // We can't directly access optimizerIntegration, but we can verify server was created successfully
+}
+
+// TestNew_OptimizerDisabled tests server creation with optimizer disabled
+func TestNew_OptimizerDisabled(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockBackendClient := mocks.NewMockBackendClient(ctrl)
+ mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
+ mockDiscoveryMgr.EXPECT().Stop().AnyTimes()
+
+ cfg := &Config{
+ Name: "test-server",
+ Version: "1.0.0",
+ Host: "127.0.0.1",
+ Port: 0,
+ SessionTTL: 5 * time.Minute,
+ OptimizerConfig: &optimizer.Config{
+ Enabled: false, // Disabled
+ },
+ }
+
+ rt := router.NewDefaultRouter()
+ backends := []vmcp.Backend{}
+
+ srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil)
+ require.NoError(t, err)
+ require.NotNil(t, srv)
+ defer func() { _ = srv.Stop(context.Background()) }()
+}
+
+// TestNew_OptimizerConfigNil tests server creation with nil optimizer config
+func TestNew_OptimizerConfigNil(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockBackendClient := mocks.NewMockBackendClient(ctrl)
+ mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
+ mockDiscoveryMgr.EXPECT().Stop().AnyTimes()
+
+ cfg := &Config{
+ Name: "test-server",
+ Version: "1.0.0",
+ Host: "127.0.0.1",
+ Port: 0,
+ SessionTTL: 5 * time.Minute,
+ OptimizerConfig: nil, // Nil config
+ }
+
+ rt := router.NewDefaultRouter()
+ backends := []vmcp.Backend{}
+
+ srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil)
+ require.NoError(t, err)
+ require.NotNil(t, srv)
+ defer func() { _ = srv.Stop(context.Background()) }()
+}
+
+// TestNew_OptimizerIngestionError tests error handling during optimizer ingestion
+func TestNew_OptimizerIngestionError(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockBackendClient := mocks.NewMockBackendClient(ctrl)
+ // Return error when listing capabilities
+ mockBackendClient.EXPECT().
+ ListCapabilities(gomock.Any(), gomock.Any()).
+ Return(nil, assert.AnError).
+ AnyTimes()
+
+ mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
+ mockDiscoveryMgr.EXPECT().Stop().AnyTimes()
+
+ tmpDir := t.TempDir()
+
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ cfg := &Config{
+ Name: "test-server",
+ Version: "1.0.0",
+ Host: "127.0.0.1",
+ Port: 0,
+ SessionTTL: 5 * time.Minute,
+ OptimizerConfig: &optimizer.Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ },
+ }
+
+ rt := router.NewDefaultRouter()
+ backends := []vmcp.Backend{
+ {
+ ID: "backend-1",
+ Name: "Backend 1",
+ BaseURL: "http://localhost:8000",
+ TransportType: "sse",
+ },
+ }
+
+ // Should not fail even if ingestion fails
+ srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil)
+ require.NoError(t, err, "Server should be created even if optimizer ingestion fails")
+ require.NotNil(t, srv)
+ defer func() { _ = srv.Stop(context.Background()) }()
+}
+
+// TestNew_OptimizerHybridRatio tests hybrid ratio configuration
+func TestNew_OptimizerHybridRatio(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockBackendClient := mocks.NewMockBackendClient(ctrl)
+ mockBackendClient.EXPECT().
+ ListCapabilities(gomock.Any(), gomock.Any()).
+ Return(&vmcp.CapabilityList{}, nil).
+ AnyTimes()
+
+ mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
+ mockDiscoveryMgr.EXPECT().
+ Discover(gomock.Any(), gomock.Any()).
+ Return(&aggregator.AggregatedCapabilities{}, nil).
+ AnyTimes()
+ mockDiscoveryMgr.EXPECT().Stop().AnyTimes()
+
+ tmpDir := t.TempDir()
+
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ cfg := &Config{
+ Name: "test-server",
+ Version: "1.0.0",
+ Host: "127.0.0.1",
+ Port: 0,
+ SessionTTL: 5 * time.Minute,
+ OptimizerConfig: &optimizer.Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ HybridSearchRatio: 50, // Custom ratio
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ },
+ }
+
+ rt := router.NewDefaultRouter()
+ backends := []vmcp.Backend{}
+
+ srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil)
+ require.NoError(t, err)
+ require.NotNil(t, srv)
+ defer func() { _ = srv.Stop(context.Background()) }()
+}
+
+// TestServer_Stop_OptimizerCleanup tests optimizer cleanup on server stop
+func TestServer_Stop_OptimizerCleanup(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockBackendClient := mocks.NewMockBackendClient(ctrl)
+ mockBackendClient.EXPECT().
+ ListCapabilities(gomock.Any(), gomock.Any()).
+ Return(&vmcp.CapabilityList{}, nil).
+ AnyTimes()
+
+ mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
+ mockDiscoveryMgr.EXPECT().
+ Discover(gomock.Any(), gomock.Any()).
+ Return(&aggregator.AggregatedCapabilities{}, nil).
+ AnyTimes()
+ mockDiscoveryMgr.EXPECT().Stop().AnyTimes()
+
+ tmpDir := t.TempDir()
+
+ embeddingConfig := &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ }
+
+ embeddingManager, err := embeddings.NewManager(embeddingConfig)
+ if err != nil {
+ t.Skipf("Skipping test: Ollama not available. Error: %v", err)
+ return
+ }
+ _ = embeddingManager.Close()
+
+ cfg := &Config{
+ Name: "test-server",
+ Version: "1.0.0",
+ Host: "127.0.0.1",
+ Port: 0,
+ SessionTTL: 5 * time.Minute,
+ OptimizerConfig: &optimizer.Config{
+ Enabled: true,
+ PersistPath: filepath.Join(tmpDir, "optimizer-db"),
+ EmbeddingConfig: &embeddings.Config{
+ BackendType: "ollama",
+ BaseURL: "http://localhost:11434",
+ Model: "all-minilm",
+ Dimension: 384,
+ },
+ },
+ }
+
+ rt := router.NewDefaultRouter()
+ backends := []vmcp.Backend{}
+
+ srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil)
+ require.NoError(t, err)
+ require.NotNil(t, srv)
+
+ // Stop should clean up optimizer
+ err = srv.Stop(context.Background())
+ require.NoError(t, err)
+}
diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go
index 3ccecdf39c..910494218f 100644
--- a/pkg/vmcp/server/server.go
+++ b/pkg/vmcp/server/server.go
@@ -125,9 +125,15 @@ type Config struct {
// Used for /readyz endpoint to gate readiness on cache sync.
Watcher Watcher
- // OptimizerFactory builds an optimizer from a list of tools.
- // If not set, the optimizer is disabled.
- OptimizerFactory func([]server.ServerTool) optimizer.Optimizer
+ // OptimizerIntegration is the optional optimizer integration.
+ // If nil, optimizer is disabled and backend tools are exposed directly.
+ // If set, this takes precedence over OptimizerConfig.
+ OptimizerIntegration optimizer.Integration
+
+ // OptimizerConfig is the optional optimizer configuration (for backward compatibility).
+ // If OptimizerIntegration is set, this is ignored.
+ // If both are nil, optimizer is disabled.
+ OptimizerConfig *optimizer.Config
// StatusReporter enables vMCP runtime to report operational status.
// In Kubernetes mode: Updates VirtualMCPServer.Status (requires RBAC)
@@ -340,7 +346,15 @@ func New(
if cfg.HealthMonitorConfig != nil {
// Get initial backends list from registry for health monitoring setup
initialBackends := backendRegistry.List(ctx)
- healthMon, err = health.NewMonitor(backendClient, initialBackends, *cfg.HealthMonitorConfig)
+
+ // Construct server's own URL for self-check detection
+ // Use http:// as default scheme (most common for local development)
+ var selfURL string
+ if cfg.Host != "" && cfg.Port > 0 {
+ selfURL = fmt.Sprintf("http://%s:%d", cfg.Host, cfg.Port)
+ }
+
+ healthMon, err = health.NewMonitor(backendClient, initialBackends, *cfg.HealthMonitorConfig, selfURL)
if err != nil {
return nil, fmt.Errorf("failed to create health monitor: %w", err)
}
@@ -533,6 +547,23 @@ func (s *Server) Start(ctx context.Context) error {
}
}
+ // Initialize optimizer integration if configured
+ if s.config.OptimizerIntegration == nil && s.config.OptimizerConfig != nil && s.config.OptimizerConfig.Enabled {
+ // Create optimizer integration from config (for backward compatibility)
+ optimizerInteg, err := optimizer.NewIntegration(ctx, s.config.OptimizerConfig, s.mcpServer, s.backendClient, s.sessionManager)
+ if err != nil {
+ return fmt.Errorf("failed to create optimizer integration: %w", err)
+ }
+ s.config.OptimizerIntegration = optimizerInteg
+ }
+
+ // Initialize optimizer if configured (registers tools and ingests backends)
+ if s.config.OptimizerIntegration != nil {
+ if err := s.config.OptimizerIntegration.Initialize(ctx, s.mcpServer, s.backendRegistry); err != nil {
+ return fmt.Errorf("failed to initialize optimizer: %w", err)
+ }
+ }
+
// Start status reporter if configured
if s.statusReporter != nil {
shutdown, err := s.statusReporter.Start(ctx)
@@ -592,6 +623,13 @@ func (s *Server) Stop(ctx context.Context) error {
}
}
+ // Stop optimizer integration if configured
+ if s.config.OptimizerIntegration != nil {
+ if err := s.config.OptimizerIntegration.Close(); err != nil {
+ errs = append(errs, fmt.Errorf("failed to close optimizer integration: %w", err))
+ }
+ }
+
// Run shutdown functions (e.g., status reporter, future components)
for _, shutdown := range s.shutdownFuncs {
if err := shutdown(ctx); err != nil {
@@ -746,7 +784,6 @@ func (s *Server) Ready() <-chan struct{} {
// - No previous capabilities exist, so no deletion needed
// - Capabilities are IMMUTABLE for the session lifetime (see limitation below)
// - Discovery middleware does not re-run for subsequent requests
-// - If injectOptimizerCapabilities is called, this should not be called again.
//
// LIMITATION: Session capabilities are fixed at creation time.
// If backends change (new tools added, resources removed), existing sessions won't see updates.
@@ -820,54 +857,6 @@ func (s *Server) injectCapabilities(
return nil
}
-// injectOptimizerCapabilities injects all capabilities into the session, including optimizer tools.
-// It should not be called if not in optimizer mode and replaces injectCapabilities.
-//
-// When optimizer mode is enabled, instead of exposing all backend tools directly,
-// vMCP exposes only two meta-tools:
-// - find_tool: Search for tools by description
-// - call_tool: Invoke a tool by name with parameters
-//
-// This method:
-// 1. Converts all tools (backend + composite) to SDK format with handlers
-// 2. Injects the optimizer capabilities into the session
-func (s *Server) injectOptimizerCapabilities(
- sessionID string,
- caps *aggregator.AggregatedCapabilities,
-) error {
-
- tools := append([]vmcp.Tool{}, caps.Tools...)
- tools = append(tools, caps.CompositeTools...)
-
- sdkTools, err := s.capabilityAdapter.ToSDKTools(tools)
- if err != nil {
- return fmt.Errorf("failed to convert tools to SDK format: %w", err)
- }
-
- // Create optimizer tools (find_tool, call_tool)
- optimizerTools := adapter.CreateOptimizerTools(s.config.OptimizerFactory(sdkTools))
-
- logger.Debugw("created optimizer tools for session",
- "session_id", sessionID,
- "backend_tool_count", len(caps.Tools),
- "composite_tool_count", len(caps.CompositeTools),
- "total_tools_indexed", len(sdkTools))
-
- // Clear tools from caps - they're now wrapped by optimizer
- // Resources and prompts are preserved and handled normally
- capsCopy := *caps
- capsCopy.Tools = nil
- capsCopy.CompositeTools = nil
-
- // Manually add the optimizer tools, since we don't want to bother converting
- // optimizer tools into `vmcp.Tool`s as well.
- if err := s.mcpServer.AddSessionTools(sessionID, optimizerTools...); err != nil {
- return fmt.Errorf("failed to add session tools: %w", err)
- }
-
- return s.injectCapabilities(sessionID, &capsCopy)
-}
-
// handleSessionRegistration processes a new MCP session registration.
//
// This hook fires AFTER the session is registered in the SDK (unlike AfterInitialize which
@@ -880,7 +869,7 @@ func (s *Server) injectOptimizerCapabilities(
// 1. Retrieves discovered capabilities from context
// 2. Adds composite tools from configuration
// 3. Stores routing table in VMCPSession for request routing
-// 4. Injects capabilities into the SDK session
+// 4. Injects capabilities into the SDK session (or delegates to optimizer if enabled)
//
// IMPORTANT: Session capabilities are immutable after injection.
// - Capabilities discovered during initialize are fixed for the session lifetime
@@ -955,16 +944,26 @@ func (s *Server) handleSessionRegistration(
"resource_count", len(caps.RoutingTable.Resources),
"prompt_count", len(caps.RoutingTable.Prompts))
- if s.config.OptimizerFactory != nil {
- err = s.injectOptimizerCapabilities(sessionID, caps)
+ // Delegate to optimizer integration if enabled
+ if s.config.OptimizerIntegration != nil {
+ handled, err := s.config.OptimizerIntegration.HandleSessionRegistration(
+ ctx,
+ sessionID,
+ caps,
+ s.mcpServer,
+ s.capabilityAdapter.ToSDKResources,
+ )
if err != nil {
- logger.Errorw("failed to create optimizer tools",
+ logger.Errorw("failed to handle session registration with optimizer",
"error", err,
"session_id", sessionID)
- } else {
- logger.Infow("optimizer capabilities injected")
+ return
}
- return
+ if handled {
+ // Optimizer handled the registration, we're done
+ return
+ }
+ // If optimizer didn't handle it, fall through to normal registration
}
// Inject capabilities into SDK session
diff --git a/test/e2e/thv-operator/virtualmcp/helpers.go b/test/e2e/thv-operator/virtualmcp/helpers.go
index ca73e206f2..e48d3fdea5 100644
--- a/test/e2e/thv-operator/virtualmcp/helpers.go
+++ b/test/e2e/thv-operator/virtualmcp/helpers.go
@@ -89,8 +89,9 @@ func checkPodsReady(ctx context.Context, c client.Client, namespace string, labe
}
for _, pod := range podList.Items {
+ // Skip pods that are not running (e.g., Succeeded, Failed from old deployments)
if pod.Status.Phase != corev1.PodRunning {
- return fmt.Errorf("pod %s is in phase %s", pod.Name, pod.Status.Phase)
+ continue
}
containerReady := false
@@ -114,6 +115,17 @@ func checkPodsReady(ctx context.Context, c client.Client, namespace string, labe
return fmt.Errorf("pod %s not ready", pod.Name)
}
}
+
+ // After filtering, ensure we found at least one running pod
+ runningPods := 0
+ for _, pod := range podList.Items {
+ if pod.Status.Phase == corev1.PodRunning {
+ runningPods++
+ }
+ }
+ if runningPods == 0 {
+ return fmt.Errorf("no running pods found with labels %v", labels)
+ }
return nil
}
diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go
index e7e33fd623..18af2c94df 100644
--- a/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go
+++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go
@@ -1162,12 +1162,28 @@ with socketserver.TCPServer(("", PORT), OIDCHandler) as httpd:
}
It("should list and call tools from all backends with discovered auth", func() {
+ By("Verifying vMCP pods are still running and ready before health check")
+ vmcpLabels := map[string]string{
+ "app.kubernetes.io/name": "virtualmcpserver",
+ "app.kubernetes.io/instance": vmcpServerName,
+ }
+ WaitForPodsReady(ctx, k8sClient, testNamespace, vmcpLabels, 30*time.Second, 2*time.Second)
+
+ // Create HTTP client with timeout for health checks
+ healthCheckClient := &http.Client{
+ Timeout: 10 * time.Second,
+ }
+
By("Verifying HTTP connectivity to VirtualMCPServer health endpoint")
Eventually(func() error {
+ // Re-check pod readiness before each health check attempt
+ if err := checkPodsReady(ctx, k8sClient, testNamespace, vmcpLabels); err != nil {
+ return fmt.Errorf("pods not ready: %w", err)
+ }
url := fmt.Sprintf("http://localhost:%d/health", vmcpNodePort)
- resp, err := http.Get(url)
+ resp, err := healthCheckClient.Get(url)
if err != nil {
- return err
+ return fmt.Errorf("health check failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go
index 67610b043f..b08039b94e 100644
--- a/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go
+++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go
@@ -72,8 +72,9 @@ var _ = Describe("VirtualMCPServer Optimizer Mode", Ordered, func() {
Config: vmcpconfig.Config{
Group: mcpGroupName,
Optimizer: &vmcpconfig.OptimizerConfig{
- // EmbeddingService is required but not used by DummyOptimizer
- EmbeddingService: "dummy-embedding-service",
+ // EmbeddingURL is required for optimizer configuration
+ // For in-cluster services, use the full service DNS name with port
+ EmbeddingURL: "http://dummy-embedding-service.default.svc.cluster.local:11434",
},
// Define a composite tool that calls fetch twice
CompositeTools: []vmcpconfig.CompositeToolConfig{