diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 290628eb745..35ed0c16dc8 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -531,6 +531,7 @@ export const connect = async ( ) => { // Get backend URL object from the endpoint. const endpoint = getBackendURL(EVENTURL); + const on_hydrated_queue = []; // Create the socket. socket.current = io(endpoint.href, { @@ -552,7 +553,17 @@ export const connect = async ( function checkVisibility() { if (document.visibilityState === "visible") { - if (!socket.current.connected) { + if (!socket.current) { + connect( + socket, + dispatch, + transports, + setConnectErrors, + client_storage, + navigate, + params, + ); + } else if (!socket.current.connected) { console.log("Socket is disconnected, attempting to reconnect "); socket.current.connect(); } else { @@ -593,6 +604,7 @@ export const connect = async ( // When the socket disconnects reset the event_processing flag socket.current.on("disconnect", () => { + socket.current = null; // allow reconnect to occur automatically event_processing = false; window.removeEventListener("unload", disconnectTrigger); window.removeEventListener("beforeunload", disconnectTrigger); @@ -603,6 +615,14 @@ export const connect = async ( socket.current.on("event", async (update) => { for (const substate in update.delta) { dispatch[substate](update.delta[substate]); + // handle events waiting for `is_hydrated` + if ( + substate === state_name && + update.delta[substate]?.is_hydrated_rx_state_ + ) { + queueEvents(on_hydrated_queue, socket, false, navigate, params); + on_hydrated_queue.length = 0; + } } applyClientStorageDelta(client_storage, update.delta); event_processing = !update.final; @@ -612,7 +632,8 @@ export const connect = async ( }); socket.current.on("reload", async (event) => { event_processing = false; - queueEvents([...initialEvents(), event], socket, true, navigate, params); + on_hydrated_queue.push(event); + queueEvents(initialEvents(), socket, true, navigate, params); }); socket.current.on("new_token", async (new_token) => { token = new_token; @@ -774,10 +795,32 @@ export const useEventLoop = ( } }, [paramsR]); + const ensureSocketConnected = useCallback(async () => { + // only use websockets if state is present and backend is not disabled (reflex cloud). + if ( + Object.keys(initialState).length > 1 && + !isBackendDisabled() && + !socket.current + ) { + // Initialize the websocket connection. + await connect( + socket, + dispatch, + ["websocket"], + setConnectErrors, + client_storage, + navigate, + () => params.current, + ); + } + }, [socket, dispatch, setConnectErrors, client_storage, navigate, params]); + // Function to add new events to the event queue. const addEvents = useCallback((events, args, event_actions) => { const _events = events.filter((e) => e !== undefined && e !== null); + ensureSocketConnected(); + if (!(args instanceof Array)) { args = [args]; } @@ -870,21 +913,8 @@ export const useEventLoop = ( // Handle socket connect/disconnect. useEffect(() => { - // only use websockets if state is present and backend is not disabled (reflex cloud). - if (Object.keys(initialState).length > 1 && !isBackendDisabled()) { - // Initialize the websocket connection. - if (!socket.current) { - connect( - socket, - dispatch, - ["websocket"], - setConnectErrors, - client_storage, - navigate, - () => params.current, - ); - } - } + // Initialize the websocket connection. + ensureSocketConnected(); // Cleanup function. return () => { @@ -903,6 +933,7 @@ export const useEventLoop = ( (async () => { // Process all outstanding events. while (event_queue.length > 0 && !event_processing) { + await ensureSocketConnected(); await processEvent(socket.current, navigate, () => params.current); } })(); diff --git a/reflex/app.py b/reflex/app.py index 8792188fefd..61e76324e0d 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -97,6 +97,7 @@ State, StateManager, StateUpdate, + _split_substate_key, _substate_key, all_base_state_classes, code_uses_state_contexts, @@ -1559,7 +1560,7 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: state._clean() await self.event_namespace.emit_update( update=StateUpdate(delta=delta), - sid=state.router.session.session_id, + token=token, ) def _process_background( @@ -1599,7 +1600,7 @@ async def _coro(): # Send the update to the client. await self.event_namespace.emit_update( update=update, - sid=state.router.session.session_id, + token=event.token, ) task = asyncio.create_task( @@ -2061,20 +2062,19 @@ def on_disconnect(self, sid: str): and console.error(f"Token cleanup error: {t.exception()}") ) - async def emit_update(self, update: StateUpdate, sid: str) -> None: + async def emit_update(self, update: StateUpdate, token: str) -> None: """Emit an update to the client. Args: update: The state update to send. - sid: The Socket.IO session id. + token: The client token (tab) associated with the event. """ - if not sid: + client_token, _ = _split_substate_key(token) + sid = self.token_to_sid.get(client_token) + if sid is None: # If the sid is None, we are not connected to a client. Prevent sending # updates to all clients. - return - token = self.sid_to_token.get(sid) - if token is None: - console.warn(f"Attempting to send delta to disconnected websocket {sid}") + console.warn(f"Attempting to send delta to disconnected client {token!r}") return # Creating a task prevents the update from being blocked behind other coroutines. await asyncio.create_task( @@ -2165,7 +2165,7 @@ async def on_event(self, sid: str, data: Any): # Process the events. async for update in updates_gen: # Emit the update from processing the event. - await self.emit_update(update=update, sid=sid) + await self.emit_update(update=update, token=event.token) async def on_ping(self, sid: str): """Event for testing the API endpoint. diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index e152036cac7..5c61237685a 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -71,10 +71,15 @@ def __init__( state_instance: The state instance to proxy. parent_state_proxy: The parent state proxy, for linked mutability and context tracking. """ + from reflex.state import _substate_key + super().__init__(state_instance) - # compile is not relevant to backend logic self._self_app = prerequisites.get_and_validate_app().app self._self_substate_path = tuple(state_instance.get_full_name().split(".")) + self._self_substate_token = _substate_key( + state_instance.router.session.client_token, + self._self_substate_path, + ) self._self_actx = None self._self_mutable = False self._self_actx_lock = asyncio.Lock() @@ -127,16 +132,9 @@ async def __aenter__(self) -> StateProxy: msg = "The state is already mutable. Do not nest `async with self` blocks." raise ImmutableStateError(msg) - from reflex.state import _substate_key - await self._self_actx_lock.acquire() self._self_actx_lock_holder = current_task - self._self_actx = self._self_app.modify_state( - token=_substate_key( - self.__wrapped__.router.session.client_token, - self._self_substate_path, - ) - ) + self._self_actx = self._self_app.modify_state(token=self._self_substate_token) mutable_state = await self._self_actx.__aenter__() super().__setattr__( "__wrapped__", mutable_state.get_substate(self._self_substate_path) diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index c5d7fe8f844..b95fa0cf925 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -109,6 +109,15 @@ async def yield_in_async_with_self(self): yield self.counter += 1 + @rx.event(background=True) + async def disconnect_reconnect_background(self): + async with self: + self.counter += 1 + yield rx.call_script("socket.disconnect()") + await asyncio.sleep(0.5) + async with self: + self.counter += 1 + class OtherState(rx.State): @rx.event(background=True) async def get_other_state(self): @@ -134,6 +143,9 @@ def index() -> rx.Component: rx.input( id="token", value=State.router.session.client_token, is_read_only=True ), + rx.input( + id="sid", value=State.router.session.session_id, is_read_only=True + ), rx.hstack( rx.heading(State.counter, id="counter"), rx.text(State.counter_async_cv, size="1", id="counter-async-cv"), @@ -185,6 +197,11 @@ def index() -> rx.Component: on_click=State.yield_in_async_with_self, id="yield-in-async-with-self", ), + rx.button( + "Disconnect / Reconnect Background", + on_click=State.disconnect_reconnect_background, + id="disconnect-reconnect-background", + ), rx.button("Reset", on_click=State.reset_counter, id="reset"), ) @@ -395,3 +412,42 @@ def test_yield_in_async_with_self( yield_in_async_with_self_button.click() AppHarness.expect(lambda: counter.text == "2", timeout=5) + + +@pytest.mark.parametrize( + "button_id", + [ + "disconnect-reconnect-background", + ], +) +def test_disconnect_reconnect( + background_task: AppHarness, + driver: WebDriver, + token: str, + button_id: str, +): + """Test that disconnecting and reconnecting works as expected. + + Args: + background_task: harness for BackgroundTask app. + driver: WebDriver instance. + token: The token for the connected client. + button_id: The ID of the button to click. + """ + counter = driver.find_element(By.ID, "counter") + button = driver.find_element(By.ID, button_id) + increment_button = driver.find_element(By.ID, "increment") + sid_input = driver.find_element(By.ID, "sid") + sid = background_task.poll_for_value(sid_input, timeout=5) + assert sid is not None + + AppHarness.expect(lambda: counter.text == "0", timeout=5) + button.click() + AppHarness.expect(lambda: counter.text == "1", timeout=5) + increment_button.click() + # should get a new sid after the reconnect + assert ( + background_task.poll_for_value(sid_input, timeout=5, exp_not_equal=sid) != sid + ) + # Final update should come through on the new websocket connection + AppHarness.expect(lambda: counter.text == "3", timeout=5) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index d1c728d5e81..12db8328e8f 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -2005,6 +2005,7 @@ async def test_state_proxy( namespace = mock_app.event_namespace assert namespace is not None namespace.sid_to_token[router_data.session.session_id] = token + namespace.token_to_sid[token] = router_data.session.session_id if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): mock_app.state_manager.states[parent_state.router.session.client_token] = ( parent_state @@ -2214,6 +2215,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): namespace = mock_app.event_namespace assert namespace is not None namespace.sid_to_token[sid] = token + namespace.token_to_sid[token] = sid mock_app.state_manager.state = mock_app._state = BackgroundTaskState async for update in rx.app.process( mock_app,