diff --git a/cagent-schema.json b/cagent-schema.json index 677be8384..e7afbf9ce 100644 --- a/cagent-schema.json +++ b/cagent-schema.json @@ -593,7 +593,8 @@ "api", "a2a", "lsp", - "user_prompt" + "user_prompt", + "switch_model" ] }, "instruction": { @@ -700,6 +701,17 @@ "description": "Timeout in seconds for the fetch tool", "minimum": 1 }, + "models": { + "type": "array", + "description": "List of allowed model references for the switch_model tool. If not specified, all models defined in the config are available.", + "items": { + "type": "string" + }, + "examples": [ + ["fast_model", "powerful_model"], + ["openai/gpt-4o-mini", "anthropic/claude-sonnet-4-0"] + ] + }, "url": { "type": "string", "description": "URL for the a2a tool", @@ -757,7 +769,8 @@ "memory", "script", "fetch", - "user_prompt" + "user_prompt", + "switch_model" ] } } diff --git a/e2e/switch_model_test.go b/e2e/switch_model_test.go new file mode 100644 index 000000000..0d86b4d07 --- /dev/null +++ b/e2e/switch_model_test.go @@ -0,0 +1,110 @@ +package e2e_test + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/agent" + "github.com/docker/cagent/pkg/chat" + "github.com/docker/cagent/pkg/config" + "github.com/docker/cagent/pkg/runtime" + "github.com/docker/cagent/pkg/session" + "github.com/docker/cagent/pkg/teamloader" +) + +// setupSwitchModelTest creates a runtime with model switching support. +func setupSwitchModelTest(t *testing.T) (runtime.Runtime, *agent.Agent) { + t.Helper() + + ctx := t.Context() + agentSource, err := config.Resolve("testdata/switch_model.yaml") + require.NoError(t, err) + + _, runConfig := startRecordingAIProxy(t) + loadResult, err := teamloader.LoadWithConfig(ctx, agentSource, runConfig) + require.NoError(t, err) + + modelSwitcherCfg := &runtime.ModelSwitcherConfig{ + Models: loadResult.Models, + Providers: loadResult.Providers, + ModelsGateway: runConfig.ModelsGateway, + EnvProvider: runConfig.EnvProvider(), + AgentDefaultModels: loadResult.AgentDefaultModels, + } + + rt, err := runtime.New(loadResult.Team, runtime.WithModelSwitcherConfig(modelSwitcherCfg)) + require.NoError(t, err) + + rootAgent, err := loadResult.Team.Agent("root") + require.NoError(t, err) + + return rt, rootAgent +} + +// findSwitchModelCall searches session messages for a switch_model tool call containing the given model name. +func findSwitchModelCall(sess *session.Session, modelName string) bool { + for _, msg := range sess.GetAllMessages() { + if msg.Message.Role != chat.MessageRoleAssistant || msg.Message.ToolCalls == nil { + continue + } + for _, tc := range msg.Message.ToolCalls { + if tc.Function.Name == "switch_model" && strings.Contains(tc.Function.Arguments, modelName) { + return true + } + } + } + return false +} + +// TestSwitchModel_AgentCanSwitchModels verifies that an agent can use the switch_model tool +// to change between models during a conversation. +func TestSwitchModel_AgentCanSwitchModels(t *testing.T) { + t.Parallel() + + ctx := t.Context() + rt, _ := setupSwitchModelTest(t) + + // Switch to smart model + sess := session.New(session.WithUserMessage("Switch to the smart model, then say hi")) + _, err := rt.Run(ctx, sess) + require.NoError(t, err) + + assert.True(t, findSwitchModelCall(sess, "smart"), "Expected switch_model tool call with 'smart' model") + assert.NotEmpty(t, sess.GetLastAssistantMessageContent(), "Expected a response after switching") + + // Switch back to fast model + sess.AddMessage(session.UserMessage("Now switch back to the fast model and say goodbye")) + _, err = rt.Run(ctx, sess) + require.NoError(t, err) + + assert.True(t, findSwitchModelCall(sess, "fast"), "Expected switch_model tool call with 'fast' model") + assert.NotEmpty(t, sess.GetLastAssistantMessageContent(), "Expected a response after switching back") +} + +// TestSwitchModel_ModelActuallyChanges verifies that after calling switch_model, +// the agent's model object is updated to the new model. +func TestSwitchModel_ModelActuallyChanges(t *testing.T) { + t.Parallel() + + ctx := t.Context() + rt, rootAgent := setupSwitchModelTest(t) + + assert.Contains(t, rootAgent.Model().ID(), "gpt-4o-mini", "Should start with gpt-4o-mini") + + // Switch to smart model + sess := session.New(session.WithUserMessage("Use the switch_model tool to switch to smart model, then just say 'done'")) + _, err := rt.Run(ctx, sess) + require.NoError(t, err) + + assert.Contains(t, rootAgent.Model().ID(), "claude", "Model should have changed to claude") + + // Verify the new model works + sess.AddMessage(session.UserMessage("What is 2+2? Answer with just the number.")) + _, err = rt.Run(ctx, sess) + require.NoError(t, err) + + assert.NotEmpty(t, sess.GetLastAssistantMessageContent()) +} diff --git a/e2e/testdata/cassettes/TestSwitchModel_AgentCanSwitchModels.yaml b/e2e/testdata/cassettes/TestSwitchModel_AgentCanSwitchModels.yaml new file mode 100644 index 000000000..1eb30dc59 --- /dev/null +++ b/e2e/testdata/cassettes/TestSwitchModel_AgentCanSwitchModels.yaml @@ -0,0 +1,190 @@ +--- +version: 2 +interactions: + - id: 0 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: api.openai.com + body: '{"messages":[{"content":"You are a helpful assistant that can switch between models.\nWhen asked to switch to a better/smarter model, use the switch_model tool to switch to \"smart\".\nWhen asked to switch to the default/fast model, use the switch_model tool to switch to \"fast\".\nAfter switching, respond briefly confirming the switch and then answer any question.\nKeep responses very short (one sentence max).\n","role":"system"},{"content":"## Model Switching Guidelines\n\nYou have access to multiple AI models and can switch between them strategically.\n\n### When to Consider Switching Models\n\n**Switch to a faster/cheaper model when:**\n- Performing simple, routine tasks (formatting, basic Q&A, short summaries)\n- The current task doesn''t require advanced reasoning\n- Processing straightforward requests that any model can handle well\n- Optimizing for response speed or cost efficiency\n\n**Switch to a more powerful model when:**\n- Facing complex reasoning or multi-step problems\n- Writing or reviewing code that requires careful analysis\n- Handling nuanced or ambiguous requests\n- Generating detailed technical content\n- The current model is struggling with the task quality\n\n**Switch back to the default model when:**\n- A specialized task is complete\n- Returning to general conversation\n- The extra capability is no longer needed\n\n### Best Practices\n\n1. Check the tool description to see available models and which one is currently active\n2. Don''t switch unnecessarily - there''s overhead in changing models\n3. Consider switching proactively before a complex task rather than after struggling\n4. When in doubt about task complexity, prefer the more capable model","role":"system"},{"content":"Switch to the smart model, then say hi","role":"user"}],"model":"gpt-4o-mini","max_tokens":16384,"parallel_tool_calls":true,"stream_options":{"include_usage":true},"tools":[{"function":{"name":"switch_model","description":"Switch the AI model used for subsequent responses.\n\n**Available models:**\n- fast (default) (current)\n- smart\n\nOnly the models listed above can be selected. Any other model will be rejected.","parameters":{"additionalProperties":false,"properties":{"model":{"description":"The model to switch to. Must be one of the allowed models listed in the tool description.","type":"string"}},"required":["model"],"type":"object"}},"type":"function"}],"stream":true}' + url: https://api.openai.com/v1/chat/completions + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_7kJnnLoJFIbRnRHe1CoYTiXn","type":"function","function":{"name":"switch_model","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"47P"} + + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"Brogprarad7mtQ"} + + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"model"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"HyE8RFFAvAHg"} + + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"fVnOy0tVTM1B"} + + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"smart"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"OcEq07Zj7pQR"} + + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"8mxN23j9adGW6p"} + + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null,"obfuscation":"Kk3toyFyDAr59xO"} + + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[],"usage":{"prompt_tokens":424,"completion_tokens":14,"total_tokens":438,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}},"obfuscation":"hlCnskHF"} + + data: [DONE] + + headers: {} + status: 200 OK + code: 200 + duration: 925.604834ms + - id: 1 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: api.anthropic.com + body: '{"max_tokens":64000,"messages":[{"content":[{"text":"Switch to the smart model, then say hi","type":"text"}],"role":"user"},{"content":[{"id":"call_7kJnnLoJFIbRnRHe1CoYTiXn","input":{"model":"smart"},"name":"switch_model","cache_control":{"type":"ephemeral"},"type":"tool_use"}],"role":"assistant"},{"content":[{"tool_use_id":"call_7kJnnLoJFIbRnRHe1CoYTiXn","is_error":false,"cache_control":{"type":"ephemeral"},"content":[{"text":"Switched model from \"fast\" to \"smart\".","type":"text"}],"type":"tool_result"}],"role":"user"}],"model":"claude-sonnet-4-20250514","system":[{"text":"You are a helpful assistant that can switch between models.\nWhen asked to switch to a better/smarter model, use the switch_model tool to switch to \"smart\".\nWhen asked to switch to the default/fast model, use the switch_model tool to switch to \"fast\".\nAfter switching, respond briefly confirming the switch and then answer any question.\nKeep responses very short (one sentence max).","type":"text"},{"text":"## Model Switching Guidelines\n\nYou have access to multiple AI models and can switch between them strategically.\n\n### When to Consider Switching Models\n\n**Switch to a faster/cheaper model when:**\n- Performing simple, routine tasks (formatting, basic Q\u0026A, short summaries)\n- The current task doesn''t require advanced reasoning\n- Processing straightforward requests that any model can handle well\n- Optimizing for response speed or cost efficiency\n\n**Switch to a more powerful model when:**\n- Facing complex reasoning or multi-step problems\n- Writing or reviewing code that requires careful analysis\n- Handling nuanced or ambiguous requests\n- Generating detailed technical content\n- The current model is struggling with the task quality\n\n**Switch back to the default model when:**\n- A specialized task is complete\n- Returning to general conversation\n- The extra capability is no longer needed\n\n### Best Practices\n\n1. Check the tool description to see available models and which one is currently active\n2. Don''t switch unnecessarily - there''s overhead in changing models\n3. Consider switching proactively before a complex task rather than after struggling\n4. When in doubt about task complexity, prefer the more capable model","cache_control":{"type":"ephemeral"},"type":"text"}],"tools":[{"input_schema":{"properties":{"model":{"description":"The model to switch to. Must be one of the allowed models listed in the tool description.","type":"string"}},"required":["model"],"type":"object"},"name":"switch_model","description":"Switch the AI model used for subsequent responses.\n\n**Available models:**\n- fast (default)\n- smart (current)\n\nOnly the models listed above can be selected. Any other model will be rejected."}],"stream":true}' + url: https://api.anthropic.com/v1/messages + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + event: message_start + data: {"type":"message_start","message":{"model":"claude-sonnet-4-20250514","id":"msg_01MD9TRKxyqCVZAZ7ensFgbo","type":"message","role":"assistant","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":864,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":2,"service_tier":"standard"}} } + + event: content_block_start + data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} } + + event: ping + data: {"type": "ping"} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Switched"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to smart"} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" model -"} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" hi"} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" there"} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"!"} } + + event: content_block_stop + data: {"type":"content_block_stop","index":0 } + + event: message_delta + data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":864,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":12} } + + event: message_stop + data: {"type":"message_stop" } + + headers: {} + status: 200 OK + code: 200 + duration: 2.564459334s + - id: 2 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: api.anthropic.com + body: '{"max_tokens":64000,"messages":[{"content":[{"text":"Switch to the smart model, then say hi","type":"text"}],"role":"user"},{"content":[{"id":"call_7kJnnLoJFIbRnRHe1CoYTiXn","input":{"model":"smart"},"name":"switch_model","type":"tool_use"}],"role":"assistant"},{"content":[{"tool_use_id":"call_7kJnnLoJFIbRnRHe1CoYTiXn","is_error":false,"content":[{"text":"Switched model from \"fast\" to \"smart\".","type":"text"}],"type":"tool_result"}],"role":"user"},{"content":[{"text":"Switched to smart model - hi there!","cache_control":{"type":"ephemeral"},"type":"text"}],"role":"assistant"},{"content":[{"text":"Now switch back to the fast model and say goodbye","cache_control":{"type":"ephemeral"},"type":"text"}],"role":"user"}],"model":"claude-sonnet-4-20250514","system":[{"text":"You are a helpful assistant that can switch between models.\nWhen asked to switch to a better/smarter model, use the switch_model tool to switch to \"smart\".\nWhen asked to switch to the default/fast model, use the switch_model tool to switch to \"fast\".\nAfter switching, respond briefly confirming the switch and then answer any question.\nKeep responses very short (one sentence max).","type":"text"},{"text":"## Model Switching Guidelines\n\nYou have access to multiple AI models and can switch between them strategically.\n\n### When to Consider Switching Models\n\n**Switch to a faster/cheaper model when:**\n- Performing simple, routine tasks (formatting, basic Q\u0026A, short summaries)\n- The current task doesn''t require advanced reasoning\n- Processing straightforward requests that any model can handle well\n- Optimizing for response speed or cost efficiency\n\n**Switch to a more powerful model when:**\n- Facing complex reasoning or multi-step problems\n- Writing or reviewing code that requires careful analysis\n- Handling nuanced or ambiguous requests\n- Generating detailed technical content\n- The current model is struggling with the task quality\n\n**Switch back to the default model when:**\n- A specialized task is complete\n- Returning to general conversation\n- The extra capability is no longer needed\n\n### Best Practices\n\n1. Check the tool description to see available models and which one is currently active\n2. Don''t switch unnecessarily - there''s overhead in changing models\n3. Consider switching proactively before a complex task rather than after struggling\n4. When in doubt about task complexity, prefer the more capable model","cache_control":{"type":"ephemeral"},"type":"text"}],"tools":[{"input_schema":{"properties":{"model":{"description":"The model to switch to. Must be one of the allowed models listed in the tool description.","type":"string"}},"required":["model"],"type":"object"},"name":"switch_model","description":"Switch the AI model used for subsequent responses.\n\n**Available models:**\n- fast (default)\n- smart (current)\n\nOnly the models listed above can be selected. Any other model will be rejected."}],"stream":true}' + url: https://api.anthropic.com/v1/messages + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + event: message_start + data: {"type":"message_start","message":{"model":"claude-sonnet-4-20250514","id":"msg_01GwhF1NoMHrKCtnok3BqVmV","type":"message","role":"assistant","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":889,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":8,"service_tier":"standard"}} } + + event: content_block_start + data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_019FLZJHzXX2txzzYkeYERu3","name":"switch_model","input":{}} } + + event: ping + data: {"type": "ping"} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\"model\": \"fast"} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"\"}"} } + + event: content_block_stop + data: {"type":"content_block_stop","index":0 } + + event: message_delta + data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":889,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":53} } + + event: message_stop + data: {"type":"message_stop" } + + headers: {} + status: 200 OK + code: 200 + duration: 1.50779325s + - id: 3 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: api.openai.com + body: '{"messages":[{"content":"You are a helpful assistant that can switch between models.\nWhen asked to switch to a better/smarter model, use the switch_model tool to switch to \"smart\".\nWhen asked to switch to the default/fast model, use the switch_model tool to switch to \"fast\".\nAfter switching, respond briefly confirming the switch and then answer any question.\nKeep responses very short (one sentence max).\n","role":"system"},{"content":"## Model Switching Guidelines\n\nYou have access to multiple AI models and can switch between them strategically.\n\n### When to Consider Switching Models\n\n**Switch to a faster/cheaper model when:**\n- Performing simple, routine tasks (formatting, basic Q&A, short summaries)\n- The current task doesn''t require advanced reasoning\n- Processing straightforward requests that any model can handle well\n- Optimizing for response speed or cost efficiency\n\n**Switch to a more powerful model when:**\n- Facing complex reasoning or multi-step problems\n- Writing or reviewing code that requires careful analysis\n- Handling nuanced or ambiguous requests\n- Generating detailed technical content\n- The current model is struggling with the task quality\n\n**Switch back to the default model when:**\n- A specialized task is complete\n- Returning to general conversation\n- The extra capability is no longer needed\n\n### Best Practices\n\n1. Check the tool description to see available models and which one is currently active\n2. Don''t switch unnecessarily - there''s overhead in changing models\n3. Consider switching proactively before a complex task rather than after struggling\n4. When in doubt about task complexity, prefer the more capable model","role":"system"},{"content":"Switch to the smart model, then say hi","role":"user"},{"tool_calls":[{"id":"call_7kJnnLoJFIbRnRHe1CoYTiXn","function":{"arguments":"{\"model\":\"smart\"}","name":"switch_model"},"type":"function"}],"role":"assistant"},{"content":"Switched model from \"fast\" to \"smart\".","tool_call_id":"call_7kJnnLoJFIbRnRHe1CoYTiXn","role":"tool"},{"content":"Switched to smart model - hi there!","role":"assistant"},{"content":"Now switch back to the fast model and say goodbye","role":"user"},{"tool_calls":[{"id":"toolu_019FLZJHzXX2txzzYkeYERu3","function":{"arguments":"{\"model\": \"fast\"}","name":"switch_model"},"type":"function"}],"role":"assistant"},{"content":"Switched model from \"smart\" to \"fast\".","tool_call_id":"toolu_019FLZJHzXX2txzzYkeYERu3","role":"tool"}],"model":"gpt-4o-mini","max_tokens":16384,"parallel_tool_calls":true,"stream_options":{"include_usage":true},"tools":[{"function":{"name":"switch_model","description":"Switch the AI model used for subsequent responses.\n\n**Available models:**\n- fast (default) (current)\n- smart\n\nOnly the models listed above can be selected. Any other model will be rejected.","parameters":{"additionalProperties":false,"properties":{"model":{"description":"The model to switch to. Must be one of the allowed models listed in the tool description.","type":"string"}},"required":["model"],"type":"object"}},"type":"function"}],"stream":true}' + url: https://api.openai.com/v1/chat/completions + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"EBiZHklkg"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":"Sw"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"glrvUYASg"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":"itched"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"xLW9V"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":" back"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"tLdkd3"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"C1BtunrF"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":" fast"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"eqHrjd"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":" model"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"hKBvv"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":" -"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"wUJj0lES4"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":" goodbye"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"Saf"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"ZA7kxYZV20"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null,"obfuscation":"rwnJA"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[],"usage":{"prompt_tokens":517,"completion_tokens":10,"total_tokens":527,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}},"obfuscation":"xeWAz0J3"} + + data: [DONE] + + headers: {} + status: 200 OK + code: 200 + duration: 825.417875ms diff --git a/e2e/testdata/cassettes/TestSwitchModel_ModelActuallyChanges.yaml b/e2e/testdata/cassettes/TestSwitchModel_ModelActuallyChanges.yaml new file mode 100644 index 000000000..77d8210a2 --- /dev/null +++ b/e2e/testdata/cassettes/TestSwitchModel_ModelActuallyChanges.yaml @@ -0,0 +1,126 @@ +--- +version: 2 +interactions: + - id: 0 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: api.openai.com + body: '{"messages":[{"content":"You are a helpful assistant that can switch between models.\nWhen asked to switch to a better/smarter model, use the switch_model tool to switch to \"smart\".\nWhen asked to switch to the default/fast model, use the switch_model tool to switch to \"fast\".\nAfter switching, respond briefly confirming the switch and then answer any question.\nKeep responses very short (one sentence max).\n","role":"system"},{"content":"## Model Switching Guidelines\n\nYou have access to multiple AI models and can switch between them strategically.\n\n### When to Consider Switching Models\n\n**Switch to a faster/cheaper model when:**\n- Performing simple, routine tasks (formatting, basic Q&A, short summaries)\n- The current task doesn''t require advanced reasoning\n- Processing straightforward requests that any model can handle well\n- Optimizing for response speed or cost efficiency\n\n**Switch to a more powerful model when:**\n- Facing complex reasoning or multi-step problems\n- Writing or reviewing code that requires careful analysis\n- Handling nuanced or ambiguous requests\n- Generating detailed technical content\n- The current model is struggling with the task quality\n\n**Switch back to the default model when:**\n- A specialized task is complete\n- Returning to general conversation\n- The extra capability is no longer needed\n\n### Best Practices\n\n1. Check the tool description to see available models and which one is currently active\n2. Don''t switch unnecessarily - there''s overhead in changing models\n3. Consider switching proactively before a complex task rather than after struggling\n4. When in doubt about task complexity, prefer the more capable model","role":"system"},{"content":"Use the switch_model tool to switch to smart model, then just say ''done''","role":"user"}],"model":"gpt-4o-mini","max_tokens":16384,"parallel_tool_calls":true,"stream_options":{"include_usage":true},"tools":[{"function":{"name":"switch_model","description":"Switch the AI model used for subsequent responses.\n\n**Available models:**\n- fast (default) (current)\n- smart\n\nOnly the models listed above can be selected. Any other model will be rejected.","parameters":{"additionalProperties":false,"properties":{"model":{"description":"The model to switch to. Must be one of the allowed models listed in the tool description.","type":"string"}},"required":["model"],"type":"object"}},"type":"function"}],"stream":true}' + url: https://api.openai.com/v1/chat/completions + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_f4epGCysqn5sr2YvsxOSbSgd","type":"function","function":{"name":"switch_model","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"PSx"} + + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"fzow185TGFrgOP"} + + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"model"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"DdF9c1FUwkdS"} + + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"9sVOZKCWUxmi"} + + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"smart"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"HqItFmdno1nY"} + + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"ccabCw5RGwhlbI"} + + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null,"obfuscation":"5XcPTFKUh1OVXGN"} + + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[],"usage":{"prompt_tokens":432,"completion_tokens":14,"total_tokens":446,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}},"obfuscation":"Xr45WNpE"} + + data: [DONE] + + headers: {} + status: 200 OK + code: 200 + duration: 977.696458ms + - id: 1 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: api.anthropic.com + body: '{"max_tokens":64000,"messages":[{"content":[{"text":"Use the switch_model tool to switch to smart model, then just say ''done''","type":"text"}],"role":"user"},{"content":[{"id":"call_f4epGCysqn5sr2YvsxOSbSgd","input":{"model":"smart"},"name":"switch_model","cache_control":{"type":"ephemeral"},"type":"tool_use"}],"role":"assistant"},{"content":[{"tool_use_id":"call_f4epGCysqn5sr2YvsxOSbSgd","is_error":false,"cache_control":{"type":"ephemeral"},"content":[{"text":"Switched model from \"fast\" to \"smart\".","type":"text"}],"type":"tool_result"}],"role":"user"}],"model":"claude-sonnet-4-20250514","system":[{"text":"You are a helpful assistant that can switch between models.\nWhen asked to switch to a better/smarter model, use the switch_model tool to switch to \"smart\".\nWhen asked to switch to the default/fast model, use the switch_model tool to switch to \"fast\".\nAfter switching, respond briefly confirming the switch and then answer any question.\nKeep responses very short (one sentence max).","type":"text"},{"text":"## Model Switching Guidelines\n\nYou have access to multiple AI models and can switch between them strategically.\n\n### When to Consider Switching Models\n\n**Switch to a faster/cheaper model when:**\n- Performing simple, routine tasks (formatting, basic Q\u0026A, short summaries)\n- The current task doesn''t require advanced reasoning\n- Processing straightforward requests that any model can handle well\n- Optimizing for response speed or cost efficiency\n\n**Switch to a more powerful model when:**\n- Facing complex reasoning or multi-step problems\n- Writing or reviewing code that requires careful analysis\n- Handling nuanced or ambiguous requests\n- Generating detailed technical content\n- The current model is struggling with the task quality\n\n**Switch back to the default model when:**\n- A specialized task is complete\n- Returning to general conversation\n- The extra capability is no longer needed\n\n### Best Practices\n\n1. Check the tool description to see available models and which one is currently active\n2. Don''t switch unnecessarily - there''s overhead in changing models\n3. Consider switching proactively before a complex task rather than after struggling\n4. When in doubt about task complexity, prefer the more capable model","cache_control":{"type":"ephemeral"},"type":"text"}],"tools":[{"input_schema":{"properties":{"model":{"description":"The model to switch to. Must be one of the allowed models listed in the tool description.","type":"string"}},"required":["model"],"type":"object"},"name":"switch_model","description":"Switch the AI model used for subsequent responses.\n\n**Available models:**\n- fast (default)\n- smart (current)\n\nOnly the models listed above can be selected. Any other model will be rejected."}],"stream":true}' + url: https://api.anthropic.com/v1/messages + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + event: message_start + data: {"type":"message_start","message":{"model":"claude-sonnet-4-20250514","id":"msg_01Vg6kQujtxt1Trp3nhDeLoU","type":"message","role":"assistant","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":874,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":1,"service_tier":"standard"}} } + + event: content_block_start + data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + + event: ping + data: {"type": "ping"} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Done"} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"."} } + + event: content_block_stop + data: {"type":"content_block_stop","index":0 } + + event: message_delta + data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":874,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":5} } + + event: message_stop + data: {"type":"message_stop" } + + headers: {} + status: 200 OK + code: 200 + duration: 1.468132875s + - id: 2 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: api.anthropic.com + body: '{"max_tokens":64000,"messages":[{"content":[{"text":"Use the switch_model tool to switch to smart model, then just say ''done''","type":"text"}],"role":"user"},{"content":[{"id":"call_f4epGCysqn5sr2YvsxOSbSgd","input":{"model":"smart"},"name":"switch_model","type":"tool_use"}],"role":"assistant"},{"content":[{"tool_use_id":"call_f4epGCysqn5sr2YvsxOSbSgd","is_error":false,"content":[{"text":"Switched model from \"fast\" to \"smart\".","type":"text"}],"type":"tool_result"}],"role":"user"},{"content":[{"text":"Done.","cache_control":{"type":"ephemeral"},"type":"text"}],"role":"assistant"},{"content":[{"text":"What is 2+2? Answer with just the number.","cache_control":{"type":"ephemeral"},"type":"text"}],"role":"user"}],"model":"claude-sonnet-4-20250514","system":[{"text":"You are a helpful assistant that can switch between models.\nWhen asked to switch to a better/smarter model, use the switch_model tool to switch to \"smart\".\nWhen asked to switch to the default/fast model, use the switch_model tool to switch to \"fast\".\nAfter switching, respond briefly confirming the switch and then answer any question.\nKeep responses very short (one sentence max).","type":"text"},{"text":"## Model Switching Guidelines\n\nYou have access to multiple AI models and can switch between them strategically.\n\n### When to Consider Switching Models\n\n**Switch to a faster/cheaper model when:**\n- Performing simple, routine tasks (formatting, basic Q\u0026A, short summaries)\n- The current task doesn''t require advanced reasoning\n- Processing straightforward requests that any model can handle well\n- Optimizing for response speed or cost efficiency\n\n**Switch to a more powerful model when:**\n- Facing complex reasoning or multi-step problems\n- Writing or reviewing code that requires careful analysis\n- Handling nuanced or ambiguous requests\n- Generating detailed technical content\n- The current model is struggling with the task quality\n\n**Switch back to the default model when:**\n- A specialized task is complete\n- Returning to general conversation\n- The extra capability is no longer needed\n\n### Best Practices\n\n1. Check the tool description to see available models and which one is currently active\n2. Don''t switch unnecessarily - there''s overhead in changing models\n3. Consider switching proactively before a complex task rather than after struggling\n4. When in doubt about task complexity, prefer the more capable model","cache_control":{"type":"ephemeral"},"type":"text"}],"tools":[{"input_schema":{"properties":{"model":{"description":"The model to switch to. Must be one of the allowed models listed in the tool description.","type":"string"}},"required":["model"],"type":"object"},"name":"switch_model","description":"Switch the AI model used for subsequent responses.\n\n**Available models:**\n- fast (default)\n- smart (current)\n\nOnly the models listed above can be selected. Any other model will be rejected."}],"stream":true}' + url: https://api.anthropic.com/v1/messages + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + event: message_start + data: {"type":"message_start","message":{"model":"claude-sonnet-4-20250514","id":"msg_017Ko3yzvcdekdx9jBjSoPqY","type":"message","role":"assistant","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":895,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":1,"service_tier":"standard"}} } + + event: content_block_start + data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} } + + event: ping + data: {"type": "ping"} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"4"} } + + event: content_block_stop + data: {"type":"content_block_stop","index":0 } + + event: message_delta + data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":895,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":5} } + + event: message_stop + data: {"type":"message_stop" } + + headers: {} + status: 200 OK + code: 200 + duration: 1.430636s diff --git a/e2e/testdata/switch_model.yaml b/e2e/testdata/switch_model.yaml new file mode 100644 index 000000000..20dfa515a --- /dev/null +++ b/e2e/testdata/switch_model.yaml @@ -0,0 +1,23 @@ +version: 4 + +models: + fast: + provider: openai + model: gpt-4o-mini + smart: + provider: anthropic + model: claude-sonnet-4-0 + +agents: + root: + model: fast + description: Agent with model switching capability + instruction: | + You are a helpful assistant that can switch between models. + When asked to switch to a better/smarter model, use the switch_model tool to switch to "smart". + When asked to switch to the default/fast model, use the switch_model tool to switch to "fast". + After switching, respond briefly confirming the switch and then answer any question. + Keep responses very short (one sentence max). + toolsets: + - type: switch_model + models: [fast, smart] diff --git a/examples/gopher.yaml b/examples/gopher.yaml index 8b138cbb3..e30caed45 100755 --- a/examples/gopher.yaml +++ b/examples/gopher.yaml @@ -1,7 +1,7 @@ #!/usr/bin/env cagent run models: - claude: + opus: provider: anthropic model: claude-opus-4-5 haiku: @@ -10,7 +10,7 @@ models: agents: root: - model: claude + model: haiku description: Expert Golang Developer specialized in implementing features and improving code quality. instruction: | **Goal:** @@ -49,8 +49,14 @@ agents: - Be proactive in identifying potential issues - Only ask for clarification if necessary, try your best to use all the tools to get the info you need - **Tools:** - - When needed and possible, call multiple tools concurrently. It's faster and cheaper. + **Tools:** + - When needed and possible, call multiple tools concurrently. It's faster and cheaper. + + **Model Switching Strategy:** + MUST HAVE: before each task, decide which model to use based on the task complexity: + - For simple tasks (formatting, basic Q&A, summaries), switch to a faster/cheaper model (haiku) + - For complex tasks (code generation, analysis, reasoning), use a more powerful model (opus) + - After completing a specialized task, consider switching back to the default model (opus) add_date: true add_environment_info: true @@ -59,6 +65,14 @@ agents: sub_agents: - librarian toolsets: + - type: switch_model + models: [haiku, opus] + instruction: | + {ORIGINAL_INSTRUCTIONS} + + Choose the model based on task complexity: + - Use 'haiku' for simple tasks (formatting, basic Q&A, summaries) + - Use 'opus' for complex tasks (code generation, analysis, reasoning) - type: filesystem - type: shell - type: todo @@ -134,7 +148,7 @@ agents: - **Tips**: General security best practices relevant to the changes planner: - model: claude + model: opus instruction: | You are a planning agent responsible for gathering user requirements and creating a development plan. Always ask clarifying questions to ensure you fully understand the user's needs before creating the plan. @@ -185,6 +199,7 @@ permissions: - go_symbol_references - go_vulncheck - go_workspace + - switch_model - shell:cmd=gh --version - shell:cmd=gh pr view * - shell:cmd=gh pr diff * diff --git a/examples/switch_model.yaml b/examples/switch_model.yaml new file mode 100644 index 000000000..4723f9dd5 --- /dev/null +++ b/examples/switch_model.yaml @@ -0,0 +1,40 @@ +#!/usr/bin/env cagent run + +# This example demonstrates how an agent can dynamically switch between different +# AI models during a conversation. This is useful for: +# - Cost optimization: Use cheaper models for simple tasks +# - Performance: Use more powerful models for complex reasoning +# - Specialization: Use domain-specific models when appropriate + +models: + fast: + provider: anthropic + model: claude-haiku-4-5-20251001 + + powerful: + provider: anthropic + model: claude-sonnet-4-5-20250929 + +agents: + root: + model: fast + description: An adaptive assistant that can switch models based on task complexity + instruction: | + You are a helpful assistant with the ability to switch between different AI models. + + Use switch_model strategically: + - For simple tasks (formatting, basic Q&A, summaries), use the faster/cheaper model + - For complex tasks (code generation, analysis, reasoning), use the more powerful model + - The switch_model tool description shows available models and which one is current + - After completing a specialized task, consider switching back to the default model + + Example workflow: + 1. User asks a simple question -> switch to 'fast' for efficiency + 2. User asks for complex code -> switch to 'powerful' for quality + 3. Task complete -> switch back to 'fast' + + toolsets: + - type: filesystem + - type: shell + - type: switch_model + models: [fast, powerful] diff --git a/pkg/agent/opts.go b/pkg/agent/opts.go index 843370bf9..591226d9f 100644 --- a/pkg/agent/opts.go +++ b/pkg/agent/opts.go @@ -22,7 +22,7 @@ func WithToolSets(toolSet ...tools.ToolSet) Opt { var startableToolSet []*StartableToolSet for _, ts := range toolSet { startableToolSet = append(startableToolSet, &StartableToolSet{ - ToolSet: ts, + ToolSetWrapper: tools.ToolSetWrapper{ToolSet: ts}, }) } @@ -141,7 +141,7 @@ func WithThinkingConfigured(configured bool) Opt { } type StartableToolSet struct { - tools.ToolSet + tools.ToolSetWrapper mu sync.Mutex started bool diff --git a/pkg/agent/opts_test.go b/pkg/agent/opts_test.go index 6d5528b07..07fec1952 100644 --- a/pkg/agent/opts_test.go +++ b/pkg/agent/opts_test.go @@ -28,7 +28,7 @@ func (f *flakyStartToolset) Tools(context.Context) ([]tools.Tool, error) { retur func TestStartableToolSet_RetriesAfterFailure(t *testing.T) { ctx := t.Context() inner := &flakyStartToolset{} - ts := &StartableToolSet{ToolSet: inner} + ts := &StartableToolSet{ToolSetWrapper: tools.ToolSetWrapper{ToolSet: inner}} err := ts.Start(ctx) require.Error(t, err) diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index aab024847..b281361d1 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -269,6 +269,9 @@ type Toolset struct { // For the `fetch` tool Timeout int `json:"timeout,omitempty"` + + // For the `switch_model` tool - list of allowed models + Models []string `json:"models,omitempty"` } func (t *Toolset) UnmarshalYAML(unmarshal func(any) error) error { diff --git a/pkg/config/latest/validate.go b/pkg/config/latest/validate.go index 752c29987..50a9fb2bb 100644 --- a/pkg/config/latest/validate.go +++ b/pkg/config/latest/validate.go @@ -2,6 +2,7 @@ package latest import ( "errors" + "fmt" "strings" ) @@ -56,6 +57,9 @@ func (t *Toolset) validate() error { if t.Shared && t.Type != "todo" { return errors.New("shared can only be used with type 'todo'") } + if len(t.Models) > 0 && t.Type != "switch_model" { + return errors.New("models can only be used with type 'switch_model'") + } if t.Command != "" && t.Type != "mcp" && t.Type != "lsp" { return errors.New("command can only be used with type 'mcp' or 'lsp'") } @@ -86,6 +90,15 @@ func (t *Toolset) validate() error { if t.Sandbox != nil && len(t.Sandbox.Paths) == 0 { return errors.New("sandbox requires at least one path to be set") } + case "switch_model": + if len(t.Models) == 0 { + return errors.New("switch_model toolset requires at least one model") + } + for i, m := range t.Models { + if strings.TrimSpace(m) == "" { + return fmt.Errorf("switch_model toolset: model at index %d is empty", i) + } + } case "memory": if t.Path == "" { return errors.New("memory toolset requires a path to be set") diff --git a/pkg/config/latest/validate_test.go b/pkg/config/latest/validate_test.go index c68d95cf1..e700bff0e 100644 --- a/pkg/config/latest/validate_test.go +++ b/pkg/config/latest/validate_test.go @@ -189,3 +189,121 @@ agents: }) } } + +func TestToolset_Validate_SwitchModel(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config string + wantErr string + }{ + { + name: "valid switch_model with models", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: switch_model + models: [fast, powerful] +`, + wantErr: "", + }, + { + name: "valid switch_model with single model", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: switch_model + models: + - only_one +`, + wantErr: "", + }, + { + name: "switch_model without models", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: switch_model +`, + wantErr: "switch_model toolset requires at least one model", + }, + { + name: "switch_model with empty models list", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: switch_model + models: [] +`, + wantErr: "switch_model toolset requires at least one model", + }, + { + name: "switch_model with empty string in models", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: switch_model + models: [fast, "", powerful] +`, + wantErr: "switch_model toolset: model at index 1 is empty", + }, + { + name: "switch_model with whitespace-only string in models", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: switch_model + models: [fast, " ", powerful] +`, + wantErr: "switch_model toolset: model at index 1 is empty", + }, + { + name: "models on non-switch_model toolset", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: shell + models: [fast, powerful] +`, + wantErr: "models can only be used with type 'switch_model'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var cfg Config + err := yaml.Unmarshal([]byte(tt.config), &cfg) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index f84ea553f..04ccdc9a2 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -107,8 +107,9 @@ func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef st // Check if this is an inline alloy spec (comma-separated provider/model specs) // e.g., "openai/gpt-4o,anthropic/claude-sonnet-4-0" - if isInlineAlloySpec(modelRef) { - providers, err := r.createProvidersFromInlineAlloy(ctx, modelRef) + inlineRefs, isInlineAlloy := parseInlineAlloyRefs(modelRef) + if isInlineAlloy { + providers, err := r.createProvidersFromInlineAlloy(ctx, inlineRefs) if err != nil { return fmt.Errorf("failed to create inline alloy model: %w", err) } @@ -117,16 +118,11 @@ func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef st return nil } - // Try parsing as inline spec (provider/model) - providerName, modelName, ok := strings.Cut(modelRef, "/") - if !ok { - return fmt.Errorf("invalid model reference %q: expected a model name from config or 'provider/model' format", modelRef) + inlineCfg, err := parseModelRef(r.modelSwitcherCfg.Models, modelRef) + if err != nil { + return err } - inlineCfg := &latest.ModelConfig{ - Provider: providerName, - Model: modelName, - } prov, err := r.createProviderFromConfig(ctx, inlineCfg) if err != nil { return fmt.Errorf("failed to create inline model: %w", err) @@ -141,114 +137,127 @@ func isAlloyModelConfig(cfg latest.ModelConfig) bool { return cfg.Provider == "" && strings.Contains(cfg.Model, ",") } -// isInlineAlloySpec checks if a model reference is an inline alloy specification. -// An inline alloy is comma-separated provider/model specs like "openai/gpt-4o,anthropic/claude-sonnet-4-0". -func isInlineAlloySpec(modelRef string) bool { - if !strings.Contains(modelRef, ",") { - return false - } - // Check that each part looks like a provider/model spec - // and count valid parts (need at least 2 for an alloy) - validParts := 0 - for part := range strings.SplitSeq(modelRef, ",") { - part = strings.TrimSpace(part) - if part == "" { +type modelRefSource int + +const ( + modelRefInline modelRefSource = iota + modelRefConfig +) + +type modelRef struct { + Ref string + Source modelRefSource +} + +func parseModelRefs(refs string) []modelRef { + var result []modelRef + + for ref := range strings.SplitSeq(refs, ",") { + ref = strings.TrimSpace(ref) + if ref == "" { continue } - if !strings.Contains(part, "/") { - return false - } - validParts++ + + result = append(result, modelRef{ + Ref: ref, + Source: classifyModelRef(ref), + }) } - return validParts >= 2 + + return result } -// createProvidersFromInlineAlloy creates providers from an inline alloy spec. -// An inline alloy is comma-separated provider/model specs like "openai/gpt-4o,anthropic/claude-sonnet-4-0". -func (r *LocalRuntime) createProvidersFromInlineAlloy(ctx context.Context, modelRef string) ([]provider.Provider, error) { - var providers []provider.Provider +func parseInlineAlloyRefs(refs string) ([]string, bool) { + parsed := parseModelRefs(refs) + if len(parsed) < 2 { + return nil, false + } - for part := range strings.SplitSeq(modelRef, ",") { - part = strings.TrimSpace(part) - if part == "" { - continue + inlineRefs := make([]string, len(parsed)) + for i, ref := range parsed { + if ref.Source != modelRefInline { + return nil, false } + inlineRefs[i] = ref.Ref + } - // Check if this part exists as a named model in config - if modelCfg, exists := r.modelSwitcherCfg.Models[part]; exists { - prov, err := r.createProviderFromConfig(ctx, &modelCfg) - if err != nil { - return nil, fmt.Errorf("failed to create provider for %q: %w", part, err) - } - providers = append(providers, prov) - continue - } + return inlineRefs, true +} - // Parse as provider/model - providerName, modelName, ok := strings.Cut(part, "/") - if !ok { - return nil, fmt.Errorf("invalid model reference %q in inline alloy: expected 'provider/model' format", part) - } +func classifyModelRef(ref string) modelRefSource { + if strings.Contains(ref, "/") { + return modelRefInline + } + return modelRefConfig +} - inlineCfg := &latest.ModelConfig{ - Provider: providerName, - Model: modelName, - } - prov, err := r.createProviderFromConfig(ctx, inlineCfg) - if err != nil { - return nil, fmt.Errorf("failed to create provider for %q: %w", part, err) - } - providers = append(providers, prov) +// isInlineAlloySpec checks if a model reference is an inline alloy specification. +// An inline alloy is comma-separated provider/model specs like "openai/gpt-4o,anthropic/claude-sonnet-4-0". +func isInlineAlloySpec(modelRef string) bool { + _, ok := parseInlineAlloyRefs(modelRef) + return ok +} + +// parseModelRef resolves a model reference into either a named config or an inline model config. +func parseModelRef(models map[string]latest.ModelConfig, modelRef string) (*latest.ModelConfig, error) { + modelRef = strings.TrimSpace(modelRef) + if modelRef == "" { + return nil, fmt.Errorf("model reference cannot be empty") } - if len(providers) == 0 { - return nil, fmt.Errorf("inline alloy spec has no valid models") + if modelCfg, exists := models[modelRef]; exists { + return &modelCfg, nil } - return providers, nil + providerName, modelName, ok := strings.Cut(modelRef, "/") + if !ok { + return nil, fmt.Errorf("invalid model reference %q: expected a model name from config or 'provider/model' format", modelRef) + } + + return &latest.ModelConfig{ + Provider: providerName, + Model: modelName, + }, nil +} + +// createProvidersFromInlineAlloy creates providers from an inline alloy spec. +// An inline alloy is comma-separated provider/model specs like "openai/gpt-4o,anthropic/claude-sonnet-4-0". +func (r *LocalRuntime) createProvidersFromInlineAlloy(ctx context.Context, modelRefs []string) ([]provider.Provider, error) { + if !isInlineAlloySpec(strings.Join(modelRefs, ",")) { + return nil, fmt.Errorf("inline alloy spec must include at least two provider/model references") + } + return r.createProvidersFromRefs(ctx, modelRefs, "inline alloy spec") } // createProvidersFromAlloyConfig creates providers for each model in an alloy configuration. func (r *LocalRuntime) createProvidersFromAlloyConfig(ctx context.Context, alloyCfg latest.ModelConfig) ([]provider.Provider, error) { - var providers []provider.Provider - - for modelRef := range strings.SplitSeq(alloyCfg.Model, ",") { - modelRef = strings.TrimSpace(modelRef) - if modelRef == "" { - continue - } + refs := parseModelRefs(alloyCfg.Model) + modelRefs := make([]string, 0, len(refs)) + for _, ref := range refs { + modelRefs = append(modelRefs, ref.Ref) + } + return r.createProvidersFromRefs(ctx, modelRefs, "alloy config") +} - // Check if this model reference exists in the config - if modelCfg, exists := r.modelSwitcherCfg.Models[modelRef]; exists { - prov, err := r.createProviderFromConfig(ctx, &modelCfg) - if err != nil { - return nil, fmt.Errorf("failed to create provider for %q: %w", modelRef, err) - } - providers = append(providers, prov) - continue - } +func (r *LocalRuntime) createProvidersFromRefs(ctx context.Context, modelRefs []string, sourceLabel string) ([]provider.Provider, error) { + if len(modelRefs) == 0 { + return nil, fmt.Errorf("%s has no valid models", sourceLabel) + } - // Try parsing as inline spec (provider/model) - providerName, modelName, ok := strings.Cut(modelRef, "/") - if !ok { - return nil, fmt.Errorf("invalid model reference %q in alloy config: expected 'provider/model' format", modelRef) + providers := make([]provider.Provider, 0, len(modelRefs)) + for _, ref := range modelRefs { + modelCfg, err := parseModelRef(r.modelSwitcherCfg.Models, ref) + if err != nil { + return nil, fmt.Errorf("invalid model reference %q in %s: %w", ref, sourceLabel, err) } - inlineCfg := &latest.ModelConfig{ - Provider: providerName, - Model: modelName, - } - prov, err := r.createProviderFromConfig(ctx, inlineCfg) + prov, err := r.createProviderFromConfig(ctx, modelCfg) if err != nil { - return nil, fmt.Errorf("failed to create provider for %q: %w", modelRef, err) + return nil, fmt.Errorf("failed to create provider for %q: %w", ref, err) } providers = append(providers, prov) } - if len(providers) == 0 { - return nil, fmt.Errorf("alloy model config has no valid models") - } - return providers, nil } diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 219566228..31f642add 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -179,7 +179,7 @@ type LocalRuntime struct { modelsStore ModelStore sessionCompaction bool managedOAuth bool - startupInfoEmitted bool // Track if startup info has been emitted to avoid unnecessary duplication + startupInfoEmitted atomic.Bool // Track if startup info has been emitted to avoid unnecessary duplication elicitationRequestCh chan ElicitationResult // Channel for receiving elicitation responses elicitationEventsChannel chan Event // Current events channel for sending elicitation requests elicitationEventsChannelMux sync.RWMutex // Protects elicitationEventsChannel @@ -571,16 +571,15 @@ func (r *LocalRuntime) PermissionsInfo() *PermissionsInfo { // This should be called when replacing a session to allow re-emission of // agent, team, and toolset info to the UI. func (r *LocalRuntime) ResetStartupInfo() { - r.startupInfoEmitted = false + r.startupInfoEmitted.Store(false) } // EmitStartupInfo emits initial agent, team, and toolset information for immediate sidebar display func (r *LocalRuntime) EmitStartupInfo(ctx context.Context, events chan Event) { - // Prevent duplicate emissions - if r.startupInfoEmitted { + // Prevent duplicate emissions using atomic compare-and-swap + if !r.startupInfoEmitted.CompareAndSwap(false, true) { return } - r.startupInfoEmitted = true a := r.CurrentAgent() @@ -1010,7 +1009,7 @@ func (r *LocalRuntime) getTools(ctx context.Context, a *agent.Agent, sessionSpan return agentTools, nil } -// configureToolsetHandlers sets up elicitation and OAuth handlers for all toolsets of an agent. +// configureToolsetHandlers sets up elicitation, OAuth, and model switch handlers for all toolsets of an agent. func (r *LocalRuntime) configureToolsetHandlers(a *agent.Agent, events chan Event) { for _, toolset := range a.ToolSets() { toolset.SetElicitationHandler(r.elicitationHandler) @@ -1018,9 +1017,52 @@ func (r *LocalRuntime) configureToolsetHandlers(a *agent.Agent, events chan Even events <- Authorization(tools.ElicitationActionAccept, r.currentAgent) }) toolset.SetManagedOAuth(r.managedOAuth) + + // Configure switch_model callback if this is a SwitchModelToolset + if switchModelToolset := unwrapSwitchModelToolset(toolset); switchModelToolset != nil { + r.configureSwitchModelCallback(a, switchModelToolset, events) + } + } +} + +// unwrapSwitchModelToolset extracts a SwitchModelToolset from a potentially wrapped toolset. +// It recursively unwraps any toolset that implements the tools.Unwrapper interface. +// Returns the SwitchModelToolset if found, or nil if the toolset is not a SwitchModelToolset. +func unwrapSwitchModelToolset(toolset tools.ToolSet) *builtin.SwitchModelToolset { + for { + if switchModelTS, ok := toolset.(*builtin.SwitchModelToolset); ok { + return switchModelTS + } + + unwrapper, ok := toolset.(tools.Unwrapper) + if !ok { + return nil + } + toolset = unwrapper.Inner() } } +// configureSwitchModelCallback sets up the callback for the switch_model toolset +// so that when the model is switched, the agent's model override is updated +// and the TUI is notified via events. +func (r *LocalRuntime) configureSwitchModelCallback(a *agent.Agent, switchModelToolset *builtin.SwitchModelToolset, events chan Event) { + switchModelToolset.SetOnSwitchCallback(func(newModel string) error { + ctx := context.Background() + if err := r.SetAgentModel(ctx, a.Name(), newModel); err != nil { + slog.Error("Failed to switch model via switch_model tool", "agent", a.Name(), "model", newModel, "error", err) + return err + } + slog.Debug("Model switched via switch_model tool", "agent", a.Name(), "model", newModel) + + // Emit events to update the TUI sidebar with the new model + if events != nil { + events <- AgentInfo(a.Name(), getAgentModelID(a), a.Description(), a.WelcomeMessage()) + events <- TeamInfo(r.agentDetailsFromTeam(), r.currentAgent) + } + return nil + }) +} + // emitAgentWarningsWithSend emits agent warnings using the provided send function for context-aware sending. func (r *LocalRuntime) emitAgentWarningsWithSend(a *agent.Agent, send func(Event) bool) { warnings := a.DrainWarnings() diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index d9e5fe9c5..18afb806a 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" @@ -26,6 +27,7 @@ import ( "github.com/docker/cagent/pkg/session" "github.com/docker/cagent/pkg/team" "github.com/docker/cagent/pkg/tools" + tools_builtin "github.com/docker/cagent/pkg/tools/builtin" ) type stubToolSet struct { @@ -1311,3 +1313,430 @@ func TestToolRejectionWithoutReason(t *testing.T) { require.Equal(t, "The user rejected the tool call.", toolResponse.Response) require.NotContains(t, toolResponse.Response, "Reason:") } + +func TestSwitchModelTool_IntegrationWithRuntime(t *testing.T) { + // This test verifies that the switch_model tool correctly updates the + // agent's model override when used through the runtime. + + // Create a switch_model toolset + switchModelToolset, err := tools_builtin.NewSwitchModelToolset([]string{"fast", "powerful"}) + require.NoError(t, err) + + // Initial model + defaultModel := &mockProvider{id: "default-model", stream: &mockStream{}} + + // Create an agent with the switch_model toolset + root := agent.New("root", "You are a test agent", + agent.WithModel(defaultModel), + agent.WithToolSets(switchModelToolset), + ) + tm := team.New(team.WithAgents(root)) + + // Create runtime with model switcher config + modelSwitcherCfg := &ModelSwitcherConfig{ + Models: map[string]latest.ModelConfig{ + "fast": { + Provider: "openai", + Model: "gpt-4o-mini", + }, + "powerful": { + Provider: "openai", + Model: "gpt-4o", + }, + }, + Providers: nil, + EnvProvider: &mockEnvProvider{vars: map[string]string{"ANTHROPIC_API_KEY": "test-key"}}, + AgentDefaultModels: map[string]string{"root": "fast"}, + } + + rt, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + WithModelSwitcherConfig(modelSwitcherCfg), + ) + require.NoError(t, err) + + // Verify initial model + require.Equal(t, "default-model", root.Model().ID()) + require.False(t, root.HasModelOverride()) + + // Manually configure the toolset handlers (simulating what RunStream does) + events := make(chan Event, 10) + rt.configureToolsetHandlers(root, events) + + // Get the switch_model tool from the agent + agentTools, err := root.Tools(t.Context()) + require.NoError(t, err) + + var switchModelTool tools.Tool + for _, tool := range agentTools { + if tool.Name == "switch_model" { + switchModelTool = tool + break + } + } + require.NotEmpty(t, switchModelTool.Name, "switch_model tool should be available") + + // Call the switch_model tool to switch to "powerful" + toolCall := tools.ToolCall{ + ID: "test-call-1", + Function: tools.FunctionCall{ + Name: "switch_model", + Arguments: `{"model": "powerful"}`, + }, + } + + result, err := switchModelTool.Handler(t.Context(), toolCall) + require.NoError(t, err) + require.False(t, result.IsError, "switch_model should succeed: %s", result.Output) + require.Contains(t, result.Output, `Switched model from "fast" to "powerful"`) + + // Verify the model override was set + require.True(t, root.HasModelOverride(), "agent should have model override after switch") + require.Equal(t, "openai/gpt-4o", root.Model().ID(), "model should be switched to powerful") + + // Verify that AgentInfoEvent and TeamInfoEvent were emitted with the new model + var agentInfoEvent *AgentInfoEvent + var teamInfoEvent *TeamInfoEvent + for { + select { + case evt := <-events: + switch e := evt.(type) { + case *AgentInfoEvent: + agentInfoEvent = e + case *TeamInfoEvent: + teamInfoEvent = e + } + default: + // No more events in the channel + goto done + } + } +done: + require.NotNil(t, agentInfoEvent, "AgentInfoEvent should be emitted after model switch") + assert.Equal(t, "root", agentInfoEvent.AgentName) + assert.Equal(t, "openai/gpt-4o", agentInfoEvent.Model, "AgentInfoEvent should contain the new model ID") + + require.NotNil(t, teamInfoEvent, "TeamInfoEvent should be emitted after model switch") + require.Len(t, teamInfoEvent.AvailableAgents, 1) + assert.Equal(t, "gpt-4o", teamInfoEvent.AvailableAgents[0].Model, "TeamInfoEvent should contain the new model name") +} + +func TestSwitchModelTool_WithDescriptionToolSetWrapper(t *testing.T) { + // This test verifies that the switch_model tool works correctly even when + // wrapped in DescriptionToolSet (which happens when AddDescriptionParameter is enabled). + + // Create a switch_model toolset + switchModelToolset, err := tools_builtin.NewSwitchModelToolset([]string{"fast", "powerful"}) + require.NoError(t, err) + + // Wrap in DescriptionToolSet (simulating what teamloader does when AddDescriptionParameter is true) + wrappedToolset := tools.NewDescriptionToolSet(switchModelToolset) + + // Initial model + defaultModel := &mockProvider{id: "default-model", stream: &mockStream{}} + + // Create an agent with the wrapped toolset + root := agent.New("root", "You are a test agent", + agent.WithModel(defaultModel), + agent.WithToolSets(wrappedToolset), + ) + tm := team.New(team.WithAgents(root)) + + // Create runtime with model switcher config + modelSwitcherCfg := &ModelSwitcherConfig{ + Models: map[string]latest.ModelConfig{ + "fast": { + Provider: "openai", + Model: "gpt-4o-mini", + }, + "powerful": { + Provider: "openai", + Model: "gpt-4o", + }, + }, + Providers: nil, + EnvProvider: &mockEnvProvider{}, + AgentDefaultModels: map[string]string{"root": "fast"}, + } + + rt, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + WithModelSwitcherConfig(modelSwitcherCfg), + ) + require.NoError(t, err) + + // Verify initial model + require.Equal(t, "default-model", root.Model().ID()) + require.False(t, root.HasModelOverride()) + + // Manually configure the toolset handlers (simulating what RunStream does) + // This is where the bug was: the callback wasn't being set because + // unwrapSwitchModelToolset didn't handle the DescriptionToolSet wrapper. + events := make(chan Event, 10) + rt.configureToolsetHandlers(root, events) + + // Get the switch_model tool from the agent + agentTools, err := root.Tools(t.Context()) + require.NoError(t, err) + + var switchModelTool tools.Tool + for _, tool := range agentTools { + if tool.Name == "switch_model" { + switchModelTool = tool + break + } + } + require.NotEmpty(t, switchModelTool.Name, "switch_model tool should be available") + + // Call the switch_model tool to switch to "powerful" + toolCall := tools.ToolCall{ + ID: "test-call-1", + Function: tools.FunctionCall{ + Name: "switch_model", + Arguments: `{"model": "powerful"}`, + }, + } + + result, err := switchModelTool.Handler(t.Context(), toolCall) + require.NoError(t, err) + require.False(t, result.IsError, "switch_model should succeed: %s", result.Output) + require.Contains(t, result.Output, `Switched model from "fast" to "powerful"`) + + // Verify the model override was set - this is the key assertion! + // Before the fix, the callback wasn't set, so the override wouldn't be applied. + require.True(t, root.HasModelOverride(), "agent should have model override after switch") + require.Equal(t, "openai/gpt-4o", root.Model().ID(), "model should be switched to powerful") +} + +func TestSwitchModelTool_WithInstructionWrapper(t *testing.T) { + // This test verifies that the switch_model tool works correctly even when + // wrapped with a custom wrapper that implements the Unwrapper interface. + // This simulates what happens when the toolset has an 'instruction' field + // in the config (like in gopher.yaml), which causes teamloader to wrap + // the toolset with replaceInstruction. + + // Create a switch_model toolset + switchModelToolset, err := tools_builtin.NewSwitchModelToolset([]string{"haiku", "opus"}) + require.NoError(t, err) + + // Create a wrapper that implements the Unwrapper interface (simulating + // what teamloader's WithInstructions does when instruction is set). + wrappedToolset := &testInstructionWrapper{ToolSet: switchModelToolset} + + // Verify it implements Unwrapper + _, ok := tools.ToolSet(wrappedToolset).(tools.Unwrapper) + require.True(t, ok, "testInstructionWrapper should implement Unwrapper") + + // Initial model + defaultModel := &mockProvider{id: "anthropic/claude-haiku-4-5", stream: &mockStream{}} + + // Create an agent with the wrapped toolset + root := agent.New("root", "You are a test agent", + agent.WithModel(defaultModel), + agent.WithToolSets(wrappedToolset), + ) + tm := team.New(team.WithAgents(root)) + + // Create runtime with model switcher config + modelSwitcherCfg := &ModelSwitcherConfig{ + Models: map[string]latest.ModelConfig{ + "haiku": { + Provider: "anthropic", + Model: "claude-haiku-4-5", + }, + "opus": { + Provider: "anthropic", + Model: "claude-opus-4-5", + }, + }, + EnvProvider: &mockEnvProvider{vars: map[string]string{"ANTHROPIC_API_KEY": "test-key"}}, + AgentDefaultModels: map[string]string{"root": "haiku"}, + } + + rt, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + WithModelSwitcherConfig(modelSwitcherCfg), + ) + require.NoError(t, err) + + // Verify initial model + require.Equal(t, "anthropic/claude-haiku-4-5", root.Model().ID()) + require.False(t, root.HasModelOverride()) + + // Configure toolset handlers - this should find the SwitchModelToolset + // by unwrapping the instructionWrapper using the Unwrapper interface. + events := make(chan Event, 10) + rt.configureToolsetHandlers(root, events) + + // Get the switch_model tool from the agent + agentTools, err := root.Tools(t.Context()) + require.NoError(t, err) + + var switchModelTool tools.Tool + for _, tool := range agentTools { + if tool.Name == "switch_model" { + switchModelTool = tool + break + } + } + require.NotEmpty(t, switchModelTool.Name, "switch_model tool should be available") + + // Call the switch_model tool to switch to "opus" + toolCall := tools.ToolCall{ + ID: "test-call-1", + Function: tools.FunctionCall{ + Name: "switch_model", + Arguments: `{"model": "opus"}`, + }, + } + + result, err := switchModelTool.Handler(t.Context(), toolCall) + require.NoError(t, err) + require.False(t, result.IsError, "switch_model should succeed: %s", result.Output) + require.Contains(t, result.Output, `Switched model from "haiku" to "opus"`) + + // Verify the model override was set - this is the key assertion! + // This works because the unwrapSwitchModelToolset function now uses the + // Unwrapper interface to recursively unwrap any wrapper that implements it. + require.True(t, root.HasModelOverride(), "agent should have model override after switch") + require.Equal(t, "anthropic/claude-opus-4-5", root.Model().ID(), "model should be switched to opus") +} + +// testInstructionWrapper is a test helper that simulates teamloader's replaceInstruction wrapper. +// It implements the tools.Unwrapper interface to allow unwrapping. +type testInstructionWrapper struct { + tools.ToolSet +} + +func (w *testInstructionWrapper) Inner() tools.ToolSet { return w.ToolSet } + +func TestSwitchModelTool_MultiAgentScenario(t *testing.T) { + // This test simulates the multi-agent scenario from gopher.yaml + // where root has switch_model but librarian does not + + // Create mock providers + haikuModel := &mockProvider{id: "anthropic/claude-haiku-4-5"} + + // Create switch_model toolset for root + switchModelToolset, err := tools_builtin.NewSwitchModelToolset([]string{"haiku", "opus"}) + require.NoError(t, err) + + // Create agents: root with switch_model, librarian without + root := agent.New("root", "You are the root agent", + agent.WithModel(haikuModel), + agent.WithToolSets(switchModelToolset), + ) + + librarian := agent.New("librarian", "You are the librarian", + agent.WithModel(haikuModel), + ) + + tm := team.New(team.WithAgents(root, librarian)) + + // Create runtime with model switcher config + modelSwitcherCfg := &ModelSwitcherConfig{ + Models: map[string]latest.ModelConfig{ + "haiku": { + Provider: "anthropic", + Model: "claude-haiku-4-5", + }, + "opus": { + Provider: "anthropic", + Model: "claude-opus-4-5", + }, + }, + EnvProvider: &mockEnvProvider{vars: map[string]string{"ANTHROPIC_API_KEY": "test-key"}}, + AgentDefaultModels: map[string]string{"root": "haiku", "librarian": "haiku"}, + } + + rt, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + WithModelSwitcherConfig(modelSwitcherCfg), + ) + require.NoError(t, err) + + // Verify initial state + require.Equal(t, "anthropic/claude-haiku-4-5", root.Model().ID()) + require.False(t, root.HasModelOverride()) + + // Create events channel and configure handlers for root agent + events := make(chan Event, 20) + rt.configureToolsetHandlers(root, events) + + // Get tools for root agent (this creates the handler) + agentTools, err := root.Tools(t.Context()) + require.NoError(t, err) + + var switchModelTool tools.Tool + for _, tool := range agentTools { + if tool.Name == "switch_model" { + switchModelTool = tool + break + } + } + require.NotEmpty(t, switchModelTool.Name, "switch_model tool should be available") + + // Call the switch_model tool to switch to opus + toolCall := tools.ToolCall{ + ID: "test-multi-agent", + Function: tools.FunctionCall{ + Name: "switch_model", + Arguments: `{"model": "opus"}`, + }, + } + + t.Log("Calling switch_model tool in multi-agent scenario...") + result, err := switchModelTool.Handler(t.Context(), toolCall) + require.NoError(t, err) + require.False(t, result.IsError, "switch_model should succeed: %s", result.Output) + t.Logf("Result: %s", result.Output) + + // Verify model was switched + require.True(t, root.HasModelOverride(), "root should have model override") + require.Equal(t, "anthropic/claude-opus-4-5", root.Model().ID(), "root model should be switched to opus") + + // Verify events were emitted + var agentInfoEvent *AgentInfoEvent + var teamInfoEvent *TeamInfoEvent + for { + select { + case evt := <-events: + t.Logf("Event: %T", evt) + switch e := evt.(type) { + case *AgentInfoEvent: + agentInfoEvent = e + case *TeamInfoEvent: + teamInfoEvent = e + } + default: + goto done + } + } +done: + + require.NotNil(t, agentInfoEvent, "AgentInfoEvent should be emitted") + assert.Equal(t, "root", agentInfoEvent.AgentName) + assert.Equal(t, "anthropic/claude-opus-4-5", agentInfoEvent.Model) + + require.NotNil(t, teamInfoEvent, "TeamInfoEvent should be emitted") + // Team info should include both agents + require.Len(t, teamInfoEvent.AvailableAgents, 2) + + // Find root agent in team info and verify its model was updated + var rootAgentInfo *AgentDetails + for i := range teamInfoEvent.AvailableAgents { + if teamInfoEvent.AvailableAgents[i].Name == "root" { + rootAgentInfo = &teamInfoEvent.AvailableAgents[i] + break + } + } + require.NotNil(t, rootAgentInfo, "root agent should be in team info") + assert.Equal(t, "anthropic", rootAgentInfo.Provider) + assert.Equal(t, "claude-opus-4-5", rootAgentInfo.Model) + + t.Log("Multi-agent scenario test passed!") +} diff --git a/pkg/teamloader/filter.go b/pkg/teamloader/filter.go index 0690cbfdb..9b2736856 100644 --- a/pkg/teamloader/filter.go +++ b/pkg/teamloader/filter.go @@ -16,9 +16,9 @@ func WithToolsFilter(inner tools.ToolSet, toolNames ...string) tools.ToolSet { } return &filterTools{ - ToolSet: inner, - toolNames: toolNames, - exclude: false, + ToolSetWrapper: tools.ToolSetWrapper{ToolSet: inner}, + toolNames: toolNames, + exclude: false, } } @@ -30,14 +30,14 @@ func WithToolsExcludeFilter(inner tools.ToolSet, toolNames ...string) tools.Tool } return &filterTools{ - ToolSet: inner, - toolNames: toolNames, - exclude: true, + ToolSetWrapper: tools.ToolSetWrapper{ToolSet: inner}, + toolNames: toolNames, + exclude: true, } } type filterTools struct { - tools.ToolSet + tools.ToolSetWrapper toolNames []string exclude bool } diff --git a/pkg/teamloader/instructions.go b/pkg/teamloader/instructions.go index 4d09e0d79..e5ce2fa0a 100644 --- a/pkg/teamloader/instructions.go +++ b/pkg/teamloader/instructions.go @@ -12,13 +12,13 @@ func WithInstructions(inner tools.ToolSet, instruction string) tools.ToolSet { } return &replaceInstruction{ - ToolSet: inner, - instruction: instruction, + ToolSetWrapper: tools.ToolSetWrapper{ToolSet: inner}, + instruction: instruction, } } type replaceInstruction struct { - tools.ToolSet + tools.ToolSetWrapper instruction string } diff --git a/pkg/teamloader/registry.go b/pkg/teamloader/registry.go index 7374605cc..da70d0bcf 100644 --- a/pkg/teamloader/registry.go +++ b/pkg/teamloader/registry.go @@ -70,6 +70,7 @@ func NewDefaultToolsetRegistry() *ToolsetRegistry { r.Register("a2a", createA2ATool) r.Register("lsp", createLSPTool) r.Register("user_prompt", createUserPromptTool) + r.Register("switch_model", createSwitchModelTool) return r } @@ -284,3 +285,7 @@ func createLSPTool(ctx context.Context, toolset latest.Toolset, _ string, runCon func createUserPromptTool(_ context.Context, _ latest.Toolset, _ string, _ *config.RuntimeConfig) (tools.ToolSet, error) { return builtin.NewUserPromptTool(), nil } + +func createSwitchModelTool(_ context.Context, toolset latest.Toolset, _ string, _ *config.RuntimeConfig) (tools.ToolSet, error) { + return builtin.NewSwitchModelToolset(toolset.Models) +} diff --git a/pkg/teamloader/toon.go b/pkg/teamloader/toon.go index 3cb333f89..e95557d45 100644 --- a/pkg/teamloader/toon.go +++ b/pkg/teamloader/toon.go @@ -12,7 +12,7 @@ import ( ) type toonTools struct { - tools.ToolSet + tools.ToolSetWrapper toolRegexps []*regexp.Regexp } @@ -67,7 +67,7 @@ func WithToon(inner tools.ToolSet, toon string) tools.ToolSet { toolRegexps = append(toolRegexps, regexp.MustCompile(strings.TrimSpace(toolName))) } return &toonTools{ - ToolSet: inner, - toolRegexps: toolRegexps, + ToolSetWrapper: tools.ToolSetWrapper{ToolSet: inner}, + toolRegexps: toolRegexps, } } diff --git a/pkg/tools/builtin/switch_model.go b/pkg/tools/builtin/switch_model.go new file mode 100644 index 000000000..1e96e1eee --- /dev/null +++ b/pkg/tools/builtin/switch_model.go @@ -0,0 +1,189 @@ +package builtin + +import ( + "context" + "fmt" + "slices" + "strings" + "sync" + + "github.com/docker/cagent/pkg/tools" +) + +const ToolNameSwitchModel = "switch_model" + +// ModelSwitchCallback is called when the model is switched. +// It returns an error if the switch failed. +type ModelSwitchCallback func(newModel string) error + +// SwitchModelToolset provides a tool that allows agents to switch between +// a predefined set of models during a conversation. +type SwitchModelToolset struct { + tools.BaseToolSet + + mu sync.RWMutex + models []string + currentModel string // currently selected model + onSwitch ModelSwitchCallback // optional callback when model changes +} + +var _ tools.ToolSet = (*SwitchModelToolset)(nil) + +type SwitchModelArgs struct { + Model string `json:"model" jsonschema:"The model to switch to. Must be one of the allowed models listed in the tool description."` +} + +// NewSwitchModelToolset creates a new switch_model toolset with the given allowed models. +// The first model in the list becomes the default and initially selected model. +// Panics if models is empty or contains empty strings. +func NewSwitchModelToolset(models []string) (*SwitchModelToolset, error) { + if len(models) == 0 { + return nil, fmt.Errorf("switch_model toolset requires at least one model") + } + for i, m := range models { + if strings.TrimSpace(m) == "" { + return nil, fmt.Errorf("switch_model toolset: model at index %d is empty", i) + } + } + + return &SwitchModelToolset{ + models: slices.Clone(models), + currentModel: models[0], + }, nil +} + +// CurrentModel returns the currently selected model. +func (t *SwitchModelToolset) CurrentModel() string { + t.mu.RLock() + defer t.mu.RUnlock() + return t.currentModel +} + +// SetOnSwitchCallback sets a callback that will be invoked whenever the model is switched. +// The callback receives the new model name. This allows the runtime to react to model changes. +func (t *SwitchModelToolset) SetOnSwitchCallback(callback ModelSwitchCallback) { + t.mu.Lock() + defer t.mu.Unlock() + t.onSwitch = callback +} + +// Instructions returns guidance for when to use model switching. +func (t *SwitchModelToolset) Instructions() string { + return `## Model Switching Guidelines + +You have access to multiple AI models and can switch between them strategically. + +### When to Consider Switching Models + +**Switch to a faster/cheaper model when:** +- Performing simple, routine tasks (formatting, basic Q&A, short summaries) +- The current task doesn't require advanced reasoning +- Processing straightforward requests that any model can handle well +- Optimizing for response speed or cost efficiency + +**Switch to a more powerful model when:** +- Facing complex reasoning or multi-step problems +- Writing or reviewing code that requires careful analysis +- Handling nuanced or ambiguous requests +- Generating detailed technical content +- The current model is struggling with the task quality + +**Switch back to the default model when:** +- A specialized task is complete +- Returning to general conversation +- The extra capability is no longer needed + +### Best Practices + +1. Check the tool description to see available models and which one is currently active +2. Don't switch unnecessarily - there's overhead in changing models +3. Consider switching proactively before a complex task rather than after struggling +4. When in doubt about task complexity, prefer the more capable model` +} + +// callTool handles the switch_model tool invocation. +func (t *SwitchModelToolset) callTool(_ context.Context, params SwitchModelArgs) (*tools.ToolCallResult, error) { + requestedModel := strings.TrimSpace(params.Model) + if requestedModel == "" { + return tools.ResultError("model parameter is required and cannot be empty"), nil + } + + // Check if the requested model is in the allowed list + if !slices.Contains(t.models, requestedModel) { + return tools.ResultError(fmt.Sprintf( + "model %q is not allowed. Available models: %s", + requestedModel, + strings.Join(t.models, ", "), + )), nil + } + + t.mu.Lock() + previousModel := t.currentModel + t.currentModel = requestedModel + callback := t.onSwitch + t.mu.Unlock() + + if previousModel == requestedModel { + return tools.ResultSuccess(fmt.Sprintf("Model is already set to %q.", requestedModel)), nil + } + + // Notify the runtime about the model change + if callback != nil { + if err := callback(requestedModel); err != nil { + // Rollback the internal state change since the actual switch failed + t.mu.Lock() + t.currentModel = previousModel + t.mu.Unlock() + return tools.ResultError(fmt.Sprintf("Failed to switch model: %v", err)), nil + } + } + + return tools.ResultSuccess(fmt.Sprintf("Switched model from %q to %q.", previousModel, requestedModel)), nil +} + +// Tools returns the switch_model tool definition. +func (t *SwitchModelToolset) Tools(context.Context) ([]tools.Tool, error) { + t.mu.RLock() + currentModel := t.currentModel + t.mu.RUnlock() + + description := t.buildDescription(currentModel) + + return []tools.Tool{ + { + Name: ToolNameSwitchModel, + Category: "model", + Description: description, + Parameters: tools.MustSchemaFor[SwitchModelArgs](), + OutputSchema: tools.MustSchemaFor[string](), + Handler: tools.NewHandler(t.callTool), + Annotations: tools.ToolAnnotations{ + ReadOnlyHint: true, + Title: "Switch Model", + }, + }, + }, nil +} + +// buildDescription generates the tool description with current state. +func (t *SwitchModelToolset) buildDescription(currentModel string) string { + var sb strings.Builder + + sb.WriteString("Switch the AI model used for subsequent responses.\n\n") + sb.WriteString("**Available models:**\n") + for _, m := range t.models { + sb.WriteString(fmt.Sprintf("- %s", m)) + if m == t.models[0] { + sb.WriteString(" (default)") + } + if m == currentModel { + sb.WriteString(" (current)") + } + sb.WriteString("\n") + } + sb.WriteString("\n") + sb.WriteString("Only the models listed above can be selected. ") + sb.WriteString("Any other model will be rejected.") + + return sb.String() +} diff --git a/pkg/tools/builtin/switch_model_test.go b/pkg/tools/builtin/switch_model_test.go new file mode 100644 index 000000000..873af74af --- /dev/null +++ b/pkg/tools/builtin/switch_model_test.go @@ -0,0 +1,238 @@ +package builtin + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewSwitchModelToolset(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + models []string + wantErr bool + }{ + {"valid models", []string{"fast", "powerful"}, false}, + {"empty list", []string{}, true}, + {"nil list", nil, true}, + {"empty model in list", []string{"fast", "", "powerful"}, true}, + {"whitespace-only model", []string{"fast", " ", "powerful"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + toolset, err := NewSwitchModelToolset(tt.models) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.models[0], toolset.CurrentModel()) + }) + } +} + +func TestSwitchModelToolset_callTool(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + model string + wantError bool + wantOutput string + wantCurrent string + }{ + { + name: "switches to allowed model", + model: "powerful", + wantOutput: "Switched model from \"fast\" to \"powerful\"", + wantCurrent: "powerful", + }, + { + name: "already on requested model", + model: "fast", + wantOutput: "Model is already set to \"fast\"", + wantCurrent: "fast", + }, + { + name: "rejects unknown model", + model: "unknown", + wantError: true, + wantOutput: "model \"unknown\" is not allowed", + wantCurrent: "fast", + }, + { + name: "rejects empty model", + model: "", + wantError: true, + wantOutput: "model parameter is required", + wantCurrent: "fast", + }, + { + name: "rejects whitespace-only model", + model: " ", + wantError: true, + wantOutput: "model parameter is required", + wantCurrent: "fast", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := t.Context() + toolset, err := NewSwitchModelToolset([]string{"fast", "powerful"}) + require.NoError(t, err) + + result, err := toolset.callTool(ctx, SwitchModelArgs{Model: tt.model}) + + require.NoError(t, err) + assert.Equal(t, tt.wantError, result.IsError) + assert.Contains(t, result.Output, tt.wantOutput) + assert.Equal(t, tt.wantCurrent, toolset.CurrentModel()) + }) + } +} + +func TestSwitchModelToolset_Tools(t *testing.T) { + t.Parallel() + + ctx := t.Context() + toolset, err := NewSwitchModelToolset([]string{"fast", "powerful"}) + require.NoError(t, err) + + tools, err := toolset.Tools(ctx) + require.NoError(t, err) + require.Len(t, tools, 1) + + tool := tools[0] + assert.Equal(t, ToolNameSwitchModel, tool.Name) + assert.Equal(t, "model", tool.Category) + assert.True(t, tool.Annotations.ReadOnlyHint) + assert.NotNil(t, tool.Handler) + + // Description includes model info + assert.Contains(t, tool.Description, "fast (default) (current)") + assert.Contains(t, tool.Description, "powerful") + assert.Contains(t, tool.Description, "rejected") + + // After switching, description updates + _, _ = toolset.callTool(ctx, SwitchModelArgs{Model: "powerful"}) + tools, _ = toolset.Tools(ctx) + assert.Contains(t, tools[0].Description, "fast (default)") + assert.Contains(t, tools[0].Description, "powerful (current)") +} + +func TestSwitchModelToolset_Instructions(t *testing.T) { + t.Parallel() + + toolset, err := NewSwitchModelToolset([]string{"fast", "powerful"}) + require.NoError(t, err) + + instructions := toolset.Instructions() + + assert.Contains(t, instructions, "Model Switching Guidelines") + assert.Contains(t, instructions, "Switch to a faster/cheaper model") + assert.Contains(t, instructions, "Switch to a more powerful model") + assert.Contains(t, instructions, "Best Practices") +} + +func TestSwitchModelToolset_OnSwitchCallback(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + initialModel string + targetModel string + wantCallbackCall bool + wantCallbackArg string + }{ + { + name: "callback called on successful switch", + initialModel: "fast", + targetModel: "powerful", + wantCallbackCall: true, + wantCallbackArg: "powerful", + }, + { + name: "callback not called when already on model", + initialModel: "fast", + targetModel: "fast", + wantCallbackCall: false, + }, + { + name: "callback not called for invalid model", + initialModel: "fast", + targetModel: "unknown", + wantCallbackCall: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := t.Context() + + toolset, err := NewSwitchModelToolset([]string{"fast", "powerful"}) + require.NoError(t, err) + + var callbackCalled bool + var callbackArg string + toolset.SetOnSwitchCallback(func(newModel string) error { + callbackCalled = true + callbackArg = newModel + return nil + }) + + _, _ = toolset.callTool(ctx, SwitchModelArgs{Model: tt.targetModel}) + + assert.Equal(t, tt.wantCallbackCall, callbackCalled, "callback called mismatch") + if tt.wantCallbackCall { + assert.Equal(t, tt.wantCallbackArg, callbackArg, "callback argument mismatch") + } + }) + } +} + +func TestSwitchModelToolset_OnSwitchCallback_NilCallback(t *testing.T) { + t.Parallel() + + ctx := t.Context() + toolset, err := NewSwitchModelToolset([]string{"fast", "powerful"}) + require.NoError(t, err) + + // Ensure no panic when callback is nil + result, err := toolset.callTool(ctx, SwitchModelArgs{Model: "powerful"}) + require.NoError(t, err) + assert.False(t, result.IsError) + assert.Equal(t, "powerful", toolset.CurrentModel()) +} + +func TestSwitchModelToolset_OnSwitchCallback_WithError(t *testing.T) { + t.Parallel() + + ctx := t.Context() + toolset, err := NewSwitchModelToolset([]string{"fast", "powerful"}) + require.NoError(t, err) + + // Set callback that returns an error + callbackErr := fmt.Errorf("API key not configured") + toolset.SetOnSwitchCallback(func(newModel string) error { + return callbackErr + }) + + // Call the tool - should fail because callback returns error + result, err := toolset.callTool(ctx, SwitchModelArgs{Model: "powerful"}) + require.NoError(t, err) // No Go error, but tool error + assert.True(t, result.IsError, "should be a tool error") + assert.Contains(t, result.Output, "Failed to switch model") + assert.Contains(t, result.Output, "API key not configured") + + // Verify internal state was rolled back + assert.Equal(t, "fast", toolset.CurrentModel(), "internal state should be rolled back to previous model") +} diff --git a/pkg/tools/description.go b/pkg/tools/description.go index 1cfe3be0c..25f49234d 100644 --- a/pkg/tools/description.go +++ b/pkg/tools/description.go @@ -13,16 +13,16 @@ const ( // DescriptionToolSet wraps a ToolSet and adds a "description" parameter to all tools. // This allows the LLM to provide context about what it's doing with each tool call. type DescriptionToolSet struct { - inner ToolSet + ToolSetWrapper } // NewDescriptionToolSet creates a new DescriptionToolSet wrapping the given ToolSet. func NewDescriptionToolSet(inner ToolSet) *DescriptionToolSet { - return &DescriptionToolSet{inner: inner} + return &DescriptionToolSet{ToolSetWrapper: ToolSetWrapper{ToolSet: inner}} } func (f *DescriptionToolSet) Tools(ctx context.Context) ([]Tool, error) { - tools, err := f.inner.Tools(ctx) + tools, err := f.ToolSet.Tools(ctx) if err != nil { return nil, err } @@ -34,30 +34,6 @@ func (f *DescriptionToolSet) Tools(ctx context.Context) ([]Tool, error) { return result, nil } -func (f *DescriptionToolSet) Instructions() string { - return f.inner.Instructions() -} - -func (f *DescriptionToolSet) Start(ctx context.Context) error { - return f.inner.Start(ctx) -} - -func (f *DescriptionToolSet) Stop(ctx context.Context) error { - return f.inner.Stop(ctx) -} - -func (f *DescriptionToolSet) SetElicitationHandler(handler ElicitationHandler) { - f.inner.SetElicitationHandler(handler) -} - -func (f *DescriptionToolSet) SetOAuthSuccessHandler(handler func()) { - f.inner.SetOAuthSuccessHandler(handler) -} - -func (f *DescriptionToolSet) SetManagedOAuth(managed bool) { - f.inner.SetManagedOAuth(managed) -} - func (f *DescriptionToolSet) addDescriptionParam(tool Tool) Tool { if !tool.AddDescriptionParameter { return tool diff --git a/pkg/tools/tools.go b/pkg/tools/tools.go index a879009e5..ec3f29c91 100644 --- a/pkg/tools/tools.go +++ b/pkg/tools/tools.go @@ -106,6 +106,25 @@ type ToolSet interface { SetManagedOAuth(managed bool) } +// Unwrapper is an optional interface that wrapper toolsets can implement +// to provide access to their inner toolset. This is useful for unwrapping +// nested toolset wrappers to find specific toolset types (e.g., SwitchModelToolset). +type Unwrapper interface { + Inner() ToolSet +} + +// ToolSetWrapper is a base type for toolset decorators that wrap another toolset. +// It implements both ToolSet (by delegation) and Unwrapper interfaces. +// Embed this in wrapper types and override only the methods you need to customize. +type ToolSetWrapper struct { + ToolSet +} + +// Inner returns the wrapped toolset (implements Unwrapper). +func (w *ToolSetWrapper) Inner() ToolSet { + return w.ToolSet +} + // NewHandler creates a type-safe tool handler from a function that accepts typed parameters. // It handles JSON unmarshaling of the tool call arguments into the specified type T. func NewHandler[T any](fn func(context.Context, T) (*ToolCallResult, error)) ToolHandler { diff --git a/pkg/tui/components/sidebar/agent_info_test.go b/pkg/tui/components/sidebar/agent_info_test.go new file mode 100644 index 000000000..474303b9b --- /dev/null +++ b/pkg/tui/components/sidebar/agent_info_test.go @@ -0,0 +1,165 @@ +package sidebar + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/runtime" + "github.com/docker/cagent/pkg/session" + "github.com/docker/cagent/pkg/tui/service" +) + +func TestSidebar_SetAgentInfoUpdatesAvailableAgents(t *testing.T) { + t.Parallel() + + sess := session.New() + sessionState := service.NewSessionState(sess) + sb := New(sessionState) + + m := sb.(*model) + + // Set initial team info with original model + m.SetTeamInfo([]runtime.AgentDetails{ + { + Name: "root", + Description: "Test agent", + Provider: "openai", + Model: "gpt-4o-mini", + }, + }) + + // Verify initial state + require.Len(t, m.availableAgents, 1) + assert.Equal(t, "openai", m.availableAgents[0].Provider) + assert.Equal(t, "gpt-4o-mini", m.availableAgents[0].Model) + + // Now simulate a model switch via SetAgentInfo with new model + m.SetAgentInfo("root", "anthropic/claude-sonnet-4-0", "Test agent") + + // Verify the model was updated in availableAgents + require.Len(t, m.availableAgents, 1) + assert.Equal(t, "anthropic", m.availableAgents[0].Provider, "Provider should be updated") + assert.Equal(t, "claude-sonnet-4-0", m.availableAgents[0].Model, "Model should be updated") +} + +func TestSidebar_SetAgentInfoUpdatesCorrectAgent(t *testing.T) { + t.Parallel() + + sess := session.New() + sessionState := service.NewSessionState(sess) + sb := New(sessionState) + + m := sb.(*model) + + // Set up multiple agents + m.SetTeamInfo([]runtime.AgentDetails{ + { + Name: "root", + Description: "Root agent", + Provider: "openai", + Model: "gpt-4o", + }, + { + Name: "helper", + Description: "Helper agent", + Provider: "anthropic", + Model: "claude-sonnet-4-0", + }, + }) + + // Switch the model for the helper agent + m.SetAgentInfo("helper", "google/gemini-2.0-flash", "Helper agent") + + // Verify only the helper agent's model was updated + require.Len(t, m.availableAgents, 2) + assert.Equal(t, "openai", m.availableAgents[0].Provider, "Root provider should not change") + assert.Equal(t, "gpt-4o", m.availableAgents[0].Model, "Root model should not change") + assert.Equal(t, "google", m.availableAgents[1].Provider, "Helper provider should be updated") + assert.Equal(t, "gemini-2.0-flash", m.availableAgents[1].Model, "Helper model should be updated") +} + +func TestSidebar_SetAgentInfoWithModelIDWithoutProvider(t *testing.T) { + t.Parallel() + + sess := session.New() + sessionState := service.NewSessionState(sess) + sb := New(sessionState) + + m := sb.(*model) + + // Set initial team info + m.SetTeamInfo([]runtime.AgentDetails{ + { + Name: "root", + Description: "Test agent", + Provider: "openai", + Model: "gpt-4o-mini", + }, + }) + + // Switch to a model ID without provider prefix (shouldn't happen but handle gracefully) + m.SetAgentInfo("root", "some-model", "Test agent") + + // Verify the model was set (provider should remain unchanged) + require.Len(t, m.availableAgents, 1) + assert.Equal(t, "openai", m.availableAgents[0].Provider, "Provider should not change for non-prefixed model") + assert.Equal(t, "some-model", m.availableAgents[0].Model, "Model should be updated to the full ID") +} + +func TestSidebar_SetAgentInfoForNonExistentAgent(t *testing.T) { + t.Parallel() + + sess := session.New() + sessionState := service.NewSessionState(sess) + sb := New(sessionState) + + m := sb.(*model) + + // Set initial team info + m.SetTeamInfo([]runtime.AgentDetails{ + { + Name: "root", + Description: "Test agent", + Provider: "openai", + Model: "gpt-4o-mini", + }, + }) + + // Try to set info for a non-existent agent (should not panic or modify existing agents) + m.SetAgentInfo("nonexistent", "anthropic/claude-sonnet-4-0", "Some agent") + + // Verify the existing agent was not modified + require.Len(t, m.availableAgents, 1) + assert.Equal(t, "openai", m.availableAgents[0].Provider) + assert.Equal(t, "gpt-4o-mini", m.availableAgents[0].Model) +} + +func TestSidebar_SetAgentInfoWithEmptyModelID(t *testing.T) { + t.Parallel() + + sess := session.New() + sessionState := service.NewSessionState(sess) + sb := New(sessionState) + + m := sb.(*model) + + // Set initial team info + m.SetTeamInfo([]runtime.AgentDetails{ + { + Name: "root", + Description: "Test agent", + Provider: "openai", + Model: "gpt-4o-mini", + }, + }) + + // Call SetAgentInfo with empty modelID (should not modify availableAgents) + m.SetAgentInfo("root", "", "Test agent") + + // Verify the existing agent's model was not modified + require.Len(t, m.availableAgents, 1) + assert.Equal(t, "openai", m.availableAgents[0].Provider) + assert.Equal(t, "gpt-4o-mini", m.availableAgents[0].Model) +} diff --git a/pkg/tui/components/sidebar/sidebar.go b/pkg/tui/components/sidebar/sidebar.go index bca578821..8b4ae0529 100644 --- a/pkg/tui/components/sidebar/sidebar.go +++ b/pkg/tui/components/sidebar/sidebar.go @@ -219,23 +219,27 @@ func (m *model) SetTodos(result *tools.ToolCallResult) error { return m.todoComp.SetTodos(result) } -// SetAgentInfo sets the current agent information and updates the model in availableAgents +// SetAgentInfo sets the current agent information func (m *model) SetAgentInfo(agentName, modelID, description string) { m.currentAgent = agentName m.agentModel = modelID m.agentDescription = description m.reasoningSupported = modelsdev.ModelSupportsReasoning(context.Background(), modelID) - // Update the model in availableAgents for the current agent - // This is important when model routing selects a different model than configured - // Extract just the model name from "provider/model" format to match TeamInfoEvent format + // Update the model in availableAgents to ensure the sidebar displays the correct model + // when the model is switched via the switch_model tool. + if modelID == "" { + return + } for i := range m.availableAgents { - if m.availableAgents[i].Name == agentName && modelID != "" { - modelName := modelID - if idx := strings.LastIndex(modelName, "/"); idx != -1 { - modelName = modelName[idx+1:] + if m.availableAgents[i].Name == agentName { + // Parse the modelID to extract provider and model name + if prov, modelName, found := strings.Cut(modelID, "/"); found { + m.availableAgents[i].Provider = prov + m.availableAgents[i].Model = modelName + } else { + m.availableAgents[i].Model = modelID } - m.availableAgents[i].Model = modelName break } }