diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index 5d3f75f761542..0bcab5a98eecd 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -18,15 +18,18 @@ import dataclasses import json import logging +import queue import threading from ssl import CERT_NONE from threading import Thread -from time import sleep from websocket import WebSocketApp from selenium.common import WebDriverException +# Sentinel pushed onto the event queue to tell the dispatcher thread to stop. +_DISPATCHER_SHUTDOWN = object() + def _snake_to_camel(name: str) -> str: """Convert snake_case field name to camelCase for BiDi protocol.""" @@ -89,17 +92,41 @@ def __init__(self, url, timeout, interval): self.url = url self.response_wait_timeout = timeout + # Retained for backwards compatibility; the connection no longer + # busy-waits, so the interval no longer influences response latency. self.response_wait_interval = interval - self.callbacks = {} self.session_id = None + self._ws = None + self._ws_thread = None + self._id = 0 self._id_lock = threading.Lock() + + # Command responses keyed by id, alongside a per-request ``Event`` the + # receive thread sets when the matching response arrives. Both are + # guarded by ``_response_lock`` so caller threads and the receive thread + # share them safely instead of relying on the GIL. self._messages = {} - self._started = False + self._response_events = {} + self._response_lock = threading.Lock() + + # Event callbacks, guarded by ``_callbacks_lock``. Incoming events are + # handed to a single long-lived dispatcher thread: this preserves event + # ordering, bounds thread usage to one regardless of event volume (no + # thread-per-event exhaustion), and lets us surface callback exceptions + # instead of losing them on an orphaned thread. + self.callbacks = {} + self._callbacks_lock = threading.Lock() + self._dispatch_queue = queue.Queue() + self._dispatcher_thread = Thread(target=self._dispatch_events, daemon=True, name="BiDi-event-dispatcher") + self._dispatcher_thread.start() + + self._open_event = threading.Event() self._start_ws() - self._wait_until(lambda: self._started) + if not self._open_event.wait(self.response_wait_timeout): + raise WebDriverException("Timed out waiting for the BiDi websocket connection to open") def close(self): # Close the socket first so ``run_forever`` returns; only then join the @@ -112,7 +139,25 @@ def close(self): logger.debug(f"Error while closing websocket connection: {e}") if self._ws_thread is not None: self._ws_thread.join(timeout=self.response_wait_timeout) - self._started = False + + # Stop the dispatcher thread now the receive thread is done producing events. + self._dispatch_queue.put(_DISPATCHER_SHUTDOWN) + if self._dispatcher_thread is not None: + self._dispatcher_thread.join(timeout=self.response_wait_timeout) + + # Drop registered handlers so nothing fires after close, and wake any + # callers still blocked on a response so they fail fast rather than + # waiting out the full timeout. + with self._callbacks_lock: + self.callbacks.clear() + with self._response_lock: + self._messages.clear() + pending = list(self._response_events.values()) + self._response_events.clear() + for response_event in pending: + response_event.set() + + self._open_event.clear() self._ws = None def execute(self, command): @@ -126,12 +171,21 @@ def execute(self, command): data = json.dumps(payload, cls=_BiDiEncoder) logger.debug(f"-> {data}"[: self._max_log_message_size]) + + # Register the waiter before sending so a fast response can't arrive + # before we are ready to receive it. + response_event = threading.Event() + with self._response_lock: + self._response_events[current_id] = response_event + self._ws.send(data) - self._wait_until(lambda: current_id in self._messages) - if current_id not in self._messages: + response_event.wait(self.response_wait_timeout) + with self._response_lock: + self._response_events.pop(current_id, None) + response = self._messages.pop(current_id, None) + if response is None: raise WebDriverException(f"Timed out waiting for response to BiDi command {current_id}") - response = self._messages.pop(current_id) if "error" in response: error = response["error"] @@ -146,21 +200,20 @@ def execute(self, command): def add_callback(self, event, callback): event_name = event.event_class - if event_name not in self.callbacks: - self.callbacks[event_name] = [] def _callback(params): callback(event.from_json(params)) - self.callbacks[event_name].append(_callback) + with self._callbacks_lock: + self.callbacks.setdefault(event_name, []).append(_callback) return id(_callback) on = add_callback def remove_callback(self, event, callback_id): event_name = event.event_class - if event_name in self.callbacks: - for callback in self.callbacks[event_name]: + with self._callbacks_lock: + for callback in self.callbacks.get(event_name, []): if id(callback) == callback_id: self.callbacks[event_name].remove(callback) return @@ -177,7 +230,7 @@ def _deserialize_result(self, result, command): def _start_ws(self): def on_open(ws): - self._started = True + self._open_event.set() def on_message(ws, message): self._process_message(message) @@ -201,21 +254,31 @@ def _process_message(self, message): logger.debug(f"<- {message}"[: self._max_log_message_size]) if "id" in message: - self._messages[message["id"]] = message + message_id = message["id"] + with self._response_lock: + self._messages[message_id] = message + response_event = self._response_events.get(message_id) + if response_event is not None: + response_event.set() if "method" in message: - params = message["params"] - for callback in self.callbacks.get(message["method"], []): - Thread(target=callback, args=(params,), daemon=True).start() - - def _wait_until(self, condition): - timeout = self.response_wait_timeout - interval = self.response_wait_interval - - while timeout > 0: - result = condition() - if result: - return result - else: - timeout -= interval - sleep(interval) + # Hand events to the single dispatcher thread instead of spawning a + # thread per event; this keeps ordering and avoids the receive thread + # being blocked by a slow callback. + self._dispatch_queue.put((message["method"], message["params"])) + + def _dispatch_events(self): + while True: + item = self._dispatch_queue.get() + if item is _DISPATCHER_SHUTDOWN: + break + method, params = item + with self._callbacks_lock: + callbacks = list(self.callbacks.get(method, [])) + for callback in callbacks: + try: + callback(params) + except Exception: + # Never let one handler's failure kill the dispatcher or + # silently vanish: log it and keep delivering other events. + logger.error(f"Unhandled exception in BiDi event callback for '{method}'", exc_info=True) diff --git a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py index 86e3d11af0341..ba44725073e5f 100644 --- a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py @@ -618,9 +618,10 @@ def on_context_created(info): # Create a new context to trigger the event context_id = driver.browsing_context.create(type=WindowTypes.TAB) - # Verify the event was received (might be > 1 since default context is also included) - assert len(events_received) >= 1 - assert events_received[0].context == context_id or events_received[1].context == context_id + # context_created is a global event delivered asynchronously, and other contexts may also be + # reported, so wait for the event for the context we created rather than indexing positionally. + WebDriverWait(driver, 5).until(lambda d: any(e.context == context_id for e in events_received)) + assert any(e.context == context_id for e in events_received) driver.browsing_context.close(context_id) driver.browsing_context.remove_event_handler("context_created", callback_id) @@ -640,8 +641,8 @@ def on_context_destroyed(info): context_id = driver.browsing_context.create(type=WindowTypes.TAB) driver.browsing_context.close(context_id) - assert len(events_received) == 1 - assert events_received[0].context == context_id + WebDriverWait(driver, 5).until(lambda d: any(e.context == context_id for e in events_received)) + assert any(e.context == context_id for e in events_received) driver.browsing_context.remove_event_handler("context_destroyed", callback_id) @@ -661,6 +662,7 @@ def on_navigation_committed(info): url = pages.url("simpleTest.html") driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1) assert len(events_received) >= 1 assert any(url in event.url for event in events_received) @@ -682,6 +684,7 @@ def on_dom_content_loaded(info): url = pages.url("simpleTest.html") driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1) assert len(events_received) == 1 assert any("simpleTest" in event.url for event in events_received) @@ -702,6 +705,7 @@ def on_load(info): url = pages.url("simpleTest.html") driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1) assert len(events_received) == 1 assert any("simpleTest" in event.url for event in events_received) @@ -722,6 +726,7 @@ def on_navigation_started(info): url = pages.url("simpleTest.html") driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1) assert len(events_received) == 1 assert any("simpleTest" in event.url for event in events_received) @@ -747,6 +752,7 @@ def on_fragment_navigated(info): fragment_url = url + "#link" driver.browsing_context.navigate(context=context_id, url=fragment_url, wait=ReadinessState.COMPLETE) + WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1) assert len(events_received) == 1 assert any("link" in event.url for event in events_received) @@ -772,6 +778,7 @@ def on_navigation_failed(info): # Expect an exception due to navigation failure pass + WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1) assert len(events_received) == 1 assert events_received[0].url == "http://invalid-domain-that-does-not-exist.test/" assert events_received[0].context == context_id @@ -822,6 +829,7 @@ def on_user_prompt_closed(info): context=driver.current_window_handle, accept=True, user_text="test input" ) + WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1) assert len(events_received) == 1 assert events_received[0].accepted is True assert events_received[0].user_text == "test input" @@ -940,6 +948,7 @@ def on_context_created(info): # Create another context (should trigger event) new_context_id = driver.browsing_context.create(type=WindowTypes.TAB) + WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1) assert len(events_received) >= 1 driver.browsing_context.close(context_id) @@ -959,16 +968,18 @@ def on_context_created(info): # Create a context to trigger the event context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) - initial_events = len(events_received) + # Wait until the first context's event is observed (delivered asynchronously) + WebDriverWait(driver, 5).until(lambda d: any(e.context == context_id_1 for e in events_received)) # Remove the event handler driver.browsing_context.remove_event_handler("context_created", callback_id) - # Create another context (should not trigger event after removal) + # Create another context. remove_event_handler unsubscribes synchronously, so with the handler + # gone this context must never be observed. Asserting on this specific context avoids relying on + # event counts, which are unreliable because context_created may report more than one context. context_id_2 = driver.browsing_context.create(type=WindowTypes.TAB) - # Verify no new events were received after removal - assert len(events_received) == initial_events + assert not any(e.context == context_id_2 for e in events_received) driver.browsing_context.close(context_id_1) driver.browsing_context.close(context_id_2) @@ -992,10 +1003,13 @@ def on_context_created_2(info): # Create a context to trigger both handlers context_id = driver.browsing_context.create(type=WindowTypes.TAB) - # Verify both handlers received the event - assert len(events_received_1) >= 1 - assert len(events_received_2) >= 1 - # Check any of the events has the required context ID + # Verify both handlers received the created context's event (delivered asynchronously) + WebDriverWait(driver, 5).until( + lambda d: ( + any(e.context == context_id for e in events_received_1) + and any(e.context == context_id for e in events_received_2) + ) + ) assert any(event.context == context_id for event in events_received_1) assert any(event.context == context_id for event in events_received_2) @@ -1023,22 +1037,22 @@ def on_context_created_2(info): context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) # Verify both handlers received the event + WebDriverWait(driver, 5).until(lambda d: len(events_received_1) >= 1 and len(events_received_2) >= 1) assert len(events_received_1) >= 1 assert len(events_received_2) >= 1 - # store the initial event counts - initial_count_1 = len(events_received_1) - initial_count_2 = len(events_received_2) - # Remove only the first handler driver.browsing_context.remove_event_handler("context_created", callback_id_1) # Create another context context_id_2 = driver.browsing_context.create(type=WindowTypes.TAB) - # Verify only the second handler received the new event - assert len(events_received_1) == initial_count_1 # No new events - assert len(events_received_2) == initial_count_2 + 1 # 1 new event + # Only the second (still-registered) handler should observe the new context. Waiting for it to + # see that context's event also guarantees the dispatcher has caught up before we assert the + # removed handler saw nothing for it. + WebDriverWait(driver, 5).until(lambda d: any(e.context == context_id_2 for e in events_received_2)) + assert not any(e.context == context_id_2 for e in events_received_1) # removed handler: no new event + assert any(e.context == context_id_2 for e in events_received_2) # remaining handler: new event driver.browsing_context.close(context_id_1) driver.browsing_context.close(context_id_2) @@ -1145,6 +1159,9 @@ def test_event_callback_data_consistency(driver): for ctx in test_contexts: driver.browsing_context.close(ctx) + # 3 contexts created x 5 registered handlers; events are delivered asynchronously, so wait + # for all of them before asserting on the collected data. + WebDriverWait(driver, 10).until(lambda d: len(helper.events_received) >= 15) with helper.data_lock: assert not helper.consistency_errors, "Consistency errors: " + str(helper.consistency_errors) assert len(helper.events_received) > 0, "No events received" @@ -1179,15 +1196,17 @@ def test_no_event_after_handler_removal(driver): context = driver.browsing_context.create(type=WindowTypes.TAB) driver.browsing_context.close(context) - events_before = len(helper.events_received) + # Wait until the created context's event has been delivered to the handlers (async delivery) + WebDriverWait(driver, 10).until(lambda d: any(e.context == context for e in helper.events_received)) for i, callback_id in enumerate(helper.callback_ids): helper.remove_handler(callback_id, f"rem-{i}") + # With every handler removed (and unsubscribed), a newly created context must not be observed. + # Asserting on this specific context avoids relying on event counts, which are unreliable + # because context_created may report more than one context per creation. post_context = driver.browsing_context.create(type=WindowTypes.TAB) driver.browsing_context.close(post_context) with helper.data_lock: - new_events = len(helper.events_received) - events_before - - assert new_events == 0, f"Expected 0 new events after removal, got {new_events}" + assert not any(e.context == post_context for e in helper.events_received), "Handlers fired after removal" diff --git a/py/test/selenium/webdriver/common/bidi_input_tests.py b/py/test/selenium/webdriver/common/bidi_input_tests.py index 97ab7f0848870..f1bfb0cbe1f6f 100644 --- a/py/test/selenium/webdriver/common/bidi_input_tests.py +++ b/py/test/selenium/webdriver/common/bidi_input_tests.py @@ -209,7 +209,9 @@ def test_wheel_scroll(driver, pages): driver.input.perform_actions(driver.current_window_handle, [wheel_actions]) - # Verify the page scrolled by checking scroll position + # Verify the page scrolled by checking scroll position. The scroll is applied asynchronously + # by the browser, so wait for it to settle rather than reading immediately. + WebDriverWait(driver, 5).until(lambda d: d.execute_script("return window.pageYOffset;") == 100) scroll_y = driver.execute_script("return window.pageYOffset;") assert scroll_y == 100 @@ -601,6 +603,7 @@ def test_wheel_scroll_negative_delta(driver, pages): driver.input.perform_actions(driver.current_window_handle, [wheel_actions_down]) + WebDriverWait(driver, 5).until(lambda d: d.execute_script("return window.pageYOffset;") > 0) scroll_y_down = driver.execute_script("return window.pageYOffset;") assert scroll_y_down > 0 @@ -612,6 +615,7 @@ def test_wheel_scroll_negative_delta(driver, pages): driver.input.perform_actions(driver.current_window_handle, [wheel_actions_up]) + WebDriverWait(driver, 5).until(lambda d: d.execute_script("return window.pageYOffset;") < scroll_y_down) scroll_y_up = driver.execute_script("return window.pageYOffset;") assert scroll_y_up < scroll_y_down @@ -636,6 +640,7 @@ def test_wheel_scroll_with_duration(driver, pages): driver.input.perform_actions(driver.current_window_handle, [wheel_actions]) + WebDriverWait(driver, 5).until(lambda d: d.execute_script("return window.pageYOffset;") == 100) scroll_y = driver.execute_script("return window.pageYOffset;") assert scroll_y == 100 @@ -849,6 +854,7 @@ def test_combined_keyboard_and_wheel_actions(driver, pages): driver.input.perform_actions(driver.current_window_handle, [key_actions, wheel_actions]) + WebDriverWait(driver, 5).until(lambda d: d.execute_script("return window.pageYOffset;") == 100) scroll_y = driver.execute_script("return window.pageYOffset;") assert scroll_y == 100 diff --git a/py/test/unit/selenium/webdriver/common/websocket_connection_tests.py b/py/test/unit/selenium/webdriver/common/websocket_connection_tests.py new file mode 100644 index 0000000000000..a4168feacb540 --- /dev/null +++ b/py/test/unit/selenium/webdriver/common/websocket_connection_tests.py @@ -0,0 +1,252 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Transport-level unit tests for :class:`WebSocketConnection`. + +These exercise the concurrency contract of the BiDi transport (per-request +response routing, locked shared state, single-threaded event dispatch, and +clean teardown) without a real browser. Only the network boundary +(``_start_ws``) is replaced with an in-memory fake; all transport logic runs +for real. +""" + +import json +import logging +import threading +import time + +import pytest + +from selenium.common import WebDriverException +from selenium.webdriver.remote.websocket_connection import WebSocketConnection + + +class FakeWebSocketApp: + """In-memory stand-in for ``websocket.WebSocketApp``. + + Records every payload sent so a test can learn which command ids were + written and feed matching responses back through ``_process_message``. + """ + + def __init__(self): + self.sent = [] + self._lock = threading.Lock() + + def send(self, data): + with self._lock: + self.sent.append(json.loads(data)) + + def sent_ids(self): + with self._lock: + return [payload["id"] for payload in self.sent] + + def close(self): + pass + + +class StubConnection(WebSocketConnection): + """``WebSocketConnection`` wired to an in-memory socket. + + Overriding only ``_start_ws`` replaces the network boundary; locking, + response routing, and event dispatch are the real implementations. + """ + + def _start_ws(self): + self._ws = FakeWebSocketApp() + self._ws_thread = None + self._open_event.set() + + +class FakeEvent: + """Minimal event descriptor matching what ``add_callback`` expects.""" + + def __init__(self, name): + self.event_class = name + + def from_json(self, params): + return params + + +def _make_command(method): + """Build a BiDi-style command generator that echoes its result.""" + + def command(): + result = yield {"method": method, "params": {}} + return result + + return command() + + +def _feed_response(conn, message_id, result): + conn._process_message(json.dumps({"id": message_id, "result": result})) + + +def _feed_event(conn, method, params=None): + conn._process_message(json.dumps({"method": method, "params": params or {}})) + + +def _wait_for(predicate, timeout=5.0): + deadline = time.time() + timeout + while time.time() < deadline: + if predicate(): + return True + time.sleep(0.01) + return False + + +@pytest.fixture +def conn(): + connection = StubConnection("ws://localhost:9222", 5, 0.1) + yield connection + connection.close() + + +def test_execute_returns_matching_response(conn): + sent_id = [] + + def respond(): + assert _wait_for(lambda: conn._ws.sent_ids()) + message_id = conn._ws.sent_ids()[0] + sent_id.append(message_id) + _feed_response(conn, message_id, {"value": 42}) + + responder = threading.Thread(target=respond) + responder.start() + result = conn.execute(_make_command("session.status")) + responder.join() + + assert result == {"value": 42} + + +def test_concurrent_execute_routes_each_response_to_its_caller(conn): + count = 25 + results = {} + barrier = threading.Barrier(count) + + def worker(index): + barrier.wait() # maximise overlap on the send path + results[index] = conn.execute(_make_command(f"cmd-{index}")) + + workers = [threading.Thread(target=worker, args=(i,)) for i in range(count)] + for worker_thread in workers: + worker_thread.start() + + # Wait for every command to be written, then answer them in reverse order + # so a correct routing implementation cannot rely on FIFO ordering. + assert _wait_for(lambda: len(conn._ws.sent_ids()) == count) + for payload in reversed(list(conn._ws.sent)): + _feed_response(conn, payload["id"], {"echo": payload["method"]}) + + for worker_thread in workers: + worker_thread.join(timeout=5) + + assert len(results) == count + for index in range(count): + assert results[index] == {"echo": f"cmd-{index}"} + + +def test_execute_times_out_when_no_response(): + connection = StubConnection("ws://localhost:9222", 0.2, 0.1) + try: + with pytest.raises(WebDriverException, match="Timed out waiting for response"): + connection.execute(_make_command("session.status")) + finally: + connection.close() + + +def test_events_dispatch_on_single_thread(conn): + seen_threads = [] + done = threading.Event() + event = FakeEvent("log.entryAdded") + + def callback(_params): + seen_threads.append(threading.current_thread()) + if len(seen_threads) == 5: + done.set() + + conn.add_callback(event, callback) + for _ in range(5): + _feed_event(conn, "log.entryAdded") + + assert done.wait(5) + assert len(set(seen_threads)) == 1 + assert seen_threads[0] is conn._dispatcher_thread + + +def test_callback_exception_is_logged_and_dispatch_continues(conn, caplog): + delivered = [] + second_ran = threading.Event() + event = FakeEvent("log.entryAdded") + + def boom(_params): + raise ValueError("handler blew up") + + def good(_params): + delivered.append(_params) + second_ran.set() + + conn.add_callback(event, boom) + conn.add_callback(event, good) + + with caplog.at_level(logging.ERROR): + _feed_event(conn, "log.entryAdded", {"n": 1}) + assert second_ran.wait(5) + + # The failing handler must not stop the next handler in the same event... + assert delivered == [{"n": 1}] + # ...nor kill the dispatcher for subsequent events. + second_ran.clear() + _feed_event(conn, "log.entryAdded", {"n": 2}) + assert second_ran.wait(5) + assert delivered == [{"n": 1}, {"n": 2}] + + assert any(record.levelno == logging.ERROR for record in caplog.records) + assert "log.entryAdded" in caplog.text + + +def test_close_clears_callbacks_and_stops_dispatcher(): + connection = StubConnection("ws://localhost:9222", 5, 0.1) + connection.add_callback(FakeEvent("log.entryAdded"), lambda _p: None) + assert connection.callbacks + + connection.close() + + assert connection.callbacks == {} + assert _wait_for(lambda: not connection._dispatcher_thread.is_alive()) + + +def test_close_wakes_pending_callers(): + connection = StubConnection("ws://localhost:9222", 30, 0.1) + error = [] + + def worker(): + try: + connection.execute(_make_command("session.status")) + except WebDriverException as exc: + error.append(exc) + + caller = threading.Thread(target=worker) + caller.start() + assert _wait_for(lambda: connection._ws.sent_ids()) + + connection.close() + caller.join(timeout=5) + + # The blocked caller is released by close() rather than waiting out the + # 30s timeout, and surfaces a WebDriverException. + assert not caller.is_alive() + assert len(error) == 1