diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py index 6079fdd4d8d..43e3574a442 100644 --- a/reflex/app_mixins/lifespan.py +++ b/reflex/app_mixins/lifespan.py @@ -60,6 +60,17 @@ async def _run_lifespan_tasks(self, app: Starlette): for task in running_tasks: console.debug(f"Canceling lifespan task: {task}") task.cancel(msg="lifespan_cleanup") + # Disassociate sid / token pairings so they can be reconnected properly. + try: + event_namespace = self.event_namespace # pyright: ignore[reportAttributeAccessIssue] + except AttributeError: + pass + else: + try: + if event_namespace: + await event_namespace._token_manager.disconnect_all() + except Exception as e: + console.error(f"Error during lifespan cleanup: {e}") def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs): """Register a task to run during the lifespan of the app. diff --git a/reflex/testing.py b/reflex/testing.py index 53f9f82d6ae..2f04b08ad1f 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -47,6 +47,7 @@ ) from reflex.utils import console, js_runtimes from reflex.utils.export import export +from reflex.utils.token_manager import TokenManager from reflex.utils.types import ASGIApp try: @@ -774,6 +775,19 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: self.app_instance._state_manager = app_state_manager await self.state_manager.close() + def token_manager(self) -> TokenManager: + """Get the token manager for the app instance. + + Returns: + The current token_manager attached to the app's EventNamespace. + """ + assert self.app_instance is not None + app_event_namespace = self.app_instance.event_namespace + assert app_event_namespace is not None + app_token_manager = app_event_namespace._token_manager + assert app_token_manager is not None + return app_token_manager + def poll_for_content( self, element: WebElement, diff --git a/reflex/utils/token_manager.py b/reflex/utils/token_manager.py index ebc5e36cc11..b60b77e9743 100644 --- a/reflex/utils/token_manager.py +++ b/reflex/utils/token_manager.py @@ -66,6 +66,16 @@ def create(cls) -> TokenManager: return LocalTokenManager() + async def disconnect_all(self): + """Disconnect all tracked tokens when the server is going down.""" + token_sid_pairs: set[tuple[str, str]] = set(self.token_to_sid.items()) + token_sid_pairs.update( + ((token, sid) for sid, token in self.sid_to_token.items()) + ) + # Perform the disconnection logic here + for token, sid in token_sid_pairs: + await self.disconnect_token(token, sid) + class LocalTokenManager(TokenManager): """Token manager using local in-memory dictionaries (single worker).""" diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index 6ff246cadfa..044d431edb7 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -8,7 +8,9 @@ from reflex import constants from reflex.environment import environment +from reflex.istate.manager import StateManagerRedis from reflex.testing import AppHarness, WebDriver +from reflex.utils.token_manager import RedisTokenManager from .utils import SessionStorage @@ -127,17 +129,21 @@ def has_cloud_banner(driver: WebDriver) -> bool: return True -def _assert_token(connection_banner, driver): +def _assert_token(connection_banner, driver) -> str: """Poll for backend to be up. Args: connection_banner: AppHarness instance. driver: Selenium webdriver instance. + + Returns: + The token if found, raises an assertion error otherwise. """ ss = SessionStorage(driver) assert connection_banner._poll_for(lambda: ss.get("token") is not None), ( "token not found" ) + return ss.get("token") @pytest.mark.asyncio @@ -151,9 +157,22 @@ async def test_connection_banner(connection_banner: AppHarness): assert connection_banner.backend is not None driver = connection_banner.frontend() - _assert_token(connection_banner, driver) + token = _assert_token(connection_banner, driver) AppHarness.expect(lambda: not has_error_modal(driver)) + # Check that the token association was established. + app_token_manager = connection_banner.token_manager() + assert token in app_token_manager.token_to_sid + sid_before = app_token_manager.token_to_sid[token] + if isinstance(connection_banner.state_manager, StateManagerRedis): + assert isinstance(app_token_manager, RedisTokenManager) + assert ( + await connection_banner.state_manager.redis.get( + app_token_manager._get_redis_key(token) + ) + == b"1" + ) + delay_button = driver.find_element(By.ID, "delay") increment_button = driver.find_element(By.ID, "increment") counter_element = driver.find_element(By.ID, "counter") @@ -176,6 +195,17 @@ async def test_connection_banner(connection_banner: AppHarness): # Error modal should now be displayed AppHarness.expect(lambda: has_error_modal(driver)) + # The token association should have been removed when the server exited. + assert token not in app_token_manager.token_to_sid + if isinstance(connection_banner.state_manager, StateManagerRedis): + assert isinstance(app_token_manager, RedisTokenManager) + assert ( + await connection_banner.state_manager.redis.get( + app_token_manager._get_redis_key(token) + ) + is None + ) + # Increment the counter with backend down increment_button.click() assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1" @@ -189,6 +219,20 @@ async def test_connection_banner(connection_banner: AppHarness): # Banner should be gone now AppHarness.expect(lambda: not has_error_modal(driver)) + # After reconnecting, the token association should be re-established. + app_token_manager = connection_banner.token_manager() + if isinstance(connection_banner.state_manager, StateManagerRedis): + assert isinstance(app_token_manager, RedisTokenManager) + assert ( + await connection_banner.state_manager.redis.get( + app_token_manager._get_redis_key(token) + ) + == b"1" + ) + # Make sure the new connection has a different websocket sid. + sid_after = app_token_manager.token_to_sid[token] + assert sid_before != sid_after + # Count should have incremented after coming back up assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2"