Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 93 additions & 30 deletions py/selenium/webdriver/remote/websocket_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@
import dataclasses
import json
import logging
import queue
import threading
from ssl import CERT_NONE
from threading import Thread
from time import sleep

from websocket import WebSocketApp

from selenium.common import WebDriverException

# Sentinel pushed onto the event queue to tell the dispatcher thread to stop.
_DISPATCHER_SHUTDOWN = object()


def _snake_to_camel(name: str) -> str:
"""Convert snake_case field name to camelCase for BiDi protocol."""
Expand Down Expand Up @@ -89,17 +92,41 @@ def __init__(self, url, timeout, interval):

self.url = url
self.response_wait_timeout = timeout
# Retained for backwards compatibility; the connection no longer
# busy-waits, so the interval no longer influences response latency.
self.response_wait_interval = interval

self.callbacks = {}
self.session_id = None
self._ws = None
self._ws_thread = None

self._id = 0
self._id_lock = threading.Lock()

# Command responses keyed by id, alongside a per-request ``Event`` the
# receive thread sets when the matching response arrives. Both are
# guarded by ``_response_lock`` so caller threads and the receive thread
# share them safely instead of relying on the GIL.
self._messages = {}
self._started = False
self._response_events = {}
self._response_lock = threading.Lock()

# Event callbacks, guarded by ``_callbacks_lock``. Incoming events are
# handed to a single long-lived dispatcher thread: this preserves event
# ordering, bounds thread usage to one regardless of event volume (no
# thread-per-event exhaustion), and lets us surface callback exceptions
# instead of losing them on an orphaned thread.
self.callbacks = {}
self._callbacks_lock = threading.Lock()
self._dispatch_queue = queue.Queue()
self._dispatcher_thread = Thread(target=self._dispatch_events, daemon=True, name="BiDi-event-dispatcher")
self._dispatcher_thread.start()

self._open_event = threading.Event()

self._start_ws()
self._wait_until(lambda: self._started)
if not self._open_event.wait(self.response_wait_timeout):
raise WebDriverException("Timed out waiting for the BiDi websocket connection to open")

def close(self):
# Close the socket first so ``run_forever`` returns; only then join the
Expand All @@ -112,7 +139,25 @@ def close(self):
logger.debug(f"Error while closing websocket connection: {e}")
if self._ws_thread is not None:
self._ws_thread.join(timeout=self.response_wait_timeout)
self._started = False

# Stop the dispatcher thread now the receive thread is done producing events.
self._dispatch_queue.put(_DISPATCHER_SHUTDOWN)
if self._dispatcher_thread is not None:
self._dispatcher_thread.join(timeout=self.response_wait_timeout)

# Drop registered handlers so nothing fires after close, and wake any
# callers still blocked on a response so they fail fast rather than
# waiting out the full timeout.
with self._callbacks_lock:
self.callbacks.clear()
with self._response_lock:
self._messages.clear()
pending = list(self._response_events.values())
self._response_events.clear()
for response_event in pending:
response_event.set()

self._open_event.clear()
self._ws = None

def execute(self, command):
Expand All @@ -126,12 +171,21 @@ def execute(self, command):

data = json.dumps(payload, cls=_BiDiEncoder)
logger.debug(f"-> {data}"[: self._max_log_message_size])

# Register the waiter before sending so a fast response can't arrive
# before we are ready to receive it.
response_event = threading.Event()
with self._response_lock:
self._response_events[current_id] = response_event

self._ws.send(data)

self._wait_until(lambda: current_id in self._messages)
if current_id not in self._messages:
response_event.wait(self.response_wait_timeout)
with self._response_lock:
self._response_events.pop(current_id, None)
response = self._messages.pop(current_id, None)
if response is None:
raise WebDriverException(f"Timed out waiting for response to BiDi command {current_id}")
response = self._messages.pop(current_id)

if "error" in response:
error = response["error"]
Expand All @@ -146,21 +200,20 @@ def execute(self, command):

def add_callback(self, event, callback):
event_name = event.event_class
if event_name not in self.callbacks:
self.callbacks[event_name] = []

def _callback(params):
callback(event.from_json(params))

self.callbacks[event_name].append(_callback)
with self._callbacks_lock:
self.callbacks.setdefault(event_name, []).append(_callback)
return id(_callback)

on = add_callback

def remove_callback(self, event, callback_id):
event_name = event.event_class
if event_name in self.callbacks:
for callback in self.callbacks[event_name]:
with self._callbacks_lock:
for callback in self.callbacks.get(event_name, []):
if id(callback) == callback_id:
self.callbacks[event_name].remove(callback)
return
Expand All @@ -177,7 +230,7 @@ def _deserialize_result(self, result, command):

def _start_ws(self):
def on_open(ws):
self._started = True
self._open_event.set()

def on_message(ws, message):
self._process_message(message)
Expand All @@ -201,21 +254,31 @@ def _process_message(self, message):
logger.debug(f"<- {message}"[: self._max_log_message_size])

if "id" in message:
self._messages[message["id"]] = message
message_id = message["id"]
with self._response_lock:
self._messages[message_id] = message
response_event = self._response_events.get(message_id)
if response_event is not None:
response_event.set()

if "method" in message:
params = message["params"]
for callback in self.callbacks.get(message["method"], []):
Thread(target=callback, args=(params,), daemon=True).start()

def _wait_until(self, condition):
timeout = self.response_wait_timeout
interval = self.response_wait_interval

while timeout > 0:
result = condition()
if result:
return result
else:
timeout -= interval
sleep(interval)
# Hand events to the single dispatcher thread instead of spawning a
# thread per event; this keeps ordering and avoids the receive thread
# being blocked by a slow callback.
self._dispatch_queue.put((message["method"], message["params"]))

def _dispatch_events(self):
while True:
item = self._dispatch_queue.get()
if item is _DISPATCHER_SHUTDOWN:
break
method, params = item
with self._callbacks_lock:
callbacks = list(self.callbacks.get(method, []))
for callback in callbacks:
try:
callback(params)
except Exception:
# Never let one handler's failure kill the dispatcher or
# silently vanish: log it and keep delivering other events.
logger.error(f"Unhandled exception in BiDi event callback for '{method}'", exc_info=True)
67 changes: 43 additions & 24 deletions py/test/selenium/webdriver/common/bidi_browsing_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,9 +618,10 @@ def on_context_created(info):
# Create a new context to trigger the event
context_id = driver.browsing_context.create(type=WindowTypes.TAB)

# Verify the event was received (might be > 1 since default context is also included)
assert len(events_received) >= 1
assert events_received[0].context == context_id or events_received[1].context == context_id
# context_created is a global event delivered asynchronously, and other contexts may also be
# reported, so wait for the event for the context we created rather than indexing positionally.
WebDriverWait(driver, 5).until(lambda d: any(e.context == context_id for e in events_received))
assert any(e.context == context_id for e in events_received)

driver.browsing_context.close(context_id)
driver.browsing_context.remove_event_handler("context_created", callback_id)
Expand All @@ -640,8 +641,8 @@ def on_context_destroyed(info):
context_id = driver.browsing_context.create(type=WindowTypes.TAB)
driver.browsing_context.close(context_id)

assert len(events_received) == 1
assert events_received[0].context == context_id
WebDriverWait(driver, 5).until(lambda d: any(e.context == context_id for e in events_received))
assert any(e.context == context_id for e in events_received)

driver.browsing_context.remove_event_handler("context_destroyed", callback_id)

Expand All @@ -661,6 +662,7 @@ def on_navigation_committed(info):
url = pages.url("simpleTest.html")
driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE)

WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1)
assert len(events_received) >= 1
assert any(url in event.url for event in events_received)

Expand All @@ -682,6 +684,7 @@ def on_dom_content_loaded(info):
url = pages.url("simpleTest.html")
driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE)

WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1)
assert len(events_received) == 1
assert any("simpleTest" in event.url for event in events_received)

Expand All @@ -702,6 +705,7 @@ def on_load(info):
url = pages.url("simpleTest.html")
driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE)

WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1)
assert len(events_received) == 1
assert any("simpleTest" in event.url for event in events_received)

Expand All @@ -722,6 +726,7 @@ def on_navigation_started(info):
url = pages.url("simpleTest.html")
driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE)

WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1)
assert len(events_received) == 1
assert any("simpleTest" in event.url for event in events_received)

Expand All @@ -747,6 +752,7 @@ def on_fragment_navigated(info):
fragment_url = url + "#link"
driver.browsing_context.navigate(context=context_id, url=fragment_url, wait=ReadinessState.COMPLETE)

WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1)
assert len(events_received) == 1
assert any("link" in event.url for event in events_received)

Expand All @@ -772,6 +778,7 @@ def on_navigation_failed(info):
# Expect an exception due to navigation failure
pass

WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1)
assert len(events_received) == 1
assert events_received[0].url == "http://invalid-domain-that-does-not-exist.test/"
assert events_received[0].context == context_id
Expand Down Expand Up @@ -822,6 +829,7 @@ def on_user_prompt_closed(info):
context=driver.current_window_handle, accept=True, user_text="test input"
)

WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1)
assert len(events_received) == 1
assert events_received[0].accepted is True
assert events_received[0].user_text == "test input"
Expand Down Expand Up @@ -940,6 +948,7 @@ def on_context_created(info):
# Create another context (should trigger event)
new_context_id = driver.browsing_context.create(type=WindowTypes.TAB)

WebDriverWait(driver, 5).until(lambda d: len(events_received) >= 1)
assert len(events_received) >= 1

driver.browsing_context.close(context_id)
Expand All @@ -959,16 +968,18 @@ def on_context_created(info):
# Create a context to trigger the event
context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB)

initial_events = len(events_received)
# Wait until the first context's event is observed (delivered asynchronously)
WebDriverWait(driver, 5).until(lambda d: any(e.context == context_id_1 for e in events_received))

# Remove the event handler
driver.browsing_context.remove_event_handler("context_created", callback_id)

# Create another context (should not trigger event after removal)
# Create another context. remove_event_handler unsubscribes synchronously, so with the handler
# gone this context must never be observed. Asserting on this specific context avoids relying on
# event counts, which are unreliable because context_created may report more than one context.
context_id_2 = driver.browsing_context.create(type=WindowTypes.TAB)

# Verify no new events were received after removal
assert len(events_received) == initial_events
assert not any(e.context == context_id_2 for e in events_received)

driver.browsing_context.close(context_id_1)
driver.browsing_context.close(context_id_2)
Expand All @@ -992,10 +1003,13 @@ def on_context_created_2(info):
# Create a context to trigger both handlers
context_id = driver.browsing_context.create(type=WindowTypes.TAB)

# Verify both handlers received the event
assert len(events_received_1) >= 1
assert len(events_received_2) >= 1
# Check any of the events has the required context ID
# Verify both handlers received the created context's event (delivered asynchronously)
WebDriverWait(driver, 5).until(
lambda d: (
any(e.context == context_id for e in events_received_1)
and any(e.context == context_id for e in events_received_2)
)
)
assert any(event.context == context_id for event in events_received_1)
assert any(event.context == context_id for event in events_received_2)

Expand Down Expand Up @@ -1023,22 +1037,22 @@ def on_context_created_2(info):
context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB)

# Verify both handlers received the event
WebDriverWait(driver, 5).until(lambda d: len(events_received_1) >= 1 and len(events_received_2) >= 1)
assert len(events_received_1) >= 1
assert len(events_received_2) >= 1

# store the initial event counts
initial_count_1 = len(events_received_1)
initial_count_2 = len(events_received_2)

# Remove only the first handler
driver.browsing_context.remove_event_handler("context_created", callback_id_1)

# Create another context
context_id_2 = driver.browsing_context.create(type=WindowTypes.TAB)

# Verify only the second handler received the new event
assert len(events_received_1) == initial_count_1 # No new events
assert len(events_received_2) == initial_count_2 + 1 # 1 new event
# Only the second (still-registered) handler should observe the new context. Waiting for it to
# see that context's event also guarantees the dispatcher has caught up before we assert the
# removed handler saw nothing for it.
WebDriverWait(driver, 5).until(lambda d: any(e.context == context_id_2 for e in events_received_2))
assert not any(e.context == context_id_2 for e in events_received_1) # removed handler: no new event
assert any(e.context == context_id_2 for e in events_received_2) # remaining handler: new event

driver.browsing_context.close(context_id_1)
driver.browsing_context.close(context_id_2)
Expand Down Expand Up @@ -1145,6 +1159,9 @@ def test_event_callback_data_consistency(driver):
for ctx in test_contexts:
driver.browsing_context.close(ctx)

# 3 contexts created x 5 registered handlers; events are delivered asynchronously, so wait
# for all of them before asserting on the collected data.
WebDriverWait(driver, 10).until(lambda d: len(helper.events_received) >= 15)
with helper.data_lock:
assert not helper.consistency_errors, "Consistency errors: " + str(helper.consistency_errors)
assert len(helper.events_received) > 0, "No events received"
Expand Down Expand Up @@ -1179,15 +1196,17 @@ def test_no_event_after_handler_removal(driver):
context = driver.browsing_context.create(type=WindowTypes.TAB)
driver.browsing_context.close(context)

events_before = len(helper.events_received)
# Wait until the created context's event has been delivered to the handlers (async delivery)
WebDriverWait(driver, 10).until(lambda d: any(e.context == context for e in helper.events_received))

for i, callback_id in enumerate(helper.callback_ids):
helper.remove_handler(callback_id, f"rem-{i}")

# With every handler removed (and unsubscribed), a newly created context must not be observed.
# Asserting on this specific context avoids relying on event counts, which are unreliable
# because context_created may report more than one context per creation.
post_context = driver.browsing_context.create(type=WindowTypes.TAB)
driver.browsing_context.close(post_context)

with helper.data_lock:
new_events = len(helper.events_received) - events_before

assert new_events == 0, f"Expected 0 new events after removal, got {new_events}"
assert not any(e.context == post_context for e in helper.events_received), "Handlers fired after removal"
Loading
Loading