diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index 1e8fdcf1..cd3be09d 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -170,7 +170,7 @@ async def _establish_new_connection( try: uri_and_metadata = await self._uri_and_metadata_factory() - ws = await websockets.connect(uri_and_metadata["uri"]) + ws = await websockets.connect(uri_and_metadata["uri"], max_size=None) session_id = ( self.generate_nanoid() if not old_session diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 99b30e82..9600f4cc 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1084,7 +1084,10 @@ async def _do_ensure_connected[HandshakeMetadata]( ws: ClientConnection | None = None try: uri_and_metadata = await uri_and_metadata_factory() - ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"]) + ws = await websockets.asyncio.client.connect( + uri_and_metadata["uri"], + max_size=None, + ) transition_connecting(ws) try: diff --git a/tests/v2/test_v2_session_lifecycle.py b/tests/v2/test_v2_session_lifecycle.py index 2fa8d0e2..bea6d2e0 100644 --- a/tests/v2/test_v2_session_lifecycle.py +++ b/tests/v2/test_v2_session_lifecycle.py @@ -1,17 +1,26 @@ import asyncio +import logging from typing import AsyncIterator, Awaitable, Callable, TypeAlias, TypedDict +import msgpack +import nanoid import pytest -from websockets import ConnectionClosedOK +from websockets import ConnectionClosed, ConnectionClosedOK from websockets.asyncio.server import ServerConnection, serve from websockets.typing import Data from replit_river.common_session import SessionState from replit_river.messages import parse_transport_msg from replit_river.rate_limiter import RateLimiter -from replit_river.rpc import TransportMessage +from replit_river.rpc import ( + ControlMessageHandshakeRequest, + ControlMessageHandshakeResponse, + HandShakeStatus, + TransportMessage, +) from replit_river.transport_options import TransportOptions, UriAndMetadata -from replit_river.v2.session import Session +from replit_river.v2.client import Client +from replit_river.v2.session import STREAM_CLOSED_BIT, Session class _PermissiveRateLimiter(RateLimiter): @@ -54,6 +63,8 @@ async def handle(websocket: ServerConnection) -> None: await recv.put(datagram) except ConnectionClosedOK: pass + except ConnectionClosed: + pass port: int | None = None if state["ipv4_laddr"]: @@ -65,7 +76,10 @@ async def handle(websocket: ServerConnection) -> None: state["ipv4_laddr"] = pair serve_forever = asyncio.create_task(server.serve_forever()) yield None - serve_forever.cancel() + server.close() + await server.wait_closed() + # "serve_forever" should always be done after wait_closed finishes + assert serve_forever.done() @pytest.fixture @@ -145,3 +159,89 @@ def close_session_callback(_session: Session) -> None: await connecting assert session._state == SessionState.CLOSED assert callcount == 1 + + +async def test_big_packet(ws_server: WsServerFixture) -> None: + (urimeta, recv, conn) = ws_server + + client = Client( + client_id="CLIENT1", + server_id="SERVER", + transport_options=TransportOptions(), + uri_and_metadata_factory=urimeta, + ) + + connecting = asyncio.create_task(client.ensure_connected()) + request_msg = parse_transport_msg(await recv.get()) + + assert not isinstance(request_msg, str) + assert (serverconn := conn()) + handshake_request: ControlMessageHandshakeRequest[None] = ( + ControlMessageHandshakeRequest(**request_msg.payload) + ) + + handshake_resp = ControlMessageHandshakeResponse( + status=HandShakeStatus( + ok=True, + ), + ) + handshake_request.sessionId + + msg = TransportMessage( + from_=request_msg.from_, + to=request_msg.to, + streamId=request_msg.streamId, + controlFlags=0, + id=nanoid.generate(), + seq=0, + ack=0, + payload=handshake_resp.model_dump(), + ) + packed = msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + await serverconn.send(packed) + + async def handle_server_messages() -> None: + request_msg = parse_transport_msg(await recv.get()) + assert not isinstance(request_msg, str) + msg = TransportMessage( + from_=request_msg.to, + to=request_msg.from_, + streamId=request_msg.streamId, + controlFlags=STREAM_CLOSED_BIT, + id=nanoid.generate(), + seq=0, + ack=0, + payload={ + "ok": True, + "payload": { + "big": "a" * (2**20 + 1), # One more than the default max_size + }, + }, + ) + + packed = msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + await serverconn.send(packed) + + stream_close_msg = msgpack.unpackb(await recv.get()) + assert stream_close_msg["controlFlags"] == STREAM_CLOSED_BIT + + stream_handler = asyncio.create_task(handle_server_messages()) + + try: + async for datagram in client.send_subscription( + "test", "bigstream", {}, lambda x: x, lambda x: x, lambda x: x + ): + print(datagram) + except Exception: + logging.exception("Interrupted") + + await client.close() + await connecting + + # Ensure we're listening to close messages as well + stream_handler.cancel() + await stream_handler