Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ docker/data/
*.backup

# Compiled binaries
embedder
/embedder

# Vite cache
ui/.vite/
21 changes: 19 additions & 2 deletions cmd/embedder/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/ultravioletrs/cube/internal/embedder/ingest/sources/rclone"
s3source "github.com/ultravioletrs/cube/internal/embedder/ingest/sources/s3"
"github.com/ultravioletrs/cube/internal/embedder/llm"
"github.com/ultravioletrs/cube/internal/embedder/llm/guardrails"
"github.com/ultravioletrs/cube/internal/embedder/llm/ollama"
llmopenai "github.com/ultravioletrs/cube/internal/embedder/llm/openai"
"github.com/ultravioletrs/cube/internal/embedder/postgres"
Expand All @@ -50,6 +51,7 @@ type config struct {
embeddingConfig embedding.Config
llmConfig llm.Config
chatTopK int
guardrailsURL string
rerankerModel string
rerankerBaseURL string
storageConfig objstore.Config
Expand Down Expand Up @@ -124,6 +126,7 @@ func loadConfig() config {
APIKey: env("EMBEDDER_LLM_API_KEY", ""),
},
chatTopK: envInt("EMBEDDER_CHAT_TOP_K", 15),
guardrailsURL: env("EMBEDDER_GUARDRAILS_URL", ""),
rerankerModel: env("EMBEDDER_RERANKER_MODEL", ""),
rerankerBaseURL: env("EMBEDDER_RERANKER_BASE_URL", ""),
chunkSize: envInt("EMBEDDER_CHUNK_SIZE", 512),
Expand Down Expand Up @@ -257,12 +260,25 @@ func main() {
llmClient = ollama.New(cfg.llmConfig.BaseURL, cfg.llmConfig.Model)
}

var guardrailsCtrl *guardrails.Controller
if cfg.guardrailsURL != "" {
guardrailsCtrl = guardrails.NewController(guardrails.New(cfg.guardrailsURL))
llmClient = guardrailsCtrl.Wrap(llmClient)
slog.Info("guardrails enabled", "url", cfg.guardrailsURL)
}

// clientFactory builds a per-request LLM client when the caller overrides the model.
clientFactory := llm.ClientFactory(func(cfg llm.Config) llm.Client {
var client llm.Client
if cfg.Provider == "openai" {
return llmopenai.NewFromConfig(cfg)
client = llmopenai.NewFromConfig(cfg)
} else {
client = ollama.NewFromConfig(cfg)
}
if guardrailsCtrl != nil {
return guardrailsCtrl.Wrap(client)
}
return ollama.NewFromConfig(cfg)
return client
})

var reranker llm.Reranker
Expand Down Expand Up @@ -293,6 +309,7 @@ func main() {
cfg.googleOAuthClientID,
cfg.googleOAuthClientSecret,
cfg.llmConfig.BaseURL,
guardrailsCtrl,
)
srv := &http.Server{
Addr: cfg.httpAddr,
Expand Down
2 changes: 2 additions & 0 deletions docker/cube-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ services:
EMBEDDER_LLM_MODEL: ${EMBEDDER_LLM_MODEL}
EMBEDDER_LLM_API_KEY: ${EMBEDDER_LLM_API_KEY}
EMBEDDER_CHAT_TOP_K: ${EMBEDDER_CHAT_TOP_K}
EMBEDDER_GUARDRAILS_URL: ${EMBEDDER_GUARDRAILS_URL:-http://guardrails:8001}
EMBEDDER_RCLONE_PREFLIGHT: "false"
EMBEDDER_RERANKER_MODEL: ${EMBEDDER_RERANKER_MODEL}
EMBEDDER_RERANKER_BASE_URL: ${EMBEDDER_RERANKER_BASE_URL}
volumes:
Expand Down
42 changes: 42 additions & 0 deletions guardrails/src/drivers/rest/routers/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,48 @@ def extract_guardrails_detections(res: Any, response_content: str, original_mess
return detections


@router.post("/validate", tags=["validation"])
async def validate_input(req: ChatRequest) -> Dict[str, Any]:
"""Fast input validation without LLM generation.

Runs only the Python pre-filter (substring pattern matching) — no NeMo,
no LLM call. Returns in <1 ms. Used by the embedder to enforce input
guardrails before starting a streaming Ollama response.
"""
start_time = time.time()

user_text = ""
for m in reversed(req.messages):
if m.role == "user":
user_text = (m.content or "").strip()
break

if not user_text:
return {
"decision": "BLOCK",
"refusal": "I didn't receive a valid message. Please try again.",
"violation_type": "invalid_input",
"latency_ms": 0.0,
}

violation = _check_input(user_text)
if violation:
vtype, refusal = violation
return {
"decision": "BLOCK",
"refusal": refusal,
"violation_type": vtype,
"latency_ms": (time.time() - start_time) * 1000,
}

return {
"decision": "ALLOW",
"refusal": "",
"violation_type": "",
"latency_ms": (time.time() - start_time) * 1000,
}


@router.post("/messages", tags=["chat"])
async def chat_completion(request: Request, req: ChatRequest, authorization: str = Header(None)) -> Dict[str, Any]:
start_time = time.time()
Expand Down
2 changes: 2 additions & 0 deletions internal/embedder/api/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func NewRouter(
googleOAuthClientID string,
googleOAuthClientSecret string,
ollamaBaseURL string,
guardrailsCtrl transport.GuardrailsController,
) http.Handler {
r := chi.NewRouter()
r.Use(chimw.Recoverer)
Expand All @@ -49,6 +50,7 @@ func NewRouter(
transport.MountRetrieve(r, retrieveSvc)
transport.MountChat(r, chatSvc, conversationsRepo)
transport.MountConversations(r, conversationsRepo)
transport.MountGuardrails(r, guardrailsCtrl)
})

return r
Expand Down
59 changes: 59 additions & 0 deletions internal/embedder/api/transport/guardrails.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0

package transport

import (
"encoding/json"
"net/http"

"github.com/go-chi/chi/v5"
)

// GuardrailsController is satisfied by *guardrails.GuardedClient.
// Defined here to avoid importing the guardrails package from the transport layer.
type GuardrailsController interface {
IsEnabled() bool
SetEnabled(bool)
}

// MountGuardrails registers the guardrails status endpoints.
// If ctrl is nil (guardrails not configured), the endpoints still respond but
// always report configured=false and ignore enable/disable requests.
func MountGuardrails(r chi.Router, ctrl GuardrailsController) {
r.Get("/api/v1/guardrails", guardrailsStatusHandler(ctrl))
r.Put("/api/v1/guardrails", guardrailsSetHandler(ctrl))
}

type guardrailsStatusResponse struct {
Enabled bool `json:"enabled"`
Configured bool `json:"configured"`
}

func guardrailsStatusHandler(ctrl GuardrailsController) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if ctrl == nil {
writeJSON(w, http.StatusOK, guardrailsStatusResponse{Enabled: false, Configured: false})
return
}
writeJSON(w, http.StatusOK, guardrailsStatusResponse{Enabled: ctrl.IsEnabled(), Configured: true})
}
}

func guardrailsSetHandler(ctrl GuardrailsController) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var body struct {
Enabled bool `json:"enabled"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeJSON(w, http.StatusBadRequest, errBody("invalid request body"))
return
}
if ctrl == nil {
writeJSON(w, http.StatusOK, guardrailsStatusResponse{Enabled: false, Configured: false})
return
}
ctrl.SetEnabled(body.Enabled)
writeJSON(w, http.StatusOK, guardrailsStatusResponse{Enabled: ctrl.IsEnabled(), Configured: true})
}
}
137 changes: 137 additions & 0 deletions internal/embedder/llm/guardrails/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0

package guardrails

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"sync/atomic"
"time"

"github.com/ultravioletrs/cube/internal/embedder/llm"
)

type Client struct {
baseURL string
http *http.Client
}

func New(baseURL string) *Client {
return &Client{
baseURL: baseURL,
http: &http.Client{Timeout: 5 * time.Second},
}
}

type validateMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}

type validateRequest struct {
Messages []validateMessage `json:"messages"`
}

type validateResponse struct {
Decision string `json:"decision"`
Refusal string `json:"refusal"`
ViolationType string `json:"violation_type"`
LatencyMs float64 `json:"latency_ms"`
}

// Check validates messages against guardrails input filters.
// Returns allow=true if the input is safe, or allow=false with the refusal
// message if it was blocked. Returns an error if the service is unreachable.
func (c *Client) Check(ctx context.Context, messages []llm.Message) (allow bool, refusal string, err error) {
msgs := make([]validateMessage, len(messages))
for i, m := range messages {
msgs[i] = validateMessage{Role: m.Role, Content: m.Content}
}

body, _ := json.Marshal(validateRequest{Messages: msgs})
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/guardrails/validate", bytes.NewReader(body))
if err != nil {
return false, "", err
}
req.Header.Set("Content-Type", "application/json")

resp, err := c.http.Do(req)
if err != nil {
return false, "", fmt.Errorf("guardrails validate: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return false, "", fmt.Errorf("guardrails validate status %d", resp.StatusCode)
}

var result validateResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return false, "", fmt.Errorf("guardrails decode: %w", err)
}

return result.Decision == "ALLOW", result.Refusal, nil
}

// Controller owns the runtime guardrails toggle shared by all wrapped LLM clients.
type Controller struct {
checker *Client
enabled atomic.Bool
}

// NewController returns a guardrails controller enabled by default.
func NewController(checker *Client) *Controller {
ctrl := &Controller{checker: checker}
ctrl.enabled.Store(true)
return ctrl
}

func (c *Controller) IsEnabled() bool { return c.enabled.Load() }
func (c *Controller) SetEnabled(v bool) { c.enabled.Store(v) }

// Wrap returns an LLM client guarded by this controller.
func (c *Controller) Wrap(inner llm.Client) *GuardedClient {
return &GuardedClient{inner: inner, controller: c}
}

// GuardedClient wraps any llm.Client with guardrails input validation.
// Blocked messages are returned as a single token (the refusal text) without
// calling the inner LLM. Allowed messages pass through to the inner client.
// The enabled flag can be toggled at runtime without restarting.
type GuardedClient struct {
inner llm.Client
controller *Controller
}

// NewGuardedClient returns a GuardedClient with guardrails enabled by default.
func NewGuardedClient(inner llm.Client, checker *Client) *GuardedClient {
return NewController(checker).Wrap(inner)
}

func (g *GuardedClient) IsEnabled() bool { return g.controller.IsEnabled() }
func (g *GuardedClient) SetEnabled(v bool) { g.controller.SetEnabled(v) }

func (g *GuardedClient) StreamChat(ctx context.Context, messages []llm.Message, out chan<- string) error {
if g.controller.IsEnabled() {
allow, refusal, err := g.controller.checker.Check(ctx, messages)
if err != nil {
defer close(out)
return fmt.Errorf("guardrails unavailable: %w", err)
}
if !allow {
defer close(out)
if refusal != "" {
select {
case out <- refusal:
case <-ctx.Done():
}
}
return nil
}
}
return g.inner.StreamChat(ctx, messages, out)
}
Loading
Loading