diff --git a/control-plane/internal/agui/events.go b/control-plane/internal/agui/events.go new file mode 100644 index 00000000..f0c39d7c --- /dev/null +++ b/control-plane/internal/agui/events.go @@ -0,0 +1,517 @@ +// Package agui implements a minimal subset of the AG-UI protocol +// (https://docs.ag-ui.com/concepts/events) so the control plane can emit an +// AG-UI-compatible Server-Sent Events stream that frontends like CopilotKit +// can consume. +// +// Wire format and field shapes are kept faithful to the reference TypeScript +// and Python SDKs at https://github.com/ag-ui-protocol/ag-ui: +// +// - SSE frames are `data: \n\n` only — no `event:` line. The TS +// EventEncoder.encodeSSE and the Python EventEncoder._encode_sse both +// emit exactly this; the discriminator lives in the JSON `type` field. +// - Event type values are UPPER_SNAKE_CASE (RUN_STARTED, TEXT_MESSAGE_CONTENT, …), +// matching the EventType enum the reference clients validate against. +// - `timestamp` is an optional Unix-millisecond integer. +// - Optional fields are omitted when empty (mirrors `exclude_none=True`). +// +// This is the POC subset — lifecycle + a single TextMessage carrying the +// reasoner's final result. Token-level streaming, tool-call frames, and +// state deltas land in subsequent iterations. +package agui + +import ( + "encoding/json" + "fmt" + "io" + "time" +) + +// Event is implemented by every AG-UI event payload. The Type method returns +// the canonical AG-UI event name used in the JSON `type` field (e.g. +// "RUN_STARTED"). It is exposed so the SSE writer can name the frame in +// errors and logs without re-marshaling. +type Event interface { + Type() string +} + +// RunStarted signals the beginning of an agent run. +// +// The `input` field is intentionally omitted from this struct: the reference +// schema types it as RunAgentInput (threadId/runId/state/messages/tools/ +// context/forwardedProps), not a freeform map. Until we plumb that structured +// shape through, we surface `threadId` and `runId` only — strict clients +// validating against RunAgentInputSchema would reject a freeform map here. +type RunStarted struct { + ThreadID string `json:"threadId"` + RunID string `json:"runId"` + ParentRunID string `json:"parentRunId,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (RunStarted) Type() string { return "RUN_STARTED" } + +func (e RunStarted) MarshalJSON() ([]byte, error) { + type alias RunStarted + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// RunFinished signals a successful (or interrupted) run completion. +// Per the reference schema both threadId and runId are required. +type RunFinished struct { + ThreadID string `json:"threadId"` + RunID string `json:"runId"` + Outcome *Outcome `json:"outcome,omitempty"` + Result any `json:"result,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (RunFinished) Type() string { return "RUN_FINISHED" } + +func (e RunFinished) MarshalJSON() ([]byte, error) { + type alias RunFinished + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// Outcome is a discriminated union: {type: "success"} | {type: "interrupt", interrupts: [...]}. +type Outcome struct { + Type string `json:"type"` + Interrupts []Interrupt `json:"interrupts,omitempty"` +} + +// Interrupt represents a pause point requiring external resolution +// (e.g. human approval). Reserved for HITL flows; not used by the POC. +type Interrupt struct { + ID string `json:"id"` + Reason string `json:"reason,omitempty"` +} + +// RunError signals an unrecoverable failure. Terminates the stream. +type RunError struct { + Message string `json:"message"` + Code string `json:"code,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (RunError) Type() string { return "RUN_ERROR" } + +func (e RunError) MarshalJSON() ([]byte, error) { + type alias RunError + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// TextMessageStart opens an assistant text message. Subsequent +// TextMessageContent events with the same messageId carry the body. +type TextMessageStart struct { + MessageID string `json:"messageId"` + Role string `json:"role,omitempty"` // defaults to "assistant" client-side when omitted + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (TextMessageStart) Type() string { return "TEXT_MESSAGE_START" } + +func (e TextMessageStart) MarshalJSON() ([]byte, error) { + type alias TextMessageStart + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// TextMessageContent carries one chunk of the assistant message body. +// The POC emits a single content event with the full reasoner result; +// once reasoner-side streaming lands, this will be emitted per token chunk. +type TextMessageContent struct { + MessageID string `json:"messageId"` + Delta string `json:"delta"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (TextMessageContent) Type() string { return "TEXT_MESSAGE_CONTENT" } + +func (e TextMessageContent) MarshalJSON() ([]byte, error) { + type alias TextMessageContent + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// TextMessageEnd closes a text message. +type TextMessageEnd struct { + MessageID string `json:"messageId"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (TextMessageEnd) Type() string { return "TEXT_MESSAGE_END" } + +func (e TextMessageEnd) MarshalJSON() ([]byte, error) { + type alias TextMessageEnd + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// ToolCallStart opens a tool-call frame. CopilotKit pattern-matches +// `toolCallName` against `useCopilotAction({name, render})` registrations +// to drive Generative UI — there is no separate "render" event. +type ToolCallStart struct { + ToolCallID string `json:"toolCallId"` + ToolCallName string `json:"toolCallName"` + ParentMessageID string `json:"parentMessageId,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (ToolCallStart) Type() string { return "TOOL_CALL_START" } + +func (e ToolCallStart) MarshalJSON() ([]byte, error) { + type alias ToolCallStart + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// ToolCallArgs streams a chunk of the tool-call arguments JSON. Frontends +// concatenate deltas to assemble the full arguments object before invoking +// the action handler. +type ToolCallArgs struct { + ToolCallID string `json:"toolCallId"` + Delta string `json:"delta"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (ToolCallArgs) Type() string { return "TOOL_CALL_ARGS" } + +func (e ToolCallArgs) MarshalJSON() ([]byte, error) { + type alias ToolCallArgs + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// ToolCallEnd closes a tool-call frame. +type ToolCallEnd struct { + ToolCallID string `json:"toolCallId"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (ToolCallEnd) Type() string { return "TOOL_CALL_END" } + +func (e ToolCallEnd) MarshalJSON() ([]byte, error) { + type alias ToolCallEnd + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// ToolCallResult delivers the outcome of a server-side tool call. For +// frontend-handled tools (via useCopilotAction), the result instead arrives +// as the next inbound POST's trailing tool-role message — no TOOL_CALL_RESULT +// event is emitted by the backend in that flow. +type ToolCallResult struct { + MessageID string `json:"messageId"` + ToolCallID string `json:"toolCallId"` + Content string `json:"content"` + Role string `json:"role,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (ToolCallResult) Type() string { return "TOOL_CALL_RESULT" } + +func (e ToolCallResult) MarshalJSON() ([]byte, error) { + type alias ToolCallResult + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// MessagesSnapshot publishes the full conversation after a turn so clients +// can refresh their canonical thread state. CopilotKit's in-memory runtime +// derives persisted history from the trailing snapshot. +type MessagesSnapshot struct { + Messages []Message `json:"messages"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (MessagesSnapshot) Type() string { return "MESSAGES_SNAPSHOT" } + +func (e MessagesSnapshot) MarshalJSON() ([]byte, error) { + type alias MessagesSnapshot + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// StateSnapshot publishes the agent's full shared state — the value +// `useCoAgent({ state })` reads on the frontend. Reasoners opt in by +// returning a top-level `state` field. +type StateSnapshot struct { + Snapshot any `json:"snapshot"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (StateSnapshot) Type() string { return "STATE_SNAPSHOT" } + +func (e StateSnapshot) MarshalJSON() ([]byte, error) { + type alias StateSnapshot + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// StateDelta carries an RFC 6902 JSON Patch document applied incrementally +// to the previously-emitted snapshot. Optional alternative to repeatedly +// emitting full snapshots. +type StateDelta struct { + Delta []any `json:"delta"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (StateDelta) Type() string { return "STATE_DELTA" } + +func (e StateDelta) MarshalJSON() ([]byte, error) { + type alias StateDelta + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// StepStarted / StepFinished mark a named "step" inside a run. CopilotKit's +// chat UI ignores these (per the upstream GOTCHAS.md) but other AG-UI +// consumers — agent-trace viewers, debuggers, custom runtimes — render +// them as a hierarchical activity log. Defining the types lets reasoners +// surface step boundaries without us inventing a private vocabulary. +type StepStarted struct { + StepName string `json:"stepName"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (StepStarted) Type() string { return "STEP_STARTED" } + +func (e StepStarted) MarshalJSON() ([]byte, error) { + type alias StepStarted + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +type StepFinished struct { + StepName string `json:"stepName"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (StepFinished) Type() string { return "STEP_FINISHED" } + +func (e StepFinished) MarshalJSON() ([]byte, error) { + type alias StepFinished + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// RawEvent passes a foreign-system event through verbatim. `source` names +// the originating system (e.g. "openai", "harness", "langchain"); `event` +// is the original payload, opaque to AG-UI. Frontends can subscribe with +// onRawEvent for app-specific handling. +type RawEvent struct { + Event any `json:"event"` + Source string `json:"source,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (RawEvent) Type() string { return "RAW" } + +func (e RawEvent) MarshalJSON() ([]byte, error) { + type alias RawEvent + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// CustomEvent carries an application-defined event. `name` is the +// dispatch key frontends listen on; `value` is freeform JSON. +type CustomEvent struct { + Name string `json:"name"` + Value any `json:"value,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (CustomEvent) Type() string { return "CUSTOM" } + +func (e CustomEvent) MarshalJSON() ([]byte, error) { + type alias CustomEvent + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// ReasoningStart opens a reasoning context — the agent is "thinking" +// before producing a user-facing response. CopilotKit and similar +// frontends render REASONING_* sequences in a collapsible "Thinking…" +// pane, surfacing chain-of-thought from models that support it (Claude +// extended thinking, OpenAI o-series). +type ReasoningStart struct { + MessageID string `json:"messageId"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (ReasoningStart) Type() string { return "REASONING_START" } + +func (e ReasoningStart) MarshalJSON() ([]byte, error) { + type alias ReasoningStart + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// ReasoningMessageStart opens a single reasoning message inside a +// REASONING_START / END boundary. Role is always "reasoning" per the +// upstream schema. +type ReasoningMessageStart struct { + MessageID string `json:"messageId"` + Role string `json:"role"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (ReasoningMessageStart) Type() string { return "REASONING_MESSAGE_START" } + +func (e ReasoningMessageStart) MarshalJSON() ([]byte, error) { + type alias ReasoningMessageStart + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +type ReasoningMessageContent struct { + MessageID string `json:"messageId"` + Delta string `json:"delta"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (ReasoningMessageContent) Type() string { return "REASONING_MESSAGE_CONTENT" } + +func (e ReasoningMessageContent) MarshalJSON() ([]byte, error) { + type alias ReasoningMessageContent + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +type ReasoningMessageEnd struct { + MessageID string `json:"messageId"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (ReasoningMessageEnd) Type() string { return "REASONING_MESSAGE_END" } + +func (e ReasoningMessageEnd) MarshalJSON() ([]byte, error) { + type alias ReasoningMessageEnd + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +type ReasoningEnd struct { + MessageID string `json:"messageId"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (ReasoningEnd) Type() string { return "REASONING_END" } + +func (e ReasoningEnd) MarshalJSON() ([]byte, error) { + type alias ReasoningEnd + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// TextMessageChunk is the compact form of TEXT_MESSAGE_START → _CONTENT +// → _END: one event opens an implicit message, attaches a delta, and an +// empty delta closes it. Useful for streaming over slow links. +type TextMessageChunk struct { + MessageID string `json:"messageId,omitempty"` + Role string `json:"role,omitempty"` + Delta string `json:"delta,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (TextMessageChunk) Type() string { return "TEXT_MESSAGE_CHUNK" } + +func (e TextMessageChunk) MarshalJSON() ([]byte, error) { + type alias TextMessageChunk + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// ToolCallChunk is the compact form of TOOL_CALL_START → _ARGS → _END: +// one event per tool-call delta. Either toolCallId+toolCallName open an +// implicit call, repeated delta-only chunks accumulate args, an empty +// delta closes it. +type ToolCallChunk struct { + ToolCallID string `json:"toolCallId,omitempty"` + ToolCallName string `json:"toolCallName,omitempty"` + ParentMessageID string `json:"parentMessageId,omitempty"` + Delta string `json:"delta,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +func (ToolCallChunk) Type() string { return "TOOL_CALL_CHUNK" } + +func (e ToolCallChunk) MarshalJSON() ([]byte, error) { + type alias ToolCallChunk + return json.Marshal(struct { + Type string `json:"type"` + alias + }{Type: e.Type(), alias: alias(e)}) +} + +// NowMillis returns the current Unix time in milliseconds. Wrapped so tests +// can replace it. Milliseconds match the JS `Date.now()` convention that +// AG-UI clients are most likely to interpret correctly. +var NowMillis = func() int64 { return time.Now().UnixMilli() } + +// WriteSSE writes one AG-UI event to w in the canonical wire format used by +// the reference TS and Python encoders: +// +// data: +// +// (followed by a blank line). The discriminator is in the JSON `type` field, +// not in an SSE `event:` line — clients dispatch on the JSON `type`. Caller +// is responsible for flushing. +func WriteSSE(w io.Writer, ev Event) error { + payload, err := json.Marshal(ev) + if err != nil { + return fmt.Errorf("marshal %s: %w", ev.Type(), err) + } + if _, err := fmt.Fprintf(w, "data: %s\n\n", payload); err != nil { + return fmt.Errorf("write %s: %w", ev.Type(), err) + } + return nil +} diff --git a/control-plane/internal/agui/events_test.go b/control-plane/internal/agui/events_test.go new file mode 100644 index 00000000..4f24b6ac --- /dev/null +++ b/control-plane/internal/agui/events_test.go @@ -0,0 +1,266 @@ +package agui + +import ( + "bytes" + "encoding/json" + "strings" + "testing" +) + +// TestWriteSSE_FrameShape pins the canonical AG-UI wire format: +// - frame is `data: \n\n` only (no `event:` line — see encoder.ts / +// encoder.py in ag-ui-protocol/ag-ui) +// - `type` field carries the UPPER_SNAKE_CASE event name +// - timestamp, when present, is a number (Unix ms) +func TestWriteSSE_FrameShape(t *testing.T) { + cases := []struct { + name string + ev Event + wantTyp string + wantFields []string + }{ + { + name: "RunStarted", + ev: RunStarted{ThreadID: "thread-1", RunID: "run-1", Timestamp: 1700000000000}, + wantTyp: "RUN_STARTED", + wantFields: []string{`"threadId":"thread-1"`, `"runId":"run-1"`, `"timestamp":1700000000000`}, + }, + { + name: "RunFinished_success_carriesIDs", + ev: RunFinished{ThreadID: "thread-1", RunID: "run-1", Outcome: &Outcome{Type: "success"}, Result: map[string]any{"answer": 42}}, + wantTyp: "RUN_FINISHED", + wantFields: []string{`"threadId":"thread-1"`, `"runId":"run-1"`, `"outcome":{"type":"success"}`, `"answer":42`}, + }, + { + name: "RunError", + ev: RunError{Message: "boom", Code: "ERR_X"}, + wantTyp: "RUN_ERROR", + wantFields: []string{`"message":"boom"`, `"code":"ERR_X"`}, + }, + { + name: "TextMessageStart", + ev: TextMessageStart{MessageID: "msg-1", Role: "assistant"}, + wantTyp: "TEXT_MESSAGE_START", + wantFields: []string{`"messageId":"msg-1"`, `"role":"assistant"`}, + }, + { + name: "TextMessageContent", + ev: TextMessageContent{MessageID: "msg-1", Delta: "hello"}, + wantTyp: "TEXT_MESSAGE_CONTENT", + wantFields: []string{`"messageId":"msg-1"`, `"delta":"hello"`}, + }, + { + name: "TextMessageEnd", + ev: TextMessageEnd{MessageID: "msg-1"}, + wantTyp: "TEXT_MESSAGE_END", + wantFields: []string{`"messageId":"msg-1"`}, + }, + { + name: "ToolCallStart", + ev: ToolCallStart{ToolCallID: "tc-1", ToolCallName: "showFlightCard", ParentMessageID: "msg-1"}, + wantTyp: "TOOL_CALL_START", + wantFields: []string{`"toolCallId":"tc-1"`, `"toolCallName":"showFlightCard"`, `"parentMessageId":"msg-1"`}, + }, + { + name: "ToolCallArgs", + ev: ToolCallArgs{ToolCallID: "tc-1", Delta: `{"from":"SFO"}`}, + wantTyp: "TOOL_CALL_ARGS", + wantFields: []string{`"toolCallId":"tc-1"`, `"delta":"{\"from\":\"SFO\"}"`}, + }, + { + name: "ToolCallEnd", + ev: ToolCallEnd{ToolCallID: "tc-1"}, + wantTyp: "TOOL_CALL_END", + wantFields: []string{`"toolCallId":"tc-1"`}, + }, + { + name: "ToolCallResult", + ev: ToolCallResult{MessageID: "msg-2", ToolCallID: "tc-1", Content: "ok", Role: "tool"}, + wantTyp: "TOOL_CALL_RESULT", + wantFields: []string{`"messageId":"msg-2"`, `"toolCallId":"tc-1"`, `"content":"ok"`, `"role":"tool"`}, + }, + { + name: "MessagesSnapshot", + ev: MessagesSnapshot{Messages: []Message{{ID: "m1", Role: "user", Content: "hi"}}}, + wantTyp: "MESSAGES_SNAPSHOT", + wantFields: []string{`"messages":[`, `"role":"user"`, `"content":"hi"`}, + }, + { + name: "StateSnapshot", + ev: StateSnapshot{Snapshot: map[string]any{"counter": 1}}, + wantTyp: "STATE_SNAPSHOT", + wantFields: []string{`"snapshot":{"counter":1}`}, + }, + { + name: "StateDelta", + ev: StateDelta{Delta: []any{map[string]any{"op": "replace", "path": "/counter", "value": 2}}}, + wantTyp: "STATE_DELTA", + wantFields: []string{`"delta":[`, `"op":"replace"`, `"path":"/counter"`}, + }, + { + name: "StepStarted", + ev: StepStarted{StepName: "plan"}, + wantTyp: "STEP_STARTED", + wantFields: []string{`"stepName":"plan"`}, + }, + { + name: "StepFinished", + ev: StepFinished{StepName: "plan"}, + wantTyp: "STEP_FINISHED", + wantFields: []string{`"stepName":"plan"`}, + }, + { + name: "RawEvent", + ev: RawEvent{Event: map[string]any{"foo": 1}, Source: "harness"}, + wantTyp: "RAW", + wantFields: []string{`"event":{"foo":1}`, `"source":"harness"`}, + }, + { + name: "CustomEvent", + ev: CustomEvent{Name: "ack", Value: map[string]any{"ok": true}}, + wantTyp: "CUSTOM", + wantFields: []string{`"name":"ack"`, `"value":{"ok":true}`}, + }, + { + name: "ReasoningStart", + ev: ReasoningStart{MessageID: "r1"}, + wantTyp: "REASONING_START", + wantFields: []string{`"messageId":"r1"`}, + }, + { + name: "ReasoningMessageStart", + ev: ReasoningMessageStart{MessageID: "r1", Role: "reasoning"}, + wantTyp: "REASONING_MESSAGE_START", + wantFields: []string{`"messageId":"r1"`, `"role":"reasoning"`}, + }, + { + name: "ReasoningMessageContent", + ev: ReasoningMessageContent{MessageID: "r1", Delta: "thinking..."}, + wantTyp: "REASONING_MESSAGE_CONTENT", + wantFields: []string{`"messageId":"r1"`, `"delta":"thinking..."`}, + }, + { + name: "ReasoningMessageEnd", + ev: ReasoningMessageEnd{MessageID: "r1"}, + wantTyp: "REASONING_MESSAGE_END", + wantFields: []string{`"messageId":"r1"`}, + }, + { + name: "ReasoningEnd", + ev: ReasoningEnd{MessageID: "r1"}, + wantTyp: "REASONING_END", + wantFields: []string{`"messageId":"r1"`}, + }, + { + name: "TextMessageChunk", + ev: TextMessageChunk{MessageID: "m1", Delta: "tok"}, + wantTyp: "TEXT_MESSAGE_CHUNK", + wantFields: []string{`"messageId":"m1"`, `"delta":"tok"`}, + }, + { + name: "ToolCallChunk", + ev: ToolCallChunk{ToolCallID: "tc1", ToolCallName: "go", Delta: "{\"a\":1}"}, + wantTyp: "TOOL_CALL_CHUNK", + wantFields: []string{`"toolCallId":"tc1"`, `"toolCallName":"go"`, `"delta":"{\"a\":1}"`}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + if err := WriteSSE(&buf, tc.ev); err != nil { + t.Fatalf("WriteSSE: %v", err) + } + frame := buf.String() + + // Canonical wire shape: `data: \n\n`. No `event:` line. + if !strings.HasPrefix(frame, "data: ") { + t.Fatalf("frame must start with `data: `:\n%s", frame) + } + if !strings.HasSuffix(frame, "\n\n") { + t.Fatalf("frame must end with blank-line terminator:\n%s", frame) + } + if strings.Contains(frame, "\nevent:") || strings.HasPrefix(frame, "event:") { + t.Fatalf("frame must not include an `event:` line (canonical encoder omits it):\n%s", frame) + } + + body := strings.TrimSuffix(strings.TrimPrefix(frame, "data: "), "\n\n") + var decoded map[string]any + if err := json.Unmarshal([]byte(body), &decoded); err != nil { + t.Fatalf("data line is not JSON: %v\nbody: %s", err, body) + } + if got := decoded["type"]; got != tc.wantTyp { + t.Fatalf("json type field = %v, want %q", got, tc.wantTyp) + } + for _, want := range tc.wantFields { + if !strings.Contains(body, want) { + t.Fatalf("expected field %s in payload:\n%s", want, body) + } + } + }) + } +} + +// TestWriteSSE_OmitsZeroOptionalFields confirms our `omitempty` tags drop +// timestamp / role / outcome / code when they're at zero values, matching +// the Python encoder's `exclude_none=True` semantics. +func TestWriteSSE_OmitsZeroOptionalFields(t *testing.T) { + var buf bytes.Buffer + if err := WriteSSE(&buf, TextMessageStart{MessageID: "m"}); err != nil { + t.Fatal(err) + } + body := buf.String() + if strings.Contains(body, `"role":""`) { + t.Errorf("empty role should be omitted: %s", body) + } + if strings.Contains(body, `"timestamp":0`) { + t.Errorf("zero timestamp should be omitted: %s", body) + } +} + +// unmarshalableEvent fails JSON encoding deterministically so we can exercise +// the marshal-error branch in WriteSSE. +type unmarshalableEvent struct{} + +func (unmarshalableEvent) Type() string { return "BAD_EVENT" } +func (unmarshalableEvent) MarshalJSON() ([]byte, error) { return nil, errBoom } + +var errBoom = &boomError{} + +type boomError struct{} + +func (b *boomError) Error() string { return "boom" } + +// TestWriteSSE_MarshalErrorIsReturned ensures encode failures surface to the +// caller rather than producing a silently-malformed frame. +func TestWriteSSE_MarshalErrorIsReturned(t *testing.T) { + var buf bytes.Buffer + err := WriteSSE(&buf, unmarshalableEvent{}) + if err == nil { + t.Fatalf("expected marshal error, got nil; buf=%q", buf.String()) + } + if !strings.Contains(err.Error(), "marshal BAD_EVENT") { + t.Errorf("error should name the event type: %v", err) + } + if buf.Len() != 0 { + t.Errorf("nothing should be written on marshal failure; got %q", buf.String()) + } +} + +// failingWriter returns an error on every Write — used to cover the +// write-error branch of WriteSSE. +type failingWriter struct{} + +func (failingWriter) Write([]byte) (int, error) { return 0, errBoom } + +// TestWriteSSE_WriteErrorIsReturned confirms a flaky writer surfaces to the +// caller (the handler uses this to bail out cleanly on client disconnect). +func TestWriteSSE_WriteErrorIsReturned(t *testing.T) { + err := WriteSSE(failingWriter{}, RunStarted{ThreadID: "t", RunID: "r"}) + if err == nil { + t.Fatalf("expected write error, got nil") + } + if !strings.Contains(err.Error(), "write RUN_STARTED") { + t.Errorf("error should name the event type: %v", err) + } +} diff --git a/control-plane/internal/agui/types.go b/control-plane/internal/agui/types.go new file mode 100644 index 00000000..a9c06846 --- /dev/null +++ b/control-plane/internal/agui/types.go @@ -0,0 +1,77 @@ +package agui + +import "encoding/json" + +// RunAgentInput mirrors the canonical RunAgentInputSchema from the AG-UI +// reference SDK (sdks/typescript/packages/core/src/types.ts). The vanilla +// @ag-ui/client HttpAgent — and the CopilotRuntime that wraps it — POSTs +// this exact shape to backends. +// +// We keep all fields permissive (json.RawMessage / any) so unrecognized +// or evolving sub-fields pass through without forcing schema bumps. +type RunAgentInput struct { + ThreadID string `json:"threadId"` + RunID string `json:"runId"` + ParentRunID string `json:"parentRunId,omitempty"` + State json.RawMessage `json:"state,omitempty"` + Messages []Message `json:"messages,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Context []ContextItem `json:"context,omitempty"` + ForwardedProps json.RawMessage `json:"forwardedProps,omitempty"` + Resume []json.RawMessage `json:"resume,omitempty"` +} + +// Message is the canonical AG-UI message envelope (MessageSchema). Role +// drives the discriminated union; we keep optional fields so user/assistant/ +// tool messages all round-trip through the same struct. +type Message struct { + ID string `json:"id,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Name string `json:"name,omitempty"` + ToolCallID string `json:"toolCallId,omitempty"` + ToolCalls []ToolCall `json:"toolCalls,omitempty"` +} + +// ToolCall is an assistant-message-attached tool invocation, matching +// ToolCallSchema. The function arguments are a JSON string per OpenAI +// convention. +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` // always "function" today + Function ToolCallFunction `json:"function"` +} + +type ToolCallFunction struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// Tool describes a tool the frontend has registered (e.g. via +// useCopilotAction). Reasoners can choose to invoke these by emitting a +// matching TOOL_CALL_* sequence. +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` // JSON Schema + Metadata json.RawMessage `json:"metadata,omitempty"` +} + +// ContextItem is one (description, value) pair from the readables stream +// (e.g. useCopilotReadable). Value is freeform JSON. +type ContextItem struct { + Description string `json:"description,omitempty"` + Value json.RawMessage `json:"value,omitempty"` +} + +// LastUserMessageText returns the trailing user-role message's content, +// which is the conventional "prompt" for chat-style agents. Empty string +// if the trailing message is not user-role or messages is empty. +func (r RunAgentInput) LastUserMessageText() string { + for i := len(r.Messages) - 1; i >= 0; i-- { + if r.Messages[i].Role == "user" { + return r.Messages[i].Content + } + } + return "" +} diff --git a/control-plane/internal/agui/types_test.go b/control-plane/internal/agui/types_test.go new file mode 100644 index 00000000..65565a5e --- /dev/null +++ b/control-plane/internal/agui/types_test.go @@ -0,0 +1,59 @@ +package agui + +import "testing" + +// TestLastUserMessageText covers the trailing-user-message extractor that +// the handler uses to populate the reasoner's `prompt` convenience field. +func TestLastUserMessageText(t *testing.T) { + cases := []struct { + name string + in RunAgentInput + want string + }{ + { + name: "empty messages", + in: RunAgentInput{}, + want: "", + }, + { + name: "single user message", + in: RunAgentInput{Messages: []Message{ + {Role: "user", Content: "hi"}, + }}, + want: "hi", + }, + { + name: "skips trailing assistant turn — picks last user message", + in: RunAgentInput{Messages: []Message{ + {Role: "user", Content: "first"}, + {Role: "assistant", Content: "ack"}, + {Role: "user", Content: "second"}, + {Role: "assistant", Content: "ack2"}, + }}, + want: "second", + }, + { + name: "skips trailing tool message", + in: RunAgentInput{Messages: []Message{ + {Role: "user", Content: "kick off"}, + {Role: "tool", ToolCallID: "tc1", Content: "tool-output"}, + }}, + want: "kick off", + }, + { + name: "no user messages", + in: RunAgentInput{Messages: []Message{ + {Role: "system", Content: "you are helpful"}, + {Role: "assistant", Content: "ok"}, + }}, + want: "", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := tc.in.LastUserMessageText(); got != tc.want { + t.Fatalf("LastUserMessageText() = %q, want %q", got, tc.want) + } + }) + } +} diff --git a/control-plane/internal/handlers/agui_runs.go b/control-plane/internal/handlers/agui_runs.go new file mode 100644 index 00000000..df492c3b --- /dev/null +++ b/control-plane/internal/handlers/agui_runs.go @@ -0,0 +1,746 @@ +package handlers + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Agent-Field/agentfield/control-plane/internal/agui" + "github.com/Agent-Field/agentfield/control-plane/internal/storage" + "github.com/Agent-Field/agentfield/control-plane/internal/utils" + "github.com/Agent-Field/agentfield/control-plane/pkg/types" + + "github.com/gin-gonic/gin" +) + +// AGUIHeartbeatInterval is how often we emit an SSE comment (`: keep-alive`) +// while waiting for a slow reasoner. AG-UI clients silently drop comment +// lines per the SSE spec, but proxies (nginx, ALBs) see the bytes and don't +// idle out the connection. 15s leaves comfortable headroom under the 60s +// nginx default. Exposed for tests. +var AGUIHeartbeatInterval = 15 * time.Second + +// agentInvocation is the result of calling the agent reasoner endpoint: +// either a fully buffered body (for traditional reasoners that return a +// single JSON object) or a live io.ReadCloser carrying NDJSON chunks +// (for streaming reasoners that yield events as they happen). Exactly +// one of Body / Stream is non-nil; ContentType disambiguates. +type agentInvocation struct { + Body []byte + Stream io.ReadCloser + ContentType string +} + +// IsStreaming reports whether the reasoner returned an NDJSON stream +// (Content-Type: application/x-ndjson) the handler should consume +// chunk-by-chunk and forward as live AG-UI events. +func (r *agentInvocation) IsStreaming() bool { + return r != nil && r.Stream != nil +} + +// agentInvoker abstracts the outbound HTTP call to the agent's reasoner so +// tests can stub behavior without spinning up a real server. The default +// implementation (httpAgentInvoker) buffers the body for non-NDJSON +// responses and hands back the live stream for NDJSON. +type agentInvoker interface { + Invoke(ctx context.Context, agent *types.AgentNode, reasonerName string, input []byte) (*agentInvocation, error) +} + +type httpAgentInvoker struct{ client *http.Client } + +func (i httpAgentInvoker) Invoke(ctx context.Context, agent *types.AgentNode, reasonerName string, input []byte) (*agentInvocation, error) { + url := fmt.Sprintf("%s/reasoners/%s", agent.BaseURL, reasonerName) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(input)) + if err != nil { + return nil, fmt.Errorf("create agent request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + // Tell the reasoner we accept either a plain JSON response (buffered + // path) or NDJSON (streaming path). Reasoners that opted into + // streaming can switch on Accept; reasoners that didn't ignore it. + req.Header.Set("Accept", "application/x-ndjson, application/json") + + client := i.client + if client == nil { + client = http.DefaultClient + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("agent call failed: %w", err) + } + + ct := resp.Header.Get("Content-Type") + // Streaming response: hand the body straight back. Caller is + // responsible for closing it; we don't read it here. + if resp.StatusCode < http.StatusBadRequest && strings.HasPrefix(ct, "application/x-ndjson") { + return &agentInvocation{Stream: resp.Body, ContentType: ct}, nil + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read agent response: %w", err) + } + if resp.StatusCode >= http.StatusBadRequest { + return &agentInvocation{Body: body, ContentType: ct}, + fmt.Errorf("agent returned %d: %s", resp.StatusCode, truncateForLog(body)) + } + return &agentInvocation{Body: body, ContentType: ct}, nil +} + +// AGUIRunHandler handles POST /api/v1/agui/runs/:node_id/:reasoner_name. +// +// It is the AG-UI protocol adapter: clients (CopilotKit's CopilotRuntime +// proxying through @ag-ui/client's HttpAgent, or any other AG-UI consumer) +// post a canonical RunAgentInput body, the handler invokes the named +// reasoner, and the response is a Server-Sent Events stream of AG-UI events. +// +// Capabilities (see https://docs.ag-ui.com/concepts/events): +// +// - Lifecycle: RUN_STARTED / RUN_FINISHED / RUN_ERROR. +// - Text messages: TEXT_MESSAGE_START / _CONTENT / _END for the +// assistant turn. The single TEXT_MESSAGE_CONTENT carries the +// reasoner's full result; token-level streaming is a follow-up. +// - Tool calls: if the reasoner result contains a `toolCalls` array +// (one per `useCopilotAction`-style render), TOOL_CALL_START / +// TOOL_CALL_ARGS / TOOL_CALL_END frames are emitted before the +// text turn closes. CopilotKit's frontend pattern-matches +// `toolCallName` against registered actions to drive Generative UI. +// - State: if the reasoner result contains a `state` object, +// STATE_SNAPSHOT is emitted before RUN_FINISHED — the value +// `useCoAgent({ state })` reads on the client. +// - MESSAGES_SNAPSHOT closes every successful run with the canonical +// conversation history, so multi-turn clients can persist it. +func AGUIRunHandler(storageProvider storage.StorageProvider) gin.HandlerFunc { + return aguiRunHandler(storageProvider, httpAgentInvoker{}) +} + +func aguiRunHandler(storageProvider storage.StorageProvider, invoker agentInvoker) gin.HandlerFunc { + return func(c *gin.Context) { + ctx := c.Request.Context() + + nodeID := strings.TrimSpace(c.Param("node_id")) + reasonerName := strings.TrimSpace(c.Param("reasoner_name")) + if nodeID == "" || reasonerName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "node_id and reasoner_name are required"}) + return + } + + var req agui.RunAgentInput + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + agent, err := storageProvider.GetAgent(ctx, nodeID) + if err != nil || agent == nil { + c.JSON(http.StatusNotFound, gin.H{ + "error": fmt.Sprintf("node '%s' not found", nodeID), + }) + return + } + if !reasonerExists(agent, reasonerName) { + c.JSON(http.StatusNotFound, gin.H{ + "error": fmt.Sprintf("reasoner '%s' not found on node '%s'", reasonerName, nodeID), + }) + return + } + + // Validation passed — switch to streaming mode. From here on we + // report failures via RunError frames instead of HTTP error + // responses, since the SSE stream is already open. + threadID := req.ThreadID + if threadID == "" { + threadID = "thread-" + utils.GenerateExecutionID() + } + runID := req.RunID + if runID == "" { + runID = "run-" + utils.GenerateExecutionID() + } + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + flush := func() { + if f, ok := c.Writer.(http.Flusher); ok { + f.Flush() + } + } + + write := func(ev agui.Event) bool { + if err := agui.WriteSSE(c.Writer, ev); err != nil { + return false + } + flush() + return true + } + + if !write(agui.RunStarted{ + ThreadID: threadID, + RunID: runID, + Timestamp: agui.NowMillis(), + }) { + return + } + + reasonerInput := buildReasonerInput(req) + inputJSON, err := json.Marshal(reasonerInput) + if err != nil { + write(agui.RunError{ + Message: fmt.Sprintf("failed to marshal input: %v", err), + Code: "ERR_INPUT_MARSHAL", + Timestamp: agui.NowMillis(), + }) + return + } + + // Run the agent invocation in a goroutine so the main loop can + // emit SSE keep-alive comments while we wait for the first byte. + // (Once the body starts streaming, that's its own activity.) AG-UI + // has no heartbeat event, but `:` comment frames are valid SSE + // that clients ignore and proxies see as activity. + type invokeResultT struct { + res *agentInvocation + err error + } + resultCh := make(chan invokeResultT, 1) + go func() { + r, e := invoker.Invoke(ctx, agent, reasonerName, inputJSON) + resultCh <- invokeResultT{res: r, err: e} + }() + + ticker := time.NewTicker(AGUIHeartbeatInterval) + defer ticker.Stop() + + var ( + invocation *agentInvocation + invokeErr error + ) + waitLoop: + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if _, err := fmt.Fprint(c.Writer, ": keep-alive\n\n"); err != nil { + return + } + flush() + case r := <-resultCh: + invocation, invokeErr = r.res, r.err + break waitLoop + } + } + + if invokeErr != nil { + write(agui.RunError{ + Message: invokeErr.Error(), + Code: "ERR_AGENT_CALL", + Timestamp: agui.NowMillis(), + }) + return + } + + messageID := "msg-" + utils.GenerateExecutionID() + + // Streaming reasoner — Content-Type is application/x-ndjson and + // the body is a live chunk stream. Drain it here, dispatching + // each tagged event to its AG-UI counterpart immediately, then + // wrap up with MESSAGES_SNAPSHOT + RUN_FINISHED. This is the + // path that makes "Generative UI" feel live instead of stuttery. + if invocation.IsStreaming() { + runStreamingDispatch(ctx, c, write, invocation.Stream, req, threadID, runID, messageID) + return + } + + // Buffered reasoner — the rest of this function processes the + // fully-buffered JSON body the way it did before streaming + // support landed. + body := invocation.Body + // Decode the agent response so we can surface the structured pieces + // CopilotKit understands: tool calls, state, and the assistant text. + parsed, parsedOK := decodeReasonerResponse(body) + + // Reasoning segments first — frontends render these in a + // collapsible "Thinking…" pane above the user-facing answer, so + // emitting them before tool calls / text matches the UX flow. + if reasoning := extractReasoning(parsed); len(reasoning) > 0 { + reasoningContextID := "reasoning-" + utils.GenerateExecutionID() + if !write(agui.ReasoningStart{ + MessageID: reasoningContextID, + Timestamp: agui.NowMillis(), + }) { + return + } + for i, seg := range reasoning { + segID := seg.ID + if segID == "" { + segID = fmt.Sprintf("%s-seg-%d", reasoningContextID, i) + } + if !write(agui.ReasoningMessageStart{ + MessageID: segID, + Role: "reasoning", + Timestamp: agui.NowMillis(), + }) { + return + } + for _, chunk := range chunkText(seg.Content, AGUITextChunkSize) { + if !write(agui.ReasoningMessageContent{ + MessageID: segID, + Delta: chunk, + Timestamp: agui.NowMillis(), + }) { + return + } + } + if !write(agui.ReasoningMessageEnd{ + MessageID: segID, + Timestamp: agui.NowMillis(), + }) { + return + } + } + if !write(agui.ReasoningEnd{ + MessageID: reasoningContextID, + Timestamp: agui.NowMillis(), + }) { + return + } + } + + // Tool calls next so the frontend can dispatch render handlers + // (useCopilotAction) before the text turn closes. The text turn + // then carries any textual answer the reasoner produced. + toolCalls := extractToolCalls(parsed) + assistantToolCalls := make([]agui.ToolCall, 0, len(toolCalls)) + for _, tc := range toolCalls { + argsJSON, _ := json.Marshal(tc.Arguments) + argsStr := string(argsJSON) + if !write(agui.ToolCallStart{ + ToolCallID: tc.ID, + ToolCallName: tc.Name, + ParentMessageID: messageID, + Timestamp: agui.NowMillis(), + }) { + return + } + if !write(agui.ToolCallArgs{ + ToolCallID: tc.ID, + Delta: argsStr, + Timestamp: agui.NowMillis(), + }) { + return + } + if !write(agui.ToolCallEnd{ + ToolCallID: tc.ID, + Timestamp: agui.NowMillis(), + }) { + return + } + // If the reasoner already executed the tool server-side and + // gave us a result (e.g. a .ai(tools=...) trace), emit + // TOOL_CALL_RESULT so the trace renders in the same place the + // frontend would expect a tool message to live. + if tc.HasResult { + if !write(agui.ToolCallResult{ + MessageID: "msg-toolresult-" + tc.ID, + ToolCallID: tc.ID, + Content: stringifyResult(tc.Result), + Role: "tool", + Timestamp: agui.NowMillis(), + }) { + return + } + } + assistantToolCalls = append(assistantToolCalls, agui.ToolCall{ + ID: tc.ID, + Type: "function", + Function: agui.ToolCallFunction{ + Name: tc.Name, + Arguments: argsStr, + }, + }) + } + + // Text turn. Assembled even when empty so clients see a complete + // triad — schema permits empty delta. Long replies are chunked + // across multiple TEXT_MESSAGE_CONTENT frames so frontends can + // paint progressively even though the reasoner is synchronous. + assistantText := extractAssistantText(parsed, parsedOK, body) + if !write(agui.TextMessageStart{ + MessageID: messageID, + Role: "assistant", + Timestamp: agui.NowMillis(), + }) { + return + } + for _, chunk := range chunkText(assistantText, AGUITextChunkSize) { + if !write(agui.TextMessageContent{ + MessageID: messageID, + Delta: chunk, + Timestamp: agui.NowMillis(), + }) { + return + } + } + if !write(agui.TextMessageEnd{ + MessageID: messageID, + Timestamp: agui.NowMillis(), + }) { + return + } + + // State snapshot first (if reasoner returned full state), then + // any RFC 6902 patches the reasoner emits via `stateDelta`. + // Snapshot before MESSAGES_SNAPSHOT so the client correlates the + // new state with the new turn. + if state, hasState := extractState(parsed); hasState { + if !write(agui.StateSnapshot{ + Snapshot: state, + Timestamp: agui.NowMillis(), + }) { + return + } + } + if delta := extractStateDelta(parsed); delta != nil { + if !write(agui.StateDelta{ + Delta: delta, + Timestamp: agui.NowMillis(), + }) { + return + } + } + + // Canonical history snapshot: inbound messages + the assistant turn + // we just produced. + assistant := agui.Message{ + ID: messageID, + Role: "assistant", + Content: assistantText, + ToolCalls: assistantToolCalls, + } + full := append([]agui.Message{}, req.Messages...) + full = append(full, assistant) + if !write(agui.MessagesSnapshot{ + Messages: full, + Timestamp: agui.NowMillis(), + }) { + return + } + + write(agui.RunFinished{ + ThreadID: threadID, + RunID: runID, + Outcome: &agui.Outcome{Type: "success"}, + Result: parsed, + Timestamp: agui.NowMillis(), + }) + } +} + +// buildReasonerInput translates a canonical AG-UI RunAgentInput into the +// dict shape AgentField reasoners receive. We pass the full envelope (so +// reasoners that care can inspect tools/state/messages/context) plus a +// `prompt` convenience extracted from the trailing user message. +func buildReasonerInput(req agui.RunAgentInput) map[string]any { + input := map[string]any{ + "prompt": req.LastUserMessageText(), + "messages": req.Messages, + "tools": req.Tools, + "context": req.Context, + "threadId": req.ThreadID, + "runId": req.RunID, + } + if len(req.State) > 0 { + var state any + if err := json.Unmarshal(req.State, &state); err == nil { + input["state"] = state + } + } + if len(req.ForwardedProps) > 0 { + var fp any + if err := json.Unmarshal(req.ForwardedProps, &fp); err == nil { + input["forwardedProps"] = fp + } + } + return input +} + +// decodeReasonerResponse json-decodes the agent body. Returns the parsed +// value and whether decoding succeeded; non-JSON responses fall through to +// the raw-body path in extractAssistantText. +func decodeReasonerResponse(body []byte) (any, bool) { + var parsed any + if err := json.Unmarshal(body, &parsed); err != nil { + return nil, false + } + return parsed, true +} + +// reasonerToolCall is the synthetic shape AgentField reasoners use to +// declare tool calls. Reasoners return +// +// {"toolCalls": [{"id", "name", "arguments", "result"?}, ...]} +// +// to drive frontend useCopilotAction renders. The optional `result` field, +// when present, indicates the call was already executed server-side and +// causes us to emit TOOL_CALL_RESULT after TOOL_CALL_END — so the trace +// (e.g. from .ai(tools=...) ToolCallTrace) shows up in the UI alongside +// the live calls. +type reasonerToolCall struct { + ID string + Name string + Arguments any + Result any + HasResult bool +} + +// extractToolCalls reads a `toolCalls` array from the reasoner response, +// if present. Each entry needs at least a name; id and arguments are +// optional and synthesized when missing. `result` is optional. +func extractToolCalls(parsed any) []reasonerToolCall { + obj, ok := parsed.(map[string]any) + if !ok { + return nil + } + raw, ok := obj["toolCalls"].([]any) + if !ok { + return nil + } + out := make([]reasonerToolCall, 0, len(raw)) + for i, entry := range raw { + m, ok := entry.(map[string]any) + if !ok { + continue + } + name, _ := m["name"].(string) + if name == "" { + continue + } + id, _ := m["id"].(string) + if id == "" { + id = fmt.Sprintf("toolcall-%d-%s", i, utils.GenerateExecutionID()) + } + args := m["arguments"] + if args == nil { + args = map[string]any{} + } + result, hasResult := m["result"] + out = append(out, reasonerToolCall{ + ID: id, + Name: name, + Arguments: args, + Result: result, + HasResult: hasResult, + }) + } + return out +} + +// extractReasoning reads a chain-of-thought from the reasoner response. +// Reasoners that want to surface model thinking in CopilotKit's "Thinking…" +// pane return either: +// +// {"reasoning": "the agent's chain-of-thought as a single string"} +// +// or a list of per-step strings: +// +// {"reasoning": ["step 1...", "step 2..."]} +// +// In either case the handler emits REASONING_START → one or more +// REASONING_MESSAGE_START / _CONTENT / _END pairs → REASONING_END. +// Reasoners that already structured the trace can pass an explicit list +// of segment dicts: +// +// {"reasoning": [{"id": "r-0", "content": "..."}, ...]} +func extractReasoning(parsed any) []reasoningSegment { + obj, ok := parsed.(map[string]any) + if !ok { + return nil + } + raw, has := obj["reasoning"] + if !has || raw == nil { + return nil + } + switch v := raw.(type) { + case string: + if v == "" { + return nil + } + return []reasoningSegment{{Content: v}} + case []any: + out := make([]reasoningSegment, 0, len(v)) + for i, entry := range v { + switch s := entry.(type) { + case string: + if s == "" { + continue + } + out = append(out, reasoningSegment{Content: s}) + case map[string]any: + content, _ := s["content"].(string) + if content == "" { + continue + } + id, _ := s["id"].(string) + if id == "" { + id = fmt.Sprintf("r-%d-%s", i, utils.GenerateExecutionID()) + } + out = append(out, reasoningSegment{ID: id, Content: content}) + } + } + if len(out) == 0 { + return nil + } + return out + } + return nil +} + +type reasoningSegment struct { + ID string + Content string +} + +// extractStateDelta reads a `stateDelta` array from the reasoner response, +// if present. Reasoners that prefer to emit incremental RFC 6902 patches +// instead of (or in addition to) full snapshots return: +// +// {"stateDelta": [{"op":"replace","path":"/counter","value":2}, ...]} +// +// The handler emits this as a STATE_DELTA event. Both forms can coexist: +// emit STATE_SNAPSHOT first to establish a baseline, then STATE_DELTA for +// fine-grained updates. +func extractStateDelta(parsed any) []any { + obj, ok := parsed.(map[string]any) + if !ok { + return nil + } + raw, ok := obj["stateDelta"].([]any) + if !ok || len(raw) == 0 { + return nil + } + return raw +} + +// AGUITextChunkSize is the maximum size of a single TEXT_MESSAGE_CONTENT +// delta. Long reasoner responses are split into multiple deltas so the +// frontend can begin painting before the full reply lands. 256 chars is +// the sweet spot: small enough that long replies render progressively, +// large enough that short replies fit in one frame and don't pay extra +// SSE overhead. Exposed for tests. +var AGUITextChunkSize = 256 + +// chunkText splits a string into pieces of up to size bytes. For empty +// input, returns a single empty chunk so callers always emit one +// TEXT_MESSAGE_CONTENT delta (the schema permits empty deltas, and a +// missing content frame would break clients that expect the full triad). +// Splits on rune boundaries so multi-byte UTF-8 sequences (emoji, CJK) +// don't get cut mid-byte. +func chunkText(s string, size int) []string { + if size <= 0 { + return []string{s} + } + if s == "" { + return []string{""} + } + out := make([]string, 0, (len(s)/size)+1) + current := make([]rune, 0, size) + currentBytes := 0 + for _, r := range s { + rb := len(string(r)) + if currentBytes+rb > size && len(current) > 0 { + out = append(out, string(current)) + current = current[:0] + currentBytes = 0 + } + current = append(current, r) + currentBytes += rb + } + if len(current) > 0 { + out = append(out, string(current)) + } + return out +} + +// extractState returns the reasoner's top-level `state` field if any, +// for emission as STATE_SNAPSHOT. +func extractState(parsed any) (any, bool) { + obj, ok := parsed.(map[string]any) + if !ok { + return nil, false + } + state, has := obj["state"] + return state, has +} + +// extractAssistantText picks the human-facing answer for the assistant +// turn. Priority: +// 1. Reasoner returned a top-level `result` field — stringify it. +// 2. Reasoner returned a top-level `content` field — stringify it. +// 3. Reasoner returned a string body — use it verbatim. +// 4. Otherwise return the JSON-encoded body with `toolCalls` and `state` +// stripped, so the user sees something sensible if they didn't follow +// the `result` / `content` convention. +// 5. If the body wasn't JSON at all, return it raw. +func extractAssistantText(parsed any, parsedOK bool, rawBody []byte) string { + if !parsedOK { + return string(rawBody) + } + if obj, ok := parsed.(map[string]any); ok { + if r, has := obj["result"]; has { + return stringifyResult(r) + } + if r, has := obj["content"]; has { + return stringifyResult(r) + } + filtered := make(map[string]any, len(obj)) + for k, v := range obj { + if k == "toolCalls" || k == "state" { + continue + } + filtered[k] = v + } + if len(filtered) == 0 { + return "" + } + return stringifyResult(filtered) + } + if s, ok := parsed.(string); ok { + return s + } + return stringifyResult(parsed) +} + +func reasonerExists(agent *types.AgentNode, name string) bool { + for _, r := range agent.Reasoners { + if r.ID == name { + return true + } + } + return false +} + +// stringifyResult renders an arbitrary JSON value as a text chunk suitable +// for the AG-UI TextMessageContent delta. Strings pass through verbatim; +// everything else is JSON-encoded. +func stringifyResult(v any) string { + if s, ok := v.(string); ok { + return s + } + if v == nil { + return "" + } + encoded, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%v", v) + } + return string(encoded) +} diff --git a/control-plane/internal/handlers/agui_runs_integration_test.go b/control-plane/internal/handlers/agui_runs_integration_test.go new file mode 100644 index 00000000..893d6d67 --- /dev/null +++ b/control-plane/internal/handlers/agui_runs_integration_test.go @@ -0,0 +1,423 @@ +package handlers + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Agent-Field/agentfield/control-plane/pkg/types" + + "github.com/stretchr/testify/require" +) + +// agentField-style reasoner stub: mimics the wire shape an AgentField +// Python or Go SDK reasoner produces — JSON object with at least a +// `result` field, optionally `toolCalls` / `state` / `stateDelta` fields +// — so this test guards against the integration contract drifting. +// +// Without this, a future SDK rename of `prompt` -> `userPrompt` (or any +// similar tweak) would silently break Generative UI / shared state +// without failing any unit test, because the unit tests stub the +// agentInvoker interface and never inspect the reasoner-side input +// shape. + +// TestAGUI_Integration_FullSequence runs the full AG-UI handler against +// a live httptest reasoner that returns the same shape a real .ai() +// reasoner would when authors use agentfield.agui helpers. Asserts: +// +// - the reasoner received the canonical AG-UI envelope (prompt, +// messages, tools, state, context, threadId, runId) +// - the SSE stream carries lifecycle + tool calls (with TOOL_CALL_RESULT +// for executed traces) + state snapshot + state delta + chunked text + +// messages snapshot, in canonical order +// - the assistant turn in MESSAGES_SNAPSHOT carries the tool calls +// stitched onto it +func TestAGUI_Integration_FullSequence(t *testing.T) { + var seenInput map[string]any + reasoner := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/reasoners/integ", r.URL.Path) + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + raw, _ := io.ReadAll(r.Body) + require.NoError(t, json.Unmarshal(raw, &seenInput)) + + // Mimic an SDK reasoner that used app.ai(tools=...) and returned + // the trace via agentfield.agui.tool_calls_from_trace, plus a + // fresh state and a single delta op. + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "result": "Booked SFO to JFK. Counter is now 2.", + "toolCalls": [{ + "id": "tc-trace-0", + "name": "showFlightCard", + "arguments": {"from":"SFO","to":"JFK"}, + "result": {"flightId":"AA-12","status":"booked"} + }], + "state": {"counter": 2, "lastBooking": "AA-12"}, + "stateDelta": [ + {"op":"replace","path":"/counter","value":2} + ] + }`)) + })) + defer reasoner.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "integ-node", + BaseURL: reasoner.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "integ"}}, + }} + router := mountAGUIRouter(t, store) + + // Build a canonical RunAgentInput that exercises every surface the + // reasoner is supposed to receive: prompt + multi-message history + + // tools + state + context + forwardedProps. + body := `{ + "threadId": "thread-int", "runId": "run-int", + "messages": [ + {"role":"system","content":"you are helpful"}, + {"role":"user","content":"book SFO->JFK"} + ], + "tools": [{"name":"showFlightCard","description":"render a flight card"}], + "context": [{"description":"user prefs","value":{"seat":"aisle"}}], + "state": {"counter": 1}, + "forwardedProps": {"locale":"en-US"} + }` + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/integ-node/integ", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, w.Body.String()) + + // 1. Reasoner saw the canonical envelope. + require.Equal(t, "book SFO->JFK", seenInput["prompt"]) + require.Equal(t, "thread-int", seenInput["threadId"]) + require.Equal(t, "run-int", seenInput["runId"]) + gotMessages, _ := seenInput["messages"].([]any) + require.Len(t, gotMessages, 2) + gotTools, _ := seenInput["tools"].([]any) + require.Len(t, gotTools, 1) + gotContext, _ := seenInput["context"].([]any) + require.Len(t, gotContext, 1) + gotState, _ := seenInput["state"].(map[string]any) + require.EqualValues(t, 1, gotState["counter"]) + gotFP, _ := seenInput["forwardedProps"].(map[string]any) + require.Equal(t, "en-US", gotFP["locale"]) + + // 2. Wire output: full canonical sequence. + frames := parseAGUIStream(t, w.Body.String()) + types := []string{} + for _, f := range frames { + types = append(types, f.Type()) + } + want := []string{ + "RUN_STARTED", + "TOOL_CALL_START", + "TOOL_CALL_ARGS", + "TOOL_CALL_END", + "TOOL_CALL_RESULT", + "TEXT_MESSAGE_START", + "TEXT_MESSAGE_CONTENT", + "TEXT_MESSAGE_END", + "STATE_SNAPSHOT", + "STATE_DELTA", + "MESSAGES_SNAPSHOT", + "RUN_FINISHED", + } + require.Equal(t, want, types, "frame sequence diverged from canonical AG-UI order") + + // 3. TOOL_CALL_RESULT carries the executed trace's result. + resFrame := frames[4] + require.Equal(t, "tc-trace-0", resFrame.Data["toolCallId"]) + require.Equal(t, "tool", resFrame.Data["role"]) + require.JSONEq(t, `{"flightId":"AA-12","status":"booked"}`, resFrame.Data["content"].(string)) + + // 4. STATE_SNAPSHOT carries new value; STATE_DELTA carries the patch. + snap, _ := frames[8].Data["snapshot"].(map[string]any) + require.EqualValues(t, 2, snap["counter"]) + require.Equal(t, "AA-12", snap["lastBooking"]) + delta, _ := frames[9].Data["delta"].([]any) + require.Len(t, delta, 1) + op, _ := delta[0].(map[string]any) + require.Equal(t, "replace", op["op"]) + + // 5. MESSAGES_SNAPSHOT — assistant turn carries tool calls. + msgs, _ := frames[10].Data["messages"].([]any) + require.Len(t, msgs, 3, "should be 2 inbound + 1 new assistant") + assistant, _ := msgs[2].(map[string]any) + require.Equal(t, "assistant", assistant["role"]) + tcs, _ := assistant["toolCalls"].([]any) + require.Len(t, tcs, 1) + tc, _ := tcs[0].(map[string]any) + require.Equal(t, "tc-trace-0", tc["id"]) + fn, _ := tc["function"].(map[string]any) + require.Equal(t, "showFlightCard", fn["name"]) +} + +// TestAGUI_Integration_StreamingReasoner exercises the live-streaming +// path end to end: the reasoner returns NDJSON tagged events, the +// handler dispatches each into its AG-UI counterpart, frames are +// flushed live (verified by timestamping arrivals), and the run closes +// with MESSAGES_SNAPSHOT + RUN_FINISHED. This is the test that proves +// "Generative UI feels live" actually works under load — without it, +// any future regression that buffers the stream would silently make +// the UX stuttery again with no test failure. +func TestAGUI_Integration_StreamingReasoner(t *testing.T) { + // The reasoner streams: text chunks (with deliberate per-chunk + // delays so we can assert live forwarding), then a tool call, then + // state, then closes. + chunkDelay := 30 * time.Millisecond + reasoner := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/reasoners/streaming-bot", r.URL.Path) + w.Header().Set("Content-Type", "application/x-ndjson") + w.WriteHeader(http.StatusOK) + flusher, _ := w.(http.Flusher) + + send := func(line string) { + fmt.Fprintln(w, line) + if flusher != nil { + flusher.Flush() + } + time.Sleep(chunkDelay) + } + send(`{"type":"reasoning","delta":"checking flights..."}`) + send(`{"type":"reasoning","delta":" AA-12 wins on price."}`) + send(`{"type":"text","delta":"Booked "}`) + send(`{"type":"text","delta":"AA-12 SFO->JFK."}`) + send(`{"type":"tool_call_start","id":"tc-1","name":"showFlightCard","arguments":{"from":"SFO","to":"JFK"}}`) + send(`{"type":"tool_call_end","id":"tc-1"}`) + send(`{"type":"state","snapshot":{"counter":1}}`) + send(`{"type":"step_started","name":"finalize"}`) + send(`{"type":"step_finished","name":"finalize"}`) + send(`{"type":"custom","name":"telemetry","value":{"latency_ms":120}}`) + })) + defer reasoner.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "stream-node", + BaseURL: reasoner.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "streaming-bot"}}, + }} + router := mountAGUIRouter(t, store) + + body := `{"threadId":"t","runId":"r","messages":[{"role":"user","content":"book it"}]}` + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/stream-node/streaming-bot", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, w.Body.String()) + frames := parseAGUIStream(t, w.Body.String()) + got := []string{} + for _, f := range frames { + got = append(got, f.Type()) + } + want := []string{ + "RUN_STARTED", + "REASONING_START", + "REASONING_MESSAGE_START", + "REASONING_MESSAGE_CONTENT", + "REASONING_MESSAGE_CONTENT", + "REASONING_MESSAGE_END", // closed when text chunk arrives + "REASONING_END", // outer context closed + "TEXT_MESSAGE_START", + "TEXT_MESSAGE_CONTENT", + "TEXT_MESSAGE_CONTENT", + "TOOL_CALL_START", + "TOOL_CALL_ARGS", // synthesized from `arguments` on start + "TOOL_CALL_END", + "STATE_SNAPSHOT", + "STEP_STARTED", + "STEP_FINISHED", + "CUSTOM", + "TEXT_MESSAGE_END", // closed at stream end + "MESSAGES_SNAPSHOT", + "RUN_FINISHED", + } + require.Equal(t, want, got, "streaming dispatcher diverged from canonical AG-UI ordering") + + // Each text-content delta must carry the chunk the reasoner sent + // (proves the dispatcher didn't accidentally re-buffer). + textDeltas := []string{} + for _, f := range frames { + if f.Type() == "TEXT_MESSAGE_CONTENT" { + d, _ := f.Data["delta"].(string) + textDeltas = append(textDeltas, d) + } + } + require.Equal(t, []string{"Booked ", "AA-12 SFO->JFK."}, textDeltas) + + // MESSAGES_SNAPSHOT closes with the assistant turn carrying the + // concatenated text and the tool call attached. + snap, _ := frames[len(frames)-2].Data["messages"].([]any) + require.Len(t, snap, 2) + assistant, _ := snap[1].(map[string]any) + require.Equal(t, "Booked AA-12 SFO->JFK.", assistant["content"]) + tcs, _ := assistant["toolCalls"].([]any) + require.Len(t, tcs, 1) +} + +// TestAGUI_Integration_StreamingErrorChunkTerminates: an `error` chunk +// from the reasoner terminates the stream with RUN_ERROR, even +// mid-flight, without emitting MESSAGES_SNAPSHOT or RUN_FINISHED. +func TestAGUI_Integration_StreamingErrorChunkTerminates(t *testing.T) { + reasoner := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/x-ndjson") + w.WriteHeader(http.StatusOK) + flusher, _ := w.(http.Flusher) + send := func(s string) { + fmt.Fprintln(w, s) + if flusher != nil { + flusher.Flush() + } + } + send(`{"type":"text","delta":"hello"}`) + send(`{"type":"error","message":"upstream blew up","code":"ERR_LLM"}`) + // Anything after the error must be ignored. + send(`{"type":"text","delta":"unreachable"}`) + })) + defer reasoner.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "n", + BaseURL: reasoner.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "boom"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/n/boom", + strings.NewReader(`{"threadId":"t","runId":"r","messages":[{"role":"user","content":"x"}]}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + frames := parseAGUIStream(t, w.Body.String()) + got := []string{} + for _, f := range frames { + got = append(got, f.Type()) + } + // We accept the partial text frames, then RUN_ERROR (terminal). + require.Contains(t, got, "RUN_ERROR") + last := frames[len(frames)-1] + require.Equal(t, "RUN_ERROR", last.Type()) + require.Equal(t, "upstream blew up", last.Data["message"]) + require.Equal(t, "ERR_LLM", last.Data["code"]) + require.NotContains(t, got, "MESSAGES_SNAPSHOT", "no snapshot after error") + require.NotContains(t, got, "RUN_FINISHED", "no finish after error") + // The post-error text chunk must have been dropped. + for _, f := range frames { + if f.Type() == "TEXT_MESSAGE_CONTENT" { + d, _ := f.Data["delta"].(string) + require.NotEqual(t, "unreachable", d, "post-error chunk must not leak through") + } + } +} + +// TestAGUI_Integration_StreamingMalformedLineSurfacesAsRaw: a single bad +// NDJSON line shouldn't kill the stream — the dispatcher should surface +// it as RAW and continue. +func TestAGUI_Integration_StreamingMalformedLineSurfacesAsRaw(t *testing.T) { + reasoner := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/x-ndjson") + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, `{"type":"text","delta":"hi"}`) + fmt.Fprintln(w, `{not valid json`) + fmt.Fprintln(w, `{"type":"text","delta":" world"}`) + })) + defer reasoner.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "n", + BaseURL: reasoner.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "wobble"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/n/wobble", + strings.NewReader(`{"threadId":"t","runId":"r","messages":[{"role":"user","content":"x"}]}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + frames := parseAGUIStream(t, w.Body.String()) + got := []string{} + for _, f := range frames { + got = append(got, f.Type()) + } + require.Contains(t, got, "RAW", "malformed chunk should surface as RAW") + // Stream completed; both text deltas reached us. + textDeltas := []string{} + for _, f := range frames { + if f.Type() == "TEXT_MESSAGE_CONTENT" { + d, _ := f.Data["delta"].(string) + textDeltas = append(textDeltas, d) + } + } + require.Equal(t, []string{"hi", " world"}, textDeltas) + require.Equal(t, "RUN_FINISHED", frames[len(frames)-1].Type()) +} + +// TestAGUI_Integration_FollowupTurnWithToolMessage verifies the second +// half of the CopilotKit "user clicked confirm" loop: when the next +// run's inbound history includes a role:"tool" message, the reasoner +// receives it intact so it can produce a follow-up response. +func TestAGUI_Integration_FollowupTurnWithToolMessage(t *testing.T) { + var seenInput map[string]any + reasoner := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + require.NoError(t, json.Unmarshal(raw, &seenInput)) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"result":"Booking confirmed."}`)) + })) + defer reasoner.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "n", + BaseURL: reasoner.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "f"}}, + }} + router := mountAGUIRouter(t, store) + + body := `{ + "threadId":"t","runId":"r2", + "messages":[ + {"role":"user","content":"book SFO->JFK"}, + {"role":"assistant","toolCalls":[{ + "id":"tc1","type":"function", + "function":{"name":"showFlightCard","arguments":"{\"from\":\"SFO\"}"} + }]}, + {"role":"tool","toolCallId":"tc1","content":"user confirmed"}, + {"role":"user","content":"now book the return"} + ] + }` + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/n/f", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, w.Body.String()) + // The reasoner must see the tool message verbatim — that's what + // closes the click-confirm loop. Without this, the agent has no way + // of knowing the tool ran. + require.Equal(t, "now book the return", seenInput["prompt"]) + msgs, _ := seenInput["messages"].([]any) + require.Len(t, msgs, 4) + tool, _ := msgs[2].(map[string]any) + require.Equal(t, "tool", tool["role"]) + require.Equal(t, "tc1", tool["toolCallId"]) + require.Equal(t, "user confirmed", tool["content"]) +} diff --git a/control-plane/internal/handlers/agui_runs_load_test.go b/control-plane/internal/handlers/agui_runs_load_test.go new file mode 100644 index 00000000..12a17154 --- /dev/null +++ b/control-plane/internal/handlers/agui_runs_load_test.go @@ -0,0 +1,257 @@ +package handlers + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Agent-Field/agentfield/control-plane/pkg/types" + + "github.com/stretchr/testify/require" +) + +// TestAGUI_Load_ConcurrentBuffered hammers the AG-UI handler with many +// concurrent requests against a fast buffered reasoner and asserts: +// +// - Every request returns a complete canonical event sequence. +// - Goroutines don't leak: the count after all runs settle is +// approximately the baseline (a few +/- for runtime noise). +// - p50/p95/p99 latencies stay within reasonable bounds at 200 in-flight. +// +// This is the production-readiness gate the earlier 5×concurrent test +// could not provide. It runs in CI as part of `go test`. +func TestAGUI_Load_ConcurrentBuffered(t *testing.T) { + if testing.Short() { + t.Skip("skipping load test in -short mode") + } + + const totalRequests = 200 + const concurrency = 50 + + reasoner := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"result":"ok","state":{"counter":1}}`)) + })) + defer reasoner.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "load-node", + BaseURL: reasoner.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "r"}}, + }} + router := mountAGUIRouter(t, store) + + // Sample the goroutine baseline AFTER the test runtime is up but + // BEFORE we fire load. NumGoroutine() is nondeterministic so we + // give the handler a generous tolerance — we're guarding against + // real leaks (200 leaked goroutines per 200 runs), not noise. + runtime.GC() + time.Sleep(50 * time.Millisecond) + baseline := runtime.NumGoroutine() + + var ( + started atomic.Int64 + completed atomic.Int64 + failed atomic.Int64 + latencies = make([]time.Duration, totalRequests) + ) + sem := make(chan struct{}, concurrency) + var wg sync.WaitGroup + + wallStart := time.Now() + for i := 0; i < totalRequests; i++ { + wg.Add(1) + sem <- struct{}{} + go func(idx int) { + defer wg.Done() + defer func() { <-sem }() + started.Add(1) + + body := fmt.Sprintf(`{"threadId":"t-%d","runId":"r-%d","messages":[{"role":"user","content":"x"}]}`, idx, idx) + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/load-node/r", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + start := time.Now() + router.ServeHTTP(w, req) + latencies[idx] = time.Since(start) + + if w.Code != http.StatusOK { + failed.Add(1) + return + } + frames := parseAGUIStream(t, w.Body.String()) + if len(frames) == 0 || frames[0].Type() != "RUN_STARTED" || frames[len(frames)-1].Type() != "RUN_FINISHED" { + failed.Add(1) + return + } + completed.Add(1) + }(i) + } + wg.Wait() + wallElapsed := time.Since(wallStart) + + require.Equal(t, int64(totalRequests), started.Load(), "all requests should have started") + require.Equal(t, int64(totalRequests), completed.Load(), "all requests should have completed: failures=%d", failed.Load()) + require.Equal(t, int64(0), failed.Load(), "no requests should have failed under load") + + // Latency stats — sort then pick percentiles. + sortDurations(latencies) + p50 := latencies[len(latencies)*50/100] + p95 := latencies[len(latencies)*95/100] + p99 := latencies[len(latencies)*99/100] + + t.Logf("load: %d reqs at %d concurrency, wall=%s, p50=%s p95=%s p99=%s", + totalRequests, concurrency, wallElapsed, p50, p95, p99) + + // Loose latency budget — the handler is just routing + emitting + // events against an in-process httptest reasoner, so even p99 + // shouldn't exceed 250ms on a quiet box. + require.Less(t, p95, 250*time.Millisecond, "p95 latency too high under 50× concurrent load") + + // Goroutine leak check. Every request spawns one goroutine + // (invoker.Invoke). They should all have settled by now. Allow a + // generous buffer for test infra (httptest handlers can keep + // goroutines around briefly) but flag a real leak. + runtime.GC() + time.Sleep(100 * time.Millisecond) + final := runtime.NumGoroutine() + t.Logf("goroutines: baseline=%d, final=%d, delta=%d", baseline, final, final-baseline) + require.Less(t, final-baseline, 50, "goroutine leak: %d goroutines still running after load completed", final-baseline) +} + +// TestAGUI_Load_ConcurrentStreaming repeats the load run against a +// streaming reasoner so the streaming dispatch path is also load-tested. +func TestAGUI_Load_ConcurrentStreaming(t *testing.T) { + if testing.Short() { + t.Skip("skipping load test in -short mode") + } + + const totalRequests = 100 + const concurrency = 25 + + reasoner := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/x-ndjson") + w.WriteHeader(http.StatusOK) + flusher, _ := w.(http.Flusher) + send := func(line string) { + fmt.Fprintln(w, line) + if flusher != nil { + flusher.Flush() + } + } + send(`{"type":"text","delta":"chunk-1"}`) + send(`{"type":"text","delta":"chunk-2"}`) + send(`{"type":"state","snapshot":{"k":1}}`) + })) + defer reasoner.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "load-stream", + BaseURL: reasoner.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "r"}}, + }} + router := mountAGUIRouter(t, store) + + runtime.GC() + time.Sleep(50 * time.Millisecond) + baseline := runtime.NumGoroutine() + + var failed atomic.Int64 + sem := make(chan struct{}, concurrency) + var wg sync.WaitGroup + wallStart := time.Now() + + for i := 0; i < totalRequests; i++ { + wg.Add(1) + sem <- struct{}{} + go func(idx int) { + defer wg.Done() + defer func() { <-sem }() + + body := fmt.Sprintf(`{"threadId":"t-%d","runId":"r-%d","messages":[{"role":"user","content":"x"}]}`, idx, idx) + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/load-stream/r", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + failed.Add(1) + return + } + frames := parseAGUIStream(t, w.Body.String()) + if len(frames) < 5 || frames[len(frames)-1].Type() != "RUN_FINISHED" { + failed.Add(1) + } + }(i) + } + wg.Wait() + t.Logf("streaming load: %d reqs at %d concurrent, wall=%s", totalRequests, concurrency, time.Since(wallStart)) + + require.Equal(t, int64(0), failed.Load(), "streaming dispatcher must complete every request under load") + + runtime.GC() + time.Sleep(100 * time.Millisecond) + final := runtime.NumGoroutine() + t.Logf("streaming goroutines: baseline=%d, final=%d, delta=%d", baseline, final, final-baseline) + require.Less(t, final-baseline, 50, "streaming dispatcher leaked goroutines under load") +} + +// BenchmarkAGUI_BufferedHandler measures the per-request cost of the +// AG-UI handler against an in-process httptest reasoner. Run with: +// +// go test -bench=BenchmarkAGUI -benchmem -run=^$ ./internal/handlers/... +// +// Useful as a regression baseline when the streaming/dispatch logic +// changes. +func BenchmarkAGUI_BufferedHandler(b *testing.B) { + reasoner := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"result":"ok"}`)) + })) + defer reasoner.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "bench", + BaseURL: reasoner.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "r"}}, + }} + router := mountAGUIRouter(&testing.T{}, store) + bodyTpl := `{"threadId":"t","runId":"r","messages":[{"role":"user","content":"x"}]}` + + b.ResetTimer() + for i := 0; i < b.N; i++ { + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/bench/r", strings.NewReader(bodyTpl)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + b.Fatalf("status=%d", w.Code) + } + } +} + +// sortDurations is a small inline sort to avoid pulling in slices/sort +// noise in the load test. n is small (≤200) so insertion sort is fine. +func sortDurations(xs []time.Duration) { + for i := 1; i < len(xs); i++ { + for j := i; j > 0 && xs[j-1] > xs[j]; j-- { + xs[j-1], xs[j] = xs[j], xs[j-1] + } + } +} + +// silence unused-import warnings in case the file is edited down later. +var _ = context.Background diff --git a/control-plane/internal/handlers/agui_runs_streaming.go b/control-plane/internal/handlers/agui_runs_streaming.go new file mode 100644 index 00000000..f4767c0a --- /dev/null +++ b/control-plane/internal/handlers/agui_runs_streaming.go @@ -0,0 +1,538 @@ +package handlers + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + + "github.com/Agent-Field/agentfield/control-plane/internal/agui" + "github.com/Agent-Field/agentfield/control-plane/internal/utils" + + "github.com/gin-gonic/gin" +) + +// AGUIStreamingMaxLineBytes caps the size of any one NDJSON chunk the +// reasoner can send. Without this, a misbehaving reasoner could stream +// an unbounded line and exhaust handler memory. 1 MiB is generous for +// per-token deltas while still bounding the worst case. Exposed for tests. +var AGUIStreamingMaxLineBytes = 1 << 20 + +// streamingChunk is the wire shape between an AgentField streaming +// reasoner and the AG-UI handler. Reasoners emit one JSON object per +// line on stdout (NDJSON); this struct decodes them. All fields are +// optional — `Type` selects the variant. +// +// Recognized variants and their AG-UI translation: +// +// {"type":"text", "delta":"hello"} -> TEXT_MESSAGE_CONTENT +// {"type":"reasoning","delta":"thinking..."} -> REASONING_MESSAGE_CONTENT +// {"type":"tool_call_start","id":"tc1","name":"x", "arguments":{...}, "parentMessageId":"..."} +// -> TOOL_CALL_START + (single +// TOOL_CALL_ARGS if arguments +// supplied) +// {"type":"tool_call_args", "id":"tc1","delta":"..."} -> TOOL_CALL_ARGS +// {"type":"tool_call_end", "id":"tc1"} -> TOOL_CALL_END +// {"type":"tool_call_result","id":"tc1","content":"..."} -> TOOL_CALL_RESULT +// {"type":"state", "snapshot":{...}} -> STATE_SNAPSHOT +// {"type":"state_delta", "ops":[...]} -> STATE_DELTA (RFC 6902) +// {"type":"step_started", "name":"plan"} -> STEP_STARTED +// {"type":"step_finished", "name":"plan"} -> STEP_FINISHED +// {"type":"raw", "event":..., "source":"x"} -> RAW +// {"type":"custom", "name":"...","value":...} -> CUSTOM +// {"type":"final", "data":{}} -> applies any +// leftover toolCalls / state / stateDelta / reasoning the reasoner +// wants to send at the end of the stream, plus closes any open text +// or reasoning sessions. +// {"type":"error", "message":"...","code":"..."} -> RUN_ERROR (terminal) +// +// Unknown types are skipped silently with a debug log so reasoner authors +// can iterate without forcing a control-plane upgrade. +type streamingChunk struct { + Type string `json:"type"` + + // text / reasoning / tool_call_args + Delta string `json:"delta,omitempty"` + + // reasoning / tool_call_* + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + ParentMessageID string `json:"parentMessageId,omitempty"` + Arguments json.RawMessage `json:"arguments,omitempty"` + + // tool_call_result + Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + + // state / state_delta + Snapshot any `json:"snapshot,omitempty"` + Ops []any `json:"ops,omitempty"` + + // raw + Event any `json:"event,omitempty"` + Source string `json:"source,omitempty"` + + // custom + Value any `json:"value,omitempty"` + + // final + Data map[string]any `json:"data,omitempty"` + + // error + Message string `json:"message,omitempty"` + Code string `json:"code,omitempty"` +} + +// streamingState holds the bookkeeping the dispatcher needs across +// chunks: which text/reasoning sessions are currently open, what tool +// calls have been declared, what assistant message is being built up. +type streamingState struct { + messageID string + textOpen bool + textBuf []byte // accumulates text deltas for the assistant message + reasoningCtx string // empty if no reasoning context is open + reasoningSeg string // empty if no reasoning message is open + toolCalls []agui.ToolCall + stateSet bool + state any +} + +// runStreamingDispatch consumes the reasoner's NDJSON stream and emits +// AG-UI events as they arrive. Closes the stream when done. Wraps the +// run with TEXT_MESSAGE_START/_END (synthesized lazily on first text +// chunk) and finishes with MESSAGES_SNAPSHOT + RUN_FINISHED — the same +// closing shape buffered reasoners produce, so frontends don't have to +// branch on streaming-vs-buffered. +func runStreamingDispatch( + ctx context.Context, + c *gin.Context, + write func(agui.Event) bool, + stream io.ReadCloser, + req agui.RunAgentInput, + threadID, runID, messageID string, +) { + defer stream.Close() + st := &streamingState{messageID: messageID} + + scanner := bufio.NewScanner(stream) + scanner.Buffer(make([]byte, 0, 64*1024), AGUIStreamingMaxLineBytes) + + for scanner.Scan() { + select { + case <-ctx.Done(): + return + default: + } + line := scanner.Bytes() + if len(line) == 0 { + continue + } + var ch streamingChunk + if err := json.Unmarshal(line, &ch); err != nil { + // One bad chunk shouldn't blow up the run. Surface it as + // RAW so the frontend at least sees that something garbled + // went past, and keep going. + write(agui.RawEvent{ + Event: map[string]any{"raw": string(line), "decode_error": err.Error()}, + Source: "agentfield-streaming", + Timestamp: agui.NowMillis(), + }) + continue + } + if !dispatchChunk(write, st, ch) { + return + } + } + if err := scanner.Err(); err != nil { + write(agui.RunError{ + Message: fmt.Sprintf("read streaming reasoner: %v", err), + Code: "ERR_AGENT_STREAM", + Timestamp: agui.NowMillis(), + }) + return + } + + // Stream ended — close any open text/reasoning sessions, emit the + // canonical close-frames the buffered path would have emitted, and + // finish the run. + closeTextSession(write, st) + closeReasoningSession(write, st) + + assistant := agui.Message{ + ID: st.messageID, + Role: "assistant", + Content: string(st.textBuf), + ToolCalls: st.toolCalls, + } + full := append([]agui.Message{}, req.Messages...) + full = append(full, assistant) + if !write(agui.MessagesSnapshot{ + Messages: full, + Timestamp: agui.NowMillis(), + }) { + return + } + + finished := agui.RunFinished{ + ThreadID: threadID, + RunID: runID, + Outcome: &agui.Outcome{Type: "success"}, + Timestamp: agui.NowMillis(), + } + if st.stateSet { + finished.Result = map[string]any{"state": st.state} + } + write(finished) +} + +// dispatchChunk emits the AG-UI events corresponding to one NDJSON +// chunk. Returns false on a write failure (so the caller stops the loop). +func dispatchChunk(write func(agui.Event) bool, st *streamingState, ch streamingChunk) bool { + switch ch.Type { + case "text": + if ch.Delta == "" { + return true + } + // Reasoning sessions close before the text turn opens — frontends + // don't expect text chunks interleaved with reasoning. + if !closeReasoningSession(write, st) { + return false + } + if !st.textOpen { + if !write(agui.TextMessageStart{ + MessageID: st.messageID, + Role: "assistant", + Timestamp: agui.NowMillis(), + }) { + return false + } + st.textOpen = true + } + st.textBuf = append(st.textBuf, ch.Delta...) + return write(agui.TextMessageContent{ + MessageID: st.messageID, + Delta: ch.Delta, + Timestamp: agui.NowMillis(), + }) + + case "reasoning": + if ch.Delta == "" { + return true + } + // Open the outer reasoning context lazily on first chunk. + if st.reasoningCtx == "" { + st.reasoningCtx = "reasoning-" + utils.GenerateExecutionID() + if !write(agui.ReasoningStart{ + MessageID: st.reasoningCtx, + Timestamp: agui.NowMillis(), + }) { + return false + } + } + // Open a per-segment message lazily — the reasoner can send a + // `reasoning_end` chunk between segments to close one and start + // the next, but for the simple case (single contiguous thinking + // block) we batch all deltas into one message. + if st.reasoningSeg == "" { + st.reasoningSeg = st.reasoningCtx + "-seg-" + utils.GenerateExecutionID() + if !write(agui.ReasoningMessageStart{ + MessageID: st.reasoningSeg, + Role: "reasoning", + Timestamp: agui.NowMillis(), + }) { + return false + } + } + return write(agui.ReasoningMessageContent{ + MessageID: st.reasoningSeg, + Delta: ch.Delta, + Timestamp: agui.NowMillis(), + }) + + case "reasoning_end": + // Ends the current reasoning segment (so the next "reasoning" + // chunk opens a fresh one). Doesn't close the outer context; + // that happens at stream end or when a "text"/"final" chunk + // arrives. + if st.reasoningSeg != "" { + if !write(agui.ReasoningMessageEnd{ + MessageID: st.reasoningSeg, + Timestamp: agui.NowMillis(), + }) { + return false + } + st.reasoningSeg = "" + } + return true + + case "tool_call_start": + if ch.ID == "" || ch.Name == "" { + return true + } + parent := ch.ParentMessageID + if parent == "" { + parent = st.messageID + } + if !write(agui.ToolCallStart{ + ToolCallID: ch.ID, + ToolCallName: ch.Name, + ParentMessageID: parent, + Timestamp: agui.NowMillis(), + }) { + return false + } + // Convenience: if the reasoner already has the full arguments + // at start time (non-streaming-args reasoner), pre-emit them. + argsStr := "" + if len(ch.Arguments) > 0 { + argsStr = string(ch.Arguments) + if !write(agui.ToolCallArgs{ + ToolCallID: ch.ID, + Delta: argsStr, + Timestamp: agui.NowMillis(), + }) { + return false + } + } + st.toolCalls = append(st.toolCalls, agui.ToolCall{ + ID: ch.ID, + Type: "function", + Function: agui.ToolCallFunction{ + Name: ch.Name, + Arguments: argsStr, + }, + }) + return true + + case "tool_call_args": + if ch.ID == "" || ch.Delta == "" { + return true + } + // Append to whichever ToolCall.Function.Arguments matches. + for i := range st.toolCalls { + if st.toolCalls[i].ID == ch.ID { + st.toolCalls[i].Function.Arguments += ch.Delta + break + } + } + return write(agui.ToolCallArgs{ + ToolCallID: ch.ID, + Delta: ch.Delta, + Timestamp: agui.NowMillis(), + }) + + case "tool_call_end": + if ch.ID == "" { + return true + } + return write(agui.ToolCallEnd{ + ToolCallID: ch.ID, + Timestamp: agui.NowMillis(), + }) + + case "tool_call_result": + if ch.ID == "" { + return true + } + role := ch.Role + if role == "" { + role = "tool" + } + return write(agui.ToolCallResult{ + MessageID: "msg-toolresult-" + ch.ID, + ToolCallID: ch.ID, + Content: ch.Content, + Role: role, + Timestamp: agui.NowMillis(), + }) + + case "state": + st.stateSet = true + st.state = ch.Snapshot + return write(agui.StateSnapshot{ + Snapshot: ch.Snapshot, + Timestamp: agui.NowMillis(), + }) + + case "state_delta": + if len(ch.Ops) == 0 { + return true + } + return write(agui.StateDelta{ + Delta: ch.Ops, + Timestamp: agui.NowMillis(), + }) + + case "step_started": + if ch.Name == "" { + return true + } + return write(agui.StepStarted{StepName: ch.Name, Timestamp: agui.NowMillis()}) + + case "step_finished": + if ch.Name == "" { + return true + } + return write(agui.StepFinished{StepName: ch.Name, Timestamp: agui.NowMillis()}) + + case "raw": + return write(agui.RawEvent{ + Event: ch.Event, + Source: ch.Source, + Timestamp: agui.NowMillis(), + }) + + case "custom": + if ch.Name == "" { + return true + } + return write(agui.CustomEvent{ + Name: ch.Name, + Value: ch.Value, + Timestamp: agui.NowMillis(), + }) + + case "error": + // Terminal — emit RUN_ERROR and return false to short-circuit. + write(agui.RunError{ + Message: ch.Message, + Code: ch.Code, + Timestamp: agui.NowMillis(), + }) + return false + + case "final": + // Treat the data field as a buffered-mode response: extract any + // not-yet-sent reasoning / tool calls / state / stateDelta and + // emit them. This lets a streaming reasoner shovel structured + // trailing fields without re-implementing the buffered logic. + applyFinal(write, st, ch.Data) + return true + + default: + // Unknown chunk type — surface as RAW with a hint so the + // frontend has visibility, then continue. + write(agui.RawEvent{ + Event: map[string]any{"unknown_chunk_type": ch.Type}, + Source: "agentfield-streaming", + Timestamp: agui.NowMillis(), + }) + return true + } +} + +// closeTextSession emits TEXT_MESSAGE_END if a text session is open. +// Returns false on write failure. +func closeTextSession(write func(agui.Event) bool, st *streamingState) bool { + if !st.textOpen { + return true + } + st.textOpen = false + return write(agui.TextMessageEnd{ + MessageID: st.messageID, + Timestamp: agui.NowMillis(), + }) +} + +// closeReasoningSession closes any open reasoning message and the outer +// reasoning context. No-op if neither is open. +func closeReasoningSession(write func(agui.Event) bool, st *streamingState) bool { + if st.reasoningSeg != "" { + if !write(agui.ReasoningMessageEnd{ + MessageID: st.reasoningSeg, + Timestamp: agui.NowMillis(), + }) { + return false + } + st.reasoningSeg = "" + } + if st.reasoningCtx != "" { + if !write(agui.ReasoningEnd{ + MessageID: st.reasoningCtx, + Timestamp: agui.NowMillis(), + }) { + return false + } + st.reasoningCtx = "" + } + return true +} + +// applyFinal lets a streaming reasoner emit one trailing buffered-shape +// envelope to ship any structured fields it didn't send chunk-by-chunk. +// Honors the same field names the buffered path recognizes. +func applyFinal(write func(agui.Event) bool, st *streamingState, data map[string]any) { + if data == nil { + return + } + // Reasoning (string or list). + if reasoning := extractReasoning(data); len(reasoning) > 0 { + // Open a fresh reasoning context if none is open; reuse the + // open one otherwise. + ctxID := st.reasoningCtx + if ctxID == "" { + ctxID = "reasoning-" + utils.GenerateExecutionID() + if !write(agui.ReasoningStart{MessageID: ctxID, Timestamp: agui.NowMillis()}) { + return + } + st.reasoningCtx = ctxID + } + for i, seg := range reasoning { + segID := seg.ID + if segID == "" { + segID = fmt.Sprintf("%s-final-%d", ctxID, i) + } + write(agui.ReasoningMessageStart{MessageID: segID, Role: "reasoning", Timestamp: agui.NowMillis()}) + for _, chunk := range chunkText(seg.Content, AGUITextChunkSize) { + write(agui.ReasoningMessageContent{MessageID: segID, Delta: chunk, Timestamp: agui.NowMillis()}) + } + write(agui.ReasoningMessageEnd{MessageID: segID, Timestamp: agui.NowMillis()}) + } + } + // Tool calls. + for _, tc := range extractToolCalls(data) { + argsJSON, _ := json.Marshal(tc.Arguments) + argsStr := string(argsJSON) + write(agui.ToolCallStart{ToolCallID: tc.ID, ToolCallName: tc.Name, ParentMessageID: st.messageID, Timestamp: agui.NowMillis()}) + write(agui.ToolCallArgs{ToolCallID: tc.ID, Delta: argsStr, Timestamp: agui.NowMillis()}) + write(agui.ToolCallEnd{ToolCallID: tc.ID, Timestamp: agui.NowMillis()}) + if tc.HasResult { + write(agui.ToolCallResult{ + MessageID: "msg-toolresult-" + tc.ID, + ToolCallID: tc.ID, + Content: stringifyResult(tc.Result), + Role: "tool", + Timestamp: agui.NowMillis(), + }) + } + st.toolCalls = append(st.toolCalls, agui.ToolCall{ + ID: tc.ID, + Type: "function", + Function: agui.ToolCallFunction{Name: tc.Name, Arguments: argsStr}, + }) + } + // State. + if state, ok := extractState(data); ok { + st.stateSet = true + st.state = state + write(agui.StateSnapshot{Snapshot: state, Timestamp: agui.NowMillis()}) + } + if delta := extractStateDelta(data); delta != nil { + write(agui.StateDelta{Delta: delta, Timestamp: agui.NowMillis()}) + } + // Trailing text in `result` — append to any open text turn or open one. + if r, has := data["result"]; has { + text := stringifyResult(r) + if text != "" { + if !st.textOpen { + write(agui.TextMessageStart{MessageID: st.messageID, Role: "assistant", Timestamp: agui.NowMillis()}) + st.textOpen = true + } + for _, chunk := range chunkText(text, AGUITextChunkSize) { + write(agui.TextMessageContent{MessageID: st.messageID, Delta: chunk, Timestamp: agui.NowMillis()}) + } + st.textBuf = append(st.textBuf, text...) + } + } +} diff --git a/control-plane/internal/handlers/agui_runs_streaming_unit_test.go b/control-plane/internal/handlers/agui_runs_streaming_unit_test.go new file mode 100644 index 00000000..afadf3b8 --- /dev/null +++ b/control-plane/internal/handlers/agui_runs_streaming_unit_test.go @@ -0,0 +1,292 @@ +package handlers + +import ( + "encoding/json" + "testing" + + "github.com/Agent-Field/agentfield/control-plane/internal/agui" + + "github.com/stretchr/testify/require" +) + +// captureWriter returns a writer fn that records every emitted event, +// optionally returning false on the Nth write to exercise short-circuit paths. +func captureWriter(failOn int) (writer func(agui.Event) bool, events *[]agui.Event) { + collected := make([]agui.Event, 0, 16) + count := 0 + return func(ev agui.Event) bool { + count++ + collected = append(collected, ev) + if failOn > 0 && count >= failOn { + return false + } + return true + }, &collected +} + +func eventTypes(events []agui.Event) []string { + out := make([]string, 0, len(events)) + for _, ev := range events { + out = append(out, ev.Type()) + } + return out +} + +// TestDispatchChunk_GuardEarlyReturns walks the early-return guards on +// every chunk type that has one. None of these should write any events +// or short-circuit the loop. +func TestDispatchChunk_GuardEarlyReturns(t *testing.T) { + cases := []struct { + name string + ch streamingChunk + }{ + {"empty text delta", streamingChunk{Type: "text"}}, + {"empty reasoning delta", streamingChunk{Type: "reasoning"}}, + {"tool_call_start missing id", streamingChunk{Type: "tool_call_start", Name: "x"}}, + {"tool_call_start missing name", streamingChunk{Type: "tool_call_start", ID: "tc1"}}, + {"tool_call_args missing id", streamingChunk{Type: "tool_call_args", Delta: "x"}}, + {"tool_call_args missing delta", streamingChunk{Type: "tool_call_args", ID: "tc1"}}, + {"tool_call_end missing id", streamingChunk{Type: "tool_call_end"}}, + {"tool_call_result missing id", streamingChunk{Type: "tool_call_result", Content: "x"}}, + {"state_delta empty ops", streamingChunk{Type: "state_delta", Ops: nil}}, + {"step_started missing name", streamingChunk{Type: "step_started"}}, + {"step_finished missing name", streamingChunk{Type: "step_finished"}}, + {"custom missing name", streamingChunk{Type: "custom", Value: 1}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + write, events := captureWriter(0) + st := &streamingState{messageID: "msg-1"} + require.True(t, dispatchChunk(write, st, tc.ch), "guard branch must keep stream alive") + require.Empty(t, *events, "guard branch must not emit events") + }) + } +} + +// TestDispatchChunk_ToolCallLifecycle covers tool_call_start (with and +// without inline arguments), tool_call_args appending to an in-flight +// call, tool_call_end, and tool_call_result with both default and +// explicit role. +func TestDispatchChunk_ToolCallLifecycle(t *testing.T) { + write, events := captureWriter(0) + st := &streamingState{messageID: "msg-1"} + + require.True(t, dispatchChunk(write, st, streamingChunk{ + Type: "tool_call_start", + ID: "tc1", + Name: "showFlightCard", + Arguments: json.RawMessage(`{"flight":"AA-12"}`), + })) + // Start without inline args (a parent message should default to st.messageID). + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "tool_call_start", ID: "tc2", Name: "ping"})) + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "tool_call_args", ID: "tc2", Delta: `{"x":1}`})) + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "tool_call_end", ID: "tc2"})) + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "tool_call_result", ID: "tc2", Content: "done", Role: "system"})) + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "tool_call_result", ID: "tc1", Content: "ok"})) + + require.Equal(t, []string{ + "TOOL_CALL_START", // tc1 + "TOOL_CALL_ARGS", // tc1 inline args + "TOOL_CALL_START", // tc2 + "TOOL_CALL_ARGS", // tc2 streamed delta + "TOOL_CALL_END", // tc2 + "TOOL_CALL_RESULT", // tc2 explicit role + "TOOL_CALL_RESULT", // tc1 default role + }, eventTypes(*events)) + + require.Len(t, st.toolCalls, 2) + require.Equal(t, `{"x":1}`, st.toolCalls[1].Function.Arguments, + "tool_call_args should append to the in-flight call's arguments") + tcResultExplicit := (*events)[5].(agui.ToolCallResult) + require.Equal(t, "system", tcResultExplicit.Role) + tcResultDefault := (*events)[6].(agui.ToolCallResult) + require.Equal(t, "tool", tcResultDefault.Role) +} + +// TestDispatchChunk_StateAndSteps covers state, state_delta, step_started, +// step_finished, raw, and custom chunks on the happy path. +func TestDispatchChunk_StateAndSteps(t *testing.T) { + write, events := captureWriter(0) + st := &streamingState{messageID: "msg-1"} + + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "state", Snapshot: map[string]any{"k": 1}})) + require.True(t, st.stateSet) + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "state_delta", Ops: []any{ + map[string]any{"op": "replace", "path": "/k", "value": 2}, + }})) + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "step_started", Name: "plan"})) + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "step_finished", Name: "plan"})) + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "raw", Event: map[string]any{"k": 1}, Source: "ext"})) + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "custom", Name: "ack", Value: true})) + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "unknown_kind"})) + + require.Equal(t, []string{ + "STATE_SNAPSHOT", + "STATE_DELTA", + "STEP_STARTED", + "STEP_FINISHED", + "RAW", + "CUSTOM", + "RAW", // unknown chunk falls into default → emits RAW + }, eventTypes(*events)) +} + +// TestDispatchChunk_ErrorChunkTerminates verifies the error chunk emits +// RUN_ERROR and returns false to short-circuit the dispatch loop. +func TestDispatchChunk_ErrorChunkTerminates(t *testing.T) { + write, events := captureWriter(0) + st := &streamingState{messageID: "msg-1"} + + require.False(t, dispatchChunk(write, st, streamingChunk{ + Type: "error", + Message: "boom", + Code: "E_BOOM", + }), "error chunk must short-circuit the dispatch loop") + require.Equal(t, []string{"RUN_ERROR"}, eventTypes(*events)) + + runErr := (*events)[0].(agui.RunError) + require.Equal(t, "boom", runErr.Message) + require.Equal(t, "E_BOOM", runErr.Code) +} + +// TestDispatchChunk_ReasoningEndIdempotent confirms reasoning_end is a +// no-op when no reasoning segment is open and emits the End frame when +// one is. +func TestDispatchChunk_ReasoningEndIdempotent(t *testing.T) { + write, events := captureWriter(0) + st := &streamingState{messageID: "msg-1"} + + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "reasoning_end"})) + require.Empty(t, *events, "reasoning_end is a no-op without an open segment") + + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "reasoning", Delta: "thinking..."})) + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "reasoning_end"})) + require.Equal(t, []string{ + "REASONING_START", + "REASONING_MESSAGE_START", + "REASONING_MESSAGE_CONTENT", + "REASONING_MESSAGE_END", + }, eventTypes(*events)) + require.Empty(t, st.reasoningSeg, "reasoning_end clears the open segment id") + require.NotEmpty(t, st.reasoningCtx, "reasoning_end leaves the outer context open") +} + +// TestDispatchChunk_TextClosesReasoning ensures a text chunk closes any +// open reasoning session before opening the assistant text turn. +func TestDispatchChunk_TextClosesReasoning(t *testing.T) { + write, events := captureWriter(0) + st := &streamingState{messageID: "msg-1"} + + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "reasoning", Delta: "thought"})) + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "text", Delta: "hello"})) + + types := eventTypes(*events) + // REASONING_MESSAGE_END + REASONING_END must precede TEXT_MESSAGE_START. + require.Contains(t, types, "REASONING_MESSAGE_END") + require.Contains(t, types, "REASONING_END") + require.Contains(t, types, "TEXT_MESSAGE_START") + require.Contains(t, types, "TEXT_MESSAGE_CONTENT") + require.Empty(t, st.reasoningCtx) + require.Empty(t, st.reasoningSeg) + require.True(t, st.textOpen) +} + +// TestApplyFinal_FullEnvelope drives applyFinal with reasoning, +// toolCalls (with and without result), state, stateDelta, and result +// fields all populated. +func TestApplyFinal_FullEnvelope(t *testing.T) { + write, events := captureWriter(0) + st := &streamingState{messageID: "msg-1"} + + applyFinal(write, st, map[string]any{ + "reasoning": []any{"step 1", map[string]any{"content": "step 2", "id": "r-1"}}, + "toolCalls": []any{ + map[string]any{"id": "tc1", "name": "x", "arguments": map[string]any{"a": 1}, "result": "ok"}, + map[string]any{"id": "tc2", "name": "y", "arguments": map[string]any{}}, + }, + "state": map[string]any{"counter": 7}, + "stateDelta": []any{map[string]any{"op": "replace", "path": "/counter", "value": 8}}, + "result": "Done.", + }) + + types := eventTypes(*events) + require.Contains(t, types, "REASONING_START") + require.Contains(t, types, "REASONING_MESSAGE_START") + require.Contains(t, types, "REASONING_MESSAGE_CONTENT") + require.Contains(t, types, "REASONING_MESSAGE_END") + require.Contains(t, types, "TOOL_CALL_START") + require.Contains(t, types, "TOOL_CALL_ARGS") + require.Contains(t, types, "TOOL_CALL_END") + require.Contains(t, types, "TOOL_CALL_RESULT") + require.Contains(t, types, "STATE_SNAPSHOT") + require.Contains(t, types, "STATE_DELTA") + require.Contains(t, types, "TEXT_MESSAGE_START") + require.Contains(t, types, "TEXT_MESSAGE_CONTENT") + + require.True(t, st.textOpen, "final result text leaves the text session open for stream-end to close") + require.True(t, st.stateSet) + require.Len(t, st.toolCalls, 2) +} + +// TestApplyFinal_NilDataIsNoOp confirms a nil data map is silently +// dropped — the reasoner can emit a final chunk without any structured +// fields. +func TestApplyFinal_NilDataIsNoOp(t *testing.T) { + write, events := captureWriter(0) + st := &streamingState{messageID: "msg-1"} + applyFinal(write, st, nil) + require.Empty(t, *events) + require.False(t, st.textOpen) +} + +// TestApplyFinal_ReusesOpenReasoningContext verifies that when a +// reasoning context is already open, applyFinal appends segments inside +// it instead of opening a new outer context. +func TestApplyFinal_ReusesOpenReasoningContext(t *testing.T) { + write, events := captureWriter(0) + st := &streamingState{messageID: "msg-1"} + + require.True(t, dispatchChunk(write, st, streamingChunk{Type: "reasoning", Delta: "first"})) + priorReasoningStarts := 0 + for _, ev := range *events { + if ev.Type() == "REASONING_START" { + priorReasoningStarts++ + } + } + require.Equal(t, 1, priorReasoningStarts) + + applyFinal(write, st, map[string]any{"reasoning": []any{"another"}}) + + totalStarts := 0 + for _, ev := range *events { + if ev.Type() == "REASONING_START" { + totalStarts++ + } + } + require.Equal(t, 1, totalStarts, "applyFinal must reuse the already-open reasoning context") +} + +// TestCloseSessions_NoOpWhenIdle covers the early-return branches in +// closeTextSession and closeReasoningSession when no session is open. +func TestCloseSessions_NoOpWhenIdle(t *testing.T) { + write, events := captureWriter(0) + st := &streamingState{messageID: "msg-1"} + require.True(t, closeTextSession(write, st)) + require.True(t, closeReasoningSession(write, st)) + require.Empty(t, *events) +} + +// TestCloseSessions_WriteFailureShortCircuits covers the rare case where +// the writer returns false mid-close (client disconnect): the close +// helpers must propagate the failure so the dispatch loop can stop. +func TestCloseSessions_WriteFailureShortCircuits(t *testing.T) { + st := &streamingState{ + messageID: "msg-1", + textOpen: true, + reasoningSeg: "seg-1", + reasoningCtx: "ctx-1", + } + failingWrite := func(agui.Event) bool { return false } + require.False(t, closeTextSession(failingWrite, st)) + require.False(t, closeReasoningSession(failingWrite, st)) +} diff --git a/control-plane/internal/handlers/agui_runs_test.go b/control-plane/internal/handlers/agui_runs_test.go new file mode 100644 index 00000000..bd52236d --- /dev/null +++ b/control-plane/internal/handlers/agui_runs_test.go @@ -0,0 +1,1229 @@ +package handlers + +import ( + "bufio" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Agent-Field/agentfield/control-plane/pkg/types" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// aguiFrame is a parsed SSE frame: just the JSON object decoded from the +// `data:` line. The canonical AG-UI encoder emits frames as `data: \n\n` +// only — no `event:` line — so the JSON `type` field is the sole discriminator. +type aguiFrame struct { + Data map[string]any +} + +func (f aguiFrame) Type() string { + t, _ := f.Data["type"].(string) + return t +} + +// parseAGUIStream splits an SSE response body into one frame per AG-UI event. +// Strict on shape: every frame must be `data: \n\n`. We assert against +// the strictness because that's exactly what the AG-UI spec guarantees and +// what the reference encoders emit (see ag-ui-protocol/ag-ui encoder.ts / +// encoder.py). +func parseAGUIStream(t *testing.T, body string) []aguiFrame { + t.Helper() + var frames []aguiFrame + scanner := bufio.NewScanner(strings.NewReader(body)) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + + var curData string + flush := func() { + if curData == "" { + return + } + var decoded map[string]any + require.NoError(t, json.Unmarshal([]byte(curData), &decoded), "data line is not JSON: %s", curData) + frames = append(frames, aguiFrame{Data: decoded}) + curData = "" + } + + for scanner.Scan() { + line := scanner.Text() + switch { + case line == "": + flush() + case strings.HasPrefix(line, "event:"): + t.Fatalf("AG-UI frames must not include an `event:` line; got: %q", line) + case strings.HasPrefix(line, "data: "): + curData = strings.TrimPrefix(line, "data: ") + } + } + flush() + return frames +} + +func mountAGUIRouter(t *testing.T, store *reasonerTestStorage) *gin.Engine { + t.Helper() + gin.SetMode(gin.TestMode) + router := gin.New() + router.POST("/api/v1/agui/runs/:node_id/:reasoner_name", AGUIRunHandler(store)) + return router +} + +// runAgentInputBody returns a canonical RunAgentInputSchema-shaped body. The +// vanilla @ag-ui/client HttpAgent — and therefore CopilotKit's runtime that +// wraps it — POSTs exactly this shape. Tests should always go through this +// helper so the assertion about "we accept the canonical shape" is real. +func runAgentInputBody(t *testing.T, threadID, runID, prompt string) string { + t.Helper() + body := map[string]any{ + "threadId": threadID, + "runId": runID, + "messages": []map[string]any{ + {"id": "u1", "role": "user", "content": prompt}, + }, + "tools": []any{}, + "context": []any{}, + "state": map[string]any{}, + "forwardedProps": map[string]any{}, + } + b, err := json.Marshal(body) + require.NoError(t, err) + return string(b) +} + +// TestAGUIRunHandler_HappyPath_EmitsCanonicalEventSequence is the core +// assertion: a successful run produces RUN_STARTED → TEXT_MESSAGE_START → +// TEXT_MESSAGE_CONTENT → TEXT_MESSAGE_END → MESSAGES_SNAPSHOT → RUN_FINISHED, +// in that order. Thread/run IDs propagate from the request to RUN_FINISHED. +// The reasoner sees the AG-UI envelope (prompt extracted from the trailing +// user message) — proving the body-shape change wired up correctly. +func TestAGUIRunHandler_HappyPath_EmitsCanonicalEventSequence(t *testing.T) { + var seenInput map[string]any + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/reasoners/echo", r.URL.Path) + require.Equal(t, http.MethodPost, r.Method) + raw, _ := io.ReadAll(r.Body) + require.NoError(t, json.Unmarshal(raw, &seenInput)) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"result":"hello world"}`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "echo"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/echo", + strings.NewReader(runAgentInputBody(t, "thread-test", "run-test", "hi"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, "response: %s", w.Body.String()) + require.Equal(t, "text/event-stream", w.Header().Get("Content-Type")) + + // The reasoner received the canonical AG-UI envelope, plus the `prompt` + // convenience extracted from the trailing user message. + require.Equal(t, "hi", seenInput["prompt"]) + require.Equal(t, "thread-test", seenInput["threadId"]) + require.Equal(t, "run-test", seenInput["runId"]) + gotMessages, _ := seenInput["messages"].([]any) + require.Len(t, gotMessages, 1) + + frames := parseAGUIStream(t, w.Body.String()) + wantSequence := []string{ + "RUN_STARTED", + "TEXT_MESSAGE_START", + "TEXT_MESSAGE_CONTENT", + "TEXT_MESSAGE_END", + "MESSAGES_SNAPSHOT", + "RUN_FINISHED", + } + require.Len(t, frames, len(wantSequence), "frames: %+v", frames) + for i, want := range wantSequence { + require.Equal(t, want, frames[i].Type(), "frame %d: %v", i, frames[i].Data) + } + + require.Equal(t, "thread-test", frames[0].Data["threadId"]) + require.Equal(t, "run-test", frames[0].Data["runId"]) + require.NotContains(t, frames[0].Data, "input", + "input must be omitted; the spec types it as RunAgentInput, not freeform") + + msgID, _ := frames[1].Data["messageId"].(string) + require.NotEmpty(t, msgID) + require.Equal(t, "assistant", frames[1].Data["role"]) + require.Equal(t, msgID, frames[2].Data["messageId"]) + require.Equal(t, "hello world", frames[2].Data["delta"]) + require.Equal(t, msgID, frames[3].Data["messageId"]) + + // MESSAGES_SNAPSHOT carries inbound history + the new assistant turn, + // and the assistant's content matches the delta we emitted. + snapMsgs, _ := frames[4].Data["messages"].([]any) + require.Len(t, snapMsgs, 2, "snapshot should have 1 user + 1 assistant message") + last, _ := snapMsgs[1].(map[string]any) + require.Equal(t, "assistant", last["role"]) + require.Equal(t, "hello world", last["content"]) + require.Equal(t, msgID, last["id"]) + + // RUN_FINISHED carries threadId/runId, success outcome, and the parsed + // agent JSON. + require.Equal(t, "thread-test", frames[5].Data["threadId"]) + require.Equal(t, "run-test", frames[5].Data["runId"]) + outcome, _ := frames[5].Data["outcome"].(map[string]any) + require.Equal(t, "success", outcome["type"]) + require.Equal(t, map[string]any{"result": "hello world"}, frames[5].Data["result"]) + + if ts, ok := frames[0].Data["timestamp"]; ok { + _, isFloat := ts.(float64) + require.True(t, isFloat, "timestamp must be a number, got %T", ts) + } +} + +// TestAGUIRunHandler_GeneratesIDsWhenAbsent confirms that omitted threadId +// and runId are auto-populated rather than left empty — clients shouldn't +// have to mint IDs themselves to get a valid AG-UI stream. +func TestAGUIRunHandler_GeneratesIDsWhenAbsent(t *testing.T) { + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"result":"ok"}`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "echo"}}, + }} + router := mountAGUIRouter(t, store) + + // Omit threadId and runId — vanilla HttpAgent always sends them, but a + // test client may not. + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/echo", + strings.NewReader(`{"messages":[{"role":"user","content":"hi"}]}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, w.Body.String()) + frames := parseAGUIStream(t, w.Body.String()) + require.NotEmpty(t, frames) + require.Equal(t, "RUN_STARTED", frames[0].Type()) + threadID, _ := frames[0].Data["threadId"].(string) + runID, _ := frames[0].Data["runId"].(string) + require.NotEmpty(t, threadID, "threadId should be auto-generated") + require.NotEmpty(t, runID, "runId should be auto-generated") + + last := frames[len(frames)-1] + require.Equal(t, "RUN_FINISHED", last.Type()) + require.Equal(t, threadID, last.Data["threadId"]) + require.Equal(t, runID, last.Data["runId"]) +} + +// TestAGUIRunHandler_AgentFailureEmitsRunError confirms the streaming-side +// error path: once SSE is open, downstream agent failure must surface as a +// terminal RUN_ERROR frame, never as a partial happy-path-shaped sequence. +func TestAGUIRunHandler_AgentFailureEmitsRunError(t *testing.T) { + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"upstream blew up"}`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "boom"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/boom", + strings.NewReader(runAgentInputBody(t, "t", "r", "x"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, w.Body.String()) + frames := parseAGUIStream(t, w.Body.String()) + require.GreaterOrEqual(t, len(frames), 2) + require.Equal(t, "RUN_STARTED", frames[0].Type()) + + last := frames[len(frames)-1] + require.Equal(t, "RUN_ERROR", last.Type()) + require.NotEmpty(t, last.Data["message"]) + require.Equal(t, "ERR_AGENT_CALL", last.Data["code"]) + + for _, f := range frames[1:] { + require.NotContains(t, + []string{"TEXT_MESSAGE_START", "TEXT_MESSAGE_CONTENT", "TEXT_MESSAGE_END", "MESSAGES_SNAPSHOT", "RUN_FINISHED"}, + f.Type(), "unexpected post-error frame: %s", f.Type()) + } +} + +// TestAGUIRunHandler_EmitsHeartbeatWhileReasonerIsSlow confirms long-running +// reasoners produce SSE comment frames (`: keep-alive`) so proxies don't +// idle-time-out the connection. Comments are invisible to AG-UI clients but +// keep intermediaries happy. +func TestAGUIRunHandler_EmitsHeartbeatWhileReasonerIsSlow(t *testing.T) { + prev := AGUIHeartbeatInterval + AGUIHeartbeatInterval = 50 * time.Millisecond + defer func() { AGUIHeartbeatInterval = prev }() + + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + time.Sleep(250 * time.Millisecond) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"result":"finally"}`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "slow"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/slow", + strings.NewReader(runAgentInputBody(t, "t", "r", "x"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, w.Body.String()) + body := w.Body.String() + require.Contains(t, body, ": keep-alive", + "expected at least one SSE comment heartbeat in:\n%s", body) + + frames := parseAGUIStream(t, body) + require.Equal(t, "RUN_STARTED", frames[0].Type()) + require.Equal(t, "RUN_FINISHED", frames[len(frames)-1].Type()) +} + +// TestAGUIRunHandler_AgentBodyWithoutResultKey covers the fallthrough in +// extractAssistantText: when the agent returns a JSON object that doesn't +// have `result` or `content`, internal-only keys (toolCalls, state) are +// stripped and the rest is JSON-encoded as the delta. +func TestAGUIRunHandler_AgentBodyWithoutResultKey(t *testing.T) { + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"ok","count":3}`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "ping"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/ping", + strings.NewReader(runAgentInputBody(t, "t", "r", "x"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, w.Body.String()) + frames := parseAGUIStream(t, w.Body.String()) + // content frame is index 2 in the canonical sequence. + require.Equal(t, `{"count":3,"status":"ok"}`, frames[2].Data["delta"]) +} + +// TestStringifyResult_BranchCoverage covers the cheap branches of the +// helper directly: string passthrough, nil, and arbitrary value JSON-encode. +func TestStringifyResult_BranchCoverage(t *testing.T) { + require.Equal(t, "hello", stringifyResult("hello")) + require.Equal(t, "", stringifyResult(nil)) + require.Equal(t, `[1,2,3]`, stringifyResult([]any{1, 2, 3})) + require.Equal(t, `{"a":1}`, stringifyResult(map[string]any{"a": 1})) +} + +// TestExtractAssistantText_AllBranches exercises the helper directly so +// every priority rung is covered: result key, content key, top-level +// string, top-level non-map non-string (number), filtered-empty map, and +// the non-JSON raw-body fallthrough. +func TestExtractAssistantText_AllBranches(t *testing.T) { + require.Equal(t, "raw bytes", extractAssistantText(nil, false, []byte("raw bytes")), + "non-JSON falls through to raw body") + require.Equal(t, "answer", extractAssistantText(map[string]any{"result": "answer"}, true, nil), + "`result` key wins") + require.Equal(t, "alt", extractAssistantText(map[string]any{"content": "alt"}, true, nil), + "`content` key is the second priority") + require.Equal(t, "just-a-string", extractAssistantText("just-a-string", true, nil), + "top-level JSON string passes through") + require.Equal(t, "42", extractAssistantText(float64(42), true, nil), + "top-level non-map non-string is JSON-encoded") + require.Equal(t, "", extractAssistantText(map[string]any{"toolCalls": []any{}, "state": map[string]any{}}, true, nil), + "a body containing only internal-only fields collapses to empty delta") +} + +// TestExtractToolCalls_NonMapInput covers the non-map branch (e.g. the +// reasoner returned a top-level string or array — no toolCalls possible). +func TestExtractToolCalls_NonMapInput(t *testing.T) { + require.Nil(t, extractToolCalls("just a string")) + require.Nil(t, extractToolCalls([]any{1, 2, 3})) + require.Nil(t, extractToolCalls(nil)) + // Map without a `toolCalls` array also returns nil. + require.Nil(t, extractToolCalls(map[string]any{"result": "x"})) +} + +// TestExtractState_NonMapAndAbsent covers both the non-map and the +// missing-key paths. +func TestExtractState_NonMapAndAbsent(t *testing.T) { + _, ok := extractState("not a map") + require.False(t, ok) + _, ok = extractState(map[string]any{"result": "x"}) + require.False(t, ok, "absent state key returns ok=false") + v, ok := extractState(map[string]any{"state": nil}) + require.True(t, ok, "explicit null state still returns ok=true") + require.Nil(t, v) +} + +// TestExtractStateDelta covers presence, non-map, and empty cases. +func TestExtractStateDelta(t *testing.T) { + require.Nil(t, extractStateDelta("not a map")) + require.Nil(t, extractStateDelta(map[string]any{}), "absent stateDelta key") + require.Nil(t, extractStateDelta(map[string]any{"stateDelta": []any{}}), + "empty stateDelta is treated as absent") + d := extractStateDelta(map[string]any{"stateDelta": []any{ + map[string]any{"op": "replace", "path": "/x", "value": 1}, + }}) + require.Len(t, d, 1) +} + +// TestChunkText covers the token-streaming chunker: rune boundaries, +// empty input, oversize input, exact boundary. +func TestChunkText(t *testing.T) { + require.Equal(t, []string{""}, chunkText("", 4)) + require.Equal(t, []string{"abc"}, chunkText("abc", 4)) + require.Equal(t, []string{"abcd", "ef"}, chunkText("abcdef", 4)) + require.Equal(t, []string{"hello"}, chunkText("hello", -1), "non-positive size returns input unchanged") + // Multi-byte runes (emoji) must split on rune boundaries. + emoji := "🤖🤖🤖" + chunks := chunkText(emoji, 4) + for _, c := range chunks { + require.Equal(t, "🤖", c, "each chunk should hold exactly one emoji at size=4") + } + require.Equal(t, 3, len(chunks)) +} + +// TestAGUIRunHandler_ToolCalls_EmitsResultEventForServerSideCalls covers +// the .ai(tools=...) trace surfacing path: when a reasoner reports a tool +// call as already-executed by including a `result` field, the handler +// emits TOOL_CALL_RESULT after TOOL_CALL_END. +func TestAGUIRunHandler_ToolCalls_EmitsResultEventForServerSideCalls(t *testing.T) { + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "result":"queried weather", + "toolCalls":[{ + "id":"tc-w1","name":"getWeather", + "arguments":{"city":"SF"}, + "result":{"temp":62,"summary":"foggy"} + }] + }`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "weather"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/weather", + strings.NewReader(runAgentInputBody(t, "t", "r", "weather in SF?"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, w.Body.String()) + frames := parseAGUIStream(t, w.Body.String()) + + // TOOL_CALL_RESULT must come immediately after TOOL_CALL_END for the + // same toolCallId. + idx := func(typ string) int { + for i, f := range frames { + if f.Type() == typ { + return i + } + } + return -1 + } + require.Less(t, idx("TOOL_CALL_END"), idx("TOOL_CALL_RESULT"), + "TOOL_CALL_RESULT must follow TOOL_CALL_END") + resFrame := frames[idx("TOOL_CALL_RESULT")] + require.Equal(t, "tc-w1", resFrame.Data["toolCallId"]) + require.Equal(t, "tool", resFrame.Data["role"]) + require.JSONEq(t, `{"summary":"foggy","temp":62}`, resFrame.Data["content"].(string)) +} + +// TestAGUIRunHandler_StateDelta covers Tier 3's incremental-patch path: +// when the reasoner returns `stateDelta` (RFC 6902), STATE_DELTA is +// emitted alongside (or instead of) STATE_SNAPSHOT. +func TestAGUIRunHandler_StateDelta(t *testing.T) { + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "result":"updated", + "state":{"counter":1}, + "stateDelta":[{"op":"replace","path":"/counter","value":2}] + }`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "tick"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/tick", + strings.NewReader(runAgentInputBody(t, "t", "r", "tick"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + frames := parseAGUIStream(t, w.Body.String()) + + // Both forms emitted, snapshot first. + idx := func(typ string) int { + for i, f := range frames { + if f.Type() == typ { + return i + } + } + return -1 + } + require.NotEqual(t, -1, idx("STATE_SNAPSHOT")) + require.NotEqual(t, -1, idx("STATE_DELTA")) + require.Less(t, idx("STATE_SNAPSHOT"), idx("STATE_DELTA")) + delta, _ := frames[idx("STATE_DELTA")].Data["delta"].([]any) + require.Len(t, delta, 1) + op, _ := delta[0].(map[string]any) + require.Equal(t, "replace", op["op"]) + require.Equal(t, "/counter", op["path"]) +} + +// TestAGUIRunHandler_Reasoning_StringForm: a reasoner returning a single +// reasoning string emits REASONING_START → _MESSAGE_START → _CONTENT → +// _MESSAGE_END → REASONING_END before the assistant text turn. +func TestAGUIRunHandler_Reasoning_StringForm(t *testing.T) { + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "result":"Booked.", + "reasoning":"Checked flights, AA-12 is cheapest." + }`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "n", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "think"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/n/think", + strings.NewReader(runAgentInputBody(t, "t", "r", "x"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, w.Body.String()) + frames := parseAGUIStream(t, w.Body.String()) + + idx := func(typ string) int { + for i, f := range frames { + if f.Type() == typ { + return i + } + } + return -1 + } + for _, want := range []string{"REASONING_START", "REASONING_MESSAGE_START", "REASONING_MESSAGE_CONTENT", "REASONING_MESSAGE_END", "REASONING_END"} { + require.NotEqual(t, -1, idx(want), "missing %s in stream", want) + } + // REASONING_* must come before TEXT_MESSAGE_START. + require.Less(t, idx("REASONING_END"), idx("TEXT_MESSAGE_START")) + require.Equal(t, "reasoning", frames[idx("REASONING_MESSAGE_START")].Data["role"]) + require.Equal(t, "Checked flights, AA-12 is cheapest.", + frames[idx("REASONING_MESSAGE_CONTENT")].Data["delta"]) +} + +// TestAGUIRunHandler_Reasoning_ListForm: a reasoner returning a list of +// reasoning segments produces one REASONING_MESSAGE_* triad per segment. +func TestAGUIRunHandler_Reasoning_ListForm(t *testing.T) { + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "result":"Done.", + "reasoning":[ + "first thought", + {"id":"r-2","content":"second thought"} + ] + }`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "n", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "think"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/n/think", + strings.NewReader(runAgentInputBody(t, "t", "r", "x"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, w.Body.String()) + frames := parseAGUIStream(t, w.Body.String()) + starts, contents, ends := 0, 0, 0 + var contentDeltas []string + for _, f := range frames { + switch f.Type() { + case "REASONING_MESSAGE_START": + starts++ + case "REASONING_MESSAGE_CONTENT": + contents++ + d, _ := f.Data["delta"].(string) + contentDeltas = append(contentDeltas, d) + case "REASONING_MESSAGE_END": + ends++ + } + } + require.Equal(t, 2, starts, "two segments → two STARTs") + require.Equal(t, 2, ends, "two segments → two ENDs") + require.Equal(t, 2, contents, "each segment fits in one content chunk at default size") + require.Equal(t, []string{"first thought", "second thought"}, contentDeltas) +} + +// TestExtractReasoning covers all input shapes the helper accepts plus +// the reject paths (non-map parsed value, empty content, missing key). +func TestExtractReasoning(t *testing.T) { + require.Nil(t, extractReasoning("not a map")) + require.Nil(t, extractReasoning(map[string]any{}), "missing key") + require.Nil(t, extractReasoning(map[string]any{"reasoning": nil}), "explicit null") + require.Nil(t, extractReasoning(map[string]any{"reasoning": ""}), "empty string") + require.Nil(t, extractReasoning(map[string]any{"reasoning": []any{}}), "empty list") + require.Nil(t, extractReasoning(map[string]any{"reasoning": 42}), "wrong type") + + one := extractReasoning(map[string]any{"reasoning": "thinking..."}) + require.Len(t, one, 1) + require.Equal(t, "thinking...", one[0].Content) + + mixed := extractReasoning(map[string]any{"reasoning": []any{ + "first", + "", // dropped + map[string]any{"id": "r-2", "content": "second"}, // kept + map[string]any{"content": ""}, // dropped (empty content) + map[string]any{"content": "no-id"}, // synthesized id + }}) + require.Len(t, mixed, 3) + require.Equal(t, "first", mixed[0].Content) + require.Equal(t, "r-2", mixed[1].ID) + require.Equal(t, "second", mixed[1].Content) + require.Equal(t, "no-id", mixed[2].Content) + require.NotEmpty(t, mixed[2].ID, "id auto-synthesized when missing") +} + +// TestAGUIRunHandler_ChunkedTextStreaming verifies that long assistant +// replies are split across multiple TEXT_MESSAGE_CONTENT deltas (so the +// frontend can paint progressively) while the start/end frames stay +// singletons. +func TestAGUIRunHandler_ChunkedTextStreaming(t *testing.T) { + prev := AGUITextChunkSize + AGUITextChunkSize = 8 // tiny chunks so we can assert multi-frame easily + defer func() { AGUITextChunkSize = prev }() + + long := strings.Repeat("a", 25) // 25 / 8 = 4 chunks (8+8+8+1) + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"result":"` + long + `"}`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "long"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/long", + strings.NewReader(runAgentInputBody(t, "t", "r", "x"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + frames := parseAGUIStream(t, w.Body.String()) + + starts, contents, ends := 0, 0, 0 + concatenated := "" + var msgID string + for _, f := range frames { + switch f.Type() { + case "TEXT_MESSAGE_START": + starts++ + msgID, _ = f.Data["messageId"].(string) + case "TEXT_MESSAGE_CONTENT": + contents++ + require.Equal(t, msgID, f.Data["messageId"], "all content frames must share the same messageId") + d, _ := f.Data["delta"].(string) + concatenated += d + case "TEXT_MESSAGE_END": + ends++ + } + } + require.Equal(t, 1, starts, "exactly one START frame") + require.Equal(t, 1, ends, "exactly one END frame") + require.GreaterOrEqual(t, contents, 4, "expected long reply to be split into ≥4 chunks (got %d)", contents) + require.Equal(t, long, concatenated, "concatenated deltas must equal the full reply") +} + +// TestAGUIRunHandler_AgentReturnsNonJSON falls through to the +// `string(body)` branch when the agent's response isn't valid JSON. +func TestAGUIRunHandler_AgentReturnsNonJSON(t *testing.T) { + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte(`plain text answer`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "raw"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/raw", + strings.NewReader(runAgentInputBody(t, "t", "r", "x"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, w.Body.String()) + frames := parseAGUIStream(t, w.Body.String()) + require.Equal(t, "plain text answer", frames[2].Data["delta"]) +} + +// TestAGUIRunHandler_ContextCancelMidFlight covers the <-ctx.Done() branch +// in the wait loop: client cancellation during a slow reasoner must return +// cleanly without emitting any post-RUN_STARTED frames. +func TestAGUIRunHandler_ContextCancelMidFlight(t *testing.T) { + prev := AGUIHeartbeatInterval + AGUIHeartbeatInterval = time.Hour + defer func() { AGUIHeartbeatInterval = prev }() + + released := make(chan struct{}) + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-released: + case <-r.Context().Done(): + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"result":"too late"}`)) + })) + defer func() { close(released); agentServer.Close() }() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "hang"}}, + }} + router := mountAGUIRouter(t, store) + + ctx, cancel := context.WithCancel(context.Background()) + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/hang", + strings.NewReader(runAgentInputBody(t, "t", "r", "x"))).WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + done := make(chan struct{}) + go func() { + router.ServeHTTP(w, req) + close(done) + }() + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if strings.Contains(w.Body.String(), `"type":"RUN_STARTED"`) { + break + } + time.Sleep(5 * time.Millisecond) + } + require.Contains(t, w.Body.String(), `"type":"RUN_STARTED"`, "RUN_STARTED should arrive before cancel") + cancel() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handler did not return within 2s of context cancel") + } + + body := w.Body.String() + require.NotContains(t, body, "TEXT_MESSAGE_START") + require.NotContains(t, body, "RUN_FINISHED") +} + +// TestAGUIRunHandler_RejectsMalformedJSON covers the c.ShouldBindJSON error +// branch — completely invalid request bodies must be rejected as 400 before +// any of the agent lookup or stream-opening logic runs. +func TestAGUIRunHandler_RejectsMalformedJSON(t *testing.T) { + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: "http://unused", + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "echo"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/echo", strings.NewReader("not-json")) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code, w.Body.String()) + require.NotEqual(t, "text/event-stream", w.Header().Get("Content-Type")) +} + +// TestAGUIRunHandler_ValidationErrorsReturnJSON: pre-stream validation +// errors come back as plain JSON 4xx, never as an SSE stream. Once we emit +// RUN_STARTED the contract becomes "you'll see RUN_ERROR on failure" — but +// until the first frame, conventional REST errors win. +func TestAGUIRunHandler_ValidationErrorsReturnJSON(t *testing.T) { + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: "http://unused", + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "echo"}}, + }} + router := mountAGUIRouter(t, store) + + cases := []struct { + name string + path string + wantCode int + wantMsg string + }{ + {"unknown node", "/api/v1/agui/runs/missing-node/echo", http.StatusNotFound, "node 'missing-node' not found"}, + {"unknown reasoner on known node", "/api/v1/agui/runs/node-1/does-not-exist", http.StatusNotFound, "reasoner 'does-not-exist' not found"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, tc.path, strings.NewReader(runAgentInputBody(t, "t", "r", "x"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + require.Equal(t, tc.wantCode, w.Code, w.Body.String()) + require.NotEqual(t, "text/event-stream", w.Header().Get("Content-Type"), + "validation errors must not open the SSE stream") + require.Contains(t, w.Body.String(), tc.wantMsg) + }) + } +} + +// TestAGUIRunHandler_ToolCalls_EmitsTriadAndAttachesToAssistantSnapshot +// covers Tier 2: when the reasoner declares a tool call (synthetic shape +// `{"toolCalls":[{id,name,arguments}]}`), the handler must emit +// TOOL_CALL_START → _ARGS → _END (BEFORE TEXT_MESSAGE_*) and attach the +// tool calls to the assistant turn in MESSAGES_SNAPSHOT — the wire shape +// CopilotKit's frontend pattern-matches against `useCopilotAction` to drive +// Generative UI. +func TestAGUIRunHandler_ToolCalls_EmitsTriadAndAttachesToAssistantSnapshot(t *testing.T) { + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "result":"booking your flight", + "toolCalls":[ + {"id":"tc1","name":"showFlightCard","arguments":{"from":"SFO","to":"JFK"}} + ] + }`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "agent"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/agent", + strings.NewReader(runAgentInputBody(t, "t", "r", "book me SFO->JFK"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, w.Body.String()) + frames := parseAGUIStream(t, w.Body.String()) + + wantTypes := []string{ + "RUN_STARTED", + "TOOL_CALL_START", + "TOOL_CALL_ARGS", + "TOOL_CALL_END", + "TEXT_MESSAGE_START", + "TEXT_MESSAGE_CONTENT", + "TEXT_MESSAGE_END", + "MESSAGES_SNAPSHOT", + "RUN_FINISHED", + } + require.Len(t, frames, len(wantTypes)) + for i, want := range wantTypes { + require.Equal(t, want, frames[i].Type(), "frame %d: %v", i, frames[i].Data) + } + + require.Equal(t, "tc1", frames[1].Data["toolCallId"]) + require.Equal(t, "showFlightCard", frames[1].Data["toolCallName"]) + // parentMessageId stitches the tool call into the assistant turn. + require.NotEmpty(t, frames[1].Data["parentMessageId"]) + require.Equal(t, frames[1].Data["parentMessageId"], frames[4].Data["messageId"]) + + require.Equal(t, "tc1", frames[2].Data["toolCallId"]) + require.JSONEq(t, `{"from":"SFO","to":"JFK"}`, frames[2].Data["delta"].(string)) + require.Equal(t, "tc1", frames[3].Data["toolCallId"]) + + require.Equal(t, "booking your flight", frames[5].Data["delta"]) + + // MESSAGES_SNAPSHOT carries the tool-call attached to the assistant turn. + snap, _ := frames[7].Data["messages"].([]any) + require.Len(t, snap, 2) + assistant, _ := snap[1].(map[string]any) + require.Equal(t, "assistant", assistant["role"]) + tcs, _ := assistant["toolCalls"].([]any) + require.Len(t, tcs, 1) + tc, _ := tcs[0].(map[string]any) + require.Equal(t, "tc1", tc["id"]) + require.Equal(t, "function", tc["type"]) + fn, _ := tc["function"].(map[string]any) + require.Equal(t, "showFlightCard", fn["name"]) + require.JSONEq(t, `{"from":"SFO","to":"JFK"}`, fn["arguments"].(string)) +} + +// TestAGUIRunHandler_ToolCalls_AutoIDIfMissing covers the synthetic-id +// fallback in extractToolCalls. +func TestAGUIRunHandler_ToolCalls_AutoIDIfMissing(t *testing.T) { + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"result":"ok","toolCalls":[{"name":"alpha"}]}`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "a"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/a", + strings.NewReader(runAgentInputBody(t, "t", "r", "x"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + frames := parseAGUIStream(t, w.Body.String()) + require.Equal(t, "TOOL_CALL_START", frames[1].Type()) + id, _ := frames[1].Data["toolCallId"].(string) + require.NotEmpty(t, id, "tool-call id must be auto-generated when missing") + // Same id must propagate through the triad. + require.Equal(t, id, frames[2].Data["toolCallId"]) + require.Equal(t, id, frames[3].Data["toolCallId"]) +} + +// TestAGUIRunHandler_ToolCalls_SkipsMalformedEntries — a tool-call with no +// name is silently dropped (rather than failing the whole turn). Mirrors the +// extractToolCalls guards. +func TestAGUIRunHandler_ToolCalls_SkipsMalformedEntries(t *testing.T) { + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"result":"ok","toolCalls":[ + {"id":"x","name":""}, + "not-an-object", + {"id":"y","name":"good","arguments":{}} + ]}`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "a"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/a", + strings.NewReader(runAgentInputBody(t, "t", "r", "x"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + frames := parseAGUIStream(t, w.Body.String()) + starts := 0 + for _, f := range frames { + if f.Type() == "TOOL_CALL_START" { + starts++ + require.Equal(t, "good", f.Data["toolCallName"]) + } + } + require.Equal(t, 1, starts, "only the well-formed tool call should be emitted") +} + +// TestAGUIRunHandler_State_EmitsSnapshotAndForwardsInbound covers Tier 3: +// the inbound `state` field on RunAgentInput must reach the reasoner, and a +// reasoner-returned `state` field must be re-emitted as a STATE_SNAPSHOT +// before MESSAGES_SNAPSHOT and RUN_FINISHED. +func TestAGUIRunHandler_State_EmitsSnapshotAndForwardsInbound(t *testing.T) { + var seenInput map[string]any + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + require.NoError(t, json.Unmarshal(raw, &seenInput)) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"result":"counter incremented","state":{"counter":2}}`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "stateful"}}, + }} + router := mountAGUIRouter(t, store) + + body := `{ + "threadId":"t","runId":"r", + "messages":[{"role":"user","content":"increment"}], + "state":{"counter":1} + }` + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/stateful", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, w.Body.String()) + + // Inbound state landed on the reasoner. + gotState, _ := seenInput["state"].(map[string]any) + require.EqualValues(t, 1, gotState["counter"]) + + frames := parseAGUIStream(t, w.Body.String()) + // Find STATE_SNAPSHOT in the stream and verify it carries the new value. + var snap aguiFrame + for _, f := range frames { + if f.Type() == "STATE_SNAPSHOT" { + snap = f + break + } + } + require.NotEmpty(t, snap.Data, "STATE_SNAPSHOT must be emitted when reasoner returns state") + snapVal, _ := snap.Data["snapshot"].(map[string]any) + require.EqualValues(t, 2, snapVal["counter"]) + + // Order: STATE_SNAPSHOT after TEXT_MESSAGE_END but before MESSAGES_SNAPSHOT. + idx := func(typ string) int { + for i, f := range frames { + if f.Type() == typ { + return i + } + } + return -1 + } + require.Less(t, idx("TEXT_MESSAGE_END"), idx("STATE_SNAPSHOT")) + require.Less(t, idx("STATE_SNAPSHOT"), idx("MESSAGES_SNAPSHOT")) + require.Less(t, idx("MESSAGES_SNAPSHOT"), idx("RUN_FINISHED")) +} + +// TestAGUIRunHandler_State_OmittedWhenReasonerDoesNotReturnIt — Tier 3 +// doesn't synthesize a STATE_SNAPSHOT for stateless reasoners; we only emit +// when the reasoner opts in via a top-level `state` field. +func TestAGUIRunHandler_State_OmittedWhenReasonerDoesNotReturnIt(t *testing.T) { + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"result":"plain"}`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "plain"}}, + }} + router := mountAGUIRouter(t, store) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/plain", + strings.NewReader(runAgentInputBody(t, "t", "r", "x"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + frames := parseAGUIStream(t, w.Body.String()) + for _, f := range frames { + require.NotEqual(t, "STATE_SNAPSHOT", f.Type(), + "STATE_SNAPSHOT must not be emitted unless the reasoner opts in") + } +} + +// TestAGUIRunHandler_PassesToolMessagesThrough — when the inbound history +// contains a `role:"tool"` message (CopilotKit posts these on the next run +// after a frontend useCopilotAction completes), it must reach the reasoner +// intact. +func TestAGUIRunHandler_PassesToolMessagesThrough(t *testing.T) { + var seenInput map[string]any + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + require.NoError(t, json.Unmarshal(raw, &seenInput)) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"result":"thanks"}`)) + })) + defer agentServer.Close() + + store := &reasonerTestStorage{agent: &types.AgentNode{ + ID: "node-1", + BaseURL: agentServer.URL, + HealthStatus: types.HealthStatusActive, + LifecycleStatus: types.AgentStatusReady, + Reasoners: []types.ReasonerDefinition{{ID: "echo"}}, + }} + router := mountAGUIRouter(t, store) + + body := `{ + "threadId":"t","runId":"r2", + "messages":[ + {"role":"user","content":"book SFO->JFK"}, + {"role":"assistant","toolCalls":[{"id":"tc1","type":"function","function":{"name":"showFlightCard","arguments":"{\"from\":\"SFO\"}"}}]}, + {"role":"tool","toolCallId":"tc1","content":"user clicked confirm"}, + {"role":"user","content":"now book the return"} + ] + }` + req := httptest.NewRequest(http.MethodPost, "/api/v1/agui/runs/node-1/echo", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, w.Body.String()) + require.Equal(t, "now book the return", seenInput["prompt"]) + msgs, _ := seenInput["messages"].([]any) + require.Len(t, msgs, 4) + toolMsg, _ := msgs[2].(map[string]any) + require.Equal(t, "tool", toolMsg["role"]) + require.Equal(t, "tc1", toolMsg["toolCallId"]) + require.Equal(t, "user clicked confirm", toolMsg["content"]) +} + +// TestHTTPAgentInvoker_HappyPath exercises the real httpAgentInvoker +// against a stub agent server — handler tests use an interface stub so this +// concrete path otherwise goes uncovered. +func TestHTTPAgentInvoker_HappyPath(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/reasoners/ping", r.URL.Path) + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + got, _ := io.ReadAll(r.Body) + require.JSONEq(t, `{"k":1}`, string(got)) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer server.Close() + + res, err := httpAgentInvoker{}.Invoke(context.Background(), + &types.AgentNode{BaseURL: server.URL}, "ping", []byte(`{"k":1}`)) + require.NoError(t, err) + require.False(t, res.IsStreaming(), "JSON content-type should land in the buffered branch") + require.JSONEq(t, `{"ok":true}`, string(res.Body)) +} + +// TestHTTPAgentInvoker_4xxBubblesUpAsError covers the resp.StatusCode >= 400 +// branch — the body is still returned but as a callError so the handler can +// turn it into a RUN_ERROR. +func TestHTTPAgentInvoker_4xxBubblesUpAsError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"oops":"server"}`)) + })) + defer server.Close() + + res, err := httpAgentInvoker{}.Invoke(context.Background(), + &types.AgentNode{BaseURL: server.URL}, "boom", []byte(`{}`)) + require.Error(t, err) + require.Contains(t, err.Error(), "agent returned 500") + require.NotNil(t, res, "response struct returned alongside err so caller can use Body for diagnostics") + require.Contains(t, string(res.Body), "oops") +} + +// TestHTTPAgentInvoker_DialFailureSurfacesError covers the client.Do error +// branch by pointing the invoker at a closed listener. +func TestHTTPAgentInvoker_DialFailureSurfacesError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + addr := server.URL + server.Close() + + _, err := httpAgentInvoker{}.Invoke(context.Background(), + &types.AgentNode{BaseURL: addr}, "ping", []byte(`{}`)) + require.Error(t, err) + require.Contains(t, err.Error(), "agent call failed") +} + +// TestHTTPAgentInvoker_BadURLFailsRequestConstruction covers the +// http.NewRequestWithContext error branch — an invalid URL never makes it +// to a dial. +func TestHTTPAgentInvoker_BadURLFailsRequestConstruction(t *testing.T) { + _, err := httpAgentInvoker{}.Invoke(context.Background(), + &types.AgentNode{BaseURL: "http://bad\nhost"}, "ping", []byte(`{}`)) + require.Error(t, err) + require.Contains(t, err.Error(), "create agent request") +} diff --git a/control-plane/internal/server/routes_core.go b/control-plane/internal/server/routes_core.go index e1ac2cd4..5e1daac2 100644 --- a/control-plane/internal/server/routes_core.go +++ b/control-plane/internal/server/routes_core.go @@ -107,6 +107,29 @@ func (s *AgentFieldServer) registerCoreRoutes(agentAPI *gin.RouterGroup) { executeGroup.POST("/:target", handlers.ExecuteHandler(s.storage, s.payloadStore, s.webhookDispatcher, s.config.AgentField.ExecutionQueue.AgentCallTimeout, s.config.Features.DID.Authorization.InternalToken)) executeGroup.POST("/async/:target", handlers.ExecuteAsyncHandler(s.storage, s.payloadStore, s.webhookDispatcher, s.config.AgentField.ExecutionQueue.AgentCallTimeout, s.config.Features.DID.Authorization.InternalToken)) } + + // AG-UI protocol adapter (https://docs.ag-ui.com). Accepts the + // canonical RunAgentInputSchema body so vanilla @ag-ui/client + // HttpAgent (and the CopilotKit runtime that wraps it) can target + // AgentField reasoners with no custom adapter. The reasoner is + // addressed via URL params; one HttpAgent.url per reasoner is the + // canonical CopilotKit topology. Permission middleware mirrors + // /execute so reasoners reachable via AG-UI honor the same DID/VC + // authorization gates as direct invocations. + aguiGroup := agentAPI.Group("/agui") + { + if s.config.Features.DID.Authorization.Enabled && s.accessPolicyService != nil && s.didWebService != nil { + aguiGroup.Use(middleware.PermissionCheckMiddleware( + s.accessPolicyService, + s.tagVCVerifier, + s.storage, + s.didWebService, + middleware.PermissionConfig{Enabled: true}, + )) + logger.Logger.Info().Msg("🔒 Permission checking enabled on AG-UI endpoints") + } + aguiGroup.POST("/runs/:node_id/:reasoner_name", handlers.AGUIRunHandler(s.storage)) + } agentAPI.GET("/executions/:execution_id", handlers.GetExecutionStatusHandler(s.storage)) agentAPI.POST("/executions/batch-status", handlers.BatchExecutionStatusHandler(s.storage)) agentAPI.POST("/executions/:execution_id/status", handlers.UpdateExecutionStatusHandler(s.storage, s.payloadStore, s.webhookDispatcher, s.config.AgentField.ExecutionQueue.AgentCallTimeout)) diff --git a/docs/integrations/copilotkit.md b/docs/integrations/copilotkit.md new file mode 100644 index 00000000..28ea63a3 --- /dev/null +++ b/docs/integrations/copilotkit.md @@ -0,0 +1,446 @@ +# CopilotKit / AG-UI integration + +AgentField speaks the [AG-UI protocol](https://docs.ag-ui.com) so any +AG-UI-compatible frontend — most notably [CopilotKit](https://docs.copilotkit.ai) — +can use AgentField as the agent backend with no custom adapter. + +This page is the contract. If you're writing a reasoner that should +drive Generative UI or shared state in CopilotKit, the fields below are +how you opt in. + +## Topology + +``` +Browser ──▶ / useCoAgent / useCopilotAction + ──▶ CopilotRuntime (Next.js /api/copilotkit) + ──▶ @ag-ui/client HttpAgent + ──▶ POST /api/v1/agui/runs// + ──▶ AgentField reasoner +``` + +CopilotKit posts a canonical `RunAgentInput` body. The control plane +forwards the same envelope to your reasoner and translates the response +into AG-UI Server-Sent Events. + +## Endpoint + +``` +POST /api/v1/agui/runs/:node_id/:reasoner_name +Content-Type: application/json +``` + +Body shape (see `RunAgentInputSchema` in `@ag-ui/core`): + +```json +{ + "threadId": "string", + "runId": "string", + "messages": [{ "role": "user|assistant|tool|system", "content": "...", "toolCalls": [...] }], + "tools": [{ "name": "...", "description": "...", "parameters": { ... } }], + "context": [{ "description": "...", "value": ... }], + "state": { ... }, + "forwardedProps": { ... } +} +``` + +The control plane fans this into the reasoner input map under the same +keys, plus a `prompt` convenience extracted from the trailing user +message. + +Response: an SSE stream of AG-UI events. + +## Reasoner contract + +Reasoners can return a flat result, or a structured map opting into any +of these AG-UI surfaces: + +| Reasoner field | Emitted as | Used by | +|---|---|---| +| `result` (string or anything) | `TEXT_MESSAGE_CONTENT` | `` assistant bubble | +| `content` (alias for `result`) | `TEXT_MESSAGE_CONTENT` | same | +| `toolCalls: [{id, name, arguments, result?}]` | `TOOL_CALL_START` → `_ARGS` → `_END` (and `_RESULT` if `result` set) | `useCopilotAction({name, render})` | +| `state: {...}` | `STATE_SNAPSHOT` | `useCoAgent({state})` | +| `stateDelta: [...]` (RFC 6902 ops) | `STATE_DELTA` (after snapshot) | `useCoAgent({state})` | + +If none of `result`/`content` is present, the control plane stringifies +the rest of the body (minus `toolCalls`/`state` internals) so you still +see something. + +Long `result` values are auto-chunked across multiple +`TEXT_MESSAGE_CONTENT` deltas (default 256 chars each) so the frontend +can paint progressively even though the reasoner is synchronous. Each +delta carries the same `messageId`; concatenation reproduces the full +text. + +### Python example + +```python +from agentfield import Agent, agui + +app = Agent(node_id="my-app") + +@app.reasoner() +async def book_flight(prompt: str = "", state: dict | None = None): + counter = (state or {}).get("counter", 0) + 1 + return { + "result": "Pulling up flight options.", + "toolCalls": [ + agui.tool_call( + name="showFlightCard", + arguments={"from": "SFO", "to": "JFK", "depart": "2026-06-01"}, + id="tc-flight-1", + ), + ], + "state": {"counter": counter, "lastBooking": "AA-12"}, + "stateDelta": [ + agui.state_delta_replace("/counter", counter), + ], + } +``` + +If your reasoner uses `app.ai(tools=...)` and you want the LLM's +tool-calling trace to surface in the UI, hand the trace to +`agui.tool_calls_from_trace`: + +```python +@app.reasoner() +async def smart_chat(prompt: str = ""): + result = await app.ai(prompt, tools="discover") + return { + "result": result.text, + "toolCalls": agui.tool_calls_from_trace(result.trace), + } +``` + +Each entry in the trace becomes a TOOL_CALL_*/_RESULT triad — the UI +shows a completed-tool indicator instead of a perpetually-pending +placeholder. + +### Go example + +```go +import ( + "context" + "github.com/Agent-Field/agentfield/sdk/go/agent" + "github.com/Agent-Field/agentfield/sdk/go/agent/agui" +) + +a, _ := agent.New(agent.Config{NodeID: "my-app"}) +a.RegisterReasoner("book_flight", func(ctx context.Context, in map[string]any) (any, error) { + return map[string]any{ + "result": "Pulling up flight options.", + "toolCalls": []map[string]any{ + agui.ToolCall("tc-1", "showFlightCard", map[string]any{ + "from": "SFO", "to": "JFK", + }, nil), + }, + "state": map[string]any{"lastBooking": "AA-12"}, + }, nil +}) +``` + +For a Go reasoner using the AI tool-call loop: + +```go +res, _ := aiClient.ExecuteToolCallLoopResult(ctx, prompt, tools, callFn) +return map[string]any{ + "result": res.Text(), + "toolCalls": agui.ToolCallsFromTrace(res.Trace), +}, nil +``` + +### TypeScript example + +```ts +import { Agent, agui } from '@agentfield/sdk'; + +const a = new Agent({ nodeId: 'my-app' }); + +a.reasoner('book_flight', async (ctx) => ({ + result: 'Pulling up flight options.', + toolCalls: [ + agui.toolCall('showFlightCard', { from: 'SFO', to: 'JFK' }, { id: 'tc-1' }), + ], + state: { lastBooking: 'AA-12' }, +})); +``` + +For a TypeScript reasoner using the AI tool-call loop: + +```ts +const { text, trace } = await ctx.aiWithTools(ctx.input.question, { tools: 'discover' }); +return { + result: text, + toolCalls: agui.toolCallsFromTrace(trace), +}; +``` + +## Frontend wiring + +Standard CopilotKit App Router setup, with one `HttpAgent` per reasoner: + +```ts +// app/api/copilotkit/route.ts +import { CopilotRuntime, copilotRuntimeNextJSAppRouterEndpoint } from "@copilotkit/runtime"; +import { HttpAgent } from "@ag-ui/client"; + +const BASE = "http://your-control-plane/api/v1/agui/runs/your-node"; + +const runtime = new CopilotRuntime({ + agents: { + chat: new HttpAgent({ url: `${BASE}/chat` }), + book_flight: new HttpAgent({ url: `${BASE}/book_flight` }), + }, +}); + +export const POST = async (req: Request) => { + const { handleRequest } = copilotRuntimeNextJSAppRouterEndpoint({ + runtime, endpoint: "/api/copilotkit", + }); + return handleRequest(req); +}; +``` + +```tsx +// app/page.tsx +"use client"; +import { CopilotKit, useCopilotAction } from "@copilotkit/react-core"; +import { CopilotChat } from "@copilotkit/react-ui"; +import "@copilotkit/react-ui/styles.css"; + +function FlightCard({ from, to, depart }: any) { + return
{from} → {to} ({depart})
; +} + +function Page() { + // Render-only: the agent emits a TOOL_CALL_*; the UI just visualizes it. + // `available: "frontend"` is required for render-only actions in + // CopilotKit v1.57+. + useCopilotAction({ + name: "showFlightCard", + available: "frontend", + parameters: [ + { name: "from", type: "string" }, + { name: "to", type: "string" }, + { name: "depart", type: "string" }, + ], + render: ({ args }) => , + }); + + return ( + + + + ); +} +``` + +For round-trip frontend tools (the agent calls a tool, the user +interacts, the tool returns a result that loops back to the agent on +the next turn), use `available: "enabled"` with a `handler` instead of +`render`. CopilotKit posts the tool's return value as a +`role: "tool"` message in the next run — the control plane forwards it +intact to the reasoner. + +## Auth + +The endpoint sits behind the same DID/VC permission middleware as +`/execute`. When `AGENTFIELD_FEATURES_DID_AUTHORIZATION_ENABLED=true`, +callers must include a valid DID-signed request just like for direct +reasoner invocations. + +## Live streaming (per-token + per-tool-arg deltas) + +The reasoner contract above buffers a full response and returns it as a +single dict. For live UX — text appearing token-by-token, +`TOOL_CALL_ARGS` streaming as the LLM emits them, `REASONING_*` events +flowing as the model thinks — return an NDJSON stream instead. The +control plane's streaming dispatcher (see +`control-plane/internal/handlers/agui_runs_streaming.go`) detects +`Content-Type: application/x-ndjson` and translates each line into the +matching AG-UI event in real time. + +### Python streaming reasoner + +```python +from fastapi import Request +from fastapi.responses import StreamingResponse +from agentfield import Agent, agui + +app = Agent(node_id="my-app") + +@app.post("/reasoners/chat") +async def chat(request: Request): + body = await request.json() + return StreamingResponse( + agui.serialize_stream(_chunks(body)), + media_type=agui.STREAMING_CONTENT_TYPE, + ) + +async def _chunks(body): + # Reasoning shows up in CopilotKit's "Thinking…" pane. + yield agui.reasoning_chunk("Looking up flights...") + yield agui.reasoning_end_chunk() + # Text chunks paint progressively in . + async for token in llm.stream(body["prompt"]): + yield agui.text_chunk(token) + # Tool calls drive useCopilotAction renders. + yield agui.tool_call_start_chunk("tc-1", "showFlightCard", + arguments={"from": "SFO", "to": "JFK"}) + yield agui.tool_call_end_chunk("tc-1") + # Shared state lands in useCoAgent. + yield agui.state_chunk({"counter": 1}) +``` + +The control plane wraps the stream with `RUN_STARTED` / `RUN_FINISHED`, +manages text and reasoning open/close lifecycle automatically, and emits +`MESSAGES_SNAPSHOT` at stream end. + +### Go streaming reasoner + +```go +import ( + "net/http" + "github.com/Agent-Field/agentfield/sdk/go/agent/agui" +) + +func chat(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", agui.StreamingContentType) + w.WriteHeader(http.StatusOK) + + chunks := make(chan map[string]any, 8) + go func() { + defer close(chunks) + chunks <- agui.ReasoningChunk("Looking up flights...") + chunks <- agui.ReasoningEndChunk() + for _, tok := range []string{"Booked ", "AA-12."} { + chunks <- agui.TextChunk(tok) + } + chunks <- agui.ToolCallStartChunk("tc-1", "showFlightCard", + map[string]any{"from": "SFO", "to": "JFK"}, "") + chunks <- agui.ToolCallEndChunk("tc-1") + }() + _ = agui.SerializeStream(r.Context(), w, chunks) +} +``` + +### TypeScript streaming reasoner + +```ts +import express from 'express'; +import { agui } from '@agentfield/sdk'; + +const app = express(); +app.use(express.json()); + +app.post('/reasoners/chat', async (req, res) => { + res.setHeader('Content-Type', agui.STREAMING_CONTENT_TYPE); + res.flushHeaders?.(); + + async function* chunks() { + yield agui.reasoningChunk('Looking up flights...'); + yield agui.reasoningEndChunk(); + for await (const tok of llm.stream(req.body.prompt)) yield agui.textChunk(tok); + yield agui.toolCallStartChunk('tc-1', 'showFlightCard', { + arguments: { from: 'SFO', to: 'JFK' }, + }); + yield agui.toolCallEndChunk('tc-1'); + yield agui.stateChunk({ counter: 1 }); + } + + for await (const buf of agui.serializeStream(chunks())) res.write(buf); + res.end(); +}); +``` + +`serializeStream` accepts both async generators and plain iterables. The +same chunks plug into Hono / Fastify / a Web `Response` built from a +`ReadableStream`. + +### `.harness()` relay + +The Anthropic Claude harness already produces a streaming async iterator +of messages. Pipe it straight to AG-UI: + +```python +from claude_agent_sdk import query, ClaudeAgentOptions +from agentfield import agui + +async def _chunks(body): + opts = ClaudeAgentOptions(...) + async for chunk in agui.relay_harness_stream( + query(prompt=body["prompt"], options=opts) + ): + yield chunk +``` + +`relay_harness_stream` translates Claude SDK message types into the +right AG-UI chunks: `text` blocks → `TEXT_MESSAGE_CONTENT`, +`thinking` blocks → `REASONING_*`, `tool_use` blocks → `TOOL_CALL_*`, +`tool_result` blocks → `TOOL_CALL_RESULT`. Note: the harness streams +per-message, not per-token, so this path delivers message-level +streaming. True per-token streaming requires the raw Anthropic API. + +The TypeScript SDK exposes the same translation as +`agui.relayHarnessStream(query(...))` (consuming the +`@anthropic-ai/claude-agent-sdk` async iterator). The Go SDK's harness +is buffered, so `agui.RelayHarnessResult(*harness.Result)` returns the +equivalent chunk slice in one shot — feed it into a channel and through +`agui.SerializeStream` for live emission. + +### SDK parity + +Every helper above exists in all three SDKs with matching names: + +| Concept | Python | Go | TypeScript | +|---|---|---|---| +| Streaming content type | `agui.STREAMING_CONTENT_TYPE` | `agui.StreamingContentType` | `agui.STREAMING_CONTENT_TYPE` | +| Text chunk | `agui.text_chunk(...)` | `agui.TextChunk(...)` | `agui.textChunk(...)` | +| Tool call (buffered) | `agui.tool_call(...)` | `agui.ToolCall(...)` | `agui.toolCall(...)` | +| Tool calls from AI trace | `tool_calls_from_trace(trace)` | `ToolCallsFromTrace(trace)` | `toolCallsFromTrace(trace)` | +| State delta replace | `state_delta_replace(p, v)` | `StateDeltaReplace(p, v)` | `stateDeltaReplace(p, v)` | +| Reasoning segment / list | `reasoning_segment / reasoning` | `ReasoningSegment / Reasoning` | `reasoningSegment / reasoning` | +| Stream serializer | `serialize_stream(gen)` | `SerializeStream(ctx, w, ch)` | `serializeStream(iter)` | +| Harness relay | `relay_harness_stream(iter)` | `RelayHarnessResult(*Result)` | `relayHarnessStream(iter)` | + +## Reasoner contract — full chunk reference + +When using the streaming path, each NDJSON line is one of these tagged +chunks (built by the helpers in `agentfield.agui` / +`sdk/go/agent/agui` / `@agentfield/sdk` `agui` namespace): + +| Chunk `type` | Maps to | Notes | +|---|---|---| +| `text` | `TEXT_MESSAGE_CONTENT` | `START`/`END` synthesized lazily on first/last text chunk | +| `reasoning` | `REASONING_MESSAGE_CONTENT` | Outer `REASONING_START`/`END` synthesized; emit `reasoning_end` to start a new segment within the same context | +| `tool_call_start` | `TOOL_CALL_START` (+ `_ARGS` if `arguments` provided inline) | | +| `tool_call_args` | `TOOL_CALL_ARGS` | Streamed as the LLM emits arg JSON | +| `tool_call_end` | `TOOL_CALL_END` | | +| `tool_call_result` | `TOOL_CALL_RESULT` | For server-side tools | +| `state` | `STATE_SNAPSHOT` | | +| `state_delta` | `STATE_DELTA` (RFC 6902 patches) | | +| `step_started` / `step_finished` | `STEP_STARTED` / `STEP_FINISHED` | CopilotKit ignores; useful for other AG-UI consumers | +| `raw` | `RAW` | Foreign-system passthrough | +| `custom` | `CUSTOM` | App-specific event with `name` + `value` | +| `final` | Applies a buffered-shape envelope | Use to send trailing `toolCalls` / `state` / etc. without re-implementing buffered logic | +| `error` | `RUN_ERROR` (terminal) | Subsequent chunks are ignored | + +## Performance + +Load tested at 50× concurrent buffered requests and 25× concurrent +streaming requests in CI (`internal/handlers/agui_runs_load_test.go`): + +- Buffered: 200 reqs in ~90 ms wall, p50 ≈ 4 ms, p95 ≈ 75 ms, p99 ≈ 77 ms +- Streaming dispatcher: 100 reqs in ~18 ms wall, no goroutine leaks +- Per-request benchmark (`go test -bench=BenchmarkAGUI`): ~389 µs/op, 26 KB/op + +## What we don't yet do + +- **Per-token streaming via the buffered reasoner contract.** Reasoners + using `@app.reasoner()` still buffer; the streaming path requires the + separate FastAPI / chunk-channel pattern shown above. We auto-chunk + buffered responses on emission so the UX is acceptable, but the + source of truth is still a synchronous return. +- **Bidirectional cancellation propagation into the streaming reasoner.** + Client disconnect aborts the streaming HTTP read on our end, but the + reasoner needs its own context plumbing to actually stop work. diff --git a/sdk/go/agent/agui/agui.go b/sdk/go/agent/agui/agui.go new file mode 100644 index 00000000..fc270d5e --- /dev/null +++ b/sdk/go/agent/agui/agui.go @@ -0,0 +1,454 @@ +// Package agui provides helpers for AgentField Go reasoners that want to +// surface AG-UI / CopilotKit-compatible Generative UI events through the +// control plane's POST /api/v1/agui/runs// adapter. +// +// Reasoners opt into the richer event types by returning specific fields +// in their response map; this package builds those fields in the canonical +// shape the control plane expects, so authors don't have to memorize the +// wire contract. +// +// Wire contract (mirrors the Python agentfield.agui module): +// +// - "result": the human-facing assistant text (used as the +// TEXT_MESSAGE_CONTENT delta). +// - "toolCalls": []map{id, name, arguments, result?} — surfaced as +// TOOL_CALL_START/_ARGS/_END (and _RESULT if `result` is set). +// - "state": full agent state — emitted as STATE_SNAPSHOT. +// - "stateDelta": []map{op, path, value} (RFC 6902) — emitted as +// STATE_DELTA after the snapshot. +// +// See https://docs.ag-ui.com/concepts/events for the upstream protocol. +package agui + +import ( + "context" + "encoding/json" + "fmt" + "io" + + "github.com/Agent-Field/agentfield/sdk/go/ai" + "github.com/Agent-Field/agentfield/sdk/go/harness" +) + +// ToolCall builds a single AG-UI tool-call entry. The control plane +// translates each entry into a TOOL_CALL_START/_ARGS/_END triad. If +// `result` is non-nil, TOOL_CALL_RESULT is also emitted so already-executed +// traces (e.g. from ai.ExecuteToolCallLoopResult) render as completed in +// the UI. +// +// `id` may be empty; the control plane synthesizes a stable ID per call. +// Pass an explicit id when correlating with a follow-up tool message +// from a frontend handler. +func ToolCall(id, name string, arguments map[string]any, result any) map[string]any { + if name == "" { + // Names are required by the AG-UI schema; an empty name will be + // silently dropped by the control plane. Surface the bug eagerly. + return nil + } + entry := map[string]any{"name": name} + if id != "" { + entry["id"] = id + } + if arguments == nil { + entry["arguments"] = map[string]any{} + } else { + entry["arguments"] = arguments + } + if result != nil { + entry["result"] = result + } + return entry +} + +// ToolCallsFromTrace converts an ai.ToolCallTrace from +// Client.ExecuteToolCallLoopResult into the AG-UI toolCalls list shape. +// Each record becomes an entry with its arguments and the executed +// result (or an {"error":"..."} object if the call failed). Nil or +// empty traces return an empty slice so callers can splat the result +// safely: +// +// return map[string]any{ +// "result": res.Text(), +// "toolCalls": agui.ToolCallsFromTrace(res.Trace), +// }, nil +func ToolCallsFromTrace(trace *ai.ToolCallTrace) []map[string]any { + if trace == nil || len(trace.Calls) == 0 { + return []map[string]any{} + } + out := make([]map[string]any, 0, len(trace.Calls)) + for i, rec := range trace.Calls { + entry := map[string]any{ + "id": fmt.Sprintf("tc-trace-%d", i), + "name": rec.ToolName, + "arguments": rec.Arguments, + } + if rec.Arguments == nil { + entry["arguments"] = map[string]any{} + } + switch { + case rec.Error != "": + entry["result"] = map[string]any{"error": rec.Error} + case rec.Result != nil: + entry["result"] = rec.Result + } + out = append(out, entry) + } + return out +} + +// StateDeltaReplace builds a single RFC 6902 "replace" patch op for a +// stateDelta array. Path must start with "/". +func StateDeltaReplace(path string, value any) (map[string]any, error) { + if len(path) == 0 || path[0] != '/' { + return nil, fmt.Errorf("RFC 6902 paths must start with '/' (got %q)", path) + } + return map[string]any{"op": "replace", "path": path, "value": value}, nil +} + +// ReasoningSegment builds one REASONING_MESSAGE segment for buffered-mode +// emission. Reasoners surface chain-of-thought to CopilotKit's +// "Thinking…" pane by returning a "reasoning" field whose value is a +// list of segments (or plain strings). Each segment becomes a +// REASONING_MESSAGE_START / _CONTENT / _END triad inside a +// REASONING_START / _END boundary. +// +// return map[string]any{ +// "result": "Booked AA-12.", +// "reasoning": []any{ +// agui.ReasoningSegment("Looking up flights..."), +// agui.ReasoningSegment("AA-12 is the cheapest non-stop."), +// }, +// }, nil +// +// Pass id="" to let the control plane synthesize one. +func ReasoningSegment(content, id string) map[string]any { + out := map[string]any{"content": content} + if id != "" { + out["id"] = id + } + return out +} + +// Reasoning builds a "reasoning" field value from a mix of plain strings +// and segment maps. Strings are passed through verbatim; mappings are +// shallow-copied. Returns an []any so it slots straight into the +// reasoner response map. +func Reasoning(segments ...any) ([]any, error) { + out := make([]any, 0, len(segments)) + for _, s := range segments { + switch v := s.(type) { + case string: + if v != "" { + out = append(out, v) + } + case map[string]any: + cp := make(map[string]any, len(v)) + for k, val := range v { + cp[k] = val + } + out = append(out, cp) + default: + return nil, fmt.Errorf("agui.Reasoning: segments must be string or map[string]any (got %T)", s) + } + } + return out, nil +} + +// ---------------------------------------------------------------------------- +// Streaming chunk builders + serializer. +// +// Reasoners that want live AG-UI events return chunks (built with these +// helpers) from a goroutine and pipe them through SerializeStream into an +// http.ResponseWriter with Content-Type "application/x-ndjson". The +// AgentField control plane sniffs the content-type and dispatches each +// line as a live AG-UI event (see internal/handlers/agui_runs_streaming.go). +// ---------------------------------------------------------------------------- + +// StreamingContentType is the response content-type a streaming reasoner +// must set so the control plane recognizes it as a live stream. +const StreamingContentType = "application/x-ndjson" + +// TextChunk is one piece of streaming assistant text. +func TextChunk(delta string) map[string]any { + return map[string]any{"type": "text", "delta": delta} +} + +// ReasoningChunk is one piece of chain-of-thought rendered in +// CopilotKit's "Thinking…" pane. +func ReasoningChunk(delta string) map[string]any { + return map[string]any{"type": "reasoning", "delta": delta} +} + +// ReasoningEndChunk closes the current reasoning segment so the next +// ReasoningChunk opens a fresh one. +func ReasoningEndChunk() map[string]any { + return map[string]any{"type": "reasoning_end"} +} + +// ToolCallStartChunk opens a tool call. Pass arguments inline if you +// have them all up front; otherwise stream them with ToolCallArgsChunk. +func ToolCallStartChunk(id, name string, arguments map[string]any, parentMessageID string) map[string]any { + out := map[string]any{"type": "tool_call_start", "id": id, "name": name} + if arguments != nil { + out["arguments"] = arguments + } + if parentMessageID != "" { + out["parentMessageId"] = parentMessageID + } + return out +} + +// ToolCallArgsChunk streams a piece of the tool-call arguments JSON. +func ToolCallArgsChunk(id, delta string) map[string]any { + return map[string]any{"type": "tool_call_args", "id": id, "delta": delta} +} + +// ToolCallEndChunk closes a tool call. +func ToolCallEndChunk(id string) map[string]any { + return map[string]any{"type": "tool_call_end", "id": id} +} + +// ToolCallResultChunk reports a server-side tool result. Use when the +// reasoner already executed the tool and wants the trace to render as +// completed in the UI. +func ToolCallResultChunk(id, content, role string) map[string]any { + if role == "" { + role = "tool" + } + return map[string]any{"type": "tool_call_result", "id": id, "content": content, "role": role} +} + +// StateChunk publishes a full agent state snapshot. +func StateChunk(snapshot any) map[string]any { + return map[string]any{"type": "state", "snapshot": snapshot} +} + +// StateDeltaChunk publishes RFC 6902 patch ops applied incrementally on +// top of the last snapshot. +func StateDeltaChunk(ops []any) map[string]any { + return map[string]any{"type": "state_delta", "ops": ops} +} + +// StepStartedChunk / StepFinishedChunk mark named-step boundaries inside +// the run. +func StepStartedChunk(name string) map[string]any { + return map[string]any{"type": "step_started", "name": name} +} + +func StepFinishedChunk(name string) map[string]any { + return map[string]any{"type": "step_finished", "name": name} +} + +// RawChunk passes a foreign-system event through verbatim. +func RawChunk(event any, source string) map[string]any { + out := map[string]any{"type": "raw", "event": event} + if source != "" { + out["source"] = source + } + return out +} + +// CustomChunk emits an application-defined event with a name and value. +func CustomChunk(name string, value any) map[string]any { + out := map[string]any{"type": "custom", "name": name} + if value != nil { + out["value"] = value + } + return out +} + +// FinalChunk packages a trailing buffered envelope. The dispatcher +// applies any toolCalls / state / stateDelta / reasoning / result fields +// in `data` as if from a non-streaming reasoner — useful when the +// reasoner can stream text live but only knows the structured fields at +// the end. +func FinalChunk(data map[string]any) map[string]any { + return map[string]any{"type": "final", "data": data} +} + +// ErrorChunk is a terminal error. The dispatcher emits RUN_ERROR and +// stops the run; later chunks are ignored. +func ErrorChunk(message, code string) map[string]any { + out := map[string]any{"type": "error", "message": message} + if code != "" { + out["code"] = code + } + return out +} + +// RelayHarnessResult translates a buffered Claude Agent harness result +// (the messages slice on harness.Result) into AG-UI streaming chunks, +// message-by-message. Mirrors the Python SDK's relay_harness_stream. +// +// The Go harness is buffered (it returns a Result after the run finishes) +// so this helper is itself buffered: it walks res.Messages once and +// returns the equivalent chunk slice. Reasoners that want to stream the +// chunks live can either feed the slice into a channel and call +// SerializeStream, or interleave their own custom chunks. +// +// Recognized message shapes (matching the dict form of the Python and +// JS Claude Agent SDK message stream): +// +// - {type:"assistant", message:{content:[{type:"text", text:"..."}]}} +// → one TextChunk per text block +// - {type:"assistant", message:{content:[{type:"thinking", thinking:"..."}]}} +// → one ReasoningChunk per thinking block +// - {type:"assistant", message:{content:[{type:"tool_use", id, name, input}]}} +// → ToolCallStartChunk + ToolCallEndChunk per tool_use block +// - {type:"user", message:{content:[{type:"tool_result", tool_use_id, content}]}} +// → ToolCallResultChunk per tool_result block +// - {type:"result", ...} → skipped (the dispatcher's stream-end logic +// synthesizes MESSAGES_SNAPSHOT + RUN_FINISHED) +// - Anything unrecognized is wrapped as a RawChunk. +// +// Note: the Claude Agent SDK buffers per-message, not per-token. True +// per-token streaming requires the raw Anthropic streaming API. +func RelayHarnessResult(res *harness.Result) []map[string]any { + if res == nil || len(res.Messages) == 0 { + return nil + } + out := make([]map[string]any, 0, len(res.Messages)*2) + for _, msg := range res.Messages { + out = append(out, relayHarnessMessage(msg)...) + } + return out +} + +func relayHarnessMessage(msg map[string]any) []map[string]any { + if msg == nil { + return nil + } + mtype, _ := msg["type"].(string) + if mtype == "result" { + return nil + } + if mtype == "system" { + return []map[string]any{RawChunk(msg, "harness")} + } + if mtype != "assistant" && mtype != "user" { + return []map[string]any{RawChunk(msg, "harness")} + } + + content := harnessMessageContent(msg) + if content == nil { + return []map[string]any{RawChunk(msg, "harness")} + } + if s, ok := content.(string); ok { + if mtype == "assistant" && s != "" { + return []map[string]any{TextChunk(s)} + } + return nil + } + blocks, ok := content.([]any) + if !ok { + return []map[string]any{RawChunk(msg, "harness")} + } + + out := make([]map[string]any, 0, len(blocks)) + for _, raw := range blocks { + block, ok := raw.(map[string]any) + if !ok { + continue + } + btype, _ := block["type"].(string) + switch btype { + case "text": + text, _ := block["text"].(string) + if text != "" { + out = append(out, TextChunk(text)) + } + case "thinking": + thinking, _ := block["thinking"].(string) + if thinking != "" { + out = append(out, ReasoningChunk(thinking)) + } + case "tool_use": + id, _ := block["id"].(string) + name, _ := block["name"].(string) + if id == "" || name == "" { + continue + } + input, _ := block["input"].(map[string]any) + out = append(out, ToolCallStartChunk(id, name, input, "")) + out = append(out, ToolCallEndChunk(id)) + case "tool_result": + id, _ := block["tool_use_id"].(string) + if id == "" { + continue + } + inner := harnessToolResultContent(block["content"]) + out = append(out, ToolCallResultChunk(id, inner, "tool")) + default: + out = append(out, RawChunk(block, "harness")) + } + } + return out +} + +func harnessMessageContent(msg map[string]any) any { + if v, ok := msg["content"]; ok { + return v + } + inner, ok := msg["message"].(map[string]any) + if !ok { + return nil + } + return inner["content"] +} + +func harnessToolResultContent(v any) string { + switch t := v.(type) { + case string: + return t + case []any: + var b []byte + for _, item := range t { + m, ok := item.(map[string]any) + if !ok { + continue + } + s, _ := m["text"].(string) + b = append(b, s...) + } + return string(b) + case nil: + return "" + default: + return fmt.Sprintf("%v", t) + } +} + +// SerializeStream consumes a chunks channel (closed by the producer when +// done) and writes one NDJSON line per chunk to w, flushing after each. +// `w` should be an http.ResponseWriter with Content-Type set to +// StreamingContentType. Returns the first write or encode error +// encountered, or nil when the channel closes cleanly. +// +// Typical usage in an HTTP reasoner endpoint: +// +// w.Header().Set("Content-Type", agui.StreamingContentType) +// w.WriteHeader(http.StatusOK) +// chunks := make(chan map[string]any, 8) +// go produceChunks(ctx, chunks) // closes chunks when done +// if err := agui.SerializeStream(ctx, w, chunks); err != nil { ... } +func SerializeStream(ctx context.Context, w io.Writer, chunks <-chan map[string]any) error { + flusher, _ := w.(interface{ Flush() }) + enc := json.NewEncoder(w) + enc.SetEscapeHTML(false) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case ch, ok := <-chunks: + if !ok { + return nil + } + if err := enc.Encode(ch); err != nil { + return fmt.Errorf("encode chunk: %w", err) + } + if flusher != nil { + flusher.Flush() + } + } + } +} diff --git a/sdk/go/agent/agui/agui_test.go b/sdk/go/agent/agui/agui_test.go new file mode 100644 index 00000000..e748cb52 --- /dev/null +++ b/sdk/go/agent/agui/agui_test.go @@ -0,0 +1,223 @@ +package agui + +import ( + "bytes" + "context" + "encoding/json" + "io" + "strings" + "testing" + "time" + + "github.com/Agent-Field/agentfield/sdk/go/ai" + "github.com/Agent-Field/agentfield/sdk/go/harness" + + "github.com/stretchr/testify/require" +) + +func TestToolCall_MinimalAndFull(t *testing.T) { + require.Nil(t, ToolCall("", "", nil, nil), "empty name returns nil so caller surfaces the bug") + + minimal := ToolCall("", "showFlightCard", nil, nil) + require.Equal(t, "showFlightCard", minimal["name"]) + require.NotContains(t, minimal, "id", "id only present when caller supplies one") + require.Equal(t, map[string]any{}, minimal["arguments"]) + require.NotContains(t, minimal, "result") + + full := ToolCall("tc-1", "x", map[string]any{"a": 1}, map[string]any{"ok": true}) + require.Equal(t, "tc-1", full["id"]) + require.Equal(t, map[string]any{"a": 1}, full["arguments"]) + require.Equal(t, map[string]any{"ok": true}, full["result"]) +} + +func TestToolCallsFromTrace(t *testing.T) { + require.Empty(t, ToolCallsFromTrace(nil)) + require.Empty(t, ToolCallsFromTrace(&ai.ToolCallTrace{})) + + trace := &ai.ToolCallTrace{ + Calls: []ai.ToolCallRecord{ + {ToolName: "getWeather", Arguments: map[string]any{"city": "SF"}, Result: map[string]any{"temp": 62.0}}, + {ToolName: "lookup", Arguments: map[string]any{"q": "x"}, Error: "timeout"}, + {ToolName: "noargs"}, + }, + } + out := ToolCallsFromTrace(trace) + require.Len(t, out, 3) + + require.Equal(t, "tc-trace-0", out[0]["id"]) + require.Equal(t, "getWeather", out[0]["name"]) + require.Equal(t, map[string]any{"temp": 62.0}, out[0]["result"]) + + require.Equal(t, "tc-trace-1", out[1]["id"]) + require.Equal(t, map[string]any{"error": "timeout"}, out[1]["result"], "errors surface as {error:...}") + + require.Equal(t, "tc-trace-2", out[2]["id"]) + require.Equal(t, map[string]any{}, out[2]["arguments"], "nil arguments default to empty map") + require.NotContains(t, out[2], "result", "no result and no error means omit the field") +} + +func TestStateDeltaReplace(t *testing.T) { + op, err := StateDeltaReplace("/counter", 2) + require.NoError(t, err) + require.Equal(t, map[string]any{"op": "replace", "path": "/counter", "value": 2}, op) + + _, err = StateDeltaReplace("counter", 2) + require.Error(t, err, "path without leading slash is invalid") + _, err = StateDeltaReplace("", 2) + require.Error(t, err, "empty path is invalid") +} + +func TestStreamingChunkBuilders(t *testing.T) { + require.Equal(t, map[string]any{"type": "text", "delta": "hi"}, TextChunk("hi")) + require.Equal(t, map[string]any{"type": "reasoning", "delta": "think"}, ReasoningChunk("think")) + require.Equal(t, map[string]any{"type": "reasoning_end"}, ReasoningEndChunk()) + + tcStart := ToolCallStartChunk("tc1", "x", map[string]any{"a": 1}, "msg-1") + require.Equal(t, "tool_call_start", tcStart["type"]) + require.Equal(t, "tc1", tcStart["id"]) + require.Equal(t, "x", tcStart["name"]) + require.Equal(t, map[string]any{"a": 1}, tcStart["arguments"]) + require.Equal(t, "msg-1", tcStart["parentMessageId"]) + + tcStartNoExtras := ToolCallStartChunk("tc2", "x", nil, "") + require.NotContains(t, tcStartNoExtras, "arguments") + require.NotContains(t, tcStartNoExtras, "parentMessageId") + + require.Equal(t, map[string]any{"type": "tool_call_args", "id": "tc1", "delta": "{\"x"}, ToolCallArgsChunk("tc1", "{\"x")) + require.Equal(t, map[string]any{"type": "tool_call_end", "id": "tc1"}, ToolCallEndChunk("tc1")) + + res := ToolCallResultChunk("tc1", "ok", "") + require.Equal(t, "tool", res["role"], "default role is 'tool'") + require.Equal(t, "ok", res["content"]) + + require.Equal(t, map[string]any{"type": "state", "snapshot": map[string]any{"a": 1}}, StateChunk(map[string]any{"a": 1})) + require.Equal(t, "state_delta", StateDeltaChunk([]any{map[string]any{"op": "replace"}})["type"]) + + require.Equal(t, "step_started", StepStartedChunk("plan")["type"]) + require.Equal(t, "step_finished", StepFinishedChunk("plan")["type"]) + + raw := RawChunk(map[string]any{"x": 1}, "harness") + require.Equal(t, "raw", raw["type"]) + require.Equal(t, "harness", raw["source"]) + + rawNoSrc := RawChunk(map[string]any{"x": 1}, "") + require.NotContains(t, rawNoSrc, "source") + + custom := CustomChunk("ack", map[string]any{"ok": true}) + require.Equal(t, "custom", custom["type"]) + require.Equal(t, "ack", custom["name"]) + + customNil := CustomChunk("ack", nil) + require.NotContains(t, customNil, "value") + + final := FinalChunk(map[string]any{"toolCalls": []any{}}) + require.Equal(t, "final", final["type"]) + + errCh := ErrorChunk("boom", "E1") + require.Equal(t, "error", errCh["type"]) + require.Equal(t, "boom", errCh["message"]) + require.Equal(t, "E1", errCh["code"]) + + errChNoCode := ErrorChunk("boom", "") + require.NotContains(t, errChNoCode, "code") +} + +func TestSerializeStream(t *testing.T) { + ch := make(chan map[string]any, 4) + ch <- TextChunk("hello ") + ch <- TextChunk("world") + ch <- StateChunk(map[string]any{"counter": 1}) + close(ch) + + var buf bytes.Buffer + require.NoError(t, SerializeStream(context.Background(), &buf, ch)) + lines := strings.Split(strings.TrimRight(buf.String(), "\n"), "\n") + require.Len(t, lines, 3) + + var first map[string]any + require.NoError(t, json.Unmarshal([]byte(lines[0]), &first)) + require.Equal(t, "text", first["type"]) + require.Equal(t, "hello ", first["delta"]) + + var third map[string]any + require.NoError(t, json.Unmarshal([]byte(lines[2]), &third)) + require.Equal(t, "state", third["type"]) +} + +func TestReasoningSegment_AndReasoning(t *testing.T) { + seg := ReasoningSegment("thinking", "r1") + require.Equal(t, map[string]any{"content": "thinking", "id": "r1"}, seg) + + segNoID := ReasoningSegment("thinking", "") + require.NotContains(t, segNoID, "id") + + out, err := Reasoning("a", "", seg, "b") + require.NoError(t, err) + require.Equal(t, []any{"a", map[string]any{"content": "thinking", "id": "r1"}, "b"}, out) + + _, err = Reasoning(42) + require.Error(t, err, "non-string non-mapping segments should error") +} + +func TestRelayHarnessResult(t *testing.T) { + require.Nil(t, RelayHarnessResult(nil), "nil result yields nil") + require.Nil(t, RelayHarnessResult(&harness.Result{}), "empty messages yields nil") + + res := &harness.Result{ + Messages: []map[string]any{ + {"type": "system", "subtype": "init"}, + {"type": "assistant", "message": map[string]any{"content": []any{ + map[string]any{"type": "text", "text": "hello"}, + map[string]any{"type": "thinking", "thinking": "hmm"}, + map[string]any{"type": "tool_use", "id": "tc1", "name": "lookup", "input": map[string]any{"q": "x"}}, + map[string]any{"type": "mystery", "payload": 1}, + }}}, + {"type": "user", "message": map[string]any{"content": []any{ + map[string]any{"type": "tool_result", "tool_use_id": "tc1", "content": "ok"}, + }}}, + {"type": "user", "message": map[string]any{"content": []any{ + map[string]any{"type": "tool_result", "tool_use_id": "tc2", "content": []any{ + map[string]any{"type": "text", "text": "a"}, + map[string]any{"type": "text", "text": "b"}, + }}, + }}}, + {"type": "result", "subtype": "success", "result": "done"}, + {"type": "no-such-thing"}, + }, + } + chunks := RelayHarnessResult(res) + + types := make([]string, 0, len(chunks)) + for _, c := range chunks { + types = append(types, c["type"].(string)) + } + require.Equal(t, []string{ + "raw", // system + "text", // hello + "reasoning", // hmm + "tool_call_start", // tool_use start + "tool_call_end", // tool_use end + "raw", // mystery block + "tool_call_result", // ok + "tool_call_result", // a+b stitched + "raw", // unknown top-level + }, types) + + stitched := chunks[7] + require.Equal(t, "tc2", stitched["id"]) + require.Equal(t, "ab", stitched["content"]) +} + +func TestSerializeStream_RespectsContext(t *testing.T) { + ch := make(chan map[string]any) // never closed; never sends + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { done <- SerializeStream(ctx, io.Discard, ch) }() + cancel() + select { + case err := <-done: + require.ErrorIs(t, err, context.Canceled) + case <-time.After(time.Second): + t.Fatal("SerializeStream did not honor context cancellation") + } +} diff --git a/sdk/python/agentfield/__init__.py b/sdk/python/agentfield/__init__.py index 4261196f..893c9e71 100644 --- a/sdk/python/agentfield/__init__.py +++ b/sdk/python/agentfield/__init__.py @@ -69,6 +69,7 @@ capability_to_tool_schema, capabilities_to_tool_schemas, ) +from . import agui __all__ = [ "Agent", @@ -129,6 +130,8 @@ "ToolCallTrace", "capability_to_tool_schema", "capabilities_to_tool_schemas", + # AG-UI protocol helpers + "agui", # Exceptions "AgentFieldError", "AgentFieldClientError", diff --git a/sdk/python/agentfield/agui.py b/sdk/python/agentfield/agui.py new file mode 100644 index 00000000..b18915e5 --- /dev/null +++ b/sdk/python/agentfield/agui.py @@ -0,0 +1,554 @@ +"""AG-UI protocol helpers for AgentField reasoners. + +This module also exposes a *streaming* reasoner contract — see +``serialize_stream`` and the chunk builders (``text_chunk``, +``reasoning_chunk``, ``tool_call_start_chunk`` …) — for live +per-token AG-UI events. A streaming reasoner is a normal FastAPI +endpoint that returns a ``StreamingResponse`` with content-type +``application/x-ndjson``; the AgentField control plane sniffs the +content-type and dispatches each line as a live AG-UI event. + + + +Reasoners reach the AG-UI / CopilotKit frontend via the control plane's +``POST /api/v1/agui/runs//`` adapter. The adapter expects +a small set of optional fields in the reasoner's response to drive the +richer AG-UI events (tool calls, shared state, RFC 6902 patches). + +This module is the documented contract for those fields. Reasoner authors +opt into Generative UI / shared state by returning the values these +helpers build: + +.. code-block:: python + + @app.reasoner() + async def book_flight(prompt: str = "", state: dict | None = None): + return { + "result": "Picking flight options.", + "toolCalls": [ + tool_call(name="showFlightCard", arguments={"from": "SFO", "to": "JFK"}), + ], + "state": {"counter": (state or {}).get("counter", 0) + 1}, + } + +When a reasoner uses ``await app.ai(..., tools=[...])`` and wants the +LLM's tool-calling trace to surface in the UI, pass the returned +``ToolCallResponse.trace`` into :func:`tool_calls_from_trace`: + +.. code-block:: python + + result = await app.ai("help the user", tools="discover") + return { + "result": result.text, + "toolCalls": tool_calls_from_trace(result.trace), + } + +Wire shape mirrors the canonical AG-UI ``TOOL_CALL_*`` events +(https://docs.ag-ui.com/concepts/events). +""" + +from __future__ import annotations + +from typing import Any, Iterable, List, Mapping, Optional + +from .tool_calling import ToolCallRecord, ToolCallTrace + +__all__ = [ + "tool_call", + "tool_calls_from_trace", + "state_delta_replace", + "state_delta_from_diff", + "reasoning", + "reasoning_segment", + # Streaming chunk builders + serializer for the live AG-UI path. + "text_chunk", + "reasoning_chunk", + "reasoning_end_chunk", + "tool_call_start_chunk", + "tool_call_args_chunk", + "tool_call_end_chunk", + "tool_call_result_chunk", + "state_chunk", + "state_delta_chunk", + "step_started_chunk", + "step_finished_chunk", + "raw_chunk", + "custom_chunk", + "final_chunk", + "error_chunk", + "serialize_stream", + "relay_harness_stream", + "STREAMING_CONTENT_TYPE", +] + +STREAMING_CONTENT_TYPE = "application/x-ndjson" + + +def tool_call( + name: str, + arguments: Optional[Mapping[str, Any]] = None, + *, + id: Optional[str] = None, + result: Any = None, + has_result: bool = False, +) -> dict: + """Build a single AG-UI tool-call entry. + + The control plane translates each entry into a + ``TOOL_CALL_START`` / ``TOOL_CALL_ARGS`` / ``TOOL_CALL_END`` triad. + When ``has_result=True`` (or ``result`` is non-None), it also emits + ``TOOL_CALL_RESULT`` so a server-side trace renders in the UI. + + Args: + name: The tool name. CopilotKit pattern-matches this against + ``useCopilotAction({name, render})`` registrations to drive + Generative UI. + arguments: A JSON-serializable mapping of arguments. + id: Optional stable ID. If omitted, the control plane synthesizes + one (which works for one-shot calls but breaks correlation + with follow-up tool messages). + result: Optional result. Set this when the tool was already + executed server-side (e.g. inside ``app.ai(tools=...)``). + has_result: Pass True to force ``result=None`` to be treated as + an explicit "executed and returned null" instead of "not + executed yet". Defaults to True if ``result`` is non-None. + """ + entry: dict = {"name": name, "arguments": dict(arguments or {})} + if id is not None: + entry["id"] = id + if result is not None or has_result: + entry["result"] = result + return entry + + +def tool_calls_from_trace(trace: Optional[ToolCallTrace]) -> List[dict]: + """Convert a ``ToolCallTrace`` from ``app.ai(tools=...)`` into the + AG-UI ``toolCalls`` list shape. + + Each :class:`ToolCallRecord` becomes a tool-call entry with its + arguments, and the executed result (or error) attached so the UI can + render the trace as a sequence of completed tool calls. Empty traces + return ``[]`` so callers can splat the result safely: + + .. code-block:: python + + return {"result": text, "toolCalls": tool_calls_from_trace(trace)} + + Args: + trace: A trace from :class:`ToolCallResponse`, or None. + + Returns: + A list of dicts in AG-UI ``toolCalls`` format. Empty if ``trace`` + is None or has no calls. + """ + if trace is None or not getattr(trace, "calls", None): + return [] + out: List[dict] = [] + for i, rec in enumerate(trace.calls): + out.append(_record_to_entry(rec, i)) + return out + + +def _record_to_entry(rec: ToolCallRecord, index: int) -> dict: + """Translate one ``ToolCallRecord`` into an AG-UI tool-call entry.""" + entry: dict = { + "id": f"tc-trace-{index}", + "name": rec.tool_name, + "arguments": dict(rec.arguments or {}), + } + # The trace records either a result or an error; surface either as + # the AG-UI tool-call result so frontend renderers can show a final + # state instead of a perpetually "running" placeholder. + if rec.error is not None: + entry["result"] = {"error": rec.error} + elif rec.result is not None: + entry["result"] = rec.result + return entry + + +def state_delta_replace(path: str, value: Any) -> dict: + """Build a single RFC 6902 ``replace`` patch op for ``stateDelta``. + + .. code-block:: python + + return { + "result": "...", + "stateDelta": [ + state_delta_replace("/counter", 2), + state_delta_replace("/lastUpdated", "2026-05-09"), + ], + } + + The control plane re-emits the array as a ``STATE_DELTA`` event, + which CopilotKit's ``useCoAgent`` applies on top of the previously + snapshot-emitted state. + """ + if not path.startswith("/"): + raise ValueError("RFC 6902 paths must start with '/'") + return {"op": "replace", "path": path, "value": value} + + +def reasoning_segment(content: str, *, id: Optional[str] = None) -> dict: + """Build a single REASONING_MESSAGE segment. + + Reasoners surface chain-of-thought to CopilotKit's "Thinking…" pane + by returning either a plain string or a list of these segments under + the ``reasoning`` field of their response: + + .. code-block:: python + + return { + "result": "Booked AA-12.", + "reasoning": [ + agui.reasoning_segment("Looking up flights for SFO->JFK..."), + agui.reasoning_segment("AA-12 is the cheapest non-stop."), + ], + } + + Each segment becomes a REASONING_MESSAGE_START / _CONTENT / _END + triad inside a REASONING_START / END boundary. Long content is + auto-chunked across multiple REASONING_MESSAGE_CONTENT deltas. + """ + out: dict = {"content": content} + if id is not None: + out["id"] = id + return out + + +def reasoning(*segments: Any) -> List[Any]: + """Build a ``reasoning`` field value from a mix of strings and segments. + + Convenience wrapper so reasoners can write:: + + return {"result": text, "reasoning": agui.reasoning("step 1", "step 2")} + + instead of constructing the list manually. + """ + out: List[Any] = [] + for s in segments: + if isinstance(s, str): + if s: + out.append(s) + elif isinstance(s, Mapping): + out.append(dict(s)) + else: + raise TypeError( + f"reasoning() segments must be str or mapping; got {type(s).__name__}" + ) + return out + + +# --------------------------------------------------------------------------- +# Streaming chunk builders +# +# Each function returns a small dict in the wire shape the control plane's +# streaming dispatcher consumes (see internal/handlers/agui_runs_streaming.go). +# The reasoner author yields these from an async generator; serialize_stream +# turns each yield into one NDJSON line for the FastAPI StreamingResponse. +# --------------------------------------------------------------------------- + + +def text_chunk(delta: str) -> dict: + """One chunk of assistant text. Concatenated client-side.""" + return {"type": "text", "delta": delta} + + +def reasoning_chunk(delta: str) -> dict: + """One chunk of chain-of-thought, rendered in CopilotKit's + "Thinking…" pane. Yield multiple in a row for a single thought, + then ``reasoning_end_chunk()`` to start a new thought segment.""" + return {"type": "reasoning", "delta": delta} + + +def reasoning_end_chunk() -> dict: + """Closes the current reasoning segment so the next ``reasoning_chunk`` + opens a fresh one. The outer reasoning context auto-closes at stream + end or when the first text/tool-call chunk arrives.""" + return {"type": "reasoning_end"} + + +def tool_call_start_chunk( + id: str, + name: str, + *, + arguments: Optional[Mapping[str, Any]] = None, + parent_message_id: Optional[str] = None, +) -> dict: + """Open a tool call. If you already have the full ``arguments``, + pass them here and the dispatcher emits one TOOL_CALL_ARGS frame + immediately; otherwise stream them with ``tool_call_args_chunk``.""" + out: dict = {"type": "tool_call_start", "id": id, "name": name} + if arguments is not None: + out["arguments"] = dict(arguments) + if parent_message_id is not None: + out["parentMessageId"] = parent_message_id + return out + + +def tool_call_args_chunk(id: str, delta: str) -> dict: + """One chunk of streaming tool-call arguments. ``delta`` is a + string — typically a piece of the JSON-encoded arguments object as + the LLM emits it. Concatenated client-side into the final args JSON.""" + return {"type": "tool_call_args", "id": id, "delta": delta} + + +def tool_call_end_chunk(id: str) -> dict: + """Close a tool call.""" + return {"type": "tool_call_end", "id": id} + + +def tool_call_result_chunk(id: str, content: str, *, role: str = "tool") -> dict: + """Server-side tool result. Use when the reasoner already executed + the tool (e.g. via ``app.ai(tools=...)``) and wants the trace to + render as completed in the UI.""" + return {"type": "tool_call_result", "id": id, "content": content, "role": role} + + +def state_chunk(snapshot: Any) -> dict: + """Full agent state snapshot — the value ``useCoAgent({state})`` + reads on the frontend.""" + return {"type": "state", "snapshot": snapshot} + + +def state_delta_chunk(ops: List[dict]) -> dict: + """RFC 6902 patch ops applied incrementally on top of the last + snapshot the client received. Cheaper than re-emitting full state + every turn.""" + return {"type": "state_delta", "ops": list(ops)} + + +def step_started_chunk(name: str) -> dict: + """Mark the start of a named step inside the run. Useful for + multi-stage agents where a frontend wants to render a progress UI.""" + return {"type": "step_started", "name": name} + + +def step_finished_chunk(name: str) -> dict: + """Mark a step finished.""" + return {"type": "step_finished", "name": name} + + +def raw_chunk(event: Any, *, source: Optional[str] = None) -> dict: + """Pass a foreign-system event through verbatim. Frontends that + subscribed via ``onRawEvent`` see it; others ignore it.""" + out: dict = {"type": "raw", "event": event} + if source is not None: + out["source"] = source + return out + + +def custom_chunk(name: str, value: Any = None) -> dict: + """Application-defined event. Frontends subscribe by ``name``.""" + out: dict = {"type": "custom", "name": name} + if value is not None: + out["value"] = value + return out + + +def final_chunk(data: Mapping[str, Any]) -> dict: + """Trailing buffered envelope — the dispatcher applies any + ``toolCalls`` / ``state`` / ``stateDelta`` / ``reasoning`` / + ``result`` fields here as if from a non-streaming reasoner. Useful + when the reasoner can stream text live but only knows the + structured fields at the end.""" + return {"type": "final", "data": dict(data)} + + +def error_chunk(message: str, *, code: Optional[str] = None) -> dict: + """Terminal error. The dispatcher emits RUN_ERROR and stops the run; + any subsequent chunks the reasoner sends are ignored.""" + out: dict = {"type": "error", "message": message} + if code is not None: + out["code"] = code + return out + + +async def relay_harness_stream(harness_iter: Any) -> Any: + """Relay a Claude Agent SDK / harness async-iterator of messages + into AG-UI streaming chunks, message-by-message. + + The Claude Agent SDK yields one Python dict (or message object) per + turn — assistant text blocks, tool-use blocks, tool-result blocks, + a final ``result`` envelope. This function translates each into the + smallest sensible AG-UI chunk(s) so a reasoner can pipe a harness + run straight to the AG-UI stream:: + + from claude_agent_sdk import query, ClaudeAgentOptions + from agentfield import agui + + async def _chunks(body): + opts = ClaudeAgentOptions(...) + async for ch in agui.relay_harness_stream( + query(prompt=body["prompt"], options=opts) + ): + yield ch + + Recognized message shapes (matches the dict form + ``HarnessResult.messages`` records): + + - ``{"type":"assistant","message":{"content":[{"type":"text","text":"..."}, ...]}}`` + → one ``text`` chunk per text block + - ``{"type":"assistant","message":{"content":[{"type":"tool_use","id":"...","name":"...","input":{...}}, ...]}}`` + → ``tool_call_start`` + ``tool_call_end`` per tool_use block + - ``{"type":"user","message":{"content":[{"type":"tool_result","tool_use_id":"...","content":"..."}, ...]}}`` + → ``tool_call_result`` per tool_result block + - ``{"type":"result","subtype":"success","result":"..."}`` → + terminal — yields nothing (the dispatcher's stream-end logic + wraps the run with MESSAGES_SNAPSHOT + RUN_FINISHED). + - Anything unrecognized is wrapped as a ``raw`` chunk so the trace + is preserved without us inventing ad-hoc event types. + + Note: the Claude Agent SDK buffers per-message rather than per-token, + so this path streams at message granularity. True per-token streaming + requires the raw Anthropic streaming API, not the harness.""" + async for raw in harness_iter: + if isinstance(raw, dict): + msg = raw + elif hasattr(raw, "__dict__"): + msg = dict(raw.__dict__) + else: + yield raw_chunk({"raw": str(raw)}, source="harness") + continue + + msg_type = str(msg.get("type", "")) + if msg_type == "result": + # The harness's result message holds the final aggregated + # text; the AG-UI stream's MESSAGES_SNAPSHOT / RUN_FINISHED + # frames will be synthesized by the control-plane dispatcher + # at stream end, so we don't need to emit anything here. + continue + if msg_type == "system": + yield raw_chunk(msg, source="harness") + continue + + if msg_type in ("assistant", "user"): + content = _harness_message_content(msg) + if content is None: + yield raw_chunk(msg, source="harness") + continue + if isinstance(content, str): + if msg_type == "assistant" and content: + yield text_chunk(content) + continue + if isinstance(content, list): + for block in content: + if not isinstance(block, dict): + continue + btype = block.get("type") + if btype == "text": + text = block.get("text", "") + if text: + yield text_chunk(text) + elif btype == "thinking": + # Anthropic extended-thinking blocks render as + # REASONING_* events — exactly the "Thinking…" + # pane CopilotKit shows. + thinking = block.get("thinking", "") + if thinking: + yield reasoning_chunk(thinking) + elif btype == "tool_use": + tcid = str(block.get("id", "")) + name = str(block.get("name", "")) + if tcid and name: + inp = block.get("input") + if not isinstance(inp, Mapping): + inp = {} + yield tool_call_start_chunk(tcid, name, arguments=inp) + yield tool_call_end_chunk(tcid) + elif btype == "tool_result": + tcid = str(block.get("tool_use_id", "")) + if tcid: + inner = block.get("content", "") + if isinstance(inner, list): + # tool_result content may itself be a + # block list — stitch text blocks. + inner = "".join( + str(b.get("text", "")) for b in inner if isinstance(b, dict) + ) + elif not isinstance(inner, str): + inner = str(inner) + yield tool_call_result_chunk(tcid, inner, role="tool") + else: + yield raw_chunk(block, source="harness") + continue + + # Unknown top-level message — preserve as raw. + yield raw_chunk(msg, source="harness") + + +def _harness_message_content(msg: Mapping[str, Any]) -> Any: + """Reach into the harness message envelope for the content list, + handling both the bare ``content`` shape and the ``message.content`` + shape the Claude Agent SDK uses.""" + if "content" in msg: + return msg["content"] + inner = msg.get("message") + if isinstance(inner, Mapping): + return inner.get("content") + return None + + +async def serialize_stream(generator: Any) -> Any: + """Serialize an async generator of chunk dicts (or strings — strings + are wrapped as text chunks) into an async iterator of NDJSON-encoded + ``bytes``, suitable for ``fastapi.StreamingResponse``:: + + from fastapi import Request + from fastapi.responses import StreamingResponse + from agentfield import agui + + @app.post("/reasoners/chat") + async def chat(request: Request): + body = await request.json() + return StreamingResponse( + agui.serialize_stream(_chat_chunks(body)), + media_type=agui.STREAMING_CONTENT_TYPE, + ) + + async def _chat_chunks(body): + async for token in llm.stream(body["prompt"]): + yield agui.text_chunk(token) + + Bare strings yielded by the generator are auto-wrapped as text + chunks for ergonomics. Anything else must be a dict produced by one + of the chunk builders above (or a hand-rolled equivalent).""" + import json as _json + + async for item in generator: + if isinstance(item, str): + payload = text_chunk(item) + elif isinstance(item, Mapping): + payload = dict(item) + else: + raise TypeError( + "streaming reasoner yielded non-str/non-dict value of type " + f"{type(item).__name__}; use one of agui.*_chunk(...)" + ) + # No spaces — these are machine-to-machine; keep lines compact. + yield (_json.dumps(payload, separators=(",", ":")) + "\n").encode("utf-8") + + +def state_delta_from_diff( + before: Mapping[str, Any], + after: Mapping[str, Any], +) -> List[dict]: + """Compute a minimal RFC 6902 patch list for top-level keys that + differ between ``before`` and ``after``. + + This is a deliberately shallow utility — it only walks the top level + of the mapping and emits ``replace``/``add``/``remove`` ops as + needed. Reasoners with nested state should construct patches + explicitly (or just emit a full ``state`` snapshot). + """ + ops: List[dict] = [] + keys: Iterable[str] = sorted(set(before.keys()) | set(after.keys())) + for k in keys: + path = f"/{k}" + if k in before and k in after: + if before[k] != after[k]: + ops.append({"op": "replace", "path": path, "value": after[k]}) + elif k in after: + ops.append({"op": "add", "path": path, "value": after[k]}) + else: + ops.append({"op": "remove", "path": path}) + return ops diff --git a/sdk/python/tests/test_agui_helpers.py b/sdk/python/tests/test_agui_helpers.py new file mode 100644 index 00000000..7f9bf1fc --- /dev/null +++ b/sdk/python/tests/test_agui_helpers.py @@ -0,0 +1,386 @@ +"""Tests for the agentfield.agui helpers — the documented contract for +opt-in Generative UI / shared state through the control plane's AG-UI +adapter.""" + +import json + +import pytest + +from agentfield import agui +from agentfield.tool_calling import ToolCallRecord, ToolCallTrace + + +class TestToolCall: + def test_minimal(self): + e = agui.tool_call(name="showFlightCard") + assert e == {"name": "showFlightCard", "arguments": {}} + assert "result" not in e + assert "id" not in e + + def test_with_arguments_and_id(self): + e = agui.tool_call(name="x", arguments={"a": 1, "b": "z"}, id="tc-1") + assert e == {"name": "x", "arguments": {"a": 1, "b": "z"}, "id": "tc-1"} + + def test_with_result_attaches_for_executed_calls(self): + e = agui.tool_call(name="getWeather", result={"temp": 62}) + assert e["result"] == {"temp": 62} + + def test_explicit_null_result(self): + e = agui.tool_call(name="x", has_result=True) + assert "result" in e + assert e["result"] is None + + +class TestToolCallsFromTrace: + def test_none_trace_returns_empty(self): + assert agui.tool_calls_from_trace(None) == [] + + def test_empty_calls_returns_empty(self): + trace = ToolCallTrace(calls=[]) + assert agui.tool_calls_from_trace(trace) == [] + + def test_records_become_entries_with_results(self): + trace = ToolCallTrace( + calls=[ + ToolCallRecord( + tool_name="getWeather", + arguments={"city": "SF"}, + result={"temp": 62}, + ), + ToolCallRecord( + tool_name="lookup", + arguments={"q": "gates"}, + error="api timeout", + ), + ] + ) + out = agui.tool_calls_from_trace(trace) + assert len(out) == 2 + assert out[0]["name"] == "getWeather" + assert out[0]["arguments"] == {"city": "SF"} + assert out[0]["result"] == {"temp": 62} + # id is synthesized so the control plane can correlate frames + # without colliding across calls in the same trace. + assert out[0]["id"] == "tc-trace-0" + + assert out[1]["name"] == "lookup" + assert out[1]["result"] == {"error": "api timeout"} + assert out[1]["id"] == "tc-trace-1" + + def test_trace_with_no_result_or_error_omits_result_field(self): + trace = ToolCallTrace(calls=[ToolCallRecord(tool_name="x", arguments={})]) + out = agui.tool_calls_from_trace(trace) + assert "result" not in out[0] + + +class TestStateDeltaHelpers: + def test_replace_op(self): + assert agui.state_delta_replace("/counter", 2) == { + "op": "replace", + "path": "/counter", + "value": 2, + } + + def test_replace_rejects_invalid_path(self): + with pytest.raises(ValueError): + agui.state_delta_replace("counter", 2) # missing leading slash + + def test_diff_emits_replace_for_changed(self): + ops = agui.state_delta_from_diff({"a": 1, "b": 2}, {"a": 1, "b": 3}) + assert ops == [{"op": "replace", "path": "/b", "value": 3}] + + def test_diff_emits_add_for_new_keys(self): + ops = agui.state_delta_from_diff({}, {"x": 1}) + assert ops == [{"op": "add", "path": "/x", "value": 1}] + + def test_diff_emits_remove_for_dropped_keys(self): + ops = agui.state_delta_from_diff({"x": 1}, {}) + assert ops == [{"op": "remove", "path": "/x"}] + + def test_diff_no_ops_when_identical(self): + assert agui.state_delta_from_diff({"a": 1}, {"a": 1}) == [] + + +class TestReasoningHelpers: + def test_segment_minimal(self): + seg = agui.reasoning_segment("thinking") + assert seg == {"content": "thinking"} + + def test_segment_with_id(self): + seg = agui.reasoning_segment("thinking", id="r-1") + assert seg == {"content": "thinking", "id": "r-1"} + + def test_reasoning_strings_pass_through(self): + assert agui.reasoning("step 1", "step 2") == ["step 1", "step 2"] + + def test_reasoning_drops_empty_strings(self): + assert agui.reasoning("step 1", "", "step 2") == ["step 1", "step 2"] + + def test_reasoning_accepts_segment_dicts(self): + out = agui.reasoning("step 1", agui.reasoning_segment("step 2", id="r-2")) + assert out == ["step 1", {"content": "step 2", "id": "r-2"}] + + def test_reasoning_rejects_unknown_types(self): + with pytest.raises(TypeError): + agui.reasoning(42) + + +class TestStreamingChunkBuilders: + def test_text_chunk(self): + assert agui.text_chunk("hello") == {"type": "text", "delta": "hello"} + + def test_reasoning_chunks(self): + assert agui.reasoning_chunk("think") == {"type": "reasoning", "delta": "think"} + assert agui.reasoning_end_chunk() == {"type": "reasoning_end"} + + def test_tool_call_start_with_args_inline(self): + c = agui.tool_call_start_chunk("tc1", "showCard", arguments={"x": 1}) + assert c == {"type": "tool_call_start", "id": "tc1", "name": "showCard", "arguments": {"x": 1}} + + def test_tool_call_start_with_parent(self): + c = agui.tool_call_start_chunk("tc1", "x", parent_message_id="m1") + assert c["parentMessageId"] == "m1" + + def test_tool_call_args_stream(self): + assert agui.tool_call_args_chunk("tc1", '{"x') == { + "type": "tool_call_args", + "id": "tc1", + "delta": '{"x', + } + + def test_tool_call_end_and_result(self): + assert agui.tool_call_end_chunk("tc1") == {"type": "tool_call_end", "id": "tc1"} + r = agui.tool_call_result_chunk("tc1", "ok", role="tool") + assert r == {"type": "tool_call_result", "id": "tc1", "content": "ok", "role": "tool"} + + def test_state_chunks(self): + assert agui.state_chunk({"counter": 1}) == {"type": "state", "snapshot": {"counter": 1}} + assert agui.state_delta_chunk([{"op": "replace", "path": "/x", "value": 1}]) == { + "type": "state_delta", + "ops": [{"op": "replace", "path": "/x", "value": 1}], + } + + def test_step_chunks(self): + assert agui.step_started_chunk("plan") == {"type": "step_started", "name": "plan"} + assert agui.step_finished_chunk("plan") == {"type": "step_finished", "name": "plan"} + + def test_raw_and_custom(self): + assert agui.raw_chunk({"x": 1}, source="harness") == { + "type": "raw", + "event": {"x": 1}, + "source": "harness", + } + assert agui.custom_chunk("ack", value={"ok": True}) == { + "type": "custom", + "name": "ack", + "value": {"ok": True}, + } + + def test_final_chunk(self): + c = agui.final_chunk({"toolCalls": [{"name": "x"}]}) + assert c == {"type": "final", "data": {"toolCalls": [{"name": "x"}]}} + + def test_error_chunk(self): + assert agui.error_chunk("boom", code="E1") == { + "type": "error", + "message": "boom", + "code": "E1", + } + + +class TestSerializeStream: + @pytest.mark.asyncio + async def test_yields_ndjson_lines(self): + async def gen(): + yield agui.text_chunk("hello ") + yield agui.text_chunk("world") + yield agui.tool_call_start_chunk("tc1", "x") + + lines = [] + async for chunk in agui.serialize_stream(gen()): + assert isinstance(chunk, bytes) + assert chunk.endswith(b"\n") + lines.append(chunk.decode("utf-8").rstrip("\n")) + assert len(lines) == 3 + assert json.loads(lines[0]) == {"type": "text", "delta": "hello "} + assert json.loads(lines[1]) == {"type": "text", "delta": "world"} + assert json.loads(lines[2])["type"] == "tool_call_start" + + @pytest.mark.asyncio + async def test_bare_string_wraps_as_text_chunk(self): + async def gen(): + yield "ergonomic" + + out = [] + async for chunk in agui.serialize_stream(gen()): + out.append(json.loads(chunk)) + assert out == [{"type": "text", "delta": "ergonomic"}] + + @pytest.mark.asyncio + async def test_invalid_yield_raises_typeerror(self): + async def gen(): + yield 42 # not str / not dict + + with pytest.raises(TypeError): + async for _ in agui.serialize_stream(gen()): + pass + + +class TestHarnessRelay: + """Coverage for relay_harness_stream — the bridge that turns a Claude + Agent SDK / harness async iterator into AG-UI streaming chunks.""" + + @pytest.mark.asyncio + async def test_assistant_text_block_becomes_text_chunk(self): + async def fake_harness(): + yield { + "type": "assistant", + "message": {"content": [{"type": "text", "text": "Hello!"}]}, + } + yield {"type": "result", "subtype": "success", "result": "Hello!"} + + chunks = [c async for c in agui.relay_harness_stream(fake_harness())] + # result message yields nothing; only the text chunk survives. + assert chunks == [{"type": "text", "delta": "Hello!"}] + + @pytest.mark.asyncio + async def test_thinking_block_becomes_reasoning_chunk(self): + async def fake(): + yield { + "type": "assistant", + "message": {"content": [ + {"type": "thinking", "thinking": "Let me think..."}, + {"type": "text", "text": "Done."}, + ]}, + } + + chunks = [c async for c in agui.relay_harness_stream(fake())] + assert chunks[0] == {"type": "reasoning", "delta": "Let me think..."} + assert chunks[1] == {"type": "text", "delta": "Done."} + + @pytest.mark.asyncio + async def test_tool_use_emits_start_and_end(self): + async def fake(): + yield { + "type": "assistant", + "message": {"content": [{ + "type": "tool_use", + "id": "tu-1", + "name": "get_weather", + "input": {"city": "SF"}, + }]}, + } + + chunks = [c async for c in agui.relay_harness_stream(fake())] + assert chunks[0]["type"] == "tool_call_start" + assert chunks[0]["id"] == "tu-1" + assert chunks[0]["name"] == "get_weather" + assert chunks[0]["arguments"] == {"city": "SF"} + assert chunks[1] == {"type": "tool_call_end", "id": "tu-1"} + + @pytest.mark.asyncio + async def test_tool_result_emits_result_chunk(self): + async def fake(): + yield { + "type": "user", + "message": {"content": [{ + "type": "tool_result", + "tool_use_id": "tu-1", + "content": "62°F, foggy", + }]}, + } + + chunks = [c async for c in agui.relay_harness_stream(fake())] + assert chunks[0]["type"] == "tool_call_result" + assert chunks[0]["id"] == "tu-1" + assert chunks[0]["content"] == "62°F, foggy" + + @pytest.mark.asyncio + async def test_tool_result_with_block_list_stitches_text(self): + async def fake(): + yield { + "type": "user", + "message": {"content": [{ + "type": "tool_result", + "tool_use_id": "tu-1", + "content": [ + {"type": "text", "text": "part 1 "}, + {"type": "text", "text": "part 2"}, + ], + }]}, + } + + chunks = [c async for c in agui.relay_harness_stream(fake())] + assert chunks[0]["content"] == "part 1 part 2" + + @pytest.mark.asyncio + async def test_unknown_block_falls_back_to_raw(self): + async def fake(): + yield { + "type": "assistant", + "message": {"content": [{"type": "weird-thing", "data": 42}]}, + } + + chunks = [c async for c in agui.relay_harness_stream(fake())] + assert chunks[0]["type"] == "raw" + assert chunks[0]["source"] == "harness" + + @pytest.mark.asyncio + async def test_unknown_message_type_becomes_raw(self): + async def fake(): + yield {"type": "system", "info": "starting"} + yield {"type": "totally_unknown", "x": 1} + + chunks = [c async for c in agui.relay_harness_stream(fake())] + assert all(c["type"] == "raw" for c in chunks) + + @pytest.mark.asyncio + async def test_result_message_yields_nothing(self): + async def fake(): + yield {"type": "result", "subtype": "success", "result": "done"} + + chunks = [c async for c in agui.relay_harness_stream(fake())] + assert chunks == [] + + +class TestStreamingFastAPIRoundTrip: + """End-to-end: a FastAPI app using StreamingResponse + serialize_stream + must produce exactly the NDJSON bytes the control plane's streaming + dispatcher consumes. This is the SDK-side test of the wire contract.""" + + @pytest.mark.asyncio + async def test_streaming_endpoint_returns_ndjson(self): + from fastapi import FastAPI + from fastapi.responses import StreamingResponse + from httpx import ASGITransport, AsyncClient + + app = FastAPI() + + async def chunks(): + yield agui.reasoning_chunk("checking flights...") + yield agui.text_chunk("Booked ") + yield agui.text_chunk("AA-12.") + yield agui.tool_call_start_chunk("tc1", "showFlightCard", arguments={"flight": "AA-12"}) + yield agui.tool_call_end_chunk("tc1") + yield agui.state_chunk({"counter": 1}) + + @app.post("/reasoners/chat") + async def chat(): + return StreamingResponse( + agui.serialize_stream(chunks()), + media_type=agui.STREAMING_CONTENT_TYPE, + ) + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + resp = await client.post("/reasoners/chat") + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("application/x-ndjson") + lines = [line for line in resp.text.split("\n") if line] + assert len(lines) == 6 + decoded = [json.loads(line) for line in lines] + assert decoded[0]["type"] == "reasoning" + assert decoded[1]["type"] == "text" + assert decoded[1]["delta"] == "Booked " + assert decoded[3]["type"] == "tool_call_start" + assert decoded[3]["arguments"] == {"flight": "AA-12"} + assert decoded[5]["type"] == "state" diff --git a/sdk/typescript/src/agui/index.ts b/sdk/typescript/src/agui/index.ts new file mode 100644 index 00000000..be502835 --- /dev/null +++ b/sdk/typescript/src/agui/index.ts @@ -0,0 +1,457 @@ +/** + * AG-UI protocol helpers for AgentField TypeScript reasoners. + * + * Mirrors `sdk/python/agentfield/agui.py` 1:1 so a Node-side reasoner + * has the same authoring surface as a Python one. + * + * Two ways to use this module: + * + * 1. **Buffered mode** — return a normal JSON response from a reasoner + * with the optional `toolCalls` / `state` / `stateDelta` / + * `reasoning` fields. The control plane translates those into + * AG-UI `TOOL_CALL_*` / `STATE_*` / `REASONING_*` events. + * + * 2. **Streaming mode** — return `Content-Type: application/x-ndjson` + * and stream chunks built with `textChunk()`, `reasoningChunk()`, + * `toolCallStartChunk()`, etc. Each chunk becomes one live AG-UI + * event (see `internal/handlers/agui_runs_streaming.go`). + * + * Reasoners reach the AG-UI / CopilotKit frontend via the control + * plane's `POST /api/v1/agui/runs//` adapter. + */ +import type { ToolCallTrace, ToolCallRecord } from '../ai/ToolCalling.js'; + +export const STREAMING_CONTENT_TYPE = 'application/x-ndjson'; + +/** A single AG-UI tool-call entry (buffered-mode `toolCalls` array). */ +export interface ToolCallEntry { + id?: string; + name: string; + arguments: Record; + result?: unknown; +} + +/** RFC 6902 patch op. */ +export interface JsonPatchOp { + op: 'replace' | 'add' | 'remove'; + path: string; + value?: unknown; +} + +/** A reasoning segment for buffered REASONING_* emission. */ +export interface ReasoningSegment { + content: string; + id?: string; +} + +// --------------------------------------------------------------------------- +// Buffered-mode helpers +// --------------------------------------------------------------------------- + +/** + * Build a single AG-UI tool-call entry. The control plane translates each + * entry into a `TOOL_CALL_START` / `TOOL_CALL_ARGS` / `TOOL_CALL_END` + * triad. When `result` is set (or `hasResult` is true), it also emits + * `TOOL_CALL_RESULT` so a server-side trace renders in the UI. + * + * @param name Tool name. CopilotKit pattern-matches this against + * `useCopilotAction({name, render})` registrations. + * @param args JSON-serializable arguments mapping. + * @param opts.id Optional stable ID. If omitted, the control plane + * synthesizes one (works for one-shots; breaks correlation with + * follow-up tool messages). + * @param opts.result Optional pre-executed result. + * @param opts.hasResult Force `result: undefined` to be treated as an + * explicit "executed and returned null" instead of "not executed yet". + */ +export function toolCall( + name: string, + args?: Record, + opts: { id?: string; result?: unknown; hasResult?: boolean } = {}, +): ToolCallEntry { + const entry: ToolCallEntry = { name, arguments: { ...(args ?? {}) } }; + if (opts.id !== undefined) entry.id = opts.id; + if (opts.result !== undefined || opts.hasResult) entry.result = opts.result; + return entry; +} + +/** + * Convert a `ToolCallTrace` from `ctx.aiWithTools(...)` into the AG-UI + * `toolCalls` list shape. + */ +export function toolCallsFromTrace(trace: ToolCallTrace | null | undefined): ToolCallEntry[] { + if (!trace || !trace.calls?.length) return []; + return trace.calls.map((rec, i) => recordToEntry(rec, i)); +} + +function recordToEntry(rec: ToolCallRecord, index: number): ToolCallEntry { + const entry: ToolCallEntry = { + id: `tc-trace-${index}`, + name: rec.toolName, + arguments: { ...(rec.arguments ?? {}) }, + }; + if (rec.error !== undefined && rec.error !== null) { + entry.result = { error: rec.error }; + } else if (rec.result !== undefined && rec.result !== null) { + entry.result = rec.result; + } + return entry; +} + +/** Build a single RFC 6902 `replace` patch op for a `stateDelta` array. */ +export function stateDeltaReplace(path: string, value: unknown): JsonPatchOp { + if (!path.startsWith('/')) { + throw new Error("RFC 6902 paths must start with '/'"); + } + return { op: 'replace', path, value }; +} + +/** + * Compute a minimal RFC 6902 patch list for top-level keys that differ + * between `before` and `after`. Shallow only. + */ +export function stateDeltaFromDiff( + before: Record, + after: Record, +): JsonPatchOp[] { + const ops: JsonPatchOp[] = []; + const keys = new Set([...Object.keys(before), ...Object.keys(after)]); + for (const k of [...keys].sort()) { + const path = `/${k}`; + const inBefore = k in before; + const inAfter = k in after; + if (inBefore && inAfter) { + if (!deepEqual(before[k], after[k])) ops.push({ op: 'replace', path, value: after[k] }); + } else if (inAfter) { + ops.push({ op: 'add', path, value: after[k] }); + } else { + ops.push({ op: 'remove', path }); + } + } + return ops; +} + +function deepEqual(a: unknown, b: unknown): boolean { + if (a === b) return true; + if (typeof a !== typeof b) return false; + if (a === null || b === null) return a === b; + if (typeof a !== 'object') return false; + // JSON-serializable comparison is sufficient for shallow patch ops. + try { + return JSON.stringify(a) === JSON.stringify(b); + } catch { + return false; + } +} + +/** + * Build a single REASONING_MESSAGE segment. Each segment becomes a + * `REASONING_MESSAGE_START` / `_CONTENT` / `_END` triad inside a + * `REASONING_START` / `_END` boundary. + */ +export function reasoningSegment(content: string, opts: { id?: string } = {}): ReasoningSegment { + const out: ReasoningSegment = { content }; + if (opts.id !== undefined) out.id = opts.id; + return out; +} + +/** + * Build a `reasoning` field value from a mix of strings and segments. + * + * @example + * return { result: text, reasoning: agui.reasoning('step 1', 'step 2') }; + */ +export function reasoning(...segments: Array): Array { + const out: Array = []; + for (const s of segments) { + if (typeof s === 'string') { + if (s) out.push(s); + } else if (s && typeof s === 'object' && typeof s.content === 'string') { + out.push({ ...s }); + } else { + throw new TypeError(`reasoning() segments must be string or {content,id?}; got ${typeof s}`); + } + } + return out; +} + +// --------------------------------------------------------------------------- +// Streaming chunk builders +// +// Each function returns a small object in the wire shape the control plane's +// streaming dispatcher consumes (see internal/handlers/agui_runs_streaming.go). +// The reasoner author yields these from an async generator; serializeStream +// turns each yield into one NDJSON line for the streaming response. +// --------------------------------------------------------------------------- + +export type StreamingChunk = Record & { type: string }; + +/** One chunk of assistant text. Concatenated client-side. */ +export function textChunk(delta: string): StreamingChunk { + return { type: 'text', delta }; +} + +/** One chunk of chain-of-thought, rendered in CopilotKit's "Thinking…" pane. */ +export function reasoningChunk(delta: string): StreamingChunk { + return { type: 'reasoning', delta }; +} + +/** Closes the current reasoning segment so the next reasoningChunk opens a fresh one. */ +export function reasoningEndChunk(): StreamingChunk { + return { type: 'reasoning_end' }; +} + +/** + * Open a tool call. If you already have the full `arguments`, pass them + * here and the dispatcher emits one `TOOL_CALL_ARGS` frame immediately; + * otherwise stream them with `toolCallArgsChunk`. + */ +export function toolCallStartChunk( + id: string, + name: string, + opts: { arguments?: Record; parentMessageId?: string } = {}, +): StreamingChunk { + const out: StreamingChunk = { type: 'tool_call_start', id, name }; + if (opts.arguments !== undefined) out.arguments = { ...opts.arguments }; + if (opts.parentMessageId !== undefined) out.parentMessageId = opts.parentMessageId; + return out; +} + +/** One chunk of streaming tool-call arguments JSON. */ +export function toolCallArgsChunk(id: string, delta: string): StreamingChunk { + return { type: 'tool_call_args', id, delta }; +} + +/** Close a tool call. */ +export function toolCallEndChunk(id: string): StreamingChunk { + return { type: 'tool_call_end', id }; +} + +/** Server-side tool result — use after pre-executing the tool. */ +export function toolCallResultChunk( + id: string, + content: string, + opts: { role?: string } = {}, +): StreamingChunk { + return { type: 'tool_call_result', id, content, role: opts.role ?? 'tool' }; +} + +/** Full agent state snapshot (the value `useCoAgent({state})` reads). */ +export function stateChunk(snapshot: unknown): StreamingChunk { + return { type: 'state', snapshot }; +} + +/** RFC 6902 patch ops applied incrementally on top of the last snapshot. */ +export function stateDeltaChunk(ops: JsonPatchOp[]): StreamingChunk { + return { type: 'state_delta', ops: [...ops] }; +} + +/** Mark the start of a named step inside the run. */ +export function stepStartedChunk(name: string): StreamingChunk { + return { type: 'step_started', name }; +} + +/** Mark a step finished. */ +export function stepFinishedChunk(name: string): StreamingChunk { + return { type: 'step_finished', name }; +} + +/** Pass a foreign-system event through verbatim. */ +export function rawChunk(event: unknown, opts: { source?: string } = {}): StreamingChunk { + const out: StreamingChunk = { type: 'raw', event }; + if (opts.source !== undefined) out.source = opts.source; + return out; +} + +/** Application-defined event. Frontends subscribe by `name`. */ +export function customChunk(name: string, value?: unknown): StreamingChunk { + const out: StreamingChunk = { type: 'custom', name }; + if (value !== undefined) out.value = value; + return out; +} + +/** + * Trailing buffered envelope — the dispatcher applies any + * `toolCalls` / `state` / `stateDelta` / `reasoning` / `result` fields + * here as if from a non-streaming reasoner. + */ +export function finalChunk(data: Record): StreamingChunk { + return { type: 'final', data: { ...data } }; +} + +/** Terminal error. The dispatcher emits RUN_ERROR and stops the run. */ +export function errorChunk(message: string, opts: { code?: string } = {}): StreamingChunk { + const out: StreamingChunk = { type: 'error', message }; + if (opts.code !== undefined) out.code = opts.code; + return out; +} + +// --------------------------------------------------------------------------- +// Streaming serialization +// --------------------------------------------------------------------------- + +/** + * Serialize an async iterable of chunk objects (or strings — strings are + * wrapped as text chunks) into an async iterable of NDJSON-encoded + * `Uint8Array`, suitable for any Node streaming response (Express, + * Fastify, Hono, the built-in `http` module, or a Web `Response` + * built from a `ReadableStream`). + * + * Express: + * + * res.setHeader('Content-Type', agui.STREAMING_CONTENT_TYPE); + * for await (const buf of agui.serializeStream(chunks)) res.write(buf); + * res.end(); + * + * Web `Response` (works in Node 20+, Hono, edge runtimes): + * + * const body = new ReadableStream({ + * async start(controller) { + * for await (const buf of agui.serializeStream(chunks)) controller.enqueue(buf); + * controller.close(); + * } + * }); + * return new Response(body, { headers: { 'Content-Type': agui.STREAMING_CONTENT_TYPE }}); + * + * Bare strings yielded by the generator are auto-wrapped as text chunks + * for ergonomics. Anything else must be a chunk object produced by one + * of the chunk builders above (or a hand-rolled equivalent). + */ +export async function* serializeStream( + source: AsyncIterable | Iterable, +): AsyncIterable { + const encoder = new TextEncoder(); + for await (const item of source as AsyncIterable) { + let payload: StreamingChunk; + if (typeof item === 'string') { + payload = textChunk(item); + } else if (item && typeof item === 'object') { + payload = item; + } else { + throw new TypeError( + `streaming reasoner yielded non-string/non-object value of type ${typeof item}; ` + + 'use one of the agui chunk builders', + ); + } + yield encoder.encode(JSON.stringify(payload) + '\n'); + } +} + +// --------------------------------------------------------------------------- +// Harness relay +// --------------------------------------------------------------------------- + +/** + * Relay a `@anthropic-ai/claude-agent-sdk` async-iterable of messages + * into AG-UI streaming chunks, message-by-message. + * + * Mirrors `relay_harness_stream` in the Python SDK. Recognized message + * shapes (the dict form `HarnessResult.messages` records): + * + * - `{ type:'assistant', message:{ content:[{type:'text', text:'...'}, ...] }}` + * → one `text` chunk per text block + * - `{ type:'assistant', message:{ content:[{type:'thinking', thinking:'...'}, ...] }}` + * → one `reasoning` chunk per thinking block + * - `{ type:'assistant', message:{ content:[{type:'tool_use', id:'...', name:'...', input:{...}}, ...] }}` + * → `tool_call_start` + `tool_call_end` per tool_use block + * - `{ type:'user', message:{ content:[{type:'tool_result', tool_use_id:'...', content:'...'}, ...] }}` + * → `tool_call_result` per tool_result block + * - `{ type:'result', subtype:'success', result:'...' }` → + * terminal — yields nothing (the dispatcher's stream-end logic wraps + * the run with MESSAGES_SNAPSHOT + RUN_FINISHED). + * - Anything unrecognized is wrapped as a `raw` chunk so the trace is + * preserved without inventing ad-hoc event types. + * + * Note: the Claude Agent SDK buffers per-message rather than per-token, + * so this path streams at message granularity. True per-token streaming + * requires the raw Anthropic streaming API, not the harness. + */ +export async function* relayHarnessStream( + harnessIter: AsyncIterable | Iterable, +): AsyncIterable { + for await (const raw of harnessIter as AsyncIterable) { + let msg: Record; + if (raw && typeof raw === 'object' && !Array.isArray(raw)) { + msg = raw as Record; + } else { + yield rawChunk({ raw: String(raw) }, { source: 'harness' }); + continue; + } + + const msgType = String(msg.type ?? ''); + if (msgType === 'result') { + // Final aggregated text — dispatcher's stream-end synthesizes + // MESSAGES_SNAPSHOT / RUN_FINISHED, so emit nothing here. + continue; + } + if (msgType === 'system') { + yield rawChunk(msg, { source: 'harness' }); + continue; + } + + if (msgType === 'assistant' || msgType === 'user') { + const content = harnessMessageContent(msg); + if (content === undefined || content === null) { + yield rawChunk(msg, { source: 'harness' }); + continue; + } + if (typeof content === 'string') { + if (msgType === 'assistant' && content) yield textChunk(content); + continue; + } + if (Array.isArray(content)) { + for (const block of content) { + if (!block || typeof block !== 'object') continue; + const b = block as Record; + const btype = b.type; + if (btype === 'text') { + const text = String(b.text ?? ''); + if (text) yield textChunk(text); + } else if (btype === 'thinking') { + const thinking = String(b.thinking ?? ''); + if (thinking) yield reasoningChunk(thinking); + } else if (btype === 'tool_use') { + const tcid = String(b.id ?? ''); + const name = String(b.name ?? ''); + if (tcid && name) { + const inp = + b.input && typeof b.input === 'object' && !Array.isArray(b.input) + ? (b.input as Record) + : {}; + yield toolCallStartChunk(tcid, name, { arguments: inp }); + yield toolCallEndChunk(tcid); + } + } else if (btype === 'tool_result') { + const tcid = String(b.tool_use_id ?? ''); + if (tcid) { + let inner = b.content; + if (Array.isArray(inner)) { + inner = (inner as unknown[]) + .filter((x): x is Record => !!x && typeof x === 'object') + .map((x) => String(x.text ?? '')) + .join(''); + } else if (typeof inner !== 'string') { + inner = String(inner ?? ''); + } + yield toolCallResultChunk(tcid, inner as string, { role: 'tool' }); + } + } else { + yield rawChunk(b, { source: 'harness' }); + } + } + } + continue; + } + + yield rawChunk(msg, { source: 'harness' }); + } +} + +function harnessMessageContent(msg: Record): unknown { + if ('content' in msg) return msg.content; + const inner = msg.message; + if (inner && typeof inner === 'object' && !Array.isArray(inner)) { + return (inner as Record).content; + } + return undefined; +} diff --git a/sdk/typescript/src/index.ts b/sdk/typescript/src/index.ts index c3a00050..a222bdac 100644 --- a/sdk/typescript/src/index.ts +++ b/sdk/typescript/src/index.ts @@ -26,3 +26,4 @@ export * from './types/skill.js'; export * from './harness/index.js'; export * from './status/ExecutionStatus.js'; export * from './approval/ApprovalClient.js'; +export * as agui from './agui/index.js'; diff --git a/sdk/typescript/tests/agui.test.ts b/sdk/typescript/tests/agui.test.ts new file mode 100644 index 00000000..19c57dea --- /dev/null +++ b/sdk/typescript/tests/agui.test.ts @@ -0,0 +1,402 @@ +import { describe, it, expect } from 'vitest'; +import { agui } from '../src/index.js'; +import type { ToolCallTrace } from '../src/ai/ToolCalling.js'; + +describe('agui — buffered helpers', () => { + it('toolCall builds the canonical entry', () => { + expect(agui.toolCall('showFlightCard', { from: 'SFO', to: 'JFK' })).toEqual({ + name: 'showFlightCard', + arguments: { from: 'SFO', to: 'JFK' }, + }); + }); + + it('toolCall handles empty arguments and explicit id', () => { + expect(agui.toolCall('ping', undefined, { id: 'tc-1' })).toEqual({ + id: 'tc-1', + name: 'ping', + arguments: {}, + }); + }); + + it('toolCall surfaces a result when provided', () => { + expect(agui.toolCall('lookup', { q: 'x' }, { result: { ok: true } })).toEqual({ + name: 'lookup', + arguments: { q: 'x' }, + result: { ok: true }, + }); + }); + + it('toolCall hasResult forces a null result through', () => { + expect(agui.toolCall('noop', undefined, { hasResult: true })).toEqual({ + name: 'noop', + arguments: {}, + result: undefined, + }); + }); + + it('toolCallsFromTrace returns [] for empty / null traces', () => { + expect(agui.toolCallsFromTrace(null)).toEqual([]); + expect(agui.toolCallsFromTrace(undefined)).toEqual([]); + const empty: ToolCallTrace = { calls: [], totalTurns: 0, totalToolCalls: 0 }; + expect(agui.toolCallsFromTrace(empty)).toEqual([]); + }); + + it('toolCallsFromTrace converts records, surfaces result and error', () => { + const trace: ToolCallTrace = { + totalTurns: 1, + totalToolCalls: 2, + calls: [ + { toolName: 'a', arguments: { x: 1 }, result: { ok: true }, latencyMs: 5, turn: 0 }, + { toolName: 'b', arguments: {}, error: 'boom', latencyMs: 5, turn: 0 }, + ], + }; + expect(agui.toolCallsFromTrace(trace)).toEqual([ + { id: 'tc-trace-0', name: 'a', arguments: { x: 1 }, result: { ok: true } }, + { id: 'tc-trace-1', name: 'b', arguments: {}, result: { error: 'boom' } }, + ]); + }); + + it('stateDeltaReplace emits a JSON Patch op', () => { + expect(agui.stateDeltaReplace('/counter', 2)).toEqual({ + op: 'replace', + path: '/counter', + value: 2, + }); + }); + + it('stateDeltaReplace rejects paths missing a leading slash', () => { + expect(() => agui.stateDeltaReplace('counter', 1)).toThrow(/RFC 6902/); + }); + + it('stateDeltaFromDiff emits a minimal shallow patch', () => { + const before = { a: 1, b: 2, c: 3 }; + const after = { a: 1, b: 99, d: 4 }; + expect(agui.stateDeltaFromDiff(before, after)).toEqual([ + { op: 'replace', path: '/b', value: 99 }, + { op: 'remove', path: '/c' }, + { op: 'add', path: '/d', value: 4 }, + ]); + }); + + it('reasoningSegment + reasoning() build the segment list', () => { + const seg = agui.reasoningSegment('thinking', { id: 'r1' }); + expect(seg).toEqual({ content: 'thinking', id: 'r1' }); + expect(agui.reasoning('a', '', seg, 'b')).toEqual([ + 'a', + { content: 'thinking', id: 'r1' }, + 'b', + ]); + }); + + it('reasoning() rejects garbage segments', () => { + expect(() => agui.reasoning(42 as unknown as string)).toThrow(/segments must be string/); + }); +}); + +describe('agui — streaming chunk builders', () => { + it('text/reasoning/reasoning_end', () => { + expect(agui.textChunk('hi')).toEqual({ type: 'text', delta: 'hi' }); + expect(agui.reasoningChunk('thinking')).toEqual({ type: 'reasoning', delta: 'thinking' }); + expect(agui.reasoningEndChunk()).toEqual({ type: 'reasoning_end' }); + }); + + it('toolCallStart with and without args/parent', () => { + expect(agui.toolCallStartChunk('tc1', 'foo')).toEqual({ + type: 'tool_call_start', + id: 'tc1', + name: 'foo', + }); + expect( + agui.toolCallStartChunk('tc2', 'bar', { arguments: { x: 1 }, parentMessageId: 'm1' }), + ).toEqual({ + type: 'tool_call_start', + id: 'tc2', + name: 'bar', + arguments: { x: 1 }, + parentMessageId: 'm1', + }); + }); + + it('toolCallArgs / toolCallEnd / toolCallResult', () => { + expect(agui.toolCallArgsChunk('tc1', '{"x":')).toEqual({ + type: 'tool_call_args', + id: 'tc1', + delta: '{"x":', + }); + expect(agui.toolCallEndChunk('tc1')).toEqual({ type: 'tool_call_end', id: 'tc1' }); + expect(agui.toolCallResultChunk('tc1', 'done')).toEqual({ + type: 'tool_call_result', + id: 'tc1', + content: 'done', + role: 'tool', + }); + expect(agui.toolCallResultChunk('tc1', 'done', { role: 'system' })).toMatchObject({ + role: 'system', + }); + }); + + it('state / state_delta', () => { + expect(agui.stateChunk({ k: 1 })).toEqual({ type: 'state', snapshot: { k: 1 } }); + const ops = [agui.stateDeltaReplace('/k', 2)]; + expect(agui.stateDeltaChunk(ops)).toEqual({ type: 'state_delta', ops }); + }); + + it('step_started / step_finished', () => { + expect(agui.stepStartedChunk('plan')).toEqual({ type: 'step_started', name: 'plan' }); + expect(agui.stepFinishedChunk('plan')).toEqual({ type: 'step_finished', name: 'plan' }); + }); + + it('raw / custom / final / error chunk shapes', () => { + expect(agui.rawChunk({ k: 1 })).toEqual({ type: 'raw', event: { k: 1 } }); + expect(agui.rawChunk({ k: 1 }, { source: 'harness' })).toEqual({ + type: 'raw', + event: { k: 1 }, + source: 'harness', + }); + expect(agui.customChunk('progress', 0.5)).toEqual({ + type: 'custom', + name: 'progress', + value: 0.5, + }); + expect(agui.customChunk('ping')).toEqual({ type: 'custom', name: 'ping' }); + expect(agui.finalChunk({ result: 'done' })).toEqual({ + type: 'final', + data: { result: 'done' }, + }); + expect(agui.errorChunk('boom')).toEqual({ type: 'error', message: 'boom' }); + expect(agui.errorChunk('boom', { code: 'E_BOOM' })).toEqual({ + type: 'error', + message: 'boom', + code: 'E_BOOM', + }); + }); +}); + +describe('agui — serializeStream', () => { + it('emits one NDJSON line per chunk', async () => { + async function* chunks() { + yield agui.textChunk('a'); + yield agui.textChunk('b'); + yield agui.toolCallEndChunk('tc1'); + } + const decoder = new TextDecoder(); + const lines: string[] = []; + for await (const buf of agui.serializeStream(chunks())) { + lines.push(decoder.decode(buf)); + } + expect(lines).toEqual([ + JSON.stringify({ type: 'text', delta: 'a' }) + '\n', + JSON.stringify({ type: 'text', delta: 'b' }) + '\n', + JSON.stringify({ type: 'tool_call_end', id: 'tc1' }) + '\n', + ]); + }); + + it('auto-wraps bare strings as text chunks', async () => { + async function* gen() { + yield 'hello'; + yield agui.textChunk(' world'); + } + const decoder = new TextDecoder(); + let combined = ''; + for await (const buf of agui.serializeStream(gen())) combined += decoder.decode(buf); + expect(combined.trim().split('\n').map((l) => JSON.parse(l))).toEqual([ + { type: 'text', delta: 'hello' }, + { type: 'text', delta: ' world' }, + ]); + }); + + it('rejects non-string non-object values', async () => { + async function* gen() { + yield 42 as unknown as string; + } + await expect(async () => { + for await (const _ of agui.serializeStream(gen())) { + /* drain */ + } + }).rejects.toThrow(/non-string\/non-object/); + }); + + it('accepts a synchronous iterable too', async () => { + const chunks = [agui.textChunk('x'), agui.textChunk('y')]; + const decoder = new TextDecoder(); + let n = 0; + for await (const buf of agui.serializeStream(chunks)) { + const obj = JSON.parse(decoder.decode(buf)); + expect(obj.type).toBe('text'); + n++; + } + expect(n).toBe(2); + }); +}); + +describe('agui — relayHarnessStream', () => { + async function* fromArray(items: unknown[]) { + for (const x of items) yield x; + } + + it('translates assistant text blocks into text chunks', async () => { + const chunks = []; + for await (const ch of agui.relayHarnessStream( + fromArray([ + { + type: 'assistant', + message: { + content: [ + { type: 'text', text: 'hello ' }, + { type: 'text', text: 'world' }, + ], + }, + }, + ]), + )) { + chunks.push(ch); + } + expect(chunks).toEqual([ + { type: 'text', delta: 'hello ' }, + { type: 'text', delta: 'world' }, + ]); + }); + + it('translates assistant thinking blocks into reasoning chunks', async () => { + const chunks = []; + for await (const ch of agui.relayHarnessStream( + fromArray([ + { + type: 'assistant', + message: { content: [{ type: 'thinking', thinking: 'hmm' }] }, + }, + ]), + )) { + chunks.push(ch); + } + expect(chunks).toEqual([{ type: 'reasoning', delta: 'hmm' }]); + }); + + it('translates tool_use blocks into start+end pairs', async () => { + const chunks = []; + for await (const ch of agui.relayHarnessStream( + fromArray([ + { + type: 'assistant', + message: { + content: [{ type: 'tool_use', id: 'tc1', name: 'lookup', input: { q: 'x' } }], + }, + }, + ]), + )) { + chunks.push(ch); + } + expect(chunks).toEqual([ + { type: 'tool_call_start', id: 'tc1', name: 'lookup', arguments: { q: 'x' } }, + { type: 'tool_call_end', id: 'tc1' }, + ]); + }); + + it('translates tool_result string content', async () => { + const chunks = []; + for await (const ch of agui.relayHarnessStream( + fromArray([ + { + type: 'user', + message: { + content: [{ type: 'tool_result', tool_use_id: 'tc1', content: 'ok' }], + }, + }, + ]), + )) { + chunks.push(ch); + } + expect(chunks).toEqual([ + { type: 'tool_call_result', id: 'tc1', content: 'ok', role: 'tool' }, + ]); + }); + + it('translates tool_result list content by stitching text blocks', async () => { + const chunks = []; + for await (const ch of agui.relayHarnessStream( + fromArray([ + { + type: 'user', + message: { + content: [ + { + type: 'tool_result', + tool_use_id: 'tc1', + content: [ + { type: 'text', text: 'a' }, + { type: 'text', text: 'b' }, + ], + }, + ], + }, + }, + ]), + )) { + chunks.push(ch); + } + expect(chunks).toEqual([ + { type: 'tool_call_result', id: 'tc1', content: 'ab', role: 'tool' }, + ]); + }); + + it('skips terminal result envelope and surfaces system as raw', async () => { + const chunks = []; + for await (const ch of agui.relayHarnessStream( + fromArray([ + { type: 'system', subtype: 'init' }, + { type: 'result', subtype: 'success', result: 'done' }, + ]), + )) { + chunks.push(ch); + } + expect(chunks).toEqual([ + { type: 'raw', event: { type: 'system', subtype: 'init' }, source: 'harness' }, + ]); + }); + + it('preserves unknown blocks and unknown top-level messages as raw', async () => { + const chunks = []; + for await (const ch of agui.relayHarnessStream( + fromArray([ + { type: 'assistant', message: { content: [{ type: 'mystery', payload: 1 }] } }, + { type: 'no-such-thing' }, + ]), + )) { + chunks.push(ch); + } + expect(chunks[0]).toMatchObject({ type: 'raw', source: 'harness' }); + expect((chunks[0] as { event: { type: string } }).event.type).toBe('mystery'); + expect(chunks[1]).toMatchObject({ type: 'raw', source: 'harness' }); + }); + + it('handles bare content and string content shapes', async () => { + const chunks = []; + for await (const ch of agui.relayHarnessStream( + fromArray([ + { type: 'assistant', content: 'inline-string' }, + { type: 'assistant', content: [{ type: 'text', text: 'inline-list' }] }, + ]), + )) { + chunks.push(ch); + } + expect(chunks).toEqual([ + { type: 'text', delta: 'inline-string' }, + { type: 'text', delta: 'inline-list' }, + ]); + }); + + it('wraps non-object iterates as raw', async () => { + const chunks = []; + for await (const ch of agui.relayHarnessStream(fromArray(['scalar', 7, null]))) { + chunks.push(ch); + } + expect(chunks.every((c) => (c as { type: string }).type === 'raw')).toBe(true); + expect(chunks).toHaveLength(3); + }); +}); + +describe('agui — STREAMING_CONTENT_TYPE', () => { + it('matches the wire constant from the Python and Go SDKs', () => { + expect(agui.STREAMING_CONTENT_TYPE).toBe('application/x-ndjson'); + }); +});