From 57d0bacb5b9d89c918ede6a0b32e683b0bccf664 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 12 Jun 2026 10:34:14 +0200 Subject: [PATCH 1/3] Add WebSocket support with httpx2.websocket(), Client.websocket() and AsyncClient.websocket() Vendor httpx-ws (MIT, by @frankie567) into an isolated httpx2/_websockets package, exposing WebSocketSession, AsyncWebSocketSession, ASGIWebSocketTransport and the WebSocket exception hierarchy from the top-level httpx2 namespace. --- pyproject.toml | 2 + src/httpx2/httpx2/__init__.py | 10 + src/httpx2/httpx2/_api.py | 63 ++ src/httpx2/httpx2/_client.py | 145 ++++ src/httpx2/httpx2/_websockets/__init__.py | 20 + src/httpx2/httpx2/_websockets/_exceptions.py | 66 ++ src/httpx2/httpx2/_websockets/_ping.py | 36 + src/httpx2/httpx2/_websockets/_session.py | 842 +++++++++++++++++++ src/httpx2/httpx2/_websockets/_transport.py | 296 +++++++ src/httpx2/pyproject.toml | 3 +- tests/httpx2/websockets/__init__.py | 0 tests/httpx2/websockets/conftest.py | 67 ++ tests/httpx2/websockets/test_session.py | 833 ++++++++++++++++++ tests/httpx2/websockets/test_transport.py | 379 +++++++++ uv.lock | 99 ++- 15 files changed, 2859 insertions(+), 2 deletions(-) create mode 100644 src/httpx2/httpx2/_websockets/__init__.py create mode 100644 src/httpx2/httpx2/_websockets/_exceptions.py create mode 100644 src/httpx2/httpx2/_websockets/_ping.py create mode 100644 src/httpx2/httpx2/_websockets/_session.py create mode 100644 src/httpx2/httpx2/_websockets/_transport.py create mode 100644 tests/httpx2/websockets/__init__.py create mode 100644 tests/httpx2/websockets/conftest.py create mode 100644 tests/httpx2/websockets/test_session.py create mode 100644 tests/httpx2/websockets/test_transport.py diff --git a/pyproject.toml b/pyproject.toml index 08859dbc..c1b3f74d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,10 +24,12 @@ dev = [ "pytest-codspeed>=4.1.1", "pytest-httpbin==2.0.0", "pytest-trio==0.8.0", + "starlette>=0.49", "trio==0.31.0", "trio-typing==0.10.0", "trustme==1.2.1", "uvicorn>=0.35", + "websockets>=15", "werkzeug>=3.1.6", # Linting "mypy==1.17.1", diff --git a/src/httpx2/httpx2/__init__.py b/src/httpx2/httpx2/__init__.py index 068e0a25..d7a8d8d0 100644 --- a/src/httpx2/httpx2/__init__.py +++ b/src/httpx2/httpx2/__init__.py @@ -10,16 +10,19 @@ from ._transports import * from ._types import * from ._urls import * +from ._websockets import * __all__ = [ "__description__", "__title__", "__version__", "ASGITransport", + "ASGIWebSocketTransport", "AsyncBaseTransport", "AsyncByteStream", "AsyncClient", "AsyncHTTPTransport", + "AsyncWebSocketSession", "Auth", "BaseTransport", "BasicAuth", @@ -78,6 +81,13 @@ "UnsupportedProtocol", "URL", "USE_CLIENT_DEFAULT", + "websocket", + "WebSocketDisconnect", + "WebSocketException", + "WebSocketInvalidTypeReceived", + "WebSocketNetworkError", + "WebSocketSession", + "WebSocketUpgradeError", "WriteError", "WriteTimeout", "WSGITransport", diff --git a/src/httpx2/httpx2/_api.py b/src/httpx2/httpx2/_api.py index 25171cbc..05316c95 100644 --- a/src/httpx2/httpx2/_api.py +++ b/src/httpx2/httpx2/_api.py @@ -19,6 +19,13 @@ TimeoutTypes, ) from ._urls import URL +from ._websockets._session import ( + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + DEFAULT_MAX_MESSAGE_SIZE_BYTES, + DEFAULT_QUEUE_SIZE, + WebSocketSession, +) if typing.TYPE_CHECKING: import ssl # pragma: no cover @@ -34,6 +41,7 @@ "put", "request", "stream", + "websocket", ] @@ -159,6 +167,61 @@ def stream( yield response +@contextmanager +def websocket( + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | None = None, + proxy: ProxyTypes | None = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + verify: ssl.SSLContext | str | bool = True, + trust_env: bool = True, + subprotocols: list[str] | None = None, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, +) -> Generator[WebSocketSession]: + """ + Open a WebSocket session. + + The session is closed automatically when exiting the context manager. + + ``` + >>> import httpx2 + >>> with httpx2.websocket("wss://echo.websocket.org") as ws: + ... ws.send_text("Hello!") + ... message = ws.receive_text() + ``` + + **Parameters**: See `httpx2.request` and `httpx2.Client.websocket`. + """ + with Client( + cookies=cookies, + proxy=proxy, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) as client: + with client.websocket( + url, + params=params, + headers=headers, + auth=auth, + follow_redirects=follow_redirects, + subprotocols=subprotocols, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + ) as session: + yield session + + def get( url: URL | str, *, diff --git a/src/httpx2/httpx2/_client.py b/src/httpx2/httpx2/_client.py index 18720ee6..f935d8d3 100644 --- a/src/httpx2/httpx2/_client.py +++ b/src/httpx2/httpx2/_client.py @@ -48,6 +48,16 @@ ) from ._urls import URL, QueryParams from ._utils import URLPattern, get_environment_proxies +from ._websockets._session import ( + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + DEFAULT_MAX_MESSAGE_SIZE_BYTES, + DEFAULT_QUEUE_SIZE, + AsyncWebSocketSession, + WebSocketSession, + aconnect_ws, + connect_ws, +) if typing.TYPE_CHECKING: import ssl # pragma: no cover @@ -845,6 +855,71 @@ def stream( finally: response.close() + @contextmanager + def websocket( + self, + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + subprotocols: list[str] | None = None, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + ) -> Generator[WebSocketSession]: + """ + Open a WebSocket session. + + The session is closed automatically when exiting the context manager. + + ```python + with httpx2.Client() as client: + with client.websocket("wss://example.com/ws") as ws: + ws.send_text("Hello!") + message = ws.receive_text() + ``` + + **Parameters**: See `httpx2.request` for the request parameters, plus: + + * **subprotocols** - *(optional)* A list of subprotocols to negotiate with the server. + * **max_message_size_bytes** - Message size in bytes to receive from the server. Defaults to 64 KiB. + * **queue_size** - Size of the queue where the received messages will be held + until they are consumed. If the queue is full, the client will stop receiving + messages from the server until the queue has room available. Defaults to 512. + * **keepalive_ping_interval_seconds** - Interval at which the client will automatically + send a Ping event to keep the connection alive. Set it to `None` to disable + this mechanism. Defaults to 20 seconds. + * **keepalive_ping_timeout_seconds** - Maximum delay the client will wait for an answer + to its Ping event. If the delay is exceeded, `httpx2.WebSocketNetworkError` will be + raised and the connection closed. Defaults to 20 seconds. + + Raises `httpx2.WebSocketUpgradeError` if the connection didn't correctly + upgrade to a WebSocket session. + """ + with connect_ws( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + subprotocols=subprotocols, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + ) as session: + yield session + def send( self, request: Request, @@ -1548,6 +1623,76 @@ async def stream( finally: await response.aclose() + @asynccontextmanager + async def websocket( + self, + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + subprotocols: list[str] | None = None, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + ) -> AsyncGenerator[AsyncWebSocketSession]: + """ + Open a WebSocket session. + + The session is closed automatically when exiting the context manager. + + ```python + async with httpx2.AsyncClient() as client: + async with client.websocket("wss://example.com/ws") as ws: + await ws.send_text("Hello!") + message = await ws.receive_text() + ``` + + Internally, the session uses an anyio task group to manage background tasks. + As a result, exceptions that are not caught inside the context manager and + propagate out of the `async with` block will be wrapped in an `ExceptionGroup`. + Use the `except*` syntax to handle them. + + **Parameters**: See `httpx2.request` for the request parameters, plus: + + * **subprotocols** - *(optional)* A list of subprotocols to negotiate with the server. + * **max_message_size_bytes** - Message size in bytes to receive from the server. Defaults to 64 KiB. + * **queue_size** - Size of the queue where the received messages will be held + until they are consumed. If the queue is full, the client will stop receiving + messages from the server until the queue has room available. Defaults to 512. + * **keepalive_ping_interval_seconds** - Interval at which the client will automatically + send a Ping event to keep the connection alive. Set it to `None` to disable + this mechanism. Defaults to 20 seconds. + * **keepalive_ping_timeout_seconds** - Maximum delay the client will wait for an answer + to its Ping event. If the delay is exceeded, `httpx2.WebSocketNetworkError` will be + raised and the connection closed. Defaults to 20 seconds. + + Raises `httpx2.WebSocketUpgradeError` if the connection didn't correctly + upgrade to a WebSocket session. + """ + async with aconnect_ws( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + subprotocols=subprotocols, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + ) as session: + yield session + async def send( self, request: Request, diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py new file mode 100644 index 00000000..1a227924 --- /dev/null +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -0,0 +1,20 @@ +from ._exceptions import ( + WebSocketDisconnect, + WebSocketException, + WebSocketInvalidTypeReceived, + WebSocketNetworkError, + WebSocketUpgradeError, +) +from ._session import AsyncWebSocketSession, WebSocketSession +from ._transport import ASGIWebSocketTransport + +__all__ = [ + "ASGIWebSocketTransport", + "AsyncWebSocketSession", + "WebSocketDisconnect", + "WebSocketException", + "WebSocketInvalidTypeReceived", + "WebSocketNetworkError", + "WebSocketSession", + "WebSocketUpgradeError", +] diff --git a/src/httpx2/httpx2/_websockets/_exceptions.py b/src/httpx2/httpx2/_websockets/_exceptions.py new file mode 100644 index 00000000..28d5aa35 --- /dev/null +++ b/src/httpx2/httpx2/_websockets/_exceptions.py @@ -0,0 +1,66 @@ +""" +Our exception hierarchy: + +* WebSocketException + x WebSocketUpgradeError + x WebSocketDisconnect + x WebSocketInvalidTypeReceived + x WebSocketNetworkError +""" + +from __future__ import annotations + +import typing + +if typing.TYPE_CHECKING: + import wsproto + + from .._models import Response # pragma: no cover + +__all__ = [ + "WebSocketDisconnect", + "WebSocketException", + "WebSocketInvalidTypeReceived", + "WebSocketNetworkError", + "WebSocketUpgradeError", +] + + +class WebSocketException(Exception): + """ + Base class for all WebSocket exceptions. + """ + + +class WebSocketUpgradeError(WebSocketException): + """ + The initial connection didn't correctly upgrade to a WebSocket session. + """ + + def __init__(self, response: Response) -> None: + self.response = response + + +class WebSocketDisconnect(WebSocketException): + """ + The server closed the WebSocket session. + """ + + def __init__(self, code: int = 1000, reason: str | None = None) -> None: + self.code = code + self.reason = reason or "" + + +class WebSocketInvalidTypeReceived(WebSocketException): + """ + A received event was not of the expected type. + """ + + def __init__(self, event: wsproto.events.Event) -> None: + self.event = event + + +class WebSocketNetworkError(WebSocketException): + """ + A network error occurred, typically because the underlying stream has closed or timed out. + """ diff --git a/src/httpx2/httpx2/_websockets/_ping.py b/src/httpx2/httpx2/_websockets/_ping.py new file mode 100644 index 00000000..1aab637e --- /dev/null +++ b/src/httpx2/httpx2/_websockets/_ping.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import secrets +import threading + +import anyio + + +class PingManager: + def __init__(self) -> None: + self._pings: dict[bytes, threading.Event] = {} + + def create(self, ping_id: bytes | None = None) -> tuple[bytes, threading.Event]: + ping_id = secrets.token_bytes() if not ping_id else ping_id + event = threading.Event() + self._pings[ping_id] = event + return ping_id, event + + def ack(self, ping_id: bytes | bytearray) -> None: + event = self._pings.pop(bytes(ping_id)) + event.set() + + +class AsyncPingManager: + def __init__(self) -> None: + self._pings: dict[bytes, anyio.Event] = {} + + def create(self, ping_id: bytes | None = None) -> tuple[bytes, anyio.Event]: + ping_id = secrets.token_bytes() if not ping_id else ping_id + event = anyio.Event() + self._pings[ping_id] = event + return ping_id, event + + def ack(self, ping_id: bytes | bytearray) -> None: + event = self._pings.pop(bytes(ping_id)) + event.set() diff --git a/src/httpx2/httpx2/_websockets/_session.py b/src/httpx2/httpx2/_websockets/_session.py new file mode 100644 index 00000000..5e6f2846 --- /dev/null +++ b/src/httpx2/httpx2/_websockets/_session.py @@ -0,0 +1,842 @@ +from __future__ import annotations + +import base64 +import concurrent.futures +import contextlib +import json +import queue +import secrets +import threading +import typing +from collections.abc import AsyncGenerator, Generator +from types import TracebackType + +import anyio +import wsproto +import wsproto.utilities +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from wsproto.frame_protocol import CloseReason + +import httpcore2 +from httpcore2 import AsyncNetworkStream, NetworkStream + +from .._models import Headers +from .._urls import URL +from ._exceptions import ( + WebSocketDisconnect, + WebSocketException, + WebSocketInvalidTypeReceived, + WebSocketNetworkError, + WebSocketUpgradeError, +) +from ._ping import AsyncPingManager, PingManager +from ._transport import ASGIWebSocketAsyncNetworkStream + +if typing.TYPE_CHECKING: + from .._client import AsyncClient, Client, UseClientDefault + from .._models import Response + from .._types import AuthTypes, CookieTypes, HeaderTypes, QueryParamTypes, RequestExtensions, TimeoutTypes + +JSONMode = typing.Literal["text", "binary"] +TaskResult = typing.TypeVar("TaskResult") + +DEFAULT_MAX_MESSAGE_SIZE_BYTES = 65_536 +DEFAULT_QUEUE_SIZE = 512 +DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS = 20.0 +DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS = 20.0 + + +class ShouldClose(Exception): + pass + + +class EndOfStream(Exception): + pass + + +class WebSocketSession: + """ + Sync context manager representing an opened WebSocket session. + + Attributes: + subprotocol: Optional protocol that has been accepted by the server. + response: The WebSocket handshake response. + """ + + subprotocol: str | None + response: Response | None + + def __init__( + self, + stream: NetworkStream, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + response: Response | None = None, + ) -> None: + self.stream = stream + self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) + self.response = response + if self.response is not None: + self.subprotocol = self.response.headers.get("sec-websocket-protocol") + else: + self.subprotocol = None + + self._events: queue.Queue[wsproto.events.Event | WebSocketException] = queue.Queue(queue_size) + + self._ping_manager = PingManager() + self._should_close = threading.Event() + self._write_lock = threading.Lock() + self._should_close_task: concurrent.futures.Future[bool] | None = None + self._executor: concurrent.futures.ThreadPoolExecutor | None = None + + self._max_message_size_bytes = max_message_size_bytes + self._queue_size = queue_size + self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds + self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds + + def _get_executor_should_close_task( + self, + ) -> tuple[concurrent.futures.ThreadPoolExecutor, concurrent.futures.Future[bool]]: + if self._should_close_task is None: + self._executor = concurrent.futures.ThreadPoolExecutor() + self._should_close_task = self._executor.submit(self._should_close.wait) + assert self._executor is not None + return self._executor, self._should_close_task + + def __enter__(self) -> WebSocketSession: + self._background_receive_task = threading.Thread( + target=self._background_receive, args=(self._max_message_size_bytes,) + ) + self._background_receive_task.start() + + self._background_keepalive_ping_task: threading.Thread | None = None + if self._keepalive_ping_interval_seconds is not None: + self._background_keepalive_ping_task = threading.Thread( + target=self._background_keepalive_ping, + args=( + self._keepalive_ping_interval_seconds, + self._keepalive_ping_timeout_seconds, + ), + ) + self._background_keepalive_ping_task.start() + + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.close() + self._background_receive_task.join() + if self._background_keepalive_ping_task is not None: + self._background_keepalive_ping_task.join() + + def ping(self, payload: bytes = b"") -> threading.Event: + """ + Send a Ping message. + + The payload is used internally to track this specific event. + If left empty, a random one will be generated. + + Returns an event that can be used to wait for the corresponding Pong response: + + ```python + pong_callback = ws.ping() + pong_callback.wait() + ``` + """ + ping_id, callback = self._ping_manager.create(payload) + event = wsproto.events.Ping(ping_id) + self.send(event) + return callback + + def send(self, event: wsproto.events.Event) -> None: + """ + Send a raw `wsproto.events.Event`. + + Mainly useful to send events that are not supported by the library. + Most of the time, `ping()`, `send_text()`, `send_bytes()` and `send_json()` are preferred. + + Raises `WebSocketNetworkError` if a network error occurred. + """ + try: + data = self.connection.send(event) + with self._write_lock: + self.stream.write(data) + except httpcore2.WriteError as e: + self.close(CloseReason.INTERNAL_ERROR, "Stream write error") + raise WebSocketNetworkError() from e + + def send_text(self, data: str) -> None: + """ + Send a text message. + + Raises `WebSocketNetworkError` if a network error occurred. + """ + event = wsproto.events.TextMessage(data=data) + self.send(event) + + def send_bytes(self, data: bytes) -> None: + """ + Send a bytes message. + + Raises `WebSocketNetworkError` if a network error occurred. + """ + event = wsproto.events.BytesMessage(data=data) + self.send(event) + + def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: + """ + Send JSON data, serialized with `json.dumps()`, in `'text'` or `'binary'` mode. + + Raises `WebSocketNetworkError` if a network error occurred. + """ + assert mode in ["text", "binary"] + serialized_data = json.dumps(data) + if mode == "text": + self.send_text(serialized_data) + else: + self.send_bytes(serialized_data.encode("utf-8")) + + def receive(self, timeout: float | None = None) -> wsproto.events.Event: + """ + Receive a raw `wsproto.events.Event` from the server. + + Mainly useful to receive raw events. Most of the time, + `receive_text()`, `receive_bytes()`, and `receive_json()` are preferred. + + If `timeout` is `None`, this blocks until an event is available. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + """ + try: + event = self._events.get(block=True, timeout=timeout) + except queue.Empty as e: + raise TimeoutError from e + if isinstance(event, WebSocketException): + raise event + if isinstance(event, wsproto.events.CloseConnection): + raise WebSocketDisconnect(event.code, event.reason) + return event + + def receive_text(self, timeout: float | None = None) -> str: + """ + Receive text from the server. + + If `timeout` is `None`, this blocks until an event is available. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + WebSocketInvalidTypeReceived: The received event was not a text message. + """ + event = self.receive(timeout) + if isinstance(event, wsproto.events.TextMessage): + return event.data + raise WebSocketInvalidTypeReceived(event) + + def receive_bytes(self, timeout: float | None = None) -> bytes: + """ + Receive bytes from the server. + + If `timeout` is `None`, this blocks until an event is available. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + WebSocketInvalidTypeReceived: The received event was not a bytes message. + """ + event = self.receive(timeout) + if isinstance(event, wsproto.events.BytesMessage): + return bytes(event.data) + raise WebSocketInvalidTypeReceived(event) + + def receive_json(self, timeout: float | None = None, mode: JSONMode = "text") -> typing.Any: + """ + Receive JSON data from the server, parsed with `json.loads()`, in `'text'` or `'binary'` mode. + + If `timeout` is `None`, this blocks until an event is available. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + WebSocketInvalidTypeReceived: The received event didn't correspond to the specified mode. + """ + assert mode in ["text", "binary"] + data: str | bytes + if mode == "text": + data = self.receive_text(timeout) + else: + data = self.receive_bytes(timeout) + return json.loads(data) + + def close(self, code: int = 1000, reason: str | None = None) -> None: + """ + Close the WebSocket session. + + Internally, it'll send the `wsproto.events.CloseConnection` event. + + *This method is automatically called when exiting the context manager.* + """ + self._should_close.set() + if self._executor is not None: + self._executor.shutdown(False) + if self.connection.state not in { + wsproto.connection.ConnectionState.LOCAL_CLOSING, + wsproto.connection.ConnectionState.CLOSED, + }: + event = wsproto.events.CloseConnection(code, reason) + data = self.connection.send(event) + try: + with self._write_lock: + self.stream.write(data) + except httpcore2.WriteError: + pass + self.stream.close() + + def _background_receive(self, max_bytes: int) -> None: + """ + Background thread listening for data from the server. + + Internally, it'll: + + * Answer to Ping events. + * Acknowledge Pong events. + * Put other events in the `_events` queue that'll eventually be consumed by the user. + """ + partial_message_buffer: str | bytes | None = None + try: + while not self._should_close.is_set(): + data = self._wait_until_closed(self._read_stream, max_bytes) + self.connection.receive_data(data) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Ping): + data = self.connection.send(event.response()) + with self._write_lock: + self.stream.write(data) + continue + if isinstance(event, wsproto.events.Pong): + self._ping_manager.ack(event.payload) + continue + if isinstance(event, wsproto.events.CloseConnection): + self._should_close.set() + if isinstance(event, wsproto.events.Message): + # Unfinished message: bufferize + if not event.message_finished: + if partial_message_buffer is None: + partial_message_buffer = event.data + else: + partial_message_buffer += event.data + # Finished message but no buffer: just emit the event + elif partial_message_buffer is None: + self._events.put(event) + # Finished message with buffer: emit the full event + else: + event_type = type(event) + full_message_event = event_type(partial_message_buffer + event.data) + partial_message_buffer = None + self._events.put(full_message_event) + continue + self._events.put(event) + except (httpcore2.ReadError, httpcore2.WriteError, EndOfStream): + self.close(CloseReason.INTERNAL_ERROR, "Stream error") + self._events.put(WebSocketNetworkError()) + except ShouldClose: + pass + + def _background_keepalive_ping(self, interval_seconds: float, timeout_seconds: float | None = None) -> None: + try: + while not self._should_close.is_set(): + should_close = self._wait_until_closed(self._should_close.wait, interval_seconds) + if should_close: + raise ShouldClose() + pong_callback = self.ping() + if timeout_seconds is not None: + acknowledged = self._wait_until_closed(pong_callback.wait, timeout_seconds) + if not acknowledged: + self.close(CloseReason.INTERNAL_ERROR, "Keepalive ping timeout") + self._events.put(WebSocketNetworkError()) + except ShouldClose: + pass + + def _wait_until_closed( + self, callable: typing.Callable[..., TaskResult], *args: typing.Any, **kwargs: typing.Any + ) -> TaskResult: + try: + executor, should_close_task = self._get_executor_should_close_task() + todo_task = executor.submit(callable, *args, **kwargs) + except RuntimeError as e: + raise ShouldClose() from e + else: + done, _ = concurrent.futures.wait( + (todo_task, should_close_task), # type: ignore[misc] + return_when=concurrent.futures.FIRST_COMPLETED, + ) + if should_close_task in done: + raise ShouldClose() + assert todo_task in done + result = todo_task.result() + return result + + def _read_stream(self, max_bytes: int) -> bytes: + data = self.stream.read(max_bytes) + if data == b"": + raise EndOfStream() + return data + + +class AsyncWebSocketSession(anyio.AsyncContextManagerMixin): + """ + Async context manager representing an opened WebSocket session. + + Internally, this session uses an anyio task group to manage background tasks. + As a result, exceptions that are not caught inside the context manager + and propagate out of the `async with` block will be wrapped in an `ExceptionGroup`. + + To handle them, use the `except*` syntax: + + ```python + async with AsyncWebSocketSession(stream) as ws: + try: + data = await ws.receive_text() + except WebSocketDisconnect: + # Caught inside the context manager: plain exception. + print("Connection closed") + + # If not caught inside: + try: + async with AsyncWebSocketSession(stream) as ws: + data = await ws.receive_text() + except* WebSocketDisconnect: + # Propagated out of the context manager: wrapped in ExceptionGroup. + print("Connection closed") + ``` + + Attributes: + subprotocol: Optional protocol that has been accepted by the server. + response: The WebSocket handshake response. + """ + + subprotocol: str | None + response: Response | None + _send_event: MemoryObjectSendStream[wsproto.events.Event | WebSocketException] + _receive_event: MemoryObjectReceiveStream[wsproto.events.Event | WebSocketException] + + def __init__( + self, + stream: AsyncNetworkStream, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + response: Response | None = None, + ) -> None: + self.stream = stream + self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) + self.response = response + if self.response is not None: + self.subprotocol = self.response.headers.get("sec-websocket-protocol") + else: + self.subprotocol = None + + self._ping_manager = AsyncPingManager() + self._should_close = anyio.Event() + self._write_lock = anyio.Lock() + + self._max_message_size_bytes = max_message_size_bytes + self._queue_size = queue_size + + # Always disable keepalive ping when emulating ASGI + if isinstance(stream, ASGIWebSocketAsyncNetworkStream): + self._keepalive_ping_interval_seconds = None + self._keepalive_ping_timeout_seconds = None + else: + self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds + self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds + + @contextlib.asynccontextmanager + async def __asynccontextmanager__(self) -> AsyncGenerator[AsyncWebSocketSession]: + self._send_event, self._receive_event = anyio.create_memory_object_stream[ + wsproto.events.Event | WebSocketException + ]() + self._background_task_group = anyio.create_task_group() + + async with self._send_event, self._receive_event, self._background_task_group: + self._background_task_group.start_soon(self._background_receive, self._max_message_size_bytes) + if self._keepalive_ping_interval_seconds is not None: + self._background_task_group.start_soon( + self._background_keepalive_ping, + self._keepalive_ping_interval_seconds, + self._keepalive_ping_timeout_seconds, + ) + + try: + yield self + finally: + self._background_task_group.cancel_scope.cancel() + with anyio.CancelScope(shield=True): + await self.close() + + async def ping(self, payload: bytes = b"") -> anyio.Event: + """ + Send a Ping message. + + The payload is used internally to track this specific event. + If left empty, a random one will be generated. + + Returns an event that can be used to wait for the corresponding Pong response: + + ```python + pong_callback = await ws.ping() + await pong_callback.wait() + ``` + """ + ping_id, callback = self._ping_manager.create(payload) + event = wsproto.events.Ping(ping_id) + await self.send(event) + return callback + + async def send(self, event: wsproto.events.Event) -> None: + """ + Send a raw `wsproto.events.Event`. + + Mainly useful to send events that are not supported by the library. + Most of the time, `ping()`, `send_text()`, `send_bytes()` and `send_json()` are preferred. + + Raises `WebSocketNetworkError` if a network error occurred. + """ + try: + data = self.connection.send(event) + async with self._write_lock: + await self.stream.write(data) + except httpcore2.WriteError as e: + await self.close(CloseReason.INTERNAL_ERROR, "Stream write error") + raise WebSocketNetworkError() from e + + async def send_text(self, data: str) -> None: + """ + Send a text message. + + Raises `WebSocketNetworkError` if a network error occurred. + """ + event = wsproto.events.TextMessage(data=data) + await self.send(event) + + async def send_bytes(self, data: bytes) -> None: + """ + Send a bytes message. + + Raises `WebSocketNetworkError` if a network error occurred. + """ + event = wsproto.events.BytesMessage(data=data) + await self.send(event) + + async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: + """ + Send JSON data, serialized with `json.dumps()`, in `'text'` or `'binary'` mode. + + Raises `WebSocketNetworkError` if a network error occurred. + """ + assert mode in ["text", "binary"] + serialized_data = json.dumps(data) + if mode == "text": + await self.send_text(serialized_data) + else: + await self.send_bytes(serialized_data.encode("utf-8")) + + async def receive(self, timeout: float | None = None) -> wsproto.events.Event: + """ + Receive a raw `wsproto.events.Event` from the server. + + Mainly useful to receive raw events. Most of the time, + `receive_text()`, `receive_bytes()`, and `receive_json()` are preferred. + + If `timeout` is `None`, this blocks until an event is available. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + """ + with anyio.fail_after(timeout): + event = await self._receive_event.receive() + if isinstance(event, WebSocketException): + raise event + if isinstance(event, wsproto.events.CloseConnection): + raise WebSocketDisconnect(event.code, event.reason) + return event + + async def receive_text(self, timeout: float | None = None) -> str: + """ + Receive text from the server. + + If `timeout` is `None`, this blocks until an event is available. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + WebSocketInvalidTypeReceived: The received event was not a text message. + """ + event = await self.receive(timeout) + if isinstance(event, wsproto.events.TextMessage): + return event.data + raise WebSocketInvalidTypeReceived(event) + + async def receive_bytes(self, timeout: float | None = None) -> bytes: + """ + Receive bytes from the server. + + If `timeout` is `None`, this blocks until an event is available. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + WebSocketInvalidTypeReceived: The received event was not a bytes message. + """ + event = await self.receive(timeout) + if isinstance(event, wsproto.events.BytesMessage): + return bytes(event.data) + raise WebSocketInvalidTypeReceived(event) + + async def receive_json(self, timeout: float | None = None, mode: JSONMode = "text") -> typing.Any: + """ + Receive JSON data from the server, parsed with `json.loads()`, in `'text'` or `'binary'` mode. + + If `timeout` is `None`, this blocks until an event is available. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + WebSocketInvalidTypeReceived: The received event didn't correspond to the specified mode. + """ + assert mode in ["text", "binary"] + data: str | bytes + if mode == "text": + data = await self.receive_text(timeout) + else: + data = await self.receive_bytes(timeout) + return json.loads(data) + + async def close(self, code: int = 1000, reason: str | None = None) -> None: + """ + Close the WebSocket session. + + Internally, it'll send the `wsproto.events.CloseConnection` event. + + *This method is automatically called when exiting the context manager.* + """ + self._should_close.set() + if self.connection.state not in { + wsproto.connection.ConnectionState.LOCAL_CLOSING, + wsproto.connection.ConnectionState.CLOSED, + }: + event = wsproto.events.CloseConnection(code, reason) + data = self.connection.send(event) + try: + async with self._write_lock: + await self.stream.write(data) + except httpcore2.WriteError: + pass + await self.stream.aclose() + + async def _background_receive(self, max_bytes: int) -> None: + """ + Background task listening for data from the server. + + Internally, it'll: + + * Answer to Ping events. + * Acknowledge Pong events. + * Put other events in the `_events` queue that'll eventually be consumed by the user. + """ + partial_message_buffer: str | bytes | None = None + try: + while not self._should_close.is_set(): + data = await self._read_stream(max_bytes) + self.connection.receive_data(data) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Ping): + data = self.connection.send(event.response()) + async with self._write_lock: + await self.stream.write(data) + continue + if isinstance(event, wsproto.events.Pong): + self._ping_manager.ack(event.payload) + continue + if isinstance(event, wsproto.events.CloseConnection): + self._should_close.set() + if isinstance(event, wsproto.events.Message): + # Unfinished message: bufferize + if not event.message_finished: + if partial_message_buffer is None: + partial_message_buffer = event.data + else: + partial_message_buffer += event.data + # Finished message but no buffer: just emit the event + elif partial_message_buffer is None: + await self._send_event.send(event) + # Finished message with buffer: emit the full event + else: + event_type = type(event) + full_message_event = event_type(partial_message_buffer + event.data) + partial_message_buffer = None + await self._send_event.send(full_message_event) + continue + await self._send_event.send(event) + except (httpcore2.ReadError, httpcore2.WriteError, EndOfStream): + await self.close(CloseReason.INTERNAL_ERROR, "Stream error") + await self._send_event.send(WebSocketNetworkError()) + + async def _background_keepalive_ping(self, interval_seconds: float, timeout_seconds: float | None = None) -> None: + while not self._should_close.is_set(): + await anyio.sleep(interval_seconds) + + try: + pong_callback = await self.ping() + # Connection is closing, exit the task + except wsproto.utilities.LocalProtocolError: + return + + if timeout_seconds is not None: + try: + with anyio.fail_after(timeout_seconds): + await pong_callback.wait() + except TimeoutError: + await self.close(CloseReason.INTERNAL_ERROR, "Keepalive ping timeout") + await self._send_event.send(WebSocketNetworkError()) + + async def _read_stream(self, max_bytes: int) -> bytes: + data = await self.stream.read(max_bytes) + if data == b"": + raise EndOfStream() + return data + + +def _get_headers(subprotocols: list[str] | None) -> dict[str, str]: + headers = { + "connection": "upgrade", + "upgrade": "websocket", + "sec-websocket-key": base64.b64encode(secrets.token_bytes(16)).decode("utf-8"), + "sec-websocket-version": "13", + } + if subprotocols is not None: + headers["sec-websocket-protocol"] = ", ".join(subprotocols) + return headers + + +def _get_url(url: URL | str) -> URL: + url = URL(url) + if url.scheme == "ws": + return url.copy_with(scheme="http") + if url.scheme == "wss": + return url.copy_with(scheme="https") + return url + + +@contextlib.contextmanager +def connect_ws( + client: Client, + url: URL | str, + *, + params: QueryParamTypes | None, + headers: HeaderTypes | None, + cookies: CookieTypes | None, + auth: AuthTypes | UseClientDefault | None, + follow_redirects: bool | UseClientDefault, + timeout: TimeoutTypes | UseClientDefault, + extensions: RequestExtensions | None, + subprotocols: list[str] | None, + max_message_size_bytes: int, + queue_size: int, + keepalive_ping_interval_seconds: float | None, + keepalive_ping_timeout_seconds: float | None, +) -> Generator[WebSocketSession]: + merged_headers = Headers(headers) + merged_headers.update(_get_headers(subprotocols)) + + with client.stream( + "GET", + _get_url(url), + params=params, + headers=merged_headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) as response: + if response.status_code != 101: + raise WebSocketUpgradeError(response) + + session = WebSocketSession( + response.extensions["network_stream"], + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + response=response, + ) + with session: + yield session + + +@contextlib.asynccontextmanager +async def aconnect_ws( + client: AsyncClient, + url: URL | str, + *, + params: QueryParamTypes | None, + headers: HeaderTypes | None, + cookies: CookieTypes | None, + auth: AuthTypes | UseClientDefault | None, + follow_redirects: bool | UseClientDefault, + timeout: TimeoutTypes | UseClientDefault, + extensions: RequestExtensions | None, + subprotocols: list[str] | None, + max_message_size_bytes: int, + queue_size: int, + keepalive_ping_interval_seconds: float | None, + keepalive_ping_timeout_seconds: float | None, +) -> AsyncGenerator[AsyncWebSocketSession]: + merged_headers = Headers(headers) + merged_headers.update(_get_headers(subprotocols)) + + async with client.stream( + "GET", + _get_url(url), + params=params, + headers=merged_headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) as response: + if response.status_code != 101: + raise WebSocketUpgradeError(response) + + session = AsyncWebSocketSession( + response.extensions["network_stream"], + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + response=response, + ) + async with session: + yield session diff --git a/src/httpx2/httpx2/_websockets/_transport.py b/src/httpx2/httpx2/_websockets/_transport.py new file mode 100644 index 00000000..607275c8 --- /dev/null +++ b/src/httpx2/httpx2/_websockets/_transport.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +import contextlib +import math +import typing +from types import TracebackType + +import anyio +import anyio.abc +import anyio.streams.stapled +import wsproto +from wsproto.frame_protocol import CloseReason + +from httpcore2 import AsyncNetworkStream + +from .._models import Request, Response +from .._transports.asgi import ASGITransport +from .._types import AsyncByteStream +from ._exceptions import WebSocketDisconnect, WebSocketUpgradeError + +Scope = typing.MutableMapping[str, typing.Any] +Message = typing.MutableMapping[str, typing.Any] +Receive = typing.Callable[[], typing.Awaitable[Message]] +Send = typing.Callable[[Message], typing.Awaitable[None]] +ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] + + +class ASGIWebSocketTransportError(Exception): + pass + + +class UnhandledASGIMessageType(ASGIWebSocketTransportError): + def __init__(self, message: Message) -> None: + self.message = message + + +class UnhandledWebSocketEvent(ASGIWebSocketTransportError): + def __init__(self, event: wsproto.events.Event) -> None: + self.event = event + + +class ASGIWebSocketAsyncNetworkStream(AsyncNetworkStream): + def __init__( + self, + app: ASGIApp, + scope: Scope, + task_group: anyio.abc.TaskGroup, + initial_receive_timeout: float = 1.0, + ) -> None: + self.app = app + self.scope = scope + self._receive_queue = anyio.streams.stapled.StapledObjectStream( + *anyio.create_memory_object_stream[Message](max_buffer_size=math.inf) + ) + self._send_queue = anyio.streams.stapled.StapledObjectStream( + *anyio.create_memory_object_stream[Message](max_buffer_size=math.inf) + ) + self._task_group = task_group + self._initial_receive_timeout = initial_receive_timeout + self.connection = wsproto.WSConnection(wsproto.ConnectionType.SERVER) + self.connection.initiate_upgrade_connection(scope["headers"], scope["path"]) + self._aentered = False + + async def __aenter__(self) -> tuple[ASGIWebSocketAsyncNetworkStream, bytes]: + if self._aentered: + raise RuntimeError("Cannot use ASGIWebSocketAsyncNetworkStream in a context manager twice") + self._aentered = True + self._task_group.start_soon(self._run) + async with contextlib.AsyncExitStack() as stack: + stack.push_async_callback(self.aclose) + + await self.send({"type": "websocket.connect"}) + + try: + message = await self.receive(self._initial_receive_timeout) + except TimeoutError as e: + raise RuntimeError( + "WebSocket didn't accept the connection in time. Did you forget to call accept()?" + ) from e + + if message["type"] == "websocket.close": + await stack.aclose() + raise WebSocketDisconnect(message["code"], message.get("reason")) + + # Websocket Denial Response extension + # Ref: https://asgi.readthedocs.io/en/latest/extensions.html#websocket-denial-response + if message["type"] == "websocket.http.response.start": + status_code: int = message["status"] + headers: list[tuple[bytes, bytes]] = message["headers"] + body: list[bytes] = [] + while True: + message = await self.receive() + assert message["type"] == "websocket.http.response.body" + body.append(message["body"]) + if not message.get("more_body", False): + break + + await stack.aclose() + raise WebSocketUpgradeError(Response(status_code, headers=headers, content=b"".join(body))) + + assert message["type"] == "websocket.accept" + retval = self, self._build_accept_response(message) + self._exit_stack = stack.pop_all() + return retval + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + message = await self.receive(timeout=timeout) + message_type = message["type"] + + if message_type not in {"websocket.send", "websocket.close"}: + raise UnhandledASGIMessageType(message) + + event: wsproto.events.Event + if message_type == "websocket.send": + data_str: str | None = message.get("text") + if data_str is not None: + event = wsproto.events.TextMessage(data_str) + data_bytes: bytes | None = message.get("bytes") + if data_bytes is not None: + event = wsproto.events.BytesMessage(data_bytes) + else: + event = wsproto.events.CloseConnection(message["code"], message["reason"]) + + return self.connection.send(event) + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + self.connection.receive_data(buffer) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Request): + pass + elif isinstance(event, wsproto.events.CloseConnection): + await self.send( + { + "type": "websocket.disconnect", + "code": event.code, + "reason": event.reason, + } + ) + elif isinstance(event, wsproto.events.TextMessage): + await self.send({"type": "websocket.receive", "text": event.data}) + elif isinstance(event, wsproto.events.BytesMessage): + await self.send({"type": "websocket.receive", "bytes": event.data}) + else: + raise UnhandledWebSocketEvent(event) + + async def aclose(self) -> None: + with contextlib.suppress(anyio.ClosedResourceError): + await self.send({"type": "websocket.disconnect"}) + await self._receive_queue.aclose() + await self._send_queue.aclose() + + async def send(self, message: Message) -> None: + await self._receive_queue.send(message) + + async def receive(self, timeout: float | None = None) -> Message: + if timeout is None: + timeout = math.inf + with anyio.fail_after(timeout): + return await self._send_queue.receive() + + async def _run(self) -> None: + """ + The task in which the websocket session runs. + """ + scope = self.scope + receive = self._receive_queue.receive + send = self._send_queue.send + try: + await self.app(scope, receive, send) + except Exception as e: + message: Message = { + "type": "websocket.close", + "code": CloseReason.INTERNAL_ERROR, + "reason": str(e), + } + with contextlib.suppress(anyio.ClosedResourceError): + await send(message) + + def _build_accept_response(self, message: Message) -> bytes: + subprotocol = message.get("subprotocol", None) + headers = message.get("headers", []) + return self.connection.send( + wsproto.events.AcceptConnection( + subprotocol=subprotocol, + extra_headers=headers, + ) + ) + + +class ASGIWebSocketTransport(ASGITransport): + """ + A custom `ASGITransport` that handles WebSocket upgrade requests + by emulating the WebSocket protocol against the ASGI app. + + Plain HTTP requests are handled as usual by `ASGITransport`. + + ```python + transport = httpx2.ASGIWebSocketTransport(app=app) + client = httpx2.AsyncClient(transport=transport) + ``` + """ + + scope: Scope + + def __init__( + self, + app: ASGIApp, + raise_app_exceptions: bool = True, + root_path: str = "", + client: tuple[str, int] = ("127.0.0.1", 123), + initial_receive_timeout: float = 1.0, + ) -> None: + super().__init__(app, raise_app_exceptions, root_path, client) + self._exit_stack: contextlib.AsyncExitStack | None = None + self._initial_receive_timeout = initial_receive_timeout + + async def __aenter__(self) -> ASGIWebSocketTransport: + async with contextlib.AsyncExitStack() as stack: + self._task_group = await stack.enter_async_context(anyio.create_task_group()) + self._exit_stack = stack.pop_all() + + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_val: BaseException | None = None, + exc_tb: TracebackType | None = None, + ) -> None: + await super().__aexit__(exc_type, exc_val, exc_tb) + assert self._exit_stack is not None + await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + + async def handle_async_request(self, request: Request) -> Response: + scheme = request.url.scheme + headers = request.headers + + if scheme in {"ws", "wss"} or headers.get("upgrade") == "websocket": + subprotocols: list[str] = [] + if (subprotocols_header := headers.get("sec-websocket-protocol")) is not None: + subprotocols = subprotocols_header.split(",") + + scope: Scope = { + "type": "websocket", + "path": request.url.path, + "raw_path": request.url.raw_path, + "root_path": self.root_path, + "scheme": {"http": "ws", "https": "wss"}.get(scheme, scheme), + "query_string": request.url.query, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "client": self.client, + "server": (request.url.host, request.url.port), + "subprotocols": subprotocols, + } + return await self._handle_ws_request(request, scope) + + return await super().handle_async_request(request) + + async def _create_asgi_websocket_async_network_stream( + self, + *, + task_status: anyio.abc.TaskStatus[tuple[ASGIWebSocketAsyncNetworkStream, bytes]], + ) -> None: + stream = ASGIWebSocketAsyncNetworkStream( + self.app, + self.scope, + self._task_group, + self._initial_receive_timeout, + ) + assert self._exit_stack is not None + result = await self._exit_stack.enter_async_context(stream) + task_status.started(result) + + async def _handle_ws_request(self, request: Request, scope: Scope) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + self.scope = scope + stream, accept_response = await self._task_group.start(self._create_asgi_websocket_async_network_stream) + accept_response_lines = accept_response.decode("utf-8").splitlines() + headers = [ + typing.cast(tuple[str, str], line.split(": ", 1)) for line in accept_response_lines[1:] if line.strip() + ] + + return Response( + status_code=101, + headers=headers, + extensions={"network_stream": stream}, + ) diff --git a/src/httpx2/pyproject.toml b/src/httpx2/pyproject.toml index dc194f7f..5612cf18 100644 --- a/src/httpx2/pyproject.toml +++ b/src/httpx2/pyproject.toml @@ -46,9 +46,10 @@ dynamic = ["readme", "version", "dependencies"] dependencies = [ "truststore>=0.10", "httpcore2=={{ version }}", - "anyio", + "anyio>=4.10", "idna>=3.18", "typing_extensions>=4.5.0; python_version < '3.13'", + "wsproto>=1.2", ] [project.optional-dependencies] diff --git a/tests/httpx2/websockets/__init__.py b/tests/httpx2/websockets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py new file mode 100644 index 00000000..732bf1f7 --- /dev/null +++ b/tests/httpx2/websockets/conftest.py @@ -0,0 +1,67 @@ +import contextlib +import pathlib +import queue +import tempfile +import time +import typing +from unittest.mock import MagicMock + +import pytest +import uvicorn +from anyio.from_thread import start_blocking_portal +from starlette.applications import Starlette +from starlette.routing import WebSocketRoute +from starlette.websockets import WebSocket + +WebSocketEndpoint = typing.Callable[[WebSocket], typing.Awaitable[None]] + + +@pytest.fixture +def on_receive_message() -> MagicMock: + return MagicMock() + + +@pytest.fixture(params=("wsproto", "websockets-sansio")) +def websocket_implementation(request: pytest.FixtureRequest) -> typing.Literal["wsproto", "websockets-sansio"]: + return request.param # type: ignore[no-any-return] + + +class ServerFactoryFixture(typing.Protocol): + def __call__(self, endpoint: WebSocketEndpoint) -> contextlib.AbstractContextManager[str]: ... + + +@pytest.fixture +def server_factory(websocket_implementation: typing.Literal["wsproto", "websockets-sansio"]) -> ServerFactoryFixture: + @contextlib.contextmanager + def _server_factory(endpoint: WebSocketEndpoint) -> typing.Iterator[str]: + shutdown_queue: queue.Queue[bool] = queue.Queue() + + def create_app() -> Starlette: + routes = [ + WebSocketRoute("/ws", endpoint=endpoint), + ] + return Starlette(routes=routes) + + def create_server(app: Starlette, socket: str) -> uvicorn.Server: + config = uvicorn.Config(app, uds=socket, ws=websocket_implementation, lifespan="off") + return uvicorn.Server(config) + + def on_server_stopped(_task: object) -> None: + shutdown_queue.put(True) + + with start_blocking_portal(backend="asyncio") as portal: + with tempfile.TemporaryDirectory() as socket_directory: + socket = str(pathlib.Path(socket_directory) / "socket.sock") + app = create_app() + server = create_server(app, socket) + task = portal.start_task_soon(server.serve) + task.add_done_callback(on_server_stopped) + while not server.started and not task.done(): + time.sleep(0.01) + if task.done() and task.exception() is not None: # pragma: no cover + raise typing.cast(BaseException, task.exception()) + yield socket + server.should_exit = True + shutdown_queue.get(True) + + return _server_factory diff --git a/tests/httpx2/websockets/test_session.py b/tests/httpx2/websockets/test_session.py new file mode 100644 index 00000000..55fe977f --- /dev/null +++ b/tests/httpx2/websockets/test_session.py @@ -0,0 +1,833 @@ +import concurrent.futures +import queue +import threading +import time +from unittest.mock import MagicMock, call, patch + +import anyio +import pytest +import wsproto +from starlette.websockets import WebSocket, WebSocketDisconnect as StarletteWebSocketDisconnect + +import httpcore2 +import httpx2 +from httpcore2 import AsyncNetworkStream, NetworkStream +from httpx2 import ( + AsyncWebSocketSession, + WebSocketDisconnect, + WebSocketInvalidTypeReceived, + WebSocketNetworkError, + WebSocketSession, + WebSocketUpgradeError, +) +from httpx2._websockets._session import JSONMode +from tests.httpx2.websockets.conftest import ServerFactoryFixture + + +@pytest.mark.anyio +async def test_upgrade_error() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(400) + + with httpx2.Client(base_url="http://localhost:8000", transport=httpx2.MockTransport(handler)) as client: + with pytest.raises(WebSocketUpgradeError): + with client.websocket("http://socket/ws"): + pass # pragma: no cover + + async with httpx2.AsyncClient(base_url="http://localhost:8000", transport=httpx2.MockTransport(handler)) as aclient: + with pytest.raises(WebSocketUpgradeError): + async with aclient.websocket("http://socket/ws"): + pass # pragma: no cover + + +def test_top_level_websocket() -> None: + with patch("httpx2._api.Client") as mock_client_cls: + mock_client = mock_client_cls.return_value.__enter__.return_value + with httpx2.websocket("ws://socket/ws", subprotocols=["custom_protocol"]): + pass + mock_client.websocket.assert_called_once() + assert mock_client.websocket.call_args[1]["subprotocols"] == ["custom_protocol"] + + +@pytest.mark.anyio +class TestSend: + async def test_send_error(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self._should_close = False + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + while not self._should_close: + time.sleep(0.1) + raise httpcore2.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise httpcore2.WriteError() + + def close(self) -> None: + self._should_close = True + + stream = MockNetworkStream() + with pytest.raises(WebSocketNetworkError): + with WebSocketSession(stream) as websocket_session: + websocket_session.send(wsproto.events.Ping()) + + async def test_async_send_error(self) -> None: + class AsyncMockNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self._should_close = False + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + while not self._should_close: + await anyio.sleep(0.1) + raise httpcore2.ReadError() + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise httpcore2.WriteError() + + async def aclose(self) -> None: + self._should_close = True + + stream = AsyncMockNetworkStream() + with pytest.RaisesGroup(WebSocketNetworkError): + async with AsyncWebSocketSession(stream) as websocket_session: + await websocket_session.send(wsproto.events.Ping()) + + async def test_send( + self, + server_factory: ServerFactoryFixture, + on_receive_message: MagicMock, + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + message = await websocket.receive_text() + on_receive_message(message) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + ws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + await aws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) + + on_receive_message.assert_has_calls([call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")]) + + async def test_send_text( + self, + server_factory: ServerFactoryFixture, + on_receive_message: MagicMock, + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + message = await websocket.receive_text() + on_receive_message(message) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + ws.send_text("CLIENT_MESSAGE") + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + await aws.send_text("CLIENT_MESSAGE") + + on_receive_message.assert_has_calls([call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")]) + + async def test_send_bytes( + self, + server_factory: ServerFactoryFixture, + on_receive_message: MagicMock, + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + message = await websocket.receive_bytes() + on_receive_message(message) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + ws.send_bytes(b"CLIENT_MESSAGE") + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + await aws.send_bytes(b"CLIENT_MESSAGE") + + on_receive_message.assert_has_calls([call(b"CLIENT_MESSAGE"), call(b"CLIENT_MESSAGE")]) + + @pytest.mark.parametrize("mode", ["text", "binary"]) + async def test_send_json( + self, + mode: JSONMode, + server_factory: ServerFactoryFixture, + on_receive_message: MagicMock, + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + message = await websocket.receive_json(mode=mode) + on_receive_message(message) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + ws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + await aws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) + + on_receive_message.assert_has_calls([call({"message": "CLIENT_MESSAGE"}), call({"message": "CLIENT_MESSAGE"})]) + + +@pytest.mark.anyio +class TestReceive: + async def test_receive_error(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + raise httpcore2.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + def close(self) -> None: + pass + + stream = MockNetworkStream() + with pytest.raises(WebSocketNetworkError): + with WebSocketSession(stream) as websocket_session: + websocket_session.receive() + + def test_receive_closed_socket(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + return b"" + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + def close(self) -> None: + pass + + stream = MockNetworkStream() + with pytest.raises(WebSocketNetworkError): + with WebSocketSession(stream) as websocket_session: + websocket_session.receive() + + def test_receive_timeout(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + time.sleep(0.2) + return b"" + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + def close(self) -> None: + pass + + stream = MockNetworkStream() + with pytest.raises(TimeoutError): + with WebSocketSession(stream) as websocket_session: + websocket_session.receive(timeout=0.1) + + async def test_async_receive_error(self) -> None: + class AsyncMockNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + raise httpcore2.ReadError() + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + async def aclose(self) -> None: + pass + + stream = AsyncMockNetworkStream() + with pytest.RaisesGroup(WebSocketNetworkError): + async with AsyncWebSocketSession(stream) as websocket_session: + await websocket_session.receive() + + async def test_async_receive_closed_socket(self) -> None: + class AsyncMockNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + return b"" + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + async def aclose(self) -> None: + pass + + stream = AsyncMockNetworkStream() + with pytest.RaisesGroup(WebSocketNetworkError): + async with AsyncWebSocketSession(stream) as websocket_session: + await websocket_session.receive() + + async def test_receive(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_text("SERVER_MESSAGE") + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + event = ws.receive() + assert isinstance(event, wsproto.events.TextMessage) + assert event.data == "SERVER_MESSAGE" + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + event = await aws.receive() + assert isinstance(event, wsproto.events.TextMessage) + assert event.data == "SERVER_MESSAGE" + + @pytest.mark.parametrize( + "full_message,send_method", + [ + pytest.param(b"A" * 1024 * 4, "send_bytes", id="bytes"), + pytest.param("A" * 1024 * 4, "send_text", id="text"), + ], + ) + async def test_receive_oversized_message( + self, + full_message: str | bytes, + send_method: str, + server_factory: ServerFactoryFixture, + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + method = getattr(websocket, send_method) + await method(full_message) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws", max_message_size_bytes=1024) as ws: + event = ws.receive() + assert isinstance(event, wsproto.events.Message) + assert event.data == full_message + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws", max_message_size_bytes=1024) as aws: + event = await aws.receive() + assert isinstance(event, wsproto.events.Message) + assert event.data == full_message + + async def test_receive_text(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_text("SERVER_MESSAGE") + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + data = ws.receive_text() + assert data == "SERVER_MESSAGE" + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + data = await aws.receive_text() + assert data == "SERVER_MESSAGE" + + async def test_receive_text_invalid_type(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_bytes(b"SERVER_MESSAGE") + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + with pytest.raises(WebSocketInvalidTypeReceived): + ws.receive_text() + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + with pytest.raises(WebSocketInvalidTypeReceived): + await aws.receive_text() + + async def test_receive_bytes(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_bytes(b"SERVER_MESSAGE") + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + data = ws.receive_bytes() + assert data == b"SERVER_MESSAGE" + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + data = await aws.receive_bytes() + assert data == b"SERVER_MESSAGE" + + async def test_receive_bytes_invalid_type(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_text("SERVER_MESSAGE") + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + with pytest.raises(WebSocketInvalidTypeReceived): + ws.receive_bytes() + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + with pytest.raises(WebSocketInvalidTypeReceived): + await aws.receive_bytes() + + @pytest.mark.parametrize("mode", ["text", "binary"]) + async def test_receive_json(self, mode: JSONMode, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_json({"message": "SERVER_MESSAGE"}, mode=mode) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + data = ws.receive_json(mode=mode) + assert data == {"message": "SERVER_MESSAGE"} + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + data = await aws.receive_json(mode=mode) + assert data == {"message": "SERVER_MESSAGE"} + + +@pytest.mark.anyio +class TestReceivePing: + async def test_receive_ping(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self.events_to_send = [ + wsproto.events.Ping(b"SERVER_PING"), + wsproto.events.CloseConnection(1000), + ] + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + try: + event = self.events_to_send.pop(0) + return self.connection.send(event) + except IndexError: + raise httpcore2.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + self.connection.receive_data(buffer) + + def close(self) -> None: + pass + + stream = MockNetworkStream() + with WebSocketSession(stream): + await anyio.sleep(0.1) + + received_events = list(stream.connection.events()) + assert received_events == [ + wsproto.events.Pong(b"SERVER_PING"), + wsproto.events.CloseConnection(1000, ""), + ] + + async def test_async_receive_ping(self) -> None: + class MockAsyncNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self.events_to_send = [ + wsproto.events.Ping(b"SERVER_PING"), + wsproto.events.CloseConnection(1000), + ] + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + try: + event = self.events_to_send.pop(0) + return self.connection.send(event) + except IndexError: + raise httpcore2.ReadError() + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + self.connection.receive_data(buffer) + + async def aclose(self) -> None: + pass + + stream = MockAsyncNetworkStream() + async with AsyncWebSocketSession(stream): + await anyio.sleep(0.1) + + received_events = list(stream.connection.events()) + assert received_events == [ + wsproto.events.Pong(b"SERVER_PING"), + wsproto.events.CloseConnection(1000, ""), + ] + + +@pytest.mark.anyio +class TestKeepalivePing: + async def test_keepalive_ping(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self._should_close = False + self.ping_received = 0 + self.ping_answered = 0 + self.events_to_send: queue.Queue[wsproto.events.Event] = queue.Queue() + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + while not self._should_close: + try: + event = self.events_to_send.get_nowait() + self.ping_answered += 1 + return self.connection.send(event) + except queue.Empty: + pass + raise httpcore2.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + self.connection.receive_data(buffer) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Ping): + self.ping_received += 1 + self.events_to_send.put(event.response()) + + def close(self) -> None: + self._should_close = True + + stream = MockNetworkStream() + with WebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ): + await anyio.sleep(0.2) + + assert stream.ping_received >= 1 + assert stream.ping_answered >= 1 + + async def test_keepalive_ping_timeout(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self._should_close = False + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + while not self._should_close: + time.sleep(0.1) + raise httpcore2.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + def close(self) -> None: + self._should_close = True + + stream = MockNetworkStream() + with pytest.raises(WebSocketNetworkError): + with WebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ) as websocket_session: + websocket_session.receive() + + async def test_async_keepalive_ping(self) -> None: + class MockAsyncNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self._should_close = False + self.ping_received = 0 + self.ping_answered = 0 + ( + self.send_events, + self.receive_events, + ) = anyio.create_memory_object_stream[wsproto.events.Event]() + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + while not self._should_close: + try: + event = self.receive_events.receive_nowait() + self.ping_answered += 1 + return self.connection.send(event) + except anyio.WouldBlock: + await anyio.sleep(0.1) + raise httpcore2.ReadError() + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + self.connection.receive_data(buffer) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Ping): + self.ping_received += 1 + await self.send_events.send(event.response()) + + async def aclose(self) -> None: + self._should_close = True + self.send_events.close() + self.receive_events.close() + + stream = MockAsyncNetworkStream() + async with AsyncWebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ): + await anyio.sleep(0.3) + + assert stream.ping_received >= 1 + assert stream.ping_answered >= 1 + + async def test_async_keepalive_ping_timeout(self) -> None: + class MockAsyncNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self._should_close = False + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + while not self._should_close: + await anyio.sleep(0.1) + raise httpcore2.ReadError() + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + async def aclose(self) -> None: + self._should_close = True + + stream = MockAsyncNetworkStream() + with pytest.RaisesGroup(WebSocketNetworkError): + async with AsyncWebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ) as websocket_session: + await websocket_session.receive() + + +@pytest.mark.anyio +async def test_ping_pong(server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + try: + await websocket.receive_text() + except StarletteWebSocketDisconnect: + pass + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + ping_callback = ws.ping() + result = ping_callback.wait() + assert result is True + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + aping_callback = await aws.ping() + await aping_callback.wait() + assert aping_callback.is_set() + + +@pytest.mark.anyio +async def test_send_close(server_factory: ServerFactoryFixture, on_receive_message: MagicMock) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + try: + await websocket.receive_text() + except StarletteWebSocketDisconnect as e: + on_receive_message(e.code, e.reason) + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + ws.close(code=1001, reason="CLOSE_REASON") + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + await aws.close(code=1001, reason="CLOSE_REASON") + + on_receive_message.assert_has_calls([call(1001, "CLOSE_REASON"), call(1001, "CLOSE_REASON")]) + + +@pytest.mark.anyio +async def test_receive_close(server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + with pytest.raises(WebSocketDisconnect): + ws.receive() + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + with pytest.raises(WebSocketDisconnect): + await aws.receive() + + +@pytest.mark.anyio +async def test_subprotocol_and_response() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: + assert request.headers["sec-websocket-protocol"] == "custom_protocol, unsupported_protocol" + + return httpx2.Response( + 101, + headers={"sec-websocket-protocol": "custom_protocol"}, + extensions={"network_stream": MagicMock(spec=NetworkStream)}, + ) + + def async_handler(request: httpx2.Request) -> httpx2.Response: + assert request.headers["sec-websocket-protocol"] == "custom_protocol, unsupported_protocol" + + return httpx2.Response( + 101, + headers={"sec-websocket-protocol": "custom_protocol"}, + extensions={"network_stream": MagicMock(spec=AsyncNetworkStream)}, + ) + + with httpx2.Client(base_url="http://localhost:8000", transport=httpx2.MockTransport(handler)) as client: + with client.websocket( + "http://socket/ws", + subprotocols=["custom_protocol", "unsupported_protocol"], + ) as ws: + assert isinstance(ws.response, httpx2.Response) + assert ws.subprotocol == "custom_protocol" + assert ws.response.headers["sec-websocket-protocol"] == ws.subprotocol + + async with httpx2.AsyncClient( + base_url="http://localhost:8000", transport=httpx2.MockTransport(async_handler) + ) as aclient: + async with aclient.websocket( + "http://socket/ws", + subprotocols=["custom_protocol", "unsupported_protocol"], + ) as aws: + assert isinstance(aws.response, httpx2.Response) + assert aws.subprotocol == "custom_protocol" + assert aws.response.headers["sec-websocket-protocol"] == aws.subprotocol + + +@pytest.mark.anyio +async def test_threads_wont_hang(server_factory: ServerFactoryFixture) -> None: + """ + Check that all threads spawned in WebSocketSession are properly terminated during + a series of messages exchange. This used to be the cause of a memory leak in the + connect_ws client, see https://github.com/frankie567/httpx-ws/issues/76. + """ + + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + for _ in range(50): + await websocket.send_text("SERVER_MESSAGE") + await websocket.receive_text() + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + initial_threads_count = threading.active_count() + with client.websocket("http://socket/ws", keepalive_ping_interval_seconds=None) as ws: + for _ in range(50): + ws.receive() + ws.send_text("CLIENT_MESSAGE") + time.sleep(0.1) # Let the websocket endpoint finish its handling. + threads_count = threading.active_count() + assert initial_threads_count + 2 == threads_count + time.sleep(0.1) + final_threads_count = threading.active_count() + assert initial_threads_count == final_threads_count + + +@pytest.mark.anyio +async def test_concurrency_write(server_factory: ServerFactoryFixture) -> None: + """ + Check that there is no error because of two tasks trying to write the stream at the + same time. Typically, this is when a background ping tries to send a ping while the + main task is sending a message. + + See: https://github.com/frankie567/httpx-ws/issues/29 + """ + + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + async for message in websocket.iter_text(): + await websocket.send_text(message) + + with server_factory(websocket_endpoint) as socket: + # Added for completeness, but were not able to reproduce the issue with the sync client + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + for _ in range(10): + executor.submit(ws.send_text, "CLIENT_MESSAGE") + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + async with anyio.create_task_group() as tg: + for _ in range(10): + tg.start_soon(aws.send_text, "CLIENT_MESSAGE") + + +@pytest.mark.anyio +async def test_client_websocket_with_mock_stream() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(101, extensions={"network_stream": MagicMock(spec=NetworkStream)}) + + def async_handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(101, extensions={"network_stream": MagicMock(spec=AsyncNetworkStream)}) + + with httpx2.Client(base_url="http://localhost:8000", transport=httpx2.MockTransport(handler)) as client: + with client.websocket("http://socket/ws") as ws: + assert isinstance(ws.response, httpx2.Response) + + async with httpx2.AsyncClient( + base_url="http://localhost:8000", transport=httpx2.MockTransport(async_handler) + ) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + assert isinstance(aws.response, httpx2.Response) diff --git a/tests/httpx2/websockets/test_transport.py b/tests/httpx2/websockets/test_transport.py new file mode 100644 index 00000000..49e0cd09 --- /dev/null +++ b/tests/httpx2/websockets/test_transport.py @@ -0,0 +1,379 @@ +import base64 +import secrets +import sys +import typing + +import anyio +import pytest +import wsproto +from anyio import CancelScope, ClosedResourceError, create_task_group +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import PlainTextResponse +from starlette.routing import Route, WebSocketRoute +from starlette.websockets import WebSocket + +import httpx2 +from httpx2 import ASGIWebSocketTransport, WebSocketDisconnect, WebSocketUpgradeError +from httpx2._websockets._transport import ( + ASGIWebSocketAsyncNetworkStream, + Message, + Receive, + Scope, + Send, + UnhandledASGIMessageType, + UnhandledWebSocketEvent, +) + +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + + +@pytest.fixture +def websocket_request_headers() -> dict[str, str]: + return { + "connection": "upgrade", + "upgrade": "websocket", + "sec-websocket-key": base64.b64encode(secrets.token_bytes(16)).decode("utf-8"), + "sec-websocket-version": "13", + } + + +@pytest.fixture +def scope(websocket_request_headers: dict[str, str]) -> Scope: + return { + "type": "websocket", + "path": "/ws", + "raw_path": "/ws", + "root_path": "/", + "scheme": "ws", + "headers": [ + ("host", "localhost"), + *websocket_request_headers.items(), + ], + "subprotocols": [], + "server": ("localhost", 8000), + } + + +@pytest.mark.anyio +class TestASGIWebSocketAsyncNetworkStream: + async def test_write(self, scope: Scope) -> None: + received_messages: list[Message] = [] + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + message = await receive() + received_messages.append(message) + while message["type"] != "websocket.disconnect": + message = await receive() + received_messages.append(message) + + connection = wsproto.connection.Connection(wsproto.connection.CLIENT) + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg) as (stream, _), + ): + text_event = wsproto.events.TextMessage("CLIENT_MESSAGE") + await stream.write(connection.send(text_event)) + + bytes_event = wsproto.events.BytesMessage(b"CLIENT_MESSAGE") + await stream.write(connection.send(bytes_event)) + + close_event = wsproto.events.CloseConnection(1000) + await stream.write(connection.send(close_event)) + + # Add a small delay to ensure the app has processed all messages + await anyio.sleep(0.1) + + assert received_messages == [ + {"type": "websocket.connect"}, + {"type": "websocket.receive", "text": "CLIENT_MESSAGE"}, + {"type": "websocket.receive", "bytes": b"CLIENT_MESSAGE"}, + {"type": "websocket.disconnect", "code": 1000, "reason": ""}, + ] + + async def test_write_unhandled_event(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + await receive() + + connection = wsproto.connection.Connection(wsproto.connection.CLIENT) + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg) as (stream, _), + ): + with pytest.raises(UnhandledWebSocketEvent): + ping_event = wsproto.events.Ping(b"PING") + await stream.write(connection.send(ping_event)) + + async def test_read(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + await send({"type": "websocket.send", "text": "SERVER_MESSAGE"}) + await send({"type": "websocket.send", "bytes": b"SERVER_MESSAGE"}) + await send({"type": "websocket.close", "code": 1000, "reason": ""}) + + connection = wsproto.connection.Connection(wsproto.connection.CLIENT) + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg) as (stream, _), + ): + for _ in range(3): + data = await stream.read(4096) + connection.receive_data(data) + + events = list(connection.events()) + assert events == [ + wsproto.events.TextMessage("SERVER_MESSAGE"), + wsproto.events.BytesMessage(bytearray(b"SERVER_MESSAGE")), + wsproto.events.CloseConnection(1000, ""), + ] + + async def test_read_unhandled_asgi_message(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + await send({"type": "websocket.foo"}) + + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg) as (stream, _), + ): + with pytest.raises(UnhandledASGIMessageType): + await stream.read(4096) + + async def test_close_immediately(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.close", "code": 1000, "reason": ""}) + + with pytest.raises(ExceptionGroup) as excinfo: + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg), + ): + pass # pragma: no cover + assert excinfo.group_contains(WebSocketDisconnect) + + async def test_denial_response(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.http.response.start", "status": 401, "headers": []}) + await send({"type": "websocket.http.response.body", "body": b"Unauthorized"}) + + with pytest.raises(ExceptionGroup) as excinfo: + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg), + ): + pass # pragma: no cover + assert excinfo.group_contains(WebSocketUpgradeError) + upgrade_error = excinfo.value.exceptions[0] + assert isinstance(upgrade_error, WebSocketUpgradeError) + assert upgrade_error.response.status_code == 401 + assert upgrade_error.response.content == b"Unauthorized" + + async def test_exception(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + raise Exception("Error") + + with pytest.raises(ExceptionGroup) as excinfo: + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg), + ): + pass # pragma: no cover + assert excinfo.group_contains(WebSocketDisconnect) + disconnect = excinfo.value.exceptions[0] + assert isinstance(disconnect, WebSocketDisconnect) + assert disconnect.code == 1011 + assert disconnect.reason == "Error" + + async def test_never_accepts(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + return + + with pytest.raises(ExceptionGroup) as excinfo: + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg), + ): + pass # pragma: no cover + + assert excinfo.group_contains(RuntimeError) + + async def test_app_exception_with_closed_send_queue(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + await receive() + raise Exception("App error") + + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg) as (stream, _), + ): + await stream._send_queue.aclose() + await stream.send({"type": "websocket.receive", "text": "trigger"}) + + +@pytest.fixture +def test_app() -> Starlette: + async def http_endpoint(request: Request) -> PlainTextResponse: + return PlainTextResponse("Hello, world!") + + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.receive_text() + await websocket.close() + + routes = [ + Route("/http", endpoint=http_endpoint), + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + + return Starlette(routes=routes) + + +@pytest.mark.anyio +class TestASGIWebSocketTransport: + async def test_http(self, test_app: Starlette) -> None: + async with ASGIWebSocketTransport(app=test_app) as transport: + request = httpx2.Request("GET", "http://localhost:8000/http") + response = await transport.handle_async_request(request) + assert response.status_code == 200 + + @pytest.mark.parametrize( + "url,headers", + [ + ("ws://localhost:8000/ws", {}), + ("wss://localhost:8000/ws", {}), + ("http://localhost:8000/ws", {"upgrade": "websocket"}), + ], + ) + async def test_websocket( + self, + url: str, + headers: dict[str, typing.Any], + test_app: Starlette, + websocket_request_headers: dict[str, str], + ) -> None: + async with ASGIWebSocketTransport(app=test_app) as transport: + request = httpx2.Request("GET", url, headers={**websocket_request_headers, **headers}) + response = await transport.handle_async_request(request) + assert response.status_code == 101 + + assert isinstance(response.extensions["network_stream"], ASGIWebSocketAsyncNetworkStream) + + @pytest.mark.parametrize("stream_count", [1, 3]) + async def test_transport_exit_closes_stream_queues( + self, + stream_count: int, + test_app: Starlette, + websocket_request_headers: dict[str, str], + ) -> None: + async with ASGIWebSocketTransport(app=test_app) as transport: + streams: list[ASGIWebSocketAsyncNetworkStream] = [] + for _ in range(stream_count): + request = httpx2.Request( + "GET", + "ws://localhost:8000/ws", + headers=websocket_request_headers, + ) + response = await transport.handle_async_request(request) + streams.append(response.extensions["network_stream"]) + + for stream in streams: + with pytest.raises(ClosedResourceError): + await stream._receive_queue.send({}) + with pytest.raises(ClosedResourceError): + await stream._send_queue.send({}) + + async def test_aclose_after_transport_exit_does_not_raise( + self, + test_app: Starlette, + websocket_request_headers: dict[str, str], + ) -> None: + async with ASGIWebSocketTransport(app=test_app) as transport: + request = httpx2.Request("GET", "ws://localhost:8000/ws", headers=websocket_request_headers) + response = await transport.handle_async_request(request) + stream = response.extensions["network_stream"] + + await stream.aclose() + + +@pytest.mark.anyio +async def test_subprotocol_support() -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept("custom_protocol") + assert websocket.scope.get("subprotocols") == ["custom_protocol"] + await websocket.send_text("SERVER_MESSAGE") + await websocket.close() + + app = Starlette( + routes=[ + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + ) + + async with httpx2.AsyncClient(transport=ASGIWebSocketTransport(app)) as client: + async with client.websocket("ws://localhost:8000/ws", subprotocols=["custom_protocol"]) as ws: + await ws.receive_text() + assert ws.subprotocol == "custom_protocol" + + +@pytest.mark.anyio +async def test_keepalive_ping_disabled() -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.receive_text() + await websocket.close() + + app = Starlette( + routes=[ + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + ) + + async with httpx2.AsyncClient(transport=ASGIWebSocketTransport(app)) as client: + async with client.websocket("ws://localhost:8000/ws") as ws: + assert ws._keepalive_ping_interval_seconds is None + + +@pytest.mark.anyio +async def test_cancel_scope_integrity() -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.receive_text() + await websocket.close() + + app = Starlette( + routes=[ + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + ) + + async with httpx2.AsyncClient(transport=ASGIWebSocketTransport(app)) as client: + with CancelScope(): + async with client.websocket("ws://localhost:8000/ws"): + pass + + +@pytest.mark.anyio +async def test_receive() -> None: + messages: list[str] = [] + + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + messages.append(await websocket.receive_text()) + await websocket.close() + + app = Starlette( + routes=[ + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + ) + + async with httpx2.AsyncClient(transport=ASGIWebSocketTransport(app)) as client: + async with client.websocket("ws://localhost:8000/ws") as ws: + await ws.send_text("RESULT") + + assert len(messages) == 1 + assert messages[0] == "RESULT" diff --git a/uv.lock b/uv.lock index ef708e99..a4e62128 100644 --- a/uv.lock +++ b/uv.lock @@ -41,11 +41,13 @@ dev = [ { name = "pytest-httpbin", specifier = "==2.0.0" }, { name = "pytest-trio", specifier = "==0.8.0" }, { name = "ruff", specifier = "==0.15.13" }, + { name = "starlette", specifier = ">=0.49" }, { name = "trio", specifier = "==0.31.0" }, { name = "trio-typing", specifier = "==0.10.0" }, { name = "trustme", specifier = "==1.2.1" }, { name = "twine", specifier = "==6.1.0" }, { name = "uvicorn", specifier = ">=0.35" }, + { name = "websockets", specifier = ">=15" }, { name = "werkzeug", specifier = ">=3.1.6" }, ] docs = [ @@ -1354,6 +1356,7 @@ dependencies = [ { name = "idna" }, { name = "truststore" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "wsproto" }, ] [package.optional-dependencies] @@ -1378,7 +1381,7 @@ zstd = [ [package.metadata] requires-dist = [ - { name = "anyio" }, + { name = "anyio", specifier = ">=4.10" }, { name = "brotli", marker = "platform_python_implementation == 'CPython' and extra == 'brotli'" }, { name = "brotlicffi", marker = "platform_python_implementation != 'CPython' and extra == 'brotli'" }, { name = "click", marker = "extra == 'cli'", specifier = "==8.*" }, @@ -1390,6 +1393,7 @@ requires-dist = [ { name = "socksio", marker = "extra == 'socks'", specifier = "==1.*" }, { name = "truststore", specifier = ">=0.10" }, { name = "typing-extensions", marker = "python_full_version < '3.13'", specifier = ">=4.5.0" }, + { name = "wsproto", specifier = ">=1.2" }, { name = "zstandard", marker = "python_full_version < '3.14' and extra == 'zstd'", specifier = ">=0.18.0" }, ] provides-extras = ["brotli", "cli", "http2", "socks", "zstd"] @@ -3175,6 +3179,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" }, ] +[[package]] +name = "starlette" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/25/44/ec35f1b6e83094b997da438a02c8c9b0ade2b1e84cfc48bd4656780760a6/starlette-1.2.1.tar.gz", hash = "sha256:9b9b5ebb992e67d6093741e63c2f59e4f6fff986f81163c087867bd7b924b3f6", size = 2701854, upload-time = "2026-05-31T01:07:51.847Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1c/54/196d0c1db10af76baa4f64894448505d60d3cdf70ef92cbb35f46a4e4c71/starlette-1.2.1-py3-none-any.whl", hash = "sha256:4de0082d08c8f6764a85a54cf1120d6939507a19905c7768acad2a9f875d2b89", size = 73350, upload-time = "2026-05-31T01:07:50.09Z" }, +] + [[package]] name = "tomli" version = "2.4.1" @@ -3370,6 +3387,74 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, ] +[[package]] +name = "websockets" +version = "16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/04/24/4b2031d72e840ce4c1ccb255f693b15c334757fc50023e4db9537080b8c4/websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5", size = 179346, upload-time = "2026-01-10T09:23:47.181Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/74/221f58decd852f4b59cc3354cccaf87e8ef695fede361d03dc9a7396573b/websockets-16.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04cdd5d2d1dacbad0a7bf36ccbcd3ccd5a30ee188f2560b7a62a30d14107b31a", size = 177343, upload-time = "2026-01-10T09:22:21.28Z" }, + { url = "https://files.pythonhosted.org/packages/19/0f/22ef6107ee52ab7f0b710d55d36f5a5d3ef19e8a205541a6d7ffa7994e5a/websockets-16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8ff32bb86522a9e5e31439a58addbb0166f0204d64066fb955265c4e214160f0", size = 175021, upload-time = "2026-01-10T09:22:22.696Z" }, + { url = "https://files.pythonhosted.org/packages/10/40/904a4cb30d9b61c0e278899bf36342e9b0208eb3c470324a9ecbaac2a30f/websockets-16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:583b7c42688636f930688d712885cf1531326ee05effd982028212ccc13e5957", size = 175320, upload-time = "2026-01-10T09:22:23.94Z" }, + { url = "https://files.pythonhosted.org/packages/9d/2f/4b3ca7e106bc608744b1cdae041e005e446124bebb037b18799c2d356864/websockets-16.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7d837379b647c0c4c2355c2499723f82f1635fd2c26510e1f587d89bc2199e72", size = 183815, upload-time = "2026-01-10T09:22:25.469Z" }, + { url = "https://files.pythonhosted.org/packages/86/26/d40eaa2a46d4302becec8d15b0fc5e45bdde05191e7628405a19cf491ccd/websockets-16.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df57afc692e517a85e65b72e165356ed1df12386ecb879ad5693be08fac65dde", size = 185054, upload-time = "2026-01-10T09:22:27.101Z" }, + { url = "https://files.pythonhosted.org/packages/b0/ba/6500a0efc94f7373ee8fefa8c271acdfd4dca8bd49a90d4be7ccabfc397e/websockets-16.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2b9f1e0d69bc60a4a87349d50c09a037a2607918746f07de04df9e43252c77a3", size = 184565, upload-time = "2026-01-10T09:22:28.293Z" }, + { url = "https://files.pythonhosted.org/packages/04/b4/96bf2cee7c8d8102389374a2616200574f5f01128d1082f44102140344cc/websockets-16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:335c23addf3d5e6a8633f9f8eda77efad001671e80b95c491dd0924587ece0b3", size = 183848, upload-time = "2026-01-10T09:22:30.394Z" }, + { url = "https://files.pythonhosted.org/packages/02/8e/81f40fb00fd125357814e8c3025738fc4ffc3da4b6b4a4472a82ba304b41/websockets-16.0-cp310-cp310-win32.whl", hash = "sha256:37b31c1623c6605e4c00d466c9d633f9b812ea430c11c8a278774a1fde1acfa9", size = 178249, upload-time = "2026-01-10T09:22:32.083Z" }, + { url = "https://files.pythonhosted.org/packages/b4/5f/7e40efe8df57db9b91c88a43690ac66f7b7aa73a11aa6a66b927e44f26fa/websockets-16.0-cp310-cp310-win_amd64.whl", hash = "sha256:8e1dab317b6e77424356e11e99a432b7cb2f3ec8c5ab4dabbcee6add48f72b35", size = 178685, upload-time = "2026-01-10T09:22:33.345Z" }, + { url = "https://files.pythonhosted.org/packages/f2/db/de907251b4ff46ae804ad0409809504153b3f30984daf82a1d84a9875830/websockets-16.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:31a52addea25187bde0797a97d6fc3d2f92b6f72a9370792d65a6e84615ac8a8", size = 177340, upload-time = "2026-01-10T09:22:34.539Z" }, + { url = "https://files.pythonhosted.org/packages/f3/fa/abe89019d8d8815c8781e90d697dec52523fb8ebe308bf11664e8de1877e/websockets-16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:417b28978cdccab24f46400586d128366313e8a96312e4b9362a4af504f3bbad", size = 175022, upload-time = "2026-01-10T09:22:36.332Z" }, + { url = "https://files.pythonhosted.org/packages/58/5d/88ea17ed1ded2079358b40d31d48abe90a73c9e5819dbcde1606e991e2ad/websockets-16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af80d74d4edfa3cb9ed973a0a5ba2b2a549371f8a741e0800cb07becdd20f23d", size = 175319, upload-time = "2026-01-10T09:22:37.602Z" }, + { url = "https://files.pythonhosted.org/packages/d2/ae/0ee92b33087a33632f37a635e11e1d99d429d3d323329675a6022312aac2/websockets-16.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:08d7af67b64d29823fed316505a89b86705f2b7981c07848fb5e3ea3020c1abe", size = 184631, upload-time = "2026-01-10T09:22:38.789Z" }, + { url = "https://files.pythonhosted.org/packages/c8/c5/27178df583b6c5b31b29f526ba2da5e2f864ecc79c99dae630a85d68c304/websockets-16.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7be95cfb0a4dae143eaed2bcba8ac23f4892d8971311f1b06f3c6b78952ee70b", size = 185870, upload-time = "2026-01-10T09:22:39.893Z" }, + { url = "https://files.pythonhosted.org/packages/87/05/536652aa84ddc1c018dbb7e2c4cbcd0db884580bf8e95aece7593fde526f/websockets-16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d6297ce39ce5c2e6feb13c1a996a2ded3b6832155fcfc920265c76f24c7cceb5", size = 185361, upload-time = "2026-01-10T09:22:41.016Z" }, + { url = "https://files.pythonhosted.org/packages/6d/e2/d5332c90da12b1e01f06fb1b85c50cfc489783076547415bf9f0a659ec19/websockets-16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1c1b30e4f497b0b354057f3467f56244c603a79c0d1dafce1d16c283c25f6e64", size = 184615, upload-time = "2026-01-10T09:22:42.442Z" }, + { url = "https://files.pythonhosted.org/packages/77/fb/d3f9576691cae9253b51555f841bc6600bf0a983a461c79500ace5a5b364/websockets-16.0-cp311-cp311-win32.whl", hash = "sha256:5f451484aeb5cafee1ccf789b1b66f535409d038c56966d6101740c1614b86c6", size = 178246, upload-time = "2026-01-10T09:22:43.654Z" }, + { url = "https://files.pythonhosted.org/packages/54/67/eaff76b3dbaf18dcddabc3b8c1dba50b483761cccff67793897945b37408/websockets-16.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d7f0659570eefb578dacde98e24fb60af35350193e4f56e11190787bee77dac", size = 178684, upload-time = "2026-01-10T09:22:44.941Z" }, + { url = "https://files.pythonhosted.org/packages/84/7b/bac442e6b96c9d25092695578dda82403c77936104b5682307bd4deb1ad4/websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00", size = 177365, upload-time = "2026-01-10T09:22:46.787Z" }, + { url = "https://files.pythonhosted.org/packages/b0/fe/136ccece61bd690d9c1f715baaeefd953bb2360134de73519d5df19d29ca/websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79", size = 175038, upload-time = "2026-01-10T09:22:47.999Z" }, + { url = "https://files.pythonhosted.org/packages/40/1e/9771421ac2286eaab95b8575b0cb701ae3663abf8b5e1f64f1fd90d0a673/websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39", size = 175328, upload-time = "2026-01-10T09:22:49.809Z" }, + { url = "https://files.pythonhosted.org/packages/18/29/71729b4671f21e1eaa5d6573031ab810ad2936c8175f03f97f3ff164c802/websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c", size = 184915, upload-time = "2026-01-10T09:22:51.071Z" }, + { url = "https://files.pythonhosted.org/packages/97/bb/21c36b7dbbafc85d2d480cd65df02a1dc93bf76d97147605a8e27ff9409d/websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f", size = 186152, upload-time = "2026-01-10T09:22:52.224Z" }, + { url = "https://files.pythonhosted.org/packages/4a/34/9bf8df0c0cf88fa7bfe36678dc7b02970c9a7d5e065a3099292db87b1be2/websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1", size = 185583, upload-time = "2026-01-10T09:22:53.443Z" }, + { url = "https://files.pythonhosted.org/packages/47/88/4dd516068e1a3d6ab3c7c183288404cd424a9a02d585efbac226cb61ff2d/websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2", size = 184880, upload-time = "2026-01-10T09:22:55.033Z" }, + { url = "https://files.pythonhosted.org/packages/91/d6/7d4553ad4bf1c0421e1ebd4b18de5d9098383b5caa1d937b63df8d04b565/websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89", size = 178261, upload-time = "2026-01-10T09:22:56.251Z" }, + { url = "https://files.pythonhosted.org/packages/c3/f0/f3a17365441ed1c27f850a80b2bc680a0fa9505d733fe152fdf5e98c1c0b/websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea", size = 178693, upload-time = "2026-01-10T09:22:57.478Z" }, + { url = "https://files.pythonhosted.org/packages/cc/9c/baa8456050d1c1b08dd0ec7346026668cbc6f145ab4e314d707bb845bf0d/websockets-16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878b336ac47938b474c8f982ac2f7266a540adc3fa4ad74ae96fea9823a02cc9", size = 177364, upload-time = "2026-01-10T09:22:59.333Z" }, + { url = "https://files.pythonhosted.org/packages/7e/0c/8811fc53e9bcff68fe7de2bcbe75116a8d959ac699a3200f4847a8925210/websockets-16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52a0fec0e6c8d9a784c2c78276a48a2bdf099e4ccc2a4cad53b27718dbfd0230", size = 175039, upload-time = "2026-01-10T09:23:01.171Z" }, + { url = "https://files.pythonhosted.org/packages/aa/82/39a5f910cb99ec0b59e482971238c845af9220d3ab9fa76dd9162cda9d62/websockets-16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e6578ed5b6981005df1860a56e3617f14a6c307e6a71b4fff8c48fdc50f3ed2c", size = 175323, upload-time = "2026-01-10T09:23:02.341Z" }, + { url = "https://files.pythonhosted.org/packages/bd/28/0a25ee5342eb5d5f297d992a77e56892ecb65e7854c7898fb7d35e9b33bd/websockets-16.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:95724e638f0f9c350bb1c2b0a7ad0e83d9cc0c9259f3ea94e40d7b02a2179ae5", size = 184975, upload-time = "2026-01-10T09:23:03.756Z" }, + { url = "https://files.pythonhosted.org/packages/f9/66/27ea52741752f5107c2e41fda05e8395a682a1e11c4e592a809a90c6a506/websockets-16.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0204dc62a89dc9d50d682412c10b3542d748260d743500a85c13cd1ee4bde82", size = 186203, upload-time = "2026-01-10T09:23:05.01Z" }, + { url = "https://files.pythonhosted.org/packages/37/e5/8e32857371406a757816a2b471939d51c463509be73fa538216ea52b792a/websockets-16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:52ac480f44d32970d66763115edea932f1c5b1312de36df06d6b219f6741eed8", size = 185653, upload-time = "2026-01-10T09:23:06.301Z" }, + { url = "https://files.pythonhosted.org/packages/9b/67/f926bac29882894669368dc73f4da900fcdf47955d0a0185d60103df5737/websockets-16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6e5a82b677f8f6f59e8dfc34ec06ca6b5b48bc4fcda346acd093694cc2c24d8f", size = 184920, upload-time = "2026-01-10T09:23:07.492Z" }, + { url = "https://files.pythonhosted.org/packages/3c/a1/3d6ccdcd125b0a42a311bcd15a7f705d688f73b2a22d8cf1c0875d35d34a/websockets-16.0-cp313-cp313-win32.whl", hash = "sha256:abf050a199613f64c886ea10f38b47770a65154dc37181bfaff70c160f45315a", size = 178255, upload-time = "2026-01-10T09:23:09.245Z" }, + { url = "https://files.pythonhosted.org/packages/6b/ae/90366304d7c2ce80f9b826096a9e9048b4bb760e44d3b873bb272cba696b/websockets-16.0-cp313-cp313-win_amd64.whl", hash = "sha256:3425ac5cf448801335d6fdc7ae1eb22072055417a96cc6b31b3861f455fbc156", size = 178689, upload-time = "2026-01-10T09:23:10.483Z" }, + { url = "https://files.pythonhosted.org/packages/f3/1d/e88022630271f5bd349ed82417136281931e558d628dd52c4d8621b4a0b2/websockets-16.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8cc451a50f2aee53042ac52d2d053d08bf89bcb31ae799cb4487587661c038a0", size = 177406, upload-time = "2026-01-10T09:23:12.178Z" }, + { url = "https://files.pythonhosted.org/packages/f2/78/e63be1bf0724eeb4616efb1ae1c9044f7c3953b7957799abb5915bffd38e/websockets-16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:daa3b6ff70a9241cf6c7fc9e949d41232d9d7d26fd3522b1ad2b4d62487e9904", size = 175085, upload-time = "2026-01-10T09:23:13.511Z" }, + { url = "https://files.pythonhosted.org/packages/bb/f4/d3c9220d818ee955ae390cf319a7c7a467beceb24f05ee7aaaa2414345ba/websockets-16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:fd3cb4adb94a2a6e2b7c0d8d05cb94e6f1c81a0cf9dc2694fb65c7e8d94c42e4", size = 175328, upload-time = "2026-01-10T09:23:14.727Z" }, + { url = "https://files.pythonhosted.org/packages/63/bc/d3e208028de777087e6fb2b122051a6ff7bbcca0d6df9d9c2bf1dd869ae9/websockets-16.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:781caf5e8eee67f663126490c2f96f40906594cb86b408a703630f95550a8c3e", size = 185044, upload-time = "2026-01-10T09:23:15.939Z" }, + { url = "https://files.pythonhosted.org/packages/ad/6e/9a0927ac24bd33a0a9af834d89e0abc7cfd8e13bed17a86407a66773cc0e/websockets-16.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:caab51a72c51973ca21fa8a18bd8165e1a0183f1ac7066a182ff27107b71e1a4", size = 186279, upload-time = "2026-01-10T09:23:17.148Z" }, + { url = "https://files.pythonhosted.org/packages/b9/ca/bf1c68440d7a868180e11be653c85959502efd3a709323230314fda6e0b3/websockets-16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19c4dc84098e523fd63711e563077d39e90ec6702aff4b5d9e344a60cb3c0cb1", size = 185711, upload-time = "2026-01-10T09:23:18.372Z" }, + { url = "https://files.pythonhosted.org/packages/c4/f8/fdc34643a989561f217bb477cbc47a3a07212cbda91c0e4389c43c296ebf/websockets-16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a5e18a238a2b2249c9a9235466b90e96ae4795672598a58772dd806edc7ac6d3", size = 184982, upload-time = "2026-01-10T09:23:19.652Z" }, + { url = "https://files.pythonhosted.org/packages/dd/d1/574fa27e233764dbac9c52730d63fcf2823b16f0856b3329fc6268d6ae4f/websockets-16.0-cp314-cp314-win32.whl", hash = "sha256:a069d734c4a043182729edd3e9f247c3b2a4035415a9172fd0f1b71658a320a8", size = 177915, upload-time = "2026-01-10T09:23:21.458Z" }, + { url = "https://files.pythonhosted.org/packages/8a/f1/ae6b937bf3126b5134ce1f482365fde31a357c784ac51852978768b5eff4/websockets-16.0-cp314-cp314-win_amd64.whl", hash = "sha256:c0ee0e63f23914732c6d7e0cce24915c48f3f1512ec1d079ed01fc629dab269d", size = 178381, upload-time = "2026-01-10T09:23:22.715Z" }, + { url = "https://files.pythonhosted.org/packages/06/9b/f791d1db48403e1f0a27577a6beb37afae94254a8c6f08be4a23e4930bc0/websockets-16.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:a35539cacc3febb22b8f4d4a99cc79b104226a756aa7400adc722e83b0d03244", size = 177737, upload-time = "2026-01-10T09:23:24.523Z" }, + { url = "https://files.pythonhosted.org/packages/bd/40/53ad02341fa33b3ce489023f635367a4ac98b73570102ad2cdd770dacc9a/websockets-16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b784ca5de850f4ce93ec85d3269d24d4c82f22b7212023c974c401d4980ebc5e", size = 175268, upload-time = "2026-01-10T09:23:25.781Z" }, + { url = "https://files.pythonhosted.org/packages/74/9b/6158d4e459b984f949dcbbb0c5d270154c7618e11c01029b9bbd1bb4c4f9/websockets-16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:569d01a4e7fba956c5ae4fc988f0d4e187900f5497ce46339c996dbf24f17641", size = 175486, upload-time = "2026-01-10T09:23:27.033Z" }, + { url = "https://files.pythonhosted.org/packages/e5/2d/7583b30208b639c8090206f95073646c2c9ffd66f44df967981a64f849ad/websockets-16.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:50f23cdd8343b984957e4077839841146f67a3d31ab0d00e6b824e74c5b2f6e8", size = 185331, upload-time = "2026-01-10T09:23:28.259Z" }, + { url = "https://files.pythonhosted.org/packages/45/b0/cce3784eb519b7b5ad680d14b9673a31ab8dcb7aad8b64d81709d2430aa8/websockets-16.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:152284a83a00c59b759697b7f9e9cddf4e3c7861dd0d964b472b70f78f89e80e", size = 186501, upload-time = "2026-01-10T09:23:29.449Z" }, + { url = "https://files.pythonhosted.org/packages/19/60/b8ebe4c7e89fb5f6cdf080623c9d92789a53636950f7abacfc33fe2b3135/websockets-16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bc59589ab64b0022385f429b94697348a6a234e8ce22544e3681b2e9331b5944", size = 186062, upload-time = "2026-01-10T09:23:31.368Z" }, + { url = "https://files.pythonhosted.org/packages/88/a8/a080593f89b0138b6cba1b28f8df5673b5506f72879322288b031337c0b8/websockets-16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:32da954ffa2814258030e5a57bc73a3635463238e797c7375dc8091327434206", size = 185356, upload-time = "2026-01-10T09:23:32.627Z" }, + { url = "https://files.pythonhosted.org/packages/c2/b6/b9afed2afadddaf5ebb2afa801abf4b0868f42f8539bfe4b071b5266c9fe/websockets-16.0-cp314-cp314t-win32.whl", hash = "sha256:5a4b4cc550cb665dd8a47f868c8d04c8230f857363ad3c9caf7a0c3bf8c61ca6", size = 178085, upload-time = "2026-01-10T09:23:33.816Z" }, + { url = "https://files.pythonhosted.org/packages/9f/3e/28135a24e384493fa804216b79a6a6759a38cc4ff59118787b9fb693df93/websockets-16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b14dc141ed6d2dde437cddb216004bcac6a1df0935d79656387bd41632ba0bbd", size = 178531, upload-time = "2026-01-10T09:23:35.016Z" }, + { url = "https://files.pythonhosted.org/packages/72/07/c98a68571dcf256e74f1f816b8cc5eae6eb2d3d5cfa44d37f801619d9166/websockets-16.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:349f83cd6c9a415428ee1005cadb5c2c56f4389bc06a9af16103c3bc3dcc8b7d", size = 174947, upload-time = "2026-01-10T09:23:36.166Z" }, + { url = "https://files.pythonhosted.org/packages/7e/52/93e166a81e0305b33fe416338be92ae863563fe7bce446b0f687b9df5aea/websockets-16.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:4a1aba3340a8dca8db6eb5a7986157f52eb9e436b74813764241981ca4888f03", size = 175260, upload-time = "2026-01-10T09:23:37.409Z" }, + { url = "https://files.pythonhosted.org/packages/56/0c/2dbf513bafd24889d33de2ff0368190a0e69f37bcfa19009ef819fe4d507/websockets-16.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f4a32d1bd841d4bcbffdcb3d2ce50c09c3909fbead375ab28d0181af89fd04da", size = 176071, upload-time = "2026-01-10T09:23:39.158Z" }, + { url = "https://files.pythonhosted.org/packages/a5/8f/aea9c71cc92bf9b6cc0f7f70df8f0b420636b6c96ef4feee1e16f80f75dd/websockets-16.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0298d07ee155e2e9fda5be8a9042200dd2e3bb0b8a38482156576f863a9d457c", size = 176968, upload-time = "2026-01-10T09:23:41.031Z" }, + { url = "https://files.pythonhosted.org/packages/9a/3f/f70e03f40ffc9a30d817eef7da1be72ee4956ba8d7255c399a01b135902a/websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767", size = 178735, upload-time = "2026-01-10T09:23:42.259Z" }, + { url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598, upload-time = "2026-01-10T09:23:45.395Z" }, +] + [[package]] name = "werkzeug" version = "3.1.8" @@ -3382,6 +3467,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/93/8c/2e650f2afeb7ee576912636c23ddb621c91ac6a98e66dc8d29c3c69446e1/werkzeug-3.1.8-py3-none-any.whl", hash = "sha256:63a77fb8892bf28ebc3178683445222aa500e48ebad5ec77b0ad80f8726b1f50", size = 226459, upload-time = "2026-04-02T18:49:12.72Z" }, ] +[[package]] +name = "wsproto" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/79/12135bdf8b9c9367b8701c2c19a14c913c120b882d50b014ca0d38083c2c/wsproto-1.3.2.tar.gz", hash = "sha256:b86885dcf294e15204919950f666e06ffc6c7c114ca900b060d6e16293528294", size = 50116, upload-time = "2025-11-20T18:18:01.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/f5/10b68b7b1544245097b2a1b8238f66f2fc6dcaeb24ba5d917f52bd2eed4f/wsproto-1.3.2-py3-none-any.whl", hash = "sha256:61eea322cdf56e8cc904bd3ad7573359a242ba65688716b0710a5eb12beab584", size = 24405, upload-time = "2025-11-20T18:18:00.454Z" }, +] + [[package]] name = "yarl" version = "1.23.0" From d37fd914d91e1e3b12a228ac8a3f7cd77d84a5b2 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 12 Jun 2026 11:16:54 +0200 Subject: [PATCH 2/3] Migrate WebSocket support from wsproto to the websockets package Replace the wsproto state machine with the websockets sans-IO Protocol in both the sessions and the ASGI transport. The raw event-based send()/receive() API is replaced by a message-level API: receive() now returns str | bytes, keeping the sans-IO library out of the public surface. Also set ws="none" on the HTTP test server: with websockets installed, uvicorn's ws="auto" selects the deprecated websockets.legacy implementation, whose import-time DeprecationWarning is an error under filterwarnings=error and silently kills the server thread. --- pyproject.toml | 2 +- src/httpx2/httpx2/_websockets/_exceptions.py | 8 +- src/httpx2/httpx2/_websockets/_ping.py | 4 +- src/httpx2/httpx2/_websockets/_session.py | 382 +++++++++---------- src/httpx2/httpx2/_websockets/_transport.py | 76 ++-- src/httpx2/pyproject.toml | 2 +- tests/httpx2/conftest.py | 4 +- tests/httpx2/websockets/test_session.py | 326 ++++++++++------ tests/httpx2/websockets/test_transport.py | 69 ++-- uv.lock | 6 +- 10 files changed, 492 insertions(+), 387 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c1b3f74d..1b3ddcd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dev = [ "trio-typing==0.10.0", "trustme==1.2.1", "uvicorn>=0.35", - "websockets>=15", + "wsproto>=1.2", "werkzeug>=3.1.6", # Linting "mypy==1.17.1", diff --git a/src/httpx2/httpx2/_websockets/_exceptions.py b/src/httpx2/httpx2/_websockets/_exceptions.py index 28d5aa35..a3ed3fa3 100644 --- a/src/httpx2/httpx2/_websockets/_exceptions.py +++ b/src/httpx2/httpx2/_websockets/_exceptions.py @@ -13,8 +13,6 @@ import typing if typing.TYPE_CHECKING: - import wsproto - from .._models import Response # pragma: no cover __all__ = [ @@ -53,11 +51,11 @@ def __init__(self, code: int = 1000, reason: str | None = None) -> None: class WebSocketInvalidTypeReceived(WebSocketException): """ - A received event was not of the expected type. + A received message was not of the expected type. """ - def __init__(self, event: wsproto.events.Event) -> None: - self.event = event + def __init__(self, message: str | bytes) -> None: + self.message = message class WebSocketNetworkError(WebSocketException): diff --git a/src/httpx2/httpx2/_websockets/_ping.py b/src/httpx2/httpx2/_websockets/_ping.py index 1aab637e..f7460e67 100644 --- a/src/httpx2/httpx2/_websockets/_ping.py +++ b/src/httpx2/httpx2/_websockets/_ping.py @@ -16,7 +16,7 @@ def create(self, ping_id: bytes | None = None) -> tuple[bytes, threading.Event]: self._pings[ping_id] = event return ping_id, event - def ack(self, ping_id: bytes | bytearray) -> None: + def ack(self, ping_id: bytes | bytearray | memoryview) -> None: event = self._pings.pop(bytes(ping_id)) event.set() @@ -31,6 +31,6 @@ def create(self, ping_id: bytes | None = None) -> tuple[bytes, anyio.Event]: self._pings[ping_id] = event return ping_id, event - def ack(self, ping_id: bytes | bytearray) -> None: + def ack(self, ping_id: bytes | bytearray | memoryview) -> None: event = self._pings.pop(bytes(ping_id)) event.set() diff --git a/src/httpx2/httpx2/_websockets/_session.py b/src/httpx2/httpx2/_websockets/_session.py index 5e6f2846..18eb7603 100644 --- a/src/httpx2/httpx2/_websockets/_session.py +++ b/src/httpx2/httpx2/_websockets/_session.py @@ -12,10 +12,10 @@ from types import TracebackType import anyio -import wsproto -import wsproto.utilities from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from wsproto.frame_protocol import CloseReason +from websockets.exceptions import InvalidState +from websockets.frames import Close, Frame, Opcode +from websockets.protocol import Protocol, Side, State import httpcore2 from httpcore2 import AsyncNetworkStream, NetworkStream @@ -45,6 +45,8 @@ DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS = 20.0 DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS = 20.0 +INTERNAL_ERROR = 1011 + class ShouldClose(Exception): pass @@ -54,6 +56,28 @@ class EndOfStream(Exception): pass +class MessageAssembler: + """ + Assembles data frames, possibly fragmented, into complete messages. + """ + + def __init__(self) -> None: + self._buffer = bytearray() + self._text = False + + def feed(self, frame: Frame) -> str | bytes | None: + if frame.opcode is Opcode.TEXT or frame.opcode is Opcode.BINARY: + self._buffer = bytearray(frame.data) + self._text = frame.opcode is Opcode.TEXT + else: + self._buffer += frame.data + if not frame.fin: + return None + data = bytes(self._buffer) + self._buffer = bytearray() + return data.decode("utf-8") if self._text else data + + class WebSocketSession: """ Sync context manager representing an opened WebSocket session. @@ -77,14 +101,15 @@ def __init__( response: Response | None = None, ) -> None: self.stream = stream - self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) + self.protocol = Protocol(Side.CLIENT, state=State.OPEN, max_size=None) self.response = response if self.response is not None: self.subprotocol = self.response.headers.get("sec-websocket-protocol") else: self.subprotocol = None - self._events: queue.Queue[wsproto.events.Event | WebSocketException] = queue.Queue(queue_size) + self._events: queue.Queue[str | bytes | WebSocketException] = queue.Queue(queue_size) + self._assembler = MessageAssembler() self._ping_manager = PingManager() self._should_close = threading.Event() @@ -151,35 +176,16 @@ def ping(self, payload: bytes = b"") -> threading.Event: ``` """ ping_id, callback = self._ping_manager.create(payload) - event = wsproto.events.Ping(ping_id) - self.send(event) + self._send(self.protocol.send_ping, ping_id) return callback - def send(self, event: wsproto.events.Event) -> None: - """ - Send a raw `wsproto.events.Event`. - - Mainly useful to send events that are not supported by the library. - Most of the time, `ping()`, `send_text()`, `send_bytes()` and `send_json()` are preferred. - - Raises `WebSocketNetworkError` if a network error occurred. - """ - try: - data = self.connection.send(event) - with self._write_lock: - self.stream.write(data) - except httpcore2.WriteError as e: - self.close(CloseReason.INTERNAL_ERROR, "Stream write error") - raise WebSocketNetworkError() from e - def send_text(self, data: str) -> None: """ Send a text message. Raises `WebSocketNetworkError` if a network error occurred. """ - event = wsproto.events.TextMessage(data=data) - self.send(event) + self._send(self.protocol.send_text, data.encode("utf-8")) def send_bytes(self, data: bytes) -> None: """ @@ -187,8 +193,7 @@ def send_bytes(self, data: bytes) -> None: Raises `WebSocketNetworkError` if a network error occurred. """ - event = wsproto.events.BytesMessage(data=data) - self.send(event) + self._send(self.protocol.send_binary, data) def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: """ @@ -203,17 +208,14 @@ def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: else: self.send_bytes(serialized_data.encode("utf-8")) - def receive(self, timeout: float | None = None) -> wsproto.events.Event: + def receive(self, timeout: float | None = None) -> str | bytes: """ - Receive a raw `wsproto.events.Event` from the server. + Receive a message from the server, either text or bytes. - Mainly useful to receive raw events. Most of the time, - `receive_text()`, `receive_bytes()`, and `receive_json()` are preferred. - - If `timeout` is `None`, this blocks until an event is available. + If `timeout` is `None`, this blocks until a message is available. Raises: - TimeoutError: No event was received before the timeout delay. + TimeoutError: No message was received before the timeout delay. WebSocketDisconnect: The server closed the WebSocket. WebSocketNetworkError: A network error occurred. """ @@ -223,55 +225,53 @@ def receive(self, timeout: float | None = None) -> wsproto.events.Event: raise TimeoutError from e if isinstance(event, WebSocketException): raise event - if isinstance(event, wsproto.events.CloseConnection): - raise WebSocketDisconnect(event.code, event.reason) return event def receive_text(self, timeout: float | None = None) -> str: """ Receive text from the server. - If `timeout` is `None`, this blocks until an event is available. + If `timeout` is `None`, this blocks until a message is available. Raises: - TimeoutError: No event was received before the timeout delay. + TimeoutError: No message was received before the timeout delay. WebSocketDisconnect: The server closed the WebSocket. WebSocketNetworkError: A network error occurred. - WebSocketInvalidTypeReceived: The received event was not a text message. + WebSocketInvalidTypeReceived: The received message was not a text message. """ - event = self.receive(timeout) - if isinstance(event, wsproto.events.TextMessage): - return event.data - raise WebSocketInvalidTypeReceived(event) + message = self.receive(timeout) + if isinstance(message, str): + return message + raise WebSocketInvalidTypeReceived(message) def receive_bytes(self, timeout: float | None = None) -> bytes: """ Receive bytes from the server. - If `timeout` is `None`, this blocks until an event is available. + If `timeout` is `None`, this blocks until a message is available. Raises: - TimeoutError: No event was received before the timeout delay. + TimeoutError: No message was received before the timeout delay. WebSocketDisconnect: The server closed the WebSocket. WebSocketNetworkError: A network error occurred. - WebSocketInvalidTypeReceived: The received event was not a bytes message. + WebSocketInvalidTypeReceived: The received message was not a bytes message. """ - event = self.receive(timeout) - if isinstance(event, wsproto.events.BytesMessage): - return bytes(event.data) - raise WebSocketInvalidTypeReceived(event) + message = self.receive(timeout) + if isinstance(message, bytes): + return message + raise WebSocketInvalidTypeReceived(message) def receive_json(self, timeout: float | None = None, mode: JSONMode = "text") -> typing.Any: """ Receive JSON data from the server, parsed with `json.loads()`, in `'text'` or `'binary'` mode. - If `timeout` is `None`, this blocks until an event is available. + If `timeout` is `None`, this blocks until a message is available. Raises: - TimeoutError: No event was received before the timeout delay. + TimeoutError: No message was received before the timeout delay. WebSocketDisconnect: The server closed the WebSocket. WebSocketNetworkError: A network error occurred. - WebSocketInvalidTypeReceived: The received event didn't correspond to the specified mode. + WebSocketInvalidTypeReceived: The received message didn't correspond to the specified mode. """ assert mode in ["text", "binary"] data: str | bytes @@ -285,72 +285,74 @@ def close(self, code: int = 1000, reason: str | None = None) -> None: """ Close the WebSocket session. - Internally, it'll send the `wsproto.events.CloseConnection` event. + Internally, it'll send a Close frame. *This method is automatically called when exiting the context manager.* """ self._should_close.set() if self._executor is not None: self._executor.shutdown(False) - if self.connection.state not in { - wsproto.connection.ConnectionState.LOCAL_CLOSING, - wsproto.connection.ConnectionState.CLOSED, - }: - event = wsproto.events.CloseConnection(code, reason) - data = self.connection.send(event) - try: - with self._write_lock: - self.stream.write(data) - except httpcore2.WriteError: - pass + try: + with self._write_lock: + if self.protocol.state is State.OPEN: + self.protocol.send_close(code, reason or "") + self._write_protocol_data() + except (httpcore2.WriteError, InvalidState): + pass self.stream.close() + def _send(self, send_event: typing.Callable[[bytes], None], data: bytes) -> None: + try: + with self._write_lock: + send_event(data) + self._write_protocol_data() + except httpcore2.WriteError as e: + self.close(INTERNAL_ERROR, "Stream write error") + raise WebSocketNetworkError() from e + + def _write_protocol_data(self) -> None: + for data in self.protocol.data_to_send(): + if data: + self.stream.write(data) + def _background_receive(self, max_bytes: int) -> None: """ Background thread listening for data from the server. Internally, it'll: - * Answer to Ping events. - * Acknowledge Pong events. - * Put other events in the `_events` queue that'll eventually be consumed by the user. + * Answer to Ping frames. + * Acknowledge Pong frames. + * Put messages in the `_events` queue that'll eventually be consumed by the user. """ - partial_message_buffer: str | bytes | None = None try: while not self._should_close.is_set(): data = self._wait_until_closed(self._read_stream, max_bytes) - self.connection.receive_data(data) - for event in self.connection.events(): - if isinstance(event, wsproto.events.Ping): - data = self.connection.send(event.response()) - with self._write_lock: - self.stream.write(data) + self.protocol.receive_data(data) + try: + with self._write_lock: + self._write_protocol_data() + except httpcore2.WriteError: + # Tolerate failing to reply once the peer started the closing handshake. + if self.protocol.state is State.OPEN: + raise + for frame in self.protocol.events_received(): + assert isinstance(frame, Frame) + if frame.opcode is Opcode.PING: continue - if isinstance(event, wsproto.events.Pong): - self._ping_manager.ack(event.payload) + if frame.opcode is Opcode.PONG: + self._ping_manager.ack(frame.data) continue - if isinstance(event, wsproto.events.CloseConnection): + if frame.opcode is Opcode.CLOSE: self._should_close.set() - if isinstance(event, wsproto.events.Message): - # Unfinished message: bufferize - if not event.message_finished: - if partial_message_buffer is None: - partial_message_buffer = event.data - else: - partial_message_buffer += event.data - # Finished message but no buffer: just emit the event - elif partial_message_buffer is None: - self._events.put(event) - # Finished message with buffer: emit the full event - else: - event_type = type(event) - full_message_event = event_type(partial_message_buffer + event.data) - partial_message_buffer = None - self._events.put(full_message_event) + close = Close.parse(frame.data) + self._events.put(WebSocketDisconnect(close.code, close.reason)) continue - self._events.put(event) + message = self._assembler.feed(frame) + if message is not None: + self._events.put(message) except (httpcore2.ReadError, httpcore2.WriteError, EndOfStream): - self.close(CloseReason.INTERNAL_ERROR, "Stream error") + self.close(INTERNAL_ERROR, "Stream error") self._events.put(WebSocketNetworkError()) except ShouldClose: pass @@ -361,11 +363,17 @@ def _background_keepalive_ping(self, interval_seconds: float, timeout_seconds: f should_close = self._wait_until_closed(self._should_close.wait, interval_seconds) if should_close: raise ShouldClose() - pong_callback = self.ping() + + try: + pong_callback = self.ping() + # Connection is closing, exit the task + except InvalidState: + return + if timeout_seconds is not None: acknowledged = self._wait_until_closed(pong_callback.wait, timeout_seconds) if not acknowledged: - self.close(CloseReason.INTERNAL_ERROR, "Keepalive ping timeout") + self.close(INTERNAL_ERROR, "Keepalive ping timeout") self._events.put(WebSocketNetworkError()) except ShouldClose: pass @@ -430,8 +438,8 @@ class AsyncWebSocketSession(anyio.AsyncContextManagerMixin): subprotocol: str | None response: Response | None - _send_event: MemoryObjectSendStream[wsproto.events.Event | WebSocketException] - _receive_event: MemoryObjectReceiveStream[wsproto.events.Event | WebSocketException] + _send_event: MemoryObjectSendStream[str | bytes | WebSocketException] + _receive_event: MemoryObjectReceiveStream[str | bytes | WebSocketException] def __init__( self, @@ -444,7 +452,7 @@ def __init__( response: Response | None = None, ) -> None: self.stream = stream - self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) + self.protocol = Protocol(Side.CLIENT, state=State.OPEN, max_size=None) self.response = response if self.response is not None: self.subprotocol = self.response.headers.get("sec-websocket-protocol") @@ -454,6 +462,7 @@ def __init__( self._ping_manager = AsyncPingManager() self._should_close = anyio.Event() self._write_lock = anyio.Lock() + self._assembler = MessageAssembler() self._max_message_size_bytes = max_message_size_bytes self._queue_size = queue_size @@ -468,9 +477,7 @@ def __init__( @contextlib.asynccontextmanager async def __asynccontextmanager__(self) -> AsyncGenerator[AsyncWebSocketSession]: - self._send_event, self._receive_event = anyio.create_memory_object_stream[ - wsproto.events.Event | WebSocketException - ]() + self._send_event, self._receive_event = anyio.create_memory_object_stream[str | bytes | WebSocketException]() self._background_task_group = anyio.create_task_group() async with self._send_event, self._receive_event, self._background_task_group: @@ -504,35 +511,16 @@ async def ping(self, payload: bytes = b"") -> anyio.Event: ``` """ ping_id, callback = self._ping_manager.create(payload) - event = wsproto.events.Ping(ping_id) - await self.send(event) + await self._send(self.protocol.send_ping, ping_id) return callback - async def send(self, event: wsproto.events.Event) -> None: - """ - Send a raw `wsproto.events.Event`. - - Mainly useful to send events that are not supported by the library. - Most of the time, `ping()`, `send_text()`, `send_bytes()` and `send_json()` are preferred. - - Raises `WebSocketNetworkError` if a network error occurred. - """ - try: - data = self.connection.send(event) - async with self._write_lock: - await self.stream.write(data) - except httpcore2.WriteError as e: - await self.close(CloseReason.INTERNAL_ERROR, "Stream write error") - raise WebSocketNetworkError() from e - async def send_text(self, data: str) -> None: """ Send a text message. Raises `WebSocketNetworkError` if a network error occurred. """ - event = wsproto.events.TextMessage(data=data) - await self.send(event) + await self._send(self.protocol.send_text, data.encode("utf-8")) async def send_bytes(self, data: bytes) -> None: """ @@ -540,8 +528,7 @@ async def send_bytes(self, data: bytes) -> None: Raises `WebSocketNetworkError` if a network error occurred. """ - event = wsproto.events.BytesMessage(data=data) - await self.send(event) + await self._send(self.protocol.send_binary, data) async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: """ @@ -556,17 +543,14 @@ async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: else: await self.send_bytes(serialized_data.encode("utf-8")) - async def receive(self, timeout: float | None = None) -> wsproto.events.Event: + async def receive(self, timeout: float | None = None) -> str | bytes: """ - Receive a raw `wsproto.events.Event` from the server. - - Mainly useful to receive raw events. Most of the time, - `receive_text()`, `receive_bytes()`, and `receive_json()` are preferred. + Receive a message from the server, either text or bytes. - If `timeout` is `None`, this blocks until an event is available. + If `timeout` is `None`, this blocks until a message is available. Raises: - TimeoutError: No event was received before the timeout delay. + TimeoutError: No message was received before the timeout delay. WebSocketDisconnect: The server closed the WebSocket. WebSocketNetworkError: A network error occurred. """ @@ -574,55 +558,53 @@ async def receive(self, timeout: float | None = None) -> wsproto.events.Event: event = await self._receive_event.receive() if isinstance(event, WebSocketException): raise event - if isinstance(event, wsproto.events.CloseConnection): - raise WebSocketDisconnect(event.code, event.reason) return event async def receive_text(self, timeout: float | None = None) -> str: """ Receive text from the server. - If `timeout` is `None`, this blocks until an event is available. + If `timeout` is `None`, this blocks until a message is available. Raises: - TimeoutError: No event was received before the timeout delay. + TimeoutError: No message was received before the timeout delay. WebSocketDisconnect: The server closed the WebSocket. WebSocketNetworkError: A network error occurred. - WebSocketInvalidTypeReceived: The received event was not a text message. + WebSocketInvalidTypeReceived: The received message was not a text message. """ - event = await self.receive(timeout) - if isinstance(event, wsproto.events.TextMessage): - return event.data - raise WebSocketInvalidTypeReceived(event) + message = await self.receive(timeout) + if isinstance(message, str): + return message + raise WebSocketInvalidTypeReceived(message) async def receive_bytes(self, timeout: float | None = None) -> bytes: """ Receive bytes from the server. - If `timeout` is `None`, this blocks until an event is available. + If `timeout` is `None`, this blocks until a message is available. Raises: - TimeoutError: No event was received before the timeout delay. + TimeoutError: No message was received before the timeout delay. WebSocketDisconnect: The server closed the WebSocket. WebSocketNetworkError: A network error occurred. - WebSocketInvalidTypeReceived: The received event was not a bytes message. + WebSocketInvalidTypeReceived: The received message was not a bytes message. """ - event = await self.receive(timeout) - if isinstance(event, wsproto.events.BytesMessage): - return bytes(event.data) - raise WebSocketInvalidTypeReceived(event) + message = await self.receive(timeout) + if isinstance(message, bytes): + return message + raise WebSocketInvalidTypeReceived(message) async def receive_json(self, timeout: float | None = None, mode: JSONMode = "text") -> typing.Any: """ Receive JSON data from the server, parsed with `json.loads()`, in `'text'` or `'binary'` mode. - If `timeout` is `None`, this blocks until an event is available. + If `timeout` is `None`, this blocks until a message is available. Raises: - TimeoutError: No event was received before the timeout delay. + TimeoutError: No message was received before the timeout delay. WebSocketDisconnect: The server closed the WebSocket. WebSocketNetworkError: A network error occurred. - WebSocketInvalidTypeReceived: The received event didn't correspond to the specified mode. + WebSocketInvalidTypeReceived: The received message didn't correspond to the specified mode. """ assert mode in ["text", "binary"] data: str | bytes @@ -636,70 +618,72 @@ async def close(self, code: int = 1000, reason: str | None = None) -> None: """ Close the WebSocket session. - Internally, it'll send the `wsproto.events.CloseConnection` event. + Internally, it'll send a Close frame. *This method is automatically called when exiting the context manager.* """ self._should_close.set() - if self.connection.state not in { - wsproto.connection.ConnectionState.LOCAL_CLOSING, - wsproto.connection.ConnectionState.CLOSED, - }: - event = wsproto.events.CloseConnection(code, reason) - data = self.connection.send(event) - try: - async with self._write_lock: - await self.stream.write(data) - except httpcore2.WriteError: - pass + try: + async with self._write_lock: + if self.protocol.state is State.OPEN: + self.protocol.send_close(code, reason or "") + await self._write_protocol_data() + except (httpcore2.WriteError, InvalidState): + pass await self.stream.aclose() + async def _send(self, send_event: typing.Callable[[bytes], None], data: bytes) -> None: + try: + async with self._write_lock: + send_event(data) + await self._write_protocol_data() + except httpcore2.WriteError as e: + await self.close(INTERNAL_ERROR, "Stream write error") + raise WebSocketNetworkError() from e + + async def _write_protocol_data(self) -> None: + for data in self.protocol.data_to_send(): + if data: + await self.stream.write(data) + async def _background_receive(self, max_bytes: int) -> None: """ Background task listening for data from the server. Internally, it'll: - * Answer to Ping events. - * Acknowledge Pong events. - * Put other events in the `_events` queue that'll eventually be consumed by the user. + * Answer to Ping frames. + * Acknowledge Pong frames. + * Put messages in the `_events` queue that'll eventually be consumed by the user. """ - partial_message_buffer: str | bytes | None = None try: while not self._should_close.is_set(): data = await self._read_stream(max_bytes) - self.connection.receive_data(data) - for event in self.connection.events(): - if isinstance(event, wsproto.events.Ping): - data = self.connection.send(event.response()) - async with self._write_lock: - await self.stream.write(data) + self.protocol.receive_data(data) + try: + async with self._write_lock: + await self._write_protocol_data() + except httpcore2.WriteError: + # Tolerate failing to reply once the peer started the closing handshake. + if self.protocol.state is State.OPEN: + raise + for frame in self.protocol.events_received(): + assert isinstance(frame, Frame) + if frame.opcode is Opcode.PING: continue - if isinstance(event, wsproto.events.Pong): - self._ping_manager.ack(event.payload) + if frame.opcode is Opcode.PONG: + self._ping_manager.ack(frame.data) continue - if isinstance(event, wsproto.events.CloseConnection): + if frame.opcode is Opcode.CLOSE: self._should_close.set() - if isinstance(event, wsproto.events.Message): - # Unfinished message: bufferize - if not event.message_finished: - if partial_message_buffer is None: - partial_message_buffer = event.data - else: - partial_message_buffer += event.data - # Finished message but no buffer: just emit the event - elif partial_message_buffer is None: - await self._send_event.send(event) - # Finished message with buffer: emit the full event - else: - event_type = type(event) - full_message_event = event_type(partial_message_buffer + event.data) - partial_message_buffer = None - await self._send_event.send(full_message_event) + close = Close.parse(frame.data) + await self._send_event.send(WebSocketDisconnect(close.code, close.reason)) continue - await self._send_event.send(event) + message = self._assembler.feed(frame) + if message is not None: + await self._send_event.send(message) except (httpcore2.ReadError, httpcore2.WriteError, EndOfStream): - await self.close(CloseReason.INTERNAL_ERROR, "Stream error") + await self.close(INTERNAL_ERROR, "Stream error") await self._send_event.send(WebSocketNetworkError()) async def _background_keepalive_ping(self, interval_seconds: float, timeout_seconds: float | None = None) -> None: @@ -709,7 +693,7 @@ async def _background_keepalive_ping(self, interval_seconds: float, timeout_seco try: pong_callback = await self.ping() # Connection is closing, exit the task - except wsproto.utilities.LocalProtocolError: + except InvalidState: return if timeout_seconds is not None: @@ -717,7 +701,7 @@ async def _background_keepalive_ping(self, interval_seconds: float, timeout_seco with anyio.fail_after(timeout_seconds): await pong_callback.wait() except TimeoutError: - await self.close(CloseReason.INTERNAL_ERROR, "Keepalive ping timeout") + await self.close(INTERNAL_ERROR, "Keepalive ping timeout") await self._send_event.send(WebSocketNetworkError()) async def _read_stream(self, max_bytes: int) -> bytes: diff --git a/src/httpx2/httpx2/_websockets/_transport.py b/src/httpx2/httpx2/_websockets/_transport.py index 607275c8..17481b1e 100644 --- a/src/httpx2/httpx2/_websockets/_transport.py +++ b/src/httpx2/httpx2/_websockets/_transport.py @@ -8,8 +8,9 @@ import anyio import anyio.abc import anyio.streams.stapled -import wsproto -from wsproto.frame_protocol import CloseReason +from websockets.frames import Close, Frame, Opcode +from websockets.protocol import Protocol, Side, State +from websockets.utils import accept_key from httpcore2 import AsyncNetworkStream @@ -24,6 +25,8 @@ Send = typing.Callable[[Message], typing.Awaitable[None]] ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] +INTERNAL_ERROR = 1011 + class ASGIWebSocketTransportError(Exception): pass @@ -34,9 +37,9 @@ def __init__(self, message: Message) -> None: self.message = message -class UnhandledWebSocketEvent(ASGIWebSocketTransportError): - def __init__(self, event: wsproto.events.Event) -> None: - self.event = event +class UnhandledWebSocketFrame(ASGIWebSocketTransportError): + def __init__(self, frame: Frame) -> None: + self.frame = frame class ASGIWebSocketAsyncNetworkStream(AsyncNetworkStream): @@ -57,8 +60,9 @@ def __init__( ) self._task_group = task_group self._initial_receive_timeout = initial_receive_timeout - self.connection = wsproto.WSConnection(wsproto.ConnectionType.SERVER) - self.connection.initiate_upgrade_connection(scope["headers"], scope["path"]) + self.protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + headers = {key.lower(): value for key, value in scope["headers"]} + self._websocket_key: bytes = headers[b"sec-websocket-key"] self._aentered = False async def __aenter__(self) -> tuple[ASGIWebSocketAsyncNetworkStream, bytes]: @@ -118,38 +122,37 @@ async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: if message_type not in {"websocket.send", "websocket.close"}: raise UnhandledASGIMessageType(message) - event: wsproto.events.Event if message_type == "websocket.send": data_str: str | None = message.get("text") if data_str is not None: - event = wsproto.events.TextMessage(data_str) + self.protocol.send_text(data_str.encode("utf-8")) data_bytes: bytes | None = message.get("bytes") if data_bytes is not None: - event = wsproto.events.BytesMessage(data_bytes) + self.protocol.send_binary(data_bytes) else: - event = wsproto.events.CloseConnection(message["code"], message["reason"]) + self.protocol.send_close(message["code"], message.get("reason") or "") - return self.connection.send(event) + return b"".join(data for data in self.protocol.data_to_send() if data) async def write(self, buffer: bytes, timeout: float | None = None) -> None: - self.connection.receive_data(buffer) - for event in self.connection.events(): - if isinstance(event, wsproto.events.Request): - pass - elif isinstance(event, wsproto.events.CloseConnection): + self.protocol.receive_data(buffer) + for frame in self.protocol.events_received(): + assert isinstance(frame, Frame) + if frame.opcode is Opcode.CLOSE: + close = Close.parse(frame.data) await self.send( { "type": "websocket.disconnect", - "code": event.code, - "reason": event.reason, + "code": close.code, + "reason": close.reason, } ) - elif isinstance(event, wsproto.events.TextMessage): - await self.send({"type": "websocket.receive", "text": event.data}) - elif isinstance(event, wsproto.events.BytesMessage): - await self.send({"type": "websocket.receive", "bytes": event.data}) + elif frame.opcode is Opcode.TEXT: + await self.send({"type": "websocket.receive", "text": bytes(frame.data).decode("utf-8")}) + elif frame.opcode is Opcode.BINARY: + await self.send({"type": "websocket.receive", "bytes": bytes(frame.data)}) else: - raise UnhandledWebSocketEvent(event) + raise UnhandledWebSocketFrame(frame) async def aclose(self) -> None: with contextlib.suppress(anyio.ClosedResourceError): @@ -178,20 +181,29 @@ async def _run(self) -> None: except Exception as e: message: Message = { "type": "websocket.close", - "code": CloseReason.INTERNAL_ERROR, + "code": INTERNAL_ERROR, "reason": str(e), } with contextlib.suppress(anyio.ClosedResourceError): await send(message) def _build_accept_response(self, message: Message) -> bytes: - subprotocol = message.get("subprotocol", None) - headers = message.get("headers", []) - return self.connection.send( - wsproto.events.AcceptConnection( - subprotocol=subprotocol, - extra_headers=headers, - ) + subprotocol: str | None = message.get("subprotocol", None) + headers: list[tuple[bytes, bytes]] = message.get("headers", []) + response_headers = [ + (b"Upgrade", b"websocket"), + (b"Connection", b"Upgrade"), + (b"Sec-WebSocket-Accept", accept_key(self._websocket_key.decode("utf-8")).encode("utf-8")), + ] + if subprotocol is not None: + response_headers.append((b"Sec-WebSocket-Protocol", subprotocol.encode("utf-8"))) + response_headers.extend(headers) + return b"".join( + [ + b"HTTP/1.1 101 Switching Protocols\r\n", + b"".join(key + b": " + value + b"\r\n" for key, value in response_headers), + b"\r\n", + ] ) diff --git a/src/httpx2/pyproject.toml b/src/httpx2/pyproject.toml index 5612cf18..e2a38c63 100644 --- a/src/httpx2/pyproject.toml +++ b/src/httpx2/pyproject.toml @@ -49,7 +49,7 @@ dependencies = [ "anyio>=4.10", "idna>=3.18", "typing_extensions>=4.5.0; python_version < '3.13'", - "wsproto>=1.2", + "websockets>=15", ] [project.optional-dependencies] diff --git a/tests/httpx2/conftest.py b/tests/httpx2/conftest.py index 156c5e3a..363193ce 100644 --- a/tests/httpx2/conftest.py +++ b/tests/httpx2/conftest.py @@ -274,6 +274,8 @@ def serve_in_thread(server: TestServer) -> typing.Iterator[TestServer]: @pytest.fixture(scope="session") def server(free_tcp_port_factory: typing.Callable[[], int]) -> typing.Iterator[TestServer]: - config = Config(app=app, lifespan="off", loop="asyncio", port=free_tcp_port_factory()) + # `ws="auto"` would pick uvicorn's implementation based on the deprecated `websockets.legacy`, + # which warns on import, and warnings are errors here. This server only handles plain HTTP. + config = Config(app=app, lifespan="off", loop="asyncio", port=free_tcp_port_factory(), ws="none") server = TestServer(config=config) yield from serve_in_thread(server) diff --git a/tests/httpx2/websockets/test_session.py b/tests/httpx2/websockets/test_session.py index 55fe977f..5f4e4ac5 100644 --- a/tests/httpx2/websockets/test_session.py +++ b/tests/httpx2/websockets/test_session.py @@ -6,8 +6,9 @@ import anyio import pytest -import wsproto from starlette.websockets import WebSocket, WebSocketDisconnect as StarletteWebSocketDisconnect +from websockets.frames import Frame, Opcode +from websockets.protocol import Protocol, Side, State import httpcore2 import httpx2 @@ -24,6 +25,16 @@ from tests.httpx2.websockets.conftest import ServerFactoryFixture +def wire(protocol: Protocol) -> bytes: + return b"".join(data for data in protocol.data_to_send() if data) + + +def mock_network_stream(spec: type) -> MagicMock: + stream = MagicMock(spec=spec) + stream.read.return_value = b"" + return stream + + @pytest.mark.anyio async def test_upgrade_error() -> None: def handler(request: httpx2.Request) -> httpx2.Response: @@ -54,7 +65,6 @@ class TestSend: async def test_send_error(self) -> None: class MockNetworkStream(NetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) self._should_close = False def read(self, max_bytes: int, timeout: float | None = None) -> bytes: @@ -71,18 +81,17 @@ def close(self) -> None: stream = MockNetworkStream() with pytest.raises(WebSocketNetworkError): with WebSocketSession(stream) as websocket_session: - websocket_session.send(wsproto.events.Ping()) + websocket_session.send_text("CLIENT_MESSAGE") async def test_async_send_error(self) -> None: class AsyncMockNetworkStream(AsyncNetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) self._should_close = False async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: while not self._should_close: await anyio.sleep(0.1) - raise httpcore2.ReadError() + raise httpcore2.ReadError() # pragma: no cover async def write(self, buffer: bytes, timeout: float | None = None) -> None: raise httpcore2.WriteError() @@ -93,31 +102,7 @@ async def aclose(self) -> None: stream = AsyncMockNetworkStream() with pytest.RaisesGroup(WebSocketNetworkError): async with AsyncWebSocketSession(stream) as websocket_session: - await websocket_session.send(wsproto.events.Ping()) - - async def test_send( - self, - server_factory: ServerFactoryFixture, - on_receive_message: MagicMock, - ) -> None: - async def websocket_endpoint(websocket: WebSocket) -> None: - await websocket.accept() - - message = await websocket.receive_text() - on_receive_message(message) - - await websocket.close() - - with server_factory(websocket_endpoint) as socket: - with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: - with client.websocket("http://socket/ws") as ws: - ws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) - - async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: - async with aclient.websocket("http://socket/ws") as aws: - await aws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) - - on_receive_message.assert_has_calls([call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")]) + await websocket_session.send_text("CLIENT_MESSAGE") async def test_send_text( self, @@ -198,9 +183,6 @@ async def websocket_endpoint(websocket: WebSocket) -> None: class TestReceive: async def test_receive_error(self) -> None: class MockNetworkStream(NetworkStream): - def __init__(self) -> None: - self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) - def read(self, max_bytes: int, timeout: float | None = None) -> bytes: raise httpcore2.ReadError() @@ -217,9 +199,6 @@ def close(self) -> None: def test_receive_closed_socket(self) -> None: class MockNetworkStream(NetworkStream): - def __init__(self) -> None: - self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) - def read(self, max_bytes: int, timeout: float | None = None) -> bytes: return b"" @@ -236,9 +215,6 @@ def close(self) -> None: def test_receive_timeout(self) -> None: class MockNetworkStream(NetworkStream): - def __init__(self) -> None: - self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) - def read(self, max_bytes: int, timeout: float | None = None) -> bytes: time.sleep(0.2) return b"" @@ -256,9 +232,6 @@ def close(self) -> None: async def test_async_receive_error(self) -> None: class AsyncMockNetworkStream(AsyncNetworkStream): - def __init__(self) -> None: - self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) - async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: raise httpcore2.ReadError() @@ -275,9 +248,6 @@ async def aclose(self) -> None: async def test_async_receive_closed_socket(self) -> None: class AsyncMockNetworkStream(AsyncNetworkStream): - def __init__(self) -> None: - self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) - async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: return b"" @@ -303,15 +273,13 @@ async def websocket_endpoint(websocket: WebSocket) -> None: with server_factory(websocket_endpoint) as socket: with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: with client.websocket("http://socket/ws") as ws: - event = ws.receive() - assert isinstance(event, wsproto.events.TextMessage) - assert event.data == "SERVER_MESSAGE" + message = ws.receive() + assert message == "SERVER_MESSAGE" async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: async with aclient.websocket("http://socket/ws") as aws: - event = await aws.receive() - assert isinstance(event, wsproto.events.TextMessage) - assert event.data == "SERVER_MESSAGE" + message = await aws.receive() + assert message == "SERVER_MESSAGE" @pytest.mark.parametrize( "full_message,send_method", @@ -337,15 +305,65 @@ async def websocket_endpoint(websocket: WebSocket) -> None: with server_factory(websocket_endpoint) as socket: with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: with client.websocket("http://socket/ws", max_message_size_bytes=1024) as ws: - event = ws.receive() - assert isinstance(event, wsproto.events.Message) - assert event.data == full_message + message = ws.receive() + assert message == full_message async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: async with aclient.websocket("http://socket/ws", max_message_size_bytes=1024) as aws: - event = await aws.receive() - assert isinstance(event, wsproto.events.Message) - assert event.data == full_message + message = await aws.receive() + assert message == full_message + + async def test_receive_fragmented_message(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + protocol.send_text(b"SERVER", fin=False) + first = wire(protocol) + protocol.send_continuation(b"_MESSAGE", fin=True) + second = wire(protocol) + self.data_to_send = [first, second] + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + try: + return self.data_to_send.pop(0) + except IndexError: + raise httpcore2.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + def close(self) -> None: + pass + + stream = MockNetworkStream() + with WebSocketSession(stream) as websocket_session: + assert websocket_session.receive() == "SERVER_MESSAGE" + + async def test_async_receive_fragmented_message(self) -> None: + class AsyncMockNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + protocol.send_text(b"SERVER", fin=False) + first = wire(protocol) + protocol.send_continuation(b"_MESSAGE", fin=True) + second = wire(protocol) + self.data_to_send = [first, second] + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + try: + return self.data_to_send.pop(0) + except IndexError: + raise httpcore2.ReadError() + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + async def aclose(self) -> None: + pass + + stream = AsyncMockNetworkStream() + async with AsyncWebSocketSession(stream) as websocket_session: + assert await websocket_session.receive() == "SERVER_MESSAGE" async def test_receive_text(self, server_factory: ServerFactoryFixture) -> None: async def websocket_endpoint(websocket: WebSocket) -> None: @@ -449,21 +467,21 @@ class TestReceivePing: async def test_receive_ping(self) -> None: class MockNetworkStream(NetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) - self.events_to_send = [ - wsproto.events.Ping(b"SERVER_PING"), - wsproto.events.CloseConnection(1000), - ] + self.protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + self.received_frames: list[Frame] = [] + self.protocol.send_ping(b"SERVER_PING") + self.protocol.send_close(1000) + self.data_to_send = [wire(self.protocol)] def read(self, max_bytes: int, timeout: float | None = None) -> bytes: try: - event = self.events_to_send.pop(0) - return self.connection.send(event) - except IndexError: + return self.data_to_send.pop(0) + except IndexError: # pragma: no cover raise httpcore2.ReadError() def write(self, buffer: bytes, timeout: float | None = None) -> None: - self.connection.receive_data(buffer) + self.protocol.receive_data(buffer) + self.received_frames.extend(e for e in self.protocol.events_received() if isinstance(e, Frame)) def close(self) -> None: pass @@ -472,30 +490,27 @@ def close(self) -> None: with WebSocketSession(stream): await anyio.sleep(0.1) - received_events = list(stream.connection.events()) - assert received_events == [ - wsproto.events.Pong(b"SERVER_PING"), - wsproto.events.CloseConnection(1000, ""), - ] + assert [frame.opcode for frame in stream.received_frames] == [Opcode.PONG, Opcode.CLOSE] + assert bytes(stream.received_frames[0].data) == b"SERVER_PING" async def test_async_receive_ping(self) -> None: class MockAsyncNetworkStream(AsyncNetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) - self.events_to_send = [ - wsproto.events.Ping(b"SERVER_PING"), - wsproto.events.CloseConnection(1000), - ] + self.protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + self.received_frames: list[Frame] = [] + self.protocol.send_ping(b"SERVER_PING") + self.protocol.send_close(1000) + self.data_to_send = [wire(self.protocol)] async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: try: - event = self.events_to_send.pop(0) - return self.connection.send(event) - except IndexError: + return self.data_to_send.pop(0) + except IndexError: # pragma: no cover raise httpcore2.ReadError() async def write(self, buffer: bytes, timeout: float | None = None) -> None: - self.connection.receive_data(buffer) + self.protocol.receive_data(buffer) + self.received_frames.extend(e for e in self.protocol.events_received() if isinstance(e, Frame)) async def aclose(self) -> None: pass @@ -504,40 +519,119 @@ async def aclose(self) -> None: async with AsyncWebSocketSession(stream): await anyio.sleep(0.1) - received_events = list(stream.connection.events()) - assert received_events == [ - wsproto.events.Pong(b"SERVER_PING"), - wsproto.events.CloseConnection(1000, ""), - ] + assert [frame.opcode for frame in stream.received_frames] == [Opcode.PONG, Opcode.CLOSE] + assert bytes(stream.received_frames[0].data) == b"SERVER_PING" + + async def test_receive_ping_reply_write_error(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + protocol.send_ping(b"SERVER_PING") + self.data_to_send = [wire(protocol)] + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + try: + return self.data_to_send.pop(0) + except IndexError: # pragma: no cover + raise httpcore2.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise httpcore2.WriteError() + + def close(self) -> None: + pass + + stream = MockNetworkStream() + with pytest.raises(WebSocketNetworkError): + with WebSocketSession(stream) as websocket_session: + websocket_session.receive() + + async def test_async_receive_ping_reply_write_error(self) -> None: + class MockAsyncNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + protocol.send_ping(b"SERVER_PING") + self.data_to_send = [wire(protocol)] + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + try: + return self.data_to_send.pop(0) + except IndexError: # pragma: no cover + raise httpcore2.ReadError() + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise httpcore2.WriteError() + + async def aclose(self) -> None: + pass + + stream = MockAsyncNetworkStream() + with pytest.RaisesGroup(WebSocketNetworkError): + async with AsyncWebSocketSession(stream) as websocket_session: + await websocket_session.receive() + + +class NoopNetworkStream(NetworkStream): + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + raise NotImplementedError + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise NotImplementedError + + def close(self) -> None: + pass + + +class NoopAsyncNetworkStream(AsyncNetworkStream): + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + raise NotImplementedError + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise NotImplementedError + + async def aclose(self) -> None: + pass @pytest.mark.anyio class TestKeepalivePing: + async def test_keepalive_ping_closing_connection(self) -> None: + session = WebSocketSession(NoopNetworkStream()) + session.protocol.receive_eof() + session._background_keepalive_ping(0.01) + session.close() + + async def test_async_keepalive_ping_closing_connection(self) -> None: + session = AsyncWebSocketSession(NoopAsyncNetworkStream()) + session.protocol.receive_eof() + await session._background_keepalive_ping(0.01) + await session.close() + async def test_keepalive_ping(self) -> None: class MockNetworkStream(NetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self.protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) self._should_close = False self.ping_received = 0 self.ping_answered = 0 - self.events_to_send: queue.Queue[wsproto.events.Event] = queue.Queue() + self.data_to_send: queue.Queue[bytes] = queue.Queue() def read(self, max_bytes: int, timeout: float | None = None) -> bytes: while not self._should_close: try: - event = self.events_to_send.get_nowait() + data = self.data_to_send.get_nowait() self.ping_answered += 1 - return self.connection.send(event) + return data except queue.Empty: pass raise httpcore2.ReadError() def write(self, buffer: bytes, timeout: float | None = None) -> None: - self.connection.receive_data(buffer) - for event in self.connection.events(): - if isinstance(event, wsproto.events.Ping): + self.protocol.receive_data(buffer) + for frame in self.protocol.events_received(): + if isinstance(frame, Frame) and frame.opcode is Opcode.PING: self.ping_received += 1 - self.events_to_send.put(event.response()) + self.data_to_send.put(wire(self.protocol)) def close(self) -> None: self._should_close = True @@ -556,7 +650,6 @@ def close(self) -> None: async def test_keepalive_ping_timeout(self) -> None: class MockNetworkStream(NetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) self._should_close = False def read(self, max_bytes: int, timeout: float | None = None) -> bytes: @@ -582,36 +675,36 @@ def close(self) -> None: async def test_async_keepalive_ping(self) -> None: class MockAsyncNetworkStream(AsyncNetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self.protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) self._should_close = False self.ping_received = 0 self.ping_answered = 0 ( - self.send_events, - self.receive_events, - ) = anyio.create_memory_object_stream[wsproto.events.Event]() + self.send_data, + self.receive_data, + ) = anyio.create_memory_object_stream[bytes]() async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: while not self._should_close: try: - event = self.receive_events.receive_nowait() + data = self.receive_data.receive_nowait() self.ping_answered += 1 - return self.connection.send(event) + return data except anyio.WouldBlock: await anyio.sleep(0.1) - raise httpcore2.ReadError() + raise httpcore2.ReadError() # pragma: no cover async def write(self, buffer: bytes, timeout: float | None = None) -> None: - self.connection.receive_data(buffer) - for event in self.connection.events(): - if isinstance(event, wsproto.events.Ping): + self.protocol.receive_data(buffer) + for frame in self.protocol.events_received(): + if isinstance(frame, Frame) and frame.opcode is Opcode.PING: self.ping_received += 1 - await self.send_events.send(event.response()) + await self.send_data.send(wire(self.protocol)) async def aclose(self) -> None: self._should_close = True - self.send_events.close() - self.receive_events.close() + self.send_data.close() + self.receive_data.close() stream = MockAsyncNetworkStream() async with AsyncWebSocketSession( @@ -627,13 +720,12 @@ async def aclose(self) -> None: async def test_async_keepalive_ping_timeout(self) -> None: class MockAsyncNetworkStream(AsyncNetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) self._should_close = False async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: while not self._should_close: await anyio.sleep(0.1) - raise httpcore2.ReadError() + raise httpcore2.ReadError() # pragma: no cover async def write(self, buffer: bytes, timeout: float | None = None) -> None: pass @@ -721,7 +813,7 @@ def handler(request: httpx2.Request) -> httpx2.Response: return httpx2.Response( 101, headers={"sec-websocket-protocol": "custom_protocol"}, - extensions={"network_stream": MagicMock(spec=NetworkStream)}, + extensions={"network_stream": mock_network_stream(NetworkStream)}, ) def async_handler(request: httpx2.Request) -> httpx2.Response: @@ -730,7 +822,7 @@ def async_handler(request: httpx2.Request) -> httpx2.Response: return httpx2.Response( 101, headers={"sec-websocket-protocol": "custom_protocol"}, - extensions={"network_stream": MagicMock(spec=AsyncNetworkStream)}, + extensions={"network_stream": mock_network_stream(AsyncNetworkStream)}, ) with httpx2.Client(base_url="http://localhost:8000", transport=httpx2.MockTransport(handler)) as client: @@ -815,19 +907,21 @@ async def websocket_endpoint(websocket: WebSocket) -> None: @pytest.mark.anyio -async def test_client_websocket_with_mock_stream() -> None: +async def test_client_websocket_with_wss_scheme() -> None: def handler(request: httpx2.Request) -> httpx2.Response: - return httpx2.Response(101, extensions={"network_stream": MagicMock(spec=NetworkStream)}) + assert request.url.scheme == "https" + return httpx2.Response(101, extensions={"network_stream": mock_network_stream(NetworkStream)}) def async_handler(request: httpx2.Request) -> httpx2.Response: - return httpx2.Response(101, extensions={"network_stream": MagicMock(spec=AsyncNetworkStream)}) + assert request.url.scheme == "https" + return httpx2.Response(101, extensions={"network_stream": mock_network_stream(AsyncNetworkStream)}) with httpx2.Client(base_url="http://localhost:8000", transport=httpx2.MockTransport(handler)) as client: - with client.websocket("http://socket/ws") as ws: + with client.websocket("wss://socket/ws") as ws: assert isinstance(ws.response, httpx2.Response) async with httpx2.AsyncClient( base_url="http://localhost:8000", transport=httpx2.MockTransport(async_handler) ) as aclient: - async with aclient.websocket("http://socket/ws") as aws: + async with aclient.websocket("wss://socket/ws") as aws: assert isinstance(aws.response, httpx2.Response) diff --git a/tests/httpx2/websockets/test_transport.py b/tests/httpx2/websockets/test_transport.py index 49e0cd09..2ead09e5 100644 --- a/tests/httpx2/websockets/test_transport.py +++ b/tests/httpx2/websockets/test_transport.py @@ -5,13 +5,14 @@ import anyio import pytest -import wsproto from anyio import CancelScope, ClosedResourceError, create_task_group from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route, WebSocketRoute from starlette.websockets import WebSocket +from websockets.frames import Frame, Opcode +from websockets.protocol import Protocol, Side, State import httpx2 from httpx2 import ASGIWebSocketTransport, WebSocketDisconnect, WebSocketUpgradeError @@ -22,11 +23,15 @@ Scope, Send, UnhandledASGIMessageType, - UnhandledWebSocketEvent, + UnhandledWebSocketFrame, ) if sys.version_info < (3, 11): - from exceptiongroup import ExceptionGroup + from exceptiongroup import ExceptionGroup # pragma: no cover + + +def wire(protocol: Protocol) -> bytes: + return b"".join(data for data in protocol.data_to_send() if data) @pytest.fixture @@ -48,8 +53,8 @@ def scope(websocket_request_headers: dict[str, str]) -> Scope: "root_path": "/", "scheme": "ws", "headers": [ - ("host", "localhost"), - *websocket_request_headers.items(), + (b"host", b"localhost"), + *((key.encode("utf-8"), value.encode("utf-8")) for key, value in websocket_request_headers.items()), ], "subprotocols": [], "server": ("localhost", 8000), @@ -69,19 +74,19 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: message = await receive() received_messages.append(message) - connection = wsproto.connection.Connection(wsproto.connection.CLIENT) + protocol = Protocol(Side.CLIENT, state=State.OPEN, max_size=None) async with ( create_task_group() as tg, ASGIWebSocketAsyncNetworkStream(app, scope, tg) as (stream, _), ): - text_event = wsproto.events.TextMessage("CLIENT_MESSAGE") - await stream.write(connection.send(text_event)) + protocol.send_text(b"CLIENT_MESSAGE") + await stream.write(wire(protocol)) - bytes_event = wsproto.events.BytesMessage(b"CLIENT_MESSAGE") - await stream.write(connection.send(bytes_event)) + protocol.send_binary(b"CLIENT_MESSAGE") + await stream.write(wire(protocol)) - close_event = wsproto.events.CloseConnection(1000) - await stream.write(connection.send(close_event)) + protocol.send_close(1000) + await stream.write(wire(protocol)) # Add a small delay to ensure the app has processed all messages await anyio.sleep(0.1) @@ -98,14 +103,14 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "websocket.accept"}) await receive() - connection = wsproto.connection.Connection(wsproto.connection.CLIENT) + protocol = Protocol(Side.CLIENT, state=State.OPEN, max_size=None) async with ( create_task_group() as tg, ASGIWebSocketAsyncNetworkStream(app, scope, tg) as (stream, _), ): - with pytest.raises(UnhandledWebSocketEvent): - ping_event = wsproto.events.Ping(b"PING") - await stream.write(connection.send(ping_event)) + with pytest.raises(UnhandledWebSocketFrame): + protocol.send_ping(b"PING") + await stream.write(wire(protocol)) async def test_read(self, scope: Scope) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: @@ -114,21 +119,19 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "websocket.send", "bytes": b"SERVER_MESSAGE"}) await send({"type": "websocket.close", "code": 1000, "reason": ""}) - connection = wsproto.connection.Connection(wsproto.connection.CLIENT) + protocol = Protocol(Side.CLIENT, state=State.OPEN, max_size=None) async with ( create_task_group() as tg, ASGIWebSocketAsyncNetworkStream(app, scope, tg) as (stream, _), ): for _ in range(3): data = await stream.read(4096) - connection.receive_data(data) + protocol.receive_data(data) - events = list(connection.events()) - assert events == [ - wsproto.events.TextMessage("SERVER_MESSAGE"), - wsproto.events.BytesMessage(bytearray(b"SERVER_MESSAGE")), - wsproto.events.CloseConnection(1000, ""), - ] + frames = [event for event in protocol.events_received() if isinstance(event, Frame)] + assert [frame.opcode for frame in frames] == [Opcode.TEXT, Opcode.BINARY, Opcode.CLOSE] + assert bytes(frames[0].data) == b"SERVER_MESSAGE" + assert bytes(frames[1].data) == b"SERVER_MESSAGE" async def test_read_unhandled_asgi_message(self, scope: Scope) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: @@ -200,6 +203,18 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert excinfo.group_contains(RuntimeError) + async def test_context_manager_twice(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + await receive() + + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg) as (stream, _), + ): + with pytest.raises(RuntimeError): + await stream.__aenter__() + async def test_app_exception_with_closed_send_queue(self, scope: Scope) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "websocket.accept"}) @@ -222,7 +237,7 @@ async def http_endpoint(request: Request) -> PlainTextResponse: async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.receive_text() - await websocket.close() + await websocket.close() # pragma: no cover routes = [ Route("/http", endpoint=http_endpoint), @@ -324,7 +339,7 @@ async def test_keepalive_ping_disabled() -> None: async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.receive_text() - await websocket.close() + await websocket.close() # pragma: no cover app = Starlette( routes=[ @@ -342,7 +357,7 @@ async def test_cancel_scope_integrity() -> None: async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.receive_text() - await websocket.close() + await websocket.close() # pragma: no cover app = Starlette( routes=[ diff --git a/uv.lock b/uv.lock index a4e62128..847e4f90 100644 --- a/uv.lock +++ b/uv.lock @@ -47,8 +47,8 @@ dev = [ { name = "trustme", specifier = "==1.2.1" }, { name = "twine", specifier = "==6.1.0" }, { name = "uvicorn", specifier = ">=0.35" }, - { name = "websockets", specifier = ">=15" }, { name = "werkzeug", specifier = ">=3.1.6" }, + { name = "wsproto", specifier = ">=1.2" }, ] docs = [ { name = "mkdocstrings", extras = ["python"], specifier = ">=0.27" }, @@ -1356,7 +1356,7 @@ dependencies = [ { name = "idna" }, { name = "truststore" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" }, - { name = "wsproto" }, + { name = "websockets" }, ] [package.optional-dependencies] @@ -1393,7 +1393,7 @@ requires-dist = [ { name = "socksio", marker = "extra == 'socks'", specifier = "==1.*" }, { name = "truststore", specifier = ">=0.10" }, { name = "typing-extensions", marker = "python_full_version < '3.13'", specifier = ">=4.5.0" }, - { name = "wsproto", specifier = ">=1.2" }, + { name = "websockets", specifier = ">=15" }, { name = "zstandard", marker = "python_full_version < '3.14' and extra == 'zstd'", specifier = ">=0.18.0" }, ] provides-extras = ["brotli", "cli", "http2", "socks", "zstd"] From 890df71a9244fa08c163d1cd79cbef41c2745b5d Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 12 Jun 2026 13:29:41 +0200 Subject: [PATCH 3/3] Keep httpcore2 lazily imported and serialize WebSocket protocol access The _websockets modules imported httpcore2 eagerly, defeating httpx2's lazy loading of httpcore2; import it inside the methods that need its exceptions and drop the AsyncNetworkStream base class from the ASGI stream. The websockets Protocol is not thread-safe: the sync session's background thread called receive_data() outside the write lock, racing send_close() in close() and tripping an assertion inside the protocol. All protocol access now happens under the write lock. Also add websockets to the dev dependency group, avoid attribute traversal when patching in test_top_level_websocket (test_httpcore_lazy_loading re-imports httpx2, leaving the fresh module without submodule attributes), and make the thread-leak test assert that session threads terminate instead of comparing exact thread counts. --- pyproject.toml | 1 + src/httpx2/httpx2/_websockets/_session.py | 55 +++++++++++++-------- src/httpx2/httpx2/_websockets/_transport.py | 9 ++-- tests/httpx2/websockets/test_session.py | 29 +++++------ uv.lock | 1 + 5 files changed, 58 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1b3ddcd9..6b2079b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dev = [ "trio-typing==0.10.0", "trustme==1.2.1", "uvicorn>=0.35", + "websockets>=15", "wsproto>=1.2", "werkzeug>=3.1.6", # Linting diff --git a/src/httpx2/httpx2/_websockets/_session.py b/src/httpx2/httpx2/_websockets/_session.py index 18eb7603..fb1e66fb 100644 --- a/src/httpx2/httpx2/_websockets/_session.py +++ b/src/httpx2/httpx2/_websockets/_session.py @@ -17,9 +17,6 @@ from websockets.frames import Close, Frame, Opcode from websockets.protocol import Protocol, Side, State -import httpcore2 -from httpcore2 import AsyncNetworkStream, NetworkStream - from .._models import Headers from .._urls import URL from ._exceptions import ( @@ -33,6 +30,8 @@ from ._transport import ASGIWebSocketAsyncNetworkStream if typing.TYPE_CHECKING: + from httpcore2 import AsyncNetworkStream, NetworkStream + from .._client import AsyncClient, Client, UseClientDefault from .._models import Response from .._types import AuthTypes, CookieTypes, HeaderTypes, QueryParamTypes, RequestExtensions, TimeoutTypes @@ -289,6 +288,8 @@ def close(self, code: int = 1000, reason: str | None = None) -> None: *This method is automatically called when exiting the context manager.* """ + import httpcore2 + self._should_close.set() if self._executor is not None: self._executor.shutdown(False) @@ -302,6 +303,8 @@ def close(self, code: int = 1000, reason: str | None = None) -> None: self.stream.close() def _send(self, send_event: typing.Callable[[bytes], None], data: bytes) -> None: + import httpcore2 + try: with self._write_lock: send_event(data) @@ -325,18 +328,23 @@ def _background_receive(self, max_bytes: int) -> None: * Acknowledge Pong frames. * Put messages in the `_events` queue that'll eventually be consumed by the user. """ + import httpcore2 + try: while not self._should_close.is_set(): data = self._wait_until_closed(self._read_stream, max_bytes) - self.protocol.receive_data(data) - try: - with self._write_lock: + # The protocol is not thread-safe: keep every interaction with it + # under the write lock, so it can't race user sends and closes. + with self._write_lock: + self.protocol.receive_data(data) + frames = self.protocol.events_received() + try: self._write_protocol_data() - except httpcore2.WriteError: - # Tolerate failing to reply once the peer started the closing handshake. - if self.protocol.state is State.OPEN: - raise - for frame in self.protocol.events_received(): + except httpcore2.WriteError: + # Tolerate failing to reply once the peer started the closing handshake. + if self.protocol.state is State.OPEN: + raise + for frame in frames: assert isinstance(frame, Frame) if frame.opcode is Opcode.PING: continue @@ -361,7 +369,7 @@ def _background_keepalive_ping(self, interval_seconds: float, timeout_seconds: f try: while not self._should_close.is_set(): should_close = self._wait_until_closed(self._should_close.wait, interval_seconds) - if should_close: + if should_close: # pragma: no cover raise ShouldClose() try: @@ -622,6 +630,8 @@ async def close(self, code: int = 1000, reason: str | None = None) -> None: *This method is automatically called when exiting the context manager.* """ + import httpcore2 + self._should_close.set() try: async with self._write_lock: @@ -633,6 +643,8 @@ async def close(self, code: int = 1000, reason: str | None = None) -> None: await self.stream.aclose() async def _send(self, send_event: typing.Callable[[bytes], None], data: bytes) -> None: + import httpcore2 + try: async with self._write_lock: send_event(data) @@ -656,18 +668,21 @@ async def _background_receive(self, max_bytes: int) -> None: * Acknowledge Pong frames. * Put messages in the `_events` queue that'll eventually be consumed by the user. """ + import httpcore2 + try: while not self._should_close.is_set(): data = await self._read_stream(max_bytes) - self.protocol.receive_data(data) - try: - async with self._write_lock: + async with self._write_lock: + self.protocol.receive_data(data) + frames = self.protocol.events_received() + try: await self._write_protocol_data() - except httpcore2.WriteError: - # Tolerate failing to reply once the peer started the closing handshake. - if self.protocol.state is State.OPEN: - raise - for frame in self.protocol.events_received(): + except httpcore2.WriteError: + # Tolerate failing to reply once the peer started the closing handshake. + if self.protocol.state is State.OPEN: + raise + for frame in frames: assert isinstance(frame, Frame) if frame.opcode is Opcode.PING: continue diff --git a/src/httpx2/httpx2/_websockets/_transport.py b/src/httpx2/httpx2/_websockets/_transport.py index 17481b1e..ba6dd39c 100644 --- a/src/httpx2/httpx2/_websockets/_transport.py +++ b/src/httpx2/httpx2/_websockets/_transport.py @@ -12,8 +12,6 @@ from websockets.protocol import Protocol, Side, State from websockets.utils import accept_key -from httpcore2 import AsyncNetworkStream - from .._models import Request, Response from .._transports.asgi import ASGITransport from .._types import AsyncByteStream @@ -42,7 +40,12 @@ def __init__(self, frame: Frame) -> None: self.frame = frame -class ASGIWebSocketAsyncNetworkStream(AsyncNetworkStream): +class ASGIWebSocketAsyncNetworkStream: + """ + An `httpcore2.AsyncNetworkStream` lookalike that translates reads and writes + into ASGI messages exchanged with the wrapped app. + """ + def __init__( self, app: ASGIApp, diff --git a/tests/httpx2/websockets/test_session.py b/tests/httpx2/websockets/test_session.py index 5f4e4ac5..dbe200d5 100644 --- a/tests/httpx2/websockets/test_session.py +++ b/tests/httpx2/websockets/test_session.py @@ -20,6 +20,7 @@ WebSocketNetworkError, WebSocketSession, WebSocketUpgradeError, + _api, ) from httpx2._websockets._session import JSONMode from tests.httpx2.websockets.conftest import ServerFactoryFixture @@ -52,7 +53,7 @@ def handler(request: httpx2.Request) -> httpx2.Response: def test_top_level_websocket() -> None: - with patch("httpx2._api.Client") as mock_client_cls: + with patch.object(_api, "Client") as mock_client_cls: mock_client = mock_client_cls.return_value.__enter__.return_value with httpx2.websocket("ws://socket/ws", subprotocols=["custom_protocol"]): pass @@ -67,7 +68,7 @@ class MockNetworkStream(NetworkStream): def __init__(self) -> None: self._should_close = False - def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: # pragma: no cover while not self._should_close: time.sleep(0.1) raise httpcore2.ReadError() @@ -88,10 +89,10 @@ class AsyncMockNetworkStream(AsyncNetworkStream): def __init__(self) -> None: self._should_close = False - async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: # pragma: no cover while not self._should_close: await anyio.sleep(0.1) - raise httpcore2.ReadError() # pragma: no cover + raise httpcore2.ReadError() async def write(self, buffer: bytes, timeout: float | None = None) -> None: raise httpcore2.WriteError() @@ -652,7 +653,7 @@ class MockNetworkStream(NetworkStream): def __init__(self) -> None: self._should_close = False - def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: # pragma: no cover while not self._should_close: time.sleep(0.1) raise httpcore2.ReadError() @@ -722,10 +723,10 @@ class MockAsyncNetworkStream(AsyncNetworkStream): def __init__(self) -> None: self._should_close = False - async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: # pragma: no cover while not self._should_close: await anyio.sleep(0.1) - raise httpcore2.ReadError() # pragma: no cover + raise httpcore2.ReadError() async def write(self, buffer: bytes, timeout: float | None = None) -> None: pass @@ -863,17 +864,17 @@ async def websocket_endpoint(websocket: WebSocket) -> None: with server_factory(websocket_endpoint) as socket: with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: - initial_threads_count = threading.active_count() + initial_threads = set(threading.enumerate()) with client.websocket("http://socket/ws", keepalive_ping_interval_seconds=None) as ws: for _ in range(50): ws.receive() ws.send_text("CLIENT_MESSAGE") - time.sleep(0.1) # Let the websocket endpoint finish its handling. - threads_count = threading.active_count() - assert initial_threads_count + 2 == threads_count - time.sleep(0.1) - final_threads_count = threading.active_count() - assert initial_threads_count == final_threads_count + session_threads = set(threading.enumerate()) - initial_threads + assert session_threads + deadline = time.time() + 5 + while any(thread.is_alive() for thread in session_threads) and time.time() < deadline: + time.sleep(0.01) # pragma: no cover + assert not any(thread.is_alive() for thread in session_threads) @pytest.mark.anyio diff --git a/uv.lock b/uv.lock index 847e4f90..feeae90d 100644 --- a/uv.lock +++ b/uv.lock @@ -47,6 +47,7 @@ dev = [ { name = "trustme", specifier = "==1.2.1" }, { name = "twine", specifier = "==6.1.0" }, { name = "uvicorn", specifier = ">=0.35" }, + { name = "websockets", specifier = ">=15" }, { name = "werkzeug", specifier = ">=3.1.6" }, { name = "wsproto", specifier = ">=1.2" }, ]