Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ async def __call__(
async def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
logger.warning("Unhandled exception in message handler: %s", message)
await anyio.lowlevel.checkpoint()


Expand Down
44 changes: 29 additions & 15 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

# Reconnection defaults
DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry
MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up
MAX_RECONNECTION_ATTEMPTS = 5 # Max retry attempts before giving up


class StreamableHTTPError(Exception):
Expand Down Expand Up @@ -197,7 +197,9 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")

received_events = False
async for sse in event_source.aiter_sse():
received_events = True
# Track last event ID for reconnection
if sse.id:
last_event_id = sse.id
Expand All @@ -207,8 +209,9 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer:

await self._handle_sse_event(sse, read_stream_writer)

# Stream ended normally (server closed) - reset attempt counter
attempt = 0
# Only reset attempts if we actually received events;
# empty connections count toward MAX_RECONNECTION_ATTEMPTS
attempt = 0 if received_events else attempt + 1

except Exception: # pragma: lax no cover
logger.debug("GET stream error", exc_info=True)
Expand Down Expand Up @@ -364,25 +367,36 @@ async def _handle_sse_response(
await response.aclose()
return # Normal completion, no reconnect needed
except Exception:
logger.debug("SSE stream ended", exc_info=True) # pragma: no cover
logger.debug("SSE stream error", exc_info=True)

# Stream ended without response - reconnect if we received an event with ID
if last_event_id is not None: # pragma: no branch
# Stream ended without a complete response — attempt reconnection if possible
if last_event_id is not None:
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)
if await self._handle_reconnection(ctx, last_event_id, retry_interval_ms):
return # Reconnection delivered the response

# No response delivered — unblock the waiting request with an error
error_data = ErrorData(code=INTERNAL_ERROR, message="SSE stream ended without a response")
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data))
await ctx.read_stream_writer.send(error_msg)

async def _handle_reconnection(
self,
ctx: RequestContext,
last_event_id: str,
retry_interval_ms: int | None = None,
attempt: int = 0,
) -> None:
"""Reconnect with Last-Event-ID to resume stream after server disconnect."""
) -> bool:
"""Reconnect with Last-Event-ID to resume stream after server disconnect.

Returns:
True if the response was successfully delivered, False if max
reconnection attempts were exceeded without delivering a response.
"""
# Bail if max retries exceeded
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
if attempt >= MAX_RECONNECTION_ATTEMPTS:
logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
return
return False

# Always wait - use server value or default
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
Expand Down Expand Up @@ -419,15 +433,15 @@ async def _handle_reconnection(
)
if is_complete:
await event_source.response.aclose()
return
return True

# Stream ended again without response - reconnect again (reset attempt counter)
# Stream ended again without response - reconnect again
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0)
return await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, attempt + 1)
except Exception as e: # pragma: no cover
logger.debug(f"Reconnection failed: {e}")
# Try to reconnect again if we still have an event ID
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)
return await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)

async def post_writer(
self,
Expand Down
85 changes: 84 additions & 1 deletion tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@

from mcp import MCPError, types
from mcp.client.session import ClientSession
from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client
from mcp.client.streamable_http import (
MAX_RECONNECTION_ATTEMPTS,
StreamableHTTPTransport,
streamable_http_client,
)
from mcp.client.streamable_http import (
RequestContext as TransportRequestContext,
)
from mcp.server import Server, ServerRequestContext
from mcp.server.streamable_http import (
MCP_PROTOCOL_VERSION_HEADER,
Expand Down Expand Up @@ -2247,3 +2254,79 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers(

assert "content-type" in headers_data
assert headers_data["content-type"] == "application/json"


@pytest.mark.anyio
async def test_sse_read_timeout_propagates_error(basic_server: None, basic_server_url: str):
"""SSE read timeout should propagate MCPError instead of hanging."""
# Create client with very short SSE read timeout
short_timeout = httpx.Timeout(30.0, read=0.5)
async with httpx.AsyncClient(timeout=short_timeout, follow_redirects=True) as http_client:
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=http_client) as (
read_stream,
write_stream,
):
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
await session.initialize()

# Read a "slow" resource that takes 2s — longer than our 0.5s read timeout
with pytest.raises(MCPError): # pragma: no branch
with anyio.fail_after(10): # pragma: no branch
await session.read_resource("slow://test")


@pytest.mark.anyio
async def test_sse_error_when_reconnection_exhausted(
event_server: tuple[SimpleEventStore, str],
monkeypatch: pytest.MonkeyPatch,
):
"""When SSE stream closes after events and reconnection fails, MCPError is raised."""
_, server_url = event_server

async def _always_fail_reconnection(
self: Any, ctx: Any, last_event_id: Any, retry_interval_ms: Any = None, attempt: int = 0
) -> bool:
return False

monkeypatch.setattr(StreamableHTTPTransport, "_handle_reconnection", _always_fail_reconnection)

async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
await session.initialize()

# tool_with_stream_close sends a priming event (setting last_event_id),
# then closes the SSE stream. With reconnection patched to fail,
# _handle_sse_response falls through to send the error.
with pytest.raises(MCPError): # pragma: no branch
with anyio.fail_after(10): # pragma: no branch
await session.call_tool("tool_with_stream_close", {})


@pytest.mark.anyio
async def test_handle_reconnection_returns_false_on_max_attempts():
"""_handle_reconnection returns False when max attempts exceeded."""
transport = StreamableHTTPTransport(url="http://localhost:9999/mcp")

read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1)

message = JSONRPCRequest(jsonrpc="2.0", id=42, method="tools/call", params={"name": "test"})
session_message = SessionMessage(message)

ctx = TransportRequestContext(
client=httpx.AsyncClient(),
session_id="test-session",
session_message=session_message,
metadata=None,
read_stream_writer=read_stream_writer,
)

try:
with anyio.fail_after(5):
result = await transport._handle_reconnection(
ctx, last_event_id="evt-1", retry_interval_ms=None, attempt=MAX_RECONNECTION_ATTEMPTS
)
assert result is False
finally:
await read_stream_writer.aclose()
await read_stream.aclose()
await ctx.client.aclose()