Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
State,
StateManager,
StateUpdate,
_split_substate_key,
_substate_key,
all_base_state_classes,
code_uses_state_contexts,
Expand Down Expand Up @@ -1559,7 +1560,7 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
state._clean()
await self.event_namespace.emit_update(
update=StateUpdate(delta=delta),
sid=state.router.session.session_id,
token=token,
)

def _process_background(
Expand Down Expand Up @@ -1599,7 +1600,7 @@ async def _coro():
# Send the update to the client.
await self.event_namespace.emit_update(
update=update,
sid=state.router.session.session_id,
token=event.token,
)

task = asyncio.create_task(
Expand Down Expand Up @@ -2061,20 +2062,19 @@ def on_disconnect(self, sid: str):
and console.error(f"Token cleanup error: {t.exception()}")
)

async def emit_update(self, update: StateUpdate, sid: str) -> None:
async def emit_update(self, update: StateUpdate, token: str) -> None:
"""Emit an update to the client.

Args:
update: The state update to send.
sid: The Socket.IO session id.
token: The client token (tab) associated with the event.
"""
if not sid:
client_token, _ = _split_substate_key(token)
sid = self.token_to_sid.get(client_token)
if sid is None:
# If the sid is None, we are not connected to a client. Prevent sending
# updates to all clients.
return
token = self.sid_to_token.get(sid)
if token is None:
console.warn(f"Attempting to send delta to disconnected websocket {sid}")
console.warn(f"Attempting to send delta to disconnected client {token!r}")
return
# Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task(
Expand Down Expand Up @@ -2165,7 +2165,7 @@ async def on_event(self, sid: str, data: Any):
# Process the events.
async for update in updates_gen:
# Emit the update from processing the event.
await self.emit_update(update=update, sid=sid)
await self.emit_update(update=update, token=event.token)

async def on_ping(self, sid: str):
"""Event for testing the API endpoint.
Expand Down
16 changes: 7 additions & 9 deletions reflex/istate/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,15 @@ def __init__(
state_instance: The state instance to proxy.
parent_state_proxy: The parent state proxy, for linked mutability and context tracking.
"""
from reflex.state import _substate_key

super().__init__(state_instance)
# compile is not relevant to backend logic
self._self_app = prerequisites.get_and_validate_app().app
self._self_substate_path = tuple(state_instance.get_full_name().split("."))
self._self_substate_token = _substate_key(
state_instance.router.session.client_token,
self._self_substate_path,
)
self._self_actx = None
self._self_mutable = False
self._self_actx_lock = asyncio.Lock()
Expand Down Expand Up @@ -127,16 +132,9 @@ async def __aenter__(self) -> StateProxy:
msg = "The state is already mutable. Do not nest `async with self` blocks."
raise ImmutableStateError(msg)

from reflex.state import _substate_key

await self._self_actx_lock.acquire()
self._self_actx_lock_holder = current_task
self._self_actx = self._self_app.modify_state(
token=_substate_key(
self.__wrapped__.router.session.client_token,
self._self_substate_path,
)
)
self._self_actx = self._self_app.modify_state(token=self._self_substate_token)
mutable_state = await self._self_actx.__aenter__()
super().__setattr__(
"__wrapped__", mutable_state.get_substate(self._self_substate_path)
Expand Down
2 changes: 2 additions & 0 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2005,6 +2005,7 @@ async def test_state_proxy(
namespace = mock_app.event_namespace
assert namespace is not None
namespace.sid_to_token[router_data.session.session_id] = token
namespace.token_to_sid[token] = router_data.session.session_id
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
mock_app.state_manager.states[parent_state.router.session.client_token] = (
parent_state
Expand Down Expand Up @@ -2214,6 +2215,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
namespace = mock_app.event_namespace
assert namespace is not None
namespace.sid_to_token[sid] = token
namespace.token_to_sid[token] = sid
mock_app.state_manager.state = mock_app._state = BackgroundTaskState
async for update in rx.app.process(
mock_app,
Expand Down
Loading