diff --git a/bin/test-compaction b/bin/test-compaction new file mode 100755 index 0000000000..c59d457afb --- /dev/null +++ b/bin/test-compaction @@ -0,0 +1,53 @@ +#!/usr/bin/env bash +# E2E Automatic Test for Agent History Compaction +# This script temporarily lowers the token limit, runs the agent, spams it to force compaction, +# and verifies it recovers without an API crash. + +set -e + +# 1. Temporarily patch the agent configuration to a very low 500-token limit +echo "Temporarily patching AgentConfig to max_history_tokens=500..." +# macOS requires -i '' for sed +sed -i '' 's/max_history_tokens: int | None = None/max_history_tokens: int | None = 500/g' dimos/agents/agent.py + +# Ensure we revert the change even if the script fails or is aborted (Ctrl+C) +cleanup() { + echo "Cleaning up..." + dimos stop --force || true + git checkout dimos/agents/agent.py + echo "Reverted dimos/agents/agent.py to original state." +} +trap cleanup EXIT + +# 2. Launch the agentic blueprint in daemon (background) mode +echo "Starting Go2 Agentic Simulation in the background..." +dimos --simulation run unitree-go2-agentic --daemon + +# Wait for the agent and all RPC modules to initialize +echo "Waiting 15 seconds for modules to boot and connect..." +sleep 15 + +# 3. Send a physical command that inherently requires a tool call +echo "Sending initial tool command to start execution..." +dimos agent-send "Move forward 0.5 meters." +sleep 10 + +# 4. Spam the context window to force LangChain trim_messages() to drop the old tool history +echo "Spamming the agent to force history compaction..." +for i in {1..4}; do + echo "Sending spam block $i..." + # ~300 dummy words + SPAM_TEXT="Please explicitly acknowledge this message. Context block $i: $(printf 'spam dummy word %.0s' {1..300})" + dimos agent-send "$SPAM_TEXT" + sleep 10 +done + +# 5. Send one final physical tool command +echo "Sending final tool command. If compaction breaks, LangGraph will crash now!" +dimos agent-send "Now turn left 90 degrees." +sleep 15 + +echo "--------------------------------------------------------" +echo "Test Completed Successfully! (No early exit from crashes)" +echo "Run 'dimos log' to review the actual LLM tool outputs." +echo "--------------------------------------------------------" diff --git a/dimos/agents/agent.py b/dimos/agents/agent.py index 37e1a4757c..d307aa4b04 100644 --- a/dimos/agents/agent.py +++ b/dimos/agents/agent.py @@ -43,6 +43,7 @@ class AgentConfig(ModuleConfig): system_prompt: str | None = SYSTEM_PROMPT model: str = "gpt-4o" model_fixture: str | None = None + max_history_tokens: int | None = None class Agent(Module[AgentConfig]): @@ -132,11 +133,31 @@ def _thread_loop(self) -> None: def _process_message( self, state_graph: CompiledStateGraph[Any, Any, Any, Any], message: BaseMessage ) -> None: + from langchain_core.messages import trim_messages + from dimos.agents.utils import estimate_tokens + self.agent_idle.publish(False) self._history.append(message) pretty_print_langchain_message(message) self.agent.publish(message) + if self.config.max_history_tokens is not None: + trimmed_history = trim_messages( + self._history, + max_tokens=self.config.max_history_tokens, + strategy="last", + token_counter=estimate_tokens, + include_system=True, + allow_partial=False, + ) + # Ensure it's a list since trim_messages can return other types sometimes + trimmed_history = list(trimmed_history) + else: + trimmed_history = self._history + + # We replace the internal history with the pruned one so it doesn't grow indefinitely in RAM + self._history = trimmed_history.copy() + for update in state_graph.stream({"messages": self._history}, stream_mode="updates"): for node_output in update.values(): for msg in node_output.get("messages", []): diff --git a/dimos/agents/mcp/mcp_client.py b/dimos/agents/mcp/mcp_client.py index 7c5eda5302..4800758450 100644 --- a/dimos/agents/mcp/mcp_client.py +++ b/dimos/agents/mcp/mcp_client.py @@ -45,6 +45,7 @@ class McpClientConfig(ModuleConfig): model: str = "gpt-4o" model_fixture: str | None = None mcp_server_url: str = "http://localhost:9990/mcp" + max_history_tokens: int | None = None class McpClient(Module[McpClientConfig]): @@ -213,11 +214,30 @@ def _thread_loop(self) -> None: def _process_message( self, state_graph: CompiledStateGraph[Any, Any, Any, Any], message: BaseMessage ) -> None: + from langchain_core.messages import trim_messages + from dimos.agents.utils import estimate_tokens + self.agent_idle.publish(False) self._history.append(message) pretty_print_langchain_message(message) self.agent.publish(message) + if self.config.max_history_tokens is not None: + trimmed_history = trim_messages( + self._history, + max_tokens=self.config.max_history_tokens, + strategy="last", + token_counter=estimate_tokens, + include_system=True, + allow_partial=False, + ) + # Ensure it's a list since trim_messages can return other types sometimes + trimmed_history = list(trimmed_history) + else: + trimmed_history = self._history + + self._history = trimmed_history.copy() + for update in state_graph.stream({"messages": self._history}, stream_mode="updates"): for node_output in update.values(): for msg in node_output.get("messages", []): diff --git a/dimos/agents/test_compaction.py b/dimos/agents/test_compaction.py new file mode 100644 index 0000000000..8a52a45c8c --- /dev/null +++ b/dimos/agents/test_compaction.py @@ -0,0 +1,73 @@ +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from dimos.agents.utils import estimate_tokens +from langchain_core.messages import trim_messages + +def test_estimate_tokens(): + msgs = [ + SystemMessage(content="system prompt"), + HumanMessage(content="Hello world!"), # length 12 + AIMessage(content="", tool_calls=[{"name": "test_tool", "args": {}, "id": "call_123"}]) # 1 tool call + ] + + # 13 char / 4 = 3 + 10 = 13 (System) + # 12 char / 4 = 3 + 10 = 13 (Human) + # 0 char / 4 = 0 + 10 = 10 + 50 (ToolCall) = 60 (AI) + # Total ~ 86 + tokens = estimate_tokens(msgs) + assert tokens > 0 + assert tokens < 150 # Roughly sane heuristic test + +def test_trim_history_preserves_system_prompt(): + msgs = [ + SystemMessage(content="System"), + HumanMessage(content="A" * 100), + AIMessage(content="B" * 100), + HumanMessage(content="C" * 100), + ] + + # Need to find a threshold that drops something but keeps system + # Each block of 100 chars is ~35 tokens (100 / 4 + 10). System is ~11 tokens. + # Total is ~ 116. + # Let's set limit to 60. Should keep System and "C". + trimmed = trim_messages( + msgs, + max_tokens=60, + strategy="last", + token_counter=estimate_tokens, + include_system=True, + allow_partial=False + ) + + assert isinstance(trimmed, list) + assert len(trimmed) == 2 + assert trimmed[0].content == "System" + assert trimmed[1].content == "C" * 100 + +def test_trim_history_does_not_break_toolcalls(): + msgs = [ + SystemMessage(content="System"), + HumanMessage(content="trigger tool"), + # Tool call sequence + AIMessage(content="", tool_calls=[{"name": "test_tool", "args": {}, "id": "call_123"}]), + ToolMessage(content="tool result", tool_call_id="call_123"), + # New prompt + HumanMessage(content="D" * 200), + ] + + # Total is around: + # System: ~11, Human: ~13, AI: ~60, Tool: ~12, Human: ~60. Total ~156. + # What if we set max to 75? + # It must keep "D" (60 tokens). It needs 15 more, but the ToolMessage + AI block is ~72 tokens, + # so it should drop both. It should not keep just the ToolMessage or just the AIMessage. + trimmed = trim_messages( + msgs, + max_tokens=75, + strategy="last", + token_counter=estimate_tokens, + include_system=True, + allow_partial=False + ) + + assert len(trimmed) == 2 + assert trimmed[0].content == "System" + assert trimmed[1].content == "D" * 200 diff --git a/dimos/agents/utils.py b/dimos/agents/utils.py index 5084c65b1f..245ef2dd1d 100644 --- a/dimos/agents/utils.py +++ b/dimos/agents/utils.py @@ -16,6 +16,17 @@ from typing import Any from langchain_core.messages.base import BaseMessage +import json + +def estimate_tokens(msgs: list[BaseMessage]) -> int: + """Safely estimates token counts for agent history compaction.""" + count = 0 + for m in msgs: + content_str = json.dumps(m.content) if not isinstance(m.content, str) else m.content + count += len(content_str) // 4 + 10 + if getattr(m, "tool_calls", None): + count += 50 * len(m.tool_calls) # type: ignore + return count from dimos.utils.logging_config import setup_logger diff --git a/pyproject.toml b/pyproject.toml index 017562a78a..858129b008 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -202,7 +202,7 @@ manipulation = [ # Hardware SDKs "piper-sdk", - "pyrealsense2", + "pyrealsense2; sys_platform != 'darwin'", "xarm-python-sdk>=1.17.0", # Visualization (Optional) @@ -226,7 +226,7 @@ cuda = [ "cupy-cuda12x==13.6.0; platform_machine == 'x86_64'", "nvidia-nvimgcodec-cu12[all]; platform_machine == 'x86_64'", "onnxruntime-gpu>=1.17.1; platform_machine == 'x86_64'", # Only versions supporting both cuda11 and cuda12 - "ctransformers[cuda]==0.2.27", + "ctransformers[cuda]==0.2.27; sys_platform != 'darwin'", "xformers>=0.0.20; platform_machine == 'x86_64'", ]