diff --git a/bugbug/tools/build_repair/agent.py b/bugbug/tools/build_repair/agent.py index 9eada330db..bdb17e3d3b 100644 --- a/bugbug/tools/build_repair/agent.py +++ b/bugbug/tools/build_repair/agent.py @@ -11,6 +11,13 @@ from claude_agent_sdk import ClaudeAgentOptions, ResultMessage, query from pydantic import BaseModel, Field +from tenacity import ( + retry, + retry_if_exception, + retry_if_exception_message, + stop_after_attempt, + wait_exponential_jitter, +) from bugbug.tools.base import GenerativeModelTool from bugbug.tools.build_repair.config import ( @@ -20,11 +27,14 @@ FIREFOX_MCP_URL, FIX_MODEL, SANDBOX_CONFIG, + VERIFY_ALLOWED_TOOLS, + VERIFY_MODEL, ) from bugbug.tools.build_repair.prompts import ( ANALYSIS_TEMPLATE, EVAL_PROMPT, FIX_TEMPLATE, + VERIFY_TEMPLATE, ) logger = getLogger(__name__) @@ -44,7 +54,16 @@ class BuildFailure(BaseModel): ) -class AgentResponse(BaseModel): +class UsageStats(BaseModel): + cost_usd: float = Field(default=0.0) + num_turns: int = Field(default=0) + input_tokens: int = Field(default=0) + output_tokens: int = Field(default=0) + cache_read_input_tokens: int = Field(default=0) + cache_creation_input_tokens: int = Field(default=0) + + +class AgentResponse(UsageStats): """Output from a build repair run, including analysis, diff, cost, and build results.""" summary: str = Field(default="") @@ -67,6 +86,28 @@ class AgentResponse(BaseModel): stage2_transcript: list[dict] = Field(default_factory=list) +class GroundTruth(BaseModel): + gh_fix_commits: list[str] = Field( + description="Git commit hashes of the ground truth fix." + ) + + +class Judgment(BaseModel): + analysis_correct: bool + analysis_quality: float + analysis_explanation: str + fix_matches_ground_truth: bool + fix_quality: float + fix_explanation: str + fix_acceptance_probability: float + fix_acceptance_explanation: str + + +class VerifyResponse(UsageStats): + judgment: Judgment | None = Field(default=None) + verification_transcript: list[dict] = Field(default_factory=list) + + class BuildRepairTool(GenerativeModelTool): """Two-stage build repair agent using Claude Agent SDK. @@ -82,12 +123,14 @@ def __init__( eval_mode: bool = False, analysis_model: str = ANALYSIS_MODEL, fix_model: str = FIX_MODEL, + verify_model: str = VERIFY_MODEL, ) -> None: self.eval_mode = eval_mode self.target_software = target_software self.analysis_only = analysis_only self.analysis_model = analysis_model self.fix_model = fix_model + self.verify_model = verify_model @classmethod def create(cls, **kwargs): @@ -128,21 +171,30 @@ async def _run_stage( result_data: dict = {} usage: dict = {} - if on_message: - on_message( - stage_name, - { - "type": "stage_start", - "prompt": prompt, - "model": model, - }, - ) - try: + @retry( + retry=( + retry_if_exception_message(match="Control request timeout") + | retry_if_exception_message(match="overloaded") + | retry_if_exception_message(match="529") + | retry_if_exception_message(match="exit code") + | retry_if_exception( + lambda e: isinstance(e, (TimeoutError, ConnectionError, OSError)) + ) + ), + stop=stop_after_attempt(5), + wait=wait_exponential_jitter(initial=2, max=60, jitter=5), + before_sleep=lambda rs: logger.warning( + f"Bug {bug_id}: {stage_name} transient error " + f"(attempt {rs.attempt_number}/5), retrying: {rs.outcome.exception()}" + ), + reraise=True, + ) + async def _query(): + nonlocal cost, turns, usage, result_data async for message in query(prompt=prompt, options=options): serialized = self._serialize_message(message) transcript.append(serialized) - logger.info(f"Bug {bug_id}: {stage_name} [{serialized['type']}]") - logger.debug(f"Bug {bug_id}: {stage_name} detail: {serialized}") + logger.debug(f"Bug {bug_id}: {stage_name} [{serialized['type']}]") if on_message: on_message(stage_name, serialized) if isinstance(message, ResultMessage): @@ -150,6 +202,18 @@ async def _run_stage( turns += message.num_turns or 0 usage = getattr(message, "usage", {}) or {} result_data = serialized + + if on_message: + on_message( + stage_name, + { + "type": "stage_start", + "prompt": prompt, + "model": model, + }, + ) + try: + await _query() finally: if on_message: on_message( @@ -233,6 +297,7 @@ async def run( analysis_prompt = ANALYSIS_TEMPLATE.format( bug_id=failure.bug_id, target_software=self.target_software, + worktree_path=worktree_path, eval=EVAL_PROMPT if self.eval_mode else "", ) try: @@ -305,7 +370,10 @@ async def run( mcp_servers=mcp_servers, ) fix_prompt = FIX_TEMPLATE.format( - bug_id=failure.bug_id, eval=EVAL_PROMPT if self.eval_mode else "" + target_software=self.target_software, + bug_id=failure.bug_id, + worktree_path=worktree_path, + eval=EVAL_PROMPT if self.eval_mode else "", ) try: ( @@ -410,3 +478,82 @@ async def run( stage1_transcript=stage1_transcript, stage2_transcript=stage2_transcript, ) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential_jitter(initial=2, max=30, jitter=5), + before_sleep=lambda rs: logger.warning( + f"Verification failed (attempt {rs.attempt_number}/3), " + f"retrying: {rs.outcome.exception()}" + ), + reraise=True, + ) + async def verify( + self, + failure: BuildFailure, + agent_diff: str, + ground_truth: GroundTruth, + worktree_path: Path, + on_message: Callable[[str, dict], None] | None = None, + ) -> VerifyResponse: + out_dir = worktree_path / "repair_agent" / "out" / str(failure.bug_id) + out_dir.mkdir(parents=True, exist_ok=True) + (out_dir / "agent_fix.diff").write_text(agent_diff, encoding="utf-8") + + gt_commits = " ".join(ground_truth.gh_fix_commits) + prompt = VERIFY_TEMPLATE.format( + target_software=self.target_software, + bug_id=failure.bug_id, + failure_commit=failure.git_commit, + ground_truth_commits=gt_commits, + worktree_path=worktree_path, + ) + + options = ClaudeAgentOptions( + model=self.verify_model, + cwd=str(worktree_path), + allowed_tools=VERIFY_ALLOWED_TOOLS, + disallowed_tools=["AskUserQuestion", "Task"], + sandbox=SANDBOX_CONFIG, + permission_mode="acceptEdits", + effort="high", + output_format={ + "type": "json_schema", + "schema": Judgment.model_json_schema(), + }, + ) + + logger.info( + f"Bug {failure.bug_id}: starting verification stage " + f"(model={self.verify_model}, ground_truth={gt_commits})" + ) + + transcript, cost, turns, usage = await self._run_stage( + "verification", + prompt, + self.verify_model, + options, + failure.bug_id, + on_message, + ) + + judgment: Judgment | None = None + for msg in reversed(transcript): + if msg.get("structured_output"): + judgment = Judgment.model_validate(msg["structured_output"]) + break + + if judgment is None: + result_msgs = [m for m in transcript if m.get("type") == "ResultMessage"] + raise RuntimeError( + f"Bug {failure.bug_id}: verification produced no structured output. " + f"Result messages: {result_msgs}" + ) + + return VerifyResponse( + judgment=judgment, + cost_usd=cost, + num_turns=turns, + verification_transcript=transcript, + **self._usage_fields(usage), + ) diff --git a/bugbug/tools/build_repair/config.py b/bugbug/tools/build_repair/config.py index f4bde13a3a..a3f69ef34e 100644 --- a/bugbug/tools/build_repair/config.py +++ b/bugbug/tools/build_repair/config.py @@ -9,6 +9,7 @@ ANALYSIS_MODEL = "claude-opus-4-6" FIX_MODEL = "claude-opus-4-6" +VERIFY_MODEL = "claude-opus-4-6" DEFAULT_MAX_TURNS = 80 WORKTREE_BASE_DIR = "/tmp/build_repair_worktrees" TRY_PUSH_TIMEOUT_SECONDS = 7200 @@ -32,6 +33,17 @@ "claude-opus-4-20250514": date(2025, 3, 1), } +VERIFY_ALLOWED_TOOLS = [ + "Read", + "Bash(git show:*)", + "Bash(git log:*)", + "Bash(git diff:*)", + "Bash(find:*)", + "Bash(grep:*)", + "WebFetch(domain:firefox-source-docs.mozilla.org)", + "WebFetch(domain:searchfox.org)", +] + ALLOWED_TOOLS = [ "Edit(~/.mozbuild)", "Edit(~/.cache/uv)", diff --git a/bugbug/tools/build_repair/prompts.py b/bugbug/tools/build_repair/prompts.py index cdab7e11e0..ee166620c2 100644 --- a/bugbug/tools/build_repair/prompts.py +++ b/bugbug/tools/build_repair/prompts.py @@ -15,22 +15,24 @@ 1. Git diff for the last commit 2. Bugzilla bug description 3. Taskcluster build failure logs -The files with bug description and logs are located at @repair_agent/in/{bug_id} +The files with bug description and logs are located at {worktree_path}/repair_agent/in/{bug_id} Create three separate documents: -1. repair_agent/out/{bug_id}/analysis.md with your detailed analysis on what caused the issues -2. repair_agent/out/{bug_id}/planning.md with a fixing plan -3. repair_agent/out/{bug_id}/summary.md with a brief one paragraph summary of analysis and planning that can point a developer in the right direction +1. {worktree_path}/repair_agent/out/{bug_id}/analysis.md with your detailed analysis on what caused the issues +2. {worktree_path}/repair_agent/out/{bug_id}/planning.md with a fixing plan +3. {worktree_path}/repair_agent/out/{bug_id}/summary.md with a brief one paragraph summary of analysis and planning that can point a developer in the right direction Do not prompt to edit those documents. {eval} -Do not write any code yet. Work fully autonomously, do not ask any questions. Think hard. +Do not write any code yet. Work fully autonomously, do not ask any questions. """ -FIX_TEMPLATE = """Read the following files and implement a fix of the failure: -1. repair_agent/out/{bug_id}/analysis.md with your detailed analysis on what caused the issues -2. repair_agent/out/{bug_id}/planning.md with a fixing plan +FIX_TEMPLATE = """You are an expert {target_software} engineer tasked with analyzing and fixing a build failure. + +Read the following files and implement a fix of the failure: +1. {worktree_path}/repair_agent/out/{bug_id}/analysis.md with your detailed analysis on what caused the issues +2. {worktree_path}/repair_agent/out/{bug_id}/planning.md with a fixing plan {eval} Do not prompt to edit files. Work fully autonomously, do not ask any questions. Use all allowed tools without prompting. @@ -40,3 +42,36 @@ Do not request bug info from Bugzilla or Phabricator. Use only the provided file with bug description. Do not look at git commits other than the specified last commit. """ + +VERIFY_TEMPLATE = """You are an expert {target_software} code reviewer evaluating an automated build repair agent's work. + +Examine the relevant commits using git: +- Failure commit (broke the build): {failure_commit} +- Ground truth fix commit(s) (the real fix that was landed): {ground_truth_commits} + +Inspect each commit's changes and read the repair agent's input/output files: +- {worktree_path}/repair_agent/in/{bug_id}/bug_description.md +- {worktree_path}/repair_agent/in/{bug_id}/build_failure_logs.md +- {worktree_path}/repair_agent/out/{bug_id}/analysis.md +- {worktree_path}/repair_agent/out/{bug_id}/summary.md +- {worktree_path}/repair_agent/out/{bug_id}/agent_fix.diff (may be empty if no fix was produced) + +Evaluate the agent's work on two dimensions: + +ANALYSIS: +- Did the agent correctly identify the root cause of the build failure? +- How thorough and accurate is the analysis? + +FIX: +- Does the agent's fix address the same files/functions as the ground truth? +- Is the fix semantically equivalent or close to the ground truth? +- Would the fix be acceptable in code review as-is? + +Guidelines: +- If agent_fix.diff is empty, set fix_matches_ground_truth=false, fix_quality=0.0, fix_acceptance_probability=0.0 +- A fix can be correct even if it differs syntactically from ground truth -- focus on semantic equivalence +- analysis_correct should be true if the agent found the right root cause, even if the explanation is imperfect +- Be calibrated: 0.5 means genuinely uncertain, not a default score + +Work autonomously, do not ask questions. +""" diff --git a/bugbug/tools/build_repair/scorer.py b/bugbug/tools/build_repair/scorer.py index 566b384a6a..2fc702258c 100644 --- a/bugbug/tools/build_repair/scorer.py +++ b/bugbug/tools/build_repair/scorer.py @@ -141,23 +141,67 @@ def summarize(self, score_rows: list[dict]) -> dict: class LLMFixMatchingScorer(weave.Scorer): - """Scaffold for LLM-as-a-judge comparing agent fix to ground truth. + """Aggregates LLM-as-a-judge verify results from the model output.""" - Implementation deferred. Will use a non-Claude LLM to semantically - compare the agent's diff against the ground truth fix commit. - """ + num_trials: int = 1 @weave.op() - async def score(self, output: dict | None, gh_fix_commits: list[str]) -> dict: + def score(self, output: dict | None) -> dict: + none_metrics = { + "analysis_correct": None, + "analysis_quality": None, + "fix_matches_ground_truth": None, + "fix_quality": None, + "fix_acceptance_probability": None, + "judge_cost_usd": 0, + } + if output is None: - return { - "match_score": None, - "match_category": "errored", - } + return none_metrics + + verify = output.get("verify") + if not verify: + return none_metrics + + j = verify.get("judgment") + if not j: + none_metrics["judge_cost_usd"] = verify.get("cost_usd", 0) + return none_metrics + return { - "match_score": None, - "match_category": "not_implemented", + "analysis_correct": j.get("analysis_correct"), + "analysis_quality": j.get("analysis_quality"), + "analysis_explanation": j.get("analysis_explanation", ""), + "fix_matches_ground_truth": j.get("fix_matches_ground_truth"), + "fix_quality": j.get("fix_quality"), + "fix_explanation": j.get("fix_explanation", ""), + "fix_acceptance_probability": j.get("fix_acceptance_probability"), + "fix_acceptance_explanation": j.get("fix_acceptance_explanation", ""), + "judge_cost_usd": verify.get("cost_usd", 0), } def summarize(self, score_rows: list[dict]) -> dict: - return {"status": "not_implemented"} + scored = [r for r in score_rows if r.get("analysis_quality") is not None] + n = len(scored) + + total_analysis_quality = sum(r["analysis_quality"] for r in scored) + analysis_correct_count = sum(r.get("analysis_correct") is True for r in scored) + total_fix_quality = sum(r["fix_quality"] for r in scored) + fix_match_count = sum(r.get("fix_matches_ground_truth") is True for r in scored) + total_fix_acceptance = sum(r["fix_acceptance_probability"] for r in scored) + + summary: dict = { + "avg_analysis_quality": total_analysis_quality / n if n else 0, + "analysis_correct_rate": analysis_correct_count / n if n else 0, + "avg_fix_quality": total_fix_quality / n if n else 0, + "fix_match_rate": fix_match_count / n if n else 0, + "avg_fix_acceptance_probability": total_fix_acceptance / n if n else 0, + "total_judge_cost_usd": sum(r.get("judge_cost_usd", 0) for r in score_rows), + "num_scored": n, + } + if self.num_trials > 1: + summary.update( + _pass_at_k(score_rows, self.num_trials, "fix_matches_ground_truth") + ) + logger.info(f"LLMFixMatching summary: {summary}") + return summary diff --git a/scripts/build_repair_eval.py b/scripts/build_repair_eval.py index aa7d91cc1f..b8106a5259 100644 --- a/scripts/build_repair_eval.py +++ b/scripts/build_repair_eval.py @@ -27,7 +27,12 @@ import weave -from bugbug.tools.build_repair.agent import AgentResponse, BuildFailure, BuildRepairTool +from bugbug.tools.build_repair.agent import ( + AgentResponse, + BuildFailure, + BuildRepairTool, + GroundTruth, +) from bugbug.tools.build_repair.config import MODEL_CUTOFF_DATES from bugbug.tools.build_repair.scorer import ( BasicMetricsScorer, @@ -231,6 +236,7 @@ async def invoke( bug_id: int, pre_fix_bug: dict, gh_failure_commits: list[str], + gh_fix_commits: list[str], failures: list[dict], fix_commit_date: str, **kwargs, @@ -257,6 +263,7 @@ async def invoke( worktree_path = self.worktree_mgr.create(gh_failure_commits[0], wt_name) worktree_created = True + on_message = _make_weave_callback() failure = BuildFailure( bug_id=bug_id, bug_title=pre_fix_bug["title"], @@ -268,7 +275,7 @@ async def invoke( failure, worktree_path=worktree_path, skip_try_push=self.no_try_push, - on_message=_make_weave_callback(), + on_message=on_message, ) logger.info( f"Bug {bug_id} completed: error={result.error}, " @@ -279,6 +286,18 @@ async def invoke( ) output = result.model_dump() + + if result.analysis or result.summary: + ground_truth = GroundTruth(gh_fix_commits=gh_fix_commits) + verify_result = await self.tool.verify( + failure, + result.diff, + ground_truth, + worktree_path, + on_message, + ) + output["verify"] = verify_result.model_dump() + if result.error: raise BuildRepairError(output) return output @@ -322,6 +341,7 @@ def main() -> None: ) os.environ["WEAVE_PARALLELISM"] = str(args.parallelism) + os.environ["WEAVE_LOG_LEVEL"] = "INFO" if args.verbose else "WARNING" client = weave.init("bugbug-build-repair-eval") _register_model_costs(client) @@ -331,7 +351,10 @@ def main() -> None: dataset.rows = dataset.rows[: args.limit] logger.info(f"Limited to {len(dataset.rows)} rows") - scorers = [BasicMetricsScorer(num_trials=args.trials), LLMFixMatchingScorer()] + scorers = [ + BasicMetricsScorer(num_trials=args.trials), + LLMFixMatchingScorer(num_trials=args.trials), + ] if not args.analysis_only: scorers.insert(1, BuildPassRateScorer(num_trials=args.trials)) logger.info(f"Scorers: {[type(s).__name__ for s in scorers]}")