Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions reflex/app_mixins/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ async def _run_lifespan_tasks(self, app: Starlette):
for task in running_tasks:
console.debug(f"Canceling lifespan task: {task}")
task.cancel(msg="lifespan_cleanup")
# Disassociate sid / token pairings so they can be reconnected properly.
try:
event_namespace = self.event_namespace # pyright: ignore[reportAttributeAccessIssue]
except AttributeError:
pass
else:
try:
if event_namespace:
await event_namespace._token_manager.disconnect_all()
except Exception as e:
console.error(f"Error during lifespan cleanup: {e}")

def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
"""Register a task to run during the lifespan of the app.
Expand Down
14 changes: 14 additions & 0 deletions reflex/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from reflex.utils import console, js_runtimes
from reflex.utils.export import export
from reflex.utils.token_manager import TokenManager
from reflex.utils.types import ASGIApp

try:
Expand Down Expand Up @@ -774,6 +775,19 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
self.app_instance._state_manager = app_state_manager
await self.state_manager.close()

def token_manager(self) -> TokenManager:
"""Get the token manager for the app instance.

Returns:
The current token_manager attached to the app's EventNamespace.
"""
assert self.app_instance is not None
app_event_namespace = self.app_instance.event_namespace
assert app_event_namespace is not None
app_token_manager = app_event_namespace._token_manager
assert app_token_manager is not None
return app_token_manager

def poll_for_content(
self,
element: WebElement,
Expand Down
10 changes: 10 additions & 0 deletions reflex/utils/token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ def create(cls) -> TokenManager:

return LocalTokenManager()

async def disconnect_all(self):
"""Disconnect all tracked tokens when the server is going down."""
token_sid_pairs: set[tuple[str, str]] = set(self.token_to_sid.items())
token_sid_pairs.update(
((token, sid) for sid, token in self.sid_to_token.items())
)
# Perform the disconnection logic here
for token, sid in token_sid_pairs:
await self.disconnect_token(token, sid)
Comment thread
masenf marked this conversation as resolved.


class LocalTokenManager(TokenManager):
"""Token manager using local in-memory dictionaries (single worker)."""
Expand Down
48 changes: 46 additions & 2 deletions tests/integration/test_connection_banner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

from reflex import constants
from reflex.environment import environment
from reflex.istate.manager import StateManagerRedis
from reflex.testing import AppHarness, WebDriver
from reflex.utils.token_manager import RedisTokenManager

from .utils import SessionStorage

Expand Down Expand Up @@ -127,17 +129,21 @@ def has_cloud_banner(driver: WebDriver) -> bool:
return True


def _assert_token(connection_banner, driver):
def _assert_token(connection_banner, driver) -> str:
"""Poll for backend to be up.

Args:
connection_banner: AppHarness instance.
driver: Selenium webdriver instance.

Returns:
The token if found, raises an assertion error otherwise.
"""
ss = SessionStorage(driver)
assert connection_banner._poll_for(lambda: ss.get("token") is not None), (
"token not found"
)
return ss.get("token")


@pytest.mark.asyncio
Expand All @@ -151,9 +157,22 @@ async def test_connection_banner(connection_banner: AppHarness):
assert connection_banner.backend is not None
driver = connection_banner.frontend()

_assert_token(connection_banner, driver)
token = _assert_token(connection_banner, driver)
AppHarness.expect(lambda: not has_error_modal(driver))

# Check that the token association was established.
app_token_manager = connection_banner.token_manager()
assert token in app_token_manager.token_to_sid
sid_before = app_token_manager.token_to_sid[token]
if isinstance(connection_banner.state_manager, StateManagerRedis):
assert isinstance(app_token_manager, RedisTokenManager)
assert (
await connection_banner.state_manager.redis.get(
app_token_manager._get_redis_key(token)
)
== b"1"
)

delay_button = driver.find_element(By.ID, "delay")
increment_button = driver.find_element(By.ID, "increment")
counter_element = driver.find_element(By.ID, "counter")
Expand All @@ -176,6 +195,17 @@ async def test_connection_banner(connection_banner: AppHarness):
# Error modal should now be displayed
AppHarness.expect(lambda: has_error_modal(driver))

# The token association should have been removed when the server exited.
assert token not in app_token_manager.token_to_sid
if isinstance(connection_banner.state_manager, StateManagerRedis):
assert isinstance(app_token_manager, RedisTokenManager)
assert (
await connection_banner.state_manager.redis.get(
app_token_manager._get_redis_key(token)
)
is None
)

# Increment the counter with backend down
increment_button.click()
assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1"
Expand All @@ -189,6 +219,20 @@ async def test_connection_banner(connection_banner: AppHarness):
# Banner should be gone now
AppHarness.expect(lambda: not has_error_modal(driver))

# After reconnecting, the token association should be re-established.
app_token_manager = connection_banner.token_manager()
if isinstance(connection_banner.state_manager, StateManagerRedis):
assert isinstance(app_token_manager, RedisTokenManager)
assert (
await connection_banner.state_manager.redis.get(
app_token_manager._get_redis_key(token)
)
== b"1"
)
# Make sure the new connection has a different websocket sid.
sid_after = app_token_manager.token_to_sid[token]
assert sid_before != sid_after

# Count should have incremented after coming back up
assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2"

Expand Down
Loading