diff --git a/reflex/app.py b/reflex/app.py index 9c0d2af8db0..b22d3bdb4ca 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1399,7 +1399,19 @@ async def modify_state( token = BaseStateToken.from_legacy_token(token, root_state=self._state) # Ensure Reflex contexts are available (e.g. when called from an API route). - with self.set_contexts(): + with self.set_contexts(), contextlib.ExitStack() as rebind: + # Rebind the EventContext to the token being modified so consumers + # running inside (delta resolution, computed vars) observe this token + # rather than the event context the caller inherited -- e.g. the + # shared-state fan-out runs in a task that copied the triggering + # event's context for a different client. No-op without an EventContext. + try: + forked_context = EventContext.get().fork(token=token.ident) + except LookupError: + pass + else: + reset_token = EventContext.set(forked_context) + rebind.callback(EventContext.reset, reset_token) # Get exclusive access to the state. async with self.state_manager.modify_state_with_links( token, previous_dirty_vars=previous_dirty_vars, **context diff --git a/reflex/state.py b/reflex/state.py index e7c96018654..634eeef8558 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -266,14 +266,25 @@ def get_var_for_field(cls: type[BaseState], name: str, f: Field) -> Var: ) +# Sentinel a delta-value coroutine may resolve to in order to suppress its key: +# when ``_resolve_delta`` awaits a coroutine value and gets this object back, it +# drops the key from the delta instead of writing it. Lets a value whose +# inclusion can only be decided asynchronously be deferred into the delta as a +# coroutine and then omitted post-hoc. Compared by identity (the object itself is +# the contract); never serialized into a delta sent to the client. +_DROP_FROM_DELTA: Final = object() + + async def _resolve_delta(delta: Delta) -> Delta: - """Await all coroutines in the delta. + """Await all coroutines in the delta, dropping keys that resolve to the drop sentinel. Args: delta: The delta to process. Returns: - The same delta dict with all coroutines resolved to their return value. + The same delta dict with all coroutines resolved to their return value, + and any key whose coroutine resolved to ``_DROP_FROM_DELTA`` removed + (along with any state subdict left empty by such removals). """ tasks = {} for state_name, state_delta in delta.items(): @@ -284,7 +295,13 @@ async def _resolve_delta(delta: Delta) -> Delta: name=f"reflex_resolve_delta|{state_name}|{var_name}|{time.time()}", ) for (state_name, var_name), task in tasks.items(): - delta[state_name][var_name] = await task + resolved = await task + if resolved is _DROP_FROM_DELTA: + del delta[state_name][var_name] + if not delta[state_name]: + del delta[state_name] + else: + delta[state_name][var_name] = resolved return delta diff --git a/tests/units/test_app.py b/tests/units/test_app.py index c59bb62875a..955582ae76c 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -3449,6 +3449,37 @@ def _test(): isolated_context.run(_test) +async def test_modify_state_rebinds_event_context_to_token( + app_with_processor: App, +): + """modify_state(token) rebinds EventContext.token to the modified token. + + Out-of-band ``modify_state`` (e.g. the shared-state fan-out that recomputes + another client's delta) runs in a task that copied the triggering event's + EventContext. Without rebinding, code inside (``get_delta``, computed vars) + would read the *actor's* token, not the token whose state is being modified. + """ + app_with_processor._state_manager = StateManagerMemory() + app_with_processor._event_namespace = AsyncMock() + assert app_with_processor._event_processor is not None + root_context = app_with_processor._event_processor._root_context + assert root_context is not None + + # Simulate the actor (token-A) event context the way the processor sets it + # (via ``set``, which the fan-out task then inherits by copying the context). + actor_token = EventContext.set(root_context.fork(token="token-A")) + try: + assert EventContext.get().token == "token-A" + async with app_with_processor.modify_state( + BaseStateToken(ident="token-B", cls=EmptyState) + ): + assert EventContext.get().token == "token-B" + # The actor's context is restored after modify_state exits. + assert EventContext.get().token == "token-A" + finally: + EventContext.reset(actor_token) + + def test_set_contexts_no_event_processor(isolated_context: contextvars.Context): """When event processor is None, EventContext should not be touched.""" diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 83aa823598a..ce67474ade2 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -4985,3 +4985,39 @@ def child_view(self) -> int: parent_deps = ParentDescState._var_dependencies.get("_shared", set()) assert (ChildDescState.get_full_name(), "child_view") in child_deps assert (ParentDescState.get_full_name(), "parent_view") in parent_deps + + +async def test_resolve_delta_awaits_coroutines_and_keeps_plain_values(): + """_resolve_delta awaits coroutine values and leaves plain values untouched.""" + from reflex.state import _resolve_delta + + async def _coro(value): # noqa: RUF029 - a trivial coroutine value for the delta + return value + + delta = {"s1": {"a": _coro(1), "b": 2}} + resolved = await _resolve_delta(delta) + assert resolved == {"s1": {"a": 1, "b": 2}} + + +async def test_resolve_delta_drops_keys_resolving_to_sentinel(): + """A coroutine resolving to _DROP_FROM_DELTA removes its key from the delta.""" + from reflex.state import _DROP_FROM_DELTA, _resolve_delta + + async def _coro(value): # noqa: RUF029 - a trivial coroutine value for the delta + return value + + delta = {"s1": {"gone": _coro(_DROP_FROM_DELTA), "stay": _coro("kept"), "plain": 3}} + resolved = await _resolve_delta(delta) + assert resolved == {"s1": {"stay": "kept", "plain": 3}} + + +async def test_resolve_delta_pops_subdict_emptied_by_drops(): + """A state subdict left empty after dropping all its keys is removed entirely.""" + from reflex.state import _DROP_FROM_DELTA, _resolve_delta + + async def _coro(value): # noqa: RUF029 - a trivial coroutine value for the delta + return value + + delta = {"s1": {"only": _coro(_DROP_FROM_DELTA)}, "s2": {"keep": 1}} + resolved = await _resolve_delta(delta) + assert resolved == {"s2": {"keep": 1}}