diff --git a/docs/reference/tools.mdx b/docs/reference/tools.mdx index bf5b208c..14e13211 100644 --- a/docs/reference/tools.mdx +++ b/docs/reference/tools.mdx @@ -152,6 +152,10 @@ async with hud.eval(task) as ctx: # All subagent activity appears in this single trace ``` +To create a separate trace per AgentTool call (useful for parallel subagents), +set `trace_subagent=True`. The tool result includes the sub-trace id in +`result.meta["trace_id"]`. + **See Also:** [Ops Diagnostics Cookbook](/cookbook/ops-diagnostics) for a complete hierarchical agent example. --- diff --git a/hud/__init__.py b/hud/__init__.py index 1fb747b1..828f43fd 100644 --- a/hud/__init__.py +++ b/hud/__init__.py @@ -13,6 +13,7 @@ from .eval import EvalContext from .eval import run_eval as eval from .telemetry.instrument import instrument +from .telemetry.parallel_group import parallel_agent_group def trace(*args: object, **kwargs: object) -> EvalContext: @@ -34,6 +35,7 @@ def trace(*args: object, **kwargs: object) -> EvalContext: "EvalContext", "eval", "instrument", + "parallel_agent_group", "trace", # Deprecated alias for eval ] diff --git a/hud/telemetry/__init__.py b/hud/telemetry/__init__.py index e237673b..0a0c84c2 100644 --- a/hud/telemetry/__init__.py +++ b/hud/telemetry/__init__.py @@ -3,6 +3,7 @@ This module provides: - @instrument decorator for recording function calls - High-performance span export to HUD API +- parallel_agent_group context manager for tracking parallel agent execution Usage: import hud @@ -14,14 +15,33 @@ async def my_function(): # Within an eval context, calls are recorded async with hud.eval(task) as ctx: result = await my_function() + + # Track parallel agents + from hud.telemetry import parallel_agent_group + + async with parallel_agent_group( + title="Deep Research", + description="Collect profiles...", + agents=[{"name": "Worker 1"}, {"name": "Worker 2"}], + ) as group: + # Run agents in parallel... + pass """ from hud.telemetry.exporter import flush, queue_span, shutdown from hud.telemetry.instrument import instrument +from hud.telemetry.parallel_group import ( + ParallelAgentGroup, + ParallelAgentInfo, + parallel_agent_group, +) __all__ = [ + "ParallelAgentGroup", + "ParallelAgentInfo", "flush", "instrument", + "parallel_agent_group", "queue_span", "shutdown", ] diff --git a/hud/telemetry/parallel_group.py b/hud/telemetry/parallel_group.py new file mode 100644 index 00000000..950eb8be --- /dev/null +++ b/hud/telemetry/parallel_group.py @@ -0,0 +1,310 @@ +"""Parallel agent group telemetry for HUD. + +This module provides a context manager for tracking parallel agent execution +with real-time progress updates in the HUD platform UI. + +Usage: + from hud.telemetry import parallel_agent_group + + async with parallel_agent_group( + title="Deep Research", + description="Collect profiles...", + agents=[{"name": "Worker 1"}, {"name": "Worker 2"}], + ) as group: + async def run_worker(agent_info): + group.update_status(agent_info.id, "running") + try: + result = await do_work() + group.mark_completed(agent_info.id) + return result + except Exception: + group.mark_failed(agent_info.id) + raise + + await asyncio.gather(*[run_worker(a) for a in group.agents]) +""" + +from __future__ import annotations + +import uuid +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any, Literal + +from hud.telemetry.exporter import queue_span +from hud.types import TraceStep + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + +AgentStatus = Literal["pending", "running", "completed", "failed"] + + +def _now_iso() -> str: + """Get current time as ISO-8601 string.""" + return datetime.now(UTC).isoformat().replace("+00:00", "Z") + + +def _normalize_trace_id(trace_id: str) -> str: + """Normalize trace_id to 32-character hex string.""" + clean = trace_id.replace("-", "") + return clean[:32].ljust(32, "0") + + +def _get_trace_id() -> str | None: + """Get current trace ID from eval context.""" + from hud.eval.context import get_current_trace_id + + return get_current_trace_id() + + +@dataclass +class ParallelAgentInfo: + """Individual agent in a parallel group.""" + + id: str + name: str + status: AgentStatus = "pending" + trace_id: str | None = None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "id": self.id, + "name": self.name, + "status": self.status, + "trace_id": self.trace_id, + } + + def to_status_dict(self) -> dict[str, Any]: + """Convert to minimal status dictionary.""" + return { + "id": self.id, + "status": self.status, + } + + +@dataclass +class ParallelAgentGroup: + """Manages a group of parallel agents with telemetry tracking. + + This class tracks the status of multiple agents running in parallel + and emits telemetry spans to the HUD platform. + """ + + title: str + description: str + agents: list[ParallelAgentInfo] = field(default_factory=list) + _span_id: str = field(default_factory=lambda: uuid.uuid4().hex[:16]) + _start_time: str = field(default_factory=_now_iso) + _task_run_id: str | None = field(default=None) + + def update_status( + self, + agent_id: str, + status: AgentStatus, + trace_id: str | None = None, + ) -> None: + """Update the status of an agent. + + Args: + agent_id: The ID of the agent to update + status: New status ("pending", "running", "completed", "failed") + trace_id: Optional trace ID linking to the agent's execution trace + """ + for agent in self.agents: + if agent.id == agent_id: + agent.status = status + if trace_id: + agent.trace_id = trace_id + self._emit_update() + return + raise ValueError(f"Agent with id '{agent_id}' not found in group") + + def mark_running(self, agent_id: str, trace_id: str | None = None) -> None: + """Mark an agent as running. + + Args: + agent_id: The ID of the agent + trace_id: Optional trace ID for the agent's execution + """ + self.update_status(agent_id, "running", trace_id) + + def mark_completed(self, agent_id: str, trace_id: str | None = None) -> None: + """Mark an agent as completed. + + Args: + agent_id: The ID of the agent + trace_id: Optional trace ID for the agent's execution + """ + self.update_status(agent_id, "completed", trace_id) + + def mark_failed(self, agent_id: str, trace_id: str | None = None) -> None: + """Mark an agent as failed. + + Args: + agent_id: The ID of the agent + trace_id: Optional trace ID for the agent's execution + """ + self.update_status(agent_id, "failed", trace_id) + + @property + def completed_count(self) -> int: + """Number of agents that have completed (successfully or failed).""" + return sum(1 for a in self.agents if a.status in ("completed", "failed")) + + @property + def total_count(self) -> int: + """Total number of agents in the group.""" + return len(self.agents) + + @property + def success_count(self) -> int: + """Number of agents that completed successfully.""" + return sum(1 for a in self.agents if a.status == "completed") + + @property + def failure_count(self) -> int: + """Number of agents that failed.""" + return sum(1 for a in self.agents if a.status == "failed") + + def _build_span(self, final: bool = False) -> dict[str, Any]: + """Build a HudSpan-compatible span record.""" + task_run_id = self._task_run_id or _get_trace_id() + if not task_run_id: + return {} + + now = _now_iso() + end_time = now + + # Build attributes using TraceStep + attributes = TraceStep( + task_run_id=task_run_id, + category="parallel-agent-group", + type="CLIENT", + start_timestamp=self._start_time, + end_timestamp=end_time, + request={ + "title": self.title, + "description": self.description, + "agents": [a.to_dict() for a in self.agents], + }, + result={ + "completed": self.completed_count, + "total": self.total_count, + "success": self.success_count, + "failed": self.failure_count, + "agents": [a.to_status_dict() for a in self.agents], + }, + ) + + # Determine status + has_failures = self.failure_count > 0 + status_code = "ERROR" if has_failures and final else "OK" + + span: dict[str, Any] = { + "name": "parallel_agent_group", + "trace_id": _normalize_trace_id(task_run_id), + "span_id": self._span_id, + "parent_span_id": None, + "start_time": self._start_time, + "end_time": end_time, + "status_code": status_code, + "status_message": None, + "attributes": attributes.model_dump(mode="json", exclude_none=True), + "internal_type": "parallel-agent-group", + } + + return span + + def _emit_update(self) -> None: + """Emit a span update to the telemetry backend.""" + span = self._build_span(final=False) + if span: + queue_span(span) + + def _emit_final(self) -> None: + """Emit the final span when the group completes.""" + span = self._build_span(final=True) + if span: + queue_span(span) + + +@asynccontextmanager +async def parallel_agent_group( + title: str, + description: str, + agents: list[dict[str, str]], +) -> AsyncIterator[ParallelAgentGroup]: + """Context manager for parallel agent execution with automatic telemetry. + + Creates a ParallelAgentGroup that tracks multiple agents running in parallel. + Emits spans with category="parallel-agent-group" that the HUD platform + renders as a visual card showing all agents and their progress. + + Args: + title: Display title for the group (e.g., "Deep Research") + description: Description of the parallel task + agents: List of agent configurations, each with at least a "name" key + + Yields: + ParallelAgentGroup instance for tracking agent status + + Example: + async with parallel_agent_group( + title="Deep Research", + description="Collect profiles for 250 researchers", + agents=[{"name": f"Worker {i}"} for i in range(10)], + ) as group: + async def run_worker(agent_info): + group.mark_running(agent_info.id) + try: + result = await do_research(agent_info.name) + group.mark_completed(agent_info.id) + return result + except Exception: + group.mark_failed(agent_info.id) + raise + + results = await asyncio.gather( + *[run_worker(a) for a in group.agents], + return_exceptions=True, + ) + """ + # Create agent info objects + agent_infos = [ + ParallelAgentInfo( + id=str(uuid.uuid4()), + name=agent_config.get("name", f"Agent {i}"), + status="pending", + ) + for i, agent_config in enumerate(agents) + ] + + # Create the group + task_run_id = _get_trace_id() + group = ParallelAgentGroup( + title=title, + description=description, + agents=agent_infos, + _task_run_id=task_run_id, + ) + + # Emit initial span + group._emit_update() + + try: + yield group + finally: + # Emit final span with completion status + group._emit_final() + + +__all__ = [ + "AgentStatus", + "ParallelAgentGroup", + "ParallelAgentInfo", + "parallel_agent_group", +] diff --git a/hud/telemetry/tests/test_parallel_group.py b/hud/telemetry/tests/test_parallel_group.py new file mode 100644 index 00000000..b83a94ee --- /dev/null +++ b/hud/telemetry/tests/test_parallel_group.py @@ -0,0 +1,366 @@ +"""Tests for hud.telemetry.parallel_group module.""" + +from __future__ import annotations + +import asyncio +from typing import Any +from unittest.mock import patch + +import pytest + +from hud.telemetry.parallel_group import ( + ParallelAgentGroup, + ParallelAgentInfo, + parallel_agent_group, +) + + +class TestParallelAgentInfo: + """Tests for ParallelAgentInfo dataclass.""" + + def test_default_values(self) -> None: + """Test default values are set correctly.""" + agent = ParallelAgentInfo(id="test-id", name="Test Agent") + assert agent.id == "test-id" + assert agent.name == "Test Agent" + assert agent.status == "pending" + assert agent.trace_id is None + + def test_to_dict(self) -> None: + """Test to_dict serialization.""" + agent = ParallelAgentInfo( + id="agent-1", + name="Worker 1", + status="completed", + trace_id="trace-123", + ) + result = agent.to_dict() + assert result == { + "id": "agent-1", + "name": "Worker 1", + "status": "completed", + "trace_id": "trace-123", + } + + def test_to_status_dict(self) -> None: + """Test to_status_dict minimal serialization.""" + agent = ParallelAgentInfo( + id="agent-1", + name="Worker 1", + status="running", + trace_id="trace-123", + ) + result = agent.to_status_dict() + assert result == { + "id": "agent-1", + "status": "running", + } + + +class TestParallelAgentGroup: + """Tests for ParallelAgentGroup class.""" + + def test_creation(self) -> None: + """Test group creation with agents.""" + agents = [ + ParallelAgentInfo(id="a1", name="Agent 1"), + ParallelAgentInfo(id="a2", name="Agent 2"), + ] + group = ParallelAgentGroup( + title="Test Group", + description="Test description", + agents=agents, + ) + assert group.title == "Test Group" + assert group.description == "Test description" + assert len(group.agents) == 2 + assert group.total_count == 2 + assert group.completed_count == 0 + + def test_update_status(self) -> None: + """Test updating agent status.""" + agents = [ + ParallelAgentInfo(id="a1", name="Agent 1"), + ParallelAgentInfo(id="a2", name="Agent 2"), + ] + group = ParallelAgentGroup( + title="Test", + description="Test", + agents=agents, + ) + + with patch.object(group, "_emit_update"): + group.update_status("a1", "running") + assert group.agents[0].status == "running" + + group.update_status("a1", "completed", trace_id="trace-123") + assert group.agents[0].status == "completed" + assert group.agents[0].trace_id == "trace-123" + + def test_update_status_invalid_id(self) -> None: + """Test updating non-existent agent raises error.""" + group = ParallelAgentGroup( + title="Test", + description="Test", + agents=[ParallelAgentInfo(id="a1", name="Agent 1")], + ) + + with pytest.raises(ValueError, match="Agent with id 'invalid' not found"): + group.update_status("invalid", "running") + + def test_mark_helpers(self) -> None: + """Test mark_running, mark_completed, mark_failed helpers.""" + agents = [ + ParallelAgentInfo(id="a1", name="Agent 1"), + ParallelAgentInfo(id="a2", name="Agent 2"), + ParallelAgentInfo(id="a3", name="Agent 3"), + ] + group = ParallelAgentGroup( + title="Test", + description="Test", + agents=agents, + ) + + with patch.object(group, "_emit_update"): + group.mark_running("a1") + assert group.agents[0].status == "running" + + group.mark_completed("a2", trace_id="trace-2") + assert group.agents[1].status == "completed" + assert group.agents[1].trace_id == "trace-2" + + group.mark_failed("a3") + assert group.agents[2].status == "failed" + + def test_count_properties(self) -> None: + """Test count properties calculate correctly.""" + agents = [ + ParallelAgentInfo(id="a1", name="Agent 1", status="completed"), + ParallelAgentInfo(id="a2", name="Agent 2", status="failed"), + ParallelAgentInfo(id="a3", name="Agent 3", status="running"), + ParallelAgentInfo(id="a4", name="Agent 4", status="pending"), + ] + group = ParallelAgentGroup( + title="Test", + description="Test", + agents=agents, + ) + + assert group.total_count == 4 + assert group.completed_count == 2 # completed + failed + assert group.success_count == 1 + assert group.failure_count == 1 + + def test_build_span_without_trace_id(self) -> None: + """Test _build_span returns empty dict when no trace_id.""" + group = ParallelAgentGroup( + title="Test", + description="Test", + agents=[], + ) + + with patch("hud.telemetry.parallel_group._get_trace_id", return_value=None): + span = group._build_span() + assert span == {} + + def test_build_span_with_trace_id(self) -> None: + """Test _build_span builds correct span structure.""" + agents = [ + ParallelAgentInfo(id="a1", name="Agent 1", status="completed"), + ParallelAgentInfo(id="a2", name="Agent 2", status="running"), + ] + group = ParallelAgentGroup( + title="Test Group", + description="Test description", + agents=agents, + _task_run_id="test-trace-id-123456789012", + ) + + span = group._build_span(final=False) + + assert span["name"] == "parallel_agent_group" + # Trace ID is normalized to 32 hex chars (dashes removed, padded/truncated) + assert len(span["trace_id"]) == 32 + assert span["trace_id"].startswith("testtraceid") + assert span["status_code"] == "OK" + assert span["internal_type"] == "parallel-agent-group" + + attrs = span["attributes"] + assert attrs["category"] == "parallel-agent-group" + assert attrs["request"]["title"] == "Test Group" + assert attrs["request"]["description"] == "Test description" + assert len(attrs["request"]["agents"]) == 2 + assert attrs["result"]["completed"] == 1 + assert attrs["result"]["total"] == 2 + + def test_build_span_final_with_failures(self) -> None: + """Test _build_span sets ERROR status when there are failures.""" + agents = [ + ParallelAgentInfo(id="a1", name="Agent 1", status="failed"), + ] + group = ParallelAgentGroup( + title="Test", + description="Test", + agents=agents, + _task_run_id="test-trace-id-123456789012", + ) + + span = group._build_span(final=True) + assert span["status_code"] == "ERROR" + + +class TestParallelAgentGroupContextManager: + """Tests for parallel_agent_group context manager.""" + + @pytest.mark.asyncio + async def test_basic_usage(self) -> None: + """Test basic context manager usage.""" + queued_spans: list[dict[str, Any]] = [] + + with ( + patch( + "hud.telemetry.parallel_group.queue_span", + side_effect=lambda s: queued_spans.append(s), + ), + patch( + "hud.telemetry.parallel_group._get_trace_id", + return_value="test-trace-12345678901234567890", + ), + ): + async with parallel_agent_group( + title="Test Group", + description="Test description", + agents=[{"name": "Worker 1"}, {"name": "Worker 2"}], + ) as group: + assert len(group.agents) == 2 + assert group.agents[0].name == "Worker 1" + assert group.agents[1].name == "Worker 2" + assert all(a.status == "pending" for a in group.agents) + + # Should have emitted at least 2 spans (initial + final) + assert len(queued_spans) >= 2 + + @pytest.mark.asyncio + async def test_status_updates_emit_spans(self) -> None: + """Test that status updates emit spans.""" + queued_spans: list[dict[str, Any]] = [] + + with ( + patch( + "hud.telemetry.parallel_group.queue_span", + side_effect=lambda s: queued_spans.append(s), + ), + patch( + "hud.telemetry.parallel_group._get_trace_id", + return_value="test-trace-12345678901234567890", + ), + ): + async with parallel_agent_group( + title="Test", + description="Test", + agents=[{"name": "Worker 1"}], + ) as group: + initial_count = len(queued_spans) + group.mark_running(group.agents[0].id) + assert len(queued_spans) == initial_count + 1 + + group.mark_completed(group.agents[0].id) + assert len(queued_spans) == initial_count + 2 + + # Final span emitted on exit + assert len(queued_spans) >= 3 + + @pytest.mark.asyncio + async def test_agent_name_defaults(self) -> None: + """Test that agent names default correctly.""" + with ( + patch("hud.telemetry.parallel_group.queue_span"), + patch( + "hud.telemetry.parallel_group._get_trace_id", + return_value="test-trace-12345678901234567890", + ), + ): + async with parallel_agent_group( + title="Test", + description="Test", + agents=[{}, {"name": "Custom Name"}], + ) as group: + assert group.agents[0].name == "Agent 0" + assert group.agents[1].name == "Custom Name" + + @pytest.mark.asyncio + async def test_parallel_execution_pattern(self) -> None: + """Test typical parallel execution pattern.""" + results: list[str] = [] + + with ( + patch("hud.telemetry.parallel_group.queue_span"), + patch( + "hud.telemetry.parallel_group._get_trace_id", + return_value="test-trace-12345678901234567890", + ), + ): + async with parallel_agent_group( + title="Parallel Work", + description="Do work in parallel", + agents=[{"name": f"Worker {i}"} for i in range(3)], + ) as group: + + async def do_work(agent_info: ParallelAgentInfo) -> str: + group.mark_running(agent_info.id) + await asyncio.sleep(0.01) # Simulate work + group.mark_completed(agent_info.id) + return f"Result from {agent_info.name}" + + results = await asyncio.gather(*[do_work(a) for a in group.agents]) + + assert len(results) == 3 + assert all("Result from Worker" in r for r in results) + + @pytest.mark.asyncio + async def test_exception_handling(self) -> None: + """Test that exceptions propagate and final span is still emitted.""" + queued_spans: list[dict[str, Any]] = [] + + with ( + patch( + "hud.telemetry.parallel_group.queue_span", + side_effect=lambda s: queued_spans.append(s), + ), + patch( + "hud.telemetry.parallel_group._get_trace_id", + return_value="test-trace-12345678901234567890", + ), + pytest.raises(ValueError, match="Test error"), + ): + async with parallel_agent_group( + title="Test", + description="Test", + agents=[{"name": "Worker"}], + ): + raise ValueError("Test error") + + # Final span should still be emitted + assert len(queued_spans) >= 2 + + +class TestModuleExports: + """Test that module exports are correct.""" + + def test_imports(self) -> None: + """Test that all expected symbols are importable.""" + from hud.telemetry import ( + ParallelAgentGroup, + ParallelAgentInfo, + parallel_agent_group, + ) + + assert ParallelAgentGroup is not None + assert ParallelAgentInfo is not None + assert parallel_agent_group is not None + + def test_hud_import(self) -> None: + """Test that parallel_agent_group is importable from hud.""" + from hud import parallel_agent_group + + assert parallel_agent_group is not None diff --git a/hud/tests/test_init_module.py b/hud/tests/test_init_module.py index 607dbfae..a4592bfa 100644 --- a/hud/tests/test_init_module.py +++ b/hud/tests/test_init_module.py @@ -25,6 +25,7 @@ def test_all_exports(self): "EvalContext", "eval", "instrument", + "parallel_agent_group", "trace", # Deprecated alias for eval ] diff --git a/hud/tools/agent.py b/hud/tools/agent.py index 2085ce25..b61b8a7f 100644 --- a/hud/tools/agent.py +++ b/hud/tools/agent.py @@ -2,7 +2,9 @@ from __future__ import annotations +import contextlib import inspect +import uuid from typing import TYPE_CHECKING, Any, Union, get_args, get_origin from fastmcp.tools.tool import FunctionTool, ToolResult @@ -88,6 +90,7 @@ def __init__( name: str | None = None, description: str | None = None, trace: bool = False, + trace_subagent: bool = False, ) -> None: if not model and agent is None: raise ValueError("Must provide either 'model' or 'agent'") @@ -99,6 +102,7 @@ def __init__( self._agent_cls = agent self._agent_params = agent_params or {} self._trace = trace + self._trace_subagent = trace_subagent # Get visible params from scenario function self._visible_params: set[str] = set() @@ -196,19 +200,42 @@ async def __call__(self, **kwargs: Any) -> ToolResult: # Tool calls are still recorded via the shared trace_id's context is_nested = parent_trace_id is not None - # Trace if explicitly requested AND not nested (nested uses parent trace) - should_trace = self._trace and not is_nested + # Decide how this tool call should be traced: + # - inherit: reuse parent trace_id but skip enter/exit registration + # - new: start a new trace for this subagent + # - none: no tracing + if is_nested: + trace_mode = "new" if self._trace_subagent else "inherit" + else: + trace_mode = "new" if (self._trace or self._trace_subagent) else "none" + + trace_id: str | None + trace_enabled: bool + if trace_mode == "inherit": + trace_id = parent_trace_id + trace_enabled = False + elif trace_mode == "new": + trace_id = str(uuid.uuid4()) + trace_enabled = True + else: + trace_id = None + trace_enabled = False # Wrap execution with instrumentation to mark as subagent # Platform uses category="subagent" to detect and render subagent tool calls @instrument(category="subagent", name=self.name) async def _run_subagent() -> ToolResult: + nonlocal trace_id async with run_eval( task, - trace=should_trace, - trace_id=parent_trace_id, + trace=trace_enabled, + trace_id=trace_id, quiet=True, ) as ctx: + # Only update trace_id from ctx when creating a new trace (not reusing parent) + if trace_mode == "new": + trace_id = ctx.trace_id + if self._model: from hud.agents import create_agent @@ -218,6 +245,15 @@ async def _run_subagent() -> ToolResult: result = await agent.run(ctx) content = result.content if hasattr(result, "content") and result.content else "" - return ToolResult(content=[TextContent(type="text", text=content)]) - - return await _run_subagent() + return ToolResult( + content=[TextContent(type="text", text=content)], + meta={"trace_id": trace_id} if trace_id else None, + ) + + try: + return await _run_subagent() + except Exception as e: + if trace_id: + with contextlib.suppress(Exception): + e.trace_id = trace_id # type: ignore[attr-defined] + raise