diff --git a/reflex/__init__.py b/reflex/__init__.py index 066df110f02..e80421fcb01 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -342,6 +342,7 @@ "utils.imports": ["ImportDict", "ImportVar"], "utils.misc": ["run_in_thread"], "utils.serializers": ["serializer"], + "utils.token_manager": ["get_token_manager"], "vars": ["Var", "field", "Field"], } diff --git a/reflex/app.py b/reflex/app.py index 4ff412ef863..c32f0ba6519 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -2080,7 +2080,11 @@ async def on_connect(self, sid: str, environ: dict): query_params = urllib.parse.parse_qs(environ.get("QUERY_STRING", "")) token_list = query_params.get("token", []) if token_list: - await self.link_token_to_sid(sid, token_list[0]) + token = token_list[0] + await self.link_token_to_sid(sid, token) + # Notify lifecycle watchers that this token/sid has connected. + actual_token = self._token_manager.sid_to_token.get(sid, token) + self._token_manager._notify_connect(actual_token, sid) else: console.warn(f"No token provided in connection for session {sid}") @@ -2102,6 +2106,8 @@ def on_disconnect(self, sid: str) -> asyncio.Task | None: # Get token before cleaning up disconnect_token = self.sid_to_token.get(sid) if disconnect_token: + # Notify lifecycle watchers before cleanup removes the mappings. + self._token_manager._notify_disconnect(disconnect_token, sid) # Use async cleanup through token manager task = asyncio.create_task( self._token_manager.disconnect_token(disconnect_token, sid), diff --git a/reflex/utils/token_manager.py b/reflex/utils/token_manager.py index 514d641cea6..6ee705d333f 100644 --- a/reflex/utils/token_manager.py +++ b/reflex/utils/token_manager.py @@ -20,6 +20,10 @@ from redis.asyncio import Redis +class _TokenNotConnectedError(Exception): + """Raised when a token is not connected.""" + + def _get_new_token() -> str: """Generate a new unique token. @@ -56,6 +60,10 @@ def __init__(self): self.token_to_socket: dict[str, SocketRecord] = {} # Keep a mapping between socket ID and client token. self.sid_to_token: dict[str, str] = {} + # Lifecycle events for connect/disconnect notifications. + self._token_disconnect_events: dict[str, list[asyncio.Event]] = {} + self._sid_disconnect_events: dict[str, list[asyncio.Event]] = {} + self._token_connect_events: dict[str, list[asyncio.Event]] = {} @property def token_to_sid(self) -> MappingProxyType[str, str]: @@ -124,6 +132,145 @@ async def disconnect_all(self): for token, sid in token_sid_pairs: await self.disconnect_token(token, sid) + def _notify_connect(self, token: str, sid: str) -> None: + """Notify lifecycle watchers that a token/sid has connected. + + Args: + token: The client token. + sid: The Socket.IO session ID. + """ + for event in self._token_connect_events.pop(token, []): + event.set() + + def _notify_disconnect(self, token: str, sid: str) -> None: + """Notify lifecycle watchers that a token/sid has disconnected. + + Args: + token: The client token. + sid: The Socket.IO session ID. + """ + for event in self._token_disconnect_events.pop(token, []): + event.set() + for event in self._sid_disconnect_events.pop(sid, []): + event.set() + + async def session_is_connected(self, sid: str) -> AsyncIterator[str]: + """Yield the client token, then block until the session disconnects. + + Yields the client token once, then suspends until the session + disconnects. Use with ``async for`` or ``contextlib.aclosing``. + + Args: + sid: The Socket.IO session ID. + + Yields: + The client token associated with the session. + + Raises: + _TokenNotConnectedError: If the session is not currently connected. + """ + token = self.sid_to_token.get(sid) + if token is None: + raise _TokenNotConnectedError( + f"Session {sid!r} is not currently connected." + ) + disconnect_event = asyncio.Event() + self._sid_disconnect_events.setdefault(sid, []).append(disconnect_event) + try: + yield token + await disconnect_event.wait() + finally: + events = self._sid_disconnect_events.get(sid, []) + if disconnect_event in events: + events.remove(disconnect_event) + + async def token_is_connected(self, client_token: str) -> AsyncIterator[str]: + """Yield the session ID, then block until the token disconnects. + + Yields the session ID once, then suspends until the token + disconnects. Use with ``async for`` or ``contextlib.aclosing``. + + Args: + client_token: The client token. + + Yields: + The session ID associated with the token. + + Raises: + _TokenNotConnectedError: If the token is not currently connected. + """ + socket_record = self.token_to_socket.get(client_token) + if socket_record is None: + raise _TokenNotConnectedError( + f"Token {client_token!r} is not currently connected." + ) + disconnect_event = asyncio.Event() + self._token_disconnect_events.setdefault(client_token, []).append( + disconnect_event + ) + try: + yield socket_record.sid + await disconnect_event.wait() + finally: + events = self._token_disconnect_events.get(client_token, []) + if disconnect_event in events: + events.remove(disconnect_event) + + def when_session_disconnects(self, sid: str) -> asyncio.Event: + """Return an asyncio.Event that is set when the session disconnects. + + Args: + sid: The Socket.IO session ID. + + Returns: + An asyncio.Event that will be set on disconnect. + """ + event = asyncio.Event() + if sid not in self.sid_to_token: + # Already disconnected, set immediately. + event.set() + else: + self._sid_disconnect_events.setdefault(sid, []).append(event) + return event + + def when_token_disconnects(self, client_token: str) -> asyncio.Event: + """Return an asyncio.Event that is set when the token disconnects. + + Args: + client_token: The client token. + + Returns: + An asyncio.Event that will be set on disconnect. + """ + event = asyncio.Event() + if client_token not in self.token_to_socket: + # Already disconnected, set immediately. + event.set() + else: + self._token_disconnect_events.setdefault(client_token, []).append( + event + ) + return event + + def when_token_connects(self, client_token: str) -> asyncio.Event: + """Return an asyncio.Event that is set when the token connects. + + Args: + client_token: The client token. + + Returns: + An asyncio.Event that will be set on connect. + """ + event = asyncio.Event() + if client_token in self.token_to_socket: + # Already connected, set immediately. + event.set() + else: + self._token_connect_events.setdefault(client_token, []).append( + event + ) + return event + class LocalTokenManager(TokenManager): """Token manager using local in-memory dictionaries (single worker).""" @@ -464,3 +611,21 @@ async def emit_lost_and_found( else: return True return False + + +def get_token_manager() -> TokenManager: + """Get the token manager for the currently running app. + + Returns: + The active TokenManager instance. + + Raises: + RuntimeError: If the app or event namespace is not initialized. + """ + app_mod = prerequisites.get_and_validate_app() + app = app_mod.app + event_namespace = app.event_namespace + if event_namespace is None: + msg = "Event namespace is not initialized. Is the app running?" + raise RuntimeError(msg) + return event_namespace._token_manager diff --git a/tests/units/utils/test_token_manager.py b/tests/units/utils/test_token_manager.py index 9f740a29f37..a2b8b5e8b80 100644 --- a/tests/units/utils/test_token_manager.py +++ b/tests/units/utils/test_token_manager.py @@ -18,6 +18,8 @@ RedisTokenManager, SocketRecord, TokenManager, + _TokenNotConnectedError, + get_token_manager, ) @@ -221,6 +223,243 @@ async def test_enumerate_tokens(self, manager): assert not found_tokens +class TestTokenManagerLifecycle: + """Tests for TokenManager lifecycle APIs (issue #5669).""" + + @pytest.fixture + def manager(self): + """Create a LocalTokenManager instance. + + Returns: + A LocalTokenManager instance for testing. + """ + return LocalTokenManager() + + async def test_when_token_disconnects_connected(self, manager): + """Event not set while connected, set after disconnect. + + Args: + manager: LocalTokenManager fixture instance. + """ + await manager.link_token_to_sid("tok1", "sid1") + evt = manager.when_token_disconnects("tok1") + assert not evt.is_set() + manager._notify_disconnect("tok1", "sid1") + assert evt.is_set() + + async def test_when_token_disconnects_already_disconnected(self, manager): + """Event set immediately for unknown token. + + Args: + manager: LocalTokenManager fixture instance. + """ + evt = manager.when_token_disconnects("nonexistent") + assert evt.is_set() + + async def test_when_session_disconnects_connected(self, manager): + """Event not set while connected, set after disconnect. + + Args: + manager: LocalTokenManager fixture instance. + """ + await manager.link_token_to_sid("tok2", "sid2") + evt = manager.when_session_disconnects("sid2") + assert not evt.is_set() + manager._notify_disconnect("tok2", "sid2") + assert evt.is_set() + + async def test_when_session_disconnects_already_disconnected(self, manager): + """Event set immediately for unknown sid. + + Args: + manager: LocalTokenManager fixture instance. + """ + evt = manager.when_session_disconnects("nonexistent") + assert evt.is_set() + + async def test_when_token_connects_not_yet(self, manager): + """Event set after connect. + + Args: + manager: LocalTokenManager fixture instance. + """ + evt = manager.when_token_connects("future_tok") + assert not evt.is_set() + await manager.link_token_to_sid("future_tok", "sid_f") + manager._notify_connect("future_tok", "sid_f") + assert evt.is_set() + + async def test_when_token_connects_already_connected(self, manager): + """Event set immediately for already connected token. + + Args: + manager: LocalTokenManager fixture instance. + """ + await manager.link_token_to_sid("tok3", "sid3") + evt = manager.when_token_connects("tok3") + assert evt.is_set() + + async def test_session_is_connected_yields_and_stops(self, manager): + """Yields token once, then awaits disconnect. + + Args: + manager: LocalTokenManager fixture instance. + """ + import contextlib + + await manager.link_token_to_sid("tok4", "sid4") + async with contextlib.aclosing( + manager.session_is_connected("sid4") + ) as gen: + token = await gen.__anext__() + assert token == "tok4" + # Trigger disconnect so the await inside the iterator completes. + manager._notify_disconnect("tok4", "sid4") + + async def test_session_is_connected_raises_for_unknown(self, manager): + """Raises for unknown sid. + + Args: + manager: LocalTokenManager fixture instance. + """ + with pytest.raises(_TokenNotConnectedError): + async for _ in manager.session_is_connected("unknown_sid"): + pass + + async def test_token_is_connected_yields_and_stops(self, manager): + """Yields sid once, then awaits disconnect. + + Args: + manager: LocalTokenManager fixture instance. + """ + import contextlib + + await manager.link_token_to_sid("tok5", "sid5") + async with contextlib.aclosing( + manager.token_is_connected("tok5") + ) as gen: + sid = await gen.__anext__() + assert sid == "sid5" + # Trigger disconnect so the await inside the iterator completes. + manager._notify_disconnect("tok5", "sid5") + + async def test_token_is_connected_raises_for_unknown(self, manager): + """Raises for unknown token. + + Args: + manager: LocalTokenManager fixture instance. + """ + with pytest.raises(_TokenNotConnectedError): + async for _ in manager.token_is_connected("unknown_tok"): + pass + + async def test_multiple_watchers_token_disconnect(self, manager): + """Multiple watchers on same token all get notified. + + Args: + manager: LocalTokenManager fixture instance. + """ + await manager.link_token_to_sid("tok6", "sid6") + evt1 = manager.when_token_disconnects("tok6") + evt2 = manager.when_token_disconnects("tok6") + assert not evt1.is_set() and not evt2.is_set() + manager._notify_disconnect("tok6", "sid6") + assert evt1.is_set() and evt2.is_set() + + async def test_multiple_watchers_session_disconnect(self, manager): + """Multiple session watchers all get notified. + + Args: + manager: LocalTokenManager fixture instance. + """ + await manager.link_token_to_sid("tok7", "sid7") + evt1 = manager.when_session_disconnects("sid7") + evt2 = manager.when_session_disconnects("sid7") + manager._notify_disconnect("tok7", "sid7") + assert evt1.is_set() and evt2.is_set() + + async def test_notify_connect_only_matching(self, manager): + """_notify_connect only fires for matching token. + + Args: + manager: LocalTokenManager fixture instance. + """ + evt_a = manager.when_token_connects("tok_a") + evt_b = manager.when_token_connects("tok_b") + await manager.link_token_to_sid("tok_a", "sid_a") + manager._notify_connect("tok_a", "sid_a") + assert evt_a.is_set() + assert not evt_b.is_set() + + async def test_notify_disconnect_only_matching(self, manager): + """_notify_disconnect only fires for matching token. + + Args: + manager: LocalTokenManager fixture instance. + """ + await manager.link_token_to_sid("tok_c", "sid_c") + await manager.link_token_to_sid("tok_d", "sid_d") + evt_c = manager.when_token_disconnects("tok_c") + evt_d = manager.when_token_disconnects("tok_d") + manager._notify_disconnect("tok_c", "sid_c") + assert evt_c.is_set() + assert not evt_d.is_set() + + async def test_cleanup_after_disconnect_notify(self, manager): + """Events dict cleaned up after notify. + + Args: + manager: LocalTokenManager fixture instance. + """ + await manager.link_token_to_sid("tok8", "sid8") + manager.when_token_disconnects("tok8") + assert "tok8" in manager._token_disconnect_events + manager._notify_disconnect("tok8", "sid8") + assert "tok8" not in manager._token_disconnect_events + + async def test_session_iterator_cleanup_with_aclosing(self, manager): + """Events list cleaned up when using aclosing. + + Args: + manager: LocalTokenManager fixture instance. + """ + import contextlib + + await manager.link_token_to_sid("tok9", "sid9") + async with contextlib.aclosing( + manager.session_is_connected("sid9") + ) as gen: + async for _token in gen: + break + assert len(manager._sid_disconnect_events.get("sid9", [])) == 0 + + async def test_token_iterator_cleanup_with_aclosing(self, manager): + """Events list cleaned up when using aclosing. + + Args: + manager: LocalTokenManager fixture instance. + """ + import contextlib + + await manager.link_token_to_sid("tok10", "sid10") + async with contextlib.aclosing( + manager.token_is_connected("tok10") + ) as gen: + async for _sid in gen: + break + assert len(manager._token_disconnect_events.get("tok10", [])) == 0 + + def test_get_token_manager_callable(self): + """get_token_manager is importable and callable.""" + assert callable(get_token_manager) + + def test_rx_mapping_has_get_token_manager(self): + """rx.__init__ has get_token_manager in its lazy mapping.""" + from reflex import _MAPPING + + assert "get_token_manager" in _MAPPING.get("utils.token_manager", []) + + class TestRedisTokenManager: """Tests for RedisTokenManager."""