From 924b8432e2a2bbfd1985f81424b2556bb2d0b6f9 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 19 Sep 2025 16:33:46 -0700 Subject: [PATCH 1/5] emit_update: take `token` instead of `sid` This allows the app to be more resilient in the face of websocket reconnects. The event is processed against a token, so there's no reason to maintain websocket affinity for event processing. Whenever the update is ready to send, it will be sent to the current websocket/sid associated. --- reflex/app.py | 20 ++++++++++---------- reflex/istate/proxy.py | 16 +++++++--------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 8792188fefd..61e76324e0d 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -97,6 +97,7 @@ State, StateManager, StateUpdate, + _split_substate_key, _substate_key, all_base_state_classes, code_uses_state_contexts, @@ -1559,7 +1560,7 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: state._clean() await self.event_namespace.emit_update( update=StateUpdate(delta=delta), - sid=state.router.session.session_id, + token=token, ) def _process_background( @@ -1599,7 +1600,7 @@ async def _coro(): # Send the update to the client. await self.event_namespace.emit_update( update=update, - sid=state.router.session.session_id, + token=event.token, ) task = asyncio.create_task( @@ -2061,20 +2062,19 @@ def on_disconnect(self, sid: str): and console.error(f"Token cleanup error: {t.exception()}") ) - async def emit_update(self, update: StateUpdate, sid: str) -> None: + async def emit_update(self, update: StateUpdate, token: str) -> None: """Emit an update to the client. Args: update: The state update to send. - sid: The Socket.IO session id. + token: The client token (tab) associated with the event. """ - if not sid: + client_token, _ = _split_substate_key(token) + sid = self.token_to_sid.get(client_token) + if sid is None: # If the sid is None, we are not connected to a client. Prevent sending # updates to all clients. - return - token = self.sid_to_token.get(sid) - if token is None: - console.warn(f"Attempting to send delta to disconnected websocket {sid}") + console.warn(f"Attempting to send delta to disconnected client {token!r}") return # Creating a task prevents the update from being blocked behind other coroutines. await asyncio.create_task( @@ -2165,7 +2165,7 @@ async def on_event(self, sid: str, data: Any): # Process the events. async for update in updates_gen: # Emit the update from processing the event. - await self.emit_update(update=update, sid=sid) + await self.emit_update(update=update, token=event.token) async def on_ping(self, sid: str): """Event for testing the API endpoint. diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index e152036cac7..5c61237685a 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -71,10 +71,15 @@ def __init__( state_instance: The state instance to proxy. parent_state_proxy: The parent state proxy, for linked mutability and context tracking. """ + from reflex.state import _substate_key + super().__init__(state_instance) - # compile is not relevant to backend logic self._self_app = prerequisites.get_and_validate_app().app self._self_substate_path = tuple(state_instance.get_full_name().split(".")) + self._self_substate_token = _substate_key( + state_instance.router.session.client_token, + self._self_substate_path, + ) self._self_actx = None self._self_mutable = False self._self_actx_lock = asyncio.Lock() @@ -127,16 +132,9 @@ async def __aenter__(self) -> StateProxy: msg = "The state is already mutable. Do not nest `async with self` blocks." raise ImmutableStateError(msg) - from reflex.state import _substate_key - await self._self_actx_lock.acquire() self._self_actx_lock_holder = current_task - self._self_actx = self._self_app.modify_state( - token=_substate_key( - self.__wrapped__.router.session.client_token, - self._self_substate_path, - ) - ) + self._self_actx = self._self_app.modify_state(token=self._self_substate_token) mutable_state = await self._self_actx.__aenter__() super().__setattr__( "__wrapped__", mutable_state.get_substate(self._self_substate_path) From 5ad988e9be1233c47edc0cdc8264018f3fa1e156 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 19 Sep 2025 16:38:38 -0700 Subject: [PATCH 2/5] Automatic websocket reconnect and reload handling * ensureSocketConnected is called when adding events or pumping the queue to trigger an automatic reconnection to the backend * when "reload" event is encountered, trigger a re-hydrate and wait until ALL on_load have finished processing and `is_hydrated` is True before requeue the event that caused the "reload" --- reflex/.templates/web/utils/state.js | 65 ++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 17 deletions(-) diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 290628eb745..35ed0c16dc8 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -531,6 +531,7 @@ export const connect = async ( ) => { // Get backend URL object from the endpoint. const endpoint = getBackendURL(EVENTURL); + const on_hydrated_queue = []; // Create the socket. socket.current = io(endpoint.href, { @@ -552,7 +553,17 @@ export const connect = async ( function checkVisibility() { if (document.visibilityState === "visible") { - if (!socket.current.connected) { + if (!socket.current) { + connect( + socket, + dispatch, + transports, + setConnectErrors, + client_storage, + navigate, + params, + ); + } else if (!socket.current.connected) { console.log("Socket is disconnected, attempting to reconnect "); socket.current.connect(); } else { @@ -593,6 +604,7 @@ export const connect = async ( // When the socket disconnects reset the event_processing flag socket.current.on("disconnect", () => { + socket.current = null; // allow reconnect to occur automatically event_processing = false; window.removeEventListener("unload", disconnectTrigger); window.removeEventListener("beforeunload", disconnectTrigger); @@ -603,6 +615,14 @@ export const connect = async ( socket.current.on("event", async (update) => { for (const substate in update.delta) { dispatch[substate](update.delta[substate]); + // handle events waiting for `is_hydrated` + if ( + substate === state_name && + update.delta[substate]?.is_hydrated_rx_state_ + ) { + queueEvents(on_hydrated_queue, socket, false, navigate, params); + on_hydrated_queue.length = 0; + } } applyClientStorageDelta(client_storage, update.delta); event_processing = !update.final; @@ -612,7 +632,8 @@ export const connect = async ( }); socket.current.on("reload", async (event) => { event_processing = false; - queueEvents([...initialEvents(), event], socket, true, navigate, params); + on_hydrated_queue.push(event); + queueEvents(initialEvents(), socket, true, navigate, params); }); socket.current.on("new_token", async (new_token) => { token = new_token; @@ -774,10 +795,32 @@ export const useEventLoop = ( } }, [paramsR]); + const ensureSocketConnected = useCallback(async () => { + // only use websockets if state is present and backend is not disabled (reflex cloud). + if ( + Object.keys(initialState).length > 1 && + !isBackendDisabled() && + !socket.current + ) { + // Initialize the websocket connection. + await connect( + socket, + dispatch, + ["websocket"], + setConnectErrors, + client_storage, + navigate, + () => params.current, + ); + } + }, [socket, dispatch, setConnectErrors, client_storage, navigate, params]); + // Function to add new events to the event queue. const addEvents = useCallback((events, args, event_actions) => { const _events = events.filter((e) => e !== undefined && e !== null); + ensureSocketConnected(); + if (!(args instanceof Array)) { args = [args]; } @@ -870,21 +913,8 @@ export const useEventLoop = ( // Handle socket connect/disconnect. useEffect(() => { - // only use websockets if state is present and backend is not disabled (reflex cloud). - if (Object.keys(initialState).length > 1 && !isBackendDisabled()) { - // Initialize the websocket connection. - if (!socket.current) { - connect( - socket, - dispatch, - ["websocket"], - setConnectErrors, - client_storage, - navigate, - () => params.current, - ); - } - } + // Initialize the websocket connection. + ensureSocketConnected(); // Cleanup function. return () => { @@ -903,6 +933,7 @@ export const useEventLoop = ( (async () => { // Process all outstanding events. while (event_queue.length > 0 && !event_processing) { + await ensureSocketConnected(); await processEvent(socket.current, navigate, () => params.current); } })(); From c09db7ffbe7aa4365ff326e5c72c1f3076ef370a Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 22 Sep 2025 09:19:52 -0700 Subject: [PATCH 3/5] Update mock token_to_sid mapping for test --- tests/units/test_state.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index d1c728d5e81..12db8328e8f 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -2005,6 +2005,7 @@ async def test_state_proxy( namespace = mock_app.event_namespace assert namespace is not None namespace.sid_to_token[router_data.session.session_id] = token + namespace.token_to_sid[token] = router_data.session.session_id if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): mock_app.state_manager.states[parent_state.router.session.client_token] = ( parent_state @@ -2214,6 +2215,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): namespace = mock_app.event_namespace assert namespace is not None namespace.sid_to_token[sid] = token + namespace.token_to_sid[token] = sid mock_app.state_manager.state = mock_app._state = BackgroundTaskState async for update in rx.app.process( mock_app, From d16208be5423fb63da75c0064f2477b760f04f2b Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 22 Sep 2025 13:15:31 -0700 Subject: [PATCH 4/5] Add disconnect/reconnect test to test_background_task --- tests/integration/test_background_task.py | 69 +++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index c5d7fe8f844..6d7f904aeff 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -109,6 +109,22 @@ async def yield_in_async_with_self(self): yield self.counter += 1 + @rx.event + async def disconnect_reconnect(self): + self.counter += 1 + yield rx.call_script("socket.disconnect()") + await asyncio.sleep(0.5) + self.counter += 1 + + @rx.event(background=True) + async def disconnect_reconnect_background(self): + async with self: + self.counter += 1 + yield rx.call_script("socket.disconnect()") + await asyncio.sleep(0.5) + async with self: + self.counter += 1 + class OtherState(rx.State): @rx.event(background=True) async def get_other_state(self): @@ -134,6 +150,9 @@ def index() -> rx.Component: rx.input( id="token", value=State.router.session.client_token, is_read_only=True ), + rx.input( + id="sid", value=State.router.session.session_id, is_read_only=True + ), rx.hstack( rx.heading(State.counter, id="counter"), rx.text(State.counter_async_cv, size="1", id="counter-async-cv"), @@ -185,6 +204,16 @@ def index() -> rx.Component: on_click=State.yield_in_async_with_self, id="yield-in-async-with-self", ), + rx.button( + "Disconnect / Reconnect", + on_click=State.disconnect_reconnect_background, + id="disconnect-reconnect", + ), + rx.button( + "Disconnect / Reconnect Background", + on_click=State.disconnect_reconnect_background, + id="disconnect-reconnect-background", + ), rx.button("Reset", on_click=State.reset_counter, id="reset"), ) @@ -395,3 +424,43 @@ def test_yield_in_async_with_self( yield_in_async_with_self_button.click() AppHarness.expect(lambda: counter.text == "2", timeout=5) + + +@pytest.mark.parametrize( + "button_id", + [ + "disconnect-reconnect", + "disconnect-reconnect-background", + ], +) +def test_disconnect_reconnect( + background_task: AppHarness, + driver: WebDriver, + token: str, + button_id: str, +): + """Test that disconnecting and reconnecting works as expected. + + Args: + background_task: harness for BackgroundTask app. + driver: WebDriver instance. + token: The token for the connected client. + button_id: The ID of the button to click. + """ + counter = driver.find_element(By.ID, "counter") + button = driver.find_element(By.ID, button_id) + increment_button = driver.find_element(By.ID, "increment") + sid_input = driver.find_element(By.ID, "sid") + sid = background_task.poll_for_value(sid_input, timeout=5) + assert sid is not None + + AppHarness.expect(lambda: counter.text == "0", timeout=5) + button.click() + AppHarness.expect(lambda: counter.text == "1", timeout=5) + increment_button.click() + # should get a new sid after the reconnect + assert ( + background_task.poll_for_value(sid_input, timeout=5, exp_not_equal=sid) != sid + ) + # Final update should come through on the new websocket connection + AppHarness.expect(lambda: counter.text == "3", timeout=5) From 31ecf955a036a75eb9b0fe59516d66e22c91212b Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 22 Sep 2025 13:39:01 -0700 Subject: [PATCH 5/5] Remove non-background disconnect/reconnect test It doesn't really work, because the frontend will only process one non-background event at a time, so the disconnect ends up occuring after the event handler is already done. --- tests/integration/test_background_task.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index 6d7f904aeff..b95fa0cf925 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -109,13 +109,6 @@ async def yield_in_async_with_self(self): yield self.counter += 1 - @rx.event - async def disconnect_reconnect(self): - self.counter += 1 - yield rx.call_script("socket.disconnect()") - await asyncio.sleep(0.5) - self.counter += 1 - @rx.event(background=True) async def disconnect_reconnect_background(self): async with self: @@ -204,11 +197,6 @@ def index() -> rx.Component: on_click=State.yield_in_async_with_self, id="yield-in-async-with-self", ), - rx.button( - "Disconnect / Reconnect", - on_click=State.disconnect_reconnect_background, - id="disconnect-reconnect", - ), rx.button( "Disconnect / Reconnect Background", on_click=State.disconnect_reconnect_background, @@ -429,7 +417,6 @@ def test_yield_in_async_with_self( @pytest.mark.parametrize( "button_id", [ - "disconnect-reconnect", "disconnect-reconnect-background", ], )