diff --git a/pyproject.toml b/pyproject.toml index 08859dbc..f248d8c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,20 +14,23 @@ httpx2 = { workspace = true } [dependency-groups] dev = [ - "httpx2[brotli,cli,http2,socks,zstd]", + "httpx2[brotli,cli,http2,socks,ws,zstd]", "httpcore2[asyncio,trio,http2,socks]", # Tests "chardet==6.0.0.post1", "coverage[toml]==7.10.6", "cryptography==46.0.7", + "flaky>=3.8", "pytest>=9.0.3", "pytest-codspeed>=4.1.1", "pytest-httpbin==2.0.0", "pytest-trio==0.8.0", + "starlette>=0.49", "trio==0.31.0", "trio-typing==0.10.0", "trustme==1.2.1", "uvicorn>=0.35", + "websockets>=15", "werkzeug>=3.1.6", # Linting "mypy==1.17.1", @@ -85,7 +88,7 @@ omit = ["src/httpcore2/httpcore2/_sync/*", "tests/test_benchmark.py"] [tool.coverage.report] exclude_also = [ "if TYPE_CHECKING:", - "if typing.TYPE_CHECKING:", + "if (typing|_typing).TYPE_CHECKING:", "raise NotImplementedError", "@(typing\\.)?overload", ] diff --git a/src/httpx2/httpx2/__init__.py b/src/httpx2/httpx2/__init__.py index 2fa4256d..7abf8097 100644 --- a/src/httpx2/httpx2/__init__.py +++ b/src/httpx2/httpx2/__init__.py @@ -1,3 +1,5 @@ +import typing as _typing + from .__version__ import __description__, __title__, __version__ from ._api import * from ._auth import * @@ -12,15 +14,28 @@ from ._types import * from ._urls import * +if _typing.TYPE_CHECKING: + from ._websockets.api import AsyncWebSocketSession, WebSocketSession + from ._websockets.exceptions import ( + HTTPXWSException, + WebSocketDisconnect, + WebSocketInvalidTypeReceived, + WebSocketNetworkError, + WebSocketUpgradeError, + ) + from ._websockets.transport import ASGIWebSocketTransport + __all__ = [ "__description__", "__title__", "__version__", "ASGITransport", + "ASGIWebSocketTransport", "AsyncBaseTransport", "AsyncByteStream", "AsyncClient", "AsyncHTTPTransport", + "AsyncWebSocketSession", "Auth", "BaseTransport", "BasicAuth", @@ -44,6 +59,7 @@ "HTTPError", "HTTPStatusError", "HTTPTransport", + "HTTPXWSException", "InvalidURL", "Limits", "LocalProtocolError", @@ -82,20 +98,37 @@ "UnsupportedProtocol", "URL", "USE_CLIENT_DEFAULT", + "websocket", + "WebSocketDisconnect", + "WebSocketInvalidTypeReceived", + "WebSocketNetworkError", + "WebSocketSession", + "WebSocketUpgradeError", "WriteError", "WriteTimeout", "WSGITransport", ] +_WEBSOCKET_NAMES = { + "ASGIWebSocketTransport", + "AsyncWebSocketSession", + "HTTPXWSException", + "WebSocketDisconnect", + "WebSocketInvalidTypeReceived", + "WebSocketNetworkError", + "WebSocketSession", + "WebSocketUpgradeError", +} + __locals = locals() for __name in __all__: - if not __name.startswith("__"): + if not __name.startswith("__") and __name not in _WEBSOCKET_NAMES: setattr(__locals[__name], "__module__", "httpx2") # noqa -def __getattr__(name: str) -> object: # pragma: no cover - if name == "main": +def __getattr__(name: str) -> object: + if name == "main": # pragma: no cover import warnings warnings.warn( @@ -108,4 +141,18 @@ def __getattr__(name: str) -> object: # pragma: no cover return main + if name in _WEBSOCKET_NAMES: + from . import _websockets + from ._websockets.defaults import WS_EXTRA_INSTALL_MESSAGE + + try: + return getattr(_websockets, name) + except ImportError: # pragma: no cover + raise ImportError(WS_EXTRA_INSTALL_MESSAGE) + + if name == "_websockets": + import importlib + + return importlib.import_module(f"{__name__}._websockets") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/httpx2/httpx2/_api.py b/src/httpx2/httpx2/_api.py index 25171cbc..08fa5bdd 100644 --- a/src/httpx2/httpx2/_api.py +++ b/src/httpx2/httpx2/_api.py @@ -19,10 +19,18 @@ TimeoutTypes, ) from ._urls import URL +from ._websockets.defaults import ( + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + DEFAULT_MAX_MESSAGE_SIZE_BYTES, + DEFAULT_QUEUE_SIZE, +) if typing.TYPE_CHECKING: import ssl # pragma: no cover + from ._websockets.api import WebSocketSession + __all__ = [ "delete", @@ -34,6 +42,7 @@ "put", "request", "stream", + "websocket", ] @@ -424,3 +433,57 @@ def delete( timeout=timeout, trust_env=trust_env, ) + + +@contextmanager +def websocket( + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | None = None, + proxy: ProxyTypes | None = None, + follow_redirects: bool = False, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + verify: ssl.SSLContext | str | bool = True, + trust_env: bool = True, + subprotocols: list[str] | None = None, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, +) -> Generator[WebSocketSession]: + """ + Open a WebSocket session. + + The session is closed automatically when exiting the context manager. + + ```python + with httpx2.websocket("ws://localhost:8000/ws") as ws: + ws.send_text("Hello!") + message = ws.receive_text() + ``` + + **Parameters**: See `httpx2.request` and `httpx2.Client.websocket`. + """ + with Client( + cookies=cookies, + proxy=proxy, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) as client: + with client.websocket( + url, + params=params, + headers=headers, + auth=auth, + follow_redirects=follow_redirects, + subprotocols=subprotocols, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + ) as session: + yield session diff --git a/src/httpx2/httpx2/_client.py b/src/httpx2/httpx2/_client.py index 8d810bb6..0b6f4b24 100644 --- a/src/httpx2/httpx2/_client.py +++ b/src/httpx2/httpx2/_client.py @@ -49,10 +49,19 @@ ) from ._urls import URL, QueryParams from ._utils import URLPattern, get_environment_proxies +from ._websockets.defaults import ( + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + DEFAULT_MAX_MESSAGE_SIZE_BYTES, + DEFAULT_QUEUE_SIZE, + WS_EXTRA_INSTALL_MESSAGE, +) if typing.TYPE_CHECKING: import ssl # pragma: no cover + from ._websockets.api import AsyncWebSocketSession, WebSocketSession + __all__ = ["USE_CLIENT_DEFAULT", "AsyncClient", "Client"] # The type annotation for @classmethod and context managers here follows PEP 484 @@ -888,6 +897,59 @@ def sse( ) as response: yield EventSource(response) + @contextmanager + def websocket( + self, + url: URL | str, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: list[str] | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> Generator[WebSocketSession]: + """ + Open a WebSocket session, using this client's configuration. + + The session is closed automatically when exiting the context manager. + + ```python + with httpx2.Client() as client: + with client.websocket("ws://localhost:8000/ws") as ws: + ws.send_text("Hello!") + message = ws.receive_text() + ``` + """ + try: + from ._websockets.api import connect_ws + except ImportError: # pragma: no cover + raise ImportError(WS_EXTRA_INSTALL_MESSAGE) + + with connect_ws( + str(url), + self, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocols=subprotocols, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) as session: + yield session + def send( self, request: Request, @@ -1633,6 +1695,59 @@ async def sse( ) as response: yield EventSource(response) + @asynccontextmanager + async def websocket( + self, + url: URL | str, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: list[str] | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + ) -> AsyncGenerator[AsyncWebSocketSession]: + """ + Open a WebSocket session, using this client's configuration. + + The session is closed automatically when exiting the context manager. + + ```python + async with httpx2.AsyncClient() as client: + async with client.websocket("ws://localhost:8000/ws") as ws: + await ws.send_text("Hello!") + message = await ws.receive_text() + ``` + """ + try: + from ._websockets.api import aconnect_ws + except ImportError: # pragma: no cover + raise ImportError(WS_EXTRA_INSTALL_MESSAGE) + + async with aconnect_ws( + str(url), + self, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocols=subprotocols, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) as session: + yield session + async def send( self, request: Request, diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py new file mode 100644 index 00000000..7ff2cf31 --- /dev/null +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -0,0 +1,61 @@ +""" +WebSocket support, derived from httpx-ws (https://github.com/frankie567/httpx-ws). + +Copyright (c) 2021 François Voron, MIT License (https://github.com/frankie567/httpx-ws/blob/main/LICENSE). +""" + +from .defaults import ( + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + DEFAULT_MAX_MESSAGE_SIZE_BYTES, + DEFAULT_QUEUE_SIZE, +) + +__all__ = [ + "ASGIWebSocketTransport", + "AsyncWebSocketSession", + "DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS", + "DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS", + "DEFAULT_MAX_MESSAGE_SIZE_BYTES", + "DEFAULT_QUEUE_SIZE", + "HTTPXWSException", + "JSONMode", + "WebSocketDisconnect", + "WebSocketInvalidTypeReceived", + "WebSocketNetworkError", + "WebSocketSession", + "WebSocketUpgradeError", + "aconnect_ws", + "connect_ws", +] + +_API_NAMES = { + "AsyncWebSocketSession", + "JSONMode", + "WebSocketSession", + "aconnect_ws", + "connect_ws", +} +_EXCEPTION_NAMES = { + "HTTPXWSException", + "WebSocketDisconnect", + "WebSocketInvalidTypeReceived", + "WebSocketNetworkError", + "WebSocketUpgradeError", +} + + +def __getattr__(name: str) -> object: + if name in _API_NAMES: + from . import api + + return getattr(api, name) + if name in _EXCEPTION_NAMES: + from . import exceptions + + return getattr(exceptions, name) + if name == "ASGIWebSocketTransport": + from .transport import ASGIWebSocketTransport + + return ASGIWebSocketTransport + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") # pragma: no cover diff --git a/src/httpx2/httpx2/_websockets/api.py b/src/httpx2/httpx2/_websockets/api.py new file mode 100644 index 00000000..004a7482 --- /dev/null +++ b/src/httpx2/httpx2/_websockets/api.py @@ -0,0 +1,1359 @@ +from __future__ import annotations + +import base64 +import concurrent.futures +import contextlib +import json +import queue +import secrets +import threading +import typing +from types import TracebackType + +import anyio +import wsproto +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from wsproto.frame_protocol import CloseReason + +from .._client import USE_CLIENT_DEFAULT +from .._models import Headers +from .defaults import ( + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + DEFAULT_MAX_MESSAGE_SIZE_BYTES, + DEFAULT_QUEUE_SIZE, +) +from .exceptions import ( + HTTPXWSException, + WebSocketDisconnect, + WebSocketInvalidTypeReceived, + WebSocketNetworkError, + WebSocketUpgradeError, +) +from .ping import AsyncPingManager, PingManager +from .transport import ASGIWebSocketAsyncNetworkStream + +if typing.TYPE_CHECKING: + from httpcore2 import AsyncNetworkStream, NetworkStream + + from .._client import AsyncClient, Client, UseClientDefault + from .._models import Response + from .._types import ( + AuthTypes, + CookieTypes, + HeaderTypes, + QueryParamTypes, + RequestExtensions, + TimeoutTypes, + ) + +JSONMode = typing.Literal["text", "binary"] +TaskFunction = typing.TypeVar("TaskFunction") +TaskResult = typing.TypeVar("TaskResult") + + +class ShouldClose(Exception): + pass + + +class WebSocketSession: + """ + Sync context manager representing an opened WebSocket session. + + Attributes: + subprotocol (typing.Optional[str]): + Optional protocol that has been accepted by the server. + response (Response | None): + The webSocket handshake response. + """ + + subprotocol: str | None + response: Response | None + + def __init__( + self, + stream: NetworkStream, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + response: Response | None = None, + ) -> None: + self.stream = stream + self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) + self.response = response + if self.response is not None: + self.subprotocol = self.response.headers.get("sec-websocket-protocol") + else: + self.subprotocol = None + + self._events: queue.Queue[wsproto.events.Event | HTTPXWSException] = queue.Queue(queue_size) + + self._ping_manager = PingManager() + self._should_close = threading.Event() + self._should_close_task: concurrent.futures.Future[bool] | None = None + self._executor: concurrent.futures.ThreadPoolExecutor | None = None + + self._max_message_size_bytes = max_message_size_bytes + self._queue_size = queue_size + self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds + self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds + + def _get_executor_should_close_task( + self, + ) -> tuple[concurrent.futures.ThreadPoolExecutor, concurrent.futures.Future[bool]]: + if self._should_close_task is None: + self._executor = concurrent.futures.ThreadPoolExecutor() + self._should_close_task = self._executor.submit(self._should_close.wait) + assert self._executor is not None + return self._executor, self._should_close_task + + def __enter__(self) -> WebSocketSession: + self._background_receive_task = threading.Thread( + target=self._background_receive, args=(self._max_message_size_bytes,) + ) + self._background_receive_task.start() + + self._background_keepalive_ping_task: threading.Thread | None = None + if self._keepalive_ping_interval_seconds is not None: + self._background_keepalive_ping_task = threading.Thread( + target=self._background_keepalive_ping, + args=( + self._keepalive_ping_interval_seconds, + self._keepalive_ping_timeout_seconds, + ), + ) + self._background_keepalive_ping_task.start() + + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.close() + self._background_receive_task.join() + if self._background_keepalive_ping_task is not None: + self._background_keepalive_ping_task.join() + + def ping(self, payload: bytes = b"") -> threading.Event: + """ + Send a Ping message. + + Args: + payload: + Payload to attach to the Ping event. + Internally, it's used to track this specific event. + If left empty, a random one will be generated. + + Returns: + An event that can be used to wait for the corresponding Pong response. + + Examples: + Send a Ping and wait for the Pong + + pong_callback = ws.ping() + # Will block until the corresponding Pong is received. + pong_callback.wait() + """ + ping_id, callback = self._ping_manager.create(payload) + event = wsproto.events.Ping(ping_id) + self.send(event) + return callback + + def send(self, event: wsproto.events.Event) -> None: + """ + Send an Event message. + + Mainly useful to send events that are not supported by the library. + Most of the time, [ping()][httpx_ws.WebSocketSession.ping], + [send_text()][httpx_ws.WebSocketSession.send_text], + [send_bytes()][httpx_ws.WebSocketSession.send_bytes] + and [send_json()][httpx_ws.WebSocketSession.send_json] are preferred. + + Args: + event: The event to send. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send an event. + + event = wsproto.events.Message(b"Hello!") + ws.send(event) + """ + import httpcore2 + + try: + data = self.connection.send(event) + self.stream.write(data) + except httpcore2.WriteError as e: + self.close(CloseReason.INTERNAL_ERROR, "Stream write error") + raise WebSocketNetworkError() from e + + def send_text(self, data: str) -> None: + """ + Send a text message. + + Args: + data: The text to send. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send a text message. + + ws.send_text("Hello!") + """ + event = wsproto.events.TextMessage(data=data) + self.send(event) + + def send_bytes(self, data: bytes) -> None: + """ + Send a bytes message. + + Args: + data: The data to send. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send a bytes message. + + ws.send_bytes(b"Hello!") + """ + event = wsproto.events.BytesMessage(data=data) + self.send(event) + + def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: + """ + Send JSON data. + + Args: + data: + The data to send. Must be serializable by [json.dumps][json.dumps]. + mode: + The sending mode. Should either be `'text'` or `'bytes'`. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send JSON data. + + data = {"message": "Hello!"} + ws.send_json(data) + """ + assert mode in ["text", "binary"] + serialized_data = json.dumps(data) + if mode == "text": + self.send_text(serialized_data) + else: + self.send_bytes(serialized_data.encode("utf-8")) + + def receive(self, timeout: float | None = None) -> wsproto.events.Event: + """ + Receive an event from the server. + + Mainly useful to receive raw [wsproto.events.Event][wsproto.events.Event]. + Most of the time, [receive_text()][httpx_ws.WebSocketSession.receive_text], + [receive_bytes()][httpx_ws.WebSocketSession.receive_bytes], + and [receive_json()][httpx_ws.WebSocketSession.receive_json] are preferred. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + A raw [wsproto.events.Event][wsproto.events.Event]. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + + Examples: + Wait for an event until one is available. + + try: + event = ws.receive() + except WebSocketDisconnect: + print("Connection closed") + + Wait for an event for 2 seconds. + + try: + event = ws.receive(timeout=2.) + except queue.Empty: + print("No event received.") + except WebSocketDisconnect: + print("Connection closed") + """ + event = self._events.get(block=True, timeout=timeout) + if isinstance(event, HTTPXWSException): + raise event + if isinstance(event, wsproto.events.CloseConnection): + raise WebSocketDisconnect(event.code, event.reason) + return event + + def receive_text(self, timeout: float | None = None) -> str: + """ + Receive text from the server. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + Text data. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + WebSocketInvalidTypeReceived: The received event was not a text message. + + Examples: + Wait for text until available. + + try: + text = ws.receive_text() + except WebSocketDisconnect: + print("Connection closed") + + Wait for text for 2 seconds. + + try: + event = ws.receive_text(timeout=2.) + except queue.Empty: + print("No text received.") + except WebSocketDisconnect: + print("Connection closed") + """ + event = self.receive(timeout) + if isinstance(event, wsproto.events.TextMessage): + return event.data + raise WebSocketInvalidTypeReceived(event) + + def receive_bytes(self, timeout: float | None = None) -> bytes: + """ + Receive bytes from the server. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + Bytes data. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + WebSocketInvalidTypeReceived: The received event was not a bytes message. + + Examples: + Wait for bytes until available. + + try: + data = ws.receive_bytes() + except WebSocketDisconnect: + print("Connection closed") + + Wait for bytes for 2 seconds. + + try: + data = ws.receive_bytes(timeout=2.) + except queue.Empty: + print("No data received.") + except WebSocketDisconnect: + print("Connection closed") + """ + event = self.receive(timeout) + if isinstance(event, wsproto.events.BytesMessage): + return bytes(event.data) + raise WebSocketInvalidTypeReceived(event) + + def receive_json(self, timeout: float | None = None, mode: JSONMode = "text") -> typing.Any: + """ + Receive JSON data from the server. + + The received data should be parseable by [json.loads][json.loads]. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + mode: + Receive mode. Should either be `'text'` or `'bytes'`. + + Returns: + Parsed JSON data. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + WebSocketInvalidTypeReceived: The received event + didn't correspond to the specified mode. + + Examples: + Wait for data until available. + + try: + data = ws.receive_json() + except WebSocketDisconnect: + print("Connection closed") + + Wait for data for 2 seconds. + + try: + data = ws.receive_json(timeout=2.) + except queue.Empty: + print("No data received.") + except WebSocketDisconnect: + print("Connection closed") + """ + assert mode in ["text", "binary"] + data: str | bytes + if mode == "text": + data = self.receive_text(timeout) + elif mode == "binary": + data = self.receive_bytes(timeout) + return json.loads(data) + + def close(self, code: int = 1000, reason: str | None = None) -> None: + """ + Close the WebSocket session. + + Internally, it'll send the + [CloseConnection][wsproto.events.CloseConnection] event. + + *This method is automatically called when exiting the context manager.* + + Args: + code: + The integer close code to indicate why the connection has closed. + reason: + Additional reasoning for why the connection has closed. + + Examples: + Close the WebSocket session. + + ws.close() + """ + import httpcore2 + + self._should_close.set() + if self._executor is not None: + self._executor.shutdown(False) + if self.connection.state not in { + wsproto.connection.ConnectionState.LOCAL_CLOSING, + wsproto.connection.ConnectionState.CLOSED, + }: + event = wsproto.events.CloseConnection(code, reason) + data = self.connection.send(event) + try: + self.stream.write(data) + except httpcore2.WriteError: + pass + self.stream.close() + + def _background_receive(self, max_bytes: int) -> None: + """ + Background thread listening for data from the server. + + Internally, it'll: + + * Answer to Ping events. + * Acknowledge Pong events. + * Put other events in the [_events][_events] + queue that'll eventually be consumed by the user. + + Args: + max_bytes: The maximum chunk size to read at each iteration. + """ + import httpcore2 + + partial_message_buffer: str | bytes | None = None + try: + while not self._should_close.is_set(): + data = self._wait_until_closed(self.stream.read, max_bytes) + self.connection.receive_data(data) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Ping): + data = self.connection.send(event.response()) + self.stream.write(data) + continue + if isinstance(event, wsproto.events.Pong): + self._ping_manager.ack(event.payload) + continue + if isinstance(event, wsproto.events.CloseConnection): + self._should_close.set() + if isinstance(event, wsproto.events.Message): + # Unfinished message: bufferize + if not event.message_finished: + if partial_message_buffer is None: + partial_message_buffer = event.data + else: + partial_message_buffer += event.data + # Finished message but no buffer: just emit the event + elif partial_message_buffer is None: + self._events.put(event) + # Finished message with buffer: emit the full event + else: + event_type = type(event) + full_message_event = event_type(partial_message_buffer + event.data) + partial_message_buffer = None + self._events.put(full_message_event) + continue + self._events.put(event) + except (httpcore2.ReadError, httpcore2.WriteError): + self.close(CloseReason.INTERNAL_ERROR, "Stream error") + self._events.put(WebSocketNetworkError()) + except ShouldClose: + pass + + def _background_keepalive_ping(self, interval_seconds: float, timeout_seconds: float | None = None) -> None: + try: + while not self._should_close.is_set(): + should_close = self._wait_until_closed(self._should_close.wait, interval_seconds) + if should_close: # pragma: no cover + raise ShouldClose() + pong_callback = self.ping() + if timeout_seconds is not None: + acknowledged = self._wait_until_closed(pong_callback.wait, timeout_seconds) + if not acknowledged: + self.close(CloseReason.INTERNAL_ERROR, "Keepalive ping timeout") + self._events.put(WebSocketNetworkError()) + except ShouldClose: + pass + + def _wait_until_closed( + self, callable: typing.Callable[..., TaskResult], *args: typing.Any, **kwargs: typing.Any + ) -> TaskResult: + try: + executor, should_close_task = self._get_executor_should_close_task() + todo_task = executor.submit(callable, *args, **kwargs) + except RuntimeError as e: + raise ShouldClose() from e + else: + done, _ = concurrent.futures.wait( + (todo_task, should_close_task), # type: ignore[misc] + return_when=concurrent.futures.FIRST_COMPLETED, + ) + if should_close_task in done: + raise ShouldClose() + assert todo_task in done + result = todo_task.result() + return result + + +class AsyncWebSocketSession: + """ + Async context manager representing an opened WebSocket session. + + Attributes: + subprotocol (typing.Optional[str]): + Optional protocol that has been accepted by the server. + response (Response | None): + The webSocket handshake response. + """ + + subprotocol: str | None + response: Response | None + _send_event: MemoryObjectSendStream[wsproto.events.Event | HTTPXWSException] + _receive_event: MemoryObjectReceiveStream[wsproto.events.Event | HTTPXWSException] + + def __init__( + self, + stream: AsyncNetworkStream, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + response: Response | None = None, + ) -> None: + self.stream = stream + self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) + self.response = response + if self.response is not None: + self.subprotocol = self.response.headers.get("sec-websocket-protocol") + else: + self.subprotocol = None + + self._ping_manager = AsyncPingManager() + self._should_close = anyio.Event() + + self._max_message_size_bytes = max_message_size_bytes + self._queue_size = queue_size + + # Always disable keepalive ping when emulating ASGI + if isinstance(stream, ASGIWebSocketAsyncNetworkStream): + self._keepalive_ping_interval_seconds = None + self._keepalive_ping_timeout_seconds = None + else: + self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds + self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds + + async def __aenter__(self) -> AsyncWebSocketSession: + async with contextlib.AsyncExitStack() as exit_stack: + self._send_event, self._receive_event = anyio.create_memory_object_stream[ + wsproto.events.Event | HTTPXWSException + ]() + exit_stack.enter_context(self._send_event) + exit_stack.enter_context(self._receive_event) + + self._background_task_group = anyio.create_task_group() + await exit_stack.enter_async_context(self._background_task_group) + + self._background_task_group.start_soon(self._background_receive, self._max_message_size_bytes) + if self._keepalive_ping_interval_seconds is not None: + self._background_task_group.start_soon( + self._background_keepalive_ping, + self._keepalive_ping_interval_seconds, + self._keepalive_ping_timeout_seconds, + ) + + exit_stack.callback(self._background_task_group.cancel_scope.cancel) + exit_stack.push_async_callback(self.close) + self._exit_stack = exit_stack.pop_all() + + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + await self._exit_stack.aclose() + + async def ping(self, payload: bytes = b"") -> anyio.Event: + """ + Send a Ping message. + + Args: + payload: + Payload to attach to the Ping event. + Internally, it's used to track this specific event. + If left empty, a random one will be generated. + + Returns: + An event that can be used to wait for the corresponding Pong response. + + Examples: + Send a Ping and wait for the Pong + + pong_callback = await ws.ping() + # Will block until the corresponding Pong is received. + await pong_callback.wait() + """ + ping_id, callback = self._ping_manager.create(payload) + event = wsproto.events.Ping(ping_id) + await self.send(event) + return callback + + async def send(self, event: wsproto.events.Event) -> None: + """ + Send an Event message. + + Mainly useful to send events that are not supported by the library. + Most of the time, [ping()][httpx_ws.AsyncWebSocketSession.ping], + [send_text()][httpx_ws.AsyncWebSocketSession.send_text], + [send_bytes()][httpx_ws.AsyncWebSocketSession.send_bytes] + and [send_json()][httpx_ws.AsyncWebSocketSession.send_json] are preferred. + + Args: + event: The event to send. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send an event. + + event = await wsproto.events.Message(b"Hello!") + ws.send(event) + """ + import httpcore2 + + try: + data = self.connection.send(event) + await self.stream.write(data) + except httpcore2.WriteError as e: + await self.close(CloseReason.INTERNAL_ERROR, "Stream write error") + raise WebSocketNetworkError() from e + + async def send_text(self, data: str) -> None: + """ + Send a text message. + + Args: + data: The text to send. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send a text message. + + await ws.send_text("Hello!") + """ + event = wsproto.events.TextMessage(data=data) + await self.send(event) + + async def send_bytes(self, data: bytes) -> None: + """ + Send a bytes message. + + Args: + data: The data to send. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send a bytes message. + + await ws.send_bytes(b"Hello!") + """ + event = wsproto.events.BytesMessage(data=data) + await self.send(event) + + async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: + """ + Send JSON data. + + Args: + data: + The data to send. Must be serializable by [json.dumps][json.dumps]. + mode: + The sending mode. Should either be `'text'` or `'bytes'`. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send JSON data. + + data = {"message": "Hello!"} + await ws.send_json(data) + """ + assert mode in ["text", "binary"] + serialized_data = json.dumps(data) + if mode == "text": + await self.send_text(serialized_data) + else: + await self.send_bytes(serialized_data.encode("utf-8")) + + async def receive(self, timeout: float | None = None) -> wsproto.events.Event: + """ + Receive an event from the server. + + Mainly useful to receive raw [wsproto.events.Event][wsproto.events.Event]. + Most of the time, [receive_text()][httpx_ws.AsyncWebSocketSession.receive_text], + [receive_bytes()][httpx_ws.AsyncWebSocketSession.receive_bytes], + and [receive_json()][httpx_ws.AsyncWebSocketSession.receive_json] are preferred. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + A raw [wsproto.events.Event][wsproto.events.Event]. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + + Examples: + Wait for an event until one is available. + + try: + event = await ws.receive() + except WebSocketDisconnect: + print("Connection closed") + + Wait for an event for 2 seconds. + + try: + event = await ws.receive(timeout=2.) + except TimeoutError: + print("No event received.") + except WebSocketDisconnect: + print("Connection closed") + """ + with anyio.fail_after(timeout): + event = await self._receive_event.receive() + if isinstance(event, HTTPXWSException): + raise event + if isinstance(event, wsproto.events.CloseConnection): + raise WebSocketDisconnect(event.code, event.reason) + return event + + async def receive_text(self, timeout: float | None = None) -> str: + """ + Receive text from the server. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + Text data. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + WebSocketInvalidTypeReceived: The received event was not a text message. + + Examples: + Wait for text until available. + + try: + text = await ws.receive_text() + except WebSocketDisconnect: + print("Connection closed") + + Wait for text for 2 seconds. + + try: + event = await ws.receive_text(timeout=2.) + except TimeoutError: + print("No text received.") + except WebSocketDisconnect: + print("Connection closed") + """ + event = await self.receive(timeout) + if isinstance(event, wsproto.events.TextMessage): + return event.data + raise WebSocketInvalidTypeReceived(event) + + async def receive_bytes(self, timeout: float | None = None) -> bytes: + """ + Receive bytes from the server. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + Bytes data. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + WebSocketInvalidTypeReceived: The received event was not a bytes message. + + Examples: + Wait for bytes until available. + + try: + data = await ws.receive_bytes() + except WebSocketDisconnect: + print("Connection closed") + + Wait for bytes for 2 seconds. + + try: + data = await ws.receive_bytes(timeout=2.) + except TimeoutError: + print("No data received.") + except WebSocketDisconnect: + print("Connection closed") + """ + event = await self.receive(timeout) + if isinstance(event, wsproto.events.BytesMessage): + return bytes(event.data) + raise WebSocketInvalidTypeReceived(event) + + async def receive_json(self, timeout: float | None = None, mode: JSONMode = "text") -> typing.Any: + """ + Receive JSON data from the server. + + The received data should be parseable by [json.loads][json.loads]. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + mode: + Receive mode. Should either be `'text'` or `'bytes'`. + + Returns: + Parsed JSON data. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + WebSocketInvalidTypeReceived: The received event + didn't correspond to the specified mode. + + Examples: + Wait for data until available. + + try: + data = await ws.receive_json() + except WebSocketDisconnect: + print("Connection closed") + + Wait for data for 2 seconds. + + try: + data = await ws.receive_json(timeout=2.) + except TimeoutError: + print("No data received.") + except WebSocketDisconnect: + print("Connection closed") + """ + assert mode in ["text", "binary"] + data: str | bytes + if mode == "text": + data = await self.receive_text(timeout) + elif mode == "binary": + data = await self.receive_bytes(timeout) + return json.loads(data) + + async def close(self, code: int = 1000, reason: str | None = None) -> None: + """ + Close the WebSocket session. + + Internally, it'll send the + [CloseConnection][wsproto.events.CloseConnection] event. + + *This method is automatically called when exiting the context manager.* + + Args: + code: + The integer close code to indicate why the connection has closed. + reason: + Additional reasoning for why the connection has closed. + + Examples: + Close the WebSocket session. + + await ws.close() + """ + import httpcore2 + + self._should_close.set() + if self.connection.state not in { + wsproto.connection.ConnectionState.LOCAL_CLOSING, + wsproto.connection.ConnectionState.CLOSED, + }: + event = wsproto.events.CloseConnection(code, reason) + data = self.connection.send(event) + try: + await self.stream.write(data) + except httpcore2.WriteError: + pass + await self.stream.aclose() + + async def _background_receive(self, max_bytes: int) -> None: + """ + Background task listening for data from the server. + + Internally, it'll: + + * Answer to Ping events. + * Acknowledge Pong events. + * Put other events in the [_events][_events] + queue that'll eventually be consumed by the user. + + Args: + max_bytes: The maximum chunk size to read at each iteration. + """ + import httpcore2 + + partial_message_buffer: str | bytes | None = None + try: + while not self._should_close.is_set(): + data = await self.stream.read(max_bytes=max_bytes) + self.connection.receive_data(data) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Ping): + data = self.connection.send(event.response()) + await self.stream.write(data) + continue + if isinstance(event, wsproto.events.Pong): + self._ping_manager.ack(event.payload) + continue + if isinstance(event, wsproto.events.CloseConnection): + self._should_close.set() + if isinstance(event, wsproto.events.Message): + # Unfinished message: bufferize + if not event.message_finished: + if partial_message_buffer is None: + partial_message_buffer = event.data + else: + partial_message_buffer += event.data + # Finished message but no buffer: just emit the event + elif partial_message_buffer is None: + await self._send_event.send(event) + # Finished message with buffer: emit the full event + else: + event_type = type(event) + full_message_event = event_type(partial_message_buffer + event.data) + partial_message_buffer = None + await self._send_event.send(full_message_event) + continue + await self._send_event.send(event) + except (httpcore2.ReadError, httpcore2.WriteError): + await self.close(CloseReason.INTERNAL_ERROR, "Stream error") + await self._send_event.send(WebSocketNetworkError()) + + async def _background_keepalive_ping(self, interval_seconds: float, timeout_seconds: float | None = None) -> None: + while not self._should_close.is_set(): + await anyio.sleep(interval_seconds) + if self._should_close.is_set(): + return + pong_callback = await self.ping() + if timeout_seconds is not None: + try: + with anyio.fail_after(timeout_seconds): + await pong_callback.wait() + except TimeoutError: + await self.close(CloseReason.INTERNAL_ERROR, "Keepalive ping timeout") + await self._send_event.send(WebSocketNetworkError()) + + +def _get_headers( + subprotocols: list[str] | None, +) -> dict[str, typing.Any]: + headers = { + "connection": "upgrade", + "upgrade": "websocket", + "sec-websocket-key": base64.b64encode(secrets.token_bytes(16)).decode("utf-8"), + "sec-websocket-version": "13", + } + if subprotocols is not None: + headers["sec-websocket-protocol"] = ", ".join(subprotocols) + return headers + + +@contextlib.contextmanager +def _connect_ws( + url: str, + client: Client, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: list[str] | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, +) -> typing.Generator[WebSocketSession, None, None]: + with client.stream( + "GET", + url, + params=params, + headers=Headers(headers) | _get_headers(subprotocols), + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) as response: + if response.status_code != 101: + raise WebSocketUpgradeError(response) + + with WebSocketSession( + response.extensions["network_stream"], + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + response=response, + ) as session: + yield session + + +@contextlib.contextmanager +def connect_ws( + url: str, + client: Client | None = None, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: list[str] | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, +) -> typing.Generator[WebSocketSession, None, None]: + """ + Start a sync WebSocket session. + + It returns a context manager that'll automatically + call [close()][httpx_ws.WebSocketSession.close] when exiting. + + Args: + url: The WebSocket URL. + client: + HTTPX client to use. + If not provided, a default one will be initialized. + max_message_size_bytes: + Message size in bytes to receive from the server. + Defaults to 65 KiB. + queue_size: + Size of the queue where the received messages will be held + until they are consumed. + If the queue is full, the client will stop receive messages + from the server until the queue has room available. + Defaults to 512. + keepalive_ping_interval_seconds: + Interval at which the client will automatically send a Ping event + to keep the connection alive. Set it to `None` to disable this mechanism. + Defaults to 20 seconds. + keepalive_ping_timeout_seconds: + Maximum delay the client will wait for an answer to its Ping event. + If the delay is exceeded, + [WebSocketNetworkError][httpx_ws.WebSocketNetworkError] + will be raised and the connection closed. + Defaults to 20 seconds. + subprotocols: + Optional list of suprotocols to negotiate with the server. + params: + Query parameters to include in the handshake request. + headers: + Headers to include in the handshake request. + cookies: + Cookies to include in the handshake request. + auth: + Authentication to use for the handshake request. + follow_redirects: + Whether to follow redirects on the handshake request. + timeout: + Timeout configuration for the handshake request. + extensions: + Request extensions for the handshake request. + + Returns: + A [context manager][contextlib.AbstractContextManager] + for [WebSocketSession][httpx_ws.WebSocketSession]. + + Examples: + Without explicit HTTPX client. + + with connect_ws("http://localhost:8000/ws") as ws: + message = ws.receive_text() + print(message) + ws.send_text("Hello!") + + With explicit HTTPX client. + + with httpx2.Client() as client: + with connect_ws("http://localhost:8000/ws", client) as ws: + message = ws.receive_text() + print(message) + ws.send_text("Hello!") + """ + if client is None: + from .._client import Client + + owned_client: contextlib.AbstractContextManager[Client] = Client() + else: + owned_client = contextlib.nullcontext(client) + + with owned_client as client: + with _connect_ws( + url, + client=client, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocols=subprotocols, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) as websocket: + yield websocket + + +@contextlib.asynccontextmanager +async def _aconnect_ws( + url: str, + client: AsyncClient, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: list[str] | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, +) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: + async with client.stream( + "GET", + url, + params=params, + headers=Headers(headers) | _get_headers(subprotocols), + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) as response: + if response.status_code != 101: + raise WebSocketUpgradeError(response) + + async with AsyncWebSocketSession( + response.extensions["network_stream"], + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + response=response, + ) as session: + yield session + + +@contextlib.asynccontextmanager +async def aconnect_ws( + url: str, + client: AsyncClient | None = None, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: list[str] | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, +) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: + """ + Start an async WebSocket session. + + It returns an async context manager that'll automatically + call [close()][httpx_ws.AsyncWebSocketSession.close] when exiting. + + Args: + url: The WebSocket URL. + client: + HTTPX client to use. + If not provided, a default one will be initialized. + max_message_size_bytes: + Message size in bytes to receive from the server. + Defaults to 65 KiB. + queue_size: + Size of the queue where the received messages will be held + until they are consumed. + If the queue is full, the client will stop receive messages + from the server until the queue has room available. + Defaults to 512. + keepalive_ping_interval_seconds: + Interval at which the client will automatically send a Ping event + to keep the connection alive. Set it to `None` to disable this mechanism. + Defaults to 20 seconds. + keepalive_ping_timeout_seconds: + Maximum delay the client will wait for an answer to its Ping event. + If the delay is exceeded, + [WebSocketNetworkError][httpx_ws.WebSocketNetworkError] + will be raised and the connection closed. + Defaults to 20 seconds. + subprotocols: + Optional list of suprotocols to negotiate with the server. + params: + Query parameters to include in the handshake request. + headers: + Headers to include in the handshake request. + cookies: + Cookies to include in the handshake request. + auth: + Authentication to use for the handshake request. + follow_redirects: + Whether to follow redirects on the handshake request. + timeout: + Timeout configuration for the handshake request. + extensions: + Request extensions for the handshake request. + + Returns: + An [async context manager][contextlib.AbstractAsyncContextManager] + for [AsyncWebSocketSession][httpx_ws.AsyncWebSocketSession]. + + Examples: + Without explicit HTTPX client. + + async with aconnect_ws("http://localhost:8000/ws") as ws: + message = await ws.receive_text() + print(message) + await ws.send_text("Hello!") + + With explicit HTTPX client. + + async with httpx2.AsyncClient() as client: + async with aconnect_ws("http://localhost:8000/ws", client) as ws: + message = await ws.receive_text() + print(message) + await ws.send_text("Hello!") + """ + if client is None: + from .._client import AsyncClient + + owned_client: contextlib.AbstractAsyncContextManager[AsyncClient] = AsyncClient() + else: + owned_client = contextlib.nullcontext(client) + + async with owned_client as client: + async with _aconnect_ws( + url, + client=client, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocols=subprotocols, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) as websocket: + yield websocket diff --git a/src/httpx2/httpx2/_websockets/defaults.py b/src/httpx2/httpx2/_websockets/defaults.py new file mode 100644 index 00000000..6adc202e --- /dev/null +++ b/src/httpx2/httpx2/_websockets/defaults.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +DEFAULT_MAX_MESSAGE_SIZE_BYTES = 65_536 +DEFAULT_QUEUE_SIZE = 512 +DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS = 20.0 +DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS = 20.0 + +WS_EXTRA_INSTALL_MESSAGE = "WebSocket support requires the `wsproto` package. Install it with `pip install httpx2[ws]`." diff --git a/src/httpx2/httpx2/_websockets/exceptions.py b/src/httpx2/httpx2/_websockets/exceptions.py new file mode 100644 index 00000000..762643aa --- /dev/null +++ b/src/httpx2/httpx2/_websockets/exceptions.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import typing + +if typing.TYPE_CHECKING: + import wsproto + + from .._models import Response + + +class HTTPXWSException(Exception): + """ + Base exception class for HTTPX WS. + """ + + +class WebSocketUpgradeError(HTTPXWSException): + """ + Raised when the initial connection didn't correctly upgrade to a WebSocket session. + """ + + def __init__(self, response: Response) -> None: + self.response = response + + +class WebSocketDisconnect(HTTPXWSException): + """ + Raised when the server closed the WebSocket session. + + Args: + code: + The integer close code to indicate why the connection has closed. + reason: + Additional reasoning for why the connection has closed. + """ + + def __init__(self, code: int = 1000, reason: str | None = None) -> None: + self.code = code + self.reason = reason or "" + + +class WebSocketInvalidTypeReceived(HTTPXWSException): + """ + Raised when a event is not of the expected type. + """ + + def __init__(self, event: wsproto.events.Event) -> None: + self.event = event + + +class WebSocketNetworkError(HTTPXWSException): + """ + Raised when a network error occured, + typically if the underlying stream has closed or timeout. + """ diff --git a/src/httpx2/httpx2/_websockets/ping.py b/src/httpx2/httpx2/_websockets/ping.py new file mode 100644 index 00000000..b9116cb0 --- /dev/null +++ b/src/httpx2/httpx2/_websockets/ping.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import secrets +import threading + +import anyio + + +class PingManagerBase: + def _generate_id(self) -> bytes: + return secrets.token_bytes() + + +class PingManager(PingManagerBase): + def __init__(self) -> None: + self._pings: dict[bytes, threading.Event] = {} + + def create(self, ping_id: bytes | None = None) -> tuple[bytes, threading.Event]: + ping_id = self._generate_id() if not ping_id else ping_id + event = threading.Event() + self._pings[ping_id] = event + return ping_id, event + + def ack(self, ping_id: bytes | bytearray) -> None: + event = self._pings.pop(bytes(ping_id)) + event.set() + + +class AsyncPingManager(PingManagerBase): + def __init__(self) -> None: + self._pings: dict[bytes, anyio.Event] = {} + + def create(self, ping_id: bytes | None = None) -> tuple[bytes, anyio.Event]: + ping_id = self._generate_id() if not ping_id else ping_id + event = anyio.Event() + self._pings[ping_id] = event + return ping_id, event + + def ack(self, ping_id: bytes | bytearray) -> None: + event = self._pings.pop(bytes(ping_id)) + event.set() diff --git a/src/httpx2/httpx2/_websockets/py.typed b/src/httpx2/httpx2/_websockets/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py new file mode 100644 index 00000000..621bd55c --- /dev/null +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import contextlib +import queue +import typing +from concurrent.futures import Future + +import anyio +import wsproto +from wsproto.frame_protocol import CloseReason + +from .._models import Request, Response +from .._transports.asgi import ASGITransport, _ASGIApp +from .._types import AsyncByteStream +from .exceptions import WebSocketDisconnect + +Scope = dict[str, typing.Any] +Message = dict[str, typing.Any] +Receive = typing.Callable[[], typing.Awaitable[Message]] +Send = typing.Callable[[Scope], typing.Coroutine[None, None, None]] +ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Coroutine[None, None, None]] + + +class ASGIWebSocketTransportError(Exception): + pass + + +class UnhandledASGIMessageType(ASGIWebSocketTransportError): + def __init__(self, message: Message) -> None: + self.message = message + + +class UnhandledWebSocketEvent(ASGIWebSocketTransportError): + def __init__(self, event: wsproto.events.Event) -> None: + self.event = event + + +class ASGIWebSocketAsyncNetworkStream: + def __init__(self, app: ASGIApp, scope: Scope) -> None: + self.app = app + self.scope = scope + self._receive_queue: queue.Queue[Message] = queue.Queue() + self._send_queue: queue.Queue[Message] = queue.Queue() + self.connection = wsproto.WSConnection(wsproto.ConnectionType.SERVER) + self.connection.initiate_upgrade_connection(scope["headers"], scope["path"]) + + async def __aenter__( + self, + ) -> tuple[ASGIWebSocketAsyncNetworkStream, bytes]: + self.exit_stack = contextlib.ExitStack() + self.portal = self.exit_stack.enter_context(anyio.from_thread.start_blocking_portal("asyncio")) + _: Future[None] = self.portal.start_task_soon(self._run) + + await self.send({"type": "websocket.connect"}) + message = await self.receive() + + if message["type"] == "websocket.close": + await self.aclose() + raise WebSocketDisconnect(message["code"], message.get("reason")) + + assert message["type"] == "websocket.accept" + return self, self._build_accept_response(message) + + async def __aexit__(self, *args: typing.Any) -> None: + await self.aclose() + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + message: Message = await self.receive(timeout=timeout) + type = message["type"] + + if type not in {"websocket.send", "websocket.close"}: + raise UnhandledASGIMessageType(message) + + event: wsproto.events.Event + if type == "websocket.send": + data_str: str | None = message.get("text") + if data_str is not None: + event = wsproto.events.TextMessage(data_str) + data_bytes: bytes | None = message.get("bytes") + if data_bytes is not None: + event = wsproto.events.BytesMessage(data_bytes) + elif type == "websocket.close": + event = wsproto.events.CloseConnection(message["code"], message["reason"]) + + return self.connection.send(event) + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + self.connection.receive_data(buffer) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Request): + pass + elif isinstance(event, wsproto.events.CloseConnection): + await self.send( + { + "type": "websocket.close", + "code": event.code, + "reason": event.reason, + } + ) + elif isinstance(event, wsproto.events.TextMessage): + await self.send({"type": "websocket.receive", "text": event.data}) + elif isinstance(event, wsproto.events.BytesMessage): + await self.send({"type": "websocket.receive", "bytes": event.data}) + else: + raise UnhandledWebSocketEvent(event) + + async def aclose(self) -> None: + await self.send({"type": "websocket.close"}) + self.exit_stack.close() + + async def send(self, message: Message) -> None: + self._receive_queue.put(message) + + async def receive(self, timeout: float | None = None) -> Message: + while self._send_queue.empty(): + await anyio.sleep(0) + return self._send_queue.get(timeout=timeout) + + async def _run(self) -> None: + """ + The sub-thread in which the websocket session runs. + """ + scope = self.scope + receive = self._asgi_receive + send = self._asgi_send + try: + await self.app(scope, receive, send) + except Exception as e: + message = { + "type": "websocket.close", + "code": CloseReason.INTERNAL_ERROR, + "reason": str(e), + } + await self._asgi_send(message) + + async def _asgi_receive(self) -> Message: + while self._receive_queue.empty(): + await anyio.sleep(0) + return self._receive_queue.get() + + async def _asgi_send(self, message: Message) -> None: + self._send_queue.put(message) + + def _build_accept_response(self, message: Message) -> bytes: + subprotocol = message.get("subprotocol", None) + headers = message.get("headers", []) + return self.connection.send( + wsproto.events.AcceptConnection( + subprotocol=subprotocol, + extra_headers=headers, + ) + ) + + +class ASGIWebSocketTransport(ASGITransport): + def __init__( + self, + app: _ASGIApp, + raise_app_exceptions: bool = True, + root_path: str = "", + client: tuple[str, int] = ("127.0.0.1", 123), + ) -> None: + super().__init__(app, raise_app_exceptions, root_path, client) + self.exit_stack: contextlib.AsyncExitStack | None = None + + async def handle_async_request(self, request: Request) -> Response: + scheme = request.url.scheme + headers = request.headers + + if scheme in {"ws", "wss"} or headers.get("upgrade") == "websocket": + subprotocols: list[str] = [] + if (subprotocols_header := headers.get("sec-websocket-protocol")) is not None: + subprotocols = subprotocols_header.split(",") + + scope = { + "type": "websocket", + "path": request.url.path, + "raw_path": request.url.raw_path, + "root_path": self.root_path, + "scheme": scheme, + "query_string": request.url.query, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "client": self.client, + "server": (request.url.host, request.url.port), + "subprotocols": subprotocols, + } + return await self._handle_ws_request(request, scope) + + return await super().handle_async_request(request) + + async def _handle_ws_request( + self, + request: Request, + scope: Scope, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + self.scope = scope + self.exit_stack = contextlib.AsyncExitStack() + stream, accept_response = await self.exit_stack.enter_async_context( + ASGIWebSocketAsyncNetworkStream(self.app, self.scope) # type: ignore[arg-type] + ) + + accept_response_lines = accept_response.decode("utf-8").splitlines() + headers = [ + typing.cast(tuple[str, str], line.split(": ", 1)) + for line in accept_response_lines[1:] + if line.strip() != "" + ] + + return Response( + status_code=101, + headers=headers, + extensions={"network_stream": stream}, + ) + + async def aclose(self) -> None: + if self.exit_stack: + await self.exit_stack.aclose() diff --git a/src/httpx2/pyproject.toml b/src/httpx2/pyproject.toml index dc194f7f..2c3c54c0 100644 --- a/src/httpx2/pyproject.toml +++ b/src/httpx2/pyproject.toml @@ -67,6 +67,9 @@ http2 = [ socks = [ "socksio==1.*", ] +ws = [ + "wsproto>=1.2", +] # TODO(Marcelo): Remove when Python 3.13 reaches EOL. zstd = [ "zstandard>=0.18.0; python_version <= '3.13'", diff --git a/tests/httpx2/conftest.py b/tests/httpx2/conftest.py index 156c5e3a..5db7aa68 100644 --- a/tests/httpx2/conftest.py +++ b/tests/httpx2/conftest.py @@ -274,6 +274,6 @@ def serve_in_thread(server: TestServer) -> typing.Iterator[TestServer]: @pytest.fixture(scope="session") def server(free_tcp_port_factory: typing.Callable[[], int]) -> typing.Iterator[TestServer]: - config = Config(app=app, lifespan="off", loop="asyncio", port=free_tcp_port_factory()) + config = Config(app=app, lifespan="off", loop="asyncio", ws="none", port=free_tcp_port_factory()) server = TestServer(config=config) yield from serve_in_thread(server) diff --git a/tests/httpx2/test_exported_members.py b/tests/httpx2/test_exported_members.py index afa4d8e0..cd711074 100644 --- a/tests/httpx2/test_exported_members.py +++ b/tests/httpx2/test_exported_members.py @@ -3,7 +3,9 @@ def test_all_imports_are_exported() -> None: included_private_members = ["__description__", "__title__", "__version__"] - assert httpx2.__all__ == sorted( - (member for member in vars(httpx2).keys() if not member.startswith("_") or member in included_private_members), - key=str.casefold, - ) + # WebSocket members are exported lazily through `__getattr__` so they only + # appear in `vars(httpx2)` once accessed; force them in for the comparison. + lazy_members = httpx2._WEBSOCKET_NAMES + exported = {member for member in vars(httpx2) if not member.startswith("_") or member in included_private_members} + assert set(httpx2.__all__) == exported | lazy_members + assert httpx2.__all__ == sorted(httpx2.__all__, key=str.casefold) diff --git a/tests/httpx2/websockets/__init__.py b/tests/httpx2/websockets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py new file mode 100644 index 00000000..3699b550 --- /dev/null +++ b/tests/httpx2/websockets/conftest.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import contextlib +import pathlib +import queue +import tempfile +import time +import typing +from unittest.mock import MagicMock + +import pytest +import uvicorn +from anyio.from_thread import start_blocking_portal +from starlette.applications import Starlette +from starlette.routing import WebSocketRoute +from starlette.websockets import WebSocket + +WebSocketEndpoint = typing.Callable[[WebSocket], typing.Awaitable[None]] + + +@pytest.fixture +def on_receive_message() -> MagicMock: + return MagicMock() + + +@pytest.fixture(params=("wsproto", "websockets-sansio")) +def websocket_implementation(request: pytest.FixtureRequest) -> typing.Literal["wsproto", "websockets-sansio"]: + return request.param # type: ignore[no-any-return] + + +class ServerFactoryFixture(typing.Protocol): + def __call__(self, endpoint: WebSocketEndpoint) -> contextlib.AbstractContextManager[str]: ... + + +@pytest.fixture +def server_factory(websocket_implementation: typing.Literal["wsproto", "websockets-sansio"]) -> ServerFactoryFixture: + @contextlib.contextmanager + def _server_factory(endpoint: WebSocketEndpoint) -> typing.Iterator[str]: + shutdown_queue: queue.Queue[bool] = queue.Queue() + + def create_app() -> Starlette: + routes = [WebSocketRoute("/ws", endpoint=endpoint)] + return Starlette(routes=routes) + + def create_server(app: Starlette, socket: str) -> uvicorn.Server: + config = uvicorn.Config(app, uds=socket, ws=websocket_implementation, lifespan="off") + return uvicorn.Server(config) + + def on_server_stopped(_task: object) -> None: + shutdown_queue.put(True) + + with start_blocking_portal(backend="asyncio") as portal: + with tempfile.TemporaryDirectory() as socket_directory: + socket = str(pathlib.Path(socket_directory) / "socket.sock") + app = create_app() + server = create_server(app, socket) + task = portal.start_task_soon(server.serve) + task.add_done_callback(on_server_stopped) + while not server.started and not task.done(): + time.sleep(0.01) + if task.done() and task.exception() is not None: # pragma: no cover + raise typing.cast(BaseException, task.exception()) + yield socket + server.should_exit = True + shutdown_queue.get(True) + + return _server_factory diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py new file mode 100644 index 00000000..94143e84 --- /dev/null +++ b/tests/httpx2/websockets/test_api.py @@ -0,0 +1,839 @@ +from __future__ import annotations + +import contextlib +import queue +import threading +import time +from unittest.mock import MagicMock, call, patch + +import anyio +import pytest +import wsproto +from starlette.websockets import WebSocket, WebSocketDisconnect as StarletteWebSocketDisconnect + +import httpcore2 as httpcore +import httpx2 as httpx +from httpcore2 import AsyncNetworkStream, NetworkStream +from httpx2._websockets.api import ( + AsyncWebSocketSession, + JSONMode, + WebSocketSession, + aconnect_ws, + connect_ws, +) +from httpx2._websockets.exceptions import ( + WebSocketDisconnect, + WebSocketInvalidTypeReceived, + WebSocketNetworkError, + WebSocketUpgradeError, +) +from tests.httpx2.websockets.conftest import ServerFactoryFixture + + +@pytest.mark.anyio +async def test_upgrade_error() -> None: + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(400) + + with httpx.Client(base_url="http://localhost:8000", transport=httpx.MockTransport(handler)) as client: + with pytest.raises(WebSocketUpgradeError): + with connect_ws("http://socket/ws", client): + pass # pragma: no cover + + async with httpx.AsyncClient(base_url="http://localhost:8000", transport=httpx.MockTransport(handler)) as aclient: + with pytest.raises(WebSocketUpgradeError): + async with aconnect_ws("http://socket/ws", aclient): + pass # pragma: no cover + + +@pytest.mark.anyio +class TestSend: + async def test_send_error(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self._should_close = False + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + while not self._should_close: # pragma: no cover + time.sleep(0.1) + raise httpcore.ReadError() # pragma: no cover + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise httpcore.WriteError() + + def close(self) -> None: + self._should_close = True + + stream = MockNetworkStream() + with pytest.raises(WebSocketNetworkError): + with WebSocketSession(stream) as websocket_session: + websocket_session.send(wsproto.events.Ping()) + + async def test_async_send_error(self) -> None: + class AsyncMockNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self._should_close = False + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + while not self._should_close: # pragma: no cover + await anyio.sleep(0.1) + raise httpcore.ReadError() # pragma: no cover + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise httpcore.WriteError() + + async def aclose(self) -> None: + self._should_close = True + + stream = AsyncMockNetworkStream() + with pytest.raises(WebSocketNetworkError): + async with AsyncWebSocketSession(stream) as websocket_session: + await websocket_session.send(wsproto.events.Ping()) + + async def test_send( + self, + server_factory: ServerFactoryFixture, + on_receive_message: MagicMock, + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + message = await websocket.receive_text() + on_receive_message(message) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + try: + with connect_ws("http://socket/ws", client) as ws: + ws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) + except WebSocketDisconnect: # pragma: no cover + pass + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + try: + async with aconnect_ws("http://socket/ws", aclient) as aws: + await aws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) + except WebSocketDisconnect: # pragma: no cover + pass + + on_receive_message.assert_has_calls([call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")]) + + async def test_send_text( + self, + server_factory: ServerFactoryFixture, + on_receive_message: MagicMock, + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + message = await websocket.receive_text() + on_receive_message(message) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + try: + with connect_ws("http://socket/ws", client) as ws: + ws.send_text("CLIENT_MESSAGE") + except WebSocketDisconnect: # pragma: no cover + pass + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + try: + async with aconnect_ws("http://socket/ws", aclient) as aws: + await aws.send_text("CLIENT_MESSAGE") + except WebSocketDisconnect: # pragma: no cover + pass + + on_receive_message.assert_has_calls([call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")]) + + async def test_send_bytes( + self, + server_factory: ServerFactoryFixture, + on_receive_message: MagicMock, + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + message = await websocket.receive_bytes() + on_receive_message(message) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + try: + with connect_ws("http://socket/ws", client) as ws: + ws.send_bytes(b"CLIENT_MESSAGE") + except WebSocketDisconnect: # pragma: no cover + pass + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + try: + async with aconnect_ws("http://socket/ws", aclient) as aws: + await aws.send_bytes(b"CLIENT_MESSAGE") + except WebSocketDisconnect: # pragma: no cover + pass + + on_receive_message.assert_has_calls([call(b"CLIENT_MESSAGE"), call(b"CLIENT_MESSAGE")]) + + @pytest.mark.parametrize("mode", ["text", "binary"]) + async def test_send_json( + self, + mode: JSONMode, + server_factory: ServerFactoryFixture, + on_receive_message: MagicMock, + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + message = await websocket.receive_json(mode=mode) + on_receive_message(message) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + try: + with connect_ws("http://socket/ws", client) as ws: + ws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) + except WebSocketDisconnect: # pragma: no cover + pass + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + try: + async with aconnect_ws("http://socket/ws", aclient) as aws: + await aws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) + except WebSocketDisconnect: # pragma: no cover + pass + + on_receive_message.assert_has_calls([call({"message": "CLIENT_MESSAGE"}), call({"message": "CLIENT_MESSAGE"})]) + + +@pytest.mark.anyio +class TestReceive: + async def test_receive_error(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + raise httpcore.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + def close(self) -> None: + pass + + stream = MockNetworkStream() + with pytest.raises(WebSocketNetworkError): + with WebSocketSession(stream) as websocket_session: + websocket_session.receive() + + async def test_async_receive_error(self) -> None: + class AsyncMockNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + raise httpcore.ReadError() + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + async def aclose(self) -> None: + pass + + stream = AsyncMockNetworkStream() + with pytest.raises(WebSocketNetworkError): + async with AsyncWebSocketSession(stream) as websocket_session: + await websocket_session.receive() + + async def test_receive(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_text("SERVER_MESSAGE") + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + try: + with connect_ws("http://socket/ws", client) as ws: + event = ws.receive() + assert isinstance(event, wsproto.events.TextMessage) + assert event.data == "SERVER_MESSAGE" + except WebSocketDisconnect: # pragma: no cover + pass + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + try: + async with aconnect_ws("http://socket/ws", aclient) as aws: + event = await aws.receive() + assert isinstance(event, wsproto.events.TextMessage) + assert event.data == "SERVER_MESSAGE" + except WebSocketDisconnect: # pragma: no cover + pass + + @pytest.mark.parametrize( + "full_message,send_method", + [ + pytest.param(b"A" * 1024 * 4, "send_bytes", id="bytes"), + pytest.param("A" * 1024 * 4, "send_text", id="text"), + ], + ) + async def test_receive_oversized_message( + self, + full_message: str | bytes, + send_method: str, + server_factory: ServerFactoryFixture, + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + method = getattr(websocket, send_method) + await method(full_message) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + try: + with connect_ws("http://socket/ws", client, max_message_size_bytes=1024) as ws: + event = ws.receive() + assert isinstance(event, wsproto.events.Message) + assert event.data == full_message + except WebSocketDisconnect: # pragma: no cover + pass + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + try: + async with aconnect_ws( + "http://socket/ws", + aclient, + keepalive_ping_interval_seconds=None, + ) as aws: + event = await aws.receive() + assert isinstance(event, wsproto.events.Message) + assert event.data == full_message + except WebSocketDisconnect: # pragma: no cover + pass + + async def test_receive_text(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_text("SERVER_MESSAGE") + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + try: + with connect_ws("http://socket/ws", client) as ws: + data = ws.receive_text() + assert data == "SERVER_MESSAGE" + except WebSocketDisconnect: # pragma: no cover + pass + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + try: + async with aconnect_ws("http://socket/ws", aclient) as aws: + data = await aws.receive_text() + assert data == "SERVER_MESSAGE" + except WebSocketDisconnect: # pragma: no cover + pass + + async def test_receive_text_invalid_type(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_bytes(b"SERVER_MESSAGE") + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + try: + with connect_ws("http://socket/ws", client) as ws: + with pytest.raises(WebSocketInvalidTypeReceived): + ws.receive_text() + except WebSocketDisconnect: # pragma: no cover + pass + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + try: + async with aconnect_ws("http://socket/ws", aclient) as aws: + with pytest.raises(WebSocketInvalidTypeReceived): + await aws.receive_text() + except WebSocketDisconnect: # pragma: no cover + pass + + async def test_receive_bytes(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_bytes(b"SERVER_MESSAGE") + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + try: + with connect_ws("http://socket/ws", client) as ws: + data = ws.receive_bytes() + assert data == b"SERVER_MESSAGE" + except WebSocketDisconnect: # pragma: no cover + pass + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + try: + async with aconnect_ws("http://socket/ws", aclient) as aws: + data = await aws.receive_bytes() + assert data == b"SERVER_MESSAGE" + except WebSocketDisconnect: # pragma: no cover + pass + + async def test_receive_bytes_invalid_type(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_text("SERVER_MESSAGE") + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with connect_ws("http://socket/ws", client) as ws: + with pytest.raises(WebSocketInvalidTypeReceived): + ws.receive_bytes() + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + async with aconnect_ws("http://socket/ws", aclient) as aws: + with pytest.raises(WebSocketInvalidTypeReceived): + await aws.receive_bytes() + + @pytest.mark.parametrize("mode", ["text", "binary"]) + async def test_receive_json(self, mode: JSONMode, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_json({"message": "SERVER_MESSAGE"}, mode=mode) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + try: + with connect_ws("http://socket/ws", client) as ws: + data = ws.receive_json(mode=mode) + assert data == {"message": "SERVER_MESSAGE"} + except WebSocketDisconnect: # pragma: no cover + pass + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + try: + async with aconnect_ws("http://socket/ws", aclient) as aws: + data = await aws.receive_json(mode=mode) + assert data == {"message": "SERVER_MESSAGE"} + except WebSocketDisconnect: # pragma: no cover + pass + + +@pytest.mark.anyio +class TestReceivePing: + async def test_receive_ping(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self.events_to_send = [ + wsproto.events.Ping(b"SERVER_PING"), + wsproto.events.CloseConnection(1000), + ] + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + try: + event = self.events_to_send.pop(0) + return self.connection.send(event) + except IndexError: # pragma: no cover + raise httpcore.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + self.connection.receive_data(buffer) + + def close(self) -> None: + pass + + stream = MockNetworkStream() + with WebSocketSession(stream): + await anyio.sleep(0.1) + + received_events = list(stream.connection.events()) + assert received_events == [ + wsproto.events.Pong(b"SERVER_PING"), + wsproto.events.CloseConnection(1000, ""), + ] + + async def test_async_receive_ping(self) -> None: + class MockAsyncNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self.events_to_send = [ + wsproto.events.Ping(b"SERVER_PING"), + wsproto.events.CloseConnection(1000), + ] + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + try: + event = self.events_to_send.pop(0) + return self.connection.send(event) + except IndexError: # pragma: no cover + raise httpcore.ReadError() + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + self.connection.receive_data(buffer) + + async def aclose(self) -> None: + pass + + stream = MockAsyncNetworkStream() + async with AsyncWebSocketSession(stream): + await anyio.sleep(0.1) + + received_events = list(stream.connection.events()) + assert received_events == [ + wsproto.events.Pong(b"SERVER_PING"), + wsproto.events.CloseConnection(1000, ""), + ] + + +@pytest.mark.anyio +class TestKeepalivePing: + async def test_keepalive_ping(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self._should_close = False + self.ping_received = 0 + self.ping_answered = 0 + self.events_to_send: queue.Queue[wsproto.events.Event] = queue.Queue() + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + while not self._should_close: + try: + event = self.events_to_send.get_nowait() + self.ping_answered += 1 + return self.connection.send(event) + except queue.Empty: + pass + raise httpcore.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + self.connection.receive_data(buffer) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Ping): + self.ping_received += 1 + self.events_to_send.put(event.response()) + + def close(self) -> None: + self._should_close = True + + stream = MockNetworkStream() + with WebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ): + await anyio.sleep(0.2) + + assert stream.ping_received >= 1 + assert stream.ping_answered >= 1 + + async def test_keepalive_ping_timeout(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self._should_close = False + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + while not self._should_close: # pragma: no cover + time.sleep(0.1) + raise httpcore.ReadError() # pragma: no cover + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + def close(self) -> None: + self._should_close = True + + stream = MockNetworkStream() + with pytest.raises(WebSocketNetworkError): + with WebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ) as websocket_session: + websocket_session.receive() + + @pytest.mark.flaky(max_runs=5, min_passes=1) + async def test_async_keepalive_ping(self) -> None: + class MockAsyncNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self._should_close = False + self.ping_received = 0 + self.ping_answered = 0 + ( + self.send_events, + self.receive_events, + ) = anyio.create_memory_object_stream[wsproto.events.Event]() + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + while not self._should_close: + try: + event = self.receive_events.receive_nowait() + self.ping_answered += 1 + return self.connection.send(event) + except anyio.WouldBlock: + await anyio.sleep(0.1) + raise httpcore.ReadError() # pragma: no cover + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + self.connection.receive_data(buffer) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Ping): + self.ping_received += 1 + await self.send_events.send(event.response()) + + async def aclose(self) -> None: + self._should_close = True + await self.send_events.aclose() + await self.receive_events.aclose() + + stream = MockAsyncNetworkStream() + async with AsyncWebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ): + await anyio.sleep(0.3) + + assert stream.ping_received >= 1 + assert stream.ping_answered >= 1 + + async def test_async_keepalive_ping_skips_when_closing(self) -> None: + writes = 0 + + class MockAsyncNetworkStream(AsyncNetworkStream): + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + return b"" # pragma: no cover + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + nonlocal writes + writes += 1 # pragma: no cover + + async def aclose(self) -> None: ... # pragma: no cover + + session = AsyncWebSocketSession(MockAsyncNetworkStream(), keepalive_ping_interval_seconds=0.2) + async with anyio.create_task_group() as tg: + tg.start_soon(session._background_keepalive_ping, 0.2) + await anyio.sleep(0.05) # Let the loop enter its sleep before closing. + session._should_close.set() + + assert writes == 0 + + async def test_async_keepalive_ping_timeout(self) -> None: + class MockAsyncNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + self._should_close = False + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + while not self._should_close: # pragma: no cover + await anyio.sleep(0.1) + raise httpcore.ReadError() # pragma: no cover + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + async def aclose(self) -> None: + self._should_close = True + + stream = MockAsyncNetworkStream() + with pytest.raises(WebSocketNetworkError): + async with AsyncWebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ) as websocket_session: + await websocket_session.receive() + + +@pytest.mark.anyio +async def test_ping_pong(server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + try: + await websocket.receive_text() + except StarletteWebSocketDisconnect: + pass + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with connect_ws("http://socket/ws", client) as ws: + ping_callback = ws.ping() + result = ping_callback.wait() + assert result is True + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + async with aconnect_ws("http://socket/ws", aclient) as aws: + aping_callback = await aws.ping() + await aping_callback.wait() + assert aping_callback.is_set() + + +@pytest.mark.anyio +async def test_send_close(server_factory: ServerFactoryFixture, on_receive_message: MagicMock) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + try: + await websocket.receive_text() + except StarletteWebSocketDisconnect as e: + print(e) + on_receive_message(e.code, e.reason) + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with connect_ws("http://socket/ws", client) as ws: + ws.close(code=1001, reason="CLOSE_REASON") + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + async with aconnect_ws("http://socket/ws", aclient) as aws: + await aws.close(code=1001, reason="CLOSE_REASON") + + on_receive_message.assert_has_calls([call(1001, "CLOSE_REASON"), call(1001, "CLOSE_REASON")]) + + +@pytest.mark.anyio +async def test_receive_close(server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with connect_ws("http://socket/ws", client) as ws: + with pytest.raises(WebSocketDisconnect): + ws.receive() + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + async with aconnect_ws("http://socket/ws", aclient) as aws: + with pytest.raises(WebSocketDisconnect): + await aws.receive() + + +@pytest.mark.anyio +async def test_default_httpx_client() -> None: + mock_context = contextlib.ExitStack() + with patch("httpx2._websockets.api._connect_ws", return_value=mock_context) as mock_connect_ws: + with connect_ws("http://socket/ws"): + pass + mock_connect_ws.assert_called_once() + httpx_client = mock_connect_ws.call_args[1]["client"] + assert isinstance(httpx_client, httpx.Client) + assert httpx_client.is_closed + + mock_async_context = contextlib.AsyncExitStack() + with patch("httpx2._websockets.api._aconnect_ws", return_value=mock_async_context) as mock_aconnect_ws: + async with aconnect_ws("http://socket/ws"): + pass + mock_aconnect_ws.assert_called_once() + httpx_client = mock_aconnect_ws.call_args[1]["client"] + assert isinstance(httpx_client, httpx.AsyncClient) + assert httpx_client.is_closed + + +@pytest.mark.anyio +async def test_subprotocol_and_response() -> None: + def handler(request: httpx.Request) -> httpx.Response: + assert request.headers["sec-websocket-protocol"] == "custom_protocol, unsupported_protocol" + + return httpx.Response( + 101, + headers={"sec-websocket-protocol": "custom_protocol"}, + extensions={"network_stream": MagicMock(spec=NetworkStream)}, + ) + + def async_handler(request: httpx.Request) -> httpx.Response: + assert request.headers["sec-websocket-protocol"] == "custom_protocol, unsupported_protocol" + + return httpx.Response( + 101, + headers={"sec-websocket-protocol": "custom_protocol"}, + extensions={"network_stream": MagicMock(spec=AsyncNetworkStream)}, + ) + + with httpx.Client(base_url="http://localhost:8000", transport=httpx.MockTransport(handler)) as client: + with connect_ws( + "http://socket/ws", + client, + subprotocols=["custom_protocol", "unsupported_protocol"], + ) as ws: + assert isinstance(ws.response, httpx.Response) + assert ws.subprotocol == "custom_protocol" + assert ws.response.headers["sec-websocket-protocol"] == ws.subprotocol + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=httpx.MockTransport(async_handler) + ) as aclient: + async with aconnect_ws( + "http://socket/ws", + aclient, + subprotocols=["custom_protocol", "unsupported_protocol"], + ) as aws: + assert isinstance(aws.response, httpx.Response) + assert aws.subprotocol == "custom_protocol" + assert aws.response.headers["sec-websocket-protocol"] == aws.subprotocol + + +@pytest.mark.anyio +async def test_threads_wont_hang(server_factory: ServerFactoryFixture) -> None: + """ + Check that all threads spawned in WebSocketSession are properly terminated during + a series of messages exchange. This used to be the cause of a memory leak in the + connect_ws client, see https://github.com/frankie567/httpx-ws/issues/76. + """ + + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + for _ in range(50): + await websocket.send_text("SERVER_MESSAGE") + await websocket.receive_text() + await websocket.close() + + def session_threads() -> set[threading.Thread]: + return set(threading.enumerate()) - threads_before + + def wait_for_session_threads(expected: int) -> None: + for _ in range(100): + if len(session_threads()) == expected: + return + time.sleep(0.05) + assert len(session_threads()) == expected # pragma: no cover + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + threads_before = set(threading.enumerate()) + with connect_ws("http://socket/ws", client, keepalive_ping_interval_seconds=None) as ws: + for _ in range(50): + ws.receive() + ws.send_text("CLIENT_MESSAGE") + wait_for_session_threads(2) + wait_for_session_threads(0) diff --git a/tests/httpx2/websockets/test_high_level.py b/tests/httpx2/websockets/test_high_level.py new file mode 100644 index 00000000..ece658bf --- /dev/null +++ b/tests/httpx2/websockets/test_high_level.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import contextlib +from unittest.mock import patch + +import anyio +import pytest +import wsproto +from starlette.websockets import WebSocket + +import httpx2 as httpx +from httpcore2 import AsyncNetworkStream +from httpx2._websockets.api import AsyncWebSocketSession, WebSocketSession +from httpx2._websockets.transport import ASGIWebSocketTransport +from tests.httpx2.websockets.conftest import ServerFactoryFixture + + +@pytest.mark.parametrize( + ("name", "expected"), + [ + ("WebSocketSession", WebSocketSession), + ("AsyncWebSocketSession", AsyncWebSocketSession), + ("ASGIWebSocketTransport", ASGIWebSocketTransport), + ("WebSocketDisconnect", httpx.WebSocketDisconnect), + ], +) +def test_top_level_names_are_lazily_exported(name: str, expected: object) -> None: + assert getattr(httpx, name) is expected + + +def test_top_level_websocket_uses_a_dedicated_client() -> None: + mock_context = contextlib.ExitStack() + with patch("httpx2._websockets.api._connect_ws", return_value=mock_context) as mock_connect_ws: + with httpx.websocket("http://socket/ws"): + pass + mock_connect_ws.assert_called_once() + client = mock_connect_ws.call_args[1]["client"] + assert isinstance(client, httpx.Client) + assert client.is_closed + + +@pytest.mark.anyio +async def test_client_websocket(server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.send_text("SERVER_MESSAGE") + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + assert ws.receive_text() == "SERVER_MESSAGE" + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + assert await aws.receive_text() == "SERVER_MESSAGE" + + +@pytest.mark.anyio +async def test_client_websocket_forwards_request_params(server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.send_text(websocket.headers.get("x-token", "")) + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws", headers={"x-token": "secret"}) as ws: + assert ws.receive_text() == "secret" + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws", headers={"x-token": "secret"}) as aws: + assert await aws.receive_text() == "secret" + + +@pytest.mark.anyio +async def test_async_receive_reassembles_fragmented_message() -> None: + server = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + fragments = server.send(wsproto.events.TextMessage("FRAG", message_finished=False)) + fragments += server.send(wsproto.events.TextMessage("MEN", message_finished=False)) + fragments += server.send(wsproto.events.TextMessage("TED", message_finished=True)) + + class AsyncMockNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self._sent = False + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + if self._sent: + await anyio.sleep_forever() + self._sent = True + return fragments + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: ... + + async def aclose(self) -> None: ... + + async with AsyncWebSocketSession(AsyncMockNetworkStream()) as ws: + assert await ws.receive_text() == "FRAGMENTED" diff --git a/tests/httpx2/websockets/test_transport.py b/tests/httpx2/websockets/test_transport.py new file mode 100644 index 00000000..e14365f3 --- /dev/null +++ b/tests/httpx2/websockets/test_transport.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +import base64 +import secrets +from typing import Any + +import pytest +import wsproto +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import PlainTextResponse +from starlette.routing import Route, WebSocketRoute +from starlette.websockets import WebSocket + +import httpx2 as httpx +from httpx2._websockets.api import aconnect_ws +from httpx2._websockets.exceptions import WebSocketDisconnect +from httpx2._websockets.transport import ( + ASGIWebSocketAsyncNetworkStream, + ASGIWebSocketTransport, + Receive, + Scope, + Send, + UnhandledASGIMessageType, + UnhandledWebSocketEvent, +) + + +@pytest.fixture +def websocket_request_headers() -> dict[str, str]: + return { + "connection": "upgrade", + "upgrade": "websocket", + "sec-websocket-key": base64.b64encode(secrets.token_bytes(16)).decode("utf-8"), + "sec-websocket-version": "13", + } + + +@pytest.fixture +def scope(websocket_request_headers: dict[str, str]) -> Scope: + return { + "type": "websocket", + "path": "/ws", + "raw_path": "/ws", + "root_path": "/", + "scheme": "ws", + "headers": [ + ("host", "localhost"), + *websocket_request_headers.items(), + ], + "subprotocols": [], + "server": ("localhost", 8000), + } + + +@pytest.mark.anyio +class TestASGIWebSocketAsyncNetworkStream: + async def test_write(self, scope: Scope) -> None: + received_messages = [] + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + message = await receive() + received_messages.append(message) + while message["type"] != "websocket.close": + message = await receive() + received_messages.append(message) + + connection = wsproto.connection.Connection(wsproto.connection.CLIENT) + async with ASGIWebSocketAsyncNetworkStream(app, scope) as (stream, _): + text_event = wsproto.events.TextMessage("CLIENT_MESSAGE") + await stream.write(connection.send(text_event)) + + bytes_event = wsproto.events.BytesMessage(b"CLIENT_MESSAGE") + await stream.write(connection.send(bytes_event)) + + close_event = wsproto.events.CloseConnection(1000) + await stream.write(connection.send(close_event)) + + assert received_messages == [ + {"type": "websocket.connect"}, + {"type": "websocket.receive", "text": "CLIENT_MESSAGE"}, + {"type": "websocket.receive", "bytes": b"CLIENT_MESSAGE"}, + {"type": "websocket.close", "code": 1000, "reason": ""}, + ] + + async def test_write_unhandled_event(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + await receive() + + connection = wsproto.connection.Connection(wsproto.connection.CLIENT) + async with ASGIWebSocketAsyncNetworkStream(app, scope) as (stream, _): + with pytest.raises(UnhandledWebSocketEvent): + ping_event = wsproto.events.Ping(b"PING") + await stream.write(connection.send(ping_event)) + + async def test_read(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + await send({"type": "websocket.send", "text": "SERVER_MESSAGE"}) + await send({"type": "websocket.send", "bytes": b"SERVER_MESSAGE"}) + await send({"type": "websocket.close", "code": 1000, "reason": ""}) + + connection = wsproto.connection.Connection(wsproto.connection.CLIENT) + events = [] + async with ASGIWebSocketAsyncNetworkStream(app, scope) as (stream, _): + for _ in range(3): + data = await stream.read(4096) + connection.receive_data(data) + + events = list(connection.events()) + assert events == [ + wsproto.events.TextMessage("SERVER_MESSAGE"), + wsproto.events.BytesMessage(bytearray(b"SERVER_MESSAGE")), + wsproto.events.CloseConnection(1000, ""), + ] + + async def test_read_unhandled_asgi_message(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + await send({"type": "websocket.foo"}) + + async with ASGIWebSocketAsyncNetworkStream(app, scope) as (stream, _): + with pytest.raises(UnhandledASGIMessageType): + await stream.read(4096) + + async def test_close_immediately(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.close", "code": 1000, "reason": ""}) + + with pytest.raises(WebSocketDisconnect): + async with ASGIWebSocketAsyncNetworkStream(app, scope): + pass # pragma: no cover + + async def test_exception(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + raise Exception("Error") + + with pytest.raises(WebSocketDisconnect) as excinfo: + async with ASGIWebSocketAsyncNetworkStream(app, scope): + pass # pragma: no cover + assert excinfo.value.code == 1011 + assert excinfo.value.reason == "Error" + + +@pytest.fixture +def test_app() -> Starlette: + async def http_endpoint(request: Request) -> PlainTextResponse: + return PlainTextResponse("Hello, world!") + + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.receive_text() + await websocket.close() # pragma: no cover + + routes = [ + Route("/http", endpoint=http_endpoint), + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + + return Starlette(routes=routes) + + +@pytest.mark.anyio +class TestASGIWebSocketTransport: + async def test_http(self, test_app: Starlette) -> None: + async with ASGIWebSocketTransport(app=test_app) as transport: + request = httpx.Request("GET", "http://localhost:8000/http") + response = await transport.handle_async_request(request) + assert response.status_code == 200 + + @pytest.mark.parametrize( + "url,headers", + [ + ("ws://localhost:8000/ws", {}), + ("wss://localhost:8000/ws", {}), + ("http://localhost:8000/ws", {"upgrade": "websocket"}), + ], + ) + async def test_websocket( + self, + url: str, + headers: dict[str, Any], + test_app: Starlette, + websocket_request_headers: dict[str, str], + ) -> None: + async with ASGIWebSocketTransport(app=test_app) as transport: + request = httpx.Request("GET", url, headers={**websocket_request_headers, **headers}) + response = await transport.handle_async_request(request) + assert response.status_code == 101 + + assert isinstance(response.extensions["network_stream"], ASGIWebSocketAsyncNetworkStream) + + +@pytest.mark.anyio +async def test_subprotocol_support() -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept("custom_protocol") + assert websocket.scope.get("subprotocols") == ["custom_protocol"] + await websocket.send_text("SERVER_MESSAGE") + await websocket.close() + + app = Starlette( + routes=[ + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + ) + + async with httpx.AsyncClient(transport=ASGIWebSocketTransport(app)) as client: + async with aconnect_ws("ws://localhost:8000/ws", client, subprotocols=["custom_protocol"]) as ws: + await ws.receive_text() + assert ws.subprotocol == "custom_protocol" + + +@pytest.mark.anyio +async def test_keepalive_ping_disabled() -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.receive_text() + await websocket.close() # pragma: no cover + + app = Starlette( + routes=[ + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + ) + + async with httpx.AsyncClient(transport=ASGIWebSocketTransport(app)) as client: + async with aconnect_ws("ws://localhost:8000/ws", client) as ws: + assert ws._keepalive_ping_interval_seconds is None diff --git a/uv.lock b/uv.lock index ef708e99..de30c2c4 100644 --- a/uv.lock +++ b/uv.lock @@ -33,19 +33,22 @@ dev = [ { name = "chardet", specifier = "==6.0.0.post1" }, { name = "coverage", extras = ["toml"], specifier = "==7.10.6" }, { name = "cryptography", specifier = "==46.0.7" }, + { name = "flaky", specifier = ">=3.8" }, { name = "httpcore2", extras = ["asyncio", "trio", "http2", "socks"], editable = "src/httpcore2" }, - { name = "httpx2", extras = ["brotli", "cli", "http2", "socks", "zstd"], editable = "src/httpx2" }, + { name = "httpx2", extras = ["brotli", "cli", "http2", "socks", "ws", "zstd"], editable = "src/httpx2" }, { name = "mypy", specifier = "==1.17.1" }, { name = "pytest", specifier = ">=9.0.3" }, { name = "pytest-codspeed", specifier = ">=4.1.1" }, { name = "pytest-httpbin", specifier = "==2.0.0" }, { name = "pytest-trio", specifier = "==0.8.0" }, { name = "ruff", specifier = "==0.15.13" }, + { name = "starlette", specifier = ">=0.49" }, { name = "trio", specifier = "==0.31.0" }, { name = "trio-typing", specifier = "==0.10.0" }, { name = "trustme", specifier = "==1.2.1" }, { name = "twine", specifier = "==6.1.0" }, { name = "uvicorn", specifier = ">=0.35" }, + { name = "websockets", specifier = ">=15" }, { name = "werkzeug", specifier = ">=3.1.6" }, ] docs = [ @@ -934,6 +937,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740, upload-time = "2025-11-21T23:01:53.443Z" }, ] +[[package]] +name = "flaky" +version = "3.8.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5b/c5/ef69119a01427204ff2db5fc8f98001087bcce719bbb94749dcd7b191365/flaky-3.8.1.tar.gz", hash = "sha256:47204a81ec905f3d5acfbd61daeabcada8f9d4031616d9bcb0618461729699f5", size = 25248, upload-time = "2024-03-12T22:17:59.265Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7f/b8/b830fc43663246c3f3dd1ae7dca4847b96ed992537e85311e27fa41ac40e/flaky-3.8.1-py2.py3-none-any.whl", hash = "sha256:194ccf4f0d3a22b2de7130f4b62e45e977ac1b5ccad74d4d48f3005dcc38815e", size = 19139, upload-time = "2024-03-12T22:17:51.59Z" }, +] + [[package]] name = "flasgger" version = "0.9.7.1" @@ -1372,6 +1384,9 @@ http2 = [ socks = [ { name = "socksio" }, ] +ws = [ + { name = "wsproto" }, +] zstd = [ { name = "zstandard", marker = "python_full_version < '3.14'" }, ] @@ -1390,9 +1405,10 @@ requires-dist = [ { name = "socksio", marker = "extra == 'socks'", specifier = "==1.*" }, { name = "truststore", specifier = ">=0.10" }, { name = "typing-extensions", marker = "python_full_version < '3.13'", specifier = ">=4.5.0" }, + { name = "wsproto", marker = "extra == 'ws'", specifier = ">=1.2" }, { name = "zstandard", marker = "python_full_version < '3.14' and extra == 'zstd'", specifier = ">=0.18.0" }, ] -provides-extras = ["brotli", "cli", "http2", "socks", "zstd"] +provides-extras = ["brotli", "cli", "http2", "socks", "ws", "zstd"] [[package]] name = "hyperframe" @@ -3175,6 +3191,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" }, ] +[[package]] +name = "starlette" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/e3/7c1dc7381d9f8ab7d854328ebfa884e62cb3f3d8549ddfd37c7814f42afa/starlette-1.3.1.tar.gz", hash = "sha256:05d0213193f2fbaae60e2ecb593b4add4262ad4e46536b54abe36f11a71724e0", size = 2703240, upload-time = "2026-06-12T09:23:11.602Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/bb/2799cc2ede3ed41131f8975621e7213dfc7ef4acbbaadfa440f32500c370/starlette-1.3.1-py3-none-any.whl", hash = "sha256:c7372aae11c3c3f26a42df7bd626cec2f47d03483d261d369516a615a53714c6", size = 73632, upload-time = "2026-06-12T09:23:10.017Z" }, +] + [[package]] name = "tomli" version = "2.4.1" @@ -3370,6 +3399,74 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, ] +[[package]] +name = "websockets" +version = "16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/04/24/4b2031d72e840ce4c1ccb255f693b15c334757fc50023e4db9537080b8c4/websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5", size = 179346, upload-time = "2026-01-10T09:23:47.181Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/74/221f58decd852f4b59cc3354cccaf87e8ef695fede361d03dc9a7396573b/websockets-16.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04cdd5d2d1dacbad0a7bf36ccbcd3ccd5a30ee188f2560b7a62a30d14107b31a", size = 177343, upload-time = "2026-01-10T09:22:21.28Z" }, + { url = "https://files.pythonhosted.org/packages/19/0f/22ef6107ee52ab7f0b710d55d36f5a5d3ef19e8a205541a6d7ffa7994e5a/websockets-16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8ff32bb86522a9e5e31439a58addbb0166f0204d64066fb955265c4e214160f0", size = 175021, upload-time = "2026-01-10T09:22:22.696Z" }, + { url = "https://files.pythonhosted.org/packages/10/40/904a4cb30d9b61c0e278899bf36342e9b0208eb3c470324a9ecbaac2a30f/websockets-16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:583b7c42688636f930688d712885cf1531326ee05effd982028212ccc13e5957", size = 175320, upload-time = "2026-01-10T09:22:23.94Z" }, + { url = "https://files.pythonhosted.org/packages/9d/2f/4b3ca7e106bc608744b1cdae041e005e446124bebb037b18799c2d356864/websockets-16.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7d837379b647c0c4c2355c2499723f82f1635fd2c26510e1f587d89bc2199e72", size = 183815, upload-time = "2026-01-10T09:22:25.469Z" }, + { url = "https://files.pythonhosted.org/packages/86/26/d40eaa2a46d4302becec8d15b0fc5e45bdde05191e7628405a19cf491ccd/websockets-16.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df57afc692e517a85e65b72e165356ed1df12386ecb879ad5693be08fac65dde", size = 185054, upload-time = "2026-01-10T09:22:27.101Z" }, + { url = "https://files.pythonhosted.org/packages/b0/ba/6500a0efc94f7373ee8fefa8c271acdfd4dca8bd49a90d4be7ccabfc397e/websockets-16.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2b9f1e0d69bc60a4a87349d50c09a037a2607918746f07de04df9e43252c77a3", size = 184565, upload-time = "2026-01-10T09:22:28.293Z" }, + { url = "https://files.pythonhosted.org/packages/04/b4/96bf2cee7c8d8102389374a2616200574f5f01128d1082f44102140344cc/websockets-16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:335c23addf3d5e6a8633f9f8eda77efad001671e80b95c491dd0924587ece0b3", size = 183848, upload-time = "2026-01-10T09:22:30.394Z" }, + { url = "https://files.pythonhosted.org/packages/02/8e/81f40fb00fd125357814e8c3025738fc4ffc3da4b6b4a4472a82ba304b41/websockets-16.0-cp310-cp310-win32.whl", hash = "sha256:37b31c1623c6605e4c00d466c9d633f9b812ea430c11c8a278774a1fde1acfa9", size = 178249, upload-time = "2026-01-10T09:22:32.083Z" }, + { url = "https://files.pythonhosted.org/packages/b4/5f/7e40efe8df57db9b91c88a43690ac66f7b7aa73a11aa6a66b927e44f26fa/websockets-16.0-cp310-cp310-win_amd64.whl", hash = "sha256:8e1dab317b6e77424356e11e99a432b7cb2f3ec8c5ab4dabbcee6add48f72b35", size = 178685, upload-time = "2026-01-10T09:22:33.345Z" }, + { url = "https://files.pythonhosted.org/packages/f2/db/de907251b4ff46ae804ad0409809504153b3f30984daf82a1d84a9875830/websockets-16.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:31a52addea25187bde0797a97d6fc3d2f92b6f72a9370792d65a6e84615ac8a8", size = 177340, upload-time = "2026-01-10T09:22:34.539Z" }, + { url = "https://files.pythonhosted.org/packages/f3/fa/abe89019d8d8815c8781e90d697dec52523fb8ebe308bf11664e8de1877e/websockets-16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:417b28978cdccab24f46400586d128366313e8a96312e4b9362a4af504f3bbad", size = 175022, upload-time = "2026-01-10T09:22:36.332Z" }, + { url = "https://files.pythonhosted.org/packages/58/5d/88ea17ed1ded2079358b40d31d48abe90a73c9e5819dbcde1606e991e2ad/websockets-16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af80d74d4edfa3cb9ed973a0a5ba2b2a549371f8a741e0800cb07becdd20f23d", size = 175319, upload-time = "2026-01-10T09:22:37.602Z" }, + { url = "https://files.pythonhosted.org/packages/d2/ae/0ee92b33087a33632f37a635e11e1d99d429d3d323329675a6022312aac2/websockets-16.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:08d7af67b64d29823fed316505a89b86705f2b7981c07848fb5e3ea3020c1abe", size = 184631, upload-time = "2026-01-10T09:22:38.789Z" }, + { url = "https://files.pythonhosted.org/packages/c8/c5/27178df583b6c5b31b29f526ba2da5e2f864ecc79c99dae630a85d68c304/websockets-16.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7be95cfb0a4dae143eaed2bcba8ac23f4892d8971311f1b06f3c6b78952ee70b", size = 185870, upload-time = "2026-01-10T09:22:39.893Z" }, + { url = "https://files.pythonhosted.org/packages/87/05/536652aa84ddc1c018dbb7e2c4cbcd0db884580bf8e95aece7593fde526f/websockets-16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d6297ce39ce5c2e6feb13c1a996a2ded3b6832155fcfc920265c76f24c7cceb5", size = 185361, upload-time = "2026-01-10T09:22:41.016Z" }, + { url = "https://files.pythonhosted.org/packages/6d/e2/d5332c90da12b1e01f06fb1b85c50cfc489783076547415bf9f0a659ec19/websockets-16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1c1b30e4f497b0b354057f3467f56244c603a79c0d1dafce1d16c283c25f6e64", size = 184615, upload-time = "2026-01-10T09:22:42.442Z" }, + { url = "https://files.pythonhosted.org/packages/77/fb/d3f9576691cae9253b51555f841bc6600bf0a983a461c79500ace5a5b364/websockets-16.0-cp311-cp311-win32.whl", hash = "sha256:5f451484aeb5cafee1ccf789b1b66f535409d038c56966d6101740c1614b86c6", size = 178246, upload-time = "2026-01-10T09:22:43.654Z" }, + { url = "https://files.pythonhosted.org/packages/54/67/eaff76b3dbaf18dcddabc3b8c1dba50b483761cccff67793897945b37408/websockets-16.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d7f0659570eefb578dacde98e24fb60af35350193e4f56e11190787bee77dac", size = 178684, upload-time = "2026-01-10T09:22:44.941Z" }, + { url = "https://files.pythonhosted.org/packages/84/7b/bac442e6b96c9d25092695578dda82403c77936104b5682307bd4deb1ad4/websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00", size = 177365, upload-time = "2026-01-10T09:22:46.787Z" }, + { url = "https://files.pythonhosted.org/packages/b0/fe/136ccece61bd690d9c1f715baaeefd953bb2360134de73519d5df19d29ca/websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79", size = 175038, upload-time = "2026-01-10T09:22:47.999Z" }, + { url = "https://files.pythonhosted.org/packages/40/1e/9771421ac2286eaab95b8575b0cb701ae3663abf8b5e1f64f1fd90d0a673/websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39", size = 175328, upload-time = "2026-01-10T09:22:49.809Z" }, + { url = "https://files.pythonhosted.org/packages/18/29/71729b4671f21e1eaa5d6573031ab810ad2936c8175f03f97f3ff164c802/websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c", size = 184915, upload-time = "2026-01-10T09:22:51.071Z" }, + { url = "https://files.pythonhosted.org/packages/97/bb/21c36b7dbbafc85d2d480cd65df02a1dc93bf76d97147605a8e27ff9409d/websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f", size = 186152, upload-time = "2026-01-10T09:22:52.224Z" }, + { url = "https://files.pythonhosted.org/packages/4a/34/9bf8df0c0cf88fa7bfe36678dc7b02970c9a7d5e065a3099292db87b1be2/websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1", size = 185583, upload-time = "2026-01-10T09:22:53.443Z" }, + { url = "https://files.pythonhosted.org/packages/47/88/4dd516068e1a3d6ab3c7c183288404cd424a9a02d585efbac226cb61ff2d/websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2", size = 184880, upload-time = "2026-01-10T09:22:55.033Z" }, + { url = "https://files.pythonhosted.org/packages/91/d6/7d4553ad4bf1c0421e1ebd4b18de5d9098383b5caa1d937b63df8d04b565/websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89", size = 178261, upload-time = "2026-01-10T09:22:56.251Z" }, + { url = "https://files.pythonhosted.org/packages/c3/f0/f3a17365441ed1c27f850a80b2bc680a0fa9505d733fe152fdf5e98c1c0b/websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea", size = 178693, upload-time = "2026-01-10T09:22:57.478Z" }, + { url = "https://files.pythonhosted.org/packages/cc/9c/baa8456050d1c1b08dd0ec7346026668cbc6f145ab4e314d707bb845bf0d/websockets-16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878b336ac47938b474c8f982ac2f7266a540adc3fa4ad74ae96fea9823a02cc9", size = 177364, upload-time = "2026-01-10T09:22:59.333Z" }, + { url = "https://files.pythonhosted.org/packages/7e/0c/8811fc53e9bcff68fe7de2bcbe75116a8d959ac699a3200f4847a8925210/websockets-16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52a0fec0e6c8d9a784c2c78276a48a2bdf099e4ccc2a4cad53b27718dbfd0230", size = 175039, upload-time = "2026-01-10T09:23:01.171Z" }, + { url = "https://files.pythonhosted.org/packages/aa/82/39a5f910cb99ec0b59e482971238c845af9220d3ab9fa76dd9162cda9d62/websockets-16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e6578ed5b6981005df1860a56e3617f14a6c307e6a71b4fff8c48fdc50f3ed2c", size = 175323, upload-time = "2026-01-10T09:23:02.341Z" }, + { url = "https://files.pythonhosted.org/packages/bd/28/0a25ee5342eb5d5f297d992a77e56892ecb65e7854c7898fb7d35e9b33bd/websockets-16.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:95724e638f0f9c350bb1c2b0a7ad0e83d9cc0c9259f3ea94e40d7b02a2179ae5", size = 184975, upload-time = "2026-01-10T09:23:03.756Z" }, + { url = "https://files.pythonhosted.org/packages/f9/66/27ea52741752f5107c2e41fda05e8395a682a1e11c4e592a809a90c6a506/websockets-16.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0204dc62a89dc9d50d682412c10b3542d748260d743500a85c13cd1ee4bde82", size = 186203, upload-time = "2026-01-10T09:23:05.01Z" }, + { url = "https://files.pythonhosted.org/packages/37/e5/8e32857371406a757816a2b471939d51c463509be73fa538216ea52b792a/websockets-16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:52ac480f44d32970d66763115edea932f1c5b1312de36df06d6b219f6741eed8", size = 185653, upload-time = "2026-01-10T09:23:06.301Z" }, + { url = "https://files.pythonhosted.org/packages/9b/67/f926bac29882894669368dc73f4da900fcdf47955d0a0185d60103df5737/websockets-16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6e5a82b677f8f6f59e8dfc34ec06ca6b5b48bc4fcda346acd093694cc2c24d8f", size = 184920, upload-time = "2026-01-10T09:23:07.492Z" }, + { url = "https://files.pythonhosted.org/packages/3c/a1/3d6ccdcd125b0a42a311bcd15a7f705d688f73b2a22d8cf1c0875d35d34a/websockets-16.0-cp313-cp313-win32.whl", hash = "sha256:abf050a199613f64c886ea10f38b47770a65154dc37181bfaff70c160f45315a", size = 178255, upload-time = "2026-01-10T09:23:09.245Z" }, + { url = "https://files.pythonhosted.org/packages/6b/ae/90366304d7c2ce80f9b826096a9e9048b4bb760e44d3b873bb272cba696b/websockets-16.0-cp313-cp313-win_amd64.whl", hash = "sha256:3425ac5cf448801335d6fdc7ae1eb22072055417a96cc6b31b3861f455fbc156", size = 178689, upload-time = "2026-01-10T09:23:10.483Z" }, + { url = "https://files.pythonhosted.org/packages/f3/1d/e88022630271f5bd349ed82417136281931e558d628dd52c4d8621b4a0b2/websockets-16.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8cc451a50f2aee53042ac52d2d053d08bf89bcb31ae799cb4487587661c038a0", size = 177406, upload-time = "2026-01-10T09:23:12.178Z" }, + { url = "https://files.pythonhosted.org/packages/f2/78/e63be1bf0724eeb4616efb1ae1c9044f7c3953b7957799abb5915bffd38e/websockets-16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:daa3b6ff70a9241cf6c7fc9e949d41232d9d7d26fd3522b1ad2b4d62487e9904", size = 175085, upload-time = "2026-01-10T09:23:13.511Z" }, + { url = "https://files.pythonhosted.org/packages/bb/f4/d3c9220d818ee955ae390cf319a7c7a467beceb24f05ee7aaaa2414345ba/websockets-16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:fd3cb4adb94a2a6e2b7c0d8d05cb94e6f1c81a0cf9dc2694fb65c7e8d94c42e4", size = 175328, upload-time = "2026-01-10T09:23:14.727Z" }, + { url = "https://files.pythonhosted.org/packages/63/bc/d3e208028de777087e6fb2b122051a6ff7bbcca0d6df9d9c2bf1dd869ae9/websockets-16.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:781caf5e8eee67f663126490c2f96f40906594cb86b408a703630f95550a8c3e", size = 185044, upload-time = "2026-01-10T09:23:15.939Z" }, + { url = "https://files.pythonhosted.org/packages/ad/6e/9a0927ac24bd33a0a9af834d89e0abc7cfd8e13bed17a86407a66773cc0e/websockets-16.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:caab51a72c51973ca21fa8a18bd8165e1a0183f1ac7066a182ff27107b71e1a4", size = 186279, upload-time = "2026-01-10T09:23:17.148Z" }, + { url = "https://files.pythonhosted.org/packages/b9/ca/bf1c68440d7a868180e11be653c85959502efd3a709323230314fda6e0b3/websockets-16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19c4dc84098e523fd63711e563077d39e90ec6702aff4b5d9e344a60cb3c0cb1", size = 185711, upload-time = "2026-01-10T09:23:18.372Z" }, + { url = "https://files.pythonhosted.org/packages/c4/f8/fdc34643a989561f217bb477cbc47a3a07212cbda91c0e4389c43c296ebf/websockets-16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a5e18a238a2b2249c9a9235466b90e96ae4795672598a58772dd806edc7ac6d3", size = 184982, upload-time = "2026-01-10T09:23:19.652Z" }, + { url = "https://files.pythonhosted.org/packages/dd/d1/574fa27e233764dbac9c52730d63fcf2823b16f0856b3329fc6268d6ae4f/websockets-16.0-cp314-cp314-win32.whl", hash = "sha256:a069d734c4a043182729edd3e9f247c3b2a4035415a9172fd0f1b71658a320a8", size = 177915, upload-time = "2026-01-10T09:23:21.458Z" }, + { url = "https://files.pythonhosted.org/packages/8a/f1/ae6b937bf3126b5134ce1f482365fde31a357c784ac51852978768b5eff4/websockets-16.0-cp314-cp314-win_amd64.whl", hash = "sha256:c0ee0e63f23914732c6d7e0cce24915c48f3f1512ec1d079ed01fc629dab269d", size = 178381, upload-time = "2026-01-10T09:23:22.715Z" }, + { url = "https://files.pythonhosted.org/packages/06/9b/f791d1db48403e1f0a27577a6beb37afae94254a8c6f08be4a23e4930bc0/websockets-16.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:a35539cacc3febb22b8f4d4a99cc79b104226a756aa7400adc722e83b0d03244", size = 177737, upload-time = "2026-01-10T09:23:24.523Z" }, + { url = "https://files.pythonhosted.org/packages/bd/40/53ad02341fa33b3ce489023f635367a4ac98b73570102ad2cdd770dacc9a/websockets-16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b784ca5de850f4ce93ec85d3269d24d4c82f22b7212023c974c401d4980ebc5e", size = 175268, upload-time = "2026-01-10T09:23:25.781Z" }, + { url = "https://files.pythonhosted.org/packages/74/9b/6158d4e459b984f949dcbbb0c5d270154c7618e11c01029b9bbd1bb4c4f9/websockets-16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:569d01a4e7fba956c5ae4fc988f0d4e187900f5497ce46339c996dbf24f17641", size = 175486, upload-time = "2026-01-10T09:23:27.033Z" }, + { url = "https://files.pythonhosted.org/packages/e5/2d/7583b30208b639c8090206f95073646c2c9ffd66f44df967981a64f849ad/websockets-16.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:50f23cdd8343b984957e4077839841146f67a3d31ab0d00e6b824e74c5b2f6e8", size = 185331, upload-time = "2026-01-10T09:23:28.259Z" }, + { url = "https://files.pythonhosted.org/packages/45/b0/cce3784eb519b7b5ad680d14b9673a31ab8dcb7aad8b64d81709d2430aa8/websockets-16.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:152284a83a00c59b759697b7f9e9cddf4e3c7861dd0d964b472b70f78f89e80e", size = 186501, upload-time = "2026-01-10T09:23:29.449Z" }, + { url = "https://files.pythonhosted.org/packages/19/60/b8ebe4c7e89fb5f6cdf080623c9d92789a53636950f7abacfc33fe2b3135/websockets-16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bc59589ab64b0022385f429b94697348a6a234e8ce22544e3681b2e9331b5944", size = 186062, upload-time = "2026-01-10T09:23:31.368Z" }, + { url = "https://files.pythonhosted.org/packages/88/a8/a080593f89b0138b6cba1b28f8df5673b5506f72879322288b031337c0b8/websockets-16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:32da954ffa2814258030e5a57bc73a3635463238e797c7375dc8091327434206", size = 185356, upload-time = "2026-01-10T09:23:32.627Z" }, + { url = "https://files.pythonhosted.org/packages/c2/b6/b9afed2afadddaf5ebb2afa801abf4b0868f42f8539bfe4b071b5266c9fe/websockets-16.0-cp314-cp314t-win32.whl", hash = "sha256:5a4b4cc550cb665dd8a47f868c8d04c8230f857363ad3c9caf7a0c3bf8c61ca6", size = 178085, upload-time = "2026-01-10T09:23:33.816Z" }, + { url = "https://files.pythonhosted.org/packages/9f/3e/28135a24e384493fa804216b79a6a6759a38cc4ff59118787b9fb693df93/websockets-16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b14dc141ed6d2dde437cddb216004bcac6a1df0935d79656387bd41632ba0bbd", size = 178531, upload-time = "2026-01-10T09:23:35.016Z" }, + { url = "https://files.pythonhosted.org/packages/72/07/c98a68571dcf256e74f1f816b8cc5eae6eb2d3d5cfa44d37f801619d9166/websockets-16.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:349f83cd6c9a415428ee1005cadb5c2c56f4389bc06a9af16103c3bc3dcc8b7d", size = 174947, upload-time = "2026-01-10T09:23:36.166Z" }, + { url = "https://files.pythonhosted.org/packages/7e/52/93e166a81e0305b33fe416338be92ae863563fe7bce446b0f687b9df5aea/websockets-16.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:4a1aba3340a8dca8db6eb5a7986157f52eb9e436b74813764241981ca4888f03", size = 175260, upload-time = "2026-01-10T09:23:37.409Z" }, + { url = "https://files.pythonhosted.org/packages/56/0c/2dbf513bafd24889d33de2ff0368190a0e69f37bcfa19009ef819fe4d507/websockets-16.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f4a32d1bd841d4bcbffdcb3d2ce50c09c3909fbead375ab28d0181af89fd04da", size = 176071, upload-time = "2026-01-10T09:23:39.158Z" }, + { url = "https://files.pythonhosted.org/packages/a5/8f/aea9c71cc92bf9b6cc0f7f70df8f0b420636b6c96ef4feee1e16f80f75dd/websockets-16.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0298d07ee155e2e9fda5be8a9042200dd2e3bb0b8a38482156576f863a9d457c", size = 176968, upload-time = "2026-01-10T09:23:41.031Z" }, + { url = "https://files.pythonhosted.org/packages/9a/3f/f70e03f40ffc9a30d817eef7da1be72ee4956ba8d7255c399a01b135902a/websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767", size = 178735, upload-time = "2026-01-10T09:23:42.259Z" }, + { url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598, upload-time = "2026-01-10T09:23:45.395Z" }, +] + [[package]] name = "werkzeug" version = "3.1.8" @@ -3382,6 +3479,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/93/8c/2e650f2afeb7ee576912636c23ddb621c91ac6a98e66dc8d29c3c69446e1/werkzeug-3.1.8-py3-none-any.whl", hash = "sha256:63a77fb8892bf28ebc3178683445222aa500e48ebad5ec77b0ad80f8726b1f50", size = 226459, upload-time = "2026-04-02T18:49:12.72Z" }, ] +[[package]] +name = "wsproto" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/79/12135bdf8b9c9367b8701c2c19a14c913c120b882d50b014ca0d38083c2c/wsproto-1.3.2.tar.gz", hash = "sha256:b86885dcf294e15204919950f666e06ffc6c7c114ca900b060d6e16293528294", size = 50116, upload-time = "2025-11-20T18:18:01.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/f5/10b68b7b1544245097b2a1b8238f66f2fc6dcaeb24ba5d917f52bd2eed4f/wsproto-1.3.2-py3-none-any.whl", hash = "sha256:61eea322cdf56e8cc904bd3ad7573359a242ba65688716b0710a5eb12beab584", size = 24405, upload-time = "2025-11-20T18:18:00.454Z" }, +] + [[package]] name = "yarl" version = "1.23.0"