diff --git a/README.md b/README.md index 1ded369..18df988 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,7 @@ Create **.env** file under the repo root directory with your own API keys: ``` OPENAI_API_KEY=sk-proj-... DEEPSEEK_API_KEY=sk-... +MINIMAX_API_KEY=your-minimax-api-key ``` ## Record @@ -141,6 +142,16 @@ python run_replay.py --model-provider openai --wap_replay_list data_processed/ex ``` For **smart-replay**, replace the path with a smart‑replay JSON to test this mode. +Supported `--model-provider` values: `azure`, `anthropic`, `openai`, `ollama`, `minimax`. + +To use **MiniMax** (M2.7, 204K context): +```bash +# Set MINIMAX_API_KEY in your .env file, then run: +python run_replay.py --model-provider minimax --wap_replay_list data_processed/smart_replay/wap_smart_replay_list_.json --max-concurrent 1 +``` + +To use MiniMax in the **MCP client** (`mcp_client.py`), set `LLM_PROVIDER=minimax` and `MINIMAX_API_KEY` in your `.env` file. + ## Convert to MCP Server ```bash diff --git a/mcp_client.py b/mcp_client.py index 2c75ba9..cdc4811 100644 --- a/mcp_client.py +++ b/mcp_client.py @@ -225,6 +225,12 @@ class LLMClient: def __init__(self, api_key: str) -> None: self.api_key: str = api_key + self.provider: str = os.getenv("LLM_PROVIDER", "openai").lower() + # Resolve MiniMax key at init so tests can patch the environment once + if self.provider == "minimax": + self._minimax_key: str = os.getenv("MINIMAX_API_KEY", api_key) + else: + self._minimax_key = "" def get_response(self, messages: list[dict[str, str]]) -> str: """Get a response from the LLM. @@ -238,15 +244,22 @@ def get_response(self, messages: list[dict[str, str]]) -> str: Raises: httpx.RequestError: If the request to the LLM fails. """ - url = "https://api.openai.com/v1/chat/completions" + if self.provider == "minimax": + url = "https://api.minimax.io/v1/chat/completions" + api_key = self._minimax_key + model = "MiniMax-M2.7" + else: + url = "https://api.openai.com/v1/chat/completions" + api_key = self.api_key + model = "gpt-4o" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", + "Authorization": f"Bearer {api_key}", } payload = { "messages": messages, - "model": "gpt-4o" + "model": model } try: diff --git a/run_replay.py b/run_replay.py index a766f22..f692aca 100644 --- a/run_replay.py +++ b/run_replay.py @@ -112,6 +112,14 @@ def get_llm_model_generator( elif model_provider == "ollama": llm = ChatOllama(model="ota-preview-v16", num_ctx=20000,temperature=0) yield llm + elif model_provider == "minimax": + llm = ChatOpenAI( + model="MiniMax-M2.7", + temperature=0.01, + openai_api_key=os.getenv("MINIMAX_API_KEY", ""), + openai_api_base="https://api.minimax.io/v1", + ) + yield llm else: raise ValueError(f"Invalid model provider: {model_provider}") @@ -258,6 +266,7 @@ async def process_with_semaphore( "anthropic", "openai", "ollama", + "minimax", ], ) parser.add_argument( diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_minimax_provider.py b/tests/test_minimax_provider.py new file mode 100644 index 0000000..909b479 --- /dev/null +++ b/tests/test_minimax_provider.py @@ -0,0 +1,386 @@ +"""Unit and integration tests for MiniMax provider support.""" + +import os +import sys +import types +import unittest +from unittest.mock import MagicMock, patch + +# --------------------------------------------------------------------------- +# Stub heavy dependencies so tests can run without installing them +# --------------------------------------------------------------------------- + +def _make_stub(name): + mod = types.ModuleType(name) + sys.modules[name] = mod + return mod + + +for _dep in [ + "browser_use", + "browser_use.browser", + "browser_use.browser.context", + "langchain_anthropic", + "langchain_ollama", + "dotenv", +]: + if _dep not in sys.modules: + _make_stub(_dep) + +# Stub dotenv.load_dotenv +sys.modules["dotenv"].load_dotenv = lambda *a, **kw: None # type: ignore + +# Stub browser_use classes +for _cls in ("Agent", "Browser", "BrowserConfig"): + setattr(sys.modules["browser_use"], _cls, MagicMock) +sys.modules["browser_use.browser.context"].BrowserContextConfig = MagicMock + +# Stub langchain_anthropic.ChatAnthropic +sys.modules["langchain_anthropic"].ChatAnthropic = MagicMock # type: ignore + +# Stub langchain_ollama.ChatOllama +sys.modules["langchain_ollama"].ChatOllama = MagicMock # type: ignore + +# Stub langchain_openai — we need a real-ish ChatOpenAI to assert on +_lo = _make_stub("langchain_openai") + + +class _FakeChatOpenAI: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class _FakeAzureChatOpenAI(_FakeChatOpenAI): + pass + + +_lo.ChatOpenAI = _FakeChatOpenAI # type: ignore +_lo.AzureChatOpenAI = _FakeAzureChatOpenAI # type: ignore + +# Now import the module under test +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +import run_replay # noqa: E402 + + +# --------------------------------------------------------------------------- +# Unit tests – run_replay.py +# --------------------------------------------------------------------------- + + +class TestMiniMaxProvider(unittest.TestCase): + + def _next(self, provider): + gen = run_replay.get_llm_model_generator(provider) + return next(gen) + + # -- happy-path --------------------------------------------------------- + + def test_minimax_yields_chat_openai_instance(self): + model = self._next("minimax") + self.assertIsInstance(model, _FakeChatOpenAI) + + def test_minimax_base_url(self): + model = self._next("minimax") + self.assertEqual(model.openai_api_base, "https://api.minimax.io/v1") + + def test_minimax_model_name(self): + model = self._next("minimax") + self.assertEqual(model.model, "MiniMax-M2.7") + + def test_minimax_temperature_positive(self): + model = self._next("minimax") + self.assertGreater(model.temperature, 0.0) + + def test_minimax_temperature_below_one(self): + model = self._next("minimax") + self.assertLessEqual(model.temperature, 1.0) + + def test_minimax_uses_env_api_key(self): + with patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key-123"}): + model = self._next("minimax") + self.assertEqual(model.openai_api_key, "test-key-123") + + # -- generator is infinite ---------------------------------------------- + + def test_minimax_generator_yields_multiple(self): + gen = run_replay.get_llm_model_generator("minimax") + m1 = next(gen) + m2 = next(gen) + self.assertIsInstance(m1, _FakeChatOpenAI) + self.assertIsInstance(m2, _FakeChatOpenAI) + + # -- existing providers not broken -------------------------------------- + + def test_openai_provider_still_works(self): + model = self._next("openai") + self.assertIsInstance(model, _FakeChatOpenAI) + + def test_anthropic_provider_still_works(self): + model = self._next("anthropic") + # ChatAnthropic is mocked + self.assertIsNotNone(model) + + def test_ollama_provider_still_works(self): + model = self._next("ollama") + self.assertIsNotNone(model) + + # -- invalid provider --------------------------------------------------- + + def test_invalid_provider_raises(self): + with self.assertRaises(ValueError): + next(run_replay.get_llm_model_generator("unknown-provider")) + + # -- argparse choices --------------------------------------------------- + + def test_minimax_in_argparse_choices(self): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-provider", + choices=["azure", "anthropic", "openai", "ollama", "minimax"], + ) + args = parser.parse_args(["--model-provider", "minimax"]) + self.assertEqual(args.model_provider, "minimax") + + def test_minimax_choice_accepted_without_error(self): + """Verify run_replay's own parser accepts 'minimax'.""" + import importlib + import io + import contextlib + + # Parse only the choices list by inspecting argparse actions + # We do this without executing main() + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-provider", + type=str, + default="azure", + choices=["azure", "anthropic", "openai", "ollama", "minimax"], + ) + ns, _ = parser.parse_known_args(["--model-provider", "minimax"]) + self.assertEqual(ns.model_provider, "minimax") + + # -- temperature constraint documented ---------------------------------- + + def test_temperature_is_not_zero(self): + """MiniMax requires temperature > 0.""" + model = self._next("minimax") + self.assertNotEqual(model.temperature, 0.0) + + +# --------------------------------------------------------------------------- +# Unit tests – mcp_client.py (LLMClient with MiniMax) +# --------------------------------------------------------------------------- + +# Stub mcp dependencies +for _dep in ["mcp", "mcp.client", "mcp.client.stdio"]: + if _dep not in sys.modules: + _make_stub(_dep) + +sys.modules["mcp"].ClientSession = MagicMock # type: ignore +sys.modules["mcp"].StdioServerParameters = MagicMock # type: ignore +sys.modules["mcp.client.stdio"].stdio_client = MagicMock # type: ignore + +import mcp_client # noqa: E402 + + +class TestMCPLLMClientMiniMax(unittest.TestCase): + + def _make_client(self, provider, minimax_key="", llm_key="default-key"): + env = {"LLM_PROVIDER": provider} + if minimax_key: + env["MINIMAX_API_KEY"] = minimax_key + with patch.dict(os.environ, env, clear=False): + return mcp_client.LLMClient(api_key=llm_key) + + def _ok_response(self, text="Hello"): + mock_resp = MagicMock() + mock_resp.json.return_value = { + "choices": [{"message": {"content": text}}] + } + mock_resp.raise_for_status = MagicMock() + return mock_resp + + # -- provider routing --------------------------------------------------- + + def test_minimax_provider_set(self): + client = self._make_client("minimax") + self.assertEqual(client.provider, "minimax") + + def test_openai_provider_default(self): + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("LLM_PROVIDER", None) + client = mcp_client.LLMClient(api_key="k") + self.assertEqual(client.provider, "openai") + + def test_minimax_uses_minimax_endpoint(self): + client = self._make_client("minimax", minimax_key="mmkey") + messages = [{"role": "user", "content": "hi"}] + + captured = {} + + def fake_post(url, headers=None, json=None): + captured["url"] = url + captured["auth"] = headers.get("Authorization", "") + captured["model"] = json.get("model", "") + return self._ok_response() + + with patch("httpx.Client") as mock_ctx: + mock_client = MagicMock() + mock_ctx.return_value.__enter__.return_value = mock_client + mock_client.post.side_effect = fake_post + client.get_response(messages) + + self.assertIn("minimax.io", captured["url"]) + self.assertIn("mmkey", captured["auth"]) + self.assertEqual(captured["model"], "MiniMax-M2.7") + + def test_openai_uses_openai_endpoint(self): + client = self._make_client("openai", llm_key="oai-key") + captured = {} + + def fake_post(url, headers=None, json=None): + captured["url"] = url + return self._ok_response() + + with patch("httpx.Client") as mock_ctx: + mock_client = MagicMock() + mock_ctx.return_value.__enter__.return_value = mock_client + mock_client.post.side_effect = fake_post + client.get_response([{"role": "user", "content": "hi"}]) + + self.assertIn("openai.com", captured["url"]) + + def test_minimax_model_m27(self): + client = self._make_client("minimax", minimax_key="k") + captured = {} + + def fake_post(url, headers=None, json=None): + captured["model"] = json.get("model") + return self._ok_response() + + with patch("httpx.Client") as mock_ctx: + mock_client = MagicMock() + mock_ctx.return_value.__enter__.return_value = mock_client + mock_client.post.side_effect = fake_post + client.get_response([{"role": "user", "content": "test"}]) + + self.assertEqual(captured["model"], "MiniMax-M2.7") + + def test_get_response_returns_content(self): + client = self._make_client("minimax", minimax_key="k") + + with patch("httpx.Client") as mock_ctx: + mock_client = MagicMock() + mock_ctx.return_value.__enter__.return_value = mock_client + mock_client.post.return_value = self._ok_response("Hello MiniMax!") + result = client.get_response([{"role": "user", "content": "hi"}]) + + self.assertEqual(result, "Hello MiniMax!") + + def test_error_returns_friendly_message(self): + import httpx as _httpx + + client = self._make_client("minimax", minimax_key="k") + + with patch("httpx.Client") as mock_ctx: + mock_client = MagicMock() + mock_ctx.return_value.__enter__.return_value = mock_client + mock_client.post.side_effect = _httpx.RequestError("timeout") + result = client.get_response([{"role": "user", "content": "hi"}]) + + self.assertIn("error", result.lower()) + + # -- fallback to api_key when MINIMAX_API_KEY absent -------------------- + + def test_minimax_falls_back_to_api_key(self): + env = {"LLM_PROVIDER": "minimax"} + env.pop("MINIMAX_API_KEY", None) + with patch.dict(os.environ, env, clear=False): + os.environ.pop("MINIMAX_API_KEY", None) + client = mcp_client.LLMClient(api_key="fallback-key") + + captured = {} + + def fake_post(url, headers=None, json=None): + captured["auth"] = headers.get("Authorization", "") + return self._ok_response() + + with patch("httpx.Client") as mock_ctx: + mock_client = MagicMock() + mock_ctx.return_value.__enter__.return_value = mock_client + mock_client.post.side_effect = fake_post + client.get_response([{"role": "user", "content": "hi"}]) + + self.assertIn("fallback-key", captured["auth"]) + + +# --------------------------------------------------------------------------- +# Integration test – live API (skipped when MINIMAX_API_KEY not set) +# --------------------------------------------------------------------------- + + +class TestMiniMaxIntegration(unittest.TestCase): + + @unittest.skipUnless(os.getenv("MINIMAX_API_KEY"), "MINIMAX_API_KEY not set") + def test_live_chat_completion(self): + """Call MiniMax API and verify a non-empty response.""" + import httpx + + api_key = os.getenv("MINIMAX_API_KEY", "") + url = "https://api.minimax.io/v1/chat/completions" + payload = { + "model": "MiniMax-M2.7", + "messages": [{"role": "user", "content": "Say 'ok' in one word."}], + "temperature": 0.01, + "max_tokens": 10, + } + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + with httpx.Client(timeout=30) as client: + resp = client.post(url, headers=headers, json=payload) + self.assertEqual(resp.status_code, 200) + data = resp.json() + content = data["choices"][0]["message"]["content"] + self.assertIsInstance(content, str) + self.assertGreater(len(content), 0) + + @unittest.skipUnless(os.getenv("MINIMAX_API_KEY"), "MINIMAX_API_KEY not set") + def test_live_mcp_client_minimax(self): + """LLMClient with MiniMax provider returns a non-empty string.""" + with patch.dict(os.environ, {"LLM_PROVIDER": "minimax"}): + client = mcp_client.LLMClient(api_key=os.getenv("MINIMAX_API_KEY", "")) + result = client.get_response([{"role": "user", "content": "Say hi."}]) + self.assertIsInstance(result, str) + self.assertNotIn("error", result.lower()) + + @unittest.skipUnless(os.getenv("MINIMAX_API_KEY"), "MINIMAX_API_KEY not set") + def test_live_temperature_constraint(self): + """MiniMax rejects temperature=0; our default (0.01) must succeed.""" + import httpx + + api_key = os.getenv("MINIMAX_API_KEY", "") + url = "https://api.minimax.io/v1/chat/completions" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + # Our default: 0.01 + payload = { + "model": "MiniMax-M2.7", + "messages": [{"role": "user", "content": "hi"}], + "temperature": 0.01, + "max_tokens": 5, + } + with httpx.Client(timeout=30) as client: + resp = client.post(url, headers=headers, json=payload) + self.assertEqual(resp.status_code, 200) + + +if __name__ == "__main__": + unittest.main()