diff --git a/reflex/app.py b/reflex/app.py index 8792188fefd..61e76324e0d 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -97,6 +97,7 @@ State, StateManager, StateUpdate, + _split_substate_key, _substate_key, all_base_state_classes, code_uses_state_contexts, @@ -1559,7 +1560,7 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: state._clean() await self.event_namespace.emit_update( update=StateUpdate(delta=delta), - sid=state.router.session.session_id, + token=token, ) def _process_background( @@ -1599,7 +1600,7 @@ async def _coro(): # Send the update to the client. await self.event_namespace.emit_update( update=update, - sid=state.router.session.session_id, + token=event.token, ) task = asyncio.create_task( @@ -2061,20 +2062,19 @@ def on_disconnect(self, sid: str): and console.error(f"Token cleanup error: {t.exception()}") ) - async def emit_update(self, update: StateUpdate, sid: str) -> None: + async def emit_update(self, update: StateUpdate, token: str) -> None: """Emit an update to the client. Args: update: The state update to send. - sid: The Socket.IO session id. + token: The client token (tab) associated with the event. """ - if not sid: + client_token, _ = _split_substate_key(token) + sid = self.token_to_sid.get(client_token) + if sid is None: # If the sid is None, we are not connected to a client. Prevent sending # updates to all clients. - return - token = self.sid_to_token.get(sid) - if token is None: - console.warn(f"Attempting to send delta to disconnected websocket {sid}") + console.warn(f"Attempting to send delta to disconnected client {token!r}") return # Creating a task prevents the update from being blocked behind other coroutines. await asyncio.create_task( @@ -2165,7 +2165,7 @@ async def on_event(self, sid: str, data: Any): # Process the events. async for update in updates_gen: # Emit the update from processing the event. - await self.emit_update(update=update, sid=sid) + await self.emit_update(update=update, token=event.token) async def on_ping(self, sid: str): """Event for testing the API endpoint. diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index e152036cac7..5c61237685a 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -71,10 +71,15 @@ def __init__( state_instance: The state instance to proxy. parent_state_proxy: The parent state proxy, for linked mutability and context tracking. """ + from reflex.state import _substate_key + super().__init__(state_instance) - # compile is not relevant to backend logic self._self_app = prerequisites.get_and_validate_app().app self._self_substate_path = tuple(state_instance.get_full_name().split(".")) + self._self_substate_token = _substate_key( + state_instance.router.session.client_token, + self._self_substate_path, + ) self._self_actx = None self._self_mutable = False self._self_actx_lock = asyncio.Lock() @@ -127,16 +132,9 @@ async def __aenter__(self) -> StateProxy: msg = "The state is already mutable. Do not nest `async with self` blocks." raise ImmutableStateError(msg) - from reflex.state import _substate_key - await self._self_actx_lock.acquire() self._self_actx_lock_holder = current_task - self._self_actx = self._self_app.modify_state( - token=_substate_key( - self.__wrapped__.router.session.client_token, - self._self_substate_path, - ) - ) + self._self_actx = self._self_app.modify_state(token=self._self_substate_token) mutable_state = await self._self_actx.__aenter__() super().__setattr__( "__wrapped__", mutable_state.get_substate(self._self_substate_path) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index d1c728d5e81..12db8328e8f 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -2005,6 +2005,7 @@ async def test_state_proxy( namespace = mock_app.event_namespace assert namespace is not None namespace.sid_to_token[router_data.session.session_id] = token + namespace.token_to_sid[token] = router_data.session.session_id if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): mock_app.state_manager.states[parent_state.router.session.client_token] = ( parent_state @@ -2214,6 +2215,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): namespace = mock_app.event_namespace assert namespace is not None namespace.sid_to_token[sid] = token + namespace.token_to_sid[token] = sid mock_app.state_manager.state = mock_app._state = BackgroundTaskState async for update in rx.app.process( mock_app,