From 5c6b33aea9864470371be457afa9ac1a26ff0e44 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 19 Sep 2025 11:32:22 -0700 Subject: [PATCH 1/3] Remove token/sid associations when server is exiting This allows existing tokens to reconnect to redis after a hot or cold reload of the app. Otherwise, the old associations for the token remain in place and when the same client reconnects, it is given a new_token, since the requested token is already "taken" in redis. --- reflex/app_mixins/lifespan.py | 11 +++++++++++ reflex/utils/token_manager.py | 10 ++++++++++ 2 files changed, 21 insertions(+) diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py index 6079fdd4d8d..43e3574a442 100644 --- a/reflex/app_mixins/lifespan.py +++ b/reflex/app_mixins/lifespan.py @@ -60,6 +60,17 @@ async def _run_lifespan_tasks(self, app: Starlette): for task in running_tasks: console.debug(f"Canceling lifespan task: {task}") task.cancel(msg="lifespan_cleanup") + # Disassociate sid / token pairings so they can be reconnected properly. + try: + event_namespace = self.event_namespace # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + pass + else: + try: + if event_namespace: + await event_namespace._token_manager.disconnect_all() + except Exception as e: + console.error(f"Error during lifespan cleanup: {e}") def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs): """Register a task to run during the lifespan of the app. diff --git a/reflex/utils/token_manager.py b/reflex/utils/token_manager.py index ebc5e36cc11..b60b77e9743 100644 --- a/reflex/utils/token_manager.py +++ b/reflex/utils/token_manager.py @@ -66,6 +66,16 @@ def create(cls) -> TokenManager: return LocalTokenManager() + async def disconnect_all(self): + """Disconnect all tracked tokens when the server is going down.""" + token_sid_pairs: set[tuple[str, str]] = set(self.token_to_sid.items()) + token_sid_pairs.update( + ((token, sid) for sid, token in self.sid_to_token.items()) + ) + # Perform the disconnection logic here + for token, sid in token_sid_pairs: + await self.disconnect_token(token, sid) + class LocalTokenManager(TokenManager): """Token manager using local in-memory dictionaries (single worker).""" From 6ba230a5795a77dacf27e95f265247efe510fe95 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 22 Sep 2025 12:32:27 -0700 Subject: [PATCH 2/3] test_connection_banner: assert that token/sid association removed on shutdown --- tests/integration/test_connection_banner.py | 50 ++++++++++++++++++++- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index 6ff246cadfa..6e5dc8ab609 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -8,7 +8,9 @@ from reflex import constants from reflex.environment import environment +from reflex.istate.manager import StateManagerRedis from reflex.testing import AppHarness, WebDriver +from reflex.utils.token_manager import RedisTokenManager from .utils import SessionStorage @@ -127,17 +129,21 @@ def has_cloud_banner(driver: WebDriver) -> bool: return True -def _assert_token(connection_banner, driver): +def _assert_token(connection_banner, driver) -> str: """Poll for backend to be up. Args: connection_banner: AppHarness instance. driver: Selenium webdriver instance. + + Returns: + The token if found, raises an assertion error otherwise. """ ss = SessionStorage(driver) assert connection_banner._poll_for(lambda: ss.get("token") is not None), ( "token not found" ) + return ss.get("token") @pytest.mark.asyncio @@ -151,9 +157,25 @@ async def test_connection_banner(connection_banner: AppHarness): assert connection_banner.backend is not None driver = connection_banner.frontend() - _assert_token(connection_banner, driver) + token = _assert_token(connection_banner, driver) AppHarness.expect(lambda: not has_error_modal(driver)) + # Check that the token association was established. + app_event_namespace = connection_banner.app_instance.event_namespace + assert app_event_namespace is not None + app_token_manager = app_event_namespace._token_manager + assert app_token_manager is not None + assert token in app_token_manager.token_to_sid + sid_before = app_token_manager.token_to_sid[token] + if isinstance(connection_banner.state_manager, StateManagerRedis): + assert isinstance(app_token_manager, RedisTokenManager) + assert ( + await connection_banner.state_manager.redis.get( + app_token_manager._get_redis_key(token) + ) + == b"1" + ) + delay_button = driver.find_element(By.ID, "delay") increment_button = driver.find_element(By.ID, "increment") counter_element = driver.find_element(By.ID, "counter") @@ -176,6 +198,17 @@ async def test_connection_banner(connection_banner: AppHarness): # Error modal should now be displayed AppHarness.expect(lambda: has_error_modal(driver)) + # The token association should have been removed when the server exited. + assert token not in app_token_manager.token_to_sid + if isinstance(connection_banner.state_manager, StateManagerRedis): + assert isinstance(app_token_manager, RedisTokenManager) + assert ( + await connection_banner.state_manager.redis.get( + app_token_manager._get_redis_key(token) + ) + is None + ) + # Increment the counter with backend down increment_button.click() assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1" @@ -189,6 +222,19 @@ async def test_connection_banner(connection_banner: AppHarness): # Banner should be gone now AppHarness.expect(lambda: not has_error_modal(driver)) + # After reconnecting, the token association should be re-established. + if isinstance(connection_banner.state_manager, StateManagerRedis): + assert isinstance(app_token_manager, RedisTokenManager) + assert ( + await connection_banner.state_manager.redis.get( + app_token_manager._get_redis_key(token) + ) + == b"1" + ) + # Make sure the new connection has a different websocket sid. + sid_after = app_token_manager.token_to_sid[token] + assert sid_before != sid_after + # Count should have incremented after coming back up assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2" From 224c0f9ddce6bae1c9bd1343fab92c761ad7eff8 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 22 Sep 2025 13:28:09 -0700 Subject: [PATCH 3/3] Re-fetch the token_manager after restarting backend --- reflex/testing.py | 14 ++++++++++++++ tests/integration/test_connection_banner.py | 6 ++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/reflex/testing.py b/reflex/testing.py index 53f9f82d6ae..2f04b08ad1f 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -47,6 +47,7 @@ ) from reflex.utils import console, js_runtimes from reflex.utils.export import export +from reflex.utils.token_manager import TokenManager from reflex.utils.types import ASGIApp try: @@ -774,6 +775,19 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: self.app_instance._state_manager = app_state_manager await self.state_manager.close() + def token_manager(self) -> TokenManager: + """Get the token manager for the app instance. + + Returns: + The current token_manager attached to the app's EventNamespace. + """ + assert self.app_instance is not None + app_event_namespace = self.app_instance.event_namespace + assert app_event_namespace is not None + app_token_manager = app_event_namespace._token_manager + assert app_token_manager is not None + return app_token_manager + def poll_for_content( self, element: WebElement, diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index 6e5dc8ab609..044d431edb7 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -161,10 +161,7 @@ async def test_connection_banner(connection_banner: AppHarness): AppHarness.expect(lambda: not has_error_modal(driver)) # Check that the token association was established. - app_event_namespace = connection_banner.app_instance.event_namespace - assert app_event_namespace is not None - app_token_manager = app_event_namespace._token_manager - assert app_token_manager is not None + app_token_manager = connection_banner.token_manager() assert token in app_token_manager.token_to_sid sid_before = app_token_manager.token_to_sid[token] if isinstance(connection_banner.state_manager, StateManagerRedis): @@ -223,6 +220,7 @@ async def test_connection_banner(connection_banner: AppHarness): AppHarness.expect(lambda: not has_error_modal(driver)) # After reconnecting, the token association should be re-established. + app_token_manager = connection_banner.token_manager() if isinstance(connection_banner.state_manager, StateManagerRedis): assert isinstance(app_token_manager, RedisTokenManager) assert (