Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 95 additions & 19 deletions src/amplihack/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
_COPILOT_SDK_OK = False
try:
from copilot import CopilotClient # type: ignore[import-not-found]
from copilot.types import MessageOptions, SessionConfig # type: ignore[import-not-found]
from copilot.session import PermissionHandler # type: ignore[import-not-found]

_COPILOT_SDK_OK = True
except ImportError:
Expand All @@ -46,14 +46,61 @@

SDK_AVAILABLE = _CLAUDE_SDK_OK or _COPILOT_SDK_OK

# Env vars that explicitly select an LLM provider, in priority order.
# An explicit env override always beats file-based launcher detection so
# that embedded callers (e.g. Simard's OODA daemon, which is a Rust binary
# and never goes through `amplihack copilot` and therefore never writes a
# launcher_context.json) can pin the SDK without faking a launcher context.
#
# Recognized values: "copilot", "claude". Anything else is ignored.
_PROVIDER_ENV_VARS = (
"AMPLIHACK_LLM_PROVIDER",
"SIMARD_LLM_PROVIDER",
)

__all__ = ["completion", "SDK_AVAILABLE"]


def _provider_from_env() -> str | None:
"""Return an explicit provider override from env, or None.

Recognized values are normalized: any of {claude, copilot}. Other
values (e.g. Simard-specific aliases like "rustyclawd") are mapped
to copilot when the GitHub Copilot stack is intended, otherwise
ignored. Unrecognized values fall through to file-based detection.
"""
for var in _PROVIDER_ENV_VARS:
raw = os.environ.get(var)
if not raw:
continue
v = raw.strip().lower()
if v in ("copilot", "github-copilot", "gh-copilot", "rustyclawd"):
return "copilot"
if v in ("claude", "anthropic", "claude-code"):
return "claude"
# Unrecognized — keep looking, then fall through to file detection.
return None


def _detect_launcher(project_root: Path) -> str:
"""Detect launcher type, cached per process."""
"""Detect launcher type, cached per process.

Order:
1. Explicit env var override (AMPLIHACK_LLM_PROVIDER /
SIMARD_LLM_PROVIDER) — wins unconditionally.
2. File-based detection via LauncherDetector (reads
<project_root>/.claude/runtime/launcher_context.json).
3. Fall back to "claude" if both are absent.
"""
global _detector_cache
if _detector_cache is not None:
return _detector_cache

override = _provider_from_env()
if override is not None:
_detector_cache = override
return override

try:
from amplihack.context.adaptive.detector import LauncherDetector

Expand Down Expand Up @@ -92,11 +139,34 @@ async def completion(
"""
project_root = _get_project_root()
launcher = _detect_launcher(project_root)
explicit_override = _provider_from_env()

# Build a single prompt from the messages list
prompt = _messages_to_prompt(messages)

try:
if explicit_override == "copilot":
if not _COPILOT_SDK_OK:
print(
"WARNING: AMPLIHACK_LLM_PROVIDER/SIMARD_LLM_PROVIDER=copilot but "
"the copilot SDK is not importable. Refusing to silently fall back "
"to Claude.",
file=sys.stderr,
)
return ""
return await _query_copilot(prompt, project_root)
if explicit_override == "claude":
if not _CLAUDE_SDK_OK:
print(
"WARNING: AMPLIHACK_LLM_PROVIDER/SIMARD_LLM_PROVIDER=claude but "
"the claude SDK is not importable. Refusing to silently fall back "
"to Copilot.",
file=sys.stderr,
)
return ""
return await _query_claude(prompt, project_root)

# No explicit override — use detected launcher with cross-SDK fallback.
if launcher == "copilot" and _COPILOT_SDK_OK:
return await _query_copilot(prompt, project_root)
if _CLAUDE_SDK_OK:
Expand Down Expand Up @@ -150,21 +220,27 @@ async def _query_claude(prompt: str, project_root: Path) -> str:


async def _query_copilot(prompt: str, project_root: Path) -> str:
"""Query via GitHub Copilot SDK."""
client = CopilotClient()
try:
await client.start()
session = await client.create_session(SessionConfig())
async with asyncio.timeout(QUERY_TIMEOUT):
event = await session.send_and_wait(
MessageOptions(prompt=prompt),
timeout=float(QUERY_TIMEOUT),
)
if event and hasattr(event, "data") and event.data:
return event.data.content or ""
return ""
finally:
"""Query via GitHub Copilot SDK (copilot >= 0.1.0)."""
async with CopilotClient() as client:
session = await client.create_session(
on_permission_request=PermissionHandler.approve_all,
working_directory=str(project_root),
)
try:
await client.stop()
except Exception:
pass
async with asyncio.timeout(QUERY_TIMEOUT):
event = await session.send_and_wait(
prompt,
timeout=float(QUERY_TIMEOUT),
)
if event is None:
return ""
data = getattr(event, "data", None)
if data is None:
return ""
content = getattr(data, "content", None)
return content or ""
finally:
try:
await session.destroy()
except Exception:
pass
Empty file added tests/llm/__init__.py
Empty file.
132 changes: 132 additions & 0 deletions tests/llm/test_provider_env_override.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Tests for amplihack.llm.client provider env-var override.

Regression coverage for the bug where embedded callers (Simard's Rust OODA
daemon and any other host that imports amplihack directly without going
through `amplihack copilot`) silently fell back to the bundled Claude Code
CLI — which is "Not logged in" by default — and produced empty completions
that were swallowed by metacognition_grader's JSON-parse-fail path.

The fix: AMPLIHACK_LLM_PROVIDER / SIMARD_LLM_PROVIDER env vars take priority
over file-based launcher detection.
"""

from __future__ import annotations

import importlib

import pytest


def _reload_client():
import amplihack.llm.client as client
importlib.reload(client)
client._detector_cache = None
return client


@pytest.fixture(autouse=True)
def _clear_provider_env(monkeypatch):
for var in ("AMPLIHACK_LLM_PROVIDER", "SIMARD_LLM_PROVIDER"):
monkeypatch.delenv(var, raising=False)
yield


def test_provider_from_env_returns_none_when_unset():
client = _reload_client()
assert client._provider_from_env() is None


@pytest.mark.parametrize(
"value,expected",
[
("copilot", "copilot"),
("Copilot", "copilot"),
(" COPILOT ", "copilot"),
("github-copilot", "copilot"),
("gh-copilot", "copilot"),
("rustyclawd", "copilot"),
("claude", "claude"),
("Claude-Code", "claude"),
("anthropic", "claude"),
],
)
def test_provider_from_env_recognized(monkeypatch, value, expected):
monkeypatch.setenv("AMPLIHACK_LLM_PROVIDER", value)
client = _reload_client()
assert client._provider_from_env() == expected


def test_simard_env_var_also_honored(monkeypatch):
monkeypatch.setenv("SIMARD_LLM_PROVIDER", "copilot")
client = _reload_client()
assert client._provider_from_env() == "copilot"


def test_amplihack_env_takes_priority_over_simard(monkeypatch):
monkeypatch.setenv("AMPLIHACK_LLM_PROVIDER", "claude")
monkeypatch.setenv("SIMARD_LLM_PROVIDER", "copilot")
client = _reload_client()
assert client._provider_from_env() == "claude"


def test_unrecognized_value_falls_through(monkeypatch):
monkeypatch.setenv("AMPLIHACK_LLM_PROVIDER", "ollama")
client = _reload_client()
assert client._provider_from_env() is None


def test_detect_launcher_uses_env_override_when_present(monkeypatch, tmp_path):
monkeypatch.setenv("SIMARD_LLM_PROVIDER", "copilot")
client = _reload_client()
assert client._detect_launcher(tmp_path) == "copilot"


def test_detect_launcher_uses_env_override_over_file(monkeypatch, tmp_path):
runtime = tmp_path / ".claude" / "runtime"
runtime.mkdir(parents=True)
(runtime / "launcher_context.json").write_text(
'{"launcher": "claude", "version": "1", "timestamp": "2025-01-01T00:00:00"}'
)
monkeypatch.setenv("AMPLIHACK_LLM_PROVIDER", "copilot")
client = _reload_client()
assert client._detect_launcher(tmp_path) == "copilot"


@pytest.mark.asyncio
async def test_completion_explicit_copilot_no_silent_claude_fallback(
monkeypatch,
):
monkeypatch.setenv("AMPLIHACK_LLM_PROVIDER", "copilot")
client = _reload_client()
monkeypatch.setattr(client, "_COPILOT_SDK_OK", False)
monkeypatch.setattr(client, "_CLAUDE_SDK_OK", True)

async def _boom_claude(prompt, project_root):
raise AssertionError("should not silently fall back to claude")

monkeypatch.setattr(client, "_query_claude", _boom_claude)

out = await client.completion(
messages=[{"role": "user", "content": "hi"}],
)
assert out == ""


@pytest.mark.asyncio
async def test_completion_explicit_claude_no_silent_copilot_fallback(
monkeypatch,
):
monkeypatch.setenv("AMPLIHACK_LLM_PROVIDER", "claude")
client = _reload_client()
monkeypatch.setattr(client, "_CLAUDE_SDK_OK", False)
monkeypatch.setattr(client, "_COPILOT_SDK_OK", True)

async def _boom_copilot(prompt, project_root):
raise AssertionError("should not silently fall back to copilot")

monkeypatch.setattr(client, "_query_copilot", _boom_copilot)

out = await client.completion(
messages=[{"role": "user", "content": "hi"}],
)
assert out == ""
Loading