Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion cmd/embedder/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,14 @@ func main() {
llmClient = ollama.New(cfg.llmConfig.BaseURL, cfg.llmConfig.Model)
}

// clientFactory builds a per-request LLM client when the caller overrides the model.
clientFactory := llm.ClientFactory(func(cfg llm.Config) llm.Client {
if cfg.Provider == "openai" {
return llmopenai.NewFromConfig(cfg)
}
return ollama.NewFromConfig(cfg)
})

var reranker llm.Reranker
if cfg.rerankerModel != "" {
rerankerBase := cfg.rerankerBaseURL
Expand All @@ -267,7 +275,7 @@ func main() {
slog.Info("reranker enabled", "model", cfg.rerankerModel, "base_url", rerankerBase)
}

chatSvc := service.NewChatService(retrieveSvc, llmClient, reranker, cfg.chatTopK)
chatSvc := service.NewChatService(retrieveSvc, llmClient, reranker, cfg.chatTopK, cfg.llmConfig, clientFactory)

// ── HTTP server ───────────────────────────────────────────────────────────

Expand All @@ -284,6 +292,7 @@ func main() {
worker.Trigger,
cfg.googleOAuthClientID,
cfg.googleOAuthClientSecret,
cfg.llmConfig.BaseURL,
)
srv := &http.Server{
Addr: cfg.httpAddr,
Expand Down
1 change: 1 addition & 0 deletions docker/cube-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ services:
EMBEDDER_RCLONE_BINARY: /usr/bin/rclone
EMBEDDER_RCLONE_CONFIG_DIR: /etc/cube/rclone
EMBEDDER_RCLONE_TIMEOUT: 2m
EMBEDDER_RCLONE_PREFLIGHT: "false"
EMBEDDER_LOG_LEVEL: ${EMBEDDER_LOG_LEVEL}
EMBEDDER_LLM_PROVIDER: ${EMBEDDER_LLM_PROVIDER}
EMBEDDER_LLM_BASE_URL: ${EMBEDDER_LLM_BASE_URL}
Expand Down
3 changes: 3 additions & 0 deletions internal/embedder/api/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func NewRouter(
trigger func(),
googleOAuthClientID string,
googleOAuthClientSecret string,
ollamaBaseURL string,
) http.Handler {
r := chi.NewRouter()
r.Use(chimw.Recoverer)
Expand All @@ -39,6 +40,8 @@ func NewRouter(
r.Get("/health", healthHandler)
r.Handle("/metrics", promhttp.Handler())

transport.MountModels(r, ollamaBaseURL)

r.Group(func(r chi.Router) {
r.Use(auth.Middleware(authenticator))
transport.MountSources(r, sourcesSvc, sourceSyncSvc, trigger, googleOAuthClientID, googleOAuthClientSecret)
Expand Down
3 changes: 2 additions & 1 deletion internal/embedder/api/transport/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type chatRequest struct {
Messages []domain.ChatMessage `json:"messages"`
RecordIDs []string `json:"record_ids,omitempty"`
ConversationID string `json:"conversation_id,omitempty"`
Model *domain.ModelConfig `json:"model,omitempty"`
}

func chatHandler(svc domain.ChatService, conversations domain.ConversationRepository) http.HandlerFunc {
Expand Down Expand Up @@ -58,7 +59,7 @@ func chatHandler(svc domain.ChatService, conversations domain.ConversationReposi
_ = conversations.AppendMessages(r.Context(), convID, toDomainMessages(req.Messages))
}

events, err := svc.Chat(r.Context(), domainID, req.Messages, req.RecordIDs)
events, err := svc.Chat(r.Context(), domainID, req.Messages, req.RecordIDs, req.Model)
if err != nil {
writeJSON(w, http.StatusInternalServerError, errBody("chat failed: "+err.Error()))
return
Expand Down
52 changes: 52 additions & 0 deletions internal/embedder/api/transport/models.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0

package transport

import (
"encoding/json"
"net/http"
"strings"

"github.com/go-chi/chi/v5"
)

// MountModels registers model-listing endpoints.
func MountModels(r chi.Router, ollamaBaseURL string) {
r.Get("/api/v1/models/ollama", listOllamaModelsHandler(ollamaBaseURL))
}

type ollamaTagsResponse struct {
Models []struct {
Name string `json:"name"`
} `json:"models"`
}

func listOllamaModelsHandler(baseURL string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
resp, err := http.Get(baseURL + "/api/tags") //nolint:noctx
if err != nil {
writeJSON(w, http.StatusBadGateway, errBody("ollama unreachable: "+err.Error()))
return
}
defer resp.Body.Close()

var tags ollamaTagsResponse
if err := json.NewDecoder(resp.Body).Decode(&tags); err != nil {
writeJSON(w, http.StatusBadGateway, errBody("ollama response invalid: "+err.Error()))
return
}

names := make([]string, 0, len(tags.Models))
for _, m := range tags.Models {
// Skip models that don't support chat (embeddings, code-completion).
name := strings.ToLower(m.Name)
if strings.Contains(name, "embed") || strings.Contains(name, "starcoder") {
continue
}
names = append(names, m.Name)
}

writeJSON(w, http.StatusOK, map[string]any{"models": names})
}
}
23 changes: 22 additions & 1 deletion internal/embedder/domain/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,31 @@ type ChatEvent struct {
ConversationID string `json:"conversation_id,omitempty"`
}

// ModelConfig carries per-request LLM overrides sent by the client.
// Zero/empty values mean "use the server default".
type ModelConfig struct {
// Provider selects the LLM backend: "ollama" or "openai".
// Use "openai" for any OpenAI-compatible API (OpenAI, Anthropic, etc.).
Provider string `json:"provider"`
// BaseURL overrides the server-configured endpoint.
// Leave empty to use the server default.
BaseURL string `json:"base_url,omitempty"`
// Model is the model identifier (e.g. "llama3.1:8b", "gpt-4o").
Model string `json:"model"`
// APIKey is required for OpenAI-compatible providers.
// Never logged or persisted on the server.
APIKey string `json:"api_key,omitempty"`
// Temperature controls response randomness (0–1).
Temperature float64 `json:"temperature"`
// MaxTokens caps the response length (0 = provider default).
MaxTokens int `json:"max_tokens,omitempty"`
}

// ChatService orchestrates the full RAG pipeline for a query.
type ChatService interface {
// Chat embeds the query, retrieves relevant chunks, calls the LLM, and
// streams events to the returned channel. The channel is closed when the
// stream ends (either successfully or after an error event).
Chat(ctx context.Context, domainID string, messages []ChatMessage, recordIDs []string) (<-chan ChatEvent, error)
// modelCfg is optional; nil means use the server-configured default.
Chat(ctx context.Context, domainID string, messages []ChatMessage, recordIDs []string, modelCfg *ModelConfig) (<-chan ChatEvent, error)
}
15 changes: 11 additions & 4 deletions internal/embedder/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@ type Client interface {

// Config describes how to connect to an LLM provider.
type Config struct {
Provider string // "openai" | "ollama"
BaseURL string
Model string
APIKey string // required for openai
Provider string // "openai" | "ollama"
BaseURL string
Model string
APIKey string // required for openai
Temperature float64 // 0 = provider default
MaxTokens int // 0 = provider default
}

// ClientFactory builds an LLM client from a Config.
// Injected at startup to avoid circular imports between the llm and
// concrete provider packages.
type ClientFactory func(cfg Config) Client
22 changes: 18 additions & 4 deletions internal/embedder/llm/ollama/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ import (

// Client streams chat completions from an Ollama server.
type Client struct {
baseURL string
model string
http *http.Client
baseURL string
model string
temperature float64
maxTokens int
http *http.Client
}

// New returns an Ollama chat streaming client.
Expand All @@ -30,13 +32,25 @@ func New(baseURL, model string) *Client {
}
}

// NewFromConfig returns an Ollama client configured from an llm.Config.
func NewFromConfig(cfg llm.Config) *Client {
return &Client{
baseURL: cfg.BaseURL,
model: cfg.Model,
temperature: cfg.Temperature,
maxTokens: cfg.MaxTokens,
http: &http.Client{Timeout: 0},
}
}

type ollamaMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}

type chatOptions struct {
Temperature float64 `json:"temperature"`
NumPredict int `json:"num_predict,omitempty"`
}

type chatRequest struct {
Expand Down Expand Up @@ -66,7 +80,7 @@ func (c *Client) StreamChat(ctx context.Context, messages []llm.Message, out cha
Model: c.model,
Messages: msgs,
Stream: true,
Options: chatOptions{Temperature: 0},
Options: chatOptions{Temperature: c.temperature, NumPredict: c.maxTokens},
})

req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/chat", bytes.NewReader(body))
Expand Down
38 changes: 28 additions & 10 deletions internal/embedder/llm/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ import (

// Client streams chat completions from the OpenAI API.
type Client struct {
baseURL string
model string
apiKey string
http *http.Client
baseURL string
model string
apiKey string
temperature float64
maxTokens int
http *http.Client
}

// New returns an OpenAI chat streaming client.
Expand All @@ -33,15 +35,29 @@ func New(baseURL, model, apiKey string) *Client {
}
}

// NewFromConfig returns an OpenAI client configured from an llm.Config.
func NewFromConfig(cfg llm.Config) *Client {
return &Client{
baseURL: cfg.BaseURL,
model: cfg.Model,
apiKey: cfg.APIKey,
temperature: cfg.Temperature,
maxTokens: cfg.MaxTokens,
http: &http.Client{Timeout: 0},
}
}

type openAIMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}

type streamRequest struct {
Model string `json:"model"`
Messages []openAIMessage `json:"messages"`
Stream bool `json:"stream"`
Model string `json:"model"`
Messages []openAIMessage `json:"messages"`
Stream bool `json:"stream"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
}

type streamChunk struct {
Expand All @@ -63,9 +79,11 @@ func (c *Client) StreamChat(ctx context.Context, messages []llm.Message, out cha
}

body, err := json.Marshal(streamRequest{
Model: c.model,
Messages: msgs,
Stream: true,
Model: c.model,
Messages: msgs,
Stream: true,
Temperature: c.temperature,
MaxTokens: c.maxTokens,
})
if err != nil {
return fmt.Errorf("openai chat marshal request: %w", err)
Expand Down
40 changes: 32 additions & 8 deletions internal/embedder/service/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package service
import (
"context"
"fmt"
"log/slog"
"sort"
"strings"

Expand All @@ -14,22 +15,45 @@ import (
)

type chatService struct {
retrieve domain.VectorRetrieveService
llm llm.Client
reranker llm.Reranker // nil = disabled
topK int
retrieve domain.VectorRetrieveService
llm llm.Client
reranker llm.Reranker // nil = disabled
topK int
defaultCfg llm.Config
factory llm.ClientFactory // builds a client from a per-request config
}

// NewChatService returns a ChatService that retrieves context chunks then
// streams the LLM response. reranker may be nil to skip re-ranking.
func NewChatService(retrieve domain.VectorRetrieveService, llmClient llm.Client, reranker llm.Reranker, topK int) domain.ChatService {
// factory is called to build a temporary client when the request overrides
// the server-default model; it may be nil if per-request overrides are not needed.
func NewChatService(retrieve domain.VectorRetrieveService, llmClient llm.Client, reranker llm.Reranker, topK int, defaultCfg llm.Config, factory llm.ClientFactory) domain.ChatService {
if topK <= 0 {
topK = 15
}
return &chatService{retrieve: retrieve, llm: llmClient, reranker: reranker, topK: topK}
return &chatService{retrieve: retrieve, llm: llmClient, reranker: reranker, topK: topK, defaultCfg: defaultCfg, factory: factory}
}

func (s *chatService) Chat(ctx context.Context, domainID string, messages []domain.ChatMessage, recordIDs []string) (<-chan domain.ChatEvent, error) {
func (s *chatService) Chat(ctx context.Context, domainID string, messages []domain.ChatMessage, recordIDs []string, modelCfg *domain.ModelConfig) (<-chan domain.ChatEvent, error) {
// Build a per-request client if the caller overrides the model.
llmClient := s.llm
if modelCfg != nil && modelCfg.Model != "" && s.factory != nil {
baseURL := modelCfg.BaseURL
if baseURL == "" {
baseURL = s.defaultCfg.BaseURL
}
llmClient = s.factory(llm.Config{
Provider: modelCfg.Provider,
BaseURL: baseURL,
Model: modelCfg.Model,
APIKey: modelCfg.APIKey,
Temperature: modelCfg.Temperature,
MaxTokens: modelCfg.MaxTokens,
})
slog.Info("chat: using per-request model", "provider", modelCfg.Provider, "model", modelCfg.Model)
} else {
slog.Info("chat: using default model", "provider", s.defaultCfg.Provider, "model", s.defaultCfg.Model)
}
query := ""
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
Expand Down Expand Up @@ -147,7 +171,7 @@ func (s *chatService) Chat(ctx context.Context, domainID string, messages []doma
errCh := make(chan error, 1)

go func() {
errCh <- s.llm.StreamChat(ctx, llmMessages, tokenCh)
errCh <- llmClient.StreamChat(ctx, llmMessages, tokenCh)
}()

for {
Expand Down
Loading
Loading