diff --git a/packages/reflex-base/news/6644.bugfix.md b/packages/reflex-base/news/6644.bugfix.md new file mode 100644 index 00000000000..63ba91423ba --- /dev/null +++ b/packages/reflex-base/news/6644.bugfix.md @@ -0,0 +1 @@ +Frontend-only events (e.g. `rx.toast`, `rx.redirect`) returned from a middleware's `preprocess` are now emitted to the client instead of being enqueued on the backend event queue, where they had no registered handler and raised `KeyError`. The frontend/backend split that already applied to handler-yielded events is now shared via a `_route_events` helper and applied to middleware-preprocess updates too. diff --git a/packages/reflex-base/src/reflex_base/event/processor/base_state_processor.py b/packages/reflex-base/src/reflex_base/event/processor/base_state_processor.py index c8ad3e5d589..f361b2f3969 100644 --- a/packages/reflex-base/src/reflex_base/event/processor/base_state_processor.py +++ b/packages/reflex-base/src/reflex_base/event/processor/base_state_processor.py @@ -21,7 +21,7 @@ from reflex_base.utils.format import format_event_handler if TYPE_CHECKING: - from reflex.event import EventHandler, EventSpec + from reflex.event import Event, EventHandler, EventSpec from reflex.state import BaseState @@ -163,6 +163,29 @@ def _transform_event_payload( return transformed +async def _route_events(ctx: EventContext, events: Sequence[Event]) -> None: + """Emit frontend events to the client and queue backend events. + + Events whose name starts with ``_`` are frontend-only specs (e.g. + ``_redirect``, ``_call_function``) with no registered backend handler. + + Args: + ctx: The event context to emit/enqueue through. + events: The events to route. + """ + frontend_events: list[Event] = [] + backend_events: list[Event] = [] + for ev in events: + if ev.name.startswith("_"): + frontend_events.append(ev) + else: + backend_events.append(ev) + if frontend_events: + await ctx.emit_event(*frontend_events) + if backend_events: + await ctx.enqueue(*backend_events) + + async def chain_updates( events: EventSpec | list[EventSpec] | None, handler_name: str, @@ -195,11 +218,7 @@ async def chain_updates( if fixed_events := Event.from_event_type( _check_valid_yield(events, handler_name=handler_name), ): - # Frontend events. - if frontend_events := [e for e in fixed_events if e.name.startswith("_")]: - await ctx.emit_event(*frontend_events) - # Backend events. - await ctx.enqueue(*(e for e in fixed_events if not e.name.startswith("_"))) + await _route_events(ctx, fixed_events) async def process_event( @@ -348,7 +367,7 @@ async def _execute_event( if update.delta: await ctx.emit_delta(update.delta) if update.events: - await ctx.enqueue(*update.events) + await _route_events(ctx, update.events) return # Get the event's substate. diff --git a/tests/units/reflex_base/event/processor/test_base_state_processor.py b/tests/units/reflex_base/event/processor/test_base_state_processor.py index e414369a40e..90fc86bb6fd 100644 --- a/tests/units/reflex_base/event/processor/test_base_state_processor.py +++ b/tests/units/reflex_base/event/processor/test_base_state_processor.py @@ -12,11 +12,13 @@ from reflex_base.event.processor import BaseStateEventProcessor from reflex_base.registry import RegistrationContext +import reflex as rx from reflex import event from reflex.app import App from reflex.event import Event from reflex.istate.manager.memory import StateManagerMemory -from reflex.state import OnLoadInternalState, State +from reflex.middleware.middleware import Middleware +from reflex.state import OnLoadInternalState, State, StateUpdate @pytest.fixture @@ -103,9 +105,30 @@ async def emit_event_impl(token: str, *events: Event) -> None: # noqa: RUF029 await state_manager.close() -async def test_rehydrate_sets_is_hydrated_on_fresh_token( +@pytest.fixture +def wired_app( app_module_mock, real_base_state_processor: BaseStateEventProcessor, +) -> App: + """An App registered as the app module's app and sharing the processor's state manager. + + Args: + app_module_mock: The mock app module fixture. + real_base_state_processor: The unmocked BaseStateEventProcessor. + + Returns: + The wired App instance. + """ + OnLoadInternalState._app_ref = None + app = app_module_mock.app = App() + assert real_base_state_processor._root_context is not None + app._state_manager = real_base_state_processor._root_context.state_manager + return app + + +async def test_rehydrate_sets_is_hydrated_on_fresh_token( + wired_app: App, + real_base_state_processor: BaseStateEventProcessor, emitted_deltas: list[tuple[str, Mapping[str, Mapping[str, Any]]]], token: str, ): @@ -117,7 +140,7 @@ async def test_rehydrate_sets_is_hydrated_on_fresh_token( hydrate sets is_hydrated=True directly. Args: - app_module_mock: The mock app module fixture. + wired_app: The App wired to the processor's state manager. real_base_state_processor: The unmocked BaseStateEventProcessor. emitted_deltas: List to capture emitted deltas. token: The client token. @@ -128,11 +151,6 @@ class MyState(State): def noop(self): pass - OnLoadInternalState._app_ref = None - app = app_module_mock.app = App() - assert real_base_state_processor._root_context is not None - app._state_manager = real_base_state_processor._root_context.state_manager - async with real_base_state_processor as processor: await processor.enqueue( token, @@ -150,3 +168,50 @@ def noop(self): assert len(hydrated_deltas) >= 1, ( f"Expected at least one delta with is_hydrated=True, got deltas: {emitted_deltas}" ) + + +async def test_preprocess_update_routes_frontend_events_to_client( + wired_app: App, + real_base_state_processor: BaseStateEventProcessor, + emitted_events: list[tuple[str, tuple[Event, ...]]], + token: str, +): + """Frontend-only events in a middleware preprocess update reach the client. + + Regression: a blocking middleware (e.g. an auth gate) returns a + ``StateUpdate`` whose events are frontend specs like ``rx.toast`` + (``_call_function``) or ``rx.redirect`` (``_redirect``). Those have no + registered backend handler, so they must be emitted to the client instead + of enqueued on the backend queue (where they raise ``KeyError``). + + Args: + wired_app: The App wired to the processor's state manager. + real_base_state_processor: The unmocked BaseStateEventProcessor. + emitted_events: List to capture events emitted to the client. + token: The client token. + """ + + class GatedState(State): + @event + def do_thing(self): + pass + + class BlockingMiddleware(Middleware): + async def preprocess(self, app, state, event) -> StateUpdate: + return StateUpdate( + events=Event.from_event_type([ + rx.toast("Action not allowed"), + rx.redirect("/login"), + ]) + ) + + wired_app.add_middleware(BlockingMiddleware()) + real_base_state_processor.middleware = wired_app + + async with real_base_state_processor as p: + await p.enqueue(token, Event.from_event_type(GatedState.do_thing())[0]) + await p.join(1) + + client_event_names = {e.name for _, events in emitted_events for e in events} + assert "_call_function" in client_event_names + assert "_redirect" in client_event_names