Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 71 additions & 3 deletions astrbot/core/agent/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,34 @@
from .run_context import TContext
from .tool import FunctionTool


class _McpSseNoiseFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
try:
msg = record.getMessage().strip()
except Exception:
return True
if msg.startswith("Unknown SSE event:"):
event_name = msg.split(":", 1)[1].strip()
if event_name in {"stream", "connection"}:
return False
return True


def _install_mcp_noise_filters() -> None:
for logger_name in ("mcp.client.streamable_http", "mcp.client.sse"):
log = logging.getLogger(logger_name)
if any(isinstance(f, _McpSseNoiseFilter) for f in log.filters):
continue
log.addFilter(_McpSseNoiseFilter())


try:
import anyio
import mcp
from mcp.client.sse import sse_client

_install_mcp_noise_filters()
except (ModuleNotFoundError, ImportError):
logger.warning(
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
Expand All @@ -47,6 +71,8 @@ def _prepare_config(config: dict) -> dict:

async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"""Quick test MCP server connectivity"""
import json

import aiohttp

cfg = _prepare_config(config.copy())
Expand All @@ -55,6 +81,40 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
headers = cfg.get("headers", {})
timeout = cfg.get("timeout", 10)

async def _format_http_error(response: aiohttp.ClientResponse) -> str:
reason = response.reason or ""
detail = ""
try:
raw = await response.content.read(2048)
if raw:
text = raw.decode(errors="replace").strip()
if text:
try:
data = json.loads(text)
except Exception:
detail = text
else:
if isinstance(data, dict):
msg = (
data.get("message")
or data.get("error")
or data.get("detail")
)
code = data.get("code")
if msg is not None:
detail = (
f"{code}: {msg}" if code is not None else str(msg)
)
else:
detail = text
else:
detail = text
except Exception:
detail = ""
if detail:
return f"HTTP {response.status}: {reason} ({detail})"
return f"HTTP {response.status}: {reason}"

try:
if "transport" in cfg:
transport_type = cfg["transport"]
Expand All @@ -70,7 +130,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"method": "initialize",
"id": 0,
"params": {
"protocolVersion": "2024-11-05",
"protocolVersion": mcp.types.LATEST_PROTOCOL_VERSION,
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.2.3"},
},
Expand All @@ -87,7 +147,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
) as response:
if response.status == 200:
return True, ""
return False, f"HTTP {response.status}: {response.reason}"
return False, await _format_http_error(response)
else:
async with session.get(
url,
Expand All @@ -99,7 +159,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
) as response:
if response.status == 200:
return True, ""
return False, f"HTTP {response.status}: {response.reason}"
return False, await _format_http_error(response)

except asyncio.TimeoutError:
return False, f"Connection timeout: {timeout} seconds"
Expand Down Expand Up @@ -152,6 +212,14 @@ def logging_callback(
if msg.level in ("warning", "error", "critical", "alert", "emergency"):
log_msg = f"[{msg.level.upper()}] {str(msg.data)}"
self.server_errlogs.append(log_msg)
return
normalized = msg.strip()
if normalized.startswith("Unknown SSE event:"):
event_name = normalized.split(":", 1)[1].strip()
if event_name in {"stream", "connection"}:
return
print(f"MCP Server {name} Error: {msg}")
self.server_errlogs.append(msg)

if "url" in cfg:
success, error_msg = await _quick_test_mcp_connection(cfg)
Expand Down
Loading
Loading