diff --git a/.gitignore b/.gitignore index 7f5be4bf18..34dcc23d79 100644 --- a/.gitignore +++ b/.gitignore @@ -42,4 +42,11 @@ cmd/thv-operator/.task/checksum/crdref-gen # Test coverage coverage* -crd-helm-wrapper \ No newline at end of file +crd-helm-wrapper +cmd/vmcp/__debug_bin* + +# Demo files +examples/operator/virtual-mcps/vmcp_optimizer.yaml +scripts/k8s_vmcp_optimizer_demo.sh +examples/ingress/mcp-servers-ingress.yaml +/vmcp diff --git a/.golangci.yml b/.golangci.yml index 62c3611473..ff2b3d54e9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -139,6 +139,7 @@ linters: - third_party$ - builtin$ - examples$ + - scripts$ formatters: enable: - gci @@ -155,3 +156,4 @@ formatters: - third_party$ - builtin$ - examples$ + - scripts$ diff --git a/cmd/thv-operator/controllers/mcpserver_controller.go b/cmd/thv-operator/controllers/mcpserver_controller.go index 2313d85ace..f1221c6ecc 100644 --- a/cmd/thv-operator/controllers/mcpserver_controller.go +++ b/cmd/thv-operator/controllers/mcpserver_controller.go @@ -1250,12 +1250,13 @@ func (r *MCPServerReconciler) deploymentForMCPServer( Spec: corev1.PodSpec{ ServiceAccountName: ctrlutil.ProxyRunnerServiceAccountName(m.Name), Containers: []corev1.Container{{ - Image: getToolhiveRunnerImage(), - Name: "toolhive", - Args: args, - Env: env, - VolumeMounts: volumeMounts, - Resources: resources, + Image: getToolhiveRunnerImage(), + Name: "toolhive", + ImagePullPolicy: getImagePullPolicyForToolhiveRunner(), + Args: args, + Env: env, + VolumeMounts: volumeMounts, + Resources: resources, Ports: []corev1.ContainerPort{{ ContainerPort: m.GetProxyPort(), Name: "http", @@ -1813,6 +1814,19 @@ func getToolhiveRunnerImage() string { return image } +// getImagePullPolicyForToolhiveRunner returns the appropriate imagePullPolicy for the toolhive runner container. +// If the image is a local image (starts with "kind.local/" or "localhost/"), use Never. +// Otherwise, use IfNotPresent to allow pulling when needed but avoid unnecessary pulls. +func getImagePullPolicyForToolhiveRunner() corev1.PullPolicy { + image := getToolhiveRunnerImage() + // Check if it's a local image that should use Never + if strings.HasPrefix(image, "kind.local/") || strings.HasPrefix(image, "localhost/") { + return corev1.PullNever + } + // For other images, use IfNotPresent to allow pulling when needed + return corev1.PullIfNotPresent +} + // handleExternalAuthConfig validates and tracks the hash of the referenced MCPExternalAuthConfig. // It updates the MCPServer status when the external auth configuration changes. func (r *MCPServerReconciler) handleExternalAuthConfig(ctx context.Context, m *mcpv1alpha1.MCPServer) error { diff --git a/cmd/thv-operator/pkg/optimizer/INTEGRATION.md b/cmd/thv-operator/pkg/optimizer/INTEGRATION.md new file mode 100644 index 0000000000..a231a0dabb --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/INTEGRATION.md @@ -0,0 +1,134 @@ +# Integrating Optimizer with vMCP + +## Overview + +The optimizer package ingests MCP server and tool metadata into a searchable database with semantic embeddings. This enables intelligent tool discovery and token optimization for LLM consumption. + +## Integration Approach + +**Event-Driven Ingestion**: The optimizer integrates directly with vMCP's startup process. When vMCP starts and loads its configured servers, it calls the optimizer to ingest each server's metadata and tools. + +❌ **NOT** a separate polling service discovering backends +✅ **IS** called directly by vMCP during server initialization + +## How It Is Integrated + +The optimizer is already integrated into vMCP and works automatically when enabled via configuration. Here's how the integration works: + +### Initialization + +When vMCP starts with optimizer enabled in the configuration, it: + +1. Initializes the optimizer database (chromem-go + SQLite FTS5) +2. Configures the embedding backend (placeholder, Ollama, or vLLM) +3. Sets up the ingestion service + +### Automatic Ingestion + +The optimizer integrates with vMCP's `OnRegisterSession` hook, which is called whenever: + +- vMCP starts and loads configured MCP servers +- A new MCP server is dynamically added +- A session reconnects or refreshes + +When this hook is triggered, the optimizer: + +1. Retrieves the server's metadata and tools via MCP protocol +2. Generates embeddings for searchable content +3. Stores the data in both the vector database (chromem-go) and FTS5 database +4. Makes the tools immediately available for semantic search + +### Exposed Tools + +When the optimizer is enabled, vMCP automatically exposes these tools to LLM clients: + +- `optim.find_tool`: Semantic search for tools across all registered servers +- `optim.call_tool`: Dynamic tool invocation after discovery + +### Implementation Location + +The integration code is located in: +- `cmd/vmcp/optimizer.go`: Optimizer initialization and configuration +- `pkg/vmcp/optimizer/optimizer.go`: Session registration hook implementation +- `cmd/thv-operator/pkg/optimizer/ingestion/service.go`: Core ingestion service + +## Configuration + +Add optimizer configuration to vMCP's config: + +```yaml +# vMCP config +optimizer: + enabled: true + db_path: /data/optimizer.db + embedding: + backend: vllm # or "ollama" for local dev, "placeholder" for testing + url: http://vllm-service:8000 + model: sentence-transformers/all-MiniLM-L6-v2 + dimension: 384 +``` + +## Error Handling + +**Important**: Optimizer failures should NOT break vMCP functionality: + +- ✅ Log warnings if optimizer fails +- ✅ Continue server startup even if ingestion fails +- ✅ Run ingestion in goroutines to avoid blocking +- ❌ Don't fail server startup if optimizer is unavailable + +## Benefits + +1. **Automatic**: Servers are indexed as they're added to vMCP +2. **Up-to-date**: Database reflects current vMCP state +3. **No polling**: Event-driven, efficient +4. **Semantic search**: Enables intelligent tool discovery +5. **Token optimization**: Tracks token usage for LLM efficiency + +## Testing + +```go +func TestOptimizerIntegration(t *testing.T) { + // Initialize optimizer + optimizerSvc, err := ingestion.NewService(&ingestion.Config{ + DBConfig: &db.Config{Path: "/tmp/test-optimizer.db"}, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + Dimension: 384, + }, + }) + require.NoError(t, err) + defer optimizerSvc.Close() + + // Simulate vMCP starting a server + ctx := context.Background() + tools := []mcp.Tool{ + {Name: "get_weather", Description: "Get current weather"}, + {Name: "get_forecast", Description: "Get weather forecast"}, + } + + err = optimizerSvc.IngestServer( + ctx, + "weather-001", + "weather-service", + "http://weather.local", + models.TransportSSE, + ptr("Weather information service"), + tools, + ) + require.NoError(t, err) + + // Verify ingestion + server, err := optimizerSvc.GetServer(ctx, "weather-001") + require.NoError(t, err) + assert.Equal(t, "weather-service", server.Name) +} +``` + +## See Also + +- [Optimizer Package README](./README.md) - Package overview and API + diff --git a/cmd/thv-operator/pkg/optimizer/README.md b/cmd/thv-operator/pkg/optimizer/README.md new file mode 100644 index 0000000000..7db703b711 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/README.md @@ -0,0 +1,339 @@ +# Optimizer Package + +The optimizer package provides semantic tool discovery and ingestion for MCP servers in ToolHive's vMCP. It enables intelligent, context-aware tool selection to reduce token usage and improve LLM performance. + +## Features + +- **Pure Go**: No CGO dependencies - uses [chromem-go](https://github.com/philippgille/chromem-go) for vector search and `modernc.org/sqlite` for FTS5 +- **Hybrid Search**: Combines semantic search (chromem-go) with BM25 full-text search (SQLite FTS5) +- **In-Memory by Default**: Fast ephemeral database with optional persistence +- **Pluggable Embeddings**: Supports vLLM, Ollama, and placeholder backends +- **Event-Driven**: Integrates with vMCP's `OnRegisterSession` hook for automatic ingestion +- **Semantic + Keyword Search**: Configurable ratio between semantic and BM25 search +- **Token Counting**: Tracks token usage for LLM consumption metrics + +## Architecture + +``` +cmd/thv-operator/pkg/optimizer/ +├── models/ # Domain models (Server, Tool, etc.) +├── db/ # Hybrid database layer (chromem-go + SQLite FTS5) +│ ├── db.go # Database coordinator +│ ├── fts.go # SQLite FTS5 for BM25 search (pure Go) +│ ├── hybrid.go # Hybrid search combining semantic + BM25 +│ ├── backend_server.go # Server operations +│ └── backend_tool.go # Tool operations +├── embeddings/ # Embedding backends (vLLM, Ollama, placeholder) +├── ingestion/ # Event-driven ingestion service +└── tokens/ # Token counting for LLM metrics +``` + +## Embedding Backends + +The optimizer supports multiple embedding backends: + +| Backend | Use Case | Performance | Setup | +|---------|----------|-------------|-------| +| **vLLM** | **Production/Kubernetes (recommended)** | Excellent (GPU) | Deploy vLLM service | +| Ollama | Local development, CPU-only | Good | `ollama serve` | +| Placeholder | Testing, CI/CD | Fast (hash-based) | Zero setup | + +**For production Kubernetes deployments, vLLM is recommended** due to its high-throughput performance, GPU efficiency (PagedAttention), and scalability for multi-user environments. + +## Hybrid Search + +The optimizer **always uses hybrid search** combining: + +1. **Semantic Search** (chromem-go): Understands meaning and context via embeddings +2. **BM25 Full-Text Search** (SQLite FTS5): Keyword matching with Porter stemming + +This dual approach ensures the best of both worlds: semantic understanding for intent-based queries and keyword precision for technical terms and acronyms. + +### Configuration + +```yaml +optimizer: + enabled: true + embeddingBackend: placeholder + embeddingDimension: 384 + # persistPath: /data/optimizer # Optional: for persistence + # ftsDBPath: /data/optimizer-fts.db # Optional: defaults to :memory: or {persistPath}/fts.db + hybridSearchRatio: 70 # 70% semantic, 30% BM25 (default, 0-100 percentage) +``` + +| Ratio | Semantic | BM25 | Best For | +|-------|----------|------|----------| +| 1.0 | 100% | 0% | Pure semantic (intent-heavy queries) | +| 0.7 | 70% | 30% | **Default**: Balanced hybrid | +| 0.5 | 50% | 50% | Equal weight | +| 0.0 | 0% | 100% | Pure keyword (exact term matching) | + +### How It Works + +1. **Parallel Execution**: Semantic and BM25 searches run concurrently +2. **Result Merging**: Combines results and removes duplicates +3. **Ranking**: Sorts by similarity/relevance score +4. **Limit Enforcement**: Returns top N results + +### Example Queries + +| Query | Semantic Match | BM25 Match | Winner | +|-------|----------------|------------|--------| +| "What's the weather?" | ✅ `get_current_weather` | ✅ `weather_forecast` | Both (deduped) | +| "SQL database query" | ❌ (no embeddings) | ✅ `execute_sql` | BM25 | +| "Make it rain outside" | ✅ `weather_control` | ❌ (no keyword) | Semantic | + +## Quick Start + +### vMCP Integration (Recommended) + +The optimizer is designed to work as part of vMCP, not standalone: + +```yaml +# examples/vmcp-config-optimizer.yaml +optimizer: + enabled: true + embeddingBackend: placeholder # or "ollama", "openai-compatible" + embeddingDimension: 384 + # persistPath: /data/optimizer # Optional: for chromem-go persistence + # ftsDBPath: /data/fts.db # Optional: auto-defaults to :memory: or {persistPath}/fts.db + # hybridSearchRatio: 70 # Optional: 70% semantic, 30% BM25 (default, 0-100 percentage) +``` + +Start vMCP with optimizer: + +```bash +thv vmcp serve --config examples/vmcp-config-optimizer.yaml +``` + +When optimizer is enabled, vMCP exposes: +- `optim.find_tool`: Semantic search for tools +- `optim.call_tool`: Dynamic tool invocation + +### Programmatic Usage + +```go +import ( + "context" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/db" + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings" + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/ingestion" +) + +func main() { + ctx := context.Background() + + // Initialize database (in-memory) + database, err := db.NewDB(&db.Config{ + PersistPath: "", // Empty = in-memory only + }) + if err != nil { + panic(err) + } + + // Initialize embedding manager with Ollama (default) + embeddingMgr, err := embeddings.NewManager(&embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }) + if err != nil { + panic(err) + } + + // Create ingestion service + svc, err := ingestion.NewService(&ingestion.Config{ + DBConfig: &db.Config{PersistPath: ""}, + EmbeddingConfig: embeddingMgr.Config(), + }) + if err != nil { + panic(err) + } + defer svc.Close() + + // Ingest a server (called by vMCP on session registration) + err = svc.IngestServer(ctx, "server-id", "MyServer", nil, []mcp.Tool{...}) + if err != nil { + panic(err) + } +} +``` + +### Production Deployment with vLLM (Kubernetes) + +```yaml +optimizer: + enabled: true + embeddingBackend: openai-compatible + embeddingURL: http://vllm-service:8000/v1 + embeddingModel: BAAI/bge-small-en-v1.5 + embeddingDimension: 768 + persistPath: /data/optimizer # Persistent storage for faster restarts +``` + +Deploy vLLM alongside vMCP: + +```yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: vllm-embeddings +spec: + template: + spec: + containers: + - name: vllm + image: vllm/vllm-openai:latest + args: + - --model + - BAAI/bge-small-en-v1.5 + - --port + - "8000" + resources: + limits: + nvidia.com/gpu: 1 +``` + +### Local Development with Ollama + +```bash +# Start Ollama +ollama serve + +# Pull an embedding model +ollama pull all-minilm +``` + +Configure vMCP: + +```yaml +optimizer: + enabled: true + embeddingBackend: ollama + embeddingURL: http://localhost:11434 + embeddingModel: all-minilm + embeddingDimension: 384 +``` + +## Configuration + +### Database + +- **Storage**: chromem-go (pure Go, no CGO) +- **Default**: In-memory (ephemeral) +- **Persistence**: Optional via `persistPath` +- **Format**: Binary (gob encoding) + +### Embedding Models + +Common embedding dimensions: +- **384**: all-MiniLM-L6-v2, nomic-embed-text (default) +- **768**: BAAI/bge-small-en-v1.5 +- **1536**: OpenAI text-embedding-3-small + +### Performance + +From chromem-go benchmarks (mid-range 2020 Intel laptop): +- **1,000 tools**: ~0.5ms query time +- **5,000 tools**: ~2.2ms query time +- **25,000 tools**: ~9.9ms query time +- **100,000 tools**: ~39.6ms query time + +Perfect for typical vMCP deployments (hundreds to thousands of tools). + +## Testing + +Run the unit tests: + +```bash +# Test all packages +go test ./cmd/thv-operator/pkg/optimizer/... + +# Test with coverage +go test -cover ./cmd/thv-operator/pkg/optimizer/... + +# Test specific package +go test ./cmd/thv-operator/pkg/optimizer/models +``` + +## Inspecting the Database + +The optimizer uses a hybrid database (chromem-go + SQLite FTS5). Here's how to inspect each: + +### Inspecting SQLite FTS5 (Easiest) + +The FTS5 database is standard SQLite and can be opened with any SQLite tool: + +```bash +# Use sqlite3 CLI +sqlite3 /tmp/vmcp-optimizer-fts.db + +# Count documents +SELECT COUNT(*) FROM backend_servers_fts; +SELECT COUNT(*) FROM backend_tools_fts; + +# View tool names and descriptions +SELECT tool_name, tool_description FROM backend_tools_fts LIMIT 10; + +# Full-text search with BM25 ranking +SELECT tool_name, rank +FROM backend_tool_fts_index +WHERE backend_tool_fts_index MATCH 'github repository' +ORDER BY rank +LIMIT 5; + +# Join servers and tools +SELECT s.name, t.tool_name, t.tool_description +FROM backend_tools_fts t +JOIN backend_servers_fts s ON t.mcpserver_id = s.id +LIMIT 10; +``` + +**VSCode Extension**: Install `alexcvzz.vscode-sqlite` to view `.db` files directly in VSCode. + +### Inspecting chromem-go (Vector Database) + +chromem-go uses `.gob` binary files. Use the provided inspection scripts: + +```bash +# Quick summary (shows collection sizes and first few documents) +go run scripts/inspect-chromem-raw.go /tmp/vmcp-optimizer-debug.db + +# View specific tool with full metadata and embeddings +go run scripts/view-chromem-tool.go /tmp/vmcp-optimizer-debug.db get_file_contents + +# View all documents (warning: lots of output) +go run scripts/view-chromem-tool.go /tmp/vmcp-optimizer-debug.db + +# Search by content +go run scripts/view-chromem-tool.go /tmp/vmcp-optimizer-debug.db "search" +``` + +### chromem-go Schema + +Each document in chromem-go contains: + +```go +Document { + ID: string // "github" or UUID for tools + Content: string // "tool_name. description..." + Embedding: []float32 // 384-dimensional vector + Metadata: map[string]string // {"type": "backend_tool", "server_id": "github", "data": "...JSON..."} +} +``` + +**Collections**: +- `backend_servers`: Server metadata (3 documents in typical setup) +- `backend_tools`: Tool metadata and embeddings (40+ documents) + +## Known Limitations + +1. **Scale**: Optimized for <100,000 tools (more than sufficient for typical vMCP deployments) +2. **Approximate Search**: chromem-go uses exhaustive search (not HNSW), but this is fine for our scale +3. **Persistence Format**: Binary gob format (not human-readable) + +## License + +This package is part of ToolHive and follows the same license. diff --git a/cmd/thv-operator/pkg/optimizer/db/backend_server.go b/cmd/thv-operator/pkg/optimizer/db/backend_server.go new file mode 100644 index 0000000000..296969f07d --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/db/backend_server.go @@ -0,0 +1,243 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package db provides chromem-go based database operations for the optimizer. +package db + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/philippgille/chromem-go" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models" + "github.com/stacklok/toolhive/pkg/logger" +) + +// BackendServerOps provides operations for backend servers in chromem-go +type BackendServerOps struct { + db *DB + embeddingFunc chromem.EmbeddingFunc +} + +// NewBackendServerOps creates a new BackendServerOps instance +func NewBackendServerOps(db *DB, embeddingFunc chromem.EmbeddingFunc) *BackendServerOps { + return &BackendServerOps{ + db: db, + embeddingFunc: embeddingFunc, + } +} + +// Create adds a new backend server to the collection +func (ops *BackendServerOps) Create(ctx context.Context, server *models.BackendServer) error { + collection, err := ops.db.GetOrCreateCollection(ctx, BackendServerCollection, ops.embeddingFunc) + if err != nil { + return fmt.Errorf("failed to get backend server collection: %w", err) + } + + // Prepare content for embedding (name + description) + content := server.Name + if server.Description != nil && *server.Description != "" { + content += ". " + *server.Description + } + + // Serialize metadata + metadata, err := serializeServerMetadata(server) + if err != nil { + return fmt.Errorf("failed to serialize server metadata: %w", err) + } + + // Create document + doc := chromem.Document{ + ID: server.ID, + Content: content, + Metadata: metadata, + } + + // If embedding is provided, use it + if len(server.ServerEmbedding) > 0 { + doc.Embedding = server.ServerEmbedding + } + + // Add document to chromem-go collection + err = collection.AddDocument(ctx, doc) + if err != nil { + return fmt.Errorf("failed to add server document to chromem-go: %w", err) + } + + // Also add to FTS5 database if available (for keyword filtering) + // Use background context to avoid cancellation issues - FTS5 is supplementary + if ftsDB := ops.db.GetFTSDB(); ftsDB != nil { + // Use background context with timeout for FTS operations + // This ensures FTS operations complete even if the original context is canceled + ftsCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := ftsDB.UpsertServer(ftsCtx, server); err != nil { + // Log but don't fail - FTS5 is supplementary + logger.Warnf("Failed to upsert server to FTS5: %v", err) + } + } + + logger.Debugf("Created backend server: %s (chromem-go + FTS5)", server.ID) + return nil +} + +// Get retrieves a backend server by ID +func (ops *BackendServerOps) Get(ctx context.Context, serverID string) (*models.BackendServer, error) { + collection, err := ops.db.GetCollection(BackendServerCollection, ops.embeddingFunc) + if err != nil { + return nil, fmt.Errorf("backend server collection not found: %w", err) + } + + // Query by ID with exact match + results, err := collection.Query(ctx, serverID, 1, nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to query server: %w", err) + } + + if len(results) == 0 { + return nil, fmt.Errorf("server not found: %s", serverID) + } + + // Deserialize from metadata + server, err := deserializeServerMetadata(results[0].Metadata) + if err != nil { + return nil, fmt.Errorf("failed to deserialize server: %w", err) + } + + return server, nil +} + +// Update updates an existing backend server +func (ops *BackendServerOps) Update(ctx context.Context, server *models.BackendServer) error { + // chromem-go doesn't have an update operation, so we delete and re-create + err := ops.Delete(ctx, server.ID) + if err != nil { + // If server doesn't exist, that's fine + logger.Debugf("Server %s not found for update, will create new", server.ID) + } + + return ops.Create(ctx, server) +} + +// Delete removes a backend server +func (ops *BackendServerOps) Delete(ctx context.Context, serverID string) error { + collection, err := ops.db.GetCollection(BackendServerCollection, ops.embeddingFunc) + if err != nil { + // Collection doesn't exist, nothing to delete + return nil + } + + err = collection.Delete(ctx, nil, nil, serverID) + if err != nil { + return fmt.Errorf("failed to delete server from chromem-go: %w", err) + } + + // Also delete from FTS5 database if available + if ftsDB := ops.db.GetFTSDB(); ftsDB != nil { + if err := ftsDB.DeleteServer(ctx, serverID); err != nil { + // Log but don't fail + logger.Warnf("Failed to delete server from FTS5: %v", err) + } + } + + logger.Debugf("Deleted backend server: %s (chromem-go + FTS5)", serverID) + return nil +} + +// List returns all backend servers +func (ops *BackendServerOps) List(ctx context.Context) ([]*models.BackendServer, error) { + collection, err := ops.db.GetCollection(BackendServerCollection, ops.embeddingFunc) + if err != nil { + // Collection doesn't exist yet, return empty list + return []*models.BackendServer{}, nil + } + + // Get count to determine nResults + count := collection.Count() + if count == 0 { + return []*models.BackendServer{}, nil + } + + // Query with a generic term to get all servers + // Using "server" as a generic query that should match all servers + results, err := collection.Query(ctx, "server", count, nil, nil) + if err != nil { + return []*models.BackendServer{}, nil + } + + servers := make([]*models.BackendServer, 0, len(results)) + for _, result := range results { + server, err := deserializeServerMetadata(result.Metadata) + if err != nil { + logger.Warnf("Failed to deserialize server: %v", err) + continue + } + servers = append(servers, server) + } + + return servers, nil +} + +// Search performs semantic search for backend servers +func (ops *BackendServerOps) Search(ctx context.Context, query string, limit int) ([]*models.BackendServer, error) { + collection, err := ops.db.GetCollection(BackendServerCollection, ops.embeddingFunc) + if err != nil { + return []*models.BackendServer{}, nil + } + + // Get collection count and adjust limit if necessary + count := collection.Count() + if count == 0 { + return []*models.BackendServer{}, nil + } + if limit > count { + limit = count + } + + results, err := collection.Query(ctx, query, limit, nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to search servers: %w", err) + } + + servers := make([]*models.BackendServer, 0, len(results)) + for _, result := range results { + server, err := deserializeServerMetadata(result.Metadata) + if err != nil { + logger.Warnf("Failed to deserialize server: %v", err) + continue + } + servers = append(servers, server) + } + + return servers, nil +} + +// Helper functions for metadata serialization + +func serializeServerMetadata(server *models.BackendServer) (map[string]string, error) { + data, err := json.Marshal(server) + if err != nil { + return nil, err + } + return map[string]string{ + "data": string(data), + "type": "backend_server", + }, nil +} + +func deserializeServerMetadata(metadata map[string]string) (*models.BackendServer, error) { + data, ok := metadata["data"] + if !ok { + return nil, fmt.Errorf("missing data field in metadata") + } + + var server models.BackendServer + if err := json.Unmarshal([]byte(data), &server); err != nil { + return nil, err + } + + return &server, nil +} diff --git a/cmd/thv-operator/pkg/optimizer/db/backend_server_test.go b/cmd/thv-operator/pkg/optimizer/db/backend_server_test.go new file mode 100644 index 0000000000..9cc9a8aa43 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/db/backend_server_test.go @@ -0,0 +1,427 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models" +) + +// TestBackendServerOps_Create tests creating a backend server +func TestBackendServerOps_Create(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + description := "A test MCP server" + server := &models.BackendServer{ + ID: "server-1", + Name: "Test Server", + Description: &description, + Group: "default", + } + + err := ops.Create(ctx, server) + require.NoError(t, err) + + // Verify server was created by retrieving it + retrieved, err := ops.Get(ctx, "server-1") + require.NoError(t, err) + assert.Equal(t, "Test Server", retrieved.Name) + assert.Equal(t, "server-1", retrieved.ID) + assert.Equal(t, description, *retrieved.Description) +} + +// TestBackendServerOps_CreateWithEmbedding tests creating server with precomputed embedding +func TestBackendServerOps_CreateWithEmbedding(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + description := "Server with embedding" + embedding := make([]float32, 384) + for i := range embedding { + embedding[i] = 0.5 + } + + server := &models.BackendServer{ + ID: "server-2", + Name: "Embedded Server", + Description: &description, + Group: "default", + ServerEmbedding: embedding, + } + + err := ops.Create(ctx, server) + require.NoError(t, err) + + // Verify server was created + retrieved, err := ops.Get(ctx, "server-2") + require.NoError(t, err) + assert.Equal(t, "Embedded Server", retrieved.Name) +} + +// TestBackendServerOps_Get tests retrieving a backend server +func TestBackendServerOps_Get(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Create a server first + description := "GitHub MCP server" + server := &models.BackendServer{ + ID: "github-server", + Name: "GitHub", + Description: &description, + Group: "development", + } + + err := ops.Create(ctx, server) + require.NoError(t, err) + + // Test Get + retrieved, err := ops.Get(ctx, "github-server") + require.NoError(t, err) + assert.Equal(t, "github-server", retrieved.ID) + assert.Equal(t, "GitHub", retrieved.Name) + assert.Equal(t, "development", retrieved.Group) +} + +// TestBackendServerOps_Get_NotFound tests retrieving non-existent server +func TestBackendServerOps_Get_NotFound(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Try to get a non-existent server + _, err := ops.Get(ctx, "non-existent") + assert.Error(t, err) + // Error message could be "server not found" or "collection not found" depending on state + assert.True(t, err != nil, "Should return an error for non-existent server") +} + +// TestBackendServerOps_Update tests updating a backend server +func TestBackendServerOps_Update(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Create initial server + description := "Original description" + server := &models.BackendServer{ + ID: "server-1", + Name: "Original Name", + Description: &description, + Group: "default", + } + + err := ops.Create(ctx, server) + require.NoError(t, err) + + // Update the server + updatedDescription := "Updated description" + server.Name = "Updated Name" + server.Description = &updatedDescription + + err = ops.Update(ctx, server) + require.NoError(t, err) + + // Verify update + retrieved, err := ops.Get(ctx, "server-1") + require.NoError(t, err) + assert.Equal(t, "Updated Name", retrieved.Name) + assert.Equal(t, "Updated description", *retrieved.Description) +} + +// TestBackendServerOps_Update_NonExistent tests updating non-existent server +func TestBackendServerOps_Update_NonExistent(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Try to update non-existent server (should create it) + description := "New server" + server := &models.BackendServer{ + ID: "new-server", + Name: "New Server", + Description: &description, + Group: "default", + } + + err := ops.Update(ctx, server) + require.NoError(t, err) + + // Verify server was created + retrieved, err := ops.Get(ctx, "new-server") + require.NoError(t, err) + assert.Equal(t, "New Server", retrieved.Name) +} + +// TestBackendServerOps_Delete tests deleting a backend server +func TestBackendServerOps_Delete(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Create a server + description := "Server to delete" + server := &models.BackendServer{ + ID: "delete-me", + Name: "Delete Me", + Description: &description, + Group: "default", + } + + err := ops.Create(ctx, server) + require.NoError(t, err) + + // Delete the server + err = ops.Delete(ctx, "delete-me") + require.NoError(t, err) + + // Verify deletion + _, err = ops.Get(ctx, "delete-me") + assert.Error(t, err, "Should not find deleted server") +} + +// TestBackendServerOps_Delete_NonExistent tests deleting non-existent server +func TestBackendServerOps_Delete_NonExistent(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Try to delete a non-existent server - should not error + err := ops.Delete(ctx, "non-existent") + assert.NoError(t, err) +} + +// TestBackendServerOps_List tests listing all servers +func TestBackendServerOps_List(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Create multiple servers + desc1 := "Server 1" + server1 := &models.BackendServer{ + ID: "server-1", + Name: "Server 1", + Description: &desc1, + Group: "group-a", + } + + desc2 := "Server 2" + server2 := &models.BackendServer{ + ID: "server-2", + Name: "Server 2", + Description: &desc2, + Group: "group-b", + } + + desc3 := "Server 3" + server3 := &models.BackendServer{ + ID: "server-3", + Name: "Server 3", + Description: &desc3, + Group: "group-a", + } + + err := ops.Create(ctx, server1) + require.NoError(t, err) + err = ops.Create(ctx, server2) + require.NoError(t, err) + err = ops.Create(ctx, server3) + require.NoError(t, err) + + // List all servers + servers, err := ops.List(ctx) + require.NoError(t, err) + assert.Len(t, servers, 3, "Should have 3 servers") + + // Verify server names + serverNames := make(map[string]bool) + for _, server := range servers { + serverNames[server.Name] = true + } + assert.True(t, serverNames["Server 1"]) + assert.True(t, serverNames["Server 2"]) + assert.True(t, serverNames["Server 3"]) +} + +// TestBackendServerOps_List_Empty tests listing servers on empty database +func TestBackendServerOps_List_Empty(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // List empty database + servers, err := ops.List(ctx) + require.NoError(t, err) + assert.Empty(t, servers, "Should return empty list for empty database") +} + +// TestBackendServerOps_Search tests semantic search for servers +func TestBackendServerOps_Search(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Create test servers + desc1 := "GitHub integration server" + server1 := &models.BackendServer{ + ID: "github", + Name: "GitHub Server", + Description: &desc1, + Group: "vcs", + } + + desc2 := "Slack messaging server" + server2 := &models.BackendServer{ + ID: "slack", + Name: "Slack Server", + Description: &desc2, + Group: "messaging", + } + + err := ops.Create(ctx, server1) + require.NoError(t, err) + err = ops.Create(ctx, server2) + require.NoError(t, err) + + // Search for servers + results, err := ops.Search(ctx, "integration", 5) + require.NoError(t, err) + assert.NotEmpty(t, results, "Should find servers") +} + +// TestBackendServerOps_Search_Empty tests search on empty database +func TestBackendServerOps_Search_Empty(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Search empty database + results, err := ops.Search(ctx, "anything", 5) + require.NoError(t, err) + assert.Empty(t, results, "Should return empty results for empty database") +} + +// TestBackendServerOps_MetadataSerialization tests metadata serialization/deserialization +func TestBackendServerOps_MetadataSerialization(t *testing.T) { + t.Parallel() + + description := "Test server" + server := &models.BackendServer{ + ID: "server-1", + Name: "Test Server", + Description: &description, + Group: "default", + } + + // Test serialization + metadata, err := serializeServerMetadata(server) + require.NoError(t, err) + assert.Contains(t, metadata, "data") + assert.Equal(t, "backend_server", metadata["type"]) + + // Test deserialization + deserializedServer, err := deserializeServerMetadata(metadata) + require.NoError(t, err) + assert.Equal(t, server.ID, deserializedServer.ID) + assert.Equal(t, server.Name, deserializedServer.Name) + assert.Equal(t, server.Group, deserializedServer.Group) +} + +// TestBackendServerOps_MetadataDeserialization_MissingData tests error handling +func TestBackendServerOps_MetadataDeserialization_MissingData(t *testing.T) { + t.Parallel() + + // Test with missing data field + metadata := map[string]string{ + "type": "backend_server", + } + + _, err := deserializeServerMetadata(metadata) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing data field") +} + +// TestBackendServerOps_MetadataDeserialization_InvalidJSON tests invalid JSON handling +func TestBackendServerOps_MetadataDeserialization_InvalidJSON(t *testing.T) { + t.Parallel() + + // Test with invalid JSON + metadata := map[string]string{ + "data": "invalid json {", + "type": "backend_server", + } + + _, err := deserializeServerMetadata(metadata) + assert.Error(t, err) +} diff --git a/cmd/thv-operator/pkg/optimizer/db/backend_server_test_coverage.go b/cmd/thv-operator/pkg/optimizer/db/backend_server_test_coverage.go new file mode 100644 index 0000000000..055b6a3353 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/db/backend_server_test_coverage.go @@ -0,0 +1,97 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models" +) + +// TestBackendServerOps_Create_FTS tests FTS integration in Create +func TestBackendServerOps_Create_FTS(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + config := &Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + FTSDBPath: filepath.Join(tmpDir, "fts.db"), + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + ops := NewBackendServerOps(db, embeddingFunc) + + server := &models.BackendServer{ + ID: "server-1", + Name: "Test Server", + Description: stringPtr("A test server"), + Group: "default", + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + // Create should also update FTS + err = ops.Create(ctx, server) + require.NoError(t, err) + + // Verify FTS was updated by checking FTS DB directly + ftsDB := db.GetFTSDB() + require.NotNil(t, ftsDB) + + // FTS should have the server + // We can't easily query FTS directly, but we can verify it doesn't error +} + +// TestBackendServerOps_Delete_FTS tests FTS integration in Delete +func TestBackendServerOps_Delete_FTS(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + config := &Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + FTSDBPath: filepath.Join(tmpDir, "fts.db"), + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + ops := NewBackendServerOps(db, embeddingFunc) + + desc := "A test server" + server := &models.BackendServer{ + ID: "server-1", + Name: "Test Server", + Description: &desc, + Group: "default", + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + // Create server + err = ops.Create(ctx, server) + require.NoError(t, err) + + // Delete should also delete from FTS + err = ops.Delete(ctx, server.ID) + require.NoError(t, err) +} diff --git a/cmd/thv-operator/pkg/optimizer/db/backend_tool.go b/cmd/thv-operator/pkg/optimizer/db/backend_tool.go new file mode 100644 index 0000000000..3dfa860f1a --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/db/backend_tool.go @@ -0,0 +1,319 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/philippgille/chromem-go" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models" + "github.com/stacklok/toolhive/pkg/logger" +) + +// BackendToolOps provides operations for backend tools in chromem-go +type BackendToolOps struct { + db *DB + embeddingFunc chromem.EmbeddingFunc +} + +// NewBackendToolOps creates a new BackendToolOps instance +func NewBackendToolOps(db *DB, embeddingFunc chromem.EmbeddingFunc) *BackendToolOps { + return &BackendToolOps{ + db: db, + embeddingFunc: embeddingFunc, + } +} + +// Create adds a new backend tool to the collection +func (ops *BackendToolOps) Create(ctx context.Context, tool *models.BackendTool, serverName string) error { + collection, err := ops.db.GetOrCreateCollection(ctx, BackendToolCollection, ops.embeddingFunc) + if err != nil { + return fmt.Errorf("failed to get backend tool collection: %w", err) + } + + // Prepare content for embedding (name + description + input schema summary) + content := tool.ToolName + if tool.Description != nil && *tool.Description != "" { + content += ". " + *tool.Description + } + + // Serialize metadata + metadata, err := serializeToolMetadata(tool) + if err != nil { + return fmt.Errorf("failed to serialize tool metadata: %w", err) + } + + // Create document + doc := chromem.Document{ + ID: tool.ID, + Content: content, + Metadata: metadata, + } + + // If embedding is provided, use it + if len(tool.ToolEmbedding) > 0 { + doc.Embedding = tool.ToolEmbedding + } + + // Add document to chromem-go collection + err = collection.AddDocument(ctx, doc) + if err != nil { + return fmt.Errorf("failed to add tool document to chromem-go: %w", err) + } + + // Also add to FTS5 database if available (for BM25 search) + // Use background context to avoid cancellation issues - FTS5 is supplementary + if ops.db.fts != nil { + // Use background context with timeout for FTS operations + // This ensures FTS operations complete even if the original context is canceled + ftsCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := ops.db.fts.UpsertToolMeta(ftsCtx, tool, serverName); err != nil { + // Log but don't fail - FTS5 is supplementary + logger.Warnf("Failed to upsert tool to FTS5: %v", err) + } + } + + logger.Debugf("Created backend tool: %s (chromem-go + FTS5)", tool.ID) + return nil +} + +// Get retrieves a backend tool by ID +func (ops *BackendToolOps) Get(ctx context.Context, toolID string) (*models.BackendTool, error) { + collection, err := ops.db.GetCollection(BackendToolCollection, ops.embeddingFunc) + if err != nil { + return nil, fmt.Errorf("backend tool collection not found: %w", err) + } + + // Query by ID with exact match + results, err := collection.Query(ctx, toolID, 1, nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to query tool: %w", err) + } + + if len(results) == 0 { + return nil, fmt.Errorf("tool not found: %s", toolID) + } + + // Deserialize from metadata + tool, err := deserializeToolMetadata(results[0].Metadata) + if err != nil { + return nil, fmt.Errorf("failed to deserialize tool: %w", err) + } + + return tool, nil +} + +// Update updates an existing backend tool in chromem-go +// Note: This only updates chromem-go, not FTS5. Use Create to update both. +func (ops *BackendToolOps) Update(ctx context.Context, tool *models.BackendTool) error { + collection, err := ops.db.GetOrCreateCollection(ctx, BackendToolCollection, ops.embeddingFunc) + if err != nil { + return fmt.Errorf("failed to get backend tool collection: %w", err) + } + + // Prepare content for embedding + content := tool.ToolName + if tool.Description != nil && *tool.Description != "" { + content += ". " + *tool.Description + } + + // Serialize metadata + metadata, err := serializeToolMetadata(tool) + if err != nil { + return fmt.Errorf("failed to serialize tool metadata: %w", err) + } + + // Delete existing document + _ = collection.Delete(ctx, nil, nil, tool.ID) // Ignore error if doesn't exist + + // Create updated document + doc := chromem.Document{ + ID: tool.ID, + Content: content, + Metadata: metadata, + } + + if len(tool.ToolEmbedding) > 0 { + doc.Embedding = tool.ToolEmbedding + } + + err = collection.AddDocument(ctx, doc) + if err != nil { + return fmt.Errorf("failed to update tool document: %w", err) + } + + logger.Debugf("Updated backend tool: %s", tool.ID) + return nil +} + +// Delete removes a backend tool +func (ops *BackendToolOps) Delete(ctx context.Context, toolID string) error { + collection, err := ops.db.GetCollection(BackendToolCollection, ops.embeddingFunc) + if err != nil { + // Collection doesn't exist, nothing to delete + return nil + } + + err = collection.Delete(ctx, nil, nil, toolID) + if err != nil { + return fmt.Errorf("failed to delete tool: %w", err) + } + + logger.Debugf("Deleted backend tool: %s", toolID) + return nil +} + +// DeleteByServer removes all tools for a given server from both chromem-go and FTS5 +func (ops *BackendToolOps) DeleteByServer(ctx context.Context, serverID string) error { + collection, err := ops.db.GetCollection(BackendToolCollection, ops.embeddingFunc) + if err != nil { + // Collection doesn't exist, nothing to delete in chromem-go + logger.Debug("Backend tool collection not found, skipping chromem-go deletion") + } else { + // Query all tools for this server + tools, err := ops.ListByServer(ctx, serverID) + if err != nil { + return fmt.Errorf("failed to list tools for server: %w", err) + } + + // Delete each tool from chromem-go + for _, tool := range tools { + if err := collection.Delete(ctx, nil, nil, tool.ID); err != nil { + logger.Warnf("Failed to delete tool %s from chromem-go: %v", tool.ID, err) + } + } + + logger.Debugf("Deleted %d tools from chromem-go for server: %s", len(tools), serverID) + } + + // Also delete from FTS5 database if available + if ops.db.fts != nil { + if err := ops.db.fts.DeleteToolsByServer(ctx, serverID); err != nil { + logger.Warnf("Failed to delete tools from FTS5 for server %s: %v", serverID, err) + } else { + logger.Debugf("Deleted tools from FTS5 for server: %s", serverID) + } + } + + return nil +} + +// ListByServer returns all tools for a given server +func (ops *BackendToolOps) ListByServer(ctx context.Context, serverID string) ([]*models.BackendTool, error) { + collection, err := ops.db.GetCollection(BackendToolCollection, ops.embeddingFunc) + if err != nil { + // Collection doesn't exist yet, return empty list + return []*models.BackendTool{}, nil + } + + // Get count to determine nResults + count := collection.Count() + if count == 0 { + return []*models.BackendTool{}, nil + } + + // Query with a generic term and metadata filter + // Using "tool" as a generic query that should match all tools + results, err := collection.Query(ctx, "tool", count, map[string]string{"server_id": serverID}, nil) + if err != nil { + // If no tools match, return empty list + return []*models.BackendTool{}, nil + } + + tools := make([]*models.BackendTool, 0, len(results)) + for _, result := range results { + tool, err := deserializeToolMetadata(result.Metadata) + if err != nil { + logger.Warnf("Failed to deserialize tool: %v", err) + continue + } + tools = append(tools, tool) + } + + return tools, nil +} + +// Search performs semantic search for backend tools +func (ops *BackendToolOps) Search( + ctx context.Context, + query string, + limit int, + serverID *string, +) ([]*models.BackendToolWithMetadata, error) { + collection, err := ops.db.GetCollection(BackendToolCollection, ops.embeddingFunc) + if err != nil { + return []*models.BackendToolWithMetadata{}, nil + } + + // Get collection count and adjust limit if necessary + count := collection.Count() + if count == 0 { + return []*models.BackendToolWithMetadata{}, nil + } + if limit > count { + limit = count + } + + // Build metadata filter if server ID is provided + var metadataFilter map[string]string + if serverID != nil { + metadataFilter = map[string]string{"server_id": *serverID} + } + + results, err := collection.Query(ctx, query, limit, metadataFilter, nil) + if err != nil { + return nil, fmt.Errorf("failed to search tools: %w", err) + } + + tools := make([]*models.BackendToolWithMetadata, 0, len(results)) + for _, result := range results { + tool, err := deserializeToolMetadata(result.Metadata) + if err != nil { + logger.Warnf("Failed to deserialize tool: %v", err) + continue + } + + // Add similarity score + toolWithMeta := &models.BackendToolWithMetadata{ + BackendTool: *tool, + Similarity: result.Similarity, + } + tools = append(tools, toolWithMeta) + } + + return tools, nil +} + +// Helper functions for metadata serialization + +func serializeToolMetadata(tool *models.BackendTool) (map[string]string, error) { + data, err := json.Marshal(tool) + if err != nil { + return nil, err + } + return map[string]string{ + "data": string(data), + "type": "backend_tool", + "server_id": tool.MCPServerID, + }, nil +} + +func deserializeToolMetadata(metadata map[string]string) (*models.BackendTool, error) { + data, ok := metadata["data"] + if !ok { + return nil, fmt.Errorf("missing data field in metadata") + } + + var tool models.BackendTool + if err := json.Unmarshal([]byte(data), &tool); err != nil { + return nil, err + } + + return &tool, nil +} diff --git a/cmd/thv-operator/pkg/optimizer/db/backend_tool_test.go b/cmd/thv-operator/pkg/optimizer/db/backend_tool_test.go new file mode 100644 index 0000000000..4f9a58b01e --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/db/backend_tool_test.go @@ -0,0 +1,590 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings" + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models" +) + +// createTestDB creates a test database +func createTestDB(t *testing.T) *DB { + t.Helper() + tmpDir := t.TempDir() + + config := &Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + } + + db, err := NewDB(config) + require.NoError(t, err) + + return db +} + +// createTestEmbeddingFunc creates a test embedding function using Ollama embeddings +func createTestEmbeddingFunc(t *testing.T) func(ctx context.Context, text string) ([]float32, error) { + t.Helper() + + // Try to use Ollama if available, otherwise skip test + config := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + manager, err := embeddings.NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return nil + } + t.Cleanup(func() { _ = manager.Close() }) + + return func(_ context.Context, text string) ([]float32, error) { + results, err := manager.GenerateEmbedding([]string{text}) + if err != nil { + return nil, err + } + if len(results) == 0 { + return nil, assert.AnError + } + return results[0], nil + } +} + +// TestBackendToolOps_Create tests creating a backend tool +func TestBackendToolOps_Create(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + description := "Get current weather information" + tool := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "get_weather", + Description: &description, + InputSchema: []byte(`{"type":"object","properties":{"location":{"type":"string"}}}`), + TokenCount: 100, + } + + err := ops.Create(ctx, tool, "Test Server") + require.NoError(t, err) + + // Verify tool was created by retrieving it + retrieved, err := ops.Get(ctx, "tool-1") + require.NoError(t, err) + assert.Equal(t, "get_weather", retrieved.ToolName) + assert.Equal(t, "server-1", retrieved.MCPServerID) + assert.Equal(t, description, *retrieved.Description) +} + +// TestBackendToolOps_CreateWithPrecomputedEmbedding tests creating tool with existing embedding +func TestBackendToolOps_CreateWithPrecomputedEmbedding(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + description := "Search the web" + // Generate a precomputed embedding + precomputedEmbedding := make([]float32, 384) + for i := range precomputedEmbedding { + precomputedEmbedding[i] = 0.1 + } + + tool := &models.BackendTool{ + ID: "tool-2", + MCPServerID: "server-1", + ToolName: "search_web", + Description: &description, + InputSchema: []byte(`{}`), + ToolEmbedding: precomputedEmbedding, + TokenCount: 50, + } + + err := ops.Create(ctx, tool, "Test Server") + require.NoError(t, err) + + // Verify tool was created + retrieved, err := ops.Get(ctx, "tool-2") + require.NoError(t, err) + assert.Equal(t, "search_web", retrieved.ToolName) +} + +// TestBackendToolOps_Get tests retrieving a backend tool +func TestBackendToolOps_Get(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Create a tool first + description := "Send an email" + tool := &models.BackendTool{ + ID: "tool-3", + MCPServerID: "server-1", + ToolName: "send_email", + Description: &description, + InputSchema: []byte(`{}`), + TokenCount: 75, + } + + err := ops.Create(ctx, tool, "Test Server") + require.NoError(t, err) + + // Test Get + retrieved, err := ops.Get(ctx, "tool-3") + require.NoError(t, err) + assert.Equal(t, "tool-3", retrieved.ID) + assert.Equal(t, "send_email", retrieved.ToolName) +} + +// TestBackendToolOps_Get_NotFound tests retrieving non-existent tool +func TestBackendToolOps_Get_NotFound(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Try to get a non-existent tool + _, err := ops.Get(ctx, "non-existent") + assert.Error(t, err) +} + +// TestBackendToolOps_Update tests updating a backend tool +func TestBackendToolOps_Update(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Create initial tool + description := "Original description" + tool := &models.BackendTool{ + ID: "tool-4", + MCPServerID: "server-1", + ToolName: "test_tool", + Description: &description, + InputSchema: []byte(`{}`), + TokenCount: 50, + } + + err := ops.Create(ctx, tool, "Test Server") + require.NoError(t, err) + + // Update the tool + const updatedDescription = "Updated description" + updatedDescriptionCopy := updatedDescription + tool.Description = &updatedDescriptionCopy + tool.TokenCount = 75 + + err = ops.Update(ctx, tool) + require.NoError(t, err) + + // Verify update + retrieved, err := ops.Get(ctx, "tool-4") + require.NoError(t, err) + assert.Equal(t, "Updated description", *retrieved.Description) +} + +// TestBackendToolOps_Delete tests deleting a backend tool +func TestBackendToolOps_Delete(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Create a tool + description := "Tool to delete" + tool := &models.BackendTool{ + ID: "tool-5", + MCPServerID: "server-1", + ToolName: "delete_me", + Description: &description, + InputSchema: []byte(`{}`), + TokenCount: 25, + } + + err := ops.Create(ctx, tool, "Test Server") + require.NoError(t, err) + + // Delete the tool + err = ops.Delete(ctx, "tool-5") + require.NoError(t, err) + + // Verify deletion + _, err = ops.Get(ctx, "tool-5") + assert.Error(t, err, "Should not find deleted tool") +} + +// TestBackendToolOps_Delete_NonExistent tests deleting non-existent tool +func TestBackendToolOps_Delete_NonExistent(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Try to delete a non-existent tool - should not error + err := ops.Delete(ctx, "non-existent") + // Delete may or may not error depending on implementation + // Just ensure it doesn't panic + _ = err +} + +// TestBackendToolOps_ListByServer tests listing tools for a server +func TestBackendToolOps_ListByServer(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Create multiple tools for different servers + desc1 := "Tool 1" + tool1 := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "tool_1", + Description: &desc1, + InputSchema: []byte(`{}`), + TokenCount: 10, + } + + desc2 := "Tool 2" + tool2 := &models.BackendTool{ + ID: "tool-2", + MCPServerID: "server-1", + ToolName: "tool_2", + Description: &desc2, + InputSchema: []byte(`{}`), + TokenCount: 20, + } + + desc3 := "Tool 3" + tool3 := &models.BackendTool{ + ID: "tool-3", + MCPServerID: "server-2", + ToolName: "tool_3", + Description: &desc3, + InputSchema: []byte(`{}`), + TokenCount: 30, + } + + err := ops.Create(ctx, tool1, "Server 1") + require.NoError(t, err) + err = ops.Create(ctx, tool2, "Server 1") + require.NoError(t, err) + err = ops.Create(ctx, tool3, "Server 2") + require.NoError(t, err) + + // List tools for server-1 + tools, err := ops.ListByServer(ctx, "server-1") + require.NoError(t, err) + assert.Len(t, tools, 2, "Should have 2 tools for server-1") + + // Verify tool names + toolNames := make(map[string]bool) + for _, tool := range tools { + toolNames[tool.ToolName] = true + } + assert.True(t, toolNames["tool_1"]) + assert.True(t, toolNames["tool_2"]) + + // List tools for server-2 + tools, err = ops.ListByServer(ctx, "server-2") + require.NoError(t, err) + assert.Len(t, tools, 1, "Should have 1 tool for server-2") + assert.Equal(t, "tool_3", tools[0].ToolName) +} + +// TestBackendToolOps_ListByServer_Empty tests listing tools for server with no tools +func TestBackendToolOps_ListByServer_Empty(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // List tools for non-existent server + tools, err := ops.ListByServer(ctx, "non-existent-server") + require.NoError(t, err) + assert.Empty(t, tools, "Should return empty list for server with no tools") +} + +// TestBackendToolOps_DeleteByServer tests deleting all tools for a server +func TestBackendToolOps_DeleteByServer(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Create tools for two servers + desc1 := "Tool 1" + tool1 := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "tool_1", + Description: &desc1, + InputSchema: []byte(`{}`), + TokenCount: 10, + } + + desc2 := "Tool 2" + tool2 := &models.BackendTool{ + ID: "tool-2", + MCPServerID: "server-1", + ToolName: "tool_2", + Description: &desc2, + InputSchema: []byte(`{}`), + TokenCount: 20, + } + + desc3 := "Tool 3" + tool3 := &models.BackendTool{ + ID: "tool-3", + MCPServerID: "server-2", + ToolName: "tool_3", + Description: &desc3, + InputSchema: []byte(`{}`), + TokenCount: 30, + } + + err := ops.Create(ctx, tool1, "Server 1") + require.NoError(t, err) + err = ops.Create(ctx, tool2, "Server 1") + require.NoError(t, err) + err = ops.Create(ctx, tool3, "Server 2") + require.NoError(t, err) + + // Delete all tools for server-1 + err = ops.DeleteByServer(ctx, "server-1") + require.NoError(t, err) + + // Verify server-1 tools are deleted + tools, err := ops.ListByServer(ctx, "server-1") + require.NoError(t, err) + assert.Empty(t, tools, "All server-1 tools should be deleted") + + // Verify server-2 tools are still present + tools, err = ops.ListByServer(ctx, "server-2") + require.NoError(t, err) + assert.Len(t, tools, 1, "Server-2 tools should remain") +} + +// TestBackendToolOps_Search tests semantic search for tools +func TestBackendToolOps_Search(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Create test tools + desc1 := "Get current weather conditions" + tool1 := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "get_weather", + Description: &desc1, + InputSchema: []byte(`{}`), + TokenCount: 50, + } + + desc2 := "Send email message" + tool2 := &models.BackendTool{ + ID: "tool-2", + MCPServerID: "server-1", + ToolName: "send_email", + Description: &desc2, + InputSchema: []byte(`{}`), + TokenCount: 40, + } + + err := ops.Create(ctx, tool1, "Server 1") + require.NoError(t, err) + err = ops.Create(ctx, tool2, "Server 1") + require.NoError(t, err) + + // Search for tools + results, err := ops.Search(ctx, "weather information", 5, nil) + require.NoError(t, err) + assert.NotEmpty(t, results, "Should find tools") + + // Weather tool should be most similar to weather query + assert.NotEmpty(t, results, "Should find at least one tool") + if len(results) > 0 { + assert.Equal(t, "get_weather", results[0].ToolName, + "Weather tool should be most similar to weather query") + } +} + +// TestBackendToolOps_Search_WithServerFilter tests search with server ID filter +func TestBackendToolOps_Search_WithServerFilter(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Create tools for different servers + desc1 := "Weather tool" + tool1 := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "get_weather", + Description: &desc1, + InputSchema: []byte(`{}`), + TokenCount: 50, + } + + desc2 := "Email tool" + tool2 := &models.BackendTool{ + ID: "tool-2", + MCPServerID: "server-2", + ToolName: "send_email", + Description: &desc2, + InputSchema: []byte(`{}`), + TokenCount: 40, + } + + err := ops.Create(ctx, tool1, "Server 1") + require.NoError(t, err) + err = ops.Create(ctx, tool2, "Server 2") + require.NoError(t, err) + + // Search with server filter + serverID := "server-1" + results, err := ops.Search(ctx, "tool", 5, &serverID) + require.NoError(t, err) + assert.Len(t, results, 1, "Should only return tools from server-1") + assert.Equal(t, "server-1", results[0].MCPServerID) +} + +// TestBackendToolOps_Search_Empty tests search on empty database +func TestBackendToolOps_Search_Empty(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Search empty database + results, err := ops.Search(ctx, "anything", 5, nil) + require.NoError(t, err) + assert.Empty(t, results, "Should return empty results for empty database") +} + +// TestBackendToolOps_MetadataSerialization tests metadata serialization/deserialization +func TestBackendToolOps_MetadataSerialization(t *testing.T) { + t.Parallel() + + description := "Test tool" + tool := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "test_tool", + Description: &description, + InputSchema: []byte(`{"type":"object"}`), + TokenCount: 100, + } + + // Test serialization + metadata, err := serializeToolMetadata(tool) + require.NoError(t, err) + assert.Contains(t, metadata, "data") + assert.Equal(t, "backend_tool", metadata["type"]) + assert.Equal(t, "server-1", metadata["server_id"]) + + // Test deserialization + deserializedTool, err := deserializeToolMetadata(metadata) + require.NoError(t, err) + assert.Equal(t, tool.ID, deserializedTool.ID) + assert.Equal(t, tool.ToolName, deserializedTool.ToolName) + assert.Equal(t, tool.MCPServerID, deserializedTool.MCPServerID) +} + +// TestBackendToolOps_MetadataDeserialization_MissingData tests error handling +func TestBackendToolOps_MetadataDeserialization_MissingData(t *testing.T) { + t.Parallel() + + // Test with missing data field + metadata := map[string]string{ + "type": "backend_tool", + } + + _, err := deserializeToolMetadata(metadata) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing data field") +} + +// TestBackendToolOps_MetadataDeserialization_InvalidJSON tests invalid JSON handling +func TestBackendToolOps_MetadataDeserialization_InvalidJSON(t *testing.T) { + t.Parallel() + + // Test with invalid JSON + metadata := map[string]string{ + "data": "invalid json {", + "type": "backend_tool", + } + + _, err := deserializeToolMetadata(metadata) + assert.Error(t, err) +} diff --git a/cmd/thv-operator/pkg/optimizer/db/backend_tool_test_coverage.go b/cmd/thv-operator/pkg/optimizer/db/backend_tool_test_coverage.go new file mode 100644 index 0000000000..1e3c7b7e84 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/db/backend_tool_test_coverage.go @@ -0,0 +1,99 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models" +) + +// TestBackendToolOps_Create_FTS tests FTS integration in Create +func TestBackendToolOps_Create_FTS(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + config := &Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + FTSDBPath: filepath.Join(tmpDir, "fts.db"), + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + ops := NewBackendToolOps(db, embeddingFunc) + + desc := "A test tool" + tool := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "test_tool", + Description: &desc, + InputSchema: []byte(`{"type": "object"}`), + TokenCount: 10, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + // Create should also update FTS + err = ops.Create(ctx, tool, "TestServer") + require.NoError(t, err) + + // Verify FTS was updated + ftsDB := db.GetFTSDB() + require.NotNil(t, ftsDB) +} + +// TestBackendToolOps_DeleteByServer_FTS tests FTS integration in DeleteByServer +func TestBackendToolOps_DeleteByServer_FTS(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + config := &Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + FTSDBPath: filepath.Join(tmpDir, "fts.db"), + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + ops := NewBackendToolOps(db, embeddingFunc) + + desc := "A test tool" + tool := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "test_tool", + Description: &desc, + InputSchema: []byte(`{"type": "object"}`), + TokenCount: 10, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + // Create tool + err = ops.Create(ctx, tool, "TestServer") + require.NoError(t, err) + + // DeleteByServer should also delete from FTS + err = ops.DeleteByServer(ctx, "server-1") + require.NoError(t, err) +} diff --git a/cmd/thv-operator/pkg/optimizer/db/db.go b/cmd/thv-operator/pkg/optimizer/db/db.go new file mode 100644 index 0000000000..1e850309ed --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/db/db.go @@ -0,0 +1,215 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "fmt" + "os" + "strings" + "sync" + + "github.com/philippgille/chromem-go" + + "github.com/stacklok/toolhive/pkg/logger" +) + +// Config holds database configuration +// +// The optimizer database is designed to be ephemeral - it's rebuilt from scratch +// on each startup by ingesting MCP backends. Persistence is optional and primarily +// useful for development/debugging to avoid re-generating embeddings. +type Config struct { + // PersistPath is the optional path for chromem-go persistence. + // If empty, chromem-go will be in-memory only (recommended for production). + PersistPath string + + // FTSDBPath is the path for SQLite FTS5 database for BM25 search. + // If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set. + // FTS5 is always enabled for hybrid search. + FTSDBPath string +} + +// DB represents the hybrid database (chromem-go + SQLite FTS5) for optimizer data +type DB struct { + config *Config + chromem *chromem.DB // Vector/semantic search + fts *FTSDatabase // BM25 full-text search (optional) + mu sync.RWMutex +} + +// Collection names +// +// Terminology: We use "backend_servers" and "backend_tools" to be explicit about +// tracking MCP server metadata. While vMCP uses "Backend" for the workload concept, +// the optimizer focuses on the MCP server component for semantic search and tool discovery. +// This naming convention provides clarity across the database layer. +const ( + BackendServerCollection = "backend_servers" + BackendToolCollection = "backend_tools" +) + +// NewDB creates a new chromem-go database with FTS5 for hybrid search +func NewDB(config *Config) (*DB, error) { + var chromemDB *chromem.DB + var err error + + if config.PersistPath != "" { + logger.Infof("Creating chromem-go database with persistence at: %s", config.PersistPath) + chromemDB, err = chromem.NewPersistentDB(config.PersistPath, false) + if err != nil { + // Check if error is due to corrupted database (missing collection metadata) + if strings.Contains(err.Error(), "collection metadata file not found") { + logger.Warnf("Database appears corrupted, attempting to remove and recreate: %v", err) + // Try to remove corrupted database directory + // Use RemoveAll which should handle directories recursively + // If it fails, we'll try to create with a new path or fall back to in-memory + if removeErr := os.RemoveAll(config.PersistPath); removeErr != nil { + logger.Warnf("Failed to remove corrupted database directory (may be in use): %v. Will try to recreate anyway.", removeErr) + // Try to rename the corrupted directory and create a new one + backupPath := config.PersistPath + ".corrupted" + if renameErr := os.Rename(config.PersistPath, backupPath); renameErr != nil { + logger.Warnf("Failed to rename corrupted database: %v. Attempting to create database anyway.", renameErr) + // Continue and let chromem-go handle it - it might work if the corruption is partial + } else { + logger.Infof("Renamed corrupted database to: %s", backupPath) + } + } + // Retry creating the database + chromemDB, err = chromem.NewPersistentDB(config.PersistPath, false) + if err != nil { + // If still failing, return the error but suggest manual cleanup + return nil, fmt.Errorf( + "failed to create persistent database after cleanup attempt. Please manually remove %s and try again: %w", + config.PersistPath, err) + } + logger.Info("Successfully recreated database after cleanup") + } else { + return nil, fmt.Errorf("failed to create persistent database: %w", err) + } + } + } else { + logger.Info("Creating in-memory chromem-go database") + chromemDB = chromem.NewDB() + } + + db := &DB{ + config: config, + chromem: chromemDB, + } + + // Set default FTS5 path if not provided + ftsPath := config.FTSDBPath + if ftsPath == "" { + if config.PersistPath != "" { + // Persistent mode: store FTS5 alongside chromem-go + ftsPath = config.PersistPath + "/fts.db" + } else { + // In-memory mode: use SQLite in-memory database + ftsPath = ":memory:" + } + } + + // Initialize FTS5 database for BM25 text search (always enabled) + logger.Infof("Initializing FTS5 database for hybrid search at: %s", ftsPath) + ftsDB, err := NewFTSDatabase(&FTSConfig{DBPath: ftsPath}) + if err != nil { + return nil, fmt.Errorf("failed to create FTS5 database: %w", err) + } + db.fts = ftsDB + logger.Info("Hybrid search enabled (chromem-go + FTS5)") + + logger.Info("Optimizer database initialized successfully") + return db, nil +} + +// GetOrCreateCollection gets an existing collection or creates a new one +func (db *DB) GetOrCreateCollection( + _ context.Context, + name string, + embeddingFunc chromem.EmbeddingFunc, +) (*chromem.Collection, error) { + db.mu.Lock() + defer db.mu.Unlock() + + // Try to get existing collection first + collection := db.chromem.GetCollection(name, embeddingFunc) + if collection != nil { + return collection, nil + } + + // Create new collection if it doesn't exist + collection, err := db.chromem.CreateCollection(name, nil, embeddingFunc) + if err != nil { + return nil, fmt.Errorf("failed to create collection %s: %w", name, err) + } + + logger.Debugf("Created new collection: %s", name) + return collection, nil +} + +// GetCollection gets an existing collection +func (db *DB) GetCollection(name string, embeddingFunc chromem.EmbeddingFunc) (*chromem.Collection, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + collection := db.chromem.GetCollection(name, embeddingFunc) + if collection == nil { + return nil, fmt.Errorf("collection not found: %s", name) + } + return collection, nil +} + +// DeleteCollection deletes a collection +func (db *DB) DeleteCollection(name string) { + db.mu.Lock() + defer db.mu.Unlock() + + //nolint:errcheck,gosec // DeleteCollection in chromem-go doesn't return an error + db.chromem.DeleteCollection(name) + logger.Debugf("Deleted collection: %s", name) +} + +// Close closes both databases +func (db *DB) Close() error { + logger.Info("Closing optimizer databases") + // chromem-go doesn't need explicit close, but FTS5 does + if db.fts != nil { + if err := db.fts.Close(); err != nil { + return fmt.Errorf("failed to close FTS database: %w", err) + } + } + return nil +} + +// GetChromemDB returns the underlying chromem.DB instance +func (db *DB) GetChromemDB() *chromem.DB { + return db.chromem +} + +// GetFTSDB returns the FTS database (may be nil if FTS is disabled) +func (db *DB) GetFTSDB() *FTSDatabase { + return db.fts +} + +// Reset clears all collections and FTS tables (useful for testing and startup) +func (db *DB) Reset() { + db.mu.Lock() + defer db.mu.Unlock() + + //nolint:errcheck,gosec // DeleteCollection in chromem-go doesn't return an error + db.chromem.DeleteCollection(BackendServerCollection) + //nolint:errcheck,gosec // DeleteCollection in chromem-go doesn't return an error + db.chromem.DeleteCollection(BackendToolCollection) + + // Clear FTS5 tables if available + if db.fts != nil { + //nolint:errcheck // Best effort cleanup + _, _ = db.fts.db.Exec("DELETE FROM backend_tools_fts") + //nolint:errcheck // Best effort cleanup + _, _ = db.fts.db.Exec("DELETE FROM backend_servers_fts") + } + + logger.Debug("Reset all collections and FTS tables") +} diff --git a/cmd/thv-operator/pkg/optimizer/db/db_test.go b/cmd/thv-operator/pkg/optimizer/db/db_test.go new file mode 100644 index 0000000000..4eb98daaeb --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/db/db_test.go @@ -0,0 +1,305 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestNewDB_CorruptedDatabase tests database recovery from corruption +func TestNewDB_CorruptedDatabase(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "corrupted-db") + + // Create a directory that looks like a corrupted database + err := os.MkdirAll(dbPath, 0755) + require.NoError(t, err) + + // Create a file that might cause issues + err = os.WriteFile(filepath.Join(dbPath, "some-file"), []byte("corrupted"), 0644) + require.NoError(t, err) + + config := &Config{ + PersistPath: dbPath, + } + + // Should recover from corruption + db, err := NewDB(config) + require.NoError(t, err) + require.NotNil(t, db) + defer func() { _ = db.Close() }() +} + +// TestNewDB_CorruptedDatabase_RecoveryFailure tests when recovery fails +func TestNewDB_CorruptedDatabase_RecoveryFailure(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "corrupted-db") + + // Create a directory that looks like a corrupted database + err := os.MkdirAll(dbPath, 0755) + require.NoError(t, err) + + // Create a file that might cause issues + err = os.WriteFile(filepath.Join(dbPath, "some-file"), []byte("corrupted"), 0644) + require.NoError(t, err) + + // Make directory read-only to simulate recovery failure + // Note: This might not work on all systems, so we'll test the error path differently + // Instead, we'll test with an invalid path that can't be created + config := &Config{ + PersistPath: "/invalid/path/that/does/not/exist", + } + + _, err = NewDB(config) + // Should return error for invalid path + assert.Error(t, err) +} + +// TestDB_GetOrCreateCollection tests collection creation and retrieval +func TestDB_GetOrCreateCollection(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + // Create a simple embedding function + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + // Get or create collection + collection, err := db.GetOrCreateCollection(ctx, "test-collection", embeddingFunc) + require.NoError(t, err) + require.NotNil(t, collection) + + // Get existing collection + collection2, err := db.GetOrCreateCollection(ctx, "test-collection", embeddingFunc) + require.NoError(t, err) + require.NotNil(t, collection2) + assert.Equal(t, collection, collection2) +} + +// TestDB_GetCollection tests collection retrieval +func TestDB_GetCollection(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + // Get non-existent collection should fail + _, err = db.GetCollection("non-existent", embeddingFunc) + assert.Error(t, err) + + // Create collection first + _, err = db.GetOrCreateCollection(ctx, "test-collection", embeddingFunc) + require.NoError(t, err) + + // Now get it + collection, err := db.GetCollection("test-collection", embeddingFunc) + require.NoError(t, err) + require.NotNil(t, collection) +} + +// TestDB_DeleteCollection tests collection deletion +func TestDB_DeleteCollection(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + // Create collection + _, err = db.GetOrCreateCollection(ctx, "test-collection", embeddingFunc) + require.NoError(t, err) + + // Delete collection + db.DeleteCollection("test-collection") + + // Verify it's deleted + _, err = db.GetCollection("test-collection", embeddingFunc) + assert.Error(t, err) +} + +// TestDB_Reset tests database reset +func TestDB_Reset(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + // Create collections + _, err = db.GetOrCreateCollection(ctx, BackendServerCollection, embeddingFunc) + require.NoError(t, err) + + _, err = db.GetOrCreateCollection(ctx, BackendToolCollection, embeddingFunc) + require.NoError(t, err) + + // Reset database + db.Reset() + + // Verify collections are deleted + _, err = db.GetCollection(BackendServerCollection, embeddingFunc) + assert.Error(t, err) + + _, err = db.GetCollection(BackendToolCollection, embeddingFunc) + assert.Error(t, err) +} + +// TestDB_GetChromemDB tests chromem DB accessor +func TestDB_GetChromemDB(t *testing.T) { + t.Parallel() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + chromemDB := db.GetChromemDB() + require.NotNil(t, chromemDB) +} + +// TestDB_GetFTSDB tests FTS DB accessor +func TestDB_GetFTSDB(t *testing.T) { + t.Parallel() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + ftsDB := db.GetFTSDB() + require.NotNil(t, ftsDB) +} + +// TestDB_Close tests database closing +func TestDB_Close(t *testing.T) { + t.Parallel() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := NewDB(config) + require.NoError(t, err) + + err = db.Close() + require.NoError(t, err) + + // Multiple closes should be safe + err = db.Close() + require.NoError(t, err) +} + +// TestNewDB_FTSDBPath tests FTS database path configuration +func TestNewDB_FTSDBPath(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + + tests := []struct { + name string + config *Config + wantErr bool + }{ + { + name: "in-memory FTS with persistent chromem", + config: &Config{ + PersistPath: filepath.Join(tmpDir, "db"), + FTSDBPath: ":memory:", + }, + wantErr: false, + }, + { + name: "persistent FTS with persistent chromem", + config: &Config{ + PersistPath: filepath.Join(tmpDir, "db2"), + FTSDBPath: filepath.Join(tmpDir, "fts.db"), + }, + wantErr: false, + }, + { + name: "default FTS path with persistent chromem", + config: &Config{ + PersistPath: filepath.Join(tmpDir, "db3"), + // FTSDBPath not set, should default to {PersistPath}/fts.db + }, + wantErr: false, + }, + { + name: "in-memory FTS with in-memory chromem", + config: &Config{ + PersistPath: "", + FTSDBPath: ":memory:", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + db, err := NewDB(tt.config) + if tt.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + require.NotNil(t, db) + defer func() { _ = db.Close() }() + + // Verify FTS DB is accessible + ftsDB := db.GetFTSDB() + require.NotNil(t, ftsDB) + } + }) + } +} diff --git a/cmd/thv-operator/pkg/optimizer/db/fts.go b/cmd/thv-operator/pkg/optimizer/db/fts.go new file mode 100644 index 0000000000..2f444cfae0 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/db/fts.go @@ -0,0 +1,360 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "database/sql" + _ "embed" + "fmt" + "strings" + "sync" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models" + "github.com/stacklok/toolhive/pkg/logger" +) + +//go:embed schema_fts.sql +var schemaFTS string + +// FTSConfig holds FTS5 database configuration +type FTSConfig struct { + // DBPath is the path to the SQLite database file + // If empty, uses ":memory:" for in-memory database + DBPath string +} + +// FTSDatabase handles FTS5 (BM25) search operations +type FTSDatabase struct { + config *FTSConfig + db *sql.DB + mu sync.RWMutex +} + +// NewFTSDatabase creates a new FTS5 database for BM25 search +func NewFTSDatabase(config *FTSConfig) (*FTSDatabase, error) { + dbPath := config.DBPath + if dbPath == "" { + dbPath = ":memory:" + } + + // Open with modernc.org/sqlite (pure Go) + sqlDB, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, fmt.Errorf("failed to open FTS database: %w", err) + } + + // Set pragmas for performance + pragmas := []string{ + "PRAGMA journal_mode=WAL", + "PRAGMA synchronous=NORMAL", + "PRAGMA foreign_keys=ON", + "PRAGMA busy_timeout=5000", + } + + for _, pragma := range pragmas { + if _, err := sqlDB.Exec(pragma); err != nil { + _ = sqlDB.Close() + return nil, fmt.Errorf("failed to set pragma: %w", err) + } + } + + ftsDB := &FTSDatabase{ + config: config, + db: sqlDB, + } + + // Initialize schema + if err := ftsDB.initializeSchema(); err != nil { + _ = sqlDB.Close() + return nil, fmt.Errorf("failed to initialize FTS schema: %w", err) + } + + logger.Infof("FTS5 database initialized successfully at: %s", dbPath) + return ftsDB, nil +} + +// initializeSchema creates the FTS5 tables and triggers +// +// Note: We execute the schema directly rather than using a migration framework +// because the FTS database is ephemeral (destroyed on shutdown, recreated on startup). +// Migrations are only needed when you need to preserve data across schema changes. +func (fts *FTSDatabase) initializeSchema() error { + fts.mu.Lock() + defer fts.mu.Unlock() + + _, err := fts.db.Exec(schemaFTS) + if err != nil { + return fmt.Errorf("failed to execute schema: %w", err) + } + + logger.Debug("FTS5 schema initialized") + return nil +} + +// UpsertServer inserts or updates a server in the FTS database +func (fts *FTSDatabase) UpsertServer( + ctx context.Context, + server *models.BackendServer, +) error { + fts.mu.Lock() + defer fts.mu.Unlock() + + query := ` + INSERT INTO backend_servers_fts (id, name, description, server_group, last_updated, created_at) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + name = excluded.name, + description = excluded.description, + server_group = excluded.server_group, + last_updated = excluded.last_updated + ` + + _, err := fts.db.ExecContext( + ctx, + query, + server.ID, + server.Name, + server.Description, + server.Group, + server.LastUpdated, + server.CreatedAt, + ) + + if err != nil { + return fmt.Errorf("failed to upsert server in FTS: %w", err) + } + + logger.Debugf("Upserted server in FTS: %s", server.ID) + return nil +} + +// UpsertToolMeta inserts or updates a tool in the FTS database +func (fts *FTSDatabase) UpsertToolMeta( + ctx context.Context, + tool *models.BackendTool, + _ string, // serverName - unused, keeping for interface compatibility +) error { + fts.mu.Lock() + defer fts.mu.Unlock() + + // Convert input schema to JSON string + var schemaStr *string + if len(tool.InputSchema) > 0 { + str := string(tool.InputSchema) + schemaStr = &str + } + + query := ` + INSERT INTO backend_tools_fts ( + id, mcpserver_id, tool_name, tool_description, + input_schema, token_count, last_updated, created_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + mcpserver_id = excluded.mcpserver_id, + tool_name = excluded.tool_name, + tool_description = excluded.tool_description, + input_schema = excluded.input_schema, + token_count = excluded.token_count, + last_updated = excluded.last_updated + ` + + _, err := fts.db.ExecContext( + ctx, + query, + tool.ID, + tool.MCPServerID, + tool.ToolName, + tool.Description, + schemaStr, + tool.TokenCount, + tool.LastUpdated, + tool.CreatedAt, + ) + + if err != nil { + return fmt.Errorf("failed to upsert tool in FTS: %w", err) + } + + logger.Debugf("Upserted tool in FTS: %s", tool.ToolName) + return nil +} + +// DeleteServer removes a server and its tools from FTS database +func (fts *FTSDatabase) DeleteServer(ctx context.Context, serverID string) error { + fts.mu.Lock() + defer fts.mu.Unlock() + + // Foreign key cascade will delete related tools + _, err := fts.db.ExecContext(ctx, "DELETE FROM backend_servers_fts WHERE id = ?", serverID) + if err != nil { + return fmt.Errorf("failed to delete server from FTS: %w", err) + } + + logger.Debugf("Deleted server from FTS: %s", serverID) + return nil +} + +// DeleteToolsByServer removes all tools for a server from FTS database +func (fts *FTSDatabase) DeleteToolsByServer(ctx context.Context, serverID string) error { + fts.mu.Lock() + defer fts.mu.Unlock() + + result, err := fts.db.ExecContext(ctx, "DELETE FROM backend_tools_fts WHERE mcpserver_id = ?", serverID) + if err != nil { + return fmt.Errorf("failed to delete tools from FTS: %w", err) + } + + count, _ := result.RowsAffected() + logger.Debugf("Deleted %d tools from FTS for server: %s", count, serverID) + return nil +} + +// DeleteTool removes a tool from FTS database +func (fts *FTSDatabase) DeleteTool(ctx context.Context, toolID string) error { + fts.mu.Lock() + defer fts.mu.Unlock() + + _, err := fts.db.ExecContext(ctx, "DELETE FROM backend_tools_fts WHERE id = ?", toolID) + if err != nil { + return fmt.Errorf("failed to delete tool from FTS: %w", err) + } + + logger.Debugf("Deleted tool from FTS: %s", toolID) + return nil +} + +// SearchBM25 performs BM25 full-text search on tools +func (fts *FTSDatabase) SearchBM25( + ctx context.Context, + query string, + limit int, + serverID *string, +) ([]*models.BackendToolWithMetadata, error) { + fts.mu.RLock() + defer fts.mu.RUnlock() + + // Sanitize FTS5 query + sanitizedQuery := sanitizeFTS5Query(query) + if sanitizedQuery == "" { + return []*models.BackendToolWithMetadata{}, nil + } + + // Build query with optional server filter + sqlQuery := ` + SELECT + t.id, + t.mcpserver_id, + t.tool_name, + t.tool_description, + t.input_schema, + t.token_count, + t.last_updated, + t.created_at, + fts.rank + FROM backend_tool_fts_index fts + JOIN backend_tools_fts t ON fts.tool_id = t.id + WHERE backend_tool_fts_index MATCH ? + ` + + args := []interface{}{sanitizedQuery} + + if serverID != nil { + sqlQuery += " AND t.mcpserver_id = ?" + args = append(args, *serverID) + } + + sqlQuery += " ORDER BY rank LIMIT ?" + args = append(args, limit) + + rows, err := fts.db.QueryContext(ctx, sqlQuery, args...) + if err != nil { + return nil, fmt.Errorf("failed to search tools: %w", err) + } + defer func() { _ = rows.Close() }() + + var results []*models.BackendToolWithMetadata + for rows.Next() { + var tool models.BackendTool + var schemaStr sql.NullString + var rank float32 + + err := rows.Scan( + &tool.ID, + &tool.MCPServerID, + &tool.ToolName, + &tool.Description, + &schemaStr, + &tool.TokenCount, + &tool.LastUpdated, + &tool.CreatedAt, + &rank, + ) + if err != nil { + logger.Warnf("Failed to scan tool row: %v", err) + continue + } + + if schemaStr.Valid { + tool.InputSchema = []byte(schemaStr.String) + } + + // Convert BM25 rank to similarity score (higher is better) + // FTS5 rank is negative, so we negate and normalize + similarity := float32(1.0 / (1.0 - float64(rank))) + + results = append(results, &models.BackendToolWithMetadata{ + BackendTool: tool, + Similarity: similarity, + }) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating tool rows: %w", err) + } + + logger.Debugf("BM25 search found %d tools for query: %s", len(results), query) + return results, nil +} + +// GetTotalToolTokens returns the sum of token_count across all tools +func (fts *FTSDatabase) GetTotalToolTokens(ctx context.Context) (int, error) { + fts.mu.RLock() + defer fts.mu.RUnlock() + + var totalTokens int + query := "SELECT COALESCE(SUM(token_count), 0) FROM backend_tools_fts" + + err := fts.db.QueryRowContext(ctx, query).Scan(&totalTokens) + if err != nil { + return 0, fmt.Errorf("failed to get total tool tokens: %w", err) + } + + return totalTokens, nil +} + +// Close closes the FTS database connection +func (fts *FTSDatabase) Close() error { + return fts.db.Close() +} + +// sanitizeFTS5Query escapes special characters in FTS5 queries +// FTS5 uses: " * ( ) AND OR NOT +func sanitizeFTS5Query(query string) string { + // Remove or escape special FTS5 characters + replacer := strings.NewReplacer( + `"`, `""`, // Escape quotes + `*`, ` `, // Remove wildcards + `(`, ` `, // Remove parentheses + `)`, ` `, + ) + + sanitized := replacer.Replace(query) + + // Remove multiple spaces + sanitized = strings.Join(strings.Fields(sanitized), " ") + + return strings.TrimSpace(sanitized) +} diff --git a/cmd/thv-operator/pkg/optimizer/db/fts_test_coverage.go b/cmd/thv-operator/pkg/optimizer/db/fts_test_coverage.go new file mode 100644 index 0000000000..b4b1911b93 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/db/fts_test_coverage.go @@ -0,0 +1,162 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models" +) + +// stringPtr returns a pointer to the given string +func stringPtr(s string) *string { + return &s +} + +// TestFTSDatabase_GetTotalToolTokens tests token counting +func TestFTSDatabase_GetTotalToolTokens(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &FTSConfig{ + DBPath: ":memory:", + } + + ftsDB, err := NewFTSDatabase(config) + require.NoError(t, err) + defer func() { _ = ftsDB.Close() }() + + // Initially should be 0 + totalTokens, err := ftsDB.GetTotalToolTokens(ctx) + require.NoError(t, err) + assert.Equal(t, 0, totalTokens) + + // Add a tool + tool := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "test_tool", + Description: stringPtr("Test tool"), + TokenCount: 100, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + err = ftsDB.UpsertToolMeta(ctx, tool, "TestServer") + require.NoError(t, err) + + // Should now have tokens + totalTokens, err = ftsDB.GetTotalToolTokens(ctx) + require.NoError(t, err) + assert.Equal(t, 100, totalTokens) + + // Add another tool + tool2 := &models.BackendTool{ + ID: "tool-2", + MCPServerID: "server-1", + ToolName: "test_tool2", + Description: stringPtr("Test tool 2"), + TokenCount: 50, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + err = ftsDB.UpsertToolMeta(ctx, tool2, "TestServer") + require.NoError(t, err) + + // Should sum tokens + totalTokens, err = ftsDB.GetTotalToolTokens(ctx) + require.NoError(t, err) + assert.Equal(t, 150, totalTokens) +} + +// TestSanitizeFTS5Query tests query sanitization +func TestSanitizeFTS5Query(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "remove quotes", + input: `"test query"`, + expected: "test query", + }, + { + name: "remove wildcards", + input: "test*query", + expected: "test query", + }, + { + name: "remove parentheses", + input: "test(query)", + expected: "test query", + }, + { + name: "remove multiple spaces", + input: "test query", + expected: "test query", + }, + { + name: "trim whitespace", + input: " test query ", + expected: "test query", + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "only special characters", + input: `"*()`, + expected: "", + }, + { + name: "mixed special characters", + input: `test"query*with(special)chars`, + expected: "test query with special chars", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := sanitizeFTS5Query(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestFTSDatabase_SearchBM25_EmptyQuery tests empty query handling +func TestFTSDatabase_SearchBM25_EmptyQuery(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &FTSConfig{ + DBPath: ":memory:", + } + + ftsDB, err := NewFTSDatabase(config) + require.NoError(t, err) + defer func() { _ = ftsDB.Close() }() + + // Empty query should return empty results + results, err := ftsDB.SearchBM25(ctx, "", 10, nil) + require.NoError(t, err) + assert.Empty(t, results) + + // Query with only special characters should return empty results + results, err = ftsDB.SearchBM25(ctx, `"*()`, 10, nil) + require.NoError(t, err) + assert.Empty(t, results) +} diff --git a/cmd/thv-operator/pkg/optimizer/db/hybrid.go b/cmd/thv-operator/pkg/optimizer/db/hybrid.go new file mode 100644 index 0000000000..27df70d696 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/db/hybrid.go @@ -0,0 +1,172 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "fmt" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models" + "github.com/stacklok/toolhive/pkg/logger" +) + +// HybridSearchConfig configures hybrid search behavior +type HybridSearchConfig struct { + // SemanticRatio controls the mix of semantic vs BM25 results (0-100, representing percentage) + // Default: 70 (70% semantic, 30% BM25) + SemanticRatio int + + // Limit is the total number of results to return + Limit int + + // ServerID optionally filters results to a specific server + ServerID *string +} + +// DefaultHybridConfig returns sensible defaults for hybrid search +func DefaultHybridConfig() *HybridSearchConfig { + return &HybridSearchConfig{ + SemanticRatio: 70, + Limit: 10, + } +} + +// SearchHybrid performs hybrid search combining semantic (chromem-go) and BM25 (FTS5) results +// This matches the Python mcp-optimizer's hybrid search implementation +func (ops *BackendToolOps) SearchHybrid( + ctx context.Context, + queryText string, + config *HybridSearchConfig, +) ([]*models.BackendToolWithMetadata, error) { + if config == nil { + config = DefaultHybridConfig() + } + + // Calculate limits for each search method + // Convert percentage to ratio (0-100 -> 0.0-1.0) + semanticRatioFloat := float64(config.SemanticRatio) / 100.0 + semanticLimit := max(1, int(float64(config.Limit)*semanticRatioFloat)) + bm25Limit := max(1, config.Limit-semanticLimit) + + logger.Debugf( + "Hybrid search: semantic_limit=%d, bm25_limit=%d, ratio=%d%%", + semanticLimit, bm25Limit, config.SemanticRatio, + ) + + // Execute both searches in parallel + type searchResult struct { + results []*models.BackendToolWithMetadata + err error + } + + semanticCh := make(chan searchResult, 1) + bm25Ch := make(chan searchResult, 1) + + // Semantic search + go func() { + results, err := ops.Search(ctx, queryText, semanticLimit, config.ServerID) + semanticCh <- searchResult{results, err} + }() + + // BM25 search + go func() { + results, err := ops.db.fts.SearchBM25(ctx, queryText, bm25Limit, config.ServerID) + bm25Ch <- searchResult{results, err} + }() + + // Collect results + var semanticResults, bm25Results []*models.BackendToolWithMetadata + var errs []error + + // Wait for semantic results + semanticRes := <-semanticCh + if semanticRes.err != nil { + logger.Warnf("Semantic search failed: %v", semanticRes.err) + errs = append(errs, semanticRes.err) + } else { + semanticResults = semanticRes.results + } + + // Wait for BM25 results + bm25Res := <-bm25Ch + if bm25Res.err != nil { + logger.Warnf("BM25 search failed: %v", bm25Res.err) + errs = append(errs, bm25Res.err) + } else { + bm25Results = bm25Res.results + } + + // If both failed, return error + if len(errs) == 2 { + return nil, fmt.Errorf("both search methods failed: semantic=%v, bm25=%v", errs[0], errs[1]) + } + + // Combine and deduplicate results + combined := combineAndDeduplicateResults(semanticResults, bm25Results, config.Limit) + + logger.Infof( + "Hybrid search completed: semantic=%d, bm25=%d, combined=%d (requested=%d)", + len(semanticResults), len(bm25Results), len(combined), config.Limit, + ) + + return combined, nil +} + +// combineAndDeduplicateResults merges semantic and BM25 results, removing duplicates +// Keeps the result with the higher similarity score for duplicates +func combineAndDeduplicateResults( + semantic, bm25 []*models.BackendToolWithMetadata, + limit int, +) []*models.BackendToolWithMetadata { + // Use a map to deduplicate by tool ID + seen := make(map[string]*models.BackendToolWithMetadata) + + // Add semantic results first (they typically have higher quality) + for _, result := range semantic { + seen[result.ID] = result + } + + // Add BM25 results, only if not seen or if similarity is higher + for _, result := range bm25 { + if existing, exists := seen[result.ID]; exists { + // Keep the one with higher similarity + if result.Similarity > existing.Similarity { + seen[result.ID] = result + } + } else { + seen[result.ID] = result + } + } + + // Convert map to slice + combined := make([]*models.BackendToolWithMetadata, 0, len(seen)) + for _, result := range seen { + combined = append(combined, result) + } + + // Sort by similarity (descending) and limit + sortedResults := sortBySimilarity(combined) + if len(sortedResults) > limit { + sortedResults = sortedResults[:limit] + } + + return sortedResults +} + +// sortBySimilarity sorts results by similarity score in descending order +func sortBySimilarity(results []*models.BackendToolWithMetadata) []*models.BackendToolWithMetadata { + // Simple bubble sort (fine for small result sets) + sorted := make([]*models.BackendToolWithMetadata, len(results)) + copy(sorted, results) + + for i := 0; i < len(sorted); i++ { + for j := i + 1; j < len(sorted); j++ { + if sorted[j].Similarity > sorted[i].Similarity { + sorted[i], sorted[j] = sorted[j], sorted[i] + } + } + } + + return sorted +} diff --git a/cmd/thv-operator/pkg/optimizer/db/schema_fts.sql b/cmd/thv-operator/pkg/optimizer/db/schema_fts.sql new file mode 100644 index 0000000000..101dbea7d7 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/db/schema_fts.sql @@ -0,0 +1,120 @@ +-- FTS5 schema for BM25 full-text search +-- Complements chromem-go (which handles vector/semantic search) +-- +-- This schema only contains: +-- 1. Metadata tables for tool/server information +-- 2. FTS5 virtual tables for BM25 keyword search +-- +-- Note: chromem-go handles embeddings separately in memory/persistent storage + +-- Backend servers metadata (for FTS queries and joining) +CREATE TABLE IF NOT EXISTS backend_servers_fts ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + server_group TEXT NOT NULL DEFAULT 'default', + last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS idx_backend_servers_fts_group ON backend_servers_fts(server_group); + +-- Backend tools metadata (for FTS queries and joining) +CREATE TABLE IF NOT EXISTS backend_tools_fts ( + id TEXT PRIMARY KEY, + mcpserver_id TEXT NOT NULL, + tool_name TEXT NOT NULL, + tool_description TEXT, + input_schema TEXT, -- JSON string + token_count INTEGER NOT NULL DEFAULT 0, + last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (mcpserver_id) REFERENCES backend_servers_fts(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_backend_tools_fts_server ON backend_tools_fts(mcpserver_id); +CREATE INDEX IF NOT EXISTS idx_backend_tools_fts_name ON backend_tools_fts(tool_name); + +-- FTS5 virtual table for backend tools +-- Uses Porter stemming for better keyword matching +-- Indexes: server name, tool name, and tool description +CREATE VIRTUAL TABLE IF NOT EXISTS backend_tool_fts_index +USING fts5( + tool_id UNINDEXED, + mcp_server_name, + tool_name, + tool_description, + tokenize='porter', + content='backend_tools_fts', + content_rowid='rowid' +); + +-- Triggers to keep FTS5 index in sync with backend_tools_fts table +CREATE TRIGGER IF NOT EXISTS backend_tools_fts_ai AFTER INSERT ON backend_tools_fts BEGIN + INSERT INTO backend_tool_fts_index( + rowid, + tool_id, + mcp_server_name, + tool_name, + tool_description + ) + SELECT + rowid, + new.id, + (SELECT name FROM backend_servers_fts WHERE id = new.mcpserver_id), + new.tool_name, + COALESCE(new.tool_description, '') + FROM backend_tools_fts + WHERE id = new.id; +END; + +CREATE TRIGGER IF NOT EXISTS backend_tools_fts_ad AFTER DELETE ON backend_tools_fts BEGIN + INSERT INTO backend_tool_fts_index( + backend_tool_fts_index, + rowid, + tool_id, + mcp_server_name, + tool_name, + tool_description + ) VALUES ( + 'delete', + old.rowid, + old.id, + NULL, + NULL, + NULL + ); +END; + +CREATE TRIGGER IF NOT EXISTS backend_tools_fts_au AFTER UPDATE ON backend_tools_fts BEGIN + INSERT INTO backend_tool_fts_index( + backend_tool_fts_index, + rowid, + tool_id, + mcp_server_name, + tool_name, + tool_description + ) VALUES ( + 'delete', + old.rowid, + old.id, + NULL, + NULL, + NULL + ); + INSERT INTO backend_tool_fts_index( + rowid, + tool_id, + mcp_server_name, + tool_name, + tool_description + ) + SELECT + rowid, + new.id, + (SELECT name FROM backend_servers_fts WHERE id = new.mcpserver_id), + new.tool_name, + COALESCE(new.tool_description, '') + FROM backend_tools_fts + WHERE id = new.id; +END; diff --git a/cmd/thv-operator/pkg/optimizer/db/sqlite_fts.go b/cmd/thv-operator/pkg/optimizer/db/sqlite_fts.go new file mode 100644 index 0000000000..23ae5bcdfb --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/db/sqlite_fts.go @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package db provides database operations for the optimizer. +// This file handles FTS5 (Full-Text Search) using modernc.org/sqlite (pure Go). +package db + +import ( + // Pure Go SQLite driver with FTS5 support + _ "modernc.org/sqlite" +) diff --git a/cmd/thv-operator/pkg/optimizer/doc.go b/cmd/thv-operator/pkg/optimizer/doc.go new file mode 100644 index 0000000000..c59b7556a1 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/doc.go @@ -0,0 +1,88 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package optimizer provides semantic tool discovery and ingestion for MCP servers. +// +// The optimizer package implements an ingestion service that discovers MCP backends +// from ToolHive, generates semantic embeddings for tools using ONNX Runtime, and stores +// them in a SQLite database with vector search capabilities. +// +// # Architecture +// +// The optimizer follows a similar architecture to mcp-optimizer (Python) but adapted +// for Go idioms and patterns: +// +// pkg/optimizer/ +// ├── doc.go // Package documentation +// ├── models/ // Database models and types +// │ ├── models.go // Core domain models (Server, Tool, etc.) +// │ └── transport.go // Transport and status enums +// ├── db/ // Database layer +// │ ├── db.go // Database connection and config +// │ ├── fts.go // FTS5 database for BM25 search +// │ ├── schema_fts.sql // Embedded FTS5 schema (executed directly) +// │ ├── hybrid.go // Hybrid search (semantic + BM25) +// │ ├── backend_server.go // Backend server operations +// │ └── backend_tool.go // Backend tool operations +// ├── embeddings/ // Embedding generation +// │ ├── manager.go // Embedding manager with ONNX Runtime +// │ └── cache.go // Optional embedding cache +// ├── mcpclient/ // MCP client for tool discovery +// │ └── client.go // MCP client wrapper +// ├── ingestion/ // Core ingestion service +// │ ├── service.go // Ingestion service implementation +// │ └── errors.go // Custom errors +// └── tokens/ // Token counting (for LLM consumption) +// └── counter.go // Token counter using tiktoken-go +// +// # Core Concepts +// +// **Ingestion**: Discovers MCP backends from ToolHive (via Docker or Kubernetes), +// connects to each backend to list tools, generates embeddings, and stores in database. +// +// **Embeddings**: Uses ONNX Runtime to generate semantic embeddings for tools and servers. +// Embeddings enable semantic search to find relevant tools based on natural language queries. +// +// **Database**: Hybrid approach using chromem-go for vector search and SQLite FTS5 for +// keyword search. The database is ephemeral (in-memory by default, optional persistence) +// and schema is initialized directly on startup without migrations. +// +// **Terminology**: Uses "BackendServer" and "BackendTool" to explicitly refer to MCP server +// metadata, distinguishing from vMCP's broader "Backend" concept which represents workloads. +// +// **Token Counting**: Tracks token counts for tools to measure LLM consumption and +// calculate token savings from semantic filtering. +// +// # Usage +// +// The optimizer is integrated into vMCP as native tools: +// +// 1. **vMCP Integration**: The optimizer runs as part of vMCP, exposing +// optim.find_tool and optim.call_tool to clients. +// +// 2. **Event-Driven Ingestion**: Tools are ingested when vMCP sessions +// are registered, not via polling. +// +// Example vMCP integration (see pkg/vmcp/optimizer): +// +// import ( +// "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/ingestion" +// "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings" +// ) +// +// // Create embedding manager +// embMgr, err := embeddings.NewManager(embeddings.Config{ +// BackendType: "ollama", // or "openai-compatible" or "vllm" +// BaseURL: "http://localhost:11434", +// Model: "all-minilm", +// Dimension: 384, +// }) +// +// // Create ingestion service +// svc, err := ingestion.NewService(ctx, ingestion.Config{ +// DBConfig: dbConfig, +// }, embMgr) +// +// // Ingest a server (called by vMCP's OnRegisterSession hook) +// err = svc.IngestServer(ctx, "weather-service", tools, target) +package optimizer diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/cache.go b/cmd/thv-operator/pkg/optimizer/embeddings/cache.go new file mode 100644 index 0000000000..68f6bbe74b --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/embeddings/cache.go @@ -0,0 +1,104 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package embeddings provides caching for embedding vectors. +package embeddings + +import ( + "container/list" + "sync" +) + +// cache implements an LRU cache for embeddings +type cache struct { + maxSize int + mu sync.RWMutex + items map[string]*list.Element + lru *list.List + hits int64 + misses int64 +} + +type cacheEntry struct { + key string + value []float32 +} + +// newCache creates a new LRU cache +func newCache(maxSize int) *cache { + return &cache{ + maxSize: maxSize, + items: make(map[string]*list.Element), + lru: list.New(), + } +} + +// Get retrieves an embedding from the cache +func (c *cache) Get(key string) []float32 { + c.mu.Lock() + defer c.mu.Unlock() + + elem, ok := c.items[key] + if !ok { + c.misses++ + return nil + } + + c.hits++ + c.lru.MoveToFront(elem) + return elem.Value.(*cacheEntry).value +} + +// Put stores an embedding in the cache +func (c *cache) Put(key string, value []float32) { + c.mu.Lock() + defer c.mu.Unlock() + + // Check if key already exists + if elem, ok := c.items[key]; ok { + c.lru.MoveToFront(elem) + elem.Value.(*cacheEntry).value = value + return + } + + // Add new entry + entry := &cacheEntry{ + key: key, + value: value, + } + elem := c.lru.PushFront(entry) + c.items[key] = elem + + // Evict if necessary + if c.lru.Len() > c.maxSize { + c.evict() + } +} + +// evict removes the least recently used item +func (c *cache) evict() { + elem := c.lru.Back() + if elem != nil { + c.lru.Remove(elem) + entry := elem.Value.(*cacheEntry) + delete(c.items, entry.key) + } +} + +// Size returns the current cache size +func (c *cache) Size() int { + c.mu.RLock() + defer c.mu.RUnlock() + return c.lru.Len() +} + +// Clear clears the cache +func (c *cache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + + c.items = make(map[string]*list.Element) + c.lru = list.New() + c.hits = 0 + c.misses = 0 +} diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/cache_test.go b/cmd/thv-operator/pkg/optimizer/embeddings/cache_test.go new file mode 100644 index 0000000000..9b16346056 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/embeddings/cache_test.go @@ -0,0 +1,172 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package embeddings + +import ( + "testing" +) + +func TestCache_GetPut(t *testing.T) { + t.Parallel() + c := newCache(2) + + // Test cache miss + result := c.Get("key1") + if result != nil { + t.Error("Expected cache miss for non-existent key") + } + if c.misses != 1 { + t.Errorf("Expected 1 miss, got %d", c.misses) + } + + // Test cache put and hit + embedding := []float32{1.0, 2.0, 3.0} + c.Put("key1", embedding) + + result = c.Get("key1") + if result == nil { + t.Fatal("Expected cache hit for existing key") + } + if c.hits != 1 { + t.Errorf("Expected 1 hit, got %d", c.hits) + } + + // Verify embedding values + if len(result) != len(embedding) { + t.Errorf("Embedding length mismatch: got %d, want %d", len(result), len(embedding)) + } + for i := range embedding { + if result[i] != embedding[i] { + t.Errorf("Embedding value mismatch at index %d: got %f, want %f", i, result[i], embedding[i]) + } + } +} + +func TestCache_LRUEviction(t *testing.T) { + t.Parallel() + c := newCache(2) + + // Add two items (fills cache) + c.Put("key1", []float32{1.0}) + c.Put("key2", []float32{2.0}) + + if c.Size() != 2 { + t.Errorf("Expected cache size 2, got %d", c.Size()) + } + + // Add third item (should evict key1) + c.Put("key3", []float32{3.0}) + + if c.Size() != 2 { + t.Errorf("Expected cache size 2 after eviction, got %d", c.Size()) + } + + // key1 should be evicted (oldest) + if result := c.Get("key1"); result != nil { + t.Error("key1 should have been evicted") + } + + // key2 and key3 should still exist + if result := c.Get("key2"); result == nil { + t.Error("key2 should still exist") + } + if result := c.Get("key3"); result == nil { + t.Error("key3 should still exist") + } +} + +func TestCache_MoveToFrontOnAccess(t *testing.T) { + t.Parallel() + c := newCache(2) + + // Add two items + c.Put("key1", []float32{1.0}) + c.Put("key2", []float32{2.0}) + + // Access key1 (moves it to front) + c.Get("key1") + + // Add third item (should evict key2, not key1) + c.Put("key3", []float32{3.0}) + + // key1 should still exist (was accessed recently) + if result := c.Get("key1"); result == nil { + t.Error("key1 should still exist (was accessed recently)") + } + + // key2 should be evicted (was oldest) + if result := c.Get("key2"); result != nil { + t.Error("key2 should have been evicted") + } + + // key3 should exist + if result := c.Get("key3"); result == nil { + t.Error("key3 should exist") + } +} + +func TestCache_UpdateExistingKey(t *testing.T) { + t.Parallel() + c := newCache(2) + + // Add initial value + c.Put("key1", []float32{1.0}) + + // Update with new value + newEmbedding := []float32{2.0, 3.0} + c.Put("key1", newEmbedding) + + // Should get updated value + result := c.Get("key1") + if result == nil { + t.Fatal("Expected cache hit for existing key") + } + + if len(result) != len(newEmbedding) { + t.Errorf("Embedding length mismatch: got %d, want %d", len(result), len(newEmbedding)) + } + + // Cache size should still be 1 + if c.Size() != 1 { + t.Errorf("Expected cache size 1, got %d", c.Size()) + } +} + +func TestCache_Clear(t *testing.T) { + t.Parallel() + c := newCache(10) + + // Add some items + c.Put("key1", []float32{1.0}) + c.Put("key2", []float32{2.0}) + c.Put("key3", []float32{3.0}) + + // Access some items to generate stats + c.Get("key1") + c.Get("missing") + + if c.Size() != 3 { + t.Errorf("Expected cache size 3, got %d", c.Size()) + } + + // Clear cache + c.Clear() + + if c.Size() != 0 { + t.Errorf("Expected cache size 0 after clear, got %d", c.Size()) + } + + // Stats should be reset + if c.hits != 0 { + t.Errorf("Expected 0 hits after clear, got %d", c.hits) + } + if c.misses != 0 { + t.Errorf("Expected 0 misses after clear, got %d", c.misses) + } + + // Items should be gone + if result := c.Get("key1"); result != nil { + t.Error("key1 should be gone after clear") + } +} diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/manager.go b/cmd/thv-operator/pkg/optimizer/embeddings/manager.go new file mode 100644 index 0000000000..4f29729e3b --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/embeddings/manager.go @@ -0,0 +1,219 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package embeddings + +import ( + "fmt" + "strings" + "sync" + + "github.com/stacklok/toolhive/pkg/logger" +) + +const ( + // DefaultModelAllMiniLM is the default Ollama model name + DefaultModelAllMiniLM = "all-minilm" + // BackendTypeOllama is the Ollama backend type + BackendTypeOllama = "ollama" +) + +// Config holds configuration for the embedding manager +type Config struct { + // BackendType specifies which backend to use: + // - "ollama": Ollama native API (default) + // - "vllm": vLLM OpenAI-compatible API + // - "unified": Generic OpenAI-compatible API (works with both) + // - "openai": OpenAI-compatible API + BackendType string + + // BaseURL is the base URL for the embedding service + // - Ollama: http://127.0.0.1:11434 (or http://localhost:11434, will be normalized to 127.0.0.1) + // - vLLM: http://localhost:8000 + BaseURL string + + // Model is the model name to use + // - Ollama: "all-minilm" (default), "nomic-embed-text" + // - vLLM: "sentence-transformers/all-MiniLM-L6-v2", "intfloat/e5-mistral-7b-instruct" + Model string + + // Dimension is the embedding dimension (default 384 for all-MiniLM-L6-v2) + Dimension int + + // EnableCache enables caching of embeddings + EnableCache bool + + // MaxCacheSize is the maximum number of embeddings to cache (default 1000) + MaxCacheSize int +} + +// Backend interface for different embedding implementations +type Backend interface { + Embed(text string) ([]float32, error) + EmbedBatch(texts []string) ([][]float32, error) + Dimension() int + Close() error +} + +// Manager manages embedding generation using pluggable backends +// Default backend is all-MiniLM-L6-v2 (same model as codegate) +type Manager struct { + config *Config + backend Backend + cache *cache + mu sync.RWMutex +} + +// NewManager creates a new embedding manager +func NewManager(config *Config) (*Manager, error) { + if config.Dimension == 0 { + config.Dimension = 384 // Default dimension for all-MiniLM-L6-v2 + } + + if config.MaxCacheSize == 0 { + config.MaxCacheSize = 1000 + } + + // Default to Ollama + if config.BackendType == "" { + config.BackendType = BackendTypeOllama + } + + // Initialize backend based on configuration + var backend Backend + var err error + + switch config.BackendType { + case BackendTypeOllama: + // Use Ollama native API (requires ollama serve) + baseURL := config.BaseURL + if baseURL == "" { + baseURL = "http://127.0.0.1:11434" + } else { + // Normalize localhost to 127.0.0.1 to avoid IPv6 resolution issues + baseURL = strings.ReplaceAll(baseURL, "localhost", "127.0.0.1") + } + model := config.Model + if model == "" { + model = DefaultModelAllMiniLM // Default: all-MiniLM-L6-v2 + } + // Update dimension if not set and using default model + if config.Dimension == 0 && model == DefaultModelAllMiniLM { + config.Dimension = 384 + } + backend, err = NewOllamaBackend(baseURL, model) + if err != nil { + return nil, fmt.Errorf( + "failed to initialize Ollama backend: %w (ensure 'ollama serve' is running and 'ollama pull %s' has been executed)", + err, DefaultModelAllMiniLM) + } + + case "vllm", "unified", "openai": + // Use OpenAI-compatible API + // vLLM is recommended for production Kubernetes deployments (GPU-accelerated, high-throughput) + // Also supports: Ollama v1 API, OpenAI, or any OpenAI-compatible service + if config.BaseURL == "" { + return nil, fmt.Errorf("BaseURL is required for %s backend", config.BackendType) + } + if config.Model == "" { + return nil, fmt.Errorf("model is required for %s backend", config.BackendType) + } + backend, err = NewOpenAICompatibleBackend(config.BaseURL, config.Model, config.Dimension) + if err != nil { + return nil, fmt.Errorf("failed to initialize %s backend: %w", config.BackendType, err) + } + + default: + return nil, fmt.Errorf("unknown backend type: %s (supported: ollama, vllm, unified, openai)", config.BackendType) + } + + m := &Manager{ + config: config, + backend: backend, + } + + if config.EnableCache { + m.cache = newCache(config.MaxCacheSize) + } + + return m, nil +} + +// GenerateEmbedding generates embeddings for the given texts +// Returns a 2D slice where each row is an embedding for the corresponding text +// Uses all-MiniLM-L6-v2 by default (same model as codegate) +func (m *Manager) GenerateEmbedding(texts []string) ([][]float32, error) { + if len(texts) == 0 { + return nil, fmt.Errorf("no texts provided") + } + + // Check cache for single text requests + if len(texts) == 1 && m.config.EnableCache && m.cache != nil { + if cached := m.cache.Get(texts[0]); cached != nil { + logger.Debugf("Cache hit for embedding") + return [][]float32{cached}, nil + } + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Use backend to generate embeddings + embeddings, err := m.backend.EmbedBatch(texts) + if err != nil { + return nil, fmt.Errorf("failed to generate embeddings: %w", err) + } + + // Cache single embeddings + if len(texts) == 1 && m.config.EnableCache && m.cache != nil { + m.cache.Put(texts[0], embeddings[0]) + } + + logger.Debugf("Generated %d embeddings (dimension: %d)", len(embeddings), m.backend.Dimension()) + return embeddings, nil +} + +// GetCacheStats returns cache statistics +func (m *Manager) GetCacheStats() map[string]interface{} { + if !m.config.EnableCache || m.cache == nil { + return map[string]interface{}{ + "enabled": false, + } + } + + return map[string]interface{}{ + "enabled": true, + "hits": m.cache.hits, + "misses": m.cache.misses, + "size": m.cache.Size(), + "maxsize": m.config.MaxCacheSize, + } +} + +// ClearCache clears the embedding cache +func (m *Manager) ClearCache() { + if m.config.EnableCache && m.cache != nil { + m.cache.Clear() + logger.Info("Embedding cache cleared") + } +} + +// Close releases resources +func (m *Manager) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.backend != nil { + return m.backend.Close() + } + + return nil +} + +// Dimension returns the embedding dimension +func (m *Manager) Dimension() int { + if m.backend != nil { + return m.backend.Dimension() + } + return m.config.Dimension +} diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/manager_test_coverage.go b/cmd/thv-operator/pkg/optimizer/embeddings/manager_test_coverage.go new file mode 100644 index 0000000000..529d65ec4c --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/embeddings/manager_test_coverage.go @@ -0,0 +1,158 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package embeddings + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestManager_GetCacheStats tests cache statistics +func TestManager_GetCacheStats(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + EnableCache: true, + MaxCacheSize: 100, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + stats := manager.GetCacheStats() + require.NotNil(t, stats) + assert.True(t, stats["enabled"].(bool)) + assert.Contains(t, stats, "hits") + assert.Contains(t, stats, "misses") + assert.Contains(t, stats, "size") + assert.Contains(t, stats, "maxsize") +} + +// TestManager_GetCacheStats_Disabled tests cache statistics when cache is disabled +func TestManager_GetCacheStats_Disabled(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + EnableCache: false, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + stats := manager.GetCacheStats() + require.NotNil(t, stats) + assert.False(t, stats["enabled"].(bool)) +} + +// TestManager_ClearCache tests cache clearing +func TestManager_ClearCache(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + EnableCache: true, + MaxCacheSize: 100, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + // Clear cache should not panic + manager.ClearCache() + + // Multiple clears should be safe + manager.ClearCache() +} + +// TestManager_ClearCache_Disabled tests cache clearing when cache is disabled +func TestManager_ClearCache_Disabled(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + EnableCache: false, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + // Clear cache should not panic even when disabled + manager.ClearCache() +} + +// TestManager_Dimension tests dimension accessor +func TestManager_Dimension(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + dimension := manager.Dimension() + assert.Equal(t, 384, dimension) +} + +// TestManager_Dimension_Default tests default dimension +func TestManager_Dimension_Default(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + // Dimension not set, should default to 384 + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + dimension := manager.Dimension() + assert.Equal(t, 384, dimension) +} diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/ollama.go b/cmd/thv-operator/pkg/optimizer/embeddings/ollama.go new file mode 100644 index 0000000000..6cb6e1f8a2 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/embeddings/ollama.go @@ -0,0 +1,148 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package embeddings + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/stacklok/toolhive/pkg/logger" +) + +// OllamaBackend implements the Backend interface using Ollama +// This provides local embeddings without remote API calls +// Ollama must be running locally (ollama serve) +type OllamaBackend struct { + baseURL string + model string + dimension int + client *http.Client +} + +type ollamaEmbedRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` +} + +type ollamaEmbedResponse struct { + Embedding []float64 `json:"embedding"` +} + +// normalizeLocalhostURL converts localhost to 127.0.0.1 to avoid IPv6 resolution issues +func normalizeLocalhostURL(url string) string { + // Replace localhost with 127.0.0.1 to ensure IPv4 connection + // This prevents connection refused errors when Ollama only listens on IPv4 + return strings.ReplaceAll(url, "localhost", "127.0.0.1") +} + +// NewOllamaBackend creates a new Ollama backend +// Requires Ollama to be running locally: ollama serve +// Default model: all-minilm (all-MiniLM-L6-v2, 384 dimensions) +func NewOllamaBackend(baseURL, model string) (*OllamaBackend, error) { + if baseURL == "" { + baseURL = "http://127.0.0.1:11434" + } else { + // Normalize localhost to 127.0.0.1 to avoid IPv6 resolution issues + baseURL = normalizeLocalhostURL(baseURL) + } + if model == "" { + model = "all-minilm" // Default embedding model (all-MiniLM-L6-v2) + } + + logger.Infof("Initializing Ollama backend (model: %s, url: %s)", model, baseURL) + + // Determine dimension based on model + dimension := 384 // Default for all-minilm + if model == "nomic-embed-text" { + dimension = 768 + } + + backend := &OllamaBackend{ + baseURL: baseURL, + model: model, + dimension: dimension, + client: &http.Client{}, + } + + // Test connection + resp, err := backend.client.Get(baseURL) + if err != nil { + return nil, fmt.Errorf("failed to connect to Ollama at %s: %w (is 'ollama serve' running?)", baseURL, err) + } + _ = resp.Body.Close() + + logger.Info("Successfully connected to Ollama") + return backend, nil +} + +// Embed generates an embedding for a single text +func (o *OllamaBackend) Embed(text string) ([]float32, error) { + reqBody := ollamaEmbedRequest{ + Model: o.model, + Prompt: text, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + resp, err := o.client.Post( + o.baseURL+"/api/embeddings", + "application/json", + bytes.NewBuffer(jsonData), + ) + if err != nil { + return nil, fmt.Errorf("failed to call Ollama API: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("ollama API returned status %d: %s", resp.StatusCode, string(body)) + } + + var embedResp ollamaEmbedResponse + if err := json.NewDecoder(resp.Body).Decode(&embedResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + // Convert []float64 to []float32 + embedding := make([]float32, len(embedResp.Embedding)) + for i, v := range embedResp.Embedding { + embedding[i] = float32(v) + } + + return embedding, nil +} + +// EmbedBatch generates embeddings for multiple texts +func (o *OllamaBackend) EmbedBatch(texts []string) ([][]float32, error) { + embeddings := make([][]float32, len(texts)) + + for i, text := range texts { + emb, err := o.Embed(text) + if err != nil { + return nil, fmt.Errorf("failed to embed text %d: %w", i, err) + } + embeddings[i] = emb + } + + return embeddings, nil +} + +// Dimension returns the embedding dimension +func (o *OllamaBackend) Dimension() int { + return o.dimension +} + +// Close releases any resources +func (*OllamaBackend) Close() error { + // HTTP client doesn't need explicit cleanup + return nil +} diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/ollama_test.go b/cmd/thv-operator/pkg/optimizer/embeddings/ollama_test.go new file mode 100644 index 0000000000..16d7793e85 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/embeddings/ollama_test.go @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package embeddings + +import ( + "testing" +) + +func TestOllamaBackend_ConnectionFailure(t *testing.T) { + t.Parallel() + // This test verifies that Ollama backend handles connection failures gracefully + + // Test that NewOllamaBackend handles connection failure gracefully + _, err := NewOllamaBackend("http://localhost:99999", "all-minilm") + if err == nil { + t.Error("Expected error when connecting to invalid Ollama URL") + } +} + +func TestManagerWithOllama(t *testing.T) { + t.Parallel() + // Test that Manager works with Ollama when available + config := &Config{ + BackendType: BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: DefaultModelAllMiniLM, + Dimension: 768, + EnableCache: true, + MaxCacheSize: 100, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + defer manager.Close() + + // Test single embedding + embeddings, err := manager.GenerateEmbedding([]string{"test text"}) + if err != nil { + // Model might not be pulled - skip gracefully + t.Skipf("Skipping test: Failed to generate embedding. Error: %v. Run 'ollama pull nomic-embed-text'", err) + return + } + + if len(embeddings) != 1 { + t.Errorf("Expected 1 embedding, got %d", len(embeddings)) + } + + // Ollama all-minilm uses 384 dimensions + if len(embeddings[0]) != 384 { + t.Errorf("Expected dimension 384, got %d", len(embeddings[0])) + } + + // Test batch embeddings + texts := []string{"text 1", "text 2", "text 3"} + embeddings, err = manager.GenerateEmbedding(texts) + if err != nil { + // Model might not be pulled - skip gracefully + t.Skipf("Skipping test: Failed to generate batch embeddings. Error: %v. Run 'ollama pull nomic-embed-text'", err) + return + } + + if len(embeddings) != 3 { + t.Errorf("Expected 3 embeddings, got %d", len(embeddings)) + } +} diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/openai_compatible.go b/cmd/thv-operator/pkg/optimizer/embeddings/openai_compatible.go new file mode 100644 index 0000000000..c98adba54a --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/embeddings/openai_compatible.go @@ -0,0 +1,152 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package embeddings + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/stacklok/toolhive/pkg/logger" +) + +// OpenAICompatibleBackend implements the Backend interface for OpenAI-compatible APIs. +// +// Supported Services: +// - vLLM: Recommended for production Kubernetes deployments +// - High-throughput GPU-accelerated inference +// - PagedAttention for efficient GPU memory utilization +// - Superior scalability for multi-user environments +// - Ollama: Good for local development (via /v1/embeddings endpoint) +// - OpenAI: For cloud-based embeddings +// - Any OpenAI-compatible embedding service +// +// For production deployments, vLLM is strongly recommended due to its performance +// characteristics and Kubernetes-native design. +type OpenAICompatibleBackend struct { + baseURL string + model string + dimension int + client *http.Client +} + +type openaiEmbedRequest struct { + Model string `json:"model"` + Input string `json:"input"` // OpenAI standard uses "input" +} + +type openaiEmbedResponse struct { + Object string `json:"object"` + Data []struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` + Model string `json:"model"` +} + +// NewOpenAICompatibleBackend creates a new OpenAI-compatible backend. +// +// Examples: +// - vLLM: NewOpenAICompatibleBackend("http://vllm-service:8000", "sentence-transformers/all-MiniLM-L6-v2", 384) +// - Ollama: NewOpenAICompatibleBackend("http://localhost:11434", "nomic-embed-text", 768) +// - OpenAI: NewOpenAICompatibleBackend("https://api.openai.com", "text-embedding-3-small", 1536) +func NewOpenAICompatibleBackend(baseURL, model string, dimension int) (*OpenAICompatibleBackend, error) { + if baseURL == "" { + return nil, fmt.Errorf("baseURL is required for OpenAI-compatible backend") + } + if model == "" { + return nil, fmt.Errorf("model is required for OpenAI-compatible backend") + } + if dimension == 0 { + dimension = 384 // Default dimension + } + + logger.Infof("Initializing OpenAI-compatible backend (model: %s, url: %s)", model, baseURL) + + backend := &OpenAICompatibleBackend{ + baseURL: baseURL, + model: model, + dimension: dimension, + client: &http.Client{}, + } + + // Test connection + resp, err := backend.client.Get(baseURL) + if err != nil { + return nil, fmt.Errorf("failed to connect to %s: %w", baseURL, err) + } + _ = resp.Body.Close() + + logger.Info("Successfully connected to OpenAI-compatible service") + return backend, nil +} + +// Embed generates an embedding for a single text using OpenAI-compatible API +func (o *OpenAICompatibleBackend) Embed(text string) ([]float32, error) { + reqBody := openaiEmbedRequest{ + Model: o.model, + Input: text, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // Use standard OpenAI v1 endpoint + resp, err := o.client.Post( + o.baseURL+"/v1/embeddings", + "application/json", + bytes.NewBuffer(jsonData), + ) + if err != nil { + return nil, fmt.Errorf("failed to call embeddings API: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) + } + + var embedResp openaiEmbedResponse + if err := json.NewDecoder(resp.Body).Decode(&embedResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + if len(embedResp.Data) == 0 { + return nil, fmt.Errorf("no embeddings in response") + } + + return embedResp.Data[0].Embedding, nil +} + +// EmbedBatch generates embeddings for multiple texts +func (o *OpenAICompatibleBackend) EmbedBatch(texts []string) ([][]float32, error) { + embeddings := make([][]float32, len(texts)) + + for i, text := range texts { + emb, err := o.Embed(text) + if err != nil { + return nil, fmt.Errorf("failed to embed text %d: %w", i, err) + } + embeddings[i] = emb + } + + return embeddings, nil +} + +// Dimension returns the embedding dimension +func (o *OpenAICompatibleBackend) Dimension() int { + return o.dimension +} + +// Close releases any resources +func (*OpenAICompatibleBackend) Close() error { + // HTTP client doesn't need explicit cleanup + return nil +} diff --git a/cmd/thv-operator/pkg/optimizer/embeddings/openai_compatible_test.go b/cmd/thv-operator/pkg/optimizer/embeddings/openai_compatible_test.go new file mode 100644 index 0000000000..f9a686e953 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/embeddings/openai_compatible_test.go @@ -0,0 +1,226 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package embeddings + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +const testEmbeddingsEndpoint = "/v1/embeddings" + +func TestOpenAICompatibleBackend(t *testing.T) { + t.Parallel() + // Create a test server that mimics OpenAI-compatible API + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == testEmbeddingsEndpoint { + var req openaiEmbedRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("Failed to decode request: %v", err) + } + + // Return a mock embedding response + resp := openaiEmbedResponse{ + Object: "list", + Data: []struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + }{ + { + Object: "embedding", + Embedding: make([]float32, 384), + Index: 0, + }, + }, + Model: req.Model, + } + + // Fill with test data + for i := range resp.Data[0].Embedding { + resp.Data[0].Embedding[i] = float32(i) / 384.0 + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return + } + + // Health check endpoint + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Test backend creation + backend, err := NewOpenAICompatibleBackend(server.URL, "test-model", 384) + if err != nil { + t.Fatalf("Failed to create backend: %v", err) + } + defer backend.Close() + + // Test embedding generation + embedding, err := backend.Embed("test text") + if err != nil { + t.Fatalf("Failed to generate embedding: %v", err) + } + + if len(embedding) != 384 { + t.Errorf("Expected embedding dimension 384, got %d", len(embedding)) + } + + // Test batch embedding + texts := []string{"text1", "text2", "text3"} + embeddings, err := backend.EmbedBatch(texts) + if err != nil { + t.Fatalf("Failed to generate batch embeddings: %v", err) + } + + if len(embeddings) != len(texts) { + t.Errorf("Expected %d embeddings, got %d", len(texts), len(embeddings)) + } +} + +func TestOpenAICompatibleBackendErrors(t *testing.T) { + t.Parallel() + // Test missing baseURL + _, err := NewOpenAICompatibleBackend("", "model", 384) + if err == nil { + t.Error("Expected error for missing baseURL") + } + + // Test missing model + _, err = NewOpenAICompatibleBackend("http://localhost:8000", "", 384) + if err == nil { + t.Error("Expected error for missing model") + } +} + +func TestManagerWithVLLM(t *testing.T) { + t.Parallel() + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == testEmbeddingsEndpoint { + resp := openaiEmbedResponse{ + Object: "list", + Data: []struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + }{ + { + Object: "embedding", + Embedding: make([]float32, 384), + Index: 0, + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Test manager with vLLM backend + config := &Config{ + BackendType: "vllm", + BaseURL: server.URL, + Model: "sentence-transformers/all-MiniLM-L6-v2", + Dimension: 384, + EnableCache: true, + } + + manager, err := NewManager(config) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer manager.Close() + + // Test embedding generation + embeddings, err := manager.GenerateEmbedding([]string{"test"}) + if err != nil { + t.Fatalf("Failed to generate embeddings: %v", err) + } + + if len(embeddings) != 1 { + t.Errorf("Expected 1 embedding, got %d", len(embeddings)) + } + if len(embeddings[0]) != 384 { + t.Errorf("Expected dimension 384, got %d", len(embeddings[0])) + } +} + +func TestManagerWithUnified(t *testing.T) { + t.Parallel() + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == testEmbeddingsEndpoint { + resp := openaiEmbedResponse{ + Object: "list", + Data: []struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + }{ + { + Object: "embedding", + Embedding: make([]float32, 768), + Index: 0, + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Test manager with unified backend + config := &Config{ + BackendType: "unified", + BaseURL: server.URL, + Model: "nomic-embed-text", + Dimension: 768, + EnableCache: false, + } + + manager, err := NewManager(config) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer manager.Close() + + // Test embedding generation + embeddings, err := manager.GenerateEmbedding([]string{"test"}) + if err != nil { + t.Fatalf("Failed to generate embeddings: %v", err) + } + + if len(embeddings) != 1 { + t.Errorf("Expected 1 embedding, got %d", len(embeddings)) + } +} + +func TestManagerFallbackBehavior(t *testing.T) { + t.Parallel() + // Test that invalid vLLM backend fails gracefully during initialization + // (No fallback behavior is currently implemented) + config := &Config{ + BackendType: "vllm", + BaseURL: "http://invalid-host-that-does-not-exist:9999", + Model: "test-model", + Dimension: 384, + } + + _, err := NewManager(config) + if err == nil { + t.Error("Expected error when creating manager with invalid backend URL") + } + // Test passes if error is returned (no fallback behavior) +} diff --git a/cmd/thv-operator/pkg/optimizer/ingestion/errors.go b/cmd/thv-operator/pkg/optimizer/ingestion/errors.go new file mode 100644 index 0000000000..93e8eab31c --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/ingestion/errors.go @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package ingestion provides services for ingesting MCP tools into the database. +package ingestion + +import "errors" + +var ( + // ErrIngestionFailed is returned when ingestion fails + ErrIngestionFailed = errors.New("ingestion failed") + + // ErrBackendRetrievalFailed is returned when backend retrieval fails + ErrBackendRetrievalFailed = errors.New("backend retrieval failed") + + // ErrToolHiveUnavailable is returned when ToolHive is unavailable + ErrToolHiveUnavailable = errors.New("ToolHive unavailable") + + // ErrBackendStatusNil is returned when backend status is nil + ErrBackendStatusNil = errors.New("backend status cannot be nil") + + // ErrInvalidRuntimeMode is returned for invalid runtime mode + ErrInvalidRuntimeMode = errors.New("invalid runtime mode: must be 'docker' or 'k8s'") +) diff --git a/cmd/thv-operator/pkg/optimizer/ingestion/service.go b/cmd/thv-operator/pkg/optimizer/ingestion/service.go new file mode 100644 index 0000000000..0b78423e12 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/ingestion/service.go @@ -0,0 +1,346 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package ingestion + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/db" + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings" + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models" + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/tokens" + "github.com/stacklok/toolhive/pkg/logger" +) + +// Config holds configuration for the ingestion service +type Config struct { + // Database configuration + DBConfig *db.Config + + // Embedding configuration + EmbeddingConfig *embeddings.Config + + // MCP timeout in seconds + MCPTimeout int + + // Workloads to skip during ingestion + SkippedWorkloads []string + + // Runtime mode: "docker" or "k8s" + RuntimeMode string + + // Kubernetes configuration (used when RuntimeMode is "k8s") + K8sAPIServerURL string + K8sNamespace string + K8sAllNamespaces bool +} + +// Service handles ingestion of MCP backends and their tools +type Service struct { + config *Config + database *db.DB + embeddingManager *embeddings.Manager + tokenCounter *tokens.Counter + backendServerOps *db.BackendServerOps + backendToolOps *db.BackendToolOps + tracer trace.Tracer + + // Embedding time tracking + embeddingTimeMu sync.Mutex + totalEmbeddingTime time.Duration +} + +// NewService creates a new ingestion service +func NewService(config *Config) (*Service, error) { + // Set defaults + if config.MCPTimeout == 0 { + config.MCPTimeout = 30 + } + if len(config.SkippedWorkloads) == 0 { + config.SkippedWorkloads = []string{"inspector", "mcp-optimizer"} + } + + // Initialize database + database, err := db.NewDB(config.DBConfig) + if err != nil { + return nil, fmt.Errorf("failed to initialize database: %w", err) + } + + // Clear database on startup to ensure fresh embeddings + // This is important when the embedding model changes or for consistency + database.Reset() + logger.Info("Cleared optimizer database on startup") + + // Initialize embedding manager + embeddingManager, err := embeddings.NewManager(config.EmbeddingConfig) + if err != nil { + _ = database.Close() + return nil, fmt.Errorf("failed to initialize embedding manager: %w", err) + } + + // Initialize token counter + tokenCounter := tokens.NewCounter() + + // Initialize tracer + tracer := otel.Tracer("github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/ingestion") + + svc := &Service{ + config: config, + database: database, + embeddingManager: embeddingManager, + tokenCounter: tokenCounter, + tracer: tracer, + totalEmbeddingTime: 0, + } + + // Create chromem-go embeddingFunc from our embedding manager with tracing + embeddingFunc := func(ctx context.Context, text string) ([]float32, error) { + // Create a span for embedding calculation + _, span := svc.tracer.Start(ctx, "optimizer.ingestion.calculate_embedding", + trace.WithAttributes( + attribute.String("operation", "embedding_calculation"), + )) + defer span.End() + + start := time.Now() + + // Our manager takes a slice, so wrap the single text + embeddingsResult, err := embeddingManager.GenerateEmbedding([]string{text}) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return nil, err + } + if len(embeddingsResult) == 0 { + err := fmt.Errorf("no embeddings generated") + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return nil, err + } + + // Track embedding time + duration := time.Since(start) + svc.embeddingTimeMu.Lock() + svc.totalEmbeddingTime += duration + svc.embeddingTimeMu.Unlock() + + span.SetAttributes( + attribute.Int64("embedding.duration_ms", duration.Milliseconds()), + ) + + return embeddingsResult[0], nil + } + + svc.backendServerOps = db.NewBackendServerOps(database, embeddingFunc) + svc.backendToolOps = db.NewBackendToolOps(database, embeddingFunc) + + logger.Info("Ingestion service initialized for event-driven ingestion (chromem-go)") + return svc, nil +} + +// IngestServer ingests a single MCP server and its tools into the optimizer database. +// This is called by vMCP during session registration for each backend server. +// +// Parameters: +// - serverID: Unique identifier for the backend server +// - serverName: Human-readable server name +// - description: Optional server description +// - tools: List of tools available from this server +// +// This method will: +// 1. Create or update the backend server record (simplified metadata only) +// 2. Generate embeddings for server and tools +// 3. Count tokens for each tool +// 4. Store everything in the database for semantic search +// +// Note: URL, transport, status are NOT stored - vMCP manages backend lifecycle +func (s *Service) IngestServer( + ctx context.Context, + serverID string, + serverName string, + description *string, + tools []mcp.Tool, +) error { + // Create a span for the entire ingestion operation + ctx, span := s.tracer.Start(ctx, "optimizer.ingestion.ingest_server", + trace.WithAttributes( + attribute.String("server.id", serverID), + attribute.String("server.name", serverName), + attribute.Int("tools.count", len(tools)), + )) + defer span.End() + + start := time.Now() + logger.Infof("Ingesting server: %s (%d tools) [serverID=%s]", serverName, len(tools), serverID) + + // Create backend server record (simplified - vMCP manages lifecycle) + // chromem-go will generate embeddings automatically from the content + backendServer := &models.BackendServer{ + ID: serverID, + Name: serverName, + Description: description, + Group: "default", // TODO: Pass group from vMCP if needed + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + // Create or update server (chromem-go handles embeddings) + if err := s.backendServerOps.Update(ctx, backendServer); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return fmt.Errorf("failed to create/update server %s: %w", serverName, err) + } + logger.Debugf("Created/updated server: %s", serverName) + + // Sync tools for this server + toolCount, err := s.syncBackendTools(ctx, serverID, serverName, tools) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return fmt.Errorf("failed to sync tools for %s: %w", serverName, err) + } + + duration := time.Since(start) + span.SetAttributes( + attribute.Int64("ingestion.duration_ms", duration.Milliseconds()), + attribute.Int("tools.ingested", toolCount), + ) + + logger.Infow("Successfully ingested server", + "server_name", serverName, + "server_id", serverID, + "tools_count", toolCount, + "duration_ms", duration.Milliseconds()) + return nil +} + +// syncBackendTools synchronizes tools for a backend server +func (s *Service) syncBackendTools(ctx context.Context, serverID string, serverName string, tools []mcp.Tool) (int, error) { + // Create a span for tool synchronization + ctx, span := s.tracer.Start(ctx, "optimizer.ingestion.sync_backend_tools", + trace.WithAttributes( + attribute.String("server.id", serverID), + attribute.String("server.name", serverName), + attribute.Int("tools.count", len(tools)), + )) + defer span.End() + + logger.Debugf("syncBackendTools: server=%s, serverID=%s, tool_count=%d", serverName, serverID, len(tools)) + + // Delete existing tools + if err := s.backendToolOps.DeleteByServer(ctx, serverID); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return 0, fmt.Errorf("failed to delete existing tools: %w", err) + } + + if len(tools) == 0 { + return 0, nil + } + + // Create tool records (chromem-go will generate embeddings automatically) + for _, tool := range tools { + // Extract description for embedding + description := tool.Description + + // Convert InputSchema to JSON + schemaJSON, err := json.Marshal(tool.InputSchema) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return 0, fmt.Errorf("failed to marshal input schema for tool %s: %w", tool.Name, err) + } + + backendTool := &models.BackendTool{ + ID: uuid.New().String(), + MCPServerID: serverID, + ToolName: tool.Name, + Description: &description, + InputSchema: schemaJSON, + TokenCount: s.tokenCounter.CountToolTokens(tool), + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + if err := s.backendToolOps.Create(ctx, backendTool, serverName); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return 0, fmt.Errorf("failed to create tool %s: %w", tool.Name, err) + } + } + + logger.Infof("Synced %d tools for server %s", len(tools), serverName) + return len(tools), nil +} + +// GetEmbeddingManager returns the embedding manager for this service +func (s *Service) GetEmbeddingManager() *embeddings.Manager { + return s.embeddingManager +} + +// GetBackendToolOps returns the backend tool operations for search and retrieval +func (s *Service) GetBackendToolOps() *db.BackendToolOps { + return s.backendToolOps +} + +// GetTotalToolTokens returns the total token count across all tools in the database +func (s *Service) GetTotalToolTokens(ctx context.Context) int { + // Use FTS database to efficiently count all tool tokens + if s.database.GetFTSDB() != nil { + totalTokens, err := s.database.GetFTSDB().GetTotalToolTokens(ctx) + if err != nil { + logger.Warnw("Failed to get total tool tokens from FTS", "error", err) + return 0 + } + return totalTokens + } + + // Fallback: query all tools (less efficient but works) + logger.Warn("FTS database not available, using fallback for token counting") + return 0 +} + +// GetTotalEmbeddingTime returns the total time spent calculating embeddings +func (s *Service) GetTotalEmbeddingTime() time.Duration { + s.embeddingTimeMu.Lock() + defer s.embeddingTimeMu.Unlock() + return s.totalEmbeddingTime +} + +// ResetEmbeddingTime resets the total embedding time counter +func (s *Service) ResetEmbeddingTime() { + s.embeddingTimeMu.Lock() + defer s.embeddingTimeMu.Unlock() + s.totalEmbeddingTime = 0 +} + +// Close releases resources +func (s *Service) Close() error { + var errs []error + + if err := s.embeddingManager.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close embedding manager: %w", err)) + } + + if err := s.database.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close database: %w", err)) + } + + if len(errs) > 0 { + return fmt.Errorf("errors closing service: %v", errs) + } + + return nil +} diff --git a/cmd/thv-operator/pkg/optimizer/ingestion/service_test.go b/cmd/thv-operator/pkg/optimizer/ingestion/service_test.go new file mode 100644 index 0000000000..0475737071 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/ingestion/service_test.go @@ -0,0 +1,253 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package ingestion + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/db" + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings" +) + +// TestServiceCreationAndIngestion demonstrates the complete chromem-go workflow: +// 1. Create in-memory database +// 2. Initialize ingestion service +// 3. Ingest server and tools +// 4. Query the database +func TestServiceCreationAndIngestion(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Create temporary directory for persistence (optional) + tmpDir := t.TempDir() + + // Try to use Ollama if available, otherwise skip test + // Check for the actual model we'll use: nomic-embed-text + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "nomic-embed-text", + Dimension: 768, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available or model not found. Error: %v. Run 'ollama serve && ollama pull nomic-embed-text'", err) + return + } + _ = embeddingManager.Close() + + // Initialize service with Ollama embeddings + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "nomic-embed-text", + Dimension: 768, + }, + } + + svc, err := NewService(config) + if err != nil { + t.Skipf("Skipping test: Failed to create service. Error: %v. Run 'ollama serve && ollama pull nomic-embed-text'", err) + return + } + defer func() { _ = svc.Close() }() + + // Create test tools + tools := []mcp.Tool{ + { + Name: "get_weather", + Description: "Get the current weather for a location", + }, + { + Name: "search_web", + Description: "Search the web for information", + }, + } + + // Ingest server with tools + serverName := "test-server" + serverID := "test-server-id" + description := "A test MCP server" + + err = svc.IngestServer(ctx, serverID, serverName, &description, tools) + if err != nil { + // Check if error is due to missing model + errStr := err.Error() + if strings.Contains(errStr, "model") || strings.Contains(errStr, "not found") || strings.Contains(errStr, "404") { + t.Skipf("Skipping test: Model not available. Error: %v. Run 'ollama serve && ollama pull nomic-embed-text'", err) + return + } + require.NoError(t, err) + } + + // Query tools + allTools, err := svc.backendToolOps.ListByServer(ctx, serverID) + require.NoError(t, err) + require.Len(t, allTools, 2, "Expected 2 tools to be ingested") + + // Verify tool names + toolNames := make(map[string]bool) + for _, tool := range allTools { + toolNames[tool.ToolName] = true + } + require.True(t, toolNames["get_weather"], "get_weather tool should be present") + require.True(t, toolNames["search_web"], "search_web tool should be present") + + // Search for similar tools + results, err := svc.backendToolOps.Search(ctx, "weather information", 5, &serverID) + require.NoError(t, err) + require.NotEmpty(t, results, "Should find at least one similar tool") + + require.NotEmpty(t, results, "Should return at least one result") + + // Weather tool should be most similar to weather query + require.Equal(t, "get_weather", results[0].ToolName, + "Weather tool should be most similar to weather query") + toolNamesFound := make(map[string]bool) + for _, result := range results { + toolNamesFound[result.ToolName] = true + } + require.True(t, toolNamesFound["get_weather"], "get_weather should be in results") + require.True(t, toolNamesFound["search_web"], "search_web should be in results") +} + +// TestService_EmbeddingTimeTracking tests that embedding time is tracked correctly +func TestService_EmbeddingTimeTracking(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + + // Initialize service + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Initially, embedding time should be 0 + initialTime := svc.GetTotalEmbeddingTime() + require.Equal(t, time.Duration(0), initialTime, "Initial embedding time should be 0") + + // Create test tools + tools := []mcp.Tool{ + { + Name: "test_tool_1", + Description: "First test tool for embedding", + }, + { + Name: "test_tool_2", + Description: "Second test tool for embedding", + }, + } + + // Reset embedding time before ingestion + svc.ResetEmbeddingTime() + + // Ingest server with tools (this will generate embeddings) + err = svc.IngestServer(ctx, "test-server-id", "TestServer", nil, tools) + require.NoError(t, err) + + // After ingestion, embedding time should be greater than 0 + totalEmbeddingTime := svc.GetTotalEmbeddingTime() + require.Greater(t, totalEmbeddingTime, time.Duration(0), + "Total embedding time should be greater than 0 after ingestion") + + // Reset and verify it's back to 0 + svc.ResetEmbeddingTime() + resetTime := svc.GetTotalEmbeddingTime() + require.Equal(t, time.Duration(0), resetTime, "Embedding time should be 0 after reset") +} + +// TestServiceWithOllama demonstrates using real embeddings (requires Ollama running) +// This test can be enabled locally to verify Ollama integration +func TestServiceWithOllama(t *testing.T) { + t.Parallel() + + // Skip if not explicitly enabled or Ollama is not available + if os.Getenv("TEST_OLLAMA") != "true" { + t.Skip("Skipping Ollama integration test (set TEST_OLLAMA=true to enable)") + } + + ctx := context.Background() + tmpDir := t.TempDir() + + // Initialize service with Ollama embeddings + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "ollama-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "nomic-embed-text", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Create test tools + tools := []mcp.Tool{ + { + Name: "get_weather", + Description: "Get current weather conditions for any location worldwide", + }, + { + Name: "send_email", + Description: "Send an email message to a recipient", + }, + } + + // Ingest server + err = svc.IngestServer(ctx, "server-1", "TestServer", nil, tools) + require.NoError(t, err) + + // Search for weather-related tools + results, err := svc.backendToolOps.Search(ctx, "What's the temperature outside?", 5, nil) + require.NoError(t, err) + require.NotEmpty(t, results) + + require.Equal(t, "get_weather", results[0].ToolName, + "Weather tool should be most similar to weather query") +} diff --git a/cmd/thv-operator/pkg/optimizer/ingestion/service_test_coverage.go b/cmd/thv-operator/pkg/optimizer/ingestion/service_test_coverage.go new file mode 100644 index 0000000000..a068eab687 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/ingestion/service_test_coverage.go @@ -0,0 +1,285 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package ingestion + +import ( + "context" + "path/filepath" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/db" + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings" +) + +// TestService_GetTotalToolTokens tests token counting +func TestService_GetTotalToolTokens(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Ingest some tools + tools := []mcp.Tool{ + { + Name: "tool1", + Description: "Tool 1", + }, + { + Name: "tool2", + Description: "Tool 2", + }, + } + + err = svc.IngestServer(ctx, "server-1", "TestServer", nil, tools) + require.NoError(t, err) + + // Get total tokens + totalTokens := svc.GetTotalToolTokens(ctx) + assert.GreaterOrEqual(t, totalTokens, 0, "Total tokens should be non-negative") +} + +// TestService_GetTotalToolTokens_NoFTS tests token counting without FTS +func TestService_GetTotalToolTokens_NoFTS(t *testing.T) { + t.Parallel() + ctx := context.Background() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: "", // In-memory + FTSDBPath: "", // Will default to :memory: + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Get total tokens (should use FTS if available, fallback otherwise) + totalTokens := svc.GetTotalToolTokens(ctx) + assert.GreaterOrEqual(t, totalTokens, 0, "Total tokens should be non-negative") +} + +// TestService_GetBackendToolOps tests backend tool ops accessor +func TestService_GetBackendToolOps(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + toolOps := svc.GetBackendToolOps() + require.NotNil(t, toolOps) +} + +// TestService_GetEmbeddingManager tests embedding manager accessor +func TestService_GetEmbeddingManager(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + manager := svc.GetEmbeddingManager() + require.NotNil(t, manager) +} + +// TestService_IngestServer_ErrorHandling tests error handling during ingestion +func TestService_IngestServer_ErrorHandling(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Test with empty tools list + err = svc.IngestServer(ctx, "server-1", "TestServer", nil, []mcp.Tool{}) + require.NoError(t, err, "Should handle empty tools list gracefully") + + // Test with nil description + err = svc.IngestServer(ctx, "server-2", "TestServer2", nil, []mcp.Tool{ + { + Name: "tool1", + Description: "Tool 1", + }, + }) + require.NoError(t, err, "Should handle nil description gracefully") +} + +// TestService_Close_ErrorHandling tests error handling during close +func TestService_Close_ErrorHandling(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + + // Close should succeed + err = svc.Close() + require.NoError(t, err) + + // Multiple closes should be safe + err = svc.Close() + require.NoError(t, err) +} diff --git a/cmd/thv-operator/pkg/optimizer/models/errors.go b/cmd/thv-operator/pkg/optimizer/models/errors.go new file mode 100644 index 0000000000..c5b10eebe6 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/models/errors.go @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package models defines domain models for the optimizer. +// It includes structures for MCP servers, tools, and related metadata. +package models + +import "errors" + +var ( + // ErrRemoteServerMissingURL is returned when a remote server doesn't have a URL + ErrRemoteServerMissingURL = errors.New("remote servers must have URL") + + // ErrContainerServerMissingPackage is returned when a container server doesn't have a package + ErrContainerServerMissingPackage = errors.New("container servers must have package") + + // ErrInvalidTokenMetrics is returned when token metrics are inconsistent + ErrInvalidTokenMetrics = errors.New("invalid token metrics: calculated values don't match") +) diff --git a/cmd/thv-operator/pkg/optimizer/models/models.go b/cmd/thv-operator/pkg/optimizer/models/models.go new file mode 100644 index 0000000000..6c810fbe04 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/models/models.go @@ -0,0 +1,176 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package models + +import ( + "encoding/json" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// BaseMCPServer represents the common fields for MCP servers. +type BaseMCPServer struct { + ID string `json:"id"` + Name string `json:"name"` + Remote bool `json:"remote"` + Transport TransportType `json:"transport"` + Description *string `json:"description,omitempty"` + ServerEmbedding []float32 `json:"-"` // Excluded from JSON, stored as BLOB + Group string `json:"group"` + LastUpdated time.Time `json:"last_updated"` + CreatedAt time.Time `json:"created_at"` +} + +// RegistryServer represents an MCP server from the registry catalog. +type RegistryServer struct { + BaseMCPServer + URL *string `json:"url,omitempty"` // For remote servers + Package *string `json:"package,omitempty"` // For container servers +} + +// Validate checks if the registry server has valid data. +// Remote servers must have URL, container servers must have package. +func (r *RegistryServer) Validate() error { + if r.Remote && r.URL == nil { + return ErrRemoteServerMissingURL + } + if !r.Remote && r.Package == nil { + return ErrContainerServerMissingPackage + } + return nil +} + +// BackendServer represents a running MCP server backend. +// Simplified: Only stores metadata needed for tool organization and search results. +// vMCP manages backend lifecycle (URL, status, transport, etc.) +type BackendServer struct { + ID string `json:"id"` + Name string `json:"name"` + Description *string `json:"description,omitempty"` + Group string `json:"group"` + ServerEmbedding []float32 `json:"-"` // Excluded from JSON, stored as BLOB + LastUpdated time.Time `json:"last_updated"` + CreatedAt time.Time `json:"created_at"` +} + +// BaseTool represents the common fields for tools. +type BaseTool struct { + ID string `json:"id"` + MCPServerID string `json:"mcpserver_id"` + Details mcp.Tool `json:"details"` + DetailsEmbedding []float32 `json:"-"` // Excluded from JSON, stored as BLOB + LastUpdated time.Time `json:"last_updated"` + CreatedAt time.Time `json:"created_at"` +} + +// RegistryTool represents a tool from a registry MCP server. +type RegistryTool struct { + BaseTool +} + +// BackendTool represents a tool from a backend MCP server. +// With chromem-go, embeddings are managed by the database. +type BackendTool struct { + ID string `json:"id"` + MCPServerID string `json:"mcpserver_id"` + ToolName string `json:"tool_name"` + Description *string `json:"description,omitempty"` + InputSchema json.RawMessage `json:"input_schema,omitempty"` + ToolEmbedding []float32 `json:"-"` // Managed by chromem-go + TokenCount int `json:"token_count"` + LastUpdated time.Time `json:"last_updated"` + CreatedAt time.Time `json:"created_at"` +} + +// ToolDetailsToJSON converts mcp.Tool to JSON for storage in the database. +func ToolDetailsToJSON(tool mcp.Tool) (string, error) { + data, err := json.Marshal(tool) + if err != nil { + return "", err + } + return string(data), nil +} + +// ToolDetailsFromJSON converts JSON to mcp.Tool +func ToolDetailsFromJSON(data string) (*mcp.Tool, error) { + var tool mcp.Tool + err := json.Unmarshal([]byte(data), &tool) + if err != nil { + return nil, err + } + return &tool, nil +} + +// BackendToolWithMetadata represents a backend tool with similarity score. +type BackendToolWithMetadata struct { + BackendTool + Similarity float32 `json:"similarity"` // Cosine similarity from chromem-go (0-1, higher is better) +} + +// RegistryToolWithMetadata represents a registry tool with server information and similarity distance. +type RegistryToolWithMetadata struct { + ServerName string `json:"server_name"` + ServerDescription *string `json:"server_description,omitempty"` + Distance float64 `json:"distance"` // Cosine distance from query embedding + Tool RegistryTool `json:"tool"` +} + +// BackendWithRegistry represents a backend server with its resolved registry relationship. +type BackendWithRegistry struct { + Backend BackendServer `json:"backend"` + Registry *RegistryServer `json:"registry,omitempty"` // NULL if autonomous +} + +// EffectiveDescription returns the description (inherited from registry or own). +func (b *BackendWithRegistry) EffectiveDescription() *string { + if b.Registry != nil { + return b.Registry.Description + } + return b.Backend.Description +} + +// EffectiveEmbedding returns the embedding (inherited from registry or own). +func (b *BackendWithRegistry) EffectiveEmbedding() []float32 { + if b.Registry != nil { + return b.Registry.ServerEmbedding + } + return b.Backend.ServerEmbedding +} + +// ServerNameForTools returns the server name to use as context for tool embeddings. +func (b *BackendWithRegistry) ServerNameForTools() string { + if b.Registry != nil { + return b.Registry.Name + } + return b.Backend.Name +} + +// TokenMetrics represents token efficiency metrics for tool filtering. +type TokenMetrics struct { + BaselineTokens int `json:"baseline_tokens"` // Total tokens for all running server tools + ReturnedTokens int `json:"returned_tokens"` // Total tokens for returned/filtered tools + TokensSaved int `json:"tokens_saved"` // Number of tokens saved by filtering + SavingsPercentage float64 `json:"savings_percentage"` // Percentage of tokens saved (0-100) +} + +// Validate checks if the token metrics are consistent. +func (t *TokenMetrics) Validate() error { + if t.TokensSaved != t.BaselineTokens-t.ReturnedTokens { + return ErrInvalidTokenMetrics + } + + var expectedPct float64 + if t.BaselineTokens > 0 { + expectedPct = (float64(t.TokensSaved) / float64(t.BaselineTokens)) * 100 + // Allow small floating point differences (0.01%) + if expectedPct-t.SavingsPercentage > 0.01 || t.SavingsPercentage-expectedPct > 0.01 { + return ErrInvalidTokenMetrics + } + } else if t.SavingsPercentage != 0.0 { + return ErrInvalidTokenMetrics + } + + return nil +} diff --git a/cmd/thv-operator/pkg/optimizer/models/models_test.go b/cmd/thv-operator/pkg/optimizer/models/models_test.go new file mode 100644 index 0000000000..af06e90bf4 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/models/models_test.go @@ -0,0 +1,273 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package models + +import ( + "testing" + + "github.com/mark3labs/mcp-go/mcp" +) + +func TestRegistryServer_Validate(t *testing.T) { + t.Parallel() + url := "http://example.com/mcp" + pkg := "github.com/example/mcp-server" + + tests := []struct { + name string + server *RegistryServer + wantErr bool + }{ + { + name: "Remote server with URL is valid", + server: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Remote: true, + }, + URL: &url, + }, + wantErr: false, + }, + { + name: "Container server with package is valid", + server: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Remote: false, + }, + Package: &pkg, + }, + wantErr: false, + }, + { + name: "Remote server without URL is invalid", + server: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Remote: true, + }, + }, + wantErr: true, + }, + { + name: "Container server without package is invalid", + server: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Remote: false, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := tt.server.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("RegistryServer.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestToolDetailsToJSON(t *testing.T) { + t.Parallel() + tool := mcp.Tool{ + Name: "test_tool", + Description: "A test tool", + } + + json, err := ToolDetailsToJSON(tool) + if err != nil { + t.Fatalf("ToolDetailsToJSON() error = %v", err) + } + + if json == "" { + t.Error("ToolDetailsToJSON() returned empty string") + } + + // Try to parse it back + parsed, err := ToolDetailsFromJSON(json) + if err != nil { + t.Fatalf("ToolDetailsFromJSON() error = %v", err) + } + + if parsed.Name != tool.Name { + t.Errorf("Tool name mismatch: got %v, want %v", parsed.Name, tool.Name) + } + + if parsed.Description != tool.Description { + t.Errorf("Tool description mismatch: got %v, want %v", parsed.Description, tool.Description) + } +} + +func TestTokenMetrics_Validate(t *testing.T) { + t.Parallel() + tests := []struct { + name string + metrics *TokenMetrics + wantErr bool + }{ + { + name: "Valid metrics with savings", + metrics: &TokenMetrics{ + BaselineTokens: 1000, + ReturnedTokens: 600, + TokensSaved: 400, + SavingsPercentage: 40.0, + }, + wantErr: false, + }, + { + name: "Valid metrics with no savings", + metrics: &TokenMetrics{ + BaselineTokens: 1000, + ReturnedTokens: 1000, + TokensSaved: 0, + SavingsPercentage: 0.0, + }, + wantErr: false, + }, + { + name: "Invalid: tokens saved doesn't match", + metrics: &TokenMetrics{ + BaselineTokens: 1000, + ReturnedTokens: 600, + TokensSaved: 500, // Should be 400 + SavingsPercentage: 40.0, + }, + wantErr: true, + }, + { + name: "Invalid: savings percentage doesn't match", + metrics: &TokenMetrics{ + BaselineTokens: 1000, + ReturnedTokens: 600, + TokensSaved: 400, + SavingsPercentage: 50.0, // Should be 40.0 + }, + wantErr: true, + }, + { + name: "Invalid: non-zero percentage with zero baseline", + metrics: &TokenMetrics{ + BaselineTokens: 0, + ReturnedTokens: 0, + TokensSaved: 0, + SavingsPercentage: 10.0, // Should be 0 + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := tt.metrics.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("TokenMetrics.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestBackendWithRegistry_EffectiveDescription(t *testing.T) { + t.Parallel() + registryDesc := "Registry description" + backendDesc := "Backend description" + + tests := []struct { + name string + w *BackendWithRegistry + want *string + }{ + { + name: "Uses registry description when available", + w: &BackendWithRegistry{ + Backend: BackendServer{ + Description: &backendDesc, + }, + Registry: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Description: ®istryDesc, + }, + }, + }, + want: ®istryDesc, + }, + { + name: "Uses backend description when no registry", + w: &BackendWithRegistry{ + Backend: BackendServer{ + Description: &backendDesc, + }, + Registry: nil, + }, + want: &backendDesc, + }, + { + name: "Returns nil when no description", + w: &BackendWithRegistry{ + Backend: BackendServer{}, + Registry: nil, + }, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := tt.w.EffectiveDescription() + if (got == nil) != (tt.want == nil) { + t.Errorf("BackendWithRegistry.EffectiveDescription() = %v, want %v", got, tt.want) + } + if got != nil && tt.want != nil && *got != *tt.want { + t.Errorf("BackendWithRegistry.EffectiveDescription() = %v, want %v", *got, *tt.want) + } + }) + } +} + +func TestBackendWithRegistry_ServerNameForTools(t *testing.T) { + t.Parallel() + tests := []struct { + name string + w *BackendWithRegistry + want string + }{ + { + name: "Uses registry name when available", + w: &BackendWithRegistry{ + Backend: BackendServer{ + Name: "backend-name", + }, + Registry: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Name: "registry-name", + }, + }, + }, + want: "registry-name", + }, + { + name: "Uses backend name when no registry", + w: &BackendWithRegistry{ + Backend: BackendServer{ + Name: "backend-name", + }, + Registry: nil, + }, + want: "backend-name", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.w.ServerNameForTools(); got != tt.want { + t.Errorf("BackendWithRegistry.ServerNameForTools() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/thv-operator/pkg/optimizer/models/transport.go b/cmd/thv-operator/pkg/optimizer/models/transport.go new file mode 100644 index 0000000000..8764b7fd48 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/models/transport.go @@ -0,0 +1,114 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package models + +import ( + "database/sql/driver" + "fmt" +) + +// TransportType represents the transport protocol used by an MCP server. +// Maps 1:1 to ToolHive transport modes. +type TransportType string + +const ( + // TransportSSE represents Server-Sent Events transport + TransportSSE TransportType = "sse" + // TransportStreamable represents Streamable HTTP transport + TransportStreamable TransportType = "streamable-http" +) + +// Valid returns true if the transport type is valid +func (t TransportType) Valid() bool { + switch t { + case TransportSSE, TransportStreamable: + return true + default: + return false + } +} + +// String returns the string representation +func (t TransportType) String() string { + return string(t) +} + +// Value implements the driver.Valuer interface for database storage +func (t TransportType) Value() (driver.Value, error) { + if !t.Valid() { + return nil, fmt.Errorf("invalid transport type: %s", t) + } + return string(t), nil +} + +// Scan implements the sql.Scanner interface for database retrieval +func (t *TransportType) Scan(value interface{}) error { + if value == nil { + return fmt.Errorf("transport type cannot be nil") + } + + str, ok := value.(string) + if !ok { + return fmt.Errorf("transport type must be a string, got %T", value) + } + + *t = TransportType(str) + if !t.Valid() { + return fmt.Errorf("invalid transport type from database: %s", str) + } + + return nil +} + +// MCPStatus represents the status of an MCP server backend. +type MCPStatus string + +const ( + // StatusRunning indicates the backend is running + StatusRunning MCPStatus = "running" + // StatusStopped indicates the backend is stopped + StatusStopped MCPStatus = "stopped" +) + +// Valid returns true if the status is valid +func (s MCPStatus) Valid() bool { + switch s { + case StatusRunning, StatusStopped: + return true + default: + return false + } +} + +// String returns the string representation +func (s MCPStatus) String() string { + return string(s) +} + +// Value implements the driver.Valuer interface for database storage +func (s MCPStatus) Value() (driver.Value, error) { + if !s.Valid() { + return nil, fmt.Errorf("invalid MCP status: %s", s) + } + return string(s), nil +} + +// Scan implements the sql.Scanner interface for database retrieval +func (s *MCPStatus) Scan(value interface{}) error { + if value == nil { + return fmt.Errorf("MCP status cannot be nil") + } + + str, ok := value.(string) + if !ok { + return fmt.Errorf("MCP status must be a string, got %T", value) + } + + *s = MCPStatus(str) + if !s.Valid() { + return fmt.Errorf("invalid MCP status from database: %s", str) + } + + return nil +} diff --git a/cmd/thv-operator/pkg/optimizer/models/transport_test.go b/cmd/thv-operator/pkg/optimizer/models/transport_test.go new file mode 100644 index 0000000000..156062c595 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/models/transport_test.go @@ -0,0 +1,276 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package models + +import ( + "testing" +) + +func TestTransportType_Valid(t *testing.T) { + t.Parallel() + tests := []struct { + name string + transport TransportType + want bool + }{ + { + name: "SSE transport is valid", + transport: TransportSSE, + want: true, + }, + { + name: "Streamable transport is valid", + transport: TransportStreamable, + want: true, + }, + { + name: "Invalid transport is not valid", + transport: TransportType("invalid"), + want: false, + }, + { + name: "Empty transport is not valid", + transport: TransportType(""), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.transport.Valid(); got != tt.want { + t.Errorf("TransportType.Valid() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTransportType_Value(t *testing.T) { + t.Parallel() + tests := []struct { + name string + transport TransportType + wantValue string + wantErr bool + }{ + { + name: "SSE transport value", + transport: TransportSSE, + wantValue: "sse", + wantErr: false, + }, + { + name: "Streamable transport value", + transport: TransportStreamable, + wantValue: "streamable-http", + wantErr: false, + }, + { + name: "Invalid transport returns error", + transport: TransportType("invalid"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := tt.transport.Value() + if (err != nil) != tt.wantErr { + t.Errorf("TransportType.Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.wantValue { + t.Errorf("TransportType.Value() = %v, want %v", got, tt.wantValue) + } + }) + } +} + +func TestTransportType_Scan(t *testing.T) { + t.Parallel() + tests := []struct { + name string + value interface{} + want TransportType + wantErr bool + }{ + { + name: "Scan SSE transport", + value: "sse", + want: TransportSSE, + wantErr: false, + }, + { + name: "Scan streamable transport", + value: "streamable-http", + want: TransportStreamable, + wantErr: false, + }, + { + name: "Scan invalid transport returns error", + value: "invalid", + wantErr: true, + }, + { + name: "Scan nil returns error", + value: nil, + wantErr: true, + }, + { + name: "Scan non-string returns error", + value: 123, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var transport TransportType + err := transport.Scan(tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("TransportType.Scan() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && transport != tt.want { + t.Errorf("TransportType.Scan() = %v, want %v", transport, tt.want) + } + }) + } +} + +func TestMCPStatus_Valid(t *testing.T) { + t.Parallel() + tests := []struct { + name string + status MCPStatus + want bool + }{ + { + name: "Running status is valid", + status: StatusRunning, + want: true, + }, + { + name: "Stopped status is valid", + status: StatusStopped, + want: true, + }, + { + name: "Invalid status is not valid", + status: MCPStatus("invalid"), + want: false, + }, + { + name: "Empty status is not valid", + status: MCPStatus(""), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.status.Valid(); got != tt.want { + t.Errorf("MCPStatus.Valid() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMCPStatus_Value(t *testing.T) { + t.Parallel() + tests := []struct { + name string + status MCPStatus + wantValue string + wantErr bool + }{ + { + name: "Running status value", + status: StatusRunning, + wantValue: "running", + wantErr: false, + }, + { + name: "Stopped status value", + status: StatusStopped, + wantValue: "stopped", + wantErr: false, + }, + { + name: "Invalid status returns error", + status: MCPStatus("invalid"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := tt.status.Value() + if (err != nil) != tt.wantErr { + t.Errorf("MCPStatus.Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.wantValue { + t.Errorf("MCPStatus.Value() = %v, want %v", got, tt.wantValue) + } + }) + } +} + +func TestMCPStatus_Scan(t *testing.T) { + t.Parallel() + tests := []struct { + name string + value interface{} + want MCPStatus + wantErr bool + }{ + { + name: "Scan running status", + value: "running", + want: StatusRunning, + wantErr: false, + }, + { + name: "Scan stopped status", + value: "stopped", + want: StatusStopped, + wantErr: false, + }, + { + name: "Scan invalid status returns error", + value: "invalid", + wantErr: true, + }, + { + name: "Scan nil returns error", + value: nil, + wantErr: true, + }, + { + name: "Scan non-string returns error", + value: 123, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var status MCPStatus + err := status.Scan(tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("MCPStatus.Scan() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && status != tt.want { + t.Errorf("MCPStatus.Scan() = %v, want %v", status, tt.want) + } + }) + } +} diff --git a/cmd/thv-operator/pkg/optimizer/tokens/counter.go b/cmd/thv-operator/pkg/optimizer/tokens/counter.go new file mode 100644 index 0000000000..11ed33c118 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/tokens/counter.go @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package tokens provides token counting utilities for LLM cost estimation. +// It estimates token counts for MCP tools and their metadata. +package tokens + +import ( + "encoding/json" + + "github.com/mark3labs/mcp-go/mcp" +) + +// Counter counts tokens for LLM consumption +// This provides estimates of token usage for tools +type Counter struct { + // Simple heuristic: ~4 characters per token for English text + charsPerToken float64 +} + +// NewCounter creates a new token counter +func NewCounter() *Counter { + return &Counter{ + charsPerToken: 4.0, // GPT-style tokenization approximation + } +} + +// CountToolTokens estimates the number of tokens for a tool +func (c *Counter) CountToolTokens(tool mcp.Tool) int { + // Convert tool to JSON representation (as it would be sent to LLM) + toolJSON, err := json.Marshal(tool) + if err != nil { + // Fallback to simple estimation + return c.estimateFromTool(tool) + } + + // Estimate tokens from JSON length + return int(float64(len(toolJSON)) / c.charsPerToken) +} + +// estimateFromTool provides a fallback estimation from tool fields +func (c *Counter) estimateFromTool(tool mcp.Tool) int { + totalChars := len(tool.Name) + + if tool.Description != "" { + totalChars += len(tool.Description) + } + + // Estimate input schema size + schemaJSON, _ := json.Marshal(tool.InputSchema) + totalChars += len(schemaJSON) + + return int(float64(totalChars) / c.charsPerToken) +} + +// CountToolsTokens calculates total tokens for multiple tools +func (c *Counter) CountToolsTokens(tools []mcp.Tool) int { + total := 0 + for _, tool := range tools { + total += c.CountToolTokens(tool) + } + return total +} + +// EstimateText estimates tokens for arbitrary text +func (c *Counter) EstimateText(text string) int { + return int(float64(len(text)) / c.charsPerToken) +} diff --git a/cmd/thv-operator/pkg/optimizer/tokens/counter_test.go b/cmd/thv-operator/pkg/optimizer/tokens/counter_test.go new file mode 100644 index 0000000000..082ee385a1 --- /dev/null +++ b/cmd/thv-operator/pkg/optimizer/tokens/counter_test.go @@ -0,0 +1,146 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package tokens + +import ( + "testing" + + "github.com/mark3labs/mcp-go/mcp" +) + +func TestCountToolTokens(t *testing.T) { + t.Parallel() + counter := NewCounter() + + tool := mcp.Tool{ + Name: "test_tool", + Description: "A test tool for counting tokens", + } + + tokens := counter.CountToolTokens(tool) + + // Should return a positive number + if tokens <= 0 { + t.Errorf("Expected positive token count, got %d", tokens) + } + + // Rough estimate: tool should have at least a few tokens + if tokens < 5 { + t.Errorf("Expected at least 5 tokens for a tool with name and description, got %d", tokens) + } +} + +func TestCountToolTokens_MinimalTool(t *testing.T) { + t.Parallel() + counter := NewCounter() + + // Minimal tool with just a name + tool := mcp.Tool{ + Name: "minimal", + } + + tokens := counter.CountToolTokens(tool) + + // Should return a positive number even for minimal tool + if tokens <= 0 { + t.Errorf("Expected positive token count for minimal tool, got %d", tokens) + } +} + +func TestCountToolTokens_NoDescription(t *testing.T) { + t.Parallel() + counter := NewCounter() + + tool := mcp.Tool{ + Name: "test_tool", + } + + tokens := counter.CountToolTokens(tool) + + // Should still return a positive number + if tokens <= 0 { + t.Errorf("Expected positive token count for tool without description, got %d", tokens) + } +} + +func TestCountToolsTokens(t *testing.T) { + t.Parallel() + counter := NewCounter() + + tools := []mcp.Tool{ + { + Name: "tool1", + Description: "First tool", + }, + { + Name: "tool2", + Description: "Second tool with longer description", + }, + } + + totalTokens := counter.CountToolsTokens(tools) + + // Should be greater than individual tools + tokens1 := counter.CountToolTokens(tools[0]) + tokens2 := counter.CountToolTokens(tools[1]) + + expectedTotal := tokens1 + tokens2 + if totalTokens != expectedTotal { + t.Errorf("Expected total tokens %d, got %d", expectedTotal, totalTokens) + } +} + +func TestCountToolsTokens_EmptyList(t *testing.T) { + t.Parallel() + counter := NewCounter() + + tokens := counter.CountToolsTokens([]mcp.Tool{}) + + // Should return 0 for empty list + if tokens != 0 { + t.Errorf("Expected 0 tokens for empty list, got %d", tokens) + } +} + +func TestEstimateText(t *testing.T) { + t.Parallel() + counter := NewCounter() + + tests := []struct { + name string + text string + want int + }{ + { + name: "Empty text", + text: "", + want: 0, + }, + { + name: "Short text", + text: "Hello", + want: 1, // 5 chars / 4 chars per token ≈ 1 + }, + { + name: "Medium text", + text: "This is a test message", + want: 5, // 22 chars / 4 chars per token ≈ 5 + }, + { + name: "Long text", + text: "This is a much longer test message that should have more tokens because it contains significantly more characters", + want: 28, // 112 chars / 4 chars per token = 28 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := counter.EstimateText(tt.text) + if got != tt.want { + t.Errorf("EstimateText() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/thv-operator/pkg/vmcpconfig/converter.go b/cmd/thv-operator/pkg/vmcpconfig/converter.go index d5e283f87b..47264f422e 100644 --- a/cmd/thv-operator/pkg/vmcpconfig/converter.go +++ b/cmd/thv-operator/pkg/vmcpconfig/converter.go @@ -135,6 +135,17 @@ func (c *Converter) Convert( // are handled by kubebuilder annotations in pkg/telemetry/config.go and applied by the API server. config.Telemetry = spectoconfig.NormalizeTelemetryConfig(vmcp.Spec.Config.Telemetry, vmcp.Name) + // Convert audit config + c.convertAuditConfig(config, vmcp) + + // Apply operational defaults (fills missing values) + config.EnsureOperationalDefaults() + + return config, nil +} + +// convertAuditConfig converts audit configuration from CRD to vmcp config. +func (*Converter) convertAuditConfig(config *vmcpconfig.Config, vmcp *mcpv1alpha1.VirtualMCPServer) { if vmcp.Spec.Config.Audit != nil && vmcp.Spec.Config.Audit.Enabled { config.Audit = vmcp.Spec.Config.Audit } @@ -142,11 +153,6 @@ func (c *Converter) Convert( if config.Audit != nil && config.Audit.Component == "" { config.Audit.Component = vmcp.Name } - - // Apply operational defaults (fills missing values) - config.EnsureOperationalDefaults() - - return config, nil } // convertIncomingAuth converts IncomingAuthConfig from CRD to vmcp config. diff --git a/cmd/vmcp/README.md b/cmd/vmcp/README.md index 30ac862ca2..10a60bef2a 100644 --- a/cmd/vmcp/README.md +++ b/cmd/vmcp/README.md @@ -6,7 +6,7 @@ The Virtual MCP Server (vmcp) is a standalone binary that aggregates multiple MC ## Features -### Implemented (Phase 1) +### Implemented - ✅ **Group-Based Backend Management**: Automatic workload discovery from ToolHive groups - ✅ **Tool Aggregation**: Combines tools from multiple MCP servers with conflict resolution (prefix, priority, manual) - ✅ **Resource & Prompt Aggregation**: Unified access to resources and prompts from all backends @@ -15,12 +15,14 @@ The Virtual MCP Server (vmcp) is a standalone binary that aggregates multiple MC - ✅ **Health Endpoints**: `/health` and `/ping` for service monitoring - ✅ **Configuration Validation**: `vmcp validate` command for config verification - ✅ **Observability**: OpenTelemetry metrics and traces for backend operations and workflow executions +- ✅ **Composite Tools**: Multi-step workflows with elicitation support ### In Progress - 🚧 **Incoming Authentication** (Issue #165): OIDC, local, anonymous authentication - 🚧 **Outgoing Authentication** (Issue #160): RFC 8693 token exchange for backend API access - 🚧 **Token Caching**: Memory and Redis cache providers - 🚧 **Health Monitoring** (Issue #166): Circuit breakers, backend health checks +- 🚧 **Optimizer** Support the MCP optimizer in vMCP for context optimization on large toolsets. ### Future (Phase 2+) - 📋 **Authorization**: Cedar policy-based access control diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index f9c0aa8a70..317c0e1ad8 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -27,7 +27,7 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/discovery" "github.com/stacklok/toolhive/pkg/vmcp/health" "github.com/stacklok/toolhive/pkg/vmcp/k8s" - "github.com/stacklok/toolhive/pkg/vmcp/optimizer" + vmcpoptimizer "github.com/stacklok/toolhive/pkg/vmcp/optimizer" vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router" vmcpserver "github.com/stacklok/toolhive/pkg/vmcp/server" vmcpstatus "github.com/stacklok/toolhive/pkg/vmcp/status" @@ -436,9 +436,28 @@ func runServe(cmd *cobra.Command, _ []string) error { StatusReporter: statusReporter, } - if cfg.Optimizer != nil { - // TODO: update this with the real optimizer. - serverCfg.OptimizerFactory = optimizer.NewDummyOptimizer + // Configure optimizer if enabled in YAML config + if cfg.Optimizer != nil && cfg.Optimizer.Enabled { + logger.Info("🔬 Optimizer enabled via configuration (chromem-go)") + optimizerCfg := vmcpoptimizer.ConfigFromVMCPConfig(cfg.Optimizer) + serverCfg.OptimizerConfig = optimizerCfg + persistInfo := "in-memory" + if cfg.Optimizer.PersistPath != "" { + persistInfo = cfg.Optimizer.PersistPath + } + // FTS5 is always enabled with configurable semantic/BM25 ratio + ratio := 70 // Default (70%) + if cfg.Optimizer.HybridSearchRatio != nil { + ratio = *cfg.Optimizer.HybridSearchRatio + } + searchMode := fmt.Sprintf("hybrid (%d%% semantic, %d%% BM25)", + ratio, + 100-ratio) + logger.Infof("Optimizer configured: backend=%s, dimension=%d, persistence=%s, search=%s", + cfg.Optimizer.EmbeddingBackend, + cfg.Optimizer.EmbeddingDimension, + persistInfo, + searchMode) } // Convert composite tool configurations to workflow definitions diff --git a/codecov.yaml b/codecov.yaml index 1a8032e484..410f9ae7ee 100644 --- a/codecov.yaml +++ b/codecov.yaml @@ -13,6 +13,8 @@ coverage: - "**/mocks/**/*" - "**/mock_*.go" - "**/zz_generated.deepcopy.go" + - "**/*_test.go" + - "**/*_test_coverage.go" status: project: default: diff --git a/deploy/charts/operator-crds/Chart.yaml b/deploy/charts/operator-crds/Chart.yaml index 01865a110d..1b14897d71 100644 --- a/deploy/charts/operator-crds/Chart.yaml +++ b/deploy/charts/operator-crds/Chart.yaml @@ -2,5 +2,5 @@ apiVersion: v2 name: toolhive-operator-crds description: A Helm chart for installing the ToolHive Operator CRDs into Kubernetes. type: application -version: 0.0.101 +version: 0.0.102 appVersion: "0.0.1" diff --git a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml index 318099bce9..0f153da6a3 100644 --- a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml +++ b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml @@ -677,17 +677,74 @@ spec: optimizer: description: |- Optimizer configures the MCP optimizer for context optimization on large toolsets. - When enabled, vMCP exposes only find_tool and call_tool operations to clients + When enabled, vMCP exposes optim_find_tool and optim_call_tool operations to clients instead of all backend tools directly. This reduces token usage by allowing LLMs to discover relevant tools on demand rather than receiving all tool definitions. properties: - embeddingService: + embeddingBackend: description: |- - EmbeddingService is the name of a Kubernetes Service that provides the embedding service - for semantic tool discovery. The service must implement the optimizer embedding API. + EmbeddingBackend specifies the embedding provider: "ollama", "openai-compatible", or "placeholder". + - "ollama": Uses local Ollama HTTP API for embeddings + - "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.) + - "placeholder": Uses deterministic hash-based embeddings (for testing/development) + enum: + - ollama + - openai-compatible + - placeholder + type: string + embeddingDimension: + description: |- + EmbeddingDimension is the dimension of the embedding vectors. + Common values: + - 384: all-MiniLM-L6-v2, nomic-embed-text + - 768: BAAI/bge-small-en-v1.5 + - 1536: OpenAI text-embedding-3-small + minimum: 1 + type: integer + embeddingModel: + description: |- + EmbeddingModel is the model name to use for embeddings. + Required when EmbeddingBackend is "ollama" or "openai-compatible". + Examples: + - Ollama: "nomic-embed-text", "all-minilm" + - vLLM: "BAAI/bge-small-en-v1.5" + - OpenAI: "text-embedding-3-small" + type: string + embeddingURL: + description: |- + EmbeddingURL is the base URL for the embedding service (Ollama or OpenAI-compatible API). + Required when EmbeddingBackend is "ollama" or "openai-compatible". + Examples: + - Ollama: "http://localhost:11434" + - vLLM: "http://vllm-service:8000/v1" + - OpenAI: "https://api.openai.com/v1" + type: string + enabled: + description: |- + Enabled determines whether the optimizer is active. + When true, vMCP exposes optim_find_tool and optim_call_tool instead of all backend tools. + type: boolean + ftsDBPath: + description: |- + FTSDBPath is the path to the SQLite FTS5 database for BM25 text search. + If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set. + Hybrid search (semantic + BM25) is always enabled. + type: string + hybridSearchRatio: + description: |- + HybridSearchRatio controls the mix of semantic vs BM25 results in hybrid search. + Value range: 0 (all BM25) to 100 (all semantic), representing percentage. + Default: 70 (70% semantic, 30% BM25) + Only used when FTSDBPath is set. + maximum: 100 + minimum: 0 + type: integer + persistPath: + description: |- + PersistPath is the optional filesystem path for persisting the chromem-go database. + If empty, the database will be in-memory only (ephemeral). + When set, tool metadata and embeddings are persisted to disk for faster restarts. type: string - required: - - embeddingService type: object outgoingAuth: description: |- diff --git a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml index 7ebc65a9ab..48a26c7069 100644 --- a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml +++ b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml @@ -680,17 +680,74 @@ spec: optimizer: description: |- Optimizer configures the MCP optimizer for context optimization on large toolsets. - When enabled, vMCP exposes only find_tool and call_tool operations to clients + When enabled, vMCP exposes optim_find_tool and optim_call_tool operations to clients instead of all backend tools directly. This reduces token usage by allowing LLMs to discover relevant tools on demand rather than receiving all tool definitions. properties: - embeddingService: + embeddingBackend: description: |- - EmbeddingService is the name of a Kubernetes Service that provides the embedding service - for semantic tool discovery. The service must implement the optimizer embedding API. + EmbeddingBackend specifies the embedding provider: "ollama", "openai-compatible", or "placeholder". + - "ollama": Uses local Ollama HTTP API for embeddings + - "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.) + - "placeholder": Uses deterministic hash-based embeddings (for testing/development) + enum: + - ollama + - openai-compatible + - placeholder + type: string + embeddingDimension: + description: |- + EmbeddingDimension is the dimension of the embedding vectors. + Common values: + - 384: all-MiniLM-L6-v2, nomic-embed-text + - 768: BAAI/bge-small-en-v1.5 + - 1536: OpenAI text-embedding-3-small + minimum: 1 + type: integer + embeddingModel: + description: |- + EmbeddingModel is the model name to use for embeddings. + Required when EmbeddingBackend is "ollama" or "openai-compatible". + Examples: + - Ollama: "nomic-embed-text", "all-minilm" + - vLLM: "BAAI/bge-small-en-v1.5" + - OpenAI: "text-embedding-3-small" + type: string + embeddingURL: + description: |- + EmbeddingURL is the base URL for the embedding service (Ollama or OpenAI-compatible API). + Required when EmbeddingBackend is "ollama" or "openai-compatible". + Examples: + - Ollama: "http://localhost:11434" + - vLLM: "http://vllm-service:8000/v1" + - OpenAI: "https://api.openai.com/v1" + type: string + enabled: + description: |- + Enabled determines whether the optimizer is active. + When true, vMCP exposes optim_find_tool and optim_call_tool instead of all backend tools. + type: boolean + ftsDBPath: + description: |- + FTSDBPath is the path to the SQLite FTS5 database for BM25 text search. + If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set. + Hybrid search (semantic + BM25) is always enabled. + type: string + hybridSearchRatio: + description: |- + HybridSearchRatio controls the mix of semantic vs BM25 results in hybrid search. + Value range: 0 (all BM25) to 100 (all semantic), representing percentage. + Default: 70 (70% semantic, 30% BM25) + Only used when FTSDBPath is set. + maximum: 100 + minimum: 0 + type: integer + persistPath: + description: |- + PersistPath is the optional filesystem path for persisting the chromem-go database. + If empty, the database will be in-memory only (ephemeral). + When set, tool metadata and embeddings are persisted to disk for faster restarts. type: string - required: - - embeddingService type: object outgoingAuth: description: |- diff --git a/docs/operator/crd-api.md b/docs/operator/crd-api.md index 3d075ce09b..020e665859 100644 --- a/docs/operator/crd-api.md +++ b/docs/operator/crd-api.md @@ -245,7 +245,7 @@ _Appears in:_ | `metadata` _object (keys:string, values:string)_ | Refer to Kubernetes API documentation for fields of `metadata`. | | | | `telemetry` _[pkg.telemetry.Config](#pkgtelemetryconfig)_ | Telemetry configures OpenTelemetry-based observability for the Virtual MCP server
including distributed tracing, OTLP metrics export, and Prometheus metrics endpoint. | | | | `audit` _[pkg.audit.Config](#pkgauditconfig)_ | Audit configures audit logging for the Virtual MCP server.
When present, audit logs include MCP protocol operations.
See audit.Config for available configuration options. | | | -| `optimizer` _[vmcp.config.OptimizerConfig](#vmcpconfigoptimizerconfig)_ | Optimizer configures the MCP optimizer for context optimization on large toolsets.
When enabled, vMCP exposes only find_tool and call_tool operations to clients
instead of all backend tools directly. This reduces token usage by allowing
LLMs to discover relevant tools on demand rather than receiving all tool definitions. | | | +| `optimizer` _[vmcp.config.OptimizerConfig](#vmcpconfigoptimizerconfig)_ | Optimizer configures the MCP optimizer for context optimization on large toolsets.
When enabled, vMCP exposes optim_find_tool and optim_call_tool operations to clients
instead of all backend tools directly. This reduces token usage by allowing
LLMs to discover relevant tools on demand rather than receiving all tool definitions. | | | #### vmcp.config.ConflictResolutionConfig @@ -377,9 +377,9 @@ _Appears in:_ -OptimizerConfig configures the MCP optimizer. -When enabled, vMCP exposes only find_tool and call_tool operations to clients -instead of all backend tools directly. +OptimizerConfig configures the MCP optimizer for semantic tool discovery. +The optimizer reduces token usage by allowing LLMs to discover relevant tools +on demand rather than receiving all tool definitions upfront. @@ -388,7 +388,14 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `embeddingService` _string_ | EmbeddingService is the name of a Kubernetes Service that provides the embedding service
for semantic tool discovery. The service must implement the optimizer embedding API. | | Required: \{\}
| +| `enabled` _boolean_ | Enabled determines whether the optimizer is active.
When true, vMCP exposes optim_find_tool and optim_call_tool instead of all backend tools. | | | +| `embeddingBackend` _string_ | EmbeddingBackend specifies the embedding provider: "ollama", "openai-compatible", or "placeholder".
- "ollama": Uses local Ollama HTTP API for embeddings
- "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.)
- "placeholder": Uses deterministic hash-based embeddings (for testing/development) | | Enum: [ollama openai-compatible placeholder]
| +| `embeddingURL` _string_ | EmbeddingURL is the base URL for the embedding service (Ollama or OpenAI-compatible API).
Required when EmbeddingBackend is "ollama" or "openai-compatible".
Examples:
- Ollama: "http://localhost:11434"
- vLLM: "http://vllm-service:8000/v1"
- OpenAI: "https://api.openai.com/v1" | | | +| `embeddingModel` _string_ | EmbeddingModel is the model name to use for embeddings.
Required when EmbeddingBackend is "ollama" or "openai-compatible".
Examples:
- Ollama: "nomic-embed-text", "all-minilm"
- vLLM: "BAAI/bge-small-en-v1.5"
- OpenAI: "text-embedding-3-small" | | | +| `embeddingDimension` _integer_ | EmbeddingDimension is the dimension of the embedding vectors.
Common values:
- 384: all-MiniLM-L6-v2, nomic-embed-text
- 768: BAAI/bge-small-en-v1.5
- 1536: OpenAI text-embedding-3-small | | Minimum: 1
| +| `persistPath` _string_ | PersistPath is the optional filesystem path for persisting the chromem-go database.
If empty, the database will be in-memory only (ephemeral).
When set, tool metadata and embeddings are persisted to disk for faster restarts. | | | +| `ftsDBPath` _string_ | FTSDBPath is the path to the SQLite FTS5 database for BM25 text search.
If empty, defaults to ":memory:" for in-memory FTS5, or "\{PersistPath\}/fts.db" if PersistPath is set.
Hybrid search (semantic + BM25) is always enabled. | | | +| `hybridSearchRatio` _integer_ | HybridSearchRatio controls the mix of semantic vs BM25 results in hybrid search.
Value range: 0 (all BM25) to 100 (all semantic), representing percentage.
Default: 70 (70% semantic, 30% BM25)
Only used when FTSDBPath is set. | | Maximum: 100
Minimum: 0
| #### vmcp.config.OutgoingAuthConfig diff --git a/examples/vmcp-config-optimizer.yaml b/examples/vmcp-config-optimizer.yaml new file mode 100644 index 0000000000..547c60e5f6 --- /dev/null +++ b/examples/vmcp-config-optimizer.yaml @@ -0,0 +1,126 @@ +# vMCP Configuration with Optimizer Enabled +# This configuration enables the optimizer for semantic tool discovery + +name: "vmcp-debug" + +# Reference to ToolHive group containing MCP servers +groupRef: "default" + +# Client authentication (anonymous for local development) +incomingAuth: + type: anonymous + +# Backend authentication (unauthenticated for local development) +outgoingAuth: + source: inline + default: + type: unauthenticated + +# Tool aggregation settings +aggregation: + conflictResolution: prefix + conflictResolutionConfig: + prefixFormat: "{workload}_" + +# Operational settings +operational: + timeouts: + default: 30s + failureHandling: + healthCheckInterval: 30s + unhealthyThreshold: 3 + partialFailureMode: fail + +# ============================================================================= +# OPTIMIZER CONFIGURATION +# ============================================================================= +# When enabled, vMCP exposes optim.find_tool and optim.call_tool instead of +# all backend tools directly. This reduces token usage by allowing LLMs to +# discover relevant tools on demand via semantic search. +# +# The optimizer ingests tools from all backends in the group, generates +# embeddings, and provides semantic search capabilities. + +optimizer: + # Enable the optimizer + enabled: true + + # Embedding backend: "ollama" (default), "openai-compatible", or "vllm" + # - "ollama": Uses local Ollama HTTP API for embeddings (default, requires 'ollama serve') + # - "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.) + # - "vllm": Alias for OpenAI-compatible API + embeddingBackend: ollama + + # Embedding dimension (common values: 384, 768, 1536) + # 384 is standard for all-MiniLM-L6-v2 and nomic-embed-text + embeddingDimension: 384 + + # Optional: Path for persisting the chromem-go database + # If omitted, the database will be in-memory only (ephemeral) + persistPath: /tmp/vmcp-optimizer-debug.db + + # Optional: Path for the SQLite FTS5 database (for hybrid search) + # Default: ":memory:" (in-memory) or "{persistPath}/fts.db" if persistPath is set + # Hybrid search (semantic + BM25) is ALWAYS enabled + ftsDBPath: /tmp/vmcp-optimizer-fts.db # Uncomment to customize location + + # Optional: Hybrid search ratio (0-100, representing percentage) + # Default: 70 (70% semantic, 30% BM25) + # hybridSearchRatio: 70 + + # ============================================================================= + # PRODUCTION CONFIGURATIONS (Commented Examples) + # ============================================================================= + + # Option 1: Local Ollama (good for development/testing) + # embeddingBackend: ollama + # embeddingURL: http://localhost:11434 + # embeddingModel: all-minilm # Default model (all-MiniLM-L6-v2) + # embeddingDimension: 384 + + # Option 2: vLLM (recommended for production with GPU acceleration) + # embeddingBackend: openai-compatible + # embeddingURL: http://vllm-service:8000/v1 + # embeddingModel: BAAI/bge-small-en-v1.5 + # embeddingDimension: 768 + + # Option 3: OpenAI API (cloud-based) + # embeddingBackend: openai-compatible + # embeddingURL: https://api.openai.com/v1 + # embeddingModel: text-embedding-3-small + # embeddingDimension: 1536 + # (requires OPENAI_API_KEY environment variable) + + # Option 4: Kubernetes in-cluster service (K8s deployments) + # embeddingURL: http://embedding-service-name.namespace.svc.cluster.local:port + # Use the full service DNS name with port for in-cluster services + +# ============================================================================= +# TELEMETRY CONFIGURATION (for Jaeger tracing) +# ============================================================================= +# Configure OpenTelemetry to send traces to Jaeger +telemetry: + endpoint: "localhost:4318" # OTLP HTTP endpoint (Jaeger collector) - no http:// prefix needed with insecure: true + serviceName: "vmcp-optimizer" + serviceVersion: "1.0.0" # Optional: service version + tracingEnabled: true + metricsEnabled: false # Set to true if you want metrics too + samplingRate: "1.0" # 100% sampling for development (use lower in production) + insecure: true # Use HTTP instead of HTTPS + +# ============================================================================= +# USAGE +# ============================================================================= +# 1. Start MCP backends in the group: +# thv run weather --group default +# thv run github --group default +# +# 2. Start vMCP with optimizer: +# thv vmcp serve --config examples/vmcp-config-optimizer.yaml +# +# 3. Connect MCP client to vMCP +# +# 4. Available tools from vMCP: +# - optim.find_tool: Search for tools by semantic query +# - optim.call_tool: Execute a tool by name +# - (backend tools are NOT directly exposed when optimizer is enabled) diff --git a/go.mod b/go.mod index 0dca0108db..e041a8ca0a 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/onsi/ginkgo/v2 v2.27.5 github.com/onsi/gomega v1.39.0 github.com/ory/fosite v0.49.0 + github.com/philippgille/chromem-go v0.7.0 github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/prometheus/client_golang v1.23.2 github.com/sigstore/protobuf-specs v0.5.0 @@ -59,6 +60,7 @@ require ( k8s.io/api v0.35.0 k8s.io/apimachinery v0.35.0 k8s.io/utils v0.0.0-20260108192941-914a6e750570 + modernc.org/sqlite v1.44.0 sigs.k8s.io/controller-runtime v0.22.4 sigs.k8s.io/yaml v1.6.0 ) @@ -174,6 +176,7 @@ require ( github.com/muesli/termenv v0.16.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 // indirect github.com/olekukonko/errors v1.1.0 // indirect @@ -188,6 +191,7 @@ require ( github.com/prometheus/common v0.67.4 // indirect github.com/prometheus/otlptranslator v1.0.0 // indirect github.com/prometheus/procfs v0.19.2 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect @@ -251,6 +255,9 @@ require ( k8s.io/apiextensions-apiserver v0.34.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 // indirect + modernc.org/libc v1.67.4 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect sigs.k8s.io/randfill v1.0.0 // indirect sigs.k8s.io/structured-merge-diff/v6 v6.3.0 // indirect @@ -286,7 +293,7 @@ require ( go.opentelemetry.io/otel/metric v1.39.0 go.opentelemetry.io/otel/trace v1.39.0 golang.org/x/crypto v0.47.0 - golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/sys v0.40.0 k8s.io/client-go v0.35.0 ) diff --git a/go.sum b/go.sum index 2b66579e95..78126e3d7a 100644 --- a/go.sum +++ b/go.sum @@ -602,6 +602,8 @@ github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A= github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/nyaruka/phonenumbers v1.1.6 h1:DcueYq7QrOArAprAYNoQfDgp0KetO4LqtnBtQC6Wyes= github.com/nyaruka/phonenumbers v1.1.6/go.mod h1:yShPJHDSH3aTKzCbXyVxNpbl2kA+F+Ne5Pun/MvFRos= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= @@ -640,6 +642,8 @@ github.com/ory/x v0.0.665 h1:61vv0ObCDSX1vOQYbxBeqDiv4YiPmMT91lYxDaaKX08= github.com/ory/x v0.0.665/go.mod h1:7SCTki3N0De3ZpqlxhxU/94ZrOCfNEnXwVtd0xVt+L8= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/philippgille/chromem-go v0.7.0 h1:4jfvfyKymjKNfGxBUhHUcj1kp7B17NL/I1P+vGh1RvY= +github.com/philippgille/chromem-go v0.7.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo= github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4= github.com/pjbgf/sha1cd v0.3.2/go.mod h1:zQWigSxVmsHEZow5qaLtPYxpcKMMQpa09ixqBxuCS6A= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= @@ -661,6 +665,8 @@ github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEo github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM= github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -909,8 +915,8 @@ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= -golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= -golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/exp/event v0.0.0-20251219203646-944ab1f22d93 h1:Fee8ke0jLfLhU4ywDLs7IYmhJ8MrSP0iZE3p39EKKSc= golang.org/x/exp/event v0.0.0-20251219203646-944ab1f22d93/go.mod h1:HgAgrKXB9WF2wFZJBGBnRVkmsC8n+v2ja/8VR0H3QkY= golang.org/x/exp/jsonrpc2 v0.0.0-20260112195511-716be5621a96 h1:cN9X2vSBmT3Ruw2UlbJNLJh0iBqTmtSB0dRfh5aumiY= @@ -1086,6 +1092,34 @@ k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 h1:Y3gxNAuB0OBLImH611+UDZ k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912/go.mod h1:kdmbQkyfwUagLfXIad1y2TdrjPFWp2Q89B3qkRwf/pQ= k8s.io/utils v0.0.0-20260108192941-914a6e750570 h1:JT4W8lsdrGENg9W+YwwdLJxklIuKWdRm+BC+xt33FOY= k8s.io/utils v0.0.0-20260108192941-914a6e750570/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= +modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= +modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.67.4 h1:zZGmCMUVPORtKv95c2ReQN5VDjvkoRm9GWPTEPuvlWg= +modernc.org/libc v1.67.4/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.44.0 h1:YjCKJnzZde2mLVy0cMKTSL4PxCmbIguOq9lGp8ZvGOc= +modernc.org/sqlite v1.44.0/go.mod h1:2Dq41ir5/qri7QJJJKNZcP4UF7TsX/KNeykYgPDtGhE= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= sigs.k8s.io/controller-runtime v0.22.4 h1:GEjV7KV3TY8e+tJ2LCTxUTanW4z/FmNB7l327UfMq9A= sigs.k8s.io/controller-runtime v0.22.4/go.mod h1:+QX1XUpTXN4mLoblf4tqr5CQcyHPAki2HLXqQMY6vh8= sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 h1:IpInykpT6ceI+QxKBbEflcR5EXP7sU1kvOlxwZh5txg= diff --git a/pkg/runner/config_builder_test.go b/pkg/runner/config_builder_test.go index 735c9ccc45..0e4556937d 100644 --- a/pkg/runner/config_builder_test.go +++ b/pkg/runner/config_builder_test.go @@ -1076,8 +1076,8 @@ func TestRunConfigBuilder_WithRegistryProxyPort(t *testing.T) { ProxyPort: 8976, TargetPort: 8976, }, - cliProxyPort: 9000, - expectedProxyPort: 9000, + cliProxyPort: 9999, + expectedProxyPort: 9999, }, { name: "random port when neither CLI nor registry specified", diff --git a/pkg/vmcp/aggregator/default_aggregator.go b/pkg/vmcp/aggregator/default_aggregator.go index 95be734af0..3cf2846fcc 100644 --- a/pkg/vmcp/aggregator/default_aggregator.go +++ b/pkg/vmcp/aggregator/default_aggregator.go @@ -8,6 +8,10 @@ import ( "fmt" "sync" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" "github.com/stacklok/toolhive/pkg/logger" @@ -21,6 +25,7 @@ type defaultAggregator struct { backendClient vmcp.BackendClient conflictResolver ConflictResolver toolConfigMap map[string]*config.WorkloadToolConfig // Maps backend ID to tool config + tracer trace.Tracer } // NewDefaultAggregator creates a new default aggregator implementation. @@ -43,12 +48,20 @@ func NewDefaultAggregator( backendClient: backendClient, conflictResolver: conflictResolver, toolConfigMap: toolConfigMap, + tracer: otel.Tracer("github.com/stacklok/toolhive/pkg/vmcp/aggregator"), } } // QueryCapabilities queries a single backend for its MCP capabilities. // Returns the raw capabilities (tools, resources, prompts) from the backend. func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp.Backend) (*BackendCapabilities, error) { + ctx, span := a.tracer.Start(ctx, "aggregator.QueryCapabilities", + trace.WithAttributes( + attribute.String("backend.id", backend.ID), + ), + ) + defer span.End() + logger.Debugf("Querying capabilities from backend %s", backend.ID) // Create a BackendTarget from the Backend @@ -58,6 +71,8 @@ func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp. // Query capabilities using the backend client capabilities, err := a.backendClient.ListCapabilities(ctx, target) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("%w: %s: %w", ErrBackendQueryFailed, backend.ID, err) } @@ -74,6 +89,12 @@ func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp. SupportsSampling: capabilities.SupportsSampling, } + span.SetAttributes( + attribute.Int("tools.count", len(result.Tools)), + attribute.Int("resources.count", len(result.Resources)), + attribute.Int("prompts.count", len(result.Prompts)), + ) + logger.Debugf("Backend %s: %d tools (after filtering/overrides), %d resources, %d prompts", backend.ID, len(result.Tools), len(result.Resources), len(result.Prompts)) @@ -86,6 +107,13 @@ func (a *defaultAggregator) QueryAllCapabilities( ctx context.Context, backends []vmcp.Backend, ) (map[string]*BackendCapabilities, error) { + ctx, span := a.tracer.Start(ctx, "aggregator.QueryAllCapabilities", + trace.WithAttributes( + attribute.Int("backends.count", len(backends)), + ), + ) + defer span.End() + logger.Infof("Querying capabilities from %d backends", len(backends)) // Use errgroup for parallel queries with context cancellation @@ -118,13 +146,22 @@ func (a *defaultAggregator) QueryAllCapabilities( // Wait for all queries to complete if err := g.Wait(); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("capability queries failed: %w", err) } if len(capabilities) == 0 { - return nil, fmt.Errorf("no backends returned capabilities") + err := fmt.Errorf("no backends returned capabilities") + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return nil, err } + span.SetAttributes( + attribute.Int("successful.backends", len(capabilities)), + ) + logger.Infof("Successfully queried %d/%d backends", len(capabilities), len(backends)) return capabilities, nil } @@ -135,6 +172,13 @@ func (a *defaultAggregator) ResolveConflicts( ctx context.Context, capabilities map[string]*BackendCapabilities, ) (*ResolvedCapabilities, error) { + ctx, span := a.tracer.Start(ctx, "aggregator.ResolveConflicts", + trace.WithAttributes( + attribute.Int("backends.count", len(capabilities)), + ), + ) + defer span.End() + logger.Debugf("Resolving conflicts across %d backends", len(capabilities)) // Group tools by backend for conflict resolution @@ -150,6 +194,8 @@ func (a *defaultAggregator) ResolveConflicts( if a.conflictResolver != nil { resolvedTools, err = a.conflictResolver.ResolveToolConflicts(ctx, toolsByBackend) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("conflict resolution failed: %w", err) } } else { @@ -191,6 +237,12 @@ func (a *defaultAggregator) ResolveConflicts( resolved.SupportsSampling = resolved.SupportsSampling || caps.SupportsSampling } + span.SetAttributes( + attribute.Int("resolved.tools", len(resolved.Tools)), + attribute.Int("resolved.resources", len(resolved.Resources)), + attribute.Int("resolved.prompts", len(resolved.Prompts)), + ) + logger.Debugf("Resolved %d unique tools, %d resources, %d prompts", len(resolved.Tools), len(resolved.Resources), len(resolved.Prompts)) @@ -199,11 +251,20 @@ func (a *defaultAggregator) ResolveConflicts( // MergeCapabilities creates the final unified capability view and routing table. // Uses the backend registry to populate full BackendTarget information for routing. -func (*defaultAggregator) MergeCapabilities( +func (a *defaultAggregator) MergeCapabilities( ctx context.Context, resolved *ResolvedCapabilities, registry vmcp.BackendRegistry, ) (*AggregatedCapabilities, error) { + ctx, span := a.tracer.Start(ctx, "aggregator.MergeCapabilities", + trace.WithAttributes( + attribute.Int("resolved.tools", len(resolved.Tools)), + attribute.Int("resolved.resources", len(resolved.Resources)), + attribute.Int("resolved.prompts", len(resolved.Prompts)), + ), + ) + defer span.End() + logger.Debugf("Merging capabilities into final view") // Create routing table @@ -304,6 +365,13 @@ func (*defaultAggregator) MergeCapabilities( }, } + span.SetAttributes( + attribute.Int("aggregated.tools", aggregated.Metadata.ToolCount), + attribute.Int("aggregated.resources", aggregated.Metadata.ResourceCount), + attribute.Int("aggregated.prompts", aggregated.Metadata.PromptCount), + attribute.String("conflict.strategy", string(aggregated.Metadata.ConflictStrategy)), + ) + logger.Infof("Merged capabilities: %d tools, %d resources, %d prompts", aggregated.Metadata.ToolCount, aggregated.Metadata.ResourceCount, aggregated.Metadata.PromptCount) @@ -316,6 +384,13 @@ func (*defaultAggregator) MergeCapabilities( // 3. Resolve conflicts // 4. Merge into final view with full backend information func (a *defaultAggregator) AggregateCapabilities(ctx context.Context, backends []vmcp.Backend) (*AggregatedCapabilities, error) { + ctx, span := a.tracer.Start(ctx, "aggregator.AggregateCapabilities", + trace.WithAttributes( + attribute.Int("backends.count", len(backends)), + ), + ) + defer span.End() + logger.Infof("Starting capability aggregation for %d backends", len(backends)) // Step 1: Create registry from discovered backends @@ -325,24 +400,38 @@ func (a *defaultAggregator) AggregateCapabilities(ctx context.Context, backends // Step 2: Query all backends capabilities, err := a.QueryAllCapabilities(ctx, backends) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("failed to query backends: %w", err) } // Step 3: Resolve conflicts resolved, err := a.ResolveConflicts(ctx, capabilities) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("failed to resolve conflicts: %w", err) } // Step 4: Merge into final view with full backend information aggregated, err := a.MergeCapabilities(ctx, resolved, registry) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("failed to merge capabilities: %w", err) } // Update metadata with backend count aggregated.Metadata.BackendCount = len(backends) + span.SetAttributes( + attribute.Int("aggregated.backends", aggregated.Metadata.BackendCount), + attribute.Int("aggregated.tools", aggregated.Metadata.ToolCount), + attribute.Int("aggregated.resources", aggregated.Metadata.ResourceCount), + attribute.Int("aggregated.prompts", aggregated.Metadata.PromptCount), + attribute.String("conflict.strategy", string(aggregated.Metadata.ConflictStrategy)), + ) + logger.Infof("Capability aggregation complete: %d backends, %d tools, %d resources, %d prompts", aggregated.Metadata.BackendCount, aggregated.Metadata.ToolCount, aggregated.Metadata.ResourceCount, aggregated.Metadata.PromptCount) diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index 9e53ff994e..0634376de6 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -15,6 +15,7 @@ import ( "io" "net" "net/http" + "time" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" @@ -126,8 +127,6 @@ func (a *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) return nil, fmt.Errorf("authentication failed for backend %s: %w", a.target.WorkloadID, err) } - logger.Debugf("Applied authentication strategy %q to backend %s", a.authStrategy.Name(), a.target.WorkloadID) - return a.base.RoundTrip(reqClone) } @@ -203,8 +202,10 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm }) // Create HTTP client with configured transport chain + // Set timeouts to prevent long-lived connections that require continuous listening httpClient := &http.Client{ Transport: sizeLimitedTransport, + Timeout: 30 * time.Second, // Prevent hanging on connections } var c *client.Client @@ -213,8 +214,7 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm case "streamable-http", "streamable": c, err = client.NewStreamableHttpClient( target.BaseURL, - transport.WithHTTPTimeout(0), - transport.WithContinuousListening(), + transport.WithHTTPTimeout(30*time.Second), // Set timeout instead of 0 transport.WithHTTPBasicClient(httpClient), ) if err != nil { diff --git a/pkg/vmcp/config/config.go b/pkg/vmcp/config/config.go index aa9583cce0..f477c01232 100644 --- a/pkg/vmcp/config/config.go +++ b/pkg/vmcp/config/config.go @@ -151,7 +151,7 @@ type Config struct { Audit *audit.Config `json:"audit,omitempty" yaml:"audit,omitempty"` // Optimizer configures the MCP optimizer for context optimization on large toolsets. - // When enabled, vMCP exposes only find_tool and call_tool operations to clients + // When enabled, vMCP exposes optim_find_tool and optim_call_tool operations to clients // instead of all backend tools directly. This reduces token usage by allowing // LLMs to discover relevant tools on demand rather than receiving all tool definitions. // +optional @@ -696,16 +696,72 @@ type OutputProperty struct { Default thvjson.Any `json:"default,omitempty" yaml:"default,omitempty"` } -// OptimizerConfig configures the MCP optimizer. -// When enabled, vMCP exposes only find_tool and call_tool operations to clients -// instead of all backend tools directly. +// OptimizerConfig configures the MCP optimizer for semantic tool discovery. +// The optimizer reduces token usage by allowing LLMs to discover relevant tools +// on demand rather than receiving all tool definitions upfront. // +kubebuilder:object:generate=true // +gendoc type OptimizerConfig struct { - // EmbeddingService is the name of a Kubernetes Service that provides the embedding service - // for semantic tool discovery. The service must implement the optimizer embedding API. - // +kubebuilder:validation:Required - EmbeddingService string `json:"embeddingService" yaml:"embeddingService"` + // Enabled determines whether the optimizer is active. + // When true, vMCP exposes optim_find_tool and optim_call_tool instead of all backend tools. + // +optional + Enabled bool `json:"enabled" yaml:"enabled"` + + // EmbeddingBackend specifies the embedding provider: "ollama", "openai-compatible", or "placeholder". + // - "ollama": Uses local Ollama HTTP API for embeddings + // - "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.) + // - "placeholder": Uses deterministic hash-based embeddings (for testing/development) + // +kubebuilder:validation:Enum=ollama;openai-compatible;placeholder + // +optional + EmbeddingBackend string `json:"embeddingBackend,omitempty" yaml:"embeddingBackend,omitempty"` + + // EmbeddingURL is the base URL for the embedding service (Ollama or OpenAI-compatible API). + // Required when EmbeddingBackend is "ollama" or "openai-compatible". + // Examples: + // - Ollama: "http://localhost:11434" + // - vLLM: "http://vllm-service:8000/v1" + // - OpenAI: "https://api.openai.com/v1" + // +optional + EmbeddingURL string `json:"embeddingURL,omitempty" yaml:"embeddingURL,omitempty"` + + // EmbeddingModel is the model name to use for embeddings. + // Required when EmbeddingBackend is "ollama" or "openai-compatible". + // Examples: + // - Ollama: "nomic-embed-text", "all-minilm" + // - vLLM: "BAAI/bge-small-en-v1.5" + // - OpenAI: "text-embedding-3-small" + // +optional + EmbeddingModel string `json:"embeddingModel,omitempty" yaml:"embeddingModel,omitempty"` + + // EmbeddingDimension is the dimension of the embedding vectors. + // Common values: + // - 384: all-MiniLM-L6-v2, nomic-embed-text + // - 768: BAAI/bge-small-en-v1.5 + // - 1536: OpenAI text-embedding-3-small + // +kubebuilder:validation:Minimum=1 + // +optional + EmbeddingDimension int `json:"embeddingDimension,omitempty" yaml:"embeddingDimension,omitempty"` + + // PersistPath is the optional filesystem path for persisting the chromem-go database. + // If empty, the database will be in-memory only (ephemeral). + // When set, tool metadata and embeddings are persisted to disk for faster restarts. + // +optional + PersistPath string `json:"persistPath,omitempty" yaml:"persistPath,omitempty"` + + // FTSDBPath is the path to the SQLite FTS5 database for BM25 text search. + // If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set. + // Hybrid search (semantic + BM25) is always enabled. + // +optional + FTSDBPath string `json:"ftsDBPath,omitempty" yaml:"ftsDBPath,omitempty"` + + // HybridSearchRatio controls the mix of semantic vs BM25 results in hybrid search. + // Value range: 0 (all BM25) to 100 (all semantic), representing percentage. + // Default: 70 (70% semantic, 30% BM25) + // Only used when FTSDBPath is set. + // +optional + // +kubebuilder:validation:Minimum=0 + // +kubebuilder:validation:Maximum=100 + HybridSearchRatio *int `json:"hybridSearchRatio,omitempty" yaml:"hybridSearchRatio,omitempty"` } // Validator validates configuration. diff --git a/pkg/vmcp/discovery/manager.go b/pkg/vmcp/discovery/manager.go index 0845118ee1..6dfa659512 100644 --- a/pkg/vmcp/discovery/manager.go +++ b/pkg/vmcp/discovery/manager.go @@ -18,6 +18,8 @@ import ( "sync" "time" + "golang.org/x/sync/singleflight" + "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" @@ -68,6 +70,9 @@ type DefaultManager struct { stopCh chan struct{} stopOnce sync.Once wg sync.WaitGroup + // singleFlight ensures only one aggregation happens per cache key at a time + // This prevents concurrent requests from all triggering aggregation + singleFlight singleflight.Group } // NewManager creates a new discovery manager with the given aggregator. @@ -131,6 +136,9 @@ func NewManagerWithRegistry(agg aggregator.Aggregator, registry vmcp.DynamicRegi // // The context must contain an authenticated user identity (set by auth middleware). // Returns ErrNoIdentity if user identity is not found in context. +// +// This method uses singleflight to ensure that concurrent requests for the same +// cache key only trigger one aggregation, preventing duplicate work. func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) { // Validate user identity is present (set by auth middleware) // This ensures discovery happens with proper user authentication context @@ -142,7 +150,7 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) // Generate cache key from user identity and backend set cacheKey := m.generateCacheKey(identity.Subject, backends) - // Check cache first + // Check cache first (with read lock) if caps := m.getCachedCapabilities(cacheKey); caps != nil { logger.Debugf("Cache hit for user %s (key: %s)", identity.Subject, cacheKey) return caps, nil @@ -150,16 +158,33 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) logger.Debugf("Cache miss - performing capability discovery for user: %s", identity.Subject) - // Cache miss - perform aggregation - caps, err := m.aggregator.AggregateCapabilities(ctx, backends) + // Use singleflight to ensure only one aggregation happens per cache key + // Even if multiple requests come in concurrently, they'll all wait for the same result + result, err, _ := m.singleFlight.Do(cacheKey, func() (interface{}, error) { + // Double-check cache after acquiring singleflight lock + // Another goroutine might have populated it while we were waiting + if caps := m.getCachedCapabilities(cacheKey); caps != nil { + logger.Debugf("Cache populated while waiting - returning cached result for user %s", identity.Subject) + return caps, nil + } + + // Perform aggregation + caps, err := m.aggregator.AggregateCapabilities(ctx, backends) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrDiscoveryFailed, err) + } + + // Cache the result (skips caching if at capacity and key doesn't exist) + m.cacheCapabilities(cacheKey, caps) + + return caps, nil + }) + if err != nil { - return nil, fmt.Errorf("%w: %w", ErrDiscoveryFailed, err) + return nil, err } - // Cache the result (skips caching if at capacity and key doesn't exist) - m.cacheCapabilities(cacheKey, caps) - - return caps, nil + return result.(*aggregator.AggregatedCapabilities), nil } // Stop gracefully stops the manager and cleans up resources. diff --git a/pkg/vmcp/discovery/manager_test_coverage.go b/pkg/vmcp/discovery/manager_test_coverage.go new file mode 100644 index 0000000000..3826fc2849 --- /dev/null +++ b/pkg/vmcp/discovery/manager_test_coverage.go @@ -0,0 +1,176 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package discovery + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + aggmocks "github.com/stacklok/toolhive/pkg/vmcp/aggregator/mocks" + vmcpmocks "github.com/stacklok/toolhive/pkg/vmcp/mocks" +) + +// TestDefaultManager_CacheVersionMismatch tests cache invalidation on version mismatch +func TestDefaultManager_CacheVersionMismatch(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAggregator := aggmocks.NewMockAggregator(ctrl) + mockRegistry := vmcpmocks.NewMockDynamicRegistry(ctrl) + + // First call - version 1 + mockRegistry.EXPECT().Version().Return(uint64(1)).Times(2) + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + manager, err := NewManagerWithRegistry(mockAggregator, mockRegistry) + require.NoError(t, err) + defer manager.Stop() + + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-1", + }) + + backends := []vmcp.Backend{ + {ID: "backend-1", Name: "Backend 1"}, + } + + // First discovery - should cache + caps1, err := manager.Discover(ctx, backends) + require.NoError(t, err) + require.NotNil(t, caps1) + + // Second discovery with same version - should use cache + mockRegistry.EXPECT().Version().Return(uint64(1)).Times(1) + caps2, err := manager.Discover(ctx, backends) + require.NoError(t, err) + require.NotNil(t, caps2) + + // Third discovery with different version - should invalidate cache + mockRegistry.EXPECT().Version().Return(uint64(2)).Times(2) + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + caps3, err := manager.Discover(ctx, backends) + require.NoError(t, err) + require.NotNil(t, caps3) +} + +// TestDefaultManager_CacheAtCapacity tests cache eviction when at capacity +func TestDefaultManager_CacheAtCapacity(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAggregator := aggmocks.NewMockAggregator(ctrl) + + // Create many different cache keys to fill cache + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(maxCacheSize + 1) // One more than capacity + + manager, err := NewManager(mockAggregator) + require.NoError(t, err) + defer manager.Stop() + + // Fill cache to capacity + for i := 0; i < maxCacheSize; i++ { + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-" + string(rune(i)), + }) + + backends := []vmcp.Backend{ + {ID: "backend-" + string(rune(i)), Name: "Backend"}, + } + + _, err := manager.Discover(ctx, backends) + require.NoError(t, err) + } + + // Next discovery should not cache (at capacity) + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-new", + }) + + backends := []vmcp.Backend{ + {ID: "backend-new", Name: "Backend"}, + } + + _, err = manager.Discover(ctx, backends) + require.NoError(t, err) +} + +// TestDefaultManager_CacheAtCapacity_ExistingKey tests cache update when at capacity but key exists +func TestDefaultManager_CacheAtCapacity_ExistingKey(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAggregator := aggmocks.NewMockAggregator(ctrl) + + // First call + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + manager, err := NewManager(mockAggregator) + require.NoError(t, err) + defer manager.Stop() + + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-1", + }) + + backends := []vmcp.Backend{ + {ID: "backend-1", Name: "Backend 1"}, + } + + // First discovery + _, err = manager.Discover(ctx, backends) + require.NoError(t, err) + + // Fill cache to capacity with other keys + for i := 0; i < maxCacheSize-1; i++ { + ctxOther := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-" + string(rune(i+2)), + }) + + backendsOther := []vmcp.Backend{ + {ID: "backend-" + string(rune(i+2)), Name: "Backend"}, + } + + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + _, err := manager.Discover(ctxOther, backendsOther) + require.NoError(t, err) + } + + // Update existing key should work even at capacity + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + _, err = manager.Discover(ctx, backends) + require.NoError(t, err) +} diff --git a/pkg/vmcp/discovery/middleware_test.go b/pkg/vmcp/discovery/middleware_test.go index d1b36a870c..3c8cd8e9ca 100644 --- a/pkg/vmcp/discovery/middleware_test.go +++ b/pkg/vmcp/discovery/middleware_test.go @@ -348,8 +348,19 @@ func TestMiddleware_CapabilitiesInContext(t *testing.T) { }, } + // Use Do to capture and verify backends separately, since order may vary mockMgr.EXPECT(). - Discover(gomock.Any(), unorderedBackendsMatcher{backends}). + Discover(gomock.Any(), gomock.Any()). + Do(func(_ context.Context, actualBackends []vmcp.Backend) { + // Verify that we got the expected backends regardless of order + assert.Len(t, actualBackends, 2) + backendIDs := make(map[string]bool) + for _, b := range actualBackends { + backendIDs[b.ID] = true + } + assert.True(t, backendIDs["backend1"], "backend1 should be present") + assert.True(t, backendIDs["backend2"], "backend2 should be present") + }). Return(expectedCaps, nil) // Create handler that inspects context in detail diff --git a/pkg/vmcp/health/checker.go b/pkg/vmcp/health/checker.go index d593aa0401..bf6f5c329c 100644 --- a/pkg/vmcp/health/checker.go +++ b/pkg/vmcp/health/checker.go @@ -11,6 +11,8 @@ import ( "context" "errors" "fmt" + "net/url" + "strings" "time" "github.com/stacklok/toolhive/pkg/logger" @@ -29,6 +31,10 @@ type healthChecker struct { // If a health check succeeds but takes longer than this duration, the backend is marked degraded. // Zero means disabled (backends will never be marked degraded based on response time alone). degradedThreshold time.Duration + + // selfURL is the server's own URL. If a health check targets this URL, it's short-circuited. + // This prevents the server from trying to health check itself. + selfURL string } // NewHealthChecker creates a new health checker that uses BackendClient.ListCapabilities @@ -39,13 +45,20 @@ type healthChecker struct { // - client: BackendClient for communicating with backend MCP servers // - timeout: Maximum duration for health check operations (0 = no timeout) // - degradedThreshold: Response time threshold for marking backend as degraded (0 = disabled) +// - selfURL: Optional server's own URL. If provided, health checks targeting this URL are short-circuited. // // Returns a new HealthChecker implementation. -func NewHealthChecker(client vmcp.BackendClient, timeout time.Duration, degradedThreshold time.Duration) vmcp.HealthChecker { +func NewHealthChecker( + client vmcp.BackendClient, + timeout time.Duration, + degradedThreshold time.Duration, + selfURL string, +) vmcp.HealthChecker { return &healthChecker{ client: client, timeout: timeout, degradedThreshold: degradedThreshold, + selfURL: selfURL, } } @@ -62,16 +75,28 @@ func NewHealthChecker(client vmcp.BackendClient, timeout time.Duration, degraded // The error return is informational and provides context about what failed. // The BackendHealthStatus return indicates the categorized health state. func (h *healthChecker) CheckHealth(ctx context.Context, target *vmcp.BackendTarget) (vmcp.BackendHealthStatus, error) { - // Apply timeout if configured - checkCtx := ctx + // Mark context as health check to bypass authentication logging + // Health checks verify backend availability and should not require user credentials + healthCheckCtx := WithHealthCheckMarker(ctx) + + // Apply timeout if configured (after adding health check marker) + checkCtx := healthCheckCtx var cancel context.CancelFunc if h.timeout > 0 { - checkCtx, cancel = context.WithTimeout(ctx, h.timeout) + checkCtx, cancel = context.WithTimeout(healthCheckCtx, h.timeout) defer cancel() } logger.Debugf("Performing health check for backend %s (%s)", target.WorkloadName, target.BaseURL) + // Short-circuit health check if targeting ourselves + // This prevents the server from trying to health check itself, which would work + // but is wasteful and can cause connection issues during startup + if h.selfURL != "" && h.isSelfCheck(target.BaseURL) { + logger.Debugf("Skipping health check for backend %s - this is the server itself", target.WorkloadName) + return vmcp.BackendHealthy, nil + } + // Track response time for degraded detection startTime := time.Now() @@ -137,3 +162,62 @@ func categorizeError(err error) vmcp.BackendHealthStatus { // Default to unhealthy for unknown errors return vmcp.BackendUnhealthy } + +// isSelfCheck checks if a backend URL matches the server's own URL. +// URLs are normalized before comparison to handle variations like: +// - http://127.0.0.1:PORT vs http://localhost:PORT +// - http://HOST:PORT vs http://HOST:PORT/ +func (h *healthChecker) isSelfCheck(backendURL string) bool { + if h.selfURL == "" || backendURL == "" { + return false + } + + // Normalize both URLs for comparison + backendNormalized, err := NormalizeURLForComparison(backendURL) + if err != nil { + return false + } + + selfNormalized, err := NormalizeURLForComparison(h.selfURL) + if err != nil { + return false + } + + return backendNormalized == selfNormalized +} + +// NormalizeURLForComparison normalizes a URL for comparison by: +// - Parsing and reconstructing the URL +// - Converting localhost/127.0.0.1 to a canonical form +// - Comparing only scheme://host:port (ignoring path, query, fragment) +// - Lowercasing scheme and host +// Exported for testing purposes +func NormalizeURLForComparison(rawURL string) (string, error) { + u, err := url.Parse(rawURL) + if err != nil { + return "", err + } + // Validate that we have a scheme and host (basic URL validation) + if u.Scheme == "" || u.Host == "" { + return "", fmt.Errorf("invalid URL: missing scheme or host") + } + + // Normalize host: convert localhost to 127.0.0.1 for consistency + host := strings.ToLower(u.Hostname()) + if host == "localhost" { + host = "127.0.0.1" + } + + // Reconstruct URL with normalized components (scheme://host:port only) + // We ignore path, query, and fragment for comparison + normalized := &url.URL{ + Scheme: strings.ToLower(u.Scheme), + } + if u.Port() != "" { + normalized.Host = host + ":" + u.Port() + } else { + normalized.Host = host + } + + return normalized.String(), nil +} diff --git a/pkg/vmcp/health/checker_selfcheck_test.go b/pkg/vmcp/health/checker_selfcheck_test.go new file mode 100644 index 0000000000..ff963d8d35 --- /dev/null +++ b/pkg/vmcp/health/checker_selfcheck_test.go @@ -0,0 +1,504 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package health + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/mocks" +) + +// TestHealthChecker_CheckHealth_SelfCheck tests self-check detection +func TestHealthChecker_CheckHealth_SelfCheck(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + // Should not call ListCapabilities for self-check + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Times(0) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8080", // Same as selfURL + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_Localhost tests localhost normalization +func TestHealthChecker_CheckHealth_SelfCheck_Localhost(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Times(0) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://localhost:8080") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8080", // localhost should match 127.0.0.1 + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_Reverse tests reverse localhost normalization +func TestHealthChecker_CheckHealth_SelfCheck_Reverse(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Times(0) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://localhost:8080", // 127.0.0.1 should match localhost + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_DifferentPort tests different ports don't match +func TestHealthChecker_CheckHealth_SelfCheck_DifferentPort(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + Times(1) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8081", // Different port + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_EmptyURL tests empty URLs +func TestHealthChecker_CheckHealth_SelfCheck_EmptyURL(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + Times(1) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8080", + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_InvalidURL tests invalid URLs +func TestHealthChecker_CheckHealth_SelfCheck_InvalidURL(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + Times(1) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "not-a-valid-url") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8080", + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_WithPath tests URLs with paths are normalized +func TestHealthChecker_CheckHealth_SelfCheck_WithPath(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Times(0) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8080/mcp", // Path should be ignored + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_DegradedThreshold tests degraded threshold detection +func TestHealthChecker_CheckHealth_DegradedThreshold(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + // Simulate slow response + time.Sleep(150 * time.Millisecond) + return &vmcp.CapabilityList{}, nil + }). + Times(1) + + // Set degraded threshold to 100ms + checker := NewHealthChecker(mockClient, 5*time.Second, 100*time.Millisecond, "") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://localhost:8080", + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendDegraded, status, "Should mark as degraded when response time exceeds threshold") +} + +// TestHealthChecker_CheckHealth_DegradedThreshold_Disabled tests disabled degraded threshold +func TestHealthChecker_CheckHealth_DegradedThreshold_Disabled(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + // Simulate slow response + time.Sleep(150 * time.Millisecond) + return &vmcp.CapabilityList{}, nil + }). + Times(1) + + // Set degraded threshold to 0 (disabled) + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://localhost:8080", + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status, "Should not mark as degraded when threshold is disabled") +} + +// TestHealthChecker_CheckHealth_DegradedThreshold_FastResponse tests fast response doesn't trigger degraded +func TestHealthChecker_CheckHealth_DegradedThreshold_FastResponse(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + Times(1) + + // Set degraded threshold to 100ms + checker := NewHealthChecker(mockClient, 5*time.Second, 100*time.Millisecond, "") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://localhost:8080", + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status, "Should not mark as degraded when response is fast") +} + +// TestCategorizeError_SentinelErrors tests sentinel error categorization +func TestCategorizeError_SentinelErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + expectedStatus vmcp.BackendHealthStatus + }{ + { + name: "ErrAuthenticationFailed", + err: vmcp.ErrAuthenticationFailed, + expectedStatus: vmcp.BackendUnauthenticated, + }, + { + name: "ErrAuthorizationFailed", + err: vmcp.ErrAuthorizationFailed, + expectedStatus: vmcp.BackendUnauthenticated, + }, + { + name: "ErrTimeout", + err: vmcp.ErrTimeout, + expectedStatus: vmcp.BackendUnhealthy, + }, + { + name: "ErrCancelled", + err: vmcp.ErrCancelled, + expectedStatus: vmcp.BackendUnhealthy, + }, + { + name: "ErrBackendUnavailable", + err: vmcp.ErrBackendUnavailable, + expectedStatus: vmcp.BackendUnhealthy, + }, + { + name: "wrapped ErrAuthenticationFailed", + err: errors.New("wrapped: " + vmcp.ErrAuthenticationFailed.Error()), + expectedStatus: vmcp.BackendUnauthenticated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + status := categorizeError(tt.err) + assert.Equal(t, tt.expectedStatus, status) + }) + } +} + +// TestNormalizeURLForComparison tests URL normalization +func TestNormalizeURLForComparison(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + wantErr bool + }{ + { + name: "localhost normalized to 127.0.0.1", + input: "http://localhost:8080", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "127.0.0.1 stays as is", + input: "http://127.0.0.1:8080", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "path is ignored", + input: "http://127.0.0.1:8080/mcp", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "query is ignored", + input: "http://127.0.0.1:8080?param=value", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "fragment is ignored", + input: "http://127.0.0.1:8080#fragment", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "scheme is lowercased", + input: "HTTP://127.0.0.1:8080", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "host is lowercased", + input: "http://EXAMPLE.COM:8080", + expected: "http://example.com:8080", + wantErr: false, + }, + { + name: "no port", + input: "http://127.0.0.1", + expected: "http://127.0.0.1", + wantErr: false, + }, + { + name: "invalid URL", + input: "not-a-url", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result, err := NormalizeURLForComparison(tt.input) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +// TestIsSelfCheck_EdgeCases tests edge cases for self-check detection +func TestIsSelfCheck_EdgeCases(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(func() { ctrl.Finish() }) + + mockClient := mocks.NewMockBackendClient(ctrl) + + tests := []struct { + name string + selfURL string + backendURL string + expected bool + }{ + { + name: "both empty", + selfURL: "", + backendURL: "", + expected: false, + }, + { + name: "selfURL empty", + selfURL: "", + backendURL: "http://127.0.0.1:8080", + expected: false, + }, + { + name: "backendURL empty", + selfURL: "http://127.0.0.1:8080", + backendURL: "", + expected: false, + }, + { + name: "localhost matches 127.0.0.1", + selfURL: "http://localhost:8080", + backendURL: "http://127.0.0.1:8080", + expected: true, + }, + { + name: "127.0.0.1 matches localhost", + selfURL: "http://127.0.0.1:8080", + backendURL: "http://localhost:8080", + expected: true, + }, + { + name: "different ports", + selfURL: "http://127.0.0.1:8080", + backendURL: "http://127.0.0.1:8081", + expected: false, + }, + { + name: "different hosts", + selfURL: "http://127.0.0.1:8080", + backendURL: "http://192.168.1.1:8080", + expected: false, + }, + { + name: "path ignored", + selfURL: "http://127.0.0.1:8080", + backendURL: "http://127.0.0.1:8080/mcp", + expected: true, + }, + { + name: "query ignored", + selfURL: "http://127.0.0.1:8080", + backendURL: "http://127.0.0.1:8080?param=value", + expected: true, + }, + { + name: "invalid selfURL", + selfURL: "not-a-url", + backendURL: "http://127.0.0.1:8080", + expected: false, + }, + { + name: "invalid backendURL", + selfURL: "http://127.0.0.1:8080", + backendURL: "not-a-url", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, tt.selfURL) + hc, ok := checker.(*healthChecker) + require.True(t, ok) + + result := hc.isSelfCheck(tt.backendURL) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/vmcp/health/checker_test.go b/pkg/vmcp/health/checker_test.go index 39f7258d82..63c3c986b6 100644 --- a/pkg/vmcp/health/checker_test.go +++ b/pkg/vmcp/health/checker_test.go @@ -44,7 +44,7 @@ func TestNewHealthChecker(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - checker := NewHealthChecker(mockClient, tt.timeout, 0) + checker := NewHealthChecker(mockClient, tt.timeout, 0, "") require.NotNil(t, checker) // Type assert to access internals for verification @@ -68,7 +68,7 @@ func TestHealthChecker_CheckHealth_Success(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). Times(1) - checker := NewHealthChecker(mockClient, 5*time.Second, 0) + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -95,7 +95,7 @@ func TestHealthChecker_CheckHealth_ContextCancellation(t *testing.T) { }). Times(1) - checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0) + checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0, "") target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -123,7 +123,7 @@ func TestHealthChecker_CheckHealth_NoTimeout(t *testing.T) { Times(1) // Create checker with no timeout - checker := NewHealthChecker(mockClient, 0, 0) + checker := NewHealthChecker(mockClient, 0, 0, "") target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -213,7 +213,7 @@ func TestHealthChecker_CheckHealth_ErrorCategorization(t *testing.T) { Return(nil, tt.err). Times(1) - checker := NewHealthChecker(mockClient, 5*time.Second, 0) + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -430,7 +430,7 @@ func TestHealthChecker_CheckHealth_Timeout(t *testing.T) { }). Times(1) - checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0) + checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0, "") target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -467,7 +467,7 @@ func TestHealthChecker_CheckHealth_MultipleBackends(t *testing.T) { }). Times(4) - checker := NewHealthChecker(mockClient, 5*time.Second, 0) + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") // Test healthy backend status, err := checker.CheckHealth(context.Background(), &vmcp.BackendTarget{ diff --git a/pkg/vmcp/health/monitor.go b/pkg/vmcp/health/monitor.go index 50f00d788d..62aea9b735 100644 --- a/pkg/vmcp/health/monitor.go +++ b/pkg/vmcp/health/monitor.go @@ -108,12 +108,14 @@ func DefaultConfig() MonitorConfig { // - client: BackendClient for communicating with backend MCP servers // - backends: List of backends to monitor // - config: Configuration for health monitoring +// - selfURL: Optional server's own URL. If provided, health checks targeting this URL are short-circuited. // // Returns (monitor, error). Error is returned if configuration is invalid. func NewMonitor( client vmcp.BackendClient, backends []vmcp.Backend, config MonitorConfig, + selfURL string, ) (*Monitor, error) { // Validate configuration if config.CheckInterval <= 0 { @@ -123,8 +125,8 @@ func NewMonitor( return nil, fmt.Errorf("unhealthy threshold must be >= 1, got %d", config.UnhealthyThreshold) } - // Create health checker with degraded threshold - checker := NewHealthChecker(client, config.Timeout, config.DegradedThreshold) + // Create health checker with degraded threshold and self URL + checker := NewHealthChecker(client, config.Timeout, config.DegradedThreshold, selfURL) // Create status tracker statusTracker := newStatusTracker(config.UnhealthyThreshold) diff --git a/pkg/vmcp/health/monitor_test.go b/pkg/vmcp/health/monitor_test.go index bb177017e7..8d2de11bdd 100644 --- a/pkg/vmcp/health/monitor_test.go +++ b/pkg/vmcp/health/monitor_test.go @@ -66,7 +66,7 @@ func TestNewMonitor_Validation(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - monitor, err := NewMonitor(mockClient, backends, tt.config) + monitor, err := NewMonitor(mockClient, backends, tt.config, "") if tt.expectError { assert.Error(t, err) assert.Nil(t, monitor) @@ -101,7 +101,7 @@ func TestMonitor_StartStop(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) // Start monitor @@ -178,7 +178,7 @@ func TestMonitor_StartErrors(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) err = tt.setupFunc(monitor) @@ -208,7 +208,7 @@ func TestMonitor_StopWithoutStart(t *testing.T) { Timeout: 50 * time.Millisecond, } - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) // Try to stop without starting @@ -239,7 +239,7 @@ func TestMonitor_PeriodicHealthChecks(t *testing.T) { Return(nil, errors.New("backend unavailable")). MinTimes(2) - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) ctx := context.Background() @@ -289,7 +289,7 @@ func TestMonitor_GetHealthSummary(t *testing.T) { }). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) ctx := context.Background() @@ -333,7 +333,7 @@ func TestMonitor_GetBackendStatus(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) ctx := context.Background() @@ -382,7 +382,7 @@ func TestMonitor_GetBackendState(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) ctx := context.Background() @@ -433,7 +433,7 @@ func TestMonitor_GetAllBackendStates(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) ctx := context.Background() @@ -477,7 +477,7 @@ func TestMonitor_ContextCancellation(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) // Start with cancellable context diff --git a/pkg/vmcp/optimizer/config.go b/pkg/vmcp/optimizer/config.go new file mode 100644 index 0000000000..62aef2669c --- /dev/null +++ b/pkg/vmcp/optimizer/config.go @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package optimizer + +import ( + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings" + "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +// ConfigFromVMCPConfig converts a vmcp/config.OptimizerConfig to optimizer.Config. +// This helper function bridges the gap between the shared config package and +// the optimizer package's internal configuration structure. +func ConfigFromVMCPConfig(cfg *config.OptimizerConfig) *Config { + if cfg == nil { + return nil + } + + optimizerCfg := &Config{ + Enabled: cfg.Enabled, + PersistPath: cfg.PersistPath, + FTSDBPath: cfg.FTSDBPath, + HybridSearchRatio: 70, // Default + } + + // Handle HybridSearchRatio (pointer in config, value in optimizer.Config) + if cfg.HybridSearchRatio != nil { + optimizerCfg.HybridSearchRatio = *cfg.HybridSearchRatio + } + + // Convert embedding config + if cfg.EmbeddingBackend != "" || cfg.EmbeddingURL != "" || cfg.EmbeddingModel != "" || cfg.EmbeddingDimension > 0 { + optimizerCfg.EmbeddingConfig = &embeddings.Config{ + BackendType: cfg.EmbeddingBackend, + BaseURL: cfg.EmbeddingURL, + Model: cfg.EmbeddingModel, + Dimension: cfg.EmbeddingDimension, + } + } + + return optimizerCfg +} diff --git a/pkg/vmcp/optimizer/dummy_optimizer.go b/pkg/vmcp/optimizer/dummy_optimizer.go deleted file mode 100644 index 00c9be9eae..0000000000 --- a/pkg/vmcp/optimizer/dummy_optimizer.go +++ /dev/null @@ -1,119 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package optimizer - -import ( - "context" - "encoding/json" - "fmt" - "strings" - - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" -) - -// DummyOptimizer implements the Optimizer interface using exact string matching. -// -// This implementation is intended for testing and development. It performs -// case-insensitive substring matching on tool names and descriptions. -// -// For production use, see the EmbeddingOptimizer which uses semantic similarity. -type DummyOptimizer struct { - // tools contains all available tools indexed by name. - tools map[string]server.ServerTool -} - -// NewDummyOptimizer creates a new DummyOptimizer with the given tools. -// -// The tools slice should contain all backend tools (as ServerTool with handlers). -func NewDummyOptimizer(tools []server.ServerTool) Optimizer { - toolMap := make(map[string]server.ServerTool, len(tools)) - for _, tool := range tools { - toolMap[tool.Tool.Name] = tool - } - - return DummyOptimizer{ - tools: toolMap, - } -} - -// FindTool searches for tools using exact substring matching. -// -// The search is case-insensitive and matches against: -// - Tool name (substring match) -// - Tool description (substring match) -// -// Returns all matching tools with a score of 1.0 (exact match semantics). -// TokenMetrics are returned as zero values (not implemented in dummy). -func (d DummyOptimizer) FindTool(_ context.Context, input FindToolInput) (*FindToolOutput, error) { - if input.ToolDescription == "" { - return nil, fmt.Errorf("tool_description is required") - } - - searchTerm := strings.ToLower(input.ToolDescription) - - var matches []ToolMatch - for _, tool := range d.tools { - nameLower := strings.ToLower(tool.Tool.Name) - descLower := strings.ToLower(tool.Tool.Description) - - // Check if search term matches name or description - if strings.Contains(nameLower, searchTerm) || strings.Contains(descLower, searchTerm) { - schema, err := getToolSchema(tool.Tool) - if err != nil { - return nil, err - } - matches = append(matches, ToolMatch{ - Name: tool.Tool.Name, - Description: tool.Tool.Description, - InputSchema: schema, - Score: 1.0, // Exact match semantics - }) - } - } - - return &FindToolOutput{ - Tools: matches, - TokenMetrics: TokenMetrics{}, // Zero values for dummy - }, nil -} - -// CallTool invokes a tool by name using its registered handler. -// -// The tool is looked up by exact name match. If found, the handler -// is invoked directly with the given parameters. -func (d DummyOptimizer) CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error) { - if input.ToolName == "" { - return nil, fmt.Errorf("tool_name is required") - } - - // Verify the tool exists - tool, exists := d.tools[input.ToolName] - if !exists { - return mcp.NewToolResultError(fmt.Sprintf("tool not found: %s", input.ToolName)), nil - } - - // Build the MCP request - request := mcp.CallToolRequest{} - request.Params.Name = input.ToolName - request.Params.Arguments = input.Parameters - - // Call the tool handler directly - return tool.Handler(ctx, request) -} - -// getToolSchema returns the input schema for a tool. -// Prefers RawInputSchema if set, otherwise marshals InputSchema. -func getToolSchema(tool mcp.Tool) (json.RawMessage, error) { - if len(tool.RawInputSchema) > 0 { - return tool.RawInputSchema, nil - } - - // Fall back to InputSchema - data, err := json.Marshal(tool.InputSchema) - if err != nil { - return nil, fmt.Errorf("failed to marshal input schema for tool %s: %w", tool.Name, err) - } - return data, nil -} diff --git a/pkg/vmcp/optimizer/dummy_optimizer_test.go b/pkg/vmcp/optimizer/dummy_optimizer_test.go deleted file mode 100644 index 2113a5a4c1..0000000000 --- a/pkg/vmcp/optimizer/dummy_optimizer_test.go +++ /dev/null @@ -1,191 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package optimizer - -import ( - "context" - "testing" - - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" - "github.com/stretchr/testify/require" -) - -func TestDummyOptimizer_FindTool(t *testing.T) { - t.Parallel() - - tools := []server.ServerTool{ - { - Tool: mcp.Tool{ - Name: "fetch_url", - Description: "Fetch content from a URL", - }, - }, - { - Tool: mcp.Tool{ - Name: "read_file", - Description: "Read a file from the filesystem", - }, - }, - { - Tool: mcp.Tool{ - Name: "write_file", - Description: "Write content to a file", - }, - }, - } - - opt := NewDummyOptimizer(tools) - - tests := []struct { - name string - input FindToolInput - expectedNames []string - expectedError bool - errorContains string - }{ - { - name: "find by exact name", - input: FindToolInput{ - ToolDescription: "fetch_url", - }, - expectedNames: []string{"fetch_url"}, - }, - { - name: "find by description substring", - input: FindToolInput{ - ToolDescription: "file", - }, - expectedNames: []string{"read_file", "write_file"}, - }, - { - name: "case insensitive search", - input: FindToolInput{ - ToolDescription: "FETCH", - }, - expectedNames: []string{"fetch_url"}, - }, - { - name: "no matches", - input: FindToolInput{ - ToolDescription: "nonexistent", - }, - expectedNames: []string{}, - }, - { - name: "empty description", - input: FindToolInput{}, - expectedError: true, - errorContains: "tool_description is required", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - result, err := opt.FindTool(context.Background(), tc.input) - - if tc.expectedError { - require.Error(t, err) - require.Contains(t, err.Error(), tc.errorContains) - return - } - - require.NoError(t, err) - require.NotNil(t, result) - - // Extract names from results - var names []string - for _, match := range result.Tools { - names = append(names, match.Name) - } - - require.ElementsMatch(t, tc.expectedNames, names) - }) - } -} - -func TestDummyOptimizer_CallTool(t *testing.T) { - t.Parallel() - - tools := []server.ServerTool{ - { - Tool: mcp.Tool{ - Name: "test_tool", - Description: "A test tool", - }, - Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - args, _ := req.Params.Arguments.(map[string]any) - input := args["input"].(string) - return mcp.NewToolResultText("Hello, " + input + "!"), nil - }, - }, - } - - opt := NewDummyOptimizer(tools) - - tests := []struct { - name string - input CallToolInput - expectedText string - expectedError bool - isToolError bool - errorContains string - }{ - { - name: "successful tool call", - input: CallToolInput{ - ToolName: "test_tool", - Parameters: map[string]any{"input": "World"}, - }, - expectedText: "Hello, World!", - }, - { - name: "tool not found", - input: CallToolInput{ - ToolName: "nonexistent", - Parameters: map[string]any{}, - }, - isToolError: true, - expectedText: "tool not found: nonexistent", - }, - { - name: "empty tool name", - input: CallToolInput{ - Parameters: map[string]any{}, - }, - expectedError: true, - errorContains: "tool_name is required", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - result, err := opt.CallTool(context.Background(), tc.input) - - if tc.expectedError { - require.Error(t, err) - require.Contains(t, err.Error(), tc.errorContains) - return - } - - require.NoError(t, err) - require.NotNil(t, result) - - if tc.isToolError { - require.True(t, result.IsError) - } - - if tc.expectedText != "" { - require.Len(t, result.Content, 1) - textContent, ok := result.Content[0].(mcp.TextContent) - require.True(t, ok) - require.Equal(t, tc.expectedText, textContent.Text) - } - }) - } -} diff --git a/pkg/vmcp/optimizer/find_tool_semantic_search_test.go b/pkg/vmcp/optimizer/find_tool_semantic_search_test.go new file mode 100644 index 0000000000..3868bfd54d --- /dev/null +++ b/pkg/vmcp/optimizer/find_tool_semantic_search_test.go @@ -0,0 +1,693 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package optimizer + +import ( + "context" + "encoding/json" + "path/filepath" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/discovery" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" +) + +const ( + testBackendOllama = "ollama" + testBackendOpenAI = "openai" +) + +// verifyEmbeddingBackendWorking verifies that the embedding backend is actually working by attempting to generate an embedding +// This ensures the service is not just reachable but actually functional +func verifyEmbeddingBackendWorking(t *testing.T, manager *embeddings.Manager, backendType string) { + t.Helper() + _, err := manager.GenerateEmbedding([]string{"test"}) + if err != nil { + if backendType == testBackendOllama { + t.Skipf("Skipping test: Ollama is reachable but embedding generation failed. Error: %v. Ensure 'ollama pull %s' has been executed", err, embeddings.DefaultModelAllMiniLM) + } else { + t.Skipf("Skipping test: Embedding backend is reachable but embedding generation failed. Error: %v", err) + } + } +} + +// TestFindTool_SemanticSearch tests semantic search capabilities +// These tests verify that find_tool can find tools based on semantic meaning, +// not just exact keyword matches +func TestFindTool_SemanticSearch(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Try to use Ollama if available, otherwise skip test + embeddingBackend := testBackendOllama + embeddingConfig := &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, // all-MiniLM-L6-v2 dimension + } + + // Test if Ollama is available + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + // Try OpenAI-compatible (might be vLLM or Ollama v1 API) + embeddingConfig.BackendType = testBackendOpenAI + embeddingConfig.BaseURL = "http://localhost:11434" + embeddingConfig.Model = embeddings.DefaultModelAllMiniLM + embeddingConfig.Dimension = 768 + embeddingManager, err = embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping semantic search test: No embedding backend available (Ollama or OpenAI-compatible). Error: %v", err) + return + } + embeddingBackend = testBackendOpenAI + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Verify embedding backend is actually working, not just reachable + verifyEmbeddingBackendWorking(t, embeddingManager, embeddingBackend) + + // Setup optimizer integration with high semantic ratio to favor semantic search + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: embeddingConfig.BaseURL, + Model: embeddingConfig.Model, + Dimension: embeddingConfig.Dimension, + }, + HybridSearchRatio: 90, // 90% semantic, 10% BM25 to test semantic search + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.NotNil(t, integration) + t.Cleanup(func() { _ = integration.Close() }) + + // Create tools with diverse descriptions to test semantic understanding + tools := []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Get information on a specific pull request in GitHub repository.", + BackendID: "github", + }, + { + Name: "github_list_pull_requests", + Description: "List pull requests in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_create_pull_request", + Description: "Create a new pull request in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_merge_pull_request", + Description: "Merge a pull request in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_issue_read", + Description: "Get information about a specific issue in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_list_issues", + Description: "List issues in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_create_repository", + Description: "Create a new GitHub repository in your account or specified organization", + BackendID: "github", + }, + { + Name: "github_get_commit", + Description: "Get details for a commit from a GitHub repository", + BackendID: "github", + }, + { + Name: "github_get_branch", + Description: "Get information about a branch in a GitHub repository", + BackendID: "github", + }, + { + Name: "fetch_fetch", + Description: "Fetches a URL from the internet and optionally extracts its contents as markdown.", + BackendID: "fetch", + }, + } + + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + for _, tool := range tools { + capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{ + WorkloadID: tool.BackendID, + WorkloadName: tool.BackendID, + } + } + + session := &mockSession{sessionID: "test-session"} + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Test cases for semantic search - queries that mean the same thing but use different words + testCases := []struct { + name string + query string + keywords string + expectedTools []string // Tools that should be found semantically + description string + }{ + { + name: "semantic_pr_synonyms", + query: "view code review request", + keywords: "", + expectedTools: []string{"github_pull_request_read", "github_list_pull_requests"}, + description: "Should find PR tools using semantic synonyms (code review = pull request)", + }, + { + name: "semantic_merge_synonyms", + query: "combine code changes", + keywords: "", + expectedTools: []string{"github_merge_pull_request"}, + description: "Should find merge tool using semantic meaning (combine = merge)", + }, + { + name: "semantic_create_synonyms", + query: "make a new code review", + keywords: "", + expectedTools: []string{"github_create_pull_request", "github_list_pull_requests", "github_pull_request_read"}, + description: "Should find PR-related tools using semantic meaning (make = create, code review = PR)", + }, + { + name: "semantic_issue_synonyms", + query: "show bug reports", + keywords: "", + expectedTools: []string{"github_issue_read", "github_list_issues"}, + description: "Should find issue tools using semantic synonyms (bug report = issue)", + }, + { + name: "semantic_repository_synonyms", + query: "start a new project", + keywords: "", + expectedTools: []string{"github_create_repository"}, + description: "Should find repository tool using semantic meaning (project = repository)", + }, + { + name: "semantic_commit_synonyms", + query: "get change details", + keywords: "", + expectedTools: []string{"github_get_commit"}, + description: "Should find commit tool using semantic meaning (change = commit)", + }, + { + name: "semantic_fetch_synonyms", + query: "download web page content", + keywords: "", + expectedTools: []string{"fetch_fetch"}, + description: "Should find fetch tool using semantic synonyms (download = fetch)", + }, + { + name: "semantic_branch_synonyms", + query: "get branch information", + keywords: "", + expectedTools: []string{"github_get_branch"}, + description: "Should find branch tool using semantic meaning", + }, + { + name: "semantic_related_concepts", + query: "code collaboration features", + keywords: "", + expectedTools: []string{"github_pull_request_read", "github_create_pull_request", "github_issue_read"}, + description: "Should find collaboration-related tools (PRs and issues are collaboration features)", + }, + { + name: "semantic_intent_based", + query: "I want to see what code changes were made", + keywords: "", + expectedTools: []string{"github_get_commit", "github_pull_request_read"}, + description: "Should find tools based on user intent (seeing code changes = commits/PRs)", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": tc.query, + "tool_keywords": tc.keywords, + "limit": 10, + }, + }, + } + + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError, "Tool call should not return error for query: %s", tc.query) + + // Parse the result + require.NotEmpty(t, result.Content, "Result should have content") + textContent, okText := mcp.AsTextContent(result.Content[0]) + require.True(t, okText, "Result should be text content") + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err, "Result should be valid JSON") + + toolsArray, okArray := response["tools"].([]interface{}) + require.True(t, okArray, "Response should have tools array") + require.NotEmpty(t, toolsArray, "Should return at least one result for semantic query: %s", tc.query) + + // Extract tool names from results + foundTools := make([]string, 0, len(toolsArray)) + for _, toolInterface := range toolsArray { + toolMap, okMap := toolInterface.(map[string]interface{}) + require.True(t, okMap, "Tool should be a map") + toolName, okName := toolMap["name"].(string) + require.True(t, okName, "Tool should have name") + foundTools = append(foundTools, toolName) + + // Verify similarity score exists and is reasonable + similarity, okScore := toolMap["similarity_score"].(float64) + require.True(t, okScore, "Tool should have similarity_score") + assert.Greater(t, similarity, 0.0, "Similarity score should be positive") + } + + // Check that at least one expected tool is found + foundCount := 0 + for _, expectedTool := range tc.expectedTools { + for _, foundTool := range foundTools { + if foundTool == expectedTool { + foundCount++ + break + } + } + } + + assert.GreaterOrEqual(t, foundCount, 1, + "Semantic query '%s' should find at least one expected tool from %v. Found tools: %v (found %d/%d)", + tc.query, tc.expectedTools, foundTools, foundCount, len(tc.expectedTools)) + + // Log results for debugging + if foundCount < len(tc.expectedTools) { + t.Logf("Semantic query '%s': Found %d/%d expected tools. Found: %v, Expected: %v", + tc.query, foundCount, len(tc.expectedTools), foundTools, tc.expectedTools) + } + + // Verify token metrics exist + tokenMetrics, okMetrics := response["token_metrics"].(map[string]interface{}) + require.True(t, okMetrics, "Response should have token_metrics") + assert.Contains(t, tokenMetrics, "baseline_tokens") + assert.Contains(t, tokenMetrics, "returned_tokens") + }) + } +} + +// TestFindTool_SemanticVsKeyword tests that semantic search finds different results than keyword search +func TestFindTool_SemanticVsKeyword(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Try to use Ollama if available + embeddingBackend := "ollama" + embeddingConfig := &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + // Try OpenAI-compatible + embeddingConfig.BackendType = testBackendOpenAI + embeddingManager, err = embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: No embedding backend available. Error: %v", err) + return + } + embeddingBackend = testBackendOpenAI + } + + // Verify embedding backend is actually working, not just reachable + verifyEmbeddingBackendWorking(t, embeddingManager, embeddingBackend) + _ = embeddingManager.Close() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Test with high semantic ratio + configSemantic := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db-semantic"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: embeddingConfig.BaseURL, + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + HybridSearchRatio: 90, // 90% semantic + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integrationSemantic, err := NewIntegration(ctx, configSemantic, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integrationSemantic.Close() }() + + // Test with low semantic ratio (high BM25) + configKeyword := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db-keyword"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: embeddingConfig.BaseURL, + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + HybridSearchRatio: 10, // 10% semantic, 90% BM25 + } + + integrationKeyword, err := NewIntegration(ctx, configKeyword, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integrationKeyword.Close() }() + + tools := []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Get information on a specific pull request in GitHub repository.", + BackendID: "github", + }, + { + Name: "github_create_repository", + Description: "Create a new GitHub repository in your account or specified organization", + BackendID: "github", + }, + } + + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + for _, tool := range tools { + capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{ + WorkloadID: tool.BackendID, + WorkloadName: tool.BackendID, + } + } + + session := &mockSession{sessionID: "test-session"} + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Register both integrations + err = integrationSemantic.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + err = integrationKeyword.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integrationSemantic.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + err = integrationKeyword.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + + // Query that has semantic meaning but no exact keyword match + query := "view code review" + + // Test semantic search + requestSemantic := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": query, + "tool_keywords": "", + "limit": 10, + }, + }, + } + + handlerSemantic := integrationSemantic.CreateFindToolHandler() + resultSemantic, err := handlerSemantic(ctxWithCaps, requestSemantic) + require.NoError(t, err) + require.False(t, resultSemantic.IsError) + + // Test keyword search + requestKeyword := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": query, + "tool_keywords": "", + "limit": 10, + }, + }, + } + + handlerKeyword := integrationKeyword.CreateFindToolHandler() + resultKeyword, err := handlerKeyword(ctxWithCaps, requestKeyword) + require.NoError(t, err) + require.False(t, resultKeyword.IsError) + + // Parse both results + textSemantic, _ := mcp.AsTextContent(resultSemantic.Content[0]) + var responseSemantic map[string]any + json.Unmarshal([]byte(textSemantic.Text), &responseSemantic) + + textKeyword, _ := mcp.AsTextContent(resultKeyword.Content[0]) + var responseKeyword map[string]any + json.Unmarshal([]byte(textKeyword.Text), &responseKeyword) + + toolsSemantic, _ := responseSemantic["tools"].([]interface{}) + toolsKeyword, _ := responseKeyword["tools"].([]interface{}) + + // Both should find results (semantic should find PR tools, keyword might not) + assert.NotEmpty(t, toolsSemantic, "Semantic search should find results") + assert.NotEmpty(t, toolsKeyword, "Keyword search should find results") + + // Semantic search should find pull request tools even without exact keyword match + foundPRSemantic := false + for _, toolInterface := range toolsSemantic { + toolMap, _ := toolInterface.(map[string]interface{}) + toolName, _ := toolMap["name"].(string) + if toolName == "github_pull_request_read" { + foundPRSemantic = true + break + } + } + + t.Logf("Semantic search (90%% semantic): Found %d tools", len(toolsSemantic)) + t.Logf("Keyword search (10%% semantic): Found %d tools", len(toolsKeyword)) + t.Logf("Semantic search found PR tool: %v", foundPRSemantic) + + // Semantic search should be able to find semantically related tools + // even when keywords don't match exactly + assert.True(t, foundPRSemantic, + "Semantic search should find 'github_pull_request_read' for query 'view code review' even without exact keyword match") +} + +// TestFindTool_SemanticSimilarityScores tests that similarity scores are meaningful +func TestFindTool_SemanticSimilarityScores(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Try to use Ollama if available + embeddingBackend := "ollama" + embeddingConfig := &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + // Try OpenAI-compatible + embeddingConfig.BackendType = testBackendOpenAI + embeddingManager, err = embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: No embedding backend available. Error: %v", err) + return + } + embeddingBackend = testBackendOpenAI + } + + // Verify embedding backend is actually working, not just reachable + verifyEmbeddingBackendWorking(t, embeddingManager, embeddingBackend) + _ = embeddingManager.Close() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: embeddingConfig.BaseURL, + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + HybridSearchRatio: 90, // High semantic ratio + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + tools := []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Get information on a specific pull request in GitHub repository.", + BackendID: "github", + }, + { + Name: "github_create_repository", + Description: "Create a new GitHub repository in your account or specified organization", + BackendID: "github", + }, + { + Name: "fetch_fetch", + Description: "Fetches a URL from the internet and optionally extracts its contents as markdown.", + BackendID: "fetch", + }, + } + + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + for _, tool := range tools { + capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{ + WorkloadID: tool.BackendID, + WorkloadName: tool.BackendID, + } + } + + session := &mockSession{sessionID: "test-session"} + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Query for pull request + query := "view pull request" + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": query, + "tool_keywords": "", + "limit": 10, + }, + }, + } + + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.False(t, result.IsError) + + textContent, _ := mcp.AsTextContent(result.Content[0]) + var response map[string]any + json.Unmarshal([]byte(textContent.Text), &response) + + toolsArray, _ := response["tools"].([]interface{}) + require.NotEmpty(t, toolsArray) + + // Check that results are sorted by similarity (highest first) + var similarities []float64 + for _, toolInterface := range toolsArray { + toolMap, _ := toolInterface.(map[string]interface{}) + similarity, _ := toolMap["similarity_score"].(float64) + similarities = append(similarities, similarity) + } + + // Verify results are sorted by similarity (descending) + for i := 1; i < len(similarities); i++ { + assert.GreaterOrEqual(t, similarities[i-1], similarities[i], + "Results should be sorted by similarity score (descending). Scores: %v", similarities) + } + + // The most relevant tool (pull request) should have a higher similarity than unrelated tools + if len(similarities) > 1 { + // First result should have highest similarity + assert.Greater(t, similarities[0], 0.0, "Top result should have positive similarity") + } +} diff --git a/pkg/vmcp/optimizer/find_tool_string_matching_test.go b/pkg/vmcp/optimizer/find_tool_string_matching_test.go new file mode 100644 index 0000000000..6166de6164 --- /dev/null +++ b/pkg/vmcp/optimizer/find_tool_string_matching_test.go @@ -0,0 +1,699 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package optimizer + +import ( + "context" + "encoding/json" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/discovery" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" +) + +// verifyOllamaWorking verifies that Ollama is actually working by attempting to generate an embedding +// This ensures the service is not just reachable but actually functional +func verifyOllamaWorking(t *testing.T, manager *embeddings.Manager) { + t.Helper() + _, err := manager.GenerateEmbedding([]string{"test"}) + if err != nil { + t.Skipf("Skipping test: Ollama is reachable but embedding generation failed. Error: %v. Ensure 'ollama pull %s' has been executed", err, embeddings.DefaultModelAllMiniLM) + } +} + +// getRealToolData returns test data based on actual MCP server tools +// These are real tool descriptions from GitHub and other MCP servers +func getRealToolData() []vmcp.Tool { + return []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Get information on a specific pull request in GitHub repository.", + BackendID: "github", + }, + { + Name: "github_list_pull_requests", + Description: "List pull requests in a GitHub repository. If the user specifies an author, then DO NOT use this tool and use the search_pull_requests tool instead.", + BackendID: "github", + }, + { + Name: "github_search_pull_requests", + Description: "Search for pull requests in GitHub repositories using issues search syntax already scoped to is:pr", + BackendID: "github", + }, + { + Name: "github_create_pull_request", + Description: "Create a new pull request in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_merge_pull_request", + Description: "Merge a pull request in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_pull_request_review_write", + Description: "Create and/or submit, delete review of a pull request.", + BackendID: "github", + }, + { + Name: "github_issue_read", + Description: "Get information about a specific issue in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_list_issues", + Description: "List issues in a GitHub repository. For pagination, use the 'endCursor' from the previous response's 'pageInfo' in the 'after' parameter.", + BackendID: "github", + }, + { + Name: "github_create_repository", + Description: "Create a new GitHub repository in your account or specified organization", + BackendID: "github", + }, + { + Name: "github_get_commit", + Description: "Get details for a commit from a GitHub repository", + BackendID: "github", + }, + { + Name: "fetch_fetch", + Description: "Fetches a URL from the internet and optionally extracts its contents as markdown.", + BackendID: "fetch", + }, + } +} + +// TestFindTool_StringMatching tests that find_tool can match strings correctly +func TestFindTool_StringMatching(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Setup optimizer integration + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Verify Ollama is actually working, not just reachable + verifyOllamaWorking(t, embeddingManager) + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + HybridSearchRatio: 50, // 50% semantic, 50% BM25 for better string matching + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.NotNil(t, integration) + t.Cleanup(func() { _ = integration.Close() }) + + // Get real tool data + tools := getRealToolData() + + // Create capabilities with real tools + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + // Build routing table + for _, tool := range tools { + capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{ + WorkloadID: tool.BackendID, + WorkloadName: tool.BackendID, + } + } + + // Register session and generate embeddings + session := &mockSession{sessionID: "test-session"} + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + + // Create context with capabilities + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Test cases: query -> expected tool names that should be found + testCases := []struct { + name string + query string + keywords string + expectedTools []string // Tools that should definitely be in results + minResults int // Minimum number of results expected + description string + }{ + { + name: "exact_pull_request_match", + query: "pull request", + keywords: "pull request", + expectedTools: []string{"github_pull_request_read", "github_list_pull_requests", "github_create_pull_request"}, + minResults: 3, + description: "Should find tools with exact 'pull request' string match", + }, + { + name: "pull_request_in_name", + query: "pull request", + keywords: "pull_request", + expectedTools: []string{"github_pull_request_read", "github_list_pull_requests"}, + minResults: 2, + description: "Should match tools with 'pull_request' in name", + }, + { + name: "list_pull_requests", + query: "list pull requests", + keywords: "list pull requests", + expectedTools: []string{"github_list_pull_requests"}, + minResults: 1, + description: "Should find list pull requests tool", + }, + { + name: "read_pull_request", + query: "read pull request", + keywords: "read pull request", + expectedTools: []string{"github_pull_request_read"}, + minResults: 1, + description: "Should find read pull request tool", + }, + { + name: "create_pull_request", + query: "create pull request", + keywords: "create pull request", + expectedTools: []string{"github_create_pull_request"}, + minResults: 1, + description: "Should find create pull request tool", + }, + { + name: "merge_pull_request", + query: "merge pull request", + keywords: "merge pull request", + expectedTools: []string{"github_merge_pull_request"}, + minResults: 1, + description: "Should find merge pull request tool", + }, + { + name: "search_pull_requests", + query: "search pull requests", + keywords: "search pull requests", + expectedTools: []string{"github_search_pull_requests"}, + minResults: 1, + description: "Should find search pull requests tool", + }, + { + name: "issue_tools", + query: "issue", + keywords: "issue", + expectedTools: []string{"github_issue_read", "github_list_issues"}, + minResults: 2, + description: "Should find issue-related tools", + }, + { + name: "repository_tool", + query: "create repository", + keywords: "create repository", + expectedTools: []string{"github_create_repository"}, + minResults: 1, + description: "Should find create repository tool", + }, + { + name: "commit_tool", + query: "get commit", + keywords: "commit", + expectedTools: []string{"github_get_commit"}, + minResults: 1, + description: "Should find get commit tool", + }, + { + name: "fetch_tool", + query: "fetch URL", + keywords: "fetch", + expectedTools: []string{"fetch_fetch"}, + minResults: 1, + description: "Should find fetch tool", + }, + } + + for _, tc := range testCases { + tc := tc // capture loop variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Create the tool call request + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": tc.query, + "tool_keywords": tc.keywords, + "limit": 20, + }, + }, + } + + // Call the handler + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError, "Tool call should not return error") + + // Parse the result + require.NotEmpty(t, result.Content, "Result should have content") + textContent, ok := mcp.AsTextContent(result.Content[0]) + require.True(t, ok, "Result should be text content") + + // Parse JSON response + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err, "Result should be valid JSON") + + // Check tools array exists + toolsArray, ok := response["tools"].([]interface{}) + require.True(t, ok, "Response should have tools array") + require.GreaterOrEqual(t, len(toolsArray), tc.minResults, + "Should return at least %d results for query: %s", tc.minResults, tc.query) + + // Extract tool names from results + foundTools := make([]string, 0, len(toolsArray)) + for _, toolInterface := range toolsArray { + toolMap, okMap := toolInterface.(map[string]interface{}) + require.True(t, okMap, "Tool should be a map") + toolName, okName := toolMap["name"].(string) + require.True(t, okName, "Tool should have name") + foundTools = append(foundTools, toolName) + } + + // Check that at least some expected tools are found + // String matching may not be perfect, so we check that at least one expected tool is found + foundCount := 0 + for _, expectedTool := range tc.expectedTools { + for _, foundTool := range foundTools { + if foundTool == expectedTool { + foundCount++ + break + } + } + } + + // We should find at least one expected tool, or at least 50% of expected tools + minExpected := 1 + if len(tc.expectedTools) > 1 { + half := len(tc.expectedTools) / 2 + if half > minExpected { + minExpected = half + } + } + + assert.GreaterOrEqual(t, foundCount, minExpected, + "Query '%s' should find at least %d of expected tools %v. Found tools: %v (found %d/%d)", + tc.query, minExpected, tc.expectedTools, foundTools, foundCount, len(tc.expectedTools)) + + // Log which expected tools were found for debugging + if foundCount < len(tc.expectedTools) { + t.Logf("Query '%s': Found %d/%d expected tools. Found: %v, Expected: %v", + tc.query, foundCount, len(tc.expectedTools), foundTools, tc.expectedTools) + } + + // Verify token metrics exist + tokenMetrics, ok := response["token_metrics"].(map[string]interface{}) + require.True(t, ok, "Response should have token_metrics") + assert.Contains(t, tokenMetrics, "baseline_tokens") + assert.Contains(t, tokenMetrics, "returned_tokens") + assert.Contains(t, tokenMetrics, "tokens_saved") + assert.Contains(t, tokenMetrics, "savings_percentage") + }) + } +} + +// TestFindTool_ExactStringMatch tests that exact string matches work correctly +func TestFindTool_ExactStringMatch(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Setup optimizer integration with higher BM25 ratio for better string matching + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Verify Ollama is actually working, not just reachable + verifyOllamaWorking(t, embeddingManager) + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + HybridSearchRatio: 30, // 30% semantic, 70% BM25 for better exact string matching + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.NotNil(t, integration) + t.Cleanup(func() { _ = integration.Close() }) + + // Create tools with specific strings to match + tools := []vmcp.Tool{ + { + Name: "test_pull_request_tool", + Description: "This tool handles pull requests in GitHub", + BackendID: "test", + }, + { + Name: "test_issue_tool", + Description: "This tool handles issues in GitHub", + BackendID: "test", + }, + { + Name: "test_repository_tool", + Description: "This tool creates repositories", + BackendID: "test", + }, + } + + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + for _, tool := range tools { + capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{ + WorkloadID: tool.BackendID, + WorkloadName: tool.BackendID, + } + } + + session := &mockSession{sessionID: "test-session"} + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integration.IngestToolsForTesting(ctx, "test", "test", nil, mcpTools) + require.NoError(t, err) + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Test exact string matching + testCases := []struct { + name string + query string + keywords string + expectedTool string + description string + }{ + { + name: "exact_pull_request_string", + query: "pull request", + keywords: "pull request", + expectedTool: "test_pull_request_tool", + description: "Should match exact 'pull request' string", + }, + { + name: "exact_issue_string", + query: "issue", + keywords: "issue", + expectedTool: "test_issue_tool", + description: "Should match exact 'issue' string", + }, + { + name: "exact_repository_string", + query: "repository", + keywords: "repository", + expectedTool: "test_repository_tool", + description: "Should match exact 'repository' string", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": tc.query, + "tool_keywords": tc.keywords, + "limit": 10, + }, + }, + } + + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError) + + textContent, okText := mcp.AsTextContent(result.Content[0]) + require.True(t, okText) + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + + toolsArray, okArray := response["tools"].([]interface{}) + require.True(t, okArray) + require.NotEmpty(t, toolsArray, "Should find at least one tool for query: %s", tc.query) + + // Check that the expected tool is in the results + found := false + for _, toolInterface := range toolsArray { + toolMap, okMap := toolInterface.(map[string]interface{}) + require.True(t, okMap) + toolName, okName := toolMap["name"].(string) + require.True(t, okName) + if toolName == tc.expectedTool { + found = true + break + } + } + + assert.True(t, found, + "Expected tool '%s' not found in results for query '%s'. This indicates string matching is not working correctly.", + tc.expectedTool, tc.query) + }) + } +} + +// TestFindTool_CaseInsensitive tests case-insensitive string matching +func TestFindTool_CaseInsensitive(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Verify Ollama is actually working, not just reachable + verifyOllamaWorking(t, embeddingManager) + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + HybridSearchRatio: 30, // Favor BM25 for string matching + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.NotNil(t, integration) + t.Cleanup(func() { _ = integration.Close() }) + + tools := []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Get information on a specific pull request in GitHub repository.", + BackendID: "github", + }, + } + + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "github_pull_request_read": { + WorkloadID: "github", + WorkloadName: "github", + }, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + session := &mockSession{sessionID: "test-session"} + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Test different case variations + queries := []string{ + "PULL REQUEST", + "Pull Request", + "pull request", + "PuLl ReQuEsT", + } + + for _, query := range queries { + query := query + t.Run("case_"+strings.ToLower(query), func(t *testing.T) { + t.Parallel() + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": query, + "tool_keywords": strings.ToLower(query), + "limit": 10, + }, + }, + } + + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError) + + textContent, okText := mcp.AsTextContent(result.Content[0]) + require.True(t, okText) + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + + toolsArray, okArray := response["tools"].([]interface{}) + require.True(t, okArray) + + // Should find the pull request tool regardless of case + found := false + for _, toolInterface := range toolsArray { + toolMap, okMap := toolInterface.(map[string]interface{}) + require.True(t, okMap) + toolName, okName := toolMap["name"].(string) + require.True(t, okName) + if toolName == "github_pull_request_read" { + found = true + break + } + } + + assert.True(t, found, + "Should find pull request tool with case-insensitive query: %s", query) + }) + } +} diff --git a/pkg/vmcp/optimizer/integration.go b/pkg/vmcp/optimizer/integration.go new file mode 100644 index 0000000000..01d2f74291 --- /dev/null +++ b/pkg/vmcp/optimizer/integration.go @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package optimizer + +import ( + "context" + + "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/server/adapter" +) + +// Integration is the interface for optimizer functionality in vMCP. +// This interface encapsulates all optimizer logic, keeping server.go clean. +type Integration interface { + // Initialize performs all optimizer initialization: + // - Registers optimizer tools globally with the MCP server + // - Ingests initial backends from the registry + // This should be called once during server startup, after the MCP server is created. + Initialize(ctx context.Context, mcpServer *server.MCPServer, backendRegistry vmcp.BackendRegistry) error + + // HandleSessionRegistration handles session registration for optimizer mode. + // Returns true if optimizer mode is enabled and handled the registration, + // false if optimizer is disabled and normal registration should proceed. + // The resourceConverter function converts vmcp.Resource to server.ServerResource. + HandleSessionRegistration( + ctx context.Context, + sessionID string, + caps *aggregator.AggregatedCapabilities, + mcpServer *server.MCPServer, + resourceConverter func([]vmcp.Resource) []server.ServerResource, + ) (bool, error) + + // Close cleans up optimizer resources + Close() error + + // OptimizerHandlerProvider is embedded to provide tool handlers + adapter.OptimizerHandlerProvider +} diff --git a/pkg/vmcp/optimizer/optimizer.go b/pkg/vmcp/optimizer/optimizer.go index fea0425bb5..d3640419ec 100644 --- a/pkg/vmcp/optimizer/optimizer.go +++ b/pkg/vmcp/optimizer/optimizer.go @@ -1,91 +1,889 @@ // SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. // SPDX-License-Identifier: Apache-2.0 -// Package optimizer provides the Optimizer interface for intelligent tool discovery -// and invocation in the Virtual MCP Server. +// Package optimizer provides vMCP integration for semantic tool discovery. // -// When the optimizer is enabled, vMCP exposes only two tools to clients: -// - find_tool: Semantic search over available tools -// - call_tool: Dynamic invocation of any backend tool +// This package implements the RFC-0022 optimizer integration, exposing: +// - optim_find_tool: Semantic/keyword-based tool discovery +// - optim_call_tool: Dynamic tool invocation across backends // -// This reduces token usage by avoiding the need to send all tool definitions -// to the LLM, instead allowing it to discover relevant tools on demand. +// Architecture: +// - Embeddings are generated during session initialization (OnRegisterSession hook) +// - Tools are exposed as standard MCP tools callable via tools/call +// - Integrates with vMCP's two-boundary authentication model +// - Uses existing router for backend tool invocation package optimizer import ( "context" "encoding/json" + "fmt" + "sync" + "time" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/db" + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings" + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/ingestion" + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/models" + "github.com/stacklok/toolhive/pkg/logger" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/discovery" + "github.com/stacklok/toolhive/pkg/vmcp/server/adapter" ) -// Optimizer defines the interface for intelligent tool discovery and invocation. +// Config holds optimizer configuration for vMCP integration. +type Config struct { + // Enabled controls whether optimizer tools are available + Enabled bool + + // PersistPath is the optional path for chromem-go database persistence (empty = in-memory) + PersistPath string + + // FTSDBPath is the path to SQLite FTS5 database for BM25 search + // (empty = auto-default: ":memory:" or "{PersistPath}/fts.db") + FTSDBPath string + + // HybridSearchRatio controls semantic vs BM25 mix (0-100 percentage, default: 70) + HybridSearchRatio int + + // EmbeddingConfig configures the embedding backend (vLLM, Ollama, placeholder) + EmbeddingConfig *embeddings.Config +} + +// OptimizerIntegration manages optimizer functionality within vMCP. // -// Implementations may use various strategies for tool matching: -// - DummyOptimizer: Exact string matching (for testing) -// - EmbeddingOptimizer: Semantic similarity via embeddings (production) -type Optimizer interface { - // FindTool searches for tools matching the given description and keywords. - // Returns matching tools ranked by relevance score. - FindTool(ctx context.Context, input FindToolInput) (*FindToolOutput, error) +//nolint:revive // Name is intentional for clarity in external packages +type OptimizerIntegration struct { + config *Config + ingestionService *ingestion.Service + mcpServer *server.MCPServer // For registering tools + backendClient vmcp.BackendClient // For querying backends at startup + sessionManager *transportsession.Manager + processedSessions sync.Map // Track sessions that have already been processed + tracer trace.Tracer +} + +// NewIntegration creates a new optimizer integration. +func NewIntegration( + _ context.Context, + cfg *Config, + mcpServer *server.MCPServer, + backendClient vmcp.BackendClient, + sessionManager *transportsession.Manager, +) (*OptimizerIntegration, error) { + if cfg == nil || !cfg.Enabled { + return nil, nil // Optimizer disabled + } + + // Initialize ingestion service with embedding backend + ingestionCfg := &ingestion.Config{ + DBConfig: &db.Config{ + PersistPath: cfg.PersistPath, + FTSDBPath: cfg.FTSDBPath, + }, + EmbeddingConfig: cfg.EmbeddingConfig, + } + + svc, err := ingestion.NewService(ingestionCfg) + if err != nil { + return nil, fmt.Errorf("failed to initialize optimizer service: %w", err) + } - // CallTool invokes a tool by name with the given parameters. - // Returns the tool's result or an error if the tool is not found or execution fails. - // Returns the MCP CallToolResult directly from the underlying tool handler. - CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error) + return &OptimizerIntegration{ + config: cfg, + ingestionService: svc, + mcpServer: mcpServer, + backendClient: backendClient, + sessionManager: sessionManager, + tracer: otel.Tracer("github.com/stacklok/toolhive/pkg/vmcp/optimizer"), + }, nil } -// FindToolInput contains the parameters for finding tools. -type FindToolInput struct { - // ToolDescription is a natural language description of the tool to find. - ToolDescription string `json:"tool_description" description:"Natural language description of the tool to find"` +// Ensure OptimizerIntegration implements Integration interface at compile time. +var _ Integration = (*OptimizerIntegration)(nil) + +// HandleSessionRegistration handles session registration for optimizer mode. +// Returns true if optimizer mode is enabled and handled the registration, +// false if optimizer is disabled and normal registration should proceed. +// +// When optimizer is enabled: +// 1. Registers optimizer tools (find_tool, call_tool) for the session +// 2. Injects resources (but not backend tools or composite tools) +// 3. Backend tools are accessible via find_tool and call_tool +func (o *OptimizerIntegration) HandleSessionRegistration( + _ context.Context, + sessionID string, + caps *aggregator.AggregatedCapabilities, + mcpServer *server.MCPServer, + resourceConverter func([]vmcp.Resource) []server.ServerResource, +) (bool, error) { + if o == nil { + return false, nil // Optimizer not enabled, use normal registration + } + + logger.Debugw("HandleSessionRegistration called for optimizer mode", "session_id", sessionID) - // ToolKeywords is an optional list of keywords to narrow the search. - ToolKeywords []string `json:"tool_keywords,omitempty" description:"Optional keywords to narrow search"` + // Register optimizer tools for this session + // Tools are already registered globally, but we need to add them to the session + // when using WithToolCapabilities(false) + optimizerTools, err := adapter.CreateOptimizerTools(o) + if err != nil { + return false, fmt.Errorf("failed to create optimizer tools: %w", err) + } + + // Add optimizer tools to session + if err := mcpServer.AddSessionTools(sessionID, optimizerTools...); err != nil { + return false, fmt.Errorf("failed to add optimizer tools to session: %w", err) + } + + logger.Debugw("Optimizer tools registered for session", "session_id", sessionID) + + // Inject resources (but not backend tools or composite tools) + // Backend tools will be accessible via find_tool and call_tool + if len(caps.Resources) > 0 { + sdkResources := resourceConverter(caps.Resources) + if err := mcpServer.AddSessionResources(sessionID, sdkResources...); err != nil { + return false, fmt.Errorf("failed to add session resources: %w", err) + } + logger.Debugw("Added session resources (optimizer mode)", + "session_id", sessionID, + "count", len(sdkResources)) + } + + logger.Infow("Optimizer mode: backend tools not exposed directly", + "session_id", sessionID, + "backend_tool_count", len(caps.Tools), + "resource_count", len(caps.Resources)) + + return true, nil // Optimizer handled the registration } -// FindToolOutput contains the results of a tool search. -type FindToolOutput struct { - // Tools contains the matching tools, ranked by relevance. - Tools []ToolMatch `json:"tools"` +// OnRegisterSession is a legacy method kept for test compatibility. +// It does nothing since ingestion is now handled by Initialize(). +// This method is deprecated and will be removed in a future version. +// Tests should be updated to use HandleSessionRegistration instead. +func (o *OptimizerIntegration) OnRegisterSession( + _ context.Context, + session server.ClientSession, + _ *aggregator.AggregatedCapabilities, +) error { + if o == nil { + return nil // Optimizer not enabled + } + + sessionID := session.SessionID() + + logger.Debugw("OnRegisterSession called (legacy method, no-op)", "session_id", sessionID) - // TokenMetrics provides information about token savings from using the optimizer. - TokenMetrics TokenMetrics `json:"token_metrics"` + // Check if this session has already been processed + if _, alreadyProcessed := o.processedSessions.LoadOrStore(sessionID, true); alreadyProcessed { + logger.Debugw("Session already processed, skipping duplicate ingestion", + "session_id", sessionID) + return nil + } + + // Skip ingestion in OnRegisterSession - IngestInitialBackends already handles ingestion at startup + // This prevents duplicate ingestion when sessions are registered + // The optimizer database is populated once at startup, not per-session + logger.Infow("Skipping ingestion in OnRegisterSession (handled by Initialize at startup)", + "session_id", sessionID) + + return nil } -// ToolMatch represents a tool that matched the search criteria. -type ToolMatch struct { - // Name is the unique identifier of the tool. - Name string `json:"name"` +// Initialize performs all optimizer initialization: +// - Registers optimizer tools globally with the MCP server +// - Ingests initial backends from the registry +// +// This should be called once during server startup, after the MCP server is created. +func (o *OptimizerIntegration) Initialize( + ctx context.Context, + mcpServer *server.MCPServer, + backendRegistry vmcp.BackendRegistry, +) error { + if o == nil { + return nil // Optimizer not enabled + } - // Description is the human-readable description of the tool. - Description string `json:"description"` + // Register optimizer tools globally (available to all sessions immediately) + optimizerTools, err := adapter.CreateOptimizerTools(o) + if err != nil { + return fmt.Errorf("failed to create optimizer tools: %w", err) + } + for _, tool := range optimizerTools { + mcpServer.AddTool(tool.Tool, tool.Handler) + } + logger.Info("Optimizer tools registered globally") - // InputSchema is the JSON schema for the tool's input parameters. - // Uses json.RawMessage to preserve the original schema format. - InputSchema json.RawMessage `json:"input_schema"` + // Ingest discovered backends into optimizer database + initialBackends := backendRegistry.List(ctx) + if err := o.IngestInitialBackends(ctx, initialBackends); err != nil { + logger.Warnf("Failed to ingest initial backends into optimizer: %v", err) + // Don't fail initialization - optimizer can still work with incremental ingestion + } - // Score indicates how well this tool matches the search criteria (0.0-1.0). - Score float64 `json:"score"` + return nil } -// TokenMetrics provides information about token usage optimization. -type TokenMetrics struct { - // BaselineTokens is the estimated tokens if all tools were sent. - BaselineTokens int `json:"baseline_tokens"` +// RegisterTools adds optimizer tools to the session. +// Even though tools are registered globally via RegisterGlobalTools(), +// with WithToolCapabilities(false), we also need to register them per-session +// to ensure they appear in list_tools responses. +// This should be called after OnRegisterSession completes. +func (o *OptimizerIntegration) RegisterTools(_ context.Context, session server.ClientSession) error { + if o == nil { + return nil // Optimizer not enabled + } + + sessionID := session.SessionID() - // ReturnedTokens is the actual tokens for the returned tools. - ReturnedTokens int `json:"returned_tokens"` + // Define optimizer tools with handlers (same as global registration) + optimizerTools := []server.ServerTool{ + { + Tool: mcp.Tool{ + Name: "optim_find_tool", + Description: "Semantic search across all backend tools using natural language description and optional keywords", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "tool_description": map[string]any{ + "type": "string", + "description": "Natural language description of the tool you're looking for", + }, + "tool_keywords": map[string]any{ + "type": "string", + "description": "Optional space-separated keywords for keyword-based search", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of tools to return (default: 10)", + "default": 10, + }, + }, + Required: []string{"tool_description"}, + }, + }, + Handler: o.createFindToolHandler(), + }, + { + Tool: mcp.Tool{ + Name: "optim_call_tool", + Description: "Dynamically invoke any tool on any backend using the backend_id from find_tool", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "backend_id": map[string]any{ + "type": "string", + "description": "Backend ID from find_tool results", + }, + "tool_name": map[string]any{ + "type": "string", + "description": "Tool name to invoke", + }, + "parameters": map[string]any{ + "type": "object", + "description": "Parameters to pass to the tool", + }, + }, + Required: []string{"backend_id", "tool_name", "parameters"}, + }, + }, + Handler: o.CreateCallToolHandler(), + }, + } + + // Add tools to session (required when WithToolCapabilities(false)) + if err := o.mcpServer.AddSessionTools(sessionID, optimizerTools...); err != nil { + return fmt.Errorf("failed to add optimizer tools to session: %w", err) + } + + logger.Debugw("Optimizer tools registered for session", "session_id", sessionID) + return nil +} - // SavingsPercent is the percentage of tokens saved. - SavingsPercent float64 `json:"savings_percent"` +// GetOptimizerToolDefinitions returns the tool definitions for optimizer tools +// without handlers. This is useful for adding tools to capabilities before session registration. +func (o *OptimizerIntegration) GetOptimizerToolDefinitions() []mcp.Tool { + if o == nil { + return nil + } + return []mcp.Tool{ + { + Name: "optim_find_tool", + Description: "Semantic search across all backend tools using natural language description and optional keywords", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "tool_description": map[string]any{ + "type": "string", + "description": "Natural language description of the tool you're looking for", + }, + "tool_keywords": map[string]any{ + "type": "string", + "description": "Optional space-separated keywords for keyword-based search", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of tools to return (default: 10)", + "default": 10, + }, + }, + Required: []string{"tool_description"}, + }, + }, + { + Name: "optim_call_tool", + Description: "Dynamically invoke any tool on any backend using the backend_id from find_tool", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "backend_id": map[string]any{ + "type": "string", + "description": "Backend ID from find_tool results", + }, + "tool_name": map[string]any{ + "type": "string", + "description": "Tool name to invoke", + }, + "parameters": map[string]any{ + "type": "object", + "description": "Parameters to pass to the tool", + }, + }, + Required: []string{"backend_id", "tool_name", "parameters"}, + }, + }, + } } -// CallToolInput contains the parameters for calling a tool. -type CallToolInput struct { - // ToolName is the name of the tool to invoke. - ToolName string `json:"tool_name" description:"Name of the tool to call"` +// CreateFindToolHandler creates the handler for optim_find_tool +// Exported for testing purposes +func (o *OptimizerIntegration) CreateFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return o.createFindToolHandler() +} + +// extractFindToolParams extracts and validates parameters from the find_tool request +func extractFindToolParams(args map[string]any) (toolDescription, toolKeywords string, limit int, err *mcp.CallToolResult) { + // Extract tool_description (required) + toolDescription, ok := args["tool_description"].(string) + if !ok || toolDescription == "" { + return "", "", 0, mcp.NewToolResultError("tool_description is required and must be a non-empty string") + } + + // Extract tool_keywords (optional) + toolKeywords, _ = args["tool_keywords"].(string) + + // Extract limit (optional, default: 10) + limit = 10 + if limitVal, ok := args["limit"]; ok { + if limitFloat, ok := limitVal.(float64); ok { + limit = int(limitFloat) + } + } + + return toolDescription, toolKeywords, limit, nil +} + +// resolveToolName looks up the resolved name for a tool in the routing table. +// Returns the resolved name if found, otherwise returns the original name. +// +// The routing table maps resolved names (after conflict resolution) to BackendTarget. +// Each BackendTarget contains: +// - WorkloadID: the backend ID +// - OriginalCapabilityName: the original tool name (empty if not renamed) +// +// We need to find the resolved name by matching backend ID and original name. +func resolveToolName(routingTable *vmcp.RoutingTable, backendID string, originalName string) string { + if routingTable == nil || routingTable.Tools == nil { + return originalName + } + + // Search through routing table to find the resolved name + // Match by backend ID and original capability name + for resolvedName, target := range routingTable.Tools { + // Case 1: Tool was renamed (OriginalCapabilityName is set) + // Match by backend ID and original name + if target.WorkloadID == backendID && target.OriginalCapabilityName == originalName { + logger.Debugw("Resolved tool name (renamed)", + "backend_id", backendID, + "original_name", originalName, + "resolved_name", resolvedName) + return resolvedName + } + + // Case 2: Tool was not renamed (OriginalCapabilityName is empty) + // Match by backend ID and resolved name equals original name + if target.WorkloadID == backendID && target.OriginalCapabilityName == "" && resolvedName == originalName { + logger.Debugw("Resolved tool name (not renamed)", + "backend_id", backendID, + "original_name", originalName, + "resolved_name", resolvedName) + return resolvedName + } + } + + // If not found, return original name (fallback for tools not in routing table) + // This can happen if: + // - Tool was just ingested but routing table hasn't been updated yet + // - Tool belongs to a backend that's not currently registered + logger.Debugw("Tool name not found in routing table, using original name", + "backend_id", backendID, + "original_name", originalName) + return originalName +} + +// convertSearchResultsToResponse converts database search results to the response format. +// It resolves tool names using the routing table to ensure returned names match routing table keys. +func convertSearchResultsToResponse( + results []*models.BackendToolWithMetadata, + routingTable *vmcp.RoutingTable, +) ([]map[string]any, int) { + responseTools := make([]map[string]any, 0, len(results)) + totalReturnedTokens := 0 + + for _, result := range results { + // Unmarshal InputSchema + var inputSchema map[string]any + if len(result.InputSchema) > 0 { + if err := json.Unmarshal(result.InputSchema, &inputSchema); err != nil { + logger.Warnw("Failed to unmarshal input schema", + "tool_id", result.ID, + "tool_name", result.ToolName, + "error", err) + inputSchema = map[string]any{} // Use empty schema on error + } + } + + // Handle nil description + description := "" + if result.Description != nil { + description = *result.Description + } + + // Resolve tool name using routing table to ensure it matches routing table keys + resolvedName := resolveToolName(routingTable, result.MCPServerID, result.ToolName) + + tool := map[string]any{ + "name": resolvedName, + "description": description, + "input_schema": inputSchema, + "backend_id": result.MCPServerID, + "similarity_score": result.Similarity, + "token_count": result.TokenCount, + } + responseTools = append(responseTools, tool) + totalReturnedTokens += result.TokenCount + } + + return responseTools, totalReturnedTokens +} + +// createFindToolHandler creates the handler for optim_find_tool +func (o *OptimizerIntegration) createFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + logger.Debugw("optim_find_tool called", "request", request) + + // Extract parameters from request arguments + args, ok := request.Params.Arguments.(map[string]any) + if !ok { + return mcp.NewToolResultError("invalid arguments: expected object"), nil + } + + // Extract and validate parameters + toolDescription, toolKeywords, limit, err := extractFindToolParams(args) + if err != nil { + return err, nil + } + + // Perform hybrid search using database operations + if o.ingestionService == nil { + return mcp.NewToolResultError("backend tool operations not initialized"), nil + } + backendToolOps := o.ingestionService.GetBackendToolOps() + if backendToolOps == nil { + return mcp.NewToolResultError("backend tool operations not initialized"), nil + } + + // Configure hybrid search + hybridConfig := &db.HybridSearchConfig{ + SemanticRatio: o.config.HybridSearchRatio, + Limit: limit, + ServerID: nil, // Search across all servers + } + + // Execute hybrid search + queryText := toolDescription + if toolKeywords != "" { + queryText = toolDescription + " " + toolKeywords + } + results, err2 := backendToolOps.SearchHybrid(ctx, queryText, hybridConfig) + if err2 != nil { + logger.Errorw("Hybrid search failed", + "error", err2, + "tool_description", toolDescription, + "tool_keywords", toolKeywords, + "query_text", queryText) + return mcp.NewToolResultError(fmt.Sprintf("search failed: %v", err2)), nil + } + + // Get routing table from context to resolve tool names + var routingTable *vmcp.RoutingTable + if capabilities, ok := discovery.DiscoveredCapabilitiesFromContext(ctx); ok && capabilities != nil { + routingTable = capabilities.RoutingTable + } + + // Convert results to response format, resolving tool names to match routing table + responseTools, totalReturnedTokens := convertSearchResultsToResponse(results, routingTable) + + // Calculate token metrics + baselineTokens := o.ingestionService.GetTotalToolTokens(ctx) + tokensSaved := baselineTokens - totalReturnedTokens + savingsPercentage := 0.0 + if baselineTokens > 0 { + savingsPercentage = (float64(tokensSaved) / float64(baselineTokens)) * 100.0 + } + + tokenMetrics := map[string]any{ + "baseline_tokens": baselineTokens, + "returned_tokens": totalReturnedTokens, + "tokens_saved": tokensSaved, + "savings_percentage": savingsPercentage, + } + + // Record OpenTelemetry metrics for token savings + o.recordTokenMetrics(ctx, baselineTokens, totalReturnedTokens, tokensSaved, savingsPercentage) + + // Build response + response := map[string]any{ + "tools": responseTools, + "token_metrics": tokenMetrics, + } + + // Marshal to JSON for the result + responseJSON, err3 := json.Marshal(response) + if err3 != nil { + logger.Errorw("Failed to marshal response", "error", err3) + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal response: %v", err3)), nil + } + + logger.Infow("optim_find_tool completed", + "query", toolDescription, + "results_count", len(responseTools), + "tokens_saved", tokensSaved, + "savings_percentage", fmt.Sprintf("%.2f%%", savingsPercentage)) + + return mcp.NewToolResultText(string(responseJSON)), nil + } +} + +// recordTokenMetrics records OpenTelemetry metrics for token savings +func (*OptimizerIntegration) recordTokenMetrics( + ctx context.Context, + baselineTokens int, + returnedTokens int, + tokensSaved int, + savingsPercentage float64, +) { + // Get meter from global OpenTelemetry provider + meter := otel.Meter("github.com/stacklok/toolhive/pkg/vmcp/optimizer") + + // Create metrics if they don't exist (they'll be cached by the meter) + baselineCounter, err := meter.Int64Counter( + "toolhive_vmcp_optimizer_baseline_tokens", + metric.WithDescription("Total tokens for all tools in the optimizer database (baseline)"), + ) + if err != nil { + logger.Debugw("Failed to create baseline_tokens counter", "error", err) + return + } + + returnedCounter, err := meter.Int64Counter( + "toolhive_vmcp_optimizer_returned_tokens", + metric.WithDescription("Total tokens for tools returned by optim_find_tool"), + ) + if err != nil { + logger.Debugw("Failed to create returned_tokens counter", "error", err) + return + } + + savedCounter, err := meter.Int64Counter( + "toolhive_vmcp_optimizer_tokens_saved", + metric.WithDescription("Number of tokens saved by filtering tools with optim_find_tool"), + ) + if err != nil { + logger.Debugw("Failed to create tokens_saved counter", "error", err) + return + } + + savingsGauge, err := meter.Float64Gauge( + "toolhive_vmcp_optimizer_savings_percentage", + metric.WithDescription("Percentage of tokens saved by filtering tools (0-100)"), + metric.WithUnit("%"), + ) + if err != nil { + logger.Debugw("Failed to create savings_percentage gauge", "error", err) + return + } + + // Record metrics with attributes + attrs := metric.WithAttributes( + attribute.String("operation", "find_tool"), + ) + + baselineCounter.Add(ctx, int64(baselineTokens), attrs) + returnedCounter.Add(ctx, int64(returnedTokens), attrs) + savedCounter.Add(ctx, int64(tokensSaved), attrs) + savingsGauge.Record(ctx, savingsPercentage, attrs) + + logger.Debugw("Token metrics recorded", + "baseline_tokens", baselineTokens, + "returned_tokens", returnedTokens, + "tokens_saved", tokensSaved, + "savings_percentage", savingsPercentage) +} + +// CreateCallToolHandler creates the handler for optim_call_tool +// Exported for testing purposes +func (o *OptimizerIntegration) CreateCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return o.createCallToolHandler() +} + +// createCallToolHandler creates the handler for optim_call_tool +func (o *OptimizerIntegration) createCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + logger.Debugw("optim_call_tool called", "request", request) + + // Extract parameters from request arguments + args, ok := request.Params.Arguments.(map[string]any) + if !ok { + return mcp.NewToolResultError("invalid arguments: expected object"), nil + } + + // Extract backend_id (required) + backendID, ok := args["backend_id"].(string) + if !ok || backendID == "" { + return mcp.NewToolResultError("backend_id is required and must be a non-empty string"), nil + } + + // Extract tool_name (required) + toolName, ok := args["tool_name"].(string) + if !ok || toolName == "" { + return mcp.NewToolResultError("tool_name is required and must be a non-empty string"), nil + } + + // Extract parameters (required) + parameters, ok := args["parameters"].(map[string]any) + if !ok { + return mcp.NewToolResultError("parameters is required and must be an object"), nil + } + + // Get routing table from context via discovered capabilities + capabilities, ok := discovery.DiscoveredCapabilitiesFromContext(ctx) + if !ok || capabilities == nil { + return mcp.NewToolResultError("routing information not available in context"), nil + } + + if capabilities.RoutingTable == nil || capabilities.RoutingTable.Tools == nil { + return mcp.NewToolResultError("routing table not initialized"), nil + } + + // Find the tool in the routing table + target, exists := capabilities.RoutingTable.Tools[toolName] + if !exists { + return mcp.NewToolResultError(fmt.Sprintf("tool not found in routing table: %s", toolName)), nil + } + + // Verify the tool belongs to the specified backend + if target.WorkloadID != backendID { + return mcp.NewToolResultError(fmt.Sprintf( + "tool %s belongs to backend %s, not %s", + toolName, + target.WorkloadID, + backendID, + )), nil + } + + // Get the backend capability name (handles renamed tools) + backendToolName := target.GetBackendCapabilityName(toolName) + + logger.Infow("Calling tool via optimizer", + "backend_id", backendID, + "tool_name", toolName, + "backend_tool_name", backendToolName, + "workload_name", target.WorkloadName) + + // Call the tool on the backend using the backend client + result, err := o.backendClient.CallTool(ctx, target, backendToolName, parameters) + if err != nil { + logger.Errorw("Tool call failed", + "error", err, + "backend_id", backendID, + "tool_name", toolName, + "backend_tool_name", backendToolName) + return mcp.NewToolResultError(fmt.Sprintf("tool call failed: %v", err)), nil + } + + // Convert result to JSON + resultJSON, err := json.Marshal(result) + if err != nil { + logger.Errorw("Failed to marshal tool result", "error", err) + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + logger.Infow("optim_call_tool completed successfully", + "backend_id", backendID, + "tool_name", toolName) + + return mcp.NewToolResultText(string(resultJSON)), nil + } +} + +// IngestInitialBackends ingests all discovered backends and their tools at startup. +// This should be called after backends are discovered during server initialization. +func (o *OptimizerIntegration) IngestInitialBackends(ctx context.Context, backends []vmcp.Backend) error { + if o == nil || o.ingestionService == nil { + // Optimizer disabled - log that embedding time is 0 + logger.Infow("Optimizer disabled, embedding time: 0ms") + return nil + } + + // Reset embedding time before starting ingestion + o.ingestionService.ResetEmbeddingTime() + + // Create a span for the entire ingestion process + ctx, span := o.tracer.Start(ctx, "optimizer.ingestion.ingest_initial_backends", + trace.WithAttributes( + attribute.Int("backends.count", len(backends)), + )) + defer span.End() + + start := time.Now() + logger.Infof("Ingesting %d discovered backends into optimizer", len(backends)) + + ingestedCount := 0 + totalToolsIngested := 0 + for _, backend := range backends { + // Create a span for each backend ingestion + backendCtx, backendSpan := o.tracer.Start(ctx, "optimizer.ingestion.ingest_backend", + trace.WithAttributes( + attribute.String("backend.id", backend.ID), + attribute.String("backend.name", backend.Name), + )) + defer backendSpan.End() + + // Convert Backend to BackendTarget for client API + target := vmcp.BackendToTarget(&backend) + if target == nil { + logger.Warnf("Failed to convert backend %s to target", backend.Name) + backendSpan.RecordError(fmt.Errorf("failed to convert backend to target")) + backendSpan.SetStatus(codes.Error, "conversion failed") + continue + } + + // Query backend capabilities to get its tools + capabilities, err := o.backendClient.ListCapabilities(backendCtx, target) + if err != nil { + logger.Warnf("Failed to query capabilities for backend %s: %v", backend.Name, err) + backendSpan.RecordError(err) + backendSpan.SetStatus(codes.Error, err.Error()) + continue // Skip this backend but continue with others + } + + // Extract tools from capabilities + // Note: For ingestion, we only need name and description (for generating embeddings) + // InputSchema is not used by the ingestion service + var tools []mcp.Tool + for _, tool := range capabilities.Tools { + tools = append(tools, mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + // InputSchema not needed for embedding generation + }) + } + + // Get description from metadata (may be empty) + var description *string + if backend.Metadata != nil { + if desc := backend.Metadata["description"]; desc != "" { + description = &desc + } + } + + backendSpan.SetAttributes( + attribute.Int("tools.count", len(tools)), + ) + + // Ingest this backend's tools (IngestServer will create its own spans) + if err := o.ingestionService.IngestServer( + backendCtx, + backend.ID, + backend.Name, + description, + tools, + ); err != nil { + logger.Warnf("Failed to ingest backend %s: %v", backend.Name, err) + backendSpan.RecordError(err) + backendSpan.SetStatus(codes.Error, err.Error()) + continue // Log but don't fail startup + } + ingestedCount++ + totalToolsIngested += len(tools) + backendSpan.SetAttributes( + attribute.Int("tools.ingested", len(tools)), + ) + backendSpan.SetStatus(codes.Ok, "backend ingested successfully") + } + + // Get total embedding time + totalEmbeddingTime := o.ingestionService.GetTotalEmbeddingTime() + totalDuration := time.Since(start) + + span.SetAttributes( + attribute.Int64("ingestion.duration_ms", totalDuration.Milliseconds()), + attribute.Int64("embedding.duration_ms", totalEmbeddingTime.Milliseconds()), + attribute.Int("backends.ingested", ingestedCount), + attribute.Int("tools.ingested", totalToolsIngested), + ) + + logger.Infow("Initial backend ingestion completed", + "servers_ingested", ingestedCount, + "tools_ingested", totalToolsIngested, + "total_duration_ms", totalDuration.Milliseconds(), + "total_embedding_time_ms", totalEmbeddingTime.Milliseconds(), + "embedding_time_percentage", fmt.Sprintf("%.2f%%", float64(totalEmbeddingTime)/float64(totalDuration)*100)) + + return nil +} + +// Close cleans up optimizer resources. +func (o *OptimizerIntegration) Close() error { + if o == nil || o.ingestionService == nil { + return nil + } + return o.ingestionService.Close() +} - // Parameters are the arguments to pass to the tool. - Parameters map[string]any `json:"parameters" description:"Parameters to pass to the tool"` +// IngestToolsForTesting manually ingests tools for testing purposes. +// This is a test helper that bypasses the normal ingestion flow. +func (o *OptimizerIntegration) IngestToolsForTesting( + ctx context.Context, + serverID string, + serverName string, + description *string, + tools []mcp.Tool, +) error { + if o == nil || o.ingestionService == nil { + return fmt.Errorf("optimizer integration not initialized") + } + return o.ingestionService.IngestServer(ctx, serverID, serverName, description, tools) } diff --git a/pkg/vmcp/optimizer/optimizer_handlers_test.go b/pkg/vmcp/optimizer/optimizer_handlers_test.go new file mode 100644 index 0000000000..6adee847ee --- /dev/null +++ b/pkg/vmcp/optimizer/optimizer_handlers_test.go @@ -0,0 +1,1029 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package optimizer + +import ( + "context" + "encoding/json" + "path/filepath" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/discovery" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" +) + +// mockMCPServerWithSession implements AddSessionTools for testing +type mockMCPServerWithSession struct { + *server.MCPServer + toolsAdded map[string][]server.ServerTool +} + +func newMockMCPServerWithSession() *mockMCPServerWithSession { + return &mockMCPServerWithSession{ + MCPServer: server.NewMCPServer("test-server", "1.0"), + toolsAdded: make(map[string][]server.ServerTool), + } +} + +func (m *mockMCPServerWithSession) AddSessionTools(sessionID string, tools ...server.ServerTool) error { + m.toolsAdded[sessionID] = tools + return nil +} + +// mockBackendClientWithCallTool implements CallTool for testing +type mockBackendClientWithCallTool struct { + callToolResult map[string]any + callToolError error +} + +func (*mockBackendClientWithCallTool) ListCapabilities(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + return &vmcp.CapabilityList{}, nil +} + +func (m *mockBackendClientWithCallTool) CallTool(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (map[string]any, error) { + if m.callToolError != nil { + return nil, m.callToolError + } + return m.callToolResult, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockBackendClientWithCallTool) GetPrompt(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (string, error) { + return "", nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockBackendClientWithCallTool) ReadResource(_ context.Context, _ *vmcp.BackendTarget, _ string) ([]byte, error) { + return nil, nil +} + +// TestCreateFindToolHandler_InvalidArguments tests error handling for invalid arguments +func TestCreateFindToolHandler_InvalidArguments(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Setup optimizer integration + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateFindToolHandler() + + // Test with invalid arguments type + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: "not a map", + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for invalid arguments") + + // Test with missing tool_description + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "limit": 10, + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for missing tool_description") + + // Test with empty tool_description + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": "", + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for empty tool_description") + + // Test with non-string tool_description + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": 123, + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for non-string tool_description") +} + +// TestCreateFindToolHandler_WithKeywords tests find_tool with keywords +func TestCreateFindToolHandler_WithKeywords(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + // Ingest a tool for testing + tools := []mcp.Tool{ + { + Name: "test_tool", + Description: "A test tool for searching", + }, + } + + err = integration.IngestToolsForTesting(ctx, "server-1", "TestServer", nil, tools) + require.NoError(t, err) + + handler := integration.CreateFindToolHandler() + + // Test with keywords + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": "search tool", + "tool_keywords": "test search", + "limit": 10, + }, + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.False(t, result.IsError, "Should not return error") + + // Verify response structure + textContent, ok := mcp.AsTextContent(result.Content[0]) + require.True(t, ok) + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + + _, ok = response["tools"] + require.True(t, ok, "Response should have tools") + + _, ok = response["token_metrics"] + require.True(t, ok, "Response should have token_metrics") +} + +// TestCreateFindToolHandler_Limit tests limit parameter handling +func TestCreateFindToolHandler_Limit(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateFindToolHandler() + + // Test with custom limit + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": "test", + "limit": 5, + }, + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.False(t, result.IsError) + + // Test with float64 limit (from JSON) + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": "test", + "limit": float64(3), + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.False(t, result.IsError) +} + +// TestCreateFindToolHandler_BackendToolOpsNil tests error when backend tool ops is nil +func TestCreateFindToolHandler_BackendToolOpsNil(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Create integration with nil ingestion service to trigger error path + integration := &OptimizerIntegration{ + config: &Config{Enabled: true}, + ingestionService: nil, // This will cause GetBackendToolOps to return nil + } + + handler := integration.CreateFindToolHandler() + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": "test", + }, + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error when backend tool ops is nil") +} + +// TestCreateCallToolHandler_InvalidArguments tests error handling for invalid arguments +func TestCreateCallToolHandler_InvalidArguments(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + // Test with invalid arguments type + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: "not a map", + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for invalid arguments") + + // Test with missing backend_id + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "tool_name": "test_tool", + "parameters": map[string]any{}, + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for missing backend_id") + + // Test with empty backend_id + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "", + "tool_name": "test_tool", + "parameters": map[string]any{}, + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for empty backend_id") + + // Test with missing tool_name + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "parameters": map[string]any{}, + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for missing tool_name") + + // Test with missing parameters + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "test_tool", + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for missing parameters") + + // Test with invalid parameters type + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "test_tool", + "parameters": "not a map", + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for invalid parameters type") +} + +// TestCreateCallToolHandler_NoRoutingTable tests error when routing table is missing +func TestCreateCallToolHandler_NoRoutingTable(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + // Test without routing table in context + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "test_tool", + "parameters": map[string]any{}, + }, + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error when routing table is missing") +} + +// TestCreateCallToolHandler_ToolNotFound tests error when tool is not found +func TestCreateCallToolHandler_ToolNotFound(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + // Create context with routing table but tool not found + capabilities := &aggregator.AggregatedCapabilities{ + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "nonexistent_tool", + "parameters": map[string]any{}, + }, + }, + } + + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error when tool is not found") +} + +// TestCreateCallToolHandler_BackendMismatch tests error when backend doesn't match +func TestCreateCallToolHandler_BackendMismatch(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + // Create context with routing table where tool belongs to different backend + capabilities := &aggregator.AggregatedCapabilities{ + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "test_tool": { + WorkloadID: "backend-2", // Different backend + WorkloadName: "Backend 2", + }, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", // Requesting backend-1 + "tool_name": "test_tool", // But tool belongs to backend-2 + "parameters": map[string]any{}, + }, + }, + } + + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error when backend doesn't match") +} + +// TestCreateCallToolHandler_Success tests successful tool call +func TestCreateCallToolHandler_Success(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{ + callToolResult: map[string]any{ + "result": "success", + }, + } + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + // Create context with routing table + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "Backend 1", + BaseURL: "http://localhost:8000", + } + + capabilities := &aggregator.AggregatedCapabilities{ + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "test_tool": target, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "test_tool", + "parameters": map[string]any{ + "param1": "value1", + }, + }, + }, + } + + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.False(t, result.IsError, "Should not return error") + + // Verify response + textContent, ok := mcp.AsTextContent(result.Content[0]) + require.True(t, ok) + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.Equal(t, "success", response["result"]) +} + +// TestCreateCallToolHandler_CallToolError tests error handling when CallTool fails +func TestCreateCallToolHandler_CallToolError(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{ + callToolError: assert.AnError, + } + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "Backend 1", + BaseURL: "http://localhost:8000", + } + + capabilities := &aggregator.AggregatedCapabilities{ + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "test_tool": target, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "test_tool", + "parameters": map[string]any{}, + }, + }, + } + + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error when CallTool fails") +} + +// TestCreateFindToolHandler_InputSchemaUnmarshalError tests error handling for invalid input schema +func TestCreateFindToolHandler_InputSchemaUnmarshalError(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateFindToolHandler() + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": "test", + }, + }, + } + + // The handler should handle invalid input schema gracefully + result, err := handler(ctx, request) + require.NoError(t, err) + // Should not error even if some tools have invalid schemas + require.False(t, result.IsError) +} + +// TestOnRegisterSession_DuplicateSession tests duplicate session handling +func TestOnRegisterSession_DuplicateSession(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + session := &mockSession{sessionID: "test-session"} + capabilities := &aggregator.AggregatedCapabilities{} + + // First call + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Second call with same session ID (should be skipped) + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err, "Should handle duplicate session gracefully") +} + +// TestIngestInitialBackends_ErrorHandling tests error handling during ingestion +func TestIngestInitialBackends_ErrorHandling(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{ + err: assert.AnError, // Simulate error when listing capabilities + } + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + backends := []vmcp.Backend{ + { + ID: "backend-1", + Name: "Backend 1", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + // Should not fail even if backend query fails + err = integration.IngestInitialBackends(ctx, backends) + require.NoError(t, err, "Should handle backend query errors gracefully") +} + +// TestIngestInitialBackends_NilIntegration tests nil integration handling +func TestIngestInitialBackends_NilIntegration(t *testing.T) { + t.Parallel() + ctx := context.Background() + + var integration *OptimizerIntegration = nil + backends := []vmcp.Backend{} + + err := integration.IngestInitialBackends(ctx, backends) + require.NoError(t, err, "Should handle nil integration gracefully") +} diff --git a/pkg/vmcp/optimizer/optimizer_integration_test.go b/pkg/vmcp/optimizer/optimizer_integration_test.go new file mode 100644 index 0000000000..bb3ecf9583 --- /dev/null +++ b/pkg/vmcp/optimizer/optimizer_integration_test.go @@ -0,0 +1,439 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package optimizer + +import ( + "context" + "encoding/json" + "path/filepath" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" +) + +// mockBackendClient implements vmcp.BackendClient for integration testing +type mockIntegrationBackendClient struct { + backends map[string]*vmcp.CapabilityList +} + +func newMockIntegrationBackendClient() *mockIntegrationBackendClient { + return &mockIntegrationBackendClient{ + backends: make(map[string]*vmcp.CapabilityList), + } +} + +func (m *mockIntegrationBackendClient) addBackend(backendID string, caps *vmcp.CapabilityList) { + m.backends[backendID] = caps +} + +func (m *mockIntegrationBackendClient) ListCapabilities(_ context.Context, target *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + if caps, exists := m.backends[target.WorkloadID]; exists { + return caps, nil + } + return &vmcp.CapabilityList{}, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationBackendClient) CallTool(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (map[string]any, error) { + return nil, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationBackendClient) GetPrompt(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (string, error) { + return "", nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationBackendClient) ReadResource(_ context.Context, _ *vmcp.BackendTarget, _ string) ([]byte, error) { + return nil, nil +} + +// mockIntegrationSession implements server.ClientSession for testing +type mockIntegrationSession struct { + sessionID string +} + +func (m *mockIntegrationSession) SessionID() string { + return m.sessionID +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationSession) Send(_ interface{}) error { + return nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationSession) Close() error { + return nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationSession) Initialize() { + // No-op for testing +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationSession) Initialized() bool { + return true +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + // Return a dummy channel for testing + ch := make(chan mcp.JSONRPCNotification, 1) + return ch +} + +// TestOptimizerIntegration_WithVMCP tests the complete integration with vMCP +func TestOptimizerIntegration_WithVMCP(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Create MCP server + mcpServer := server.NewMCPServer("vmcp-test", "1.0") + + // Create mock backend client + mockClient := newMockIntegrationBackendClient() + mockClient.addBackend("github", &vmcp.CapabilityList{ + Tools: []vmcp.Tool{ + { + Name: "create_issue", + Description: "Create a GitHub issue", + }, + }, + }) + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Configure optimizer + optimizerConfig := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + } + + // Create optimizer integration + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + // Ingest backends + backends := []vmcp.Backend{ + { + ID: "github", + Name: "GitHub", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + err = integration.IngestInitialBackends(ctx, backends) + require.NoError(t, err) + + // Simulate session registration + session := &mockIntegrationSession{sessionID: "test-session"} + capabilities := &aggregator.AggregatedCapabilities{ + Tools: []vmcp.Tool{ + { + Name: "create_issue", + Description: "Create a GitHub issue", + BackendID: "github", + }, + }, + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "create_issue": { + WorkloadID: "github", + WorkloadName: "GitHub", + }, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Note: We don't test RegisterTools here because it requires the session + // to be properly registered with the MCP server, which is beyond the scope + // of this integration test. The RegisterTools method is tested separately + // in unit tests where we can properly mock the MCP server behavior. +} + +// TestOptimizerIntegration_EmbeddingTimeTracking tests that embedding time is tracked and logged +func TestOptimizerIntegration_EmbeddingTimeTracking(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Create MCP server + mcpServer := server.NewMCPServer("vmcp-test", "1.0") + + // Create mock backend client + mockClient := newMockIntegrationBackendClient() + mockClient.addBackend("github", &vmcp.CapabilityList{ + Tools: []vmcp.Tool{ + { + Name: "create_issue", + Description: "Create a GitHub issue", + }, + { + Name: "get_repo", + Description: "Get repository information", + }, + }, + }) + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Configure optimizer + optimizerConfig := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + } + + // Create optimizer integration + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + // Verify embedding time starts at 0 + embeddingTime := integration.ingestionService.GetTotalEmbeddingTime() + require.Equal(t, time.Duration(0), embeddingTime, "Initial embedding time should be 0") + + // Ingest backends + backends := []vmcp.Backend{ + { + ID: "github", + Name: "GitHub", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + err = integration.IngestInitialBackends(ctx, backends) + require.NoError(t, err) + + // After ingestion, embedding time should be tracked + // Note: The actual time depends on Ollama performance, but it should be > 0 + finalEmbeddingTime := integration.ingestionService.GetTotalEmbeddingTime() + require.Greater(t, finalEmbeddingTime, time.Duration(0), + "Embedding time should be tracked after ingestion") +} + +// TestOptimizerIntegration_DisabledEmbeddingTime tests that embedding time is 0 when optimizer is disabled +func TestOptimizerIntegration_DisabledEmbeddingTime(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Create optimizer integration with disabled optimizer + optimizerConfig := &Config{ + Enabled: false, + } + + mcpServer := server.NewMCPServer("vmcp-test", "1.0") + mockClient := newMockIntegrationBackendClient() + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + + integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.Nil(t, integration, "Integration should be nil when optimizer is disabled") + + // Try to ingest backends - should return nil without error + backends := []vmcp.Backend{ + { + ID: "github", + Name: "GitHub", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + // This should handle nil integration gracefully + var nilIntegration *OptimizerIntegration + err = nilIntegration.IngestInitialBackends(ctx, backends) + require.NoError(t, err, "Should handle nil integration gracefully") +} + +// TestOptimizerIntegration_TokenMetrics tests that token metrics are calculated and returned in optim_find_tool +func TestOptimizerIntegration_TokenMetrics(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Create MCP server + mcpServer := server.NewMCPServer("vmcp-test", "1.0") + + // Create mock backend client with multiple tools + mockClient := newMockIntegrationBackendClient() + mockClient.addBackend("github", &vmcp.CapabilityList{ + Tools: []vmcp.Tool{ + { + Name: "create_issue", + Description: "Create a GitHub issue", + }, + { + Name: "get_pull_request", + Description: "Get a pull request from GitHub", + }, + { + Name: "list_repositories", + Description: "List repositories from GitHub", + }, + }, + }) + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Configure optimizer + optimizerConfig := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + } + + // Create optimizer integration + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + // Ingest backends + backends := []vmcp.Backend{ + { + ID: "github", + Name: "GitHub", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + err = integration.IngestInitialBackends(ctx, backends) + require.NoError(t, err) + + // Get the find_tool handler + handler := integration.CreateFindToolHandler() + require.NotNil(t, handler) + + // Call optim_find_tool + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": "create issue", + "limit": 5, + }, + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.NotNil(t, result) + + // Verify result contains token_metrics + require.NotNil(t, result.Content) + require.Len(t, result.Content, 1) + textResult, ok := result.Content[0].(mcp.TextContent) + require.True(t, ok, "Result should be TextContent") + + // Parse JSON response + var response map[string]any + err = json.Unmarshal([]byte(textResult.Text), &response) + require.NoError(t, err) + + // Verify token_metrics exist + tokenMetrics, ok := response["token_metrics"].(map[string]any) + require.True(t, ok, "Response should contain token_metrics") + + // Verify token metrics fields + baselineTokens, ok := tokenMetrics["baseline_tokens"].(float64) + require.True(t, ok, "token_metrics should contain baseline_tokens") + require.Greater(t, baselineTokens, float64(0), "baseline_tokens should be greater than 0") + + returnedTokens, ok := tokenMetrics["returned_tokens"].(float64) + require.True(t, ok, "token_metrics should contain returned_tokens") + require.GreaterOrEqual(t, returnedTokens, float64(0), "returned_tokens should be >= 0") + + tokensSaved, ok := tokenMetrics["tokens_saved"].(float64) + require.True(t, ok, "token_metrics should contain tokens_saved") + require.GreaterOrEqual(t, tokensSaved, float64(0), "tokens_saved should be >= 0") + + savingsPercentage, ok := tokenMetrics["savings_percentage"].(float64) + require.True(t, ok, "token_metrics should contain savings_percentage") + require.GreaterOrEqual(t, savingsPercentage, float64(0), "savings_percentage should be >= 0") + require.LessOrEqual(t, savingsPercentage, float64(100), "savings_percentage should be <= 100") + + // Verify tools are returned + tools, ok := response["tools"].([]any) + require.True(t, ok, "Response should contain tools") + require.Greater(t, len(tools), 0, "Should return at least one tool") +} diff --git a/pkg/vmcp/optimizer/optimizer_unit_test.go b/pkg/vmcp/optimizer/optimizer_unit_test.go new file mode 100644 index 0000000000..c764d54aeb --- /dev/null +++ b/pkg/vmcp/optimizer/optimizer_unit_test.go @@ -0,0 +1,338 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package optimizer + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" +) + +// mockBackendClient implements vmcp.BackendClient for testing +type mockBackendClient struct { + capabilities *vmcp.CapabilityList + err error +} + +func (m *mockBackendClient) ListCapabilities(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + if m.err != nil { + return nil, m.err + } + return m.capabilities, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockBackendClient) CallTool(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (map[string]any, error) { + return nil, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockBackendClient) GetPrompt(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (string, error) { + return "", nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockBackendClient) ReadResource(_ context.Context, _ *vmcp.BackendTarget, _ string) ([]byte, error) { + return nil, nil +} + +// mockSession implements server.ClientSession for testing +type mockSession struct { + sessionID string +} + +func (m *mockSession) SessionID() string { + return m.sessionID +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockSession) Send(_ interface{}) error { + return nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockSession) Close() error { + return nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockSession) Initialize() { + // No-op for testing +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockSession) Initialized() bool { + return true +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + // Return a dummy channel for testing + ch := make(chan mcp.JSONRPCNotification, 1) + return ch +} + +// TestNewIntegration_Disabled tests that nil is returned when optimizer is disabled +func TestNewIntegration_Disabled(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Test with nil config + integration, err := NewIntegration(ctx, nil, nil, nil, nil) + require.NoError(t, err) + assert.Nil(t, integration, "Should return nil when config is nil") + + // Test with disabled config + config := &Config{Enabled: false} + integration, err = NewIntegration(ctx, config, nil, nil, nil) + require.NoError(t, err) + assert.Nil(t, integration, "Should return nil when optimizer is disabled") +} + +// TestNewIntegration_Enabled tests successful creation +func TestNewIntegration_Enabled(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "nomic-embed-text", + Dimension: 768, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.NotNil(t, integration) + defer func() { _ = integration.Close() }() +} + +// TestOnRegisterSession tests session registration +func TestOnRegisterSession(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "nomic-embed-text", + Dimension: 768, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + session := &mockSession{sessionID: "test-session"} + capabilities := &aggregator.AggregatedCapabilities{ + Tools: []vmcp.Tool{ + { + Name: "test_tool", + Description: "A test tool", + BackendID: "backend-1", + }, + }, + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "test_tool": { + WorkloadID: "backend-1", + WorkloadName: "Test Backend", + }, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + err = integration.OnRegisterSession(ctx, session, capabilities) + assert.NoError(t, err) +} + +// TestOnRegisterSession_NilIntegration tests nil integration handling +func TestOnRegisterSession_NilIntegration(t *testing.T) { + t.Parallel() + ctx := context.Background() + + var integration *OptimizerIntegration = nil + session := &mockSession{sessionID: "test-session"} + capabilities := &aggregator.AggregatedCapabilities{} + + err := integration.OnRegisterSession(ctx, session, capabilities) + assert.NoError(t, err) +} + +// TestRegisterTools tests tool registration behavior +func TestRegisterTools(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "nomic-embed-text", + Dimension: 768, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + session := &mockSession{sessionID: "test-session"} + // RegisterTools will fail with "session not found" because the mock session + // is not actually registered with the MCP server. This is expected behavior. + // We're just testing that the method executes without panicking. + _ = integration.RegisterTools(ctx, session) +} + +// TestRegisterTools_NilIntegration tests nil integration handling +func TestRegisterTools_NilIntegration(t *testing.T) { + t.Parallel() + ctx := context.Background() + + var integration *OptimizerIntegration = nil + session := &mockSession{sessionID: "test-session"} + + err := integration.RegisterTools(ctx, session) + assert.NoError(t, err) +} + +// TestClose tests cleanup +func TestClose(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "nomic-embed-text", + Dimension: 768, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + + err = integration.Close() + assert.NoError(t, err) + + // Multiple closes should be safe + err = integration.Close() + assert.NoError(t, err) +} + +// TestClose_NilIntegration tests nil integration close +func TestClose_NilIntegration(t *testing.T) { + t.Parallel() + + var integration *OptimizerIntegration = nil + err := integration.Close() + assert.NoError(t, err) +} diff --git a/pkg/vmcp/schema/reflect_test.go b/pkg/vmcp/schema/reflect_test.go index 55d9491019..2e0da8ed28 100644 --- a/pkg/vmcp/schema/reflect_test.go +++ b/pkg/vmcp/schema/reflect_test.go @@ -8,10 +8,85 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/stacklok/toolhive/pkg/vmcp/optimizer" ) +// FindToolInput represents the input schema for optim_find_tool +// This matches the schema defined in pkg/vmcp/optimizer/optimizer.go +type FindToolInput struct { + ToolDescription string `json:"tool_description" description:"Natural language description of the tool you're looking for"` + ToolKeywords string `json:"tool_keywords,omitempty" description:"Optional space-separated keywords for keyword-based search"` + Limit int `json:"limit,omitempty" description:"Maximum number of tools to return (default: 10)"` +} + +// CallToolInput represents the input schema for optim_call_tool +// This matches the schema defined in pkg/vmcp/optimizer/optimizer.go +type CallToolInput struct { + BackendID string `json:"backend_id" description:"Backend ID from find_tool results"` + ToolName string `json:"tool_name" description:"Tool name to invoke"` + Parameters map[string]any `json:"parameters" description:"Parameters to pass to the tool"` +} + +func TestGenerateSchema_AllTypes(t *testing.T) { + t.Parallel() + + type TestStruct struct { + StringField string `json:"string_field,omitempty"` + IntField int `json:"int_field"` + FloatField float64 `json:"float_field,omitempty"` + BoolField bool `json:"bool_field"` + OptionalStr string `json:"optional_str,omitempty"` + SliceField []int `json:"slice_field"` + MapField map[string]string `json:"map_field"` + StructField struct { + RequiredField string `json:"field"` + OptionalField string `json:"optional_field,omitempty"` + } `json:"struct_field"` + PointerField *int `json:"pointer_field"` + } + + expected := map[string]any{ + "type": "object", + "properties": map[string]any{ + "string_field": map[string]any{"type": "string"}, + "int_field": map[string]any{"type": "integer"}, + "float_field": map[string]any{"type": "number"}, + "bool_field": map[string]any{"type": "boolean"}, + "optional_str": map[string]any{"type": "string"}, + "slice_field": map[string]any{ + "type": "array", + "items": map[string]any{"type": "integer"}, + }, + "map_field": map[string]any{"type": "object"}, + "struct_field": map[string]any{ + "type": "object", + "properties": map[string]any{ + "field": map[string]any{"type": "string"}, + "optional_field": map[string]any{"type": "string"}, + }, + "required": []string{"field"}, + }, + "pointer_field": map[string]any{ + "type": "integer", + }, + }, + "required": []string{ + "int_field", + "bool_field", + "map_field", + "struct_field", + "pointer_field", + "slice_field", + }, + } + + actual, err := GenerateSchema[TestStruct]() + require.NoError(t, err) + + require.Equal(t, expected["type"], actual["type"]) + require.Equal(t, expected["properties"], actual["properties"]) + require.ElementsMatch(t, expected["required"], actual["required"]) +} + func TestGenerateSchema_FindToolInput(t *testing.T) { t.Parallel() @@ -20,18 +95,21 @@ func TestGenerateSchema_FindToolInput(t *testing.T) { "properties": map[string]any{ "tool_description": map[string]any{ "type": "string", - "description": "Natural language description of the tool to find", + "description": "Natural language description of the tool you're looking for", }, "tool_keywords": map[string]any{ - "type": "array", - "items": map[string]any{"type": "string"}, - "description": "Optional keywords to narrow search", + "type": "string", + "description": "Optional space-separated keywords for keyword-based search", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of tools to return (default: 10)", }, }, "required": []string{"tool_description"}, } - actual, err := GenerateSchema[optimizer.FindToolInput]() + actual, err := GenerateSchema[FindToolInput]() require.NoError(t, err) require.Equal(t, expected, actual) @@ -43,19 +121,23 @@ func TestGenerateSchema_CallToolInput(t *testing.T) { expected := map[string]any{ "type": "object", "properties": map[string]any{ + "backend_id": map[string]any{ + "type": "string", + "description": "Backend ID from find_tool results", + }, "tool_name": map[string]any{ "type": "string", - "description": "Name of the tool to call", + "description": "Tool name to invoke", }, "parameters": map[string]any{ "type": "object", "description": "Parameters to pass to the tool", }, }, - "required": []string{"tool_name", "parameters"}, + "required": []string{"backend_id", "tool_name", "parameters"}, } - actual, err := GenerateSchema[optimizer.CallToolInput]() + actual, err := GenerateSchema[CallToolInput]() require.NoError(t, err) require.Equal(t, expected, actual) @@ -66,15 +148,17 @@ func TestTranslate_FindToolInput(t *testing.T) { input := map[string]any{ "tool_description": "find a tool to read files", - "tool_keywords": []any{"file", "read"}, + "tool_keywords": "file read", + "limit": 5, } - result, err := Translate[optimizer.FindToolInput](input) + result, err := Translate[FindToolInput](input) require.NoError(t, err) - require.Equal(t, optimizer.FindToolInput{ + require.Equal(t, FindToolInput{ ToolDescription: "find a tool to read files", - ToolKeywords: []string{"file", "read"}, + ToolKeywords: "file read", + Limit: 5, }, result) } @@ -82,16 +166,18 @@ func TestTranslate_CallToolInput(t *testing.T) { t.Parallel() input := map[string]any{ - "tool_name": "read_file", + "backend_id": "backend-123", + "tool_name": "read_file", "parameters": map[string]any{ "path": "/etc/hosts", }, } - result, err := Translate[optimizer.CallToolInput](input) + result, err := Translate[CallToolInput](input) require.NoError(t, err) - require.Equal(t, optimizer.CallToolInput{ + require.Equal(t, CallToolInput{ + BackendID: "backend-123", ToolName: "read_file", Parameters: map[string]any{"path": "/etc/hosts"}, }, result) @@ -104,12 +190,13 @@ func TestTranslate_PartialInput(t *testing.T) { "tool_description": "find a file reader", } - result, err := Translate[optimizer.FindToolInput](input) + result, err := Translate[FindToolInput](input) require.NoError(t, err) - require.Equal(t, optimizer.FindToolInput{ + require.Equal(t, FindToolInput{ ToolDescription: "find a file reader", - ToolKeywords: nil, + ToolKeywords: "", + Limit: 0, }, result) } @@ -118,68 +205,7 @@ func TestTranslate_InvalidInput(t *testing.T) { input := make(chan int) - _, err := Translate[optimizer.FindToolInput](input) + _, err := Translate[FindToolInput](input) require.Error(t, err) assert.Contains(t, err.Error(), "failed to marshal input") } - -func TestGenerateSchema_AllTypes(t *testing.T) { - t.Parallel() - - type TestStruct struct { - StringField string `json:"string_field,omitempty"` - IntField int `json:"int_field"` - FloatField float64 `json:"float_field,omitempty"` - BoolField bool `json:"bool_field"` - OptionalStr string `json:"optional_str,omitempty"` - SliceField []int `json:"slice_field"` - MapField map[string]string `json:"map_field"` - StructField struct { - RequiredField string `json:"field"` - OptionalField string `json:"optional_field,omitempty"` - } `json:"struct_field"` - PointerField *int `json:"pointer_field"` - } - - expected := map[string]any{ - "type": "object", - "properties": map[string]any{ - "string_field": map[string]any{"type": "string"}, - "int_field": map[string]any{"type": "integer"}, - "float_field": map[string]any{"type": "number"}, - "bool_field": map[string]any{"type": "boolean"}, - "optional_str": map[string]any{"type": "string"}, - "slice_field": map[string]any{ - "type": "array", - "items": map[string]any{"type": "integer"}, - }, - "map_field": map[string]any{"type": "object"}, - "struct_field": map[string]any{ - "type": "object", - "properties": map[string]any{ - "field": map[string]any{"type": "string"}, - "optional_field": map[string]any{"type": "string"}, - }, - "required": []string{"field"}, - }, - "pointer_field": map[string]any{ - "type": "integer", - }, - }, - "required": []string{ - "int_field", - "bool_field", - "map_field", - "struct_field", - "pointer_field", - "slice_field", - }, - } - - actual, err := GenerateSchema[TestStruct]() - require.NoError(t, err) - - require.Equal(t, expected["type"], actual["type"]) - require.Equal(t, expected["properties"], actual["properties"]) - require.ElementsMatch(t, expected["required"], actual["required"]) -} diff --git a/pkg/vmcp/server/adapter/capability_adapter.go b/pkg/vmcp/server/adapter/capability_adapter.go index 875ecbd9b0..e3b488dacc 100644 --- a/pkg/vmcp/server/adapter/capability_adapter.go +++ b/pkg/vmcp/server/adapter/capability_adapter.go @@ -4,6 +4,7 @@ package adapter import ( + "context" "encoding/json" "fmt" @@ -14,6 +15,17 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp" ) +// OptimizerHandlerProvider provides handlers for optimizer tools. +// This interface allows the adapter to create optimizer tools without +// depending on the optimizer package implementation. +type OptimizerHandlerProvider interface { + // CreateFindToolHandler returns the handler for optim_find_tool + CreateFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) + + // CreateCallToolHandler returns the handler for optim_call_tool + CreateCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) +} + // CapabilityAdapter converts aggregator domain models to SDK types. // // This is the Anti-Corruption Layer between: @@ -208,3 +220,15 @@ func (a *CapabilityAdapter) ToCompositeToolSDKTools( return sdkTools, nil } + +// CreateOptimizerTools creates SDK tools for optimizer mode. +// +// When optimizer is enabled, only optim_find_tool and optim_call_tool are exposed +// to clients instead of all backend tools. This method delegates to the standalone +// CreateOptimizerTools function in optimizer_adapter.go for consistency. +// +// This keeps optimizer tool creation consistent with other tool types (backend, +// composite) by going through the adapter layer. +func (*CapabilityAdapter) CreateOptimizerTools(provider OptimizerHandlerProvider) ([]server.ServerTool, error) { + return CreateOptimizerTools(provider) +} diff --git a/pkg/vmcp/server/adapter/optimizer_adapter.go b/pkg/vmcp/server/adapter/optimizer_adapter.go index 07a6f4cb72..d38d2fa514 100644 --- a/pkg/vmcp/server/adapter/optimizer_adapter.go +++ b/pkg/vmcp/server/adapter/optimizer_adapter.go @@ -4,15 +4,11 @@ package adapter import ( - "context" "encoding/json" "fmt" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" - - "github.com/stacklok/toolhive/pkg/vmcp/optimizer" - "github.com/stacklok/toolhive/pkg/vmcp/schema" ) // OptimizerToolNames defines the tool names exposed when optimizer is enabled. @@ -24,80 +20,88 @@ const ( // Pre-generated schemas for optimizer tools. // Generated at package init time so any schema errors panic at startup. var ( - findToolInputSchema = mustGenerateSchema[optimizer.FindToolInput]() - callToolInputSchema = mustGenerateSchema[optimizer.CallToolInput]() + findToolInputSchema = mustMarshalSchema(findToolSchema) + callToolInputSchema = mustMarshalSchema(callToolSchema) +) + +// Tool schemas defined once to eliminate duplication. +var ( + findToolSchema = mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "tool_description": map[string]any{ + "type": "string", + "description": "Natural language description of the tool you're looking for", + }, + "tool_keywords": map[string]any{ + "type": "string", + "description": "Optional space-separated keywords for keyword-based search", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of tools to return (default: 10)", + "default": 10, + }, + }, + Required: []string{"tool_description"}, + } + + callToolSchema = mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "backend_id": map[string]any{ + "type": "string", + "description": "Backend ID from find_tool results", + }, + "tool_name": map[string]any{ + "type": "string", + "description": "Tool name to invoke", + }, + "parameters": map[string]any{ + "type": "object", + "description": "Parameters to pass to the tool", + }, + }, + Required: []string{"backend_id", "tool_name", "parameters"}, + } ) // CreateOptimizerTools creates the SDK tools for optimizer mode. // When optimizer is enabled, only these two tools are exposed to clients // instead of all backend tools. -func CreateOptimizerTools(opt optimizer.Optimizer) []server.ServerTool { +// +// This function uses the OptimizerHandlerProvider interface to get handlers, +// allowing it to work with OptimizerIntegration without direct dependency. +func CreateOptimizerTools(provider OptimizerHandlerProvider) ([]server.ServerTool, error) { + if provider == nil { + return nil, fmt.Errorf("optimizer handler provider cannot be nil") + } + return []server.ServerTool{ { Tool: mcp.Tool{ Name: FindToolName, - Description: "Search for tools by description. Returns matching tools ranked by relevance.", + Description: "Semantic search across all backend tools using natural language description and optional keywords", RawInputSchema: findToolInputSchema, }, - Handler: createFindToolHandler(opt), + Handler: provider.CreateFindToolHandler(), }, { Tool: mcp.Tool{ Name: CallToolName, - Description: "Call a tool by name with the given parameters.", + Description: "Dynamically invoke any tool on any backend using the backend_id from find_tool", RawInputSchema: callToolInputSchema, }, - Handler: createCallToolHandler(opt), + Handler: provider.CreateCallToolHandler(), }, - } -} - -// createFindToolHandler creates a handler for the find_tool optimizer operation. -func createFindToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - input, err := schema.Translate[optimizer.FindToolInput](request.Params.Arguments) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil - } - - output, err := opt.FindTool(ctx, input) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("find_tool failed: %v", err)), nil - } - - return mcp.NewToolResultStructuredOnly(output), nil - } -} - -// createCallToolHandler creates a handler for the call_tool optimizer operation. -func createCallToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - input, err := schema.Translate[optimizer.CallToolInput](request.Params.Arguments) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil - } - - result, err := opt.CallTool(ctx, input) - if err != nil { - // Exposing the error to the MCP client is important if you want it to correct its behavior. - // Without information on the failure, the model is pretty much hopeless in figuring out the problem. - return mcp.NewToolResultError(fmt.Sprintf("call_tool failed: %v", err)), nil - } - - return result, nil - } + }, nil } // mustMarshalSchema marshals a schema to JSON, panicking on error. // This is safe because schemas are generated from known types at startup. // This should NOT be called by runtime code. -func mustGenerateSchema[T any]() json.RawMessage { - s, err := schema.GenerateSchema[T]() - if err != nil { - panic(fmt.Sprintf("failed to generate schema: %v", err)) - } - - data, err := json.Marshal(s) +func mustMarshalSchema(schema mcp.ToolInputSchema) json.RawMessage { + data, err := json.Marshal(schema) if err != nil { panic(fmt.Sprintf("failed to marshal schema: %v", err)) } diff --git a/pkg/vmcp/server/adapter/optimizer_adapter_test.go b/pkg/vmcp/server/adapter/optimizer_adapter_test.go index b5ad7e066a..4272a978c4 100644 --- a/pkg/vmcp/server/adapter/optimizer_adapter_test.go +++ b/pkg/vmcp/server/adapter/optimizer_adapter_test.go @@ -9,65 +9,76 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/require" - - "github.com/stacklok/toolhive/pkg/vmcp/optimizer" ) -// mockOptimizer implements optimizer.Optimizer for testing. -type mockOptimizer struct { - findToolFunc func(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) - callToolFunc func(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) +// mockOptimizerHandlerProvider implements OptimizerHandlerProvider for testing. +type mockOptimizerHandlerProvider struct { + findToolHandler func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) + callToolHandler func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) } -func (m *mockOptimizer) FindTool(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) { - if m.findToolFunc != nil { - return m.findToolFunc(ctx, input) +func (m *mockOptimizerHandlerProvider) CreateFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if m.findToolHandler != nil { + return m.findToolHandler + } + return func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("ok"), nil } - return &optimizer.FindToolOutput{}, nil } -func (m *mockOptimizer) CallTool(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) { - if m.callToolFunc != nil { - return m.callToolFunc(ctx, input) +func (m *mockOptimizerHandlerProvider) CreateCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if m.callToolHandler != nil { + return m.callToolHandler + } + return func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("ok"), nil } - return mcp.NewToolResultText("ok"), nil } func TestCreateOptimizerTools(t *testing.T) { t.Parallel() - opt := &mockOptimizer{} - tools := CreateOptimizerTools(opt) + provider := &mockOptimizerHandlerProvider{} + tools, err := CreateOptimizerTools(provider) + require.NoError(t, err) require.Len(t, tools, 2) require.Equal(t, FindToolName, tools[0].Tool.Name) require.Equal(t, CallToolName, tools[1].Tool.Name) } +func TestCreateOptimizerTools_NilProvider(t *testing.T) { + t.Parallel() + + tools, err := CreateOptimizerTools(nil) + + require.Error(t, err) + require.Nil(t, tools) + require.Contains(t, err.Error(), "cannot be nil") +} + func TestFindToolHandler(t *testing.T) { t.Parallel() - opt := &mockOptimizer{ - findToolFunc: func(_ context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) { - require.Equal(t, "read files", input.ToolDescription) - return &optimizer.FindToolOutput{ - Tools: []optimizer.ToolMatch{ - { - Name: "read_file", - Description: "Read a file", - Score: 1.0, - }, - }, - }, nil + provider := &mockOptimizerHandlerProvider{ + findToolHandler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args, ok := req.Params.Arguments.(map[string]any) + require.True(t, ok) + require.Equal(t, "read files", args["tool_description"]) + return mcp.NewToolResultText("found tools"), nil }, } - tools := CreateOptimizerTools(opt) + tools, err := CreateOptimizerTools(provider) + require.NoError(t, err) handler := tools[0].Handler - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]any{ - "tool_description": "read files", + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "tool_description": "read files", + }, + }, } result, err := handler(context.Background(), request) @@ -80,22 +91,29 @@ func TestFindToolHandler(t *testing.T) { func TestCallToolHandler(t *testing.T) { t.Parallel() - opt := &mockOptimizer{ - callToolFunc: func(_ context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) { - require.Equal(t, "read_file", input.ToolName) - require.Equal(t, "/etc/hosts", input.Parameters["path"]) + provider := &mockOptimizerHandlerProvider{ + callToolHandler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args, ok := req.Params.Arguments.(map[string]any) + require.True(t, ok) + require.Equal(t, "read_file", args["tool_name"]) + params := args["parameters"].(map[string]any) + require.Equal(t, "/etc/hosts", params["path"]) return mcp.NewToolResultText("file contents here"), nil }, } - tools := CreateOptimizerTools(opt) + tools, err := CreateOptimizerTools(provider) + require.NoError(t, err) handler := tools[1].Handler - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]any{ - "tool_name": "read_file", - "parameters": map[string]any{ - "path": "/etc/hosts", + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "tool_name": "read_file", + "parameters": map[string]any{ + "path": "/etc/hosts", + }, + }, }, } diff --git a/pkg/vmcp/server/optimizer_test.go b/pkg/vmcp/server/optimizer_test.go new file mode 100644 index 0000000000..56cfeff396 --- /dev/null +++ b/pkg/vmcp/server/optimizer_test.go @@ -0,0 +1,362 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package server + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" + "github.com/stacklok/toolhive/pkg/vmcp/mocks" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" + "github.com/stacklok/toolhive/pkg/vmcp/router" +) + +// TestNew_OptimizerEnabled tests server creation with optimizer enabled +func TestNew_OptimizerEnabled(t *testing.T) { + t.Parallel() + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockBackendClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + AnyTimes() + + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT(). + Discover(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + AnyTimes() + mockDiscoveryMgr.EXPECT().Stop().AnyTimes() + + tmpDir := t.TempDir() + + // Try to use Ollama if available + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + cfg := &Config{ + Name: "test-server", + Version: "1.0.0", + Host: "127.0.0.1", + Port: 0, + SessionTTL: 5 * time.Minute, + OptimizerConfig: &optimizer.Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + HybridSearchRatio: 70, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + }, + } + + rt := router.NewDefaultRouter() + backends := []vmcp.Backend{ + { + ID: "backend-1", + Name: "Backend 1", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + require.NoError(t, err) + require.NotNil(t, srv) + defer func() { _ = srv.Stop(context.Background()) }() + + // Verify optimizer integration was created + // We can't directly access optimizerIntegration, but we can verify server was created successfully +} + +// TestNew_OptimizerDisabled tests server creation with optimizer disabled +func TestNew_OptimizerDisabled(t *testing.T) { + t.Parallel() + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT().Stop().AnyTimes() + + cfg := &Config{ + Name: "test-server", + Version: "1.0.0", + Host: "127.0.0.1", + Port: 0, + SessionTTL: 5 * time.Minute, + OptimizerConfig: &optimizer.Config{ + Enabled: false, // Disabled + }, + } + + rt := router.NewDefaultRouter() + backends := []vmcp.Backend{} + + srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + require.NoError(t, err) + require.NotNil(t, srv) + defer func() { _ = srv.Stop(context.Background()) }() +} + +// TestNew_OptimizerConfigNil tests server creation with nil optimizer config +func TestNew_OptimizerConfigNil(t *testing.T) { + t.Parallel() + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT().Stop().AnyTimes() + + cfg := &Config{ + Name: "test-server", + Version: "1.0.0", + Host: "127.0.0.1", + Port: 0, + SessionTTL: 5 * time.Minute, + OptimizerConfig: nil, // Nil config + } + + rt := router.NewDefaultRouter() + backends := []vmcp.Backend{} + + srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + require.NoError(t, err) + require.NotNil(t, srv) + defer func() { _ = srv.Stop(context.Background()) }() +} + +// TestNew_OptimizerIngestionError tests error handling during optimizer ingestion +func TestNew_OptimizerIngestionError(t *testing.T) { + t.Parallel() + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockBackendClient := mocks.NewMockBackendClient(ctrl) + // Return error when listing capabilities + mockBackendClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(nil, assert.AnError). + AnyTimes() + + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT().Stop().AnyTimes() + + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + cfg := &Config{ + Name: "test-server", + Version: "1.0.0", + Host: "127.0.0.1", + Port: 0, + SessionTTL: 5 * time.Minute, + OptimizerConfig: &optimizer.Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + }, + } + + rt := router.NewDefaultRouter() + backends := []vmcp.Backend{ + { + ID: "backend-1", + Name: "Backend 1", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + // Should not fail even if ingestion fails + srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + require.NoError(t, err, "Server should be created even if optimizer ingestion fails") + require.NotNil(t, srv) + defer func() { _ = srv.Stop(context.Background()) }() +} + +// TestNew_OptimizerHybridRatio tests hybrid ratio configuration +func TestNew_OptimizerHybridRatio(t *testing.T) { + t.Parallel() + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockBackendClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + AnyTimes() + + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT(). + Discover(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + AnyTimes() + mockDiscoveryMgr.EXPECT().Stop().AnyTimes() + + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + cfg := &Config{ + Name: "test-server", + Version: "1.0.0", + Host: "127.0.0.1", + Port: 0, + SessionTTL: 5 * time.Minute, + OptimizerConfig: &optimizer.Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + HybridSearchRatio: 50, // Custom ratio + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + }, + } + + rt := router.NewDefaultRouter() + backends := []vmcp.Backend{} + + srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + require.NoError(t, err) + require.NotNil(t, srv) + defer func() { _ = srv.Stop(context.Background()) }() +} + +// TestServer_Stop_OptimizerCleanup tests optimizer cleanup on server stop +func TestServer_Stop_OptimizerCleanup(t *testing.T) { + t.Parallel() + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockBackendClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + AnyTimes() + + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT(). + Discover(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + AnyTimes() + mockDiscoveryMgr.EXPECT().Stop().AnyTimes() + + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + cfg := &Config{ + Name: "test-server", + Version: "1.0.0", + Host: "127.0.0.1", + Port: 0, + SessionTTL: 5 * time.Minute, + OptimizerConfig: &optimizer.Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + }, + } + + rt := router.NewDefaultRouter() + backends := []vmcp.Backend{} + + srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + require.NoError(t, err) + require.NotNil(t, srv) + + // Stop should clean up optimizer + err = srv.Stop(context.Background()) + require.NoError(t, err) +} diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 3ccecdf39c..910494218f 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -125,9 +125,15 @@ type Config struct { // Used for /readyz endpoint to gate readiness on cache sync. Watcher Watcher - // OptimizerFactory builds an optimizer from a list of tools. - // If not set, the optimizer is disabled. - OptimizerFactory func([]server.ServerTool) optimizer.Optimizer + // OptimizerIntegration is the optional optimizer integration. + // If nil, optimizer is disabled and backend tools are exposed directly. + // If set, this takes precedence over OptimizerConfig. + OptimizerIntegration optimizer.Integration + + // OptimizerConfig is the optional optimizer configuration (for backward compatibility). + // If OptimizerIntegration is set, this is ignored. + // If both are nil, optimizer is disabled. + OptimizerConfig *optimizer.Config // StatusReporter enables vMCP runtime to report operational status. // In Kubernetes mode: Updates VirtualMCPServer.Status (requires RBAC) @@ -340,7 +346,15 @@ func New( if cfg.HealthMonitorConfig != nil { // Get initial backends list from registry for health monitoring setup initialBackends := backendRegistry.List(ctx) - healthMon, err = health.NewMonitor(backendClient, initialBackends, *cfg.HealthMonitorConfig) + + // Construct server's own URL for self-check detection + // Use http:// as default scheme (most common for local development) + var selfURL string + if cfg.Host != "" && cfg.Port > 0 { + selfURL = fmt.Sprintf("http://%s:%d", cfg.Host, cfg.Port) + } + + healthMon, err = health.NewMonitor(backendClient, initialBackends, *cfg.HealthMonitorConfig, selfURL) if err != nil { return nil, fmt.Errorf("failed to create health monitor: %w", err) } @@ -533,6 +547,23 @@ func (s *Server) Start(ctx context.Context) error { } } + // Initialize optimizer integration if configured + if s.config.OptimizerIntegration == nil && s.config.OptimizerConfig != nil && s.config.OptimizerConfig.Enabled { + // Create optimizer integration from config (for backward compatibility) + optimizerInteg, err := optimizer.NewIntegration(ctx, s.config.OptimizerConfig, s.mcpServer, s.backendClient, s.sessionManager) + if err != nil { + return fmt.Errorf("failed to create optimizer integration: %w", err) + } + s.config.OptimizerIntegration = optimizerInteg + } + + // Initialize optimizer if configured (registers tools and ingests backends) + if s.config.OptimizerIntegration != nil { + if err := s.config.OptimizerIntegration.Initialize(ctx, s.mcpServer, s.backendRegistry); err != nil { + return fmt.Errorf("failed to initialize optimizer: %w", err) + } + } + // Start status reporter if configured if s.statusReporter != nil { shutdown, err := s.statusReporter.Start(ctx) @@ -592,6 +623,13 @@ func (s *Server) Stop(ctx context.Context) error { } } + // Stop optimizer integration if configured + if s.config.OptimizerIntegration != nil { + if err := s.config.OptimizerIntegration.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close optimizer integration: %w", err)) + } + } + // Run shutdown functions (e.g., status reporter, future components) for _, shutdown := range s.shutdownFuncs { if err := shutdown(ctx); err != nil { @@ -746,7 +784,6 @@ func (s *Server) Ready() <-chan struct{} { // - No previous capabilities exist, so no deletion needed // - Capabilities are IMMUTABLE for the session lifetime (see limitation below) // - Discovery middleware does not re-run for subsequent requests -// - If injectOptimizerCapabilities is called, this should not be called again. // // LIMITATION: Session capabilities are fixed at creation time. // If backends change (new tools added, resources removed), existing sessions won't see updates. @@ -820,54 +857,6 @@ func (s *Server) injectCapabilities( return nil } -// injectOptimizerCapabilities injects all capabilities into the session, including optimizer tools. -// It should not be called if not in optimizer mode and replaces injectCapabilities. -// -// When optimizer mode is enabled, instead of exposing all backend tools directly, -// vMCP exposes only two meta-tools: -// - find_tool: Search for tools by description -// - call_tool: Invoke a tool by name with parameters -// -// This method: -// 1. Converts all tools (backend + composite) to SDK format with handlers -// 2. Injects the optimizer capabilities into the session -func (s *Server) injectOptimizerCapabilities( - sessionID string, - caps *aggregator.AggregatedCapabilities, -) error { - - tools := append([]vmcp.Tool{}, caps.Tools...) - tools = append(tools, caps.CompositeTools...) - - sdkTools, err := s.capabilityAdapter.ToSDKTools(tools) - if err != nil { - return fmt.Errorf("failed to convert tools to SDK format: %w", err) - } - - // Create optimizer tools (find_tool, call_tool) - optimizerTools := adapter.CreateOptimizerTools(s.config.OptimizerFactory(sdkTools)) - - logger.Debugw("created optimizer tools for session", - "session_id", sessionID, - "backend_tool_count", len(caps.Tools), - "composite_tool_count", len(caps.CompositeTools), - "total_tools_indexed", len(sdkTools)) - - // Clear tools from caps - they're now wrapped by optimizer - // Resources and prompts are preserved and handled normally - capsCopy := *caps - capsCopy.Tools = nil - capsCopy.CompositeTools = nil - - // Manually add the optimizer tools, since we don't want to bother converting - // optimizer tools into `vmcp.Tool`s as well. - if err := s.mcpServer.AddSessionTools(sessionID, optimizerTools...); err != nil { - return fmt.Errorf("failed to add session tools: %w", err) - } - - return s.injectCapabilities(sessionID, &capsCopy) -} - // handleSessionRegistration processes a new MCP session registration. // // This hook fires AFTER the session is registered in the SDK (unlike AfterInitialize which @@ -880,7 +869,7 @@ func (s *Server) injectOptimizerCapabilities( // 1. Retrieves discovered capabilities from context // 2. Adds composite tools from configuration // 3. Stores routing table in VMCPSession for request routing -// 4. Injects capabilities into the SDK session +// 4. Injects capabilities into the SDK session (or delegates to optimizer if enabled) // // IMPORTANT: Session capabilities are immutable after injection. // - Capabilities discovered during initialize are fixed for the session lifetime @@ -955,16 +944,26 @@ func (s *Server) handleSessionRegistration( "resource_count", len(caps.RoutingTable.Resources), "prompt_count", len(caps.RoutingTable.Prompts)) - if s.config.OptimizerFactory != nil { - err = s.injectOptimizerCapabilities(sessionID, caps) + // Delegate to optimizer integration if enabled + if s.config.OptimizerIntegration != nil { + handled, err := s.config.OptimizerIntegration.HandleSessionRegistration( + ctx, + sessionID, + caps, + s.mcpServer, + s.capabilityAdapter.ToSDKResources, + ) if err != nil { - logger.Errorw("failed to create optimizer tools", + logger.Errorw("failed to handle session registration with optimizer", "error", err, "session_id", sessionID) - } else { - logger.Infow("optimizer capabilities injected") + return } - return + if handled { + // Optimizer handled the registration, we're done + return + } + // If optimizer didn't handle it, fall through to normal registration } // Inject capabilities into SDK session diff --git a/test/e2e/thv-operator/virtualmcp/helpers.go b/test/e2e/thv-operator/virtualmcp/helpers.go index ca73e206f2..e48d3fdea5 100644 --- a/test/e2e/thv-operator/virtualmcp/helpers.go +++ b/test/e2e/thv-operator/virtualmcp/helpers.go @@ -89,8 +89,9 @@ func checkPodsReady(ctx context.Context, c client.Client, namespace string, labe } for _, pod := range podList.Items { + // Skip pods that are not running (e.g., Succeeded, Failed from old deployments) if pod.Status.Phase != corev1.PodRunning { - return fmt.Errorf("pod %s is in phase %s", pod.Name, pod.Status.Phase) + continue } containerReady := false @@ -114,6 +115,17 @@ func checkPodsReady(ctx context.Context, c client.Client, namespace string, labe return fmt.Errorf("pod %s not ready", pod.Name) } } + + // After filtering, ensure we found at least one running pod + runningPods := 0 + for _, pod := range podList.Items { + if pod.Status.Phase == corev1.PodRunning { + runningPods++ + } + } + if runningPods == 0 { + return fmt.Errorf("no running pods found with labels %v", labels) + } return nil } diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go index e7e33fd623..18af2c94df 100644 --- a/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go @@ -1162,12 +1162,28 @@ with socketserver.TCPServer(("", PORT), OIDCHandler) as httpd: } It("should list and call tools from all backends with discovered auth", func() { + By("Verifying vMCP pods are still running and ready before health check") + vmcpLabels := map[string]string{ + "app.kubernetes.io/name": "virtualmcpserver", + "app.kubernetes.io/instance": vmcpServerName, + } + WaitForPodsReady(ctx, k8sClient, testNamespace, vmcpLabels, 30*time.Second, 2*time.Second) + + // Create HTTP client with timeout for health checks + healthCheckClient := &http.Client{ + Timeout: 10 * time.Second, + } + By("Verifying HTTP connectivity to VirtualMCPServer health endpoint") Eventually(func() error { + // Re-check pod readiness before each health check attempt + if err := checkPodsReady(ctx, k8sClient, testNamespace, vmcpLabels); err != nil { + return fmt.Errorf("pods not ready: %w", err) + } url := fmt.Sprintf("http://localhost:%d/health", vmcpNodePort) - resp, err := http.Get(url) + resp, err := healthCheckClient.Get(url) if err != nil { - return err + return fmt.Errorf("health check failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go index 67610b043f..b08039b94e 100644 --- a/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go @@ -72,8 +72,9 @@ var _ = Describe("VirtualMCPServer Optimizer Mode", Ordered, func() { Config: vmcpconfig.Config{ Group: mcpGroupName, Optimizer: &vmcpconfig.OptimizerConfig{ - // EmbeddingService is required but not used by DummyOptimizer - EmbeddingService: "dummy-embedding-service", + // EmbeddingURL is required for optimizer configuration + // For in-cluster services, use the full service DNS name with port + EmbeddingURL: "http://dummy-embedding-service.default.svc.cluster.local:11434", }, // Define a composite tool that calls fetch twice CompositeTools: []vmcpconfig.CompositeToolConfig{