Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 48 additions & 17 deletions reflex/.templates/web/utils/state.js
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ export const connect = async (
) => {
// Get backend URL object from the endpoint.
const endpoint = getBackendURL(EVENTURL);
const on_hydrated_queue = [];

// Create the socket.
socket.current = io(endpoint.href, {
Expand All @@ -552,7 +553,17 @@ export const connect = async (

function checkVisibility() {
if (document.visibilityState === "visible") {
if (!socket.current.connected) {
if (!socket.current) {
connect(
socket,
dispatch,
transports,
setConnectErrors,
client_storage,
navigate,
params,
);
} else if (!socket.current.connected) {
console.log("Socket is disconnected, attempting to reconnect ");
socket.current.connect();
} else {
Expand Down Expand Up @@ -593,6 +604,7 @@ export const connect = async (

// When the socket disconnects reset the event_processing flag
socket.current.on("disconnect", () => {
socket.current = null; // allow reconnect to occur automatically

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Setting socket to null on disconnect enables automatic reconnection, but ensure this doesn't cause issues with concurrent access to socket.current in other parts of the codebase

event_processing = false;
window.removeEventListener("unload", disconnectTrigger);
window.removeEventListener("beforeunload", disconnectTrigger);
Expand All @@ -603,6 +615,14 @@ export const connect = async (
socket.current.on("event", async (update) => {
for (const substate in update.delta) {
dispatch[substate](update.delta[substate]);
// handle events waiting for `is_hydrated`
if (
substate === state_name &&
update.delta[substate]?.is_hydrated_rx_state_
) {
queueEvents(on_hydrated_queue, socket, false, navigate, params);
on_hydrated_queue.length = 0;
}
}
applyClientStorageDelta(client_storage, update.delta);
event_processing = !update.final;
Expand All @@ -612,7 +632,8 @@ export const connect = async (
});
socket.current.on("reload", async (event) => {
event_processing = false;
queueEvents([...initialEvents(), event], socket, true, navigate, params);
on_hydrated_queue.push(event);
queueEvents(initialEvents(), socket, true, navigate, params);
});
socket.current.on("new_token", async (new_token) => {
token = new_token;
Expand Down Expand Up @@ -774,10 +795,32 @@ export const useEventLoop = (
}
}, [paramsR]);

const ensureSocketConnected = useCallback(async () => {
// only use websockets if state is present and backend is not disabled (reflex cloud).
if (
Object.keys(initialState).length > 1 &&
!isBackendDisabled() &&
!socket.current
) {
// Initialize the websocket connection.
await connect(
socket,
dispatch,
["websocket"],
setConnectErrors,
client_storage,
navigate,
() => params.current,
);
}
}, [socket, dispatch, setConnectErrors, client_storage, navigate, params]);

// Function to add new events to the event queue.
const addEvents = useCallback((events, args, event_actions) => {
const _events = events.filter((e) => e !== undefined && e !== null);

ensureSocketConnected();

if (!(args instanceof Array)) {
args = [args];
}
Expand Down Expand Up @@ -870,21 +913,8 @@ export const useEventLoop = (

// Handle socket connect/disconnect.
useEffect(() => {
// only use websockets if state is present and backend is not disabled (reflex cloud).
if (Object.keys(initialState).length > 1 && !isBackendDisabled()) {
// Initialize the websocket connection.
if (!socket.current) {
connect(
socket,
dispatch,
["websocket"],
setConnectErrors,
client_storage,
navigate,
() => params.current,
);
}
}
// Initialize the websocket connection.
ensureSocketConnected();

// Cleanup function.
return () => {
Expand All @@ -903,6 +933,7 @@ export const useEventLoop = (
(async () => {
// Process all outstanding events.
while (event_queue.length > 0 && !event_processing) {
await ensureSocketConnected();
await processEvent(socket.current, navigate, () => params.current);
}
})();
Expand Down
20 changes: 10 additions & 10 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
State,
StateManager,
StateUpdate,
_split_substate_key,
_substate_key,
all_base_state_classes,
code_uses_state_contexts,
Expand Down Expand Up @@ -1559,7 +1560,7 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
state._clean()
await self.event_namespace.emit_update(
update=StateUpdate(delta=delta),
sid=state.router.session.session_id,
token=token,
)

def _process_background(
Expand Down Expand Up @@ -1599,7 +1600,7 @@ async def _coro():
# Send the update to the client.
await self.event_namespace.emit_update(
update=update,
sid=state.router.session.session_id,
token=event.token,
)

task = asyncio.create_task(
Expand Down Expand Up @@ -2061,20 +2062,19 @@ def on_disconnect(self, sid: str):
and console.error(f"Token cleanup error: {t.exception()}")
)

async def emit_update(self, update: StateUpdate, sid: str) -> None:
async def emit_update(self, update: StateUpdate, token: str) -> None:
"""Emit an update to the client.

Args:
update: The state update to send.
sid: The Socket.IO session id.
token: The client token (tab) associated with the event.
"""
if not sid:
client_token, _ = _split_substate_key(token)
sid = self.token_to_sid.get(client_token)
if sid is None:
# If the sid is None, we are not connected to a client. Prevent sending
# updates to all clients.
return
token = self.sid_to_token.get(sid)
if token is None:
console.warn(f"Attempting to send delta to disconnected websocket {sid}")
console.warn(f"Attempting to send delta to disconnected client {token!r}")
return
# Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task(
Expand Down Expand Up @@ -2165,7 +2165,7 @@ async def on_event(self, sid: str, data: Any):
# Process the events.
async for update in updates_gen:
# Emit the update from processing the event.
await self.emit_update(update=update, sid=sid)
await self.emit_update(update=update, token=event.token)

async def on_ping(self, sid: str):
"""Event for testing the API endpoint.
Expand Down
16 changes: 7 additions & 9 deletions reflex/istate/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,15 @@ def __init__(
state_instance: The state instance to proxy.
parent_state_proxy: The parent state proxy, for linked mutability and context tracking.
"""
from reflex.state import _substate_key

super().__init__(state_instance)
# compile is not relevant to backend logic
self._self_app = prerequisites.get_and_validate_app().app
self._self_substate_path = tuple(state_instance.get_full_name().split("."))
self._self_substate_token = _substate_key(
state_instance.router.session.client_token,
self._self_substate_path,
)
self._self_actx = None
self._self_mutable = False
self._self_actx_lock = asyncio.Lock()
Expand Down Expand Up @@ -127,16 +132,9 @@ async def __aenter__(self) -> StateProxy:
msg = "The state is already mutable. Do not nest `async with self` blocks."
raise ImmutableStateError(msg)

from reflex.state import _substate_key

await self._self_actx_lock.acquire()
self._self_actx_lock_holder = current_task
self._self_actx = self._self_app.modify_state(
token=_substate_key(
self.__wrapped__.router.session.client_token,
self._self_substate_path,
)
)
self._self_actx = self._self_app.modify_state(token=self._self_substate_token)
mutable_state = await self._self_actx.__aenter__()
super().__setattr__(
"__wrapped__", mutable_state.get_substate(self._self_substate_path)
Expand Down
56 changes: 56 additions & 0 deletions tests/integration/test_background_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ async def yield_in_async_with_self(self):
yield
self.counter += 1

@rx.event(background=True)
async def disconnect_reconnect_background(self):
async with self:
self.counter += 1
yield rx.call_script("socket.disconnect()")
await asyncio.sleep(0.5)
async with self:
self.counter += 1

class OtherState(rx.State):
@rx.event(background=True)
async def get_other_state(self):
Expand All @@ -134,6 +143,9 @@ def index() -> rx.Component:
rx.input(
id="token", value=State.router.session.client_token, is_read_only=True
),
rx.input(
id="sid", value=State.router.session.session_id, is_read_only=True
),
rx.hstack(
rx.heading(State.counter, id="counter"),
rx.text(State.counter_async_cv, size="1", id="counter-async-cv"),
Expand Down Expand Up @@ -185,6 +197,11 @@ def index() -> rx.Component:
on_click=State.yield_in_async_with_self,
id="yield-in-async-with-self",
),
rx.button(
"Disconnect / Reconnect Background",
on_click=State.disconnect_reconnect_background,
id="disconnect-reconnect-background",
),
rx.button("Reset", on_click=State.reset_counter, id="reset"),
)

Expand Down Expand Up @@ -395,3 +412,42 @@ def test_yield_in_async_with_self(

yield_in_async_with_self_button.click()
AppHarness.expect(lambda: counter.text == "2", timeout=5)


@pytest.mark.parametrize(
"button_id",
[
"disconnect-reconnect-background",
],
)
def test_disconnect_reconnect(
background_task: AppHarness,
driver: WebDriver,
token: str,
button_id: str,
):
"""Test that disconnecting and reconnecting works as expected.

Args:
background_task: harness for BackgroundTask app.
driver: WebDriver instance.
token: The token for the connected client.
button_id: The ID of the button to click.
"""
counter = driver.find_element(By.ID, "counter")
button = driver.find_element(By.ID, button_id)
increment_button = driver.find_element(By.ID, "increment")
sid_input = driver.find_element(By.ID, "sid")
sid = background_task.poll_for_value(sid_input, timeout=5)
assert sid is not None

AppHarness.expect(lambda: counter.text == "0", timeout=5)
button.click()
AppHarness.expect(lambda: counter.text == "1", timeout=5)
increment_button.click()
# should get a new sid after the reconnect
assert (
background_task.poll_for_value(sid_input, timeout=5, exp_not_equal=sid) != sid
)
# Final update should come through on the new websocket connection
AppHarness.expect(lambda: counter.text == "3", timeout=5)
2 changes: 2 additions & 0 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2005,6 +2005,7 @@ async def test_state_proxy(
namespace = mock_app.event_namespace
assert namespace is not None
namespace.sid_to_token[router_data.session.session_id] = token
namespace.token_to_sid[token] = router_data.session.session_id
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
mock_app.state_manager.states[parent_state.router.session.client_token] = (
parent_state
Expand Down Expand Up @@ -2214,6 +2215,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
namespace = mock_app.event_namespace
assert namespace is not None
namespace.sid_to_token[sid] = token
namespace.token_to_sid[token] = sid
mock_app.state_manager.state = mock_app._state = BackgroundTaskState
async for update in rx.app.process(
mock_app,
Expand Down
Loading