From b1363ca1978efdfcda483006b14ad81aa8f2337b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 22 Nov 2022 09:17:03 +0100 Subject: [PATCH 001/108] Boostrap project --- src/httpx2/httpx2/_websockets/LICENSE | 21 +++++++++++++++++++++ src/httpx2/httpx2/_websockets/__init__.py | 1 + src/httpx2/httpx2/_websockets/py.typed | 0 tests/httpx2/websockets/__init__.py | 0 tests/httpx2/websockets/conftest.py | 10 ++++++++++ 5 files changed, 32 insertions(+) create mode 100644 src/httpx2/httpx2/_websockets/LICENSE create mode 100644 src/httpx2/httpx2/_websockets/__init__.py create mode 100644 src/httpx2/httpx2/_websockets/py.typed create mode 100644 tests/httpx2/websockets/__init__.py create mode 100644 tests/httpx2/websockets/conftest.py diff --git a/src/httpx2/httpx2/_websockets/LICENSE b/src/httpx2/httpx2/_websockets/LICENSE new file mode 100644 index 00000000..3ea977c7 --- /dev/null +++ b/src/httpx2/httpx2/_websockets/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 François Voron + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py new file mode 100644 index 00000000..6c8e6b97 --- /dev/null +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -0,0 +1 @@ +__version__ = "0.0.0" diff --git a/src/httpx2/httpx2/_websockets/py.typed b/src/httpx2/httpx2/_websockets/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/tests/httpx2/websockets/__init__.py b/tests/httpx2/websockets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py new file mode 100644 index 00000000..ef30ed3a --- /dev/null +++ b/tests/httpx2/websockets/conftest.py @@ -0,0 +1,10 @@ +import asyncio + +import pytest + + +@pytest.fixture(scope="session") +def event_loop(): + """Force the pytest-asyncio loop to be the main one.""" + loop = asyncio.get_event_loop() + yield loop From 0e91d2befa73d833e89afd370433dbf07404abd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 22 Nov 2022 15:54:56 +0100 Subject: [PATCH 002/108] Start implementing API and ASGI transport --- src/httpx2/httpx2/_websockets/__init__.py | 4 + src/httpx2/httpx2/_websockets/_api.py | 69 +++++++++ src/httpx2/httpx2/_websockets/transport.py | 157 +++++++++++++++++++++ tests/httpx2/websockets/conftest.py | 21 +++ tests/httpx2/websockets/test_api.py | 83 +++++++++++ 5 files changed, 334 insertions(+) create mode 100644 src/httpx2/httpx2/_websockets/_api.py create mode 100644 src/httpx2/httpx2/_websockets/transport.py create mode 100644 tests/httpx2/websockets/test_api.py diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 6c8e6b97..3a84b684 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1 +1,5 @@ __version__ = "0.0.0" + +from httpx_ws._api import WebSocketDisconnect, WebSocketSession, aconnect_ws + +__all__ = ["aconnect_ws", "WebSocketSession", "WebSocketDisconnect"] diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py new file mode 100644 index 00000000..6c5f66bf --- /dev/null +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -0,0 +1,69 @@ +import base64 +import contextlib +import secrets +import typing + +import httpx +import wsproto + + +class HTTPXWSException(Exception): + pass + + +class WebSocketUpgradeError(HTTPXWSException): + def __init__(self, response: httpx.Response) -> None: + self.response = response + + +class WebSocketDisconnect(HTTPXWSException): + def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: + self.code = code + self.reason = reason or "" + + +class WebSocketSession: + def __init__(self, response: httpx.Response) -> None: + self.stream = response.extensions["network_stream"] + self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) + + async def send(self, message: str): + event = wsproto.events.TextMessage(message) + await self._send_event(event) + + async def receive(self): + data = await self.stream.read(max_bytes=4096) + self.connection.receive_data(data) + for event in self.connection.events(): + if isinstance(event, wsproto.events.CloseConnection): + raise WebSocketDisconnect(event.code, event.reason) + return event.data + + async def close(self, code: int = 1000, reason: typing.Optional[str] = None): + event = wsproto.events.CloseConnection(code, reason) + await self._send_event(event) + + async def _send_event(self, event: wsproto.events.Event): + data = self.connection.send(event) + await self.stream.write(data) + + +@contextlib.asynccontextmanager +async def aconnect_ws( + client: httpx.AsyncClient, url: str, **kwargs: typing.Any +) -> typing.AsyncGenerator[WebSocketSession, None]: + headers = kwargs.pop("headers", {}) + headers["connection"] = "upgrade" + headers["upgrade"] = "websocket" + headers["sec-websocket-key"] = base64.b64encode(secrets.token_bytes(16)).decode( + "utf-8" + ) + headers["sec-websocket-version"] = "13" + + async with client.stream("GET", url, headers=headers) as response: + if response.status_code != 101: + raise WebSocketUpgradeError(response) + + session = WebSocketSession(response) + yield session + await session.close() diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py new file mode 100644 index 00000000..effd67ab --- /dev/null +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -0,0 +1,157 @@ +import contextlib +import queue +import typing +from concurrent.futures import Future + +import anyio +import wsproto +from httpcore.backends.base import AsyncNetworkStream +from httpx import ASGITransport, AsyncByteStream, Request, Response + +Scope = typing.MutableMapping[str, typing.Any] +Message = typing.MutableMapping[str, typing.Any] +Receive = typing.Callable[[], typing.Awaitable[Message]] +Send = typing.Callable[[Message], typing.Awaitable[None]] +ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] + + +class ASGIWebSocketAsyncNetworkStream(AsyncNetworkStream): + def __init__(self, app: ASGIApp, scope: Scope) -> None: + self.app = app + self.scope = scope + self._receive_queue: queue.Queue[Message] = queue.Queue() + self._send_queue: queue.Queue[Message] = queue.Queue() + self.connection = wsproto.connection.Connection(wsproto.connection.SERVER) + + async def __aenter__(self) -> "ASGIWebSocketAsyncNetworkStream": + self.exit_stack = contextlib.ExitStack() + self.portal = self.exit_stack.enter_context( + anyio.start_blocking_portal("asyncio") + ) + try: + _: "Future[None]" = self.portal.start_task_soon(self._run) + await self.send({"type": "websocket.connect"}) + message = await self.receive() + assert message["type"] == "websocket.accept" + except Exception: + raise + return self + + async def __aexit__(self, *args: typing.Any) -> None: + await self.aclose() + + async def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + message: Message = await self.receive(timeout=timeout) + type = message["type"] + + if type not in {"websocket.send", "websocket.close"}: + raise ValueError("Unknown message", message) + + event: wsproto.events.Event + if type == "websocket.send": + data_str: typing.Optional[str] = message.get("text") + if data_str is not None: + event = wsproto.events.TextMessage(data_str) + data_bytes: typing.Optional[bytes] = message.get("bytes") + if data_bytes is not None: + event = wsproto.events.BytesMessage(data_bytes) + elif type == "websocket.close": + event = wsproto.events.CloseConnection(message["code"], message["reason"]) + + return self.connection.send(event) + + async def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + self.connection.receive_data(buffer) + for event in self.connection.events(): + if isinstance(event, wsproto.events.CloseConnection): + await self.send( + { + "type": "websocket.close", + "code": event.code, + "reason": event.reason, + } + ) + elif isinstance(event, wsproto.events.TextMessage): + await self.send({"type": "websocket.receive", "text": event.data}) + elif isinstance(event, wsproto.events.BytesMessage): + await self.send({"type": "websocket.receive", "bytes": event.data}) + else: + raise ValueError("Unhandled event", event) + + async def aclose(self) -> None: + self.exit_stack.close() + + def get_extra_info(self, info: str) -> typing.Any: + return None # pragma: nocover + + async def send(self, message: Message) -> None: + self._receive_queue.put(message) + + async def receive(self, timeout: typing.Optional[float] = None) -> Message: + return self._send_queue.get(timeout=timeout) + + async def _run(self) -> None: + """ + The sub-thread in which the websocket session runs. + """ + scope = self.scope + receive = self._asgi_receive + send = self._asgi_send + await self.app(scope, receive, send) + + async def _asgi_receive(self) -> Message: + while self._receive_queue.empty(): + await anyio.sleep(0) + return self._receive_queue.get() + + async def _asgi_send(self, message: Message) -> None: + self._send_queue.put(message) + + +class ASGIWebSocketTransport(ASGITransport): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.exit_stack: typing.Optional[contextlib.AsyncExitStack] = None + + async def handle_async_request(self, request: Request) -> Response: + scheme = request.url.scheme + headers = request.headers + + if scheme in {"ws", "wss"} or headers.get("upgrade") == "websocket": + scope = { + "type": "websocket", + "path": request.url.path, + "raw_path": request.url.raw_path, + "root_path": self.root_path, + "scheme": scheme, + "query_string": request.url.query, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "client": self.client, + "server": (request.url.host, request.url.port), + } + return await self._handle_ws_request(request, scope) + + return await super().handle_async_request(request) + + async def _handle_ws_request( + self, + request: Request, + scope: Scope, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + self.scope = scope + self.exit_stack = contextlib.AsyncExitStack() + stream = await self.exit_stack.enter_async_context( + ASGIWebSocketAsyncNetworkStream(self.app, self.scope) + ) + + return Response(101, extensions={"network_stream": stream}) + + async def aclose(self) -> None: + if self.exit_stack: + await self.exit_stack.aclose() diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index ef30ed3a..6e440ec4 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -1,6 +1,11 @@ import asyncio +from typing import Callable +from unittest.mock import MagicMock import pytest +from starlette.applications import Starlette +from starlette.routing import WebSocketRoute +from starlette.types import ASGIApp @pytest.fixture(scope="session") @@ -8,3 +13,19 @@ def event_loop(): """Force the pytest-asyncio loop to be the main one.""" loop = asyncio.get_event_loop() yield loop + + +@pytest.fixture +def on_receive_message(): + return MagicMock() + + +@pytest.fixture +def websocket_app_factory() -> Callable[[Callable], ASGIApp]: + def _websocket_app_factory(endpoint: Callable): + routes = [ + WebSocketRoute("/ws", endpoint=endpoint), + ] + return Starlette(routes=routes) + + return _websocket_app_factory diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py new file mode 100644 index 00000000..d472194d --- /dev/null +++ b/tests/httpx2/websockets/test_api.py @@ -0,0 +1,83 @@ +from typing import Callable +from unittest.mock import MagicMock + +import httpx +import pytest +from starlette.types import ASGIApp +from starlette.websockets import WebSocket + +from httpx_ws import WebSocketDisconnect, aconnect_ws +from httpx_ws.transport import ASGIWebSocketTransport + + +@pytest.mark.asyncio +async def test_send_message( + websocket_app_factory: Callable[[Callable], ASGIApp], on_receive_message: MagicMock +): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + message = await websocket.receive_text() + on_receive_message(message) + + await websocket.close() + + app = websocket_app_factory(websocket_endpoint) + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) + ) as client: + async with aconnect_ws(client, "/ws") as ws: + await ws.send("CLIENT_MESSAGE") + + on_receive_message.assert_called_once_with("CLIENT_MESSAGE") + + +@pytest.mark.asyncio +async def test_receive_message(websocket_app_factory: Callable[[Callable], ASGIApp]): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + await websocket.send_text("SERVER_MESSAGE") + + await websocket.close() + + app = websocket_app_factory(websocket_endpoint) + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) + ) as client: + async with aconnect_ws(client, "/ws") as ws: + message = await ws.receive() + assert message == "SERVER_MESSAGE" + + +@pytest.mark.asyncio +async def test_send_close(websocket_app_factory: Callable[[Callable], ASGIApp]): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.receive_text() + + app = websocket_app_factory(websocket_endpoint) + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) + ) as client: + async with aconnect_ws(client, "/ws"): + pass + + +@pytest.mark.asyncio +async def test_receive_close(websocket_app_factory: Callable[[Callable], ASGIApp]): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.close() + + app = websocket_app_factory(websocket_endpoint) + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) + ) as client: + with pytest.raises(WebSocketDisconnect): + async with aconnect_ws(client, "/ws") as ws: + await ws.receive() From ecf30b6d8104fd702d9bc7430541e2a784c3f409 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 22 Nov 2022 16:39:35 +0100 Subject: [PATCH 003/108] Improve transport exceptions --- src/httpx2/httpx2/_websockets/transport.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index effd67ab..69a84647 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -15,6 +15,20 @@ ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] +class ASGIWebSocketTransportError(Exception): + pass + + +class UnhandledASGIMessageType(ASGIWebSocketTransportError): + def __init__(self, message: Message) -> None: + self.message = message + + +class UnhandledWebSocketEvent(ASGIWebSocketTransportError): + def __init__(self, event: wsproto.events.Event) -> None: + self.event = event + + class ASGIWebSocketAsyncNetworkStream(AsyncNetworkStream): def __init__(self, app: ASGIApp, scope: Scope) -> None: self.app = app @@ -47,7 +61,7 @@ async def read( type = message["type"] if type not in {"websocket.send", "websocket.close"}: - raise ValueError("Unknown message", message) + raise UnhandledASGIMessageType(message) event: wsproto.events.Event if type == "websocket.send": @@ -80,14 +94,11 @@ async def write( elif isinstance(event, wsproto.events.BytesMessage): await self.send({"type": "websocket.receive", "bytes": event.data}) else: - raise ValueError("Unhandled event", event) + raise UnhandledWebSocketEvent(event) async def aclose(self) -> None: self.exit_stack.close() - def get_extra_info(self, info: str) -> typing.Any: - return None # pragma: nocover - async def send(self, message: Message) -> None: self._receive_queue.put(message) From 3499bbb911ad178d4478e7f5950df4aa2de1b208 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 22 Nov 2022 16:44:19 +0100 Subject: [PATCH 004/108] Add test for upgrade error --- src/httpx2/httpx2/_websockets/__init__.py | 16 ++++++++++++++-- tests/httpx2/websockets/test_api.py | 15 ++++++++++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 3a84b684..bc2c17fe 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,5 +1,17 @@ __version__ = "0.0.0" -from httpx_ws._api import WebSocketDisconnect, WebSocketSession, aconnect_ws +from httpx_ws._api import ( + HTTPXWSException, + WebSocketDisconnect, + WebSocketSession, + WebSocketUpgradeError, + aconnect_ws, +) -__all__ = ["aconnect_ws", "WebSocketSession", "WebSocketDisconnect"] +__all__ = [ + "HTTPXWSException", + "WebSocketDisconnect", + "WebSocketSession", + "WebSocketUpgradeError", + "aconnect_ws", +] diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index d472194d..4afdc7d7 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -6,10 +6,23 @@ from starlette.types import ASGIApp from starlette.websockets import WebSocket -from httpx_ws import WebSocketDisconnect, aconnect_ws +from httpx_ws import WebSocketDisconnect, WebSocketUpgradeError, aconnect_ws from httpx_ws.transport import ASGIWebSocketTransport +@pytest.mark.asyncio +async def test_upgrade_error(): + def handler(request): + return httpx.Response(400) + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=httpx.MockTransport(handler) + ) as client: + with pytest.raises(WebSocketUpgradeError): + async with aconnect_ws(client, "/ws"): + pass + + @pytest.mark.asyncio async def test_send_message( websocket_app_factory: Callable[[Callable], ASGIApp], on_receive_message: MagicMock From 2b673d42a0de055537dfa1a936dd131bb020bea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 22 Nov 2022 16:50:03 +0100 Subject: [PATCH 005/108] Change API so base methods work with pure wsproto Event --- src/httpx2/httpx2/_websockets/_api.py | 7 +++---- tests/httpx2/websockets/test_api.py | 8 +++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 6c5f66bf..2043291e 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -27,17 +27,16 @@ def __init__(self, response: httpx.Response) -> None: self.stream = response.extensions["network_stream"] self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) - async def send(self, message: str): - event = wsproto.events.TextMessage(message) + async def send(self, event: wsproto.events.Event): await self._send_event(event) - async def receive(self): + async def receive(self) -> wsproto.events.Event: data = await self.stream.read(max_bytes=4096) self.connection.receive_data(data) for event in self.connection.events(): if isinstance(event, wsproto.events.CloseConnection): raise WebSocketDisconnect(event.code, event.reason) - return event.data + return event async def close(self, code: int = 1000, reason: typing.Optional[str] = None): event = wsproto.events.CloseConnection(code, reason) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 4afdc7d7..b1c61ba9 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -3,6 +3,7 @@ import httpx import pytest +import wsproto from starlette.types import ASGIApp from starlette.websockets import WebSocket @@ -41,7 +42,7 @@ async def websocket_endpoint(websocket: WebSocket): base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) ) as client: async with aconnect_ws(client, "/ws") as ws: - await ws.send("CLIENT_MESSAGE") + await ws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) on_receive_message.assert_called_once_with("CLIENT_MESSAGE") @@ -61,8 +62,9 @@ async def websocket_endpoint(websocket: WebSocket): base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) ) as client: async with aconnect_ws(client, "/ws") as ws: - message = await ws.receive() - assert message == "SERVER_MESSAGE" + event = await ws.receive() + assert isinstance(event, wsproto.events.TextMessage) + assert event.data == "SERVER_MESSAGE" @pytest.mark.asyncio From 4a815b095239c0a1b60e96268ad53867bace854e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 22 Nov 2022 17:09:36 +0100 Subject: [PATCH 006/108] Add helper send methods --- src/httpx2/httpx2/_websockets/_api.py | 20 ++++- tests/httpx2/websockets/test_api.py | 109 +++++++++++++++++++++++--- 2 files changed, 118 insertions(+), 11 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 2043291e..3d3512a1 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -1,5 +1,6 @@ import base64 import contextlib +import json import secrets import typing @@ -27,9 +28,25 @@ def __init__(self, response: httpx.Response) -> None: self.stream = response.extensions["network_stream"] self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) - async def send(self, event: wsproto.events.Event): + async def send(self, event: wsproto.events.Event) -> None: await self._send_event(event) + async def send_text(self, data: str) -> None: + event = wsproto.events.TextMessage(data=data) + await self.send(event) + + async def send_bytes(self, data: bytes) -> None: + event = wsproto.events.BytesMessage(data=data) + await self.send(event) + + async def send_json(self, data: typing.Any, mode: str = "text") -> None: + assert mode in ["text", "binary"] + serialized_data = json.dumps(data) + if mode == "text": + await self.send_text(serialized_data) + else: + await self.send_bytes(serialized_data.encode("utf-8")) + async def receive(self) -> wsproto.events.Event: data = await self.stream.read(max_bytes=4096) self.connection.receive_data(data) @@ -37,6 +54,7 @@ async def receive(self) -> wsproto.events.Event: if isinstance(event, wsproto.events.CloseConnection): raise WebSocketDisconnect(event.code, event.reason) return event + raise HTTPXWSException() # pragma: no cover async def close(self, code: int = 1000, reason: typing.Optional[str] = None): event = wsproto.events.CloseConnection(code, reason) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index b1c61ba9..d1f79136 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -24,10 +24,10 @@ def handler(request): pass -@pytest.mark.asyncio -async def test_send_message( +@pytest.fixture +def send_app( websocket_app_factory: Callable[[Callable], ASGIApp], on_receive_message: MagicMock -): +) -> ASGIApp: async def websocket_endpoint(websocket: WebSocket): await websocket.accept() @@ -36,15 +36,104 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - app = websocket_app_factory(websocket_endpoint) + return websocket_app_factory(websocket_endpoint) - async with httpx.AsyncClient( - base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) - ) as client: - async with aconnect_ws(client, "/ws") as ws: - await ws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) - on_receive_message.assert_called_once_with("CLIENT_MESSAGE") +@pytest.mark.asyncio +class TestSend: + async def test_send( + self, + websocket_app_factory: Callable[[Callable], ASGIApp], + on_receive_message: MagicMock, + ): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + message = await websocket.receive_text() + on_receive_message(message) + + await websocket.close() + + app = websocket_app_factory(websocket_endpoint) + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) + ) as client: + async with aconnect_ws(client, "/ws") as ws: + await ws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) + + on_receive_message.assert_called_once_with("CLIENT_MESSAGE") + + async def test_send_text( + self, + websocket_app_factory: Callable[[Callable], ASGIApp], + on_receive_message: MagicMock, + ): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + message = await websocket.receive_text() + on_receive_message(message) + + await websocket.close() + + app = websocket_app_factory(websocket_endpoint) + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) + ) as client: + async with aconnect_ws(client, "/ws") as ws: + await ws.send_text("CLIENT_MESSAGE") + + on_receive_message.assert_called_once_with("CLIENT_MESSAGE") + + async def test_send_bytes( + self, + websocket_app_factory: Callable[[Callable], ASGIApp], + on_receive_message: MagicMock, + ): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + message = await websocket.receive_bytes() + on_receive_message(message) + + await websocket.close() + + app = websocket_app_factory(websocket_endpoint) + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) + ) as client: + async with aconnect_ws(client, "/ws") as ws: + await ws.send_bytes(b"CLIENT_MESSAGE") + + on_receive_message.assert_called_once_with(b"CLIENT_MESSAGE") + + @pytest.mark.parametrize("mode", ["text", "binary"]) + async def test_send_json( + self, + mode: str, + websocket_app_factory: Callable[[Callable], ASGIApp], + on_receive_message: MagicMock, + ): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + message = await websocket.receive_json(mode=mode) + on_receive_message(message) + + await websocket.close() + + app = websocket_app_factory(websocket_endpoint) + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) + ) as client: + async with aconnect_ws(client, "/ws") as ws: + await ws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) + + on_receive_message.assert_called_once_with({"message": "CLIENT_MESSAGE"}) @pytest.mark.asyncio From a5d89720639d8c141ac62aa895c0b3c36ac257c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 22 Nov 2022 17:23:52 +0100 Subject: [PATCH 007/108] Add receive methods helpers --- src/httpx2/httpx2/_websockets/__init__.py | 4 + src/httpx2/httpx2/_websockets/_api.py | 36 +++++- tests/httpx2/websockets/test_api.py | 133 +++++++++++++++++++--- 3 files changed, 157 insertions(+), 16 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index bc2c17fe..54018dec 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -2,7 +2,9 @@ from httpx_ws._api import ( HTTPXWSException, + JSONMode, WebSocketDisconnect, + WebSocketInvalidTypeReceived, WebSocketSession, WebSocketUpgradeError, aconnect_ws, @@ -10,7 +12,9 @@ __all__ = [ "HTTPXWSException", + "JSONMode", "WebSocketDisconnect", + "WebSocketInvalidTypeReceived", "WebSocketSession", "WebSocketUpgradeError", "aconnect_ws", diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 3d3512a1..5de42373 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -2,11 +2,19 @@ import contextlib import json import secrets +import sys import typing +if sys.version_info < (3, 8): + from typing_extensions import Literal # pragma: no cover +else: + from typing import Literal # pragma: no cover + import httpx import wsproto +JSONMode = Literal["text", "binary"] + class HTTPXWSException(Exception): pass @@ -23,6 +31,11 @@ def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> Non self.reason = reason or "" +class WebSocketInvalidTypeReceived(HTTPXWSException): + def __init__(self, event: wsproto.events.Event) -> None: + self.event = event + + class WebSocketSession: def __init__(self, response: httpx.Response) -> None: self.stream = response.extensions["network_stream"] @@ -39,7 +52,7 @@ async def send_bytes(self, data: bytes) -> None: event = wsproto.events.BytesMessage(data=data) await self.send(event) - async def send_json(self, data: typing.Any, mode: str = "text") -> None: + async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: assert mode in ["text", "binary"] serialized_data = json.dumps(data) if mode == "text": @@ -56,6 +69,27 @@ async def receive(self) -> wsproto.events.Event: return event raise HTTPXWSException() # pragma: no cover + async def receive_text(self) -> str: + event = await self.receive() + if isinstance(event, wsproto.events.TextMessage): + return event.data + raise WebSocketInvalidTypeReceived(event) + + async def receive_bytes(self) -> bytes: + event = await self.receive() + if isinstance(event, wsproto.events.BytesMessage): + return event.data + raise WebSocketInvalidTypeReceived(event) + + async def receive_json(self, mode: JSONMode = "text") -> typing.Any: + assert mode in ["text", "binary"] + data: typing.Union[str, bytes] + if mode == "text": + data = await self.receive_text() + elif mode == "binary": + data = await self.receive_bytes() + return json.loads(data) + async def close(self, code: int = 1000, reason: typing.Optional[str] = None): event = wsproto.events.CloseConnection(code, reason) await self._send_event(event) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index d1f79136..0ef5baf4 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -7,7 +7,13 @@ from starlette.types import ASGIApp from starlette.websockets import WebSocket -from httpx_ws import WebSocketDisconnect, WebSocketUpgradeError, aconnect_ws +from httpx_ws import ( + JSONMode, + WebSocketDisconnect, + WebSocketInvalidTypeReceived, + WebSocketUpgradeError, + aconnect_ws, +) from httpx_ws.transport import ASGIWebSocketTransport @@ -113,7 +119,7 @@ async def websocket_endpoint(websocket: WebSocket): @pytest.mark.parametrize("mode", ["text", "binary"]) async def test_send_json( self, - mode: str, + mode: JSONMode, websocket_app_factory: Callable[[Callable], ASGIApp], on_receive_message: MagicMock, ): @@ -137,23 +143,120 @@ async def websocket_endpoint(websocket: WebSocket): @pytest.mark.asyncio -async def test_receive_message(websocket_app_factory: Callable[[Callable], ASGIApp]): - async def websocket_endpoint(websocket: WebSocket): - await websocket.accept() +class TestReceive: + async def test_receive(self, websocket_app_factory: Callable[[Callable], ASGIApp]): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() - await websocket.send_text("SERVER_MESSAGE") + await websocket.send_text("SERVER_MESSAGE") - await websocket.close() + await websocket.close() - app = websocket_app_factory(websocket_endpoint) + app = websocket_app_factory(websocket_endpoint) - async with httpx.AsyncClient( - base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) - ) as client: - async with aconnect_ws(client, "/ws") as ws: - event = await ws.receive() - assert isinstance(event, wsproto.events.TextMessage) - assert event.data == "SERVER_MESSAGE" + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) + ) as client: + async with aconnect_ws(client, "/ws") as ws: + event = await ws.receive() + assert isinstance(event, wsproto.events.TextMessage) + assert event.data == "SERVER_MESSAGE" + + async def test_receive_text( + self, websocket_app_factory: Callable[[Callable], ASGIApp] + ): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + await websocket.send_text("SERVER_MESSAGE") + + await websocket.close() + + app = websocket_app_factory(websocket_endpoint) + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) + ) as client: + async with aconnect_ws(client, "/ws") as ws: + data = await ws.receive_text() + assert data == "SERVER_MESSAGE" + + async def test_receive_text_invalid_type( + self, websocket_app_factory: Callable[[Callable], ASGIApp] + ): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + await websocket.send_bytes(b"SERVER_MESSAGE") + + await websocket.close() + + app = websocket_app_factory(websocket_endpoint) + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) + ) as client: + with pytest.raises(WebSocketInvalidTypeReceived): + async with aconnect_ws(client, "/ws") as ws: + await ws.receive_text() + + async def test_receive_bytes( + self, websocket_app_factory: Callable[[Callable], ASGIApp] + ): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + await websocket.send_bytes(b"SERVER_MESSAGE") + + await websocket.close() + + app = websocket_app_factory(websocket_endpoint) + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) + ) as client: + async with aconnect_ws(client, "/ws") as ws: + data = await ws.receive_bytes() + assert data == b"SERVER_MESSAGE" + + async def test_receive_bytes_invalid_type( + self, websocket_app_factory: Callable[[Callable], ASGIApp] + ): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + await websocket.send_text("SERVER_MESSAGE") + + await websocket.close() + + app = websocket_app_factory(websocket_endpoint) + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) + ) as client: + with pytest.raises(WebSocketInvalidTypeReceived): + async with aconnect_ws(client, "/ws") as ws: + await ws.receive_bytes() + + @pytest.mark.parametrize("mode", ["text", "binary"]) + async def test_receive_json( + self, mode: JSONMode, websocket_app_factory: Callable[[Callable], ASGIApp] + ): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + await websocket.send_json({"message": "SERVER_MESSAGE"}, mode=mode) + + await websocket.close() + + app = websocket_app_factory(websocket_endpoint) + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) + ) as client: + async with aconnect_ws(client, "/ws") as ws: + data = await ws.receive_json(mode=mode) + assert data == {"message": "SERVER_MESSAGE"} @pytest.mark.asyncio From aa636c6760c76305b68a30c04aa2cdcbe61fdeef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 22 Nov 2022 17:45:24 +0100 Subject: [PATCH 008/108] Add sync version of the API --- src/httpx2/httpx2/_websockets/__init__.py | 4 + src/httpx2/httpx2/_websockets/_api.py | 99 +++++++++++++++++++++-- tests/httpx2/websockets/test_api.py | 10 ++- 3 files changed, 101 insertions(+), 12 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 54018dec..1c54e396 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,6 +1,7 @@ __version__ = "0.0.0" from httpx_ws._api import ( + AsyncWebSocketSession, HTTPXWSException, JSONMode, WebSocketDisconnect, @@ -8,6 +9,7 @@ WebSocketSession, WebSocketUpgradeError, aconnect_ws, + connect_ws, ) __all__ = [ @@ -15,7 +17,9 @@ "JSONMode", "WebSocketDisconnect", "WebSocketInvalidTypeReceived", + "AsyncWebSocketSession", "WebSocketSession", "WebSocketUpgradeError", "aconnect_ws", + "connect_ws", ] diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 5de42373..8e2c5adc 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -41,6 +41,69 @@ def __init__(self, response: httpx.Response) -> None: self.stream = response.extensions["network_stream"] self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) + def send(self, event: wsproto.events.Event) -> None: + self._send_event(event) + + def send_text(self, data: str) -> None: + event = wsproto.events.TextMessage(data=data) + self.send(event) + + def send_bytes(self, data: bytes) -> None: + event = wsproto.events.BytesMessage(data=data) + self.send(event) + + def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: + assert mode in ["text", "binary"] + serialized_data = json.dumps(data) + if mode == "text": + self.send_text(serialized_data) + else: + self.send_bytes(serialized_data.encode("utf-8")) + + def receive(self) -> wsproto.events.Event: + data = self.stream.read(max_bytes=4096) + self.connection.receive_data(data) + for event in self.connection.events(): + if isinstance(event, wsproto.events.CloseConnection): + raise WebSocketDisconnect(event.code, event.reason) + return event + raise HTTPXWSException() # pragma: no cover + + def receive_text(self) -> str: + event = self.receive() + if isinstance(event, wsproto.events.TextMessage): + return event.data + raise WebSocketInvalidTypeReceived(event) + + def receive_bytes(self) -> bytes: + event = self.receive() + if isinstance(event, wsproto.events.BytesMessage): + return event.data + raise WebSocketInvalidTypeReceived(event) + + def receive_json(self, mode: JSONMode = "text") -> typing.Any: + assert mode in ["text", "binary"] + data: typing.Union[str, bytes] + if mode == "text": + data = self.receive_text() + elif mode == "binary": + data = self.receive_bytes() + return json.loads(data) + + def close(self, code: int = 1000, reason: typing.Optional[str] = None): + event = wsproto.events.CloseConnection(code, reason) + self._send_event(event) + + def _send_event(self, event: wsproto.events.Event): + data = self.connection.send(event) + self.stream.write(data) + + +class AsyncWebSocketSession: + def __init__(self, response: httpx.Response) -> None: + self.stream = response.extensions["network_stream"] + self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) + async def send(self, event: wsproto.events.Event) -> None: await self._send_event(event) @@ -99,22 +162,42 @@ async def _send_event(self, event: wsproto.events.Event): await self.stream.write(data) +def _get_headers() -> typing.Dict[str, typing.Any]: + return { + "connection": "upgrade", + "upgrade": "websocket", + "sec-websocket-key": base64.b64encode(secrets.token_bytes(16)).decode("utf-8"), + "sec-websocket-version": "13", + } + + +@contextlib.contextmanager +def connect_ws( + client: httpx.Client, url: str, **kwargs: typing.Any +) -> typing.Generator[WebSocketSession, None, None]: + headers = kwargs.pop("headers", {}) + headers.update(_get_headers()) + + with client.stream("GET", url, headers=headers) as response: + if response.status_code != 101: + raise WebSocketUpgradeError(response) + + session = WebSocketSession(response) + yield session + session.close() + + @contextlib.asynccontextmanager async def aconnect_ws( client: httpx.AsyncClient, url: str, **kwargs: typing.Any -) -> typing.AsyncGenerator[WebSocketSession, None]: +) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: headers = kwargs.pop("headers", {}) - headers["connection"] = "upgrade" - headers["upgrade"] = "websocket" - headers["sec-websocket-key"] = base64.b64encode(secrets.token_bytes(16)).decode( - "utf-8" - ) - headers["sec-websocket-version"] = "13" + headers.update(_get_headers()) async with client.stream("GET", url, headers=headers) as response: if response.status_code != 101: raise WebSocketUpgradeError(response) - session = WebSocketSession(response) + session = AsyncWebSocketSession(response) yield session await session.close() diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 0ef5baf4..20fe5d2d 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -46,7 +46,7 @@ async def websocket_endpoint(websocket: WebSocket): @pytest.mark.asyncio -class TestSend: +class TestAsyncSend: async def test_send( self, websocket_app_factory: Callable[[Callable], ASGIApp], @@ -143,7 +143,7 @@ async def websocket_endpoint(websocket: WebSocket): @pytest.mark.asyncio -class TestReceive: +class TestAsyncReceive: async def test_receive(self, websocket_app_factory: Callable[[Callable], ASGIApp]): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() @@ -260,7 +260,7 @@ async def websocket_endpoint(websocket: WebSocket): @pytest.mark.asyncio -async def test_send_close(websocket_app_factory: Callable[[Callable], ASGIApp]): +async def test_async_send_close(websocket_app_factory: Callable[[Callable], ASGIApp]): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() await websocket.receive_text() @@ -275,7 +275,9 @@ async def websocket_endpoint(websocket: WebSocket): @pytest.mark.asyncio -async def test_receive_close(websocket_app_factory: Callable[[Callable], ASGIApp]): +async def test_async_receive_close( + websocket_app_factory: Callable[[Callable], ASGIApp] +): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() await websocket.close() From 35480c62938b8319d0036f5aa66659bde1373004 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 23 Nov 2022 10:03:52 +0100 Subject: [PATCH 009/108] Revamp unit tests so they work against a real Uvicorn server --- tests/httpx2/websockets/conftest.py | 43 +++-- tests/httpx2/websockets/test_api.py | 259 +++++++++++++++------------- 2 files changed, 169 insertions(+), 133 deletions(-) diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index 6e440ec4..3211f266 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -1,11 +1,14 @@ import asyncio -from typing import Callable +import contextlib +import queue +from typing import Callable, ContextManager from unittest.mock import MagicMock import pytest +import uvicorn +from anyio.from_thread import start_blocking_portal from starlette.applications import Starlette from starlette.routing import WebSocketRoute -from starlette.types import ASGIApp @pytest.fixture(scope="session") @@ -20,12 +23,32 @@ def on_receive_message(): return MagicMock() +ServerFactoryFixture = Callable[[Callable], ContextManager[None]] + + @pytest.fixture -def websocket_app_factory() -> Callable[[Callable], ASGIApp]: - def _websocket_app_factory(endpoint: Callable): - routes = [ - WebSocketRoute("/ws", endpoint=endpoint), - ] - return Starlette(routes=routes) - - return _websocket_app_factory +def server_factory() -> ServerFactoryFixture: + @contextlib.contextmanager + def _server_factory(endpoint: Callable): + startup_queue: queue.Queue[bool] = queue.Queue() + + async def start_uvicorn(): + routes = [ + WebSocketRoute("/ws", endpoint=endpoint), + ] + + async def on_startup(): + startup_queue.put(True) + + app = Starlette(routes=routes, on_startup=[on_startup]) + config = uvicorn.Config(app, port=8000) + server = uvicorn.Server(config) + await server.serve() + + with start_blocking_portal(backend="asyncio") as portal: + future = portal.start_task_soon(start_uvicorn) + startup_queue.get(True) + yield + future.cancel() + + return _server_factory diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 20fe5d2d..1021dde6 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -1,10 +1,8 @@ -from typing import Callable -from unittest.mock import MagicMock +from unittest.mock import MagicMock, call import httpx import pytest import wsproto -from starlette.types import ASGIApp from starlette.websockets import WebSocket from httpx_ws import ( @@ -13,8 +11,9 @@ WebSocketInvalidTypeReceived, WebSocketUpgradeError, aconnect_ws, + connect_ws, ) -from httpx_ws.transport import ASGIWebSocketTransport +from tests.conftest import ServerFactoryFixture @pytest.mark.asyncio @@ -22,6 +21,13 @@ async def test_upgrade_error(): def handler(request): return httpx.Response(400) + with httpx.Client( + base_url="http://localhost:8000", transport=httpx.MockTransport(handler) + ) as client: + with pytest.raises(WebSocketUpgradeError): + with connect_ws(client, "/ws"): + pass + async with httpx.AsyncClient( base_url="http://localhost:8000", transport=httpx.MockTransport(handler) ) as client: @@ -30,26 +36,11 @@ def handler(request): pass -@pytest.fixture -def send_app( - websocket_app_factory: Callable[[Callable], ASGIApp], on_receive_message: MagicMock -) -> ASGIApp: - async def websocket_endpoint(websocket: WebSocket): - await websocket.accept() - - message = await websocket.receive_text() - on_receive_message(message) - - await websocket.close() - - return websocket_app_factory(websocket_endpoint) - - @pytest.mark.asyncio -class TestAsyncSend: +class TestSend: async def test_send( self, - websocket_app_factory: Callable[[Callable], ASGIApp], + server_factory: ServerFactoryFixture, on_receive_message: MagicMock, ): async def websocket_endpoint(websocket: WebSocket): @@ -60,19 +51,22 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - app = websocket_app_factory(websocket_endpoint) + with server_factory(websocket_endpoint): + with httpx.Client(base_url="http://localhost:8000") as client: + with connect_ws(client, "/ws") as ws: + ws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) - async with httpx.AsyncClient( - base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) - ) as client: - async with aconnect_ws(client, "/ws") as ws: - await ws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) + async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: + async with aconnect_ws(aclient, "/ws") as aws: + await aws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) - on_receive_message.assert_called_once_with("CLIENT_MESSAGE") + on_receive_message.assert_has_calls( + [call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")] + ) async def test_send_text( self, - websocket_app_factory: Callable[[Callable], ASGIApp], + server_factory: ServerFactoryFixture, on_receive_message: MagicMock, ): async def websocket_endpoint(websocket: WebSocket): @@ -83,19 +77,22 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - app = websocket_app_factory(websocket_endpoint) + with server_factory(websocket_endpoint): + with httpx.Client(base_url="http://localhost:8000") as client: + with connect_ws(client, "/ws") as ws: + ws.send_text("CLIENT_MESSAGE") - async with httpx.AsyncClient( - base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) - ) as client: - async with aconnect_ws(client, "/ws") as ws: - await ws.send_text("CLIENT_MESSAGE") + async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: + async with aconnect_ws(aclient, "/ws") as aws: + await aws.send_text("CLIENT_MESSAGE") - on_receive_message.assert_called_once_with("CLIENT_MESSAGE") + on_receive_message.assert_has_calls( + [call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")] + ) async def test_send_bytes( self, - websocket_app_factory: Callable[[Callable], ASGIApp], + server_factory: ServerFactoryFixture, on_receive_message: MagicMock, ): async def websocket_endpoint(websocket: WebSocket): @@ -106,21 +103,24 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - app = websocket_app_factory(websocket_endpoint) + with server_factory(websocket_endpoint): + with httpx.Client(base_url="http://localhost:8000") as client: + with connect_ws(client, "/ws") as ws: + ws.send_bytes(b"CLIENT_MESSAGE") - async with httpx.AsyncClient( - base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) - ) as client: - async with aconnect_ws(client, "/ws") as ws: - await ws.send_bytes(b"CLIENT_MESSAGE") + async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: + async with aconnect_ws(aclient, "/ws") as aws: + await aws.send_bytes(b"CLIENT_MESSAGE") - on_receive_message.assert_called_once_with(b"CLIENT_MESSAGE") + on_receive_message.assert_has_calls( + [call(b"CLIENT_MESSAGE"), call(b"CLIENT_MESSAGE")] + ) @pytest.mark.parametrize("mode", ["text", "binary"]) async def test_send_json( self, mode: JSONMode, - websocket_app_factory: Callable[[Callable], ASGIApp], + server_factory: ServerFactoryFixture, on_receive_message: MagicMock, ): async def websocket_endpoint(websocket: WebSocket): @@ -131,20 +131,23 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - app = websocket_app_factory(websocket_endpoint) + with server_factory(websocket_endpoint): + with httpx.Client(base_url="http://localhost:8000") as client: + with connect_ws(client, "/ws") as ws: + ws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) - async with httpx.AsyncClient( - base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) - ) as client: - async with aconnect_ws(client, "/ws") as ws: - await ws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) + async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: + async with aconnect_ws(aclient, "/ws") as aws: + await aws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) - on_receive_message.assert_called_once_with({"message": "CLIENT_MESSAGE"}) + on_receive_message.assert_has_calls( + [call({"message": "CLIENT_MESSAGE"}), call({"message": "CLIENT_MESSAGE"})] + ) @pytest.mark.asyncio -class TestAsyncReceive: - async def test_receive(self, websocket_app_factory: Callable[[Callable], ASGIApp]): +class TestReceive: + async def test_receive(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() @@ -152,19 +155,20 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - app = websocket_app_factory(websocket_endpoint) + with server_factory(websocket_endpoint): + with httpx.Client(base_url="http://localhost:8000") as client: + with connect_ws(client, "/ws") as ws: + event = ws.receive() + assert isinstance(event, wsproto.events.TextMessage) + assert event.data == "SERVER_MESSAGE" - async with httpx.AsyncClient( - base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) - ) as client: - async with aconnect_ws(client, "/ws") as ws: - event = await ws.receive() - assert isinstance(event, wsproto.events.TextMessage) - assert event.data == "SERVER_MESSAGE" + async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: + async with aconnect_ws(aclient, "/ws") as aws: + event = await aws.receive() + assert isinstance(event, wsproto.events.TextMessage) + assert event.data == "SERVER_MESSAGE" - async def test_receive_text( - self, websocket_app_factory: Callable[[Callable], ASGIApp] - ): + async def test_receive_text(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() @@ -172,17 +176,19 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - app = websocket_app_factory(websocket_endpoint) + with server_factory(websocket_endpoint): + with httpx.Client(base_url="http://localhost:8000") as client: + with connect_ws(client, "/ws") as ws: + data = ws.receive_text() + assert data == "SERVER_MESSAGE" - async with httpx.AsyncClient( - base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) - ) as client: - async with aconnect_ws(client, "/ws") as ws: - data = await ws.receive_text() - assert data == "SERVER_MESSAGE" + async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: + async with aconnect_ws(aclient, "/ws") as aws: + data = await aws.receive_text() + assert data == "SERVER_MESSAGE" async def test_receive_text_invalid_type( - self, websocket_app_factory: Callable[[Callable], ASGIApp] + self, server_factory: ServerFactoryFixture ): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() @@ -191,18 +197,18 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - app = websocket_app_factory(websocket_endpoint) + with server_factory(websocket_endpoint): + with httpx.Client(base_url="http://localhost:8000") as client: + with pytest.raises(WebSocketInvalidTypeReceived): + with connect_ws(client, "/ws") as ws: + ws.receive_text() - async with httpx.AsyncClient( - base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) - ) as client: - with pytest.raises(WebSocketInvalidTypeReceived): - async with aconnect_ws(client, "/ws") as ws: - await ws.receive_text() + async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: + with pytest.raises(WebSocketInvalidTypeReceived): + async with aconnect_ws(aclient, "/ws") as aws: + await aws.receive_text() - async def test_receive_bytes( - self, websocket_app_factory: Callable[[Callable], ASGIApp] - ): + async def test_receive_bytes(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() @@ -210,17 +216,19 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - app = websocket_app_factory(websocket_endpoint) + with server_factory(websocket_endpoint): + with httpx.Client(base_url="http://localhost:8000") as client: + with connect_ws(client, "/ws") as ws: + data = ws.receive_bytes() + assert data == b"SERVER_MESSAGE" - async with httpx.AsyncClient( - base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) - ) as client: - async with aconnect_ws(client, "/ws") as ws: - data = await ws.receive_bytes() - assert data == b"SERVER_MESSAGE" + async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: + async with aconnect_ws(aclient, "/ws") as aws: + data = await aws.receive_bytes() + assert data == b"SERVER_MESSAGE" async def test_receive_bytes_invalid_type( - self, websocket_app_factory: Callable[[Callable], ASGIApp] + self, server_factory: ServerFactoryFixture ): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() @@ -229,18 +237,20 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - app = websocket_app_factory(websocket_endpoint) + with server_factory(websocket_endpoint): + with httpx.Client(base_url="http://localhost:8000") as client: + with pytest.raises(WebSocketInvalidTypeReceived): + with connect_ws(client, "/ws") as ws: + ws.receive_bytes() - async with httpx.AsyncClient( - base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) - ) as client: - with pytest.raises(WebSocketInvalidTypeReceived): - async with aconnect_ws(client, "/ws") as ws: - await ws.receive_bytes() + async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: + with pytest.raises(WebSocketInvalidTypeReceived): + async with aconnect_ws(aclient, "/ws") as aws: + await aws.receive_bytes() @pytest.mark.parametrize("mode", ["text", "binary"]) async def test_receive_json( - self, mode: JSONMode, websocket_app_factory: Callable[[Callable], ASGIApp] + self, mode: JSONMode, server_factory: ServerFactoryFixture ): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() @@ -249,44 +259,47 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - app = websocket_app_factory(websocket_endpoint) + with server_factory(websocket_endpoint): + with httpx.Client(base_url="http://localhost:8000") as client: + with connect_ws(client, "/ws") as ws: + data = ws.receive_json(mode=mode) + assert data == {"message": "SERVER_MESSAGE"} - async with httpx.AsyncClient( - base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) - ) as client: - async with aconnect_ws(client, "/ws") as ws: - data = await ws.receive_json(mode=mode) - assert data == {"message": "SERVER_MESSAGE"} + async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: + async with aconnect_ws(aclient, "/ws") as aws: + data = await aws.receive_json(mode=mode) + assert data == {"message": "SERVER_MESSAGE"} @pytest.mark.asyncio -async def test_async_send_close(websocket_app_factory: Callable[[Callable], ASGIApp]): +async def test_send_close(server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() await websocket.receive_text() - app = websocket_app_factory(websocket_endpoint) + with server_factory(websocket_endpoint): + with httpx.Client(base_url="http://localhost:8000") as client: + with connect_ws(client, "/ws"): + pass - async with httpx.AsyncClient( - base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) - ) as client: - async with aconnect_ws(client, "/ws"): - pass + async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: + async with aconnect_ws(aclient, "/ws"): + pass @pytest.mark.asyncio -async def test_async_receive_close( - websocket_app_factory: Callable[[Callable], ASGIApp] -): +async def test_receive_close(server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() await websocket.close() - app = websocket_app_factory(websocket_endpoint) + with server_factory(websocket_endpoint): + with httpx.Client(base_url="http://localhost:8000") as client: + with pytest.raises(WebSocketDisconnect): + with connect_ws(client, "/ws") as ws: + ws.receive() - async with httpx.AsyncClient( - base_url="http://localhost:8000", transport=ASGIWebSocketTransport(app) - ) as client: - with pytest.raises(WebSocketDisconnect): - async with aconnect_ws(client, "/ws") as ws: - await ws.receive() + async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: + with pytest.raises(WebSocketDisconnect): + async with aconnect_ws(aclient, "/ws") as aws: + await aws.receive() From f8a96d20bb762a50895993ac1d000aec797aaff9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 23 Nov 2022 11:27:30 +0100 Subject: [PATCH 010/108] Add tests for ASGIWebSocketTransport --- src/httpx2/httpx2/_websockets/transport.py | 12 +- tests/httpx2/websockets/test_transport.py | 136 +++++++++++++++++++++ 2 files changed, 141 insertions(+), 7 deletions(-) create mode 100644 tests/httpx2/websockets/test_transport.py diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index 69a84647..69d582cc 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -42,13 +42,10 @@ async def __aenter__(self) -> "ASGIWebSocketAsyncNetworkStream": self.portal = self.exit_stack.enter_context( anyio.start_blocking_portal("asyncio") ) - try: - _: "Future[None]" = self.portal.start_task_soon(self._run) - await self.send({"type": "websocket.connect"}) - message = await self.receive() - assert message["type"] == "websocket.accept" - except Exception: - raise + _: "Future[None]" = self.portal.start_task_soon(self._run) + await self.send({"type": "websocket.connect"}) + message = await self.receive() + assert message["type"] == "websocket.accept" return self async def __aexit__(self, *args: typing.Any) -> None: @@ -97,6 +94,7 @@ async def write( raise UnhandledWebSocketEvent(event) async def aclose(self) -> None: + await self.send({"type": "websocket.close"}) self.exit_stack.close() async def send(self, message: Message) -> None: diff --git a/tests/httpx2/websockets/test_transport.py b/tests/httpx2/websockets/test_transport.py new file mode 100644 index 00000000..70b3c68d --- /dev/null +++ b/tests/httpx2/websockets/test_transport.py @@ -0,0 +1,136 @@ +from typing import Any, Dict + +import httpx +import pytest +import wsproto +from starlette.applications import Starlette +from starlette.responses import PlainTextResponse +from starlette.routing import Route, WebSocketRoute +from starlette.websockets import WebSocket + +from httpx_ws.transport import ( + ASGIWebSocketAsyncNetworkStream, + ASGIWebSocketTransport, + UnhandledASGIMessageType, + UnhandledWebSocketEvent, +) + + +@pytest.mark.asyncio +class TestASGIWebSocketAsyncNetworkStream: + async def test_write(self): + received_messages = [] + + async def app(scope, receive, send): + await send({"type": "websocket.accept"}) + message = await receive() + received_messages.append(message) + while message["type"] != "websocket.close": + message = await receive() + received_messages.append(message) + + connection = wsproto.connection.Connection(wsproto.connection.CLIENT) + async with ASGIWebSocketAsyncNetworkStream(app, {}) as stream: + text_event = wsproto.events.TextMessage("CLIENT_MESSAGE") + await stream.write(connection.send(text_event)) + + bytes_event = wsproto.events.BytesMessage(b"CLIENT_MESSAGE") + await stream.write(connection.send(bytes_event)) + + close_event = wsproto.events.CloseConnection(1000) + await stream.write(connection.send(close_event)) + + assert received_messages == [ + {"type": "websocket.connect"}, + {"type": "websocket.receive", "text": "CLIENT_MESSAGE"}, + {"type": "websocket.receive", "bytes": b"CLIENT_MESSAGE"}, + {"type": "websocket.close", "code": 1000, "reason": ""}, + ] + + async def test_write_unhandled_event(self): + async def app(scope, receive, send): + await send({"type": "websocket.accept"}) + await receive() + + connection = wsproto.connection.Connection(wsproto.connection.CLIENT) + async with ASGIWebSocketAsyncNetworkStream(app, {}) as stream: + with pytest.raises(UnhandledWebSocketEvent): + ping_event = wsproto.events.Ping(b"PING") + await stream.write(connection.send(ping_event)) + + async def test_read(self): + async def app(scope, receive, send): + await send({"type": "websocket.accept"}) + await send({"type": "websocket.send", "text": "SERVER_MESSAGE"}) + await send({"type": "websocket.send", "bytes": b"SERVER_MESSAGE"}) + await send({"type": "websocket.close", "code": 1000, "reason": ""}) + + connection = wsproto.connection.Connection(wsproto.connection.CLIENT) + events = [] + async with ASGIWebSocketAsyncNetworkStream(app, {}) as stream: + for _ in range(3): + data = await stream.read(4096) + connection.receive_data(data) + + events = list(connection.events()) + assert events == [ + wsproto.events.TextMessage("SERVER_MESSAGE"), + wsproto.events.BytesMessage(bytearray(b"SERVER_MESSAGE")), + wsproto.events.CloseConnection(1000, ""), + ] + + async def test_read_unhandled_asgi_message(self): + async def app(scope, receive, send): + await send({"type": "websocket.accept"}) + await send({"type": "websocket.foo"}) + + async with ASGIWebSocketAsyncNetworkStream(app, {}) as stream: + with pytest.raises(UnhandledASGIMessageType): + await stream.read(4096) + + +@pytest.fixture +def test_app() -> Starlette: + async def http_endpoint(request): + return PlainTextResponse("Hello, world!") + + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.receive_text() + await websocket.close() + + routes = [ + Route("/http", endpoint=http_endpoint), + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + + return Starlette(routes=routes) + + +@pytest.mark.asyncio +class TestASGIWebSocketTransport: + async def test_http(self, test_app: Starlette): + async with ASGIWebSocketTransport(app=test_app) as transport: + request = httpx.Request("GET", "http://localhost:8000/http") + response = await transport.handle_async_request(request) + assert response.status_code == 200 + + @pytest.mark.parametrize( + "url,headers", + [ + ("ws://localhost:8000/ws", {}), + ("wss://localhost:8000/ws", {}), + ("http://localhost:8000/ws", {"upgrade": "websocket"}), + ], + ) + async def test_websocket( + self, url: str, headers: Dict[str, Any], test_app: Starlette + ): + async with ASGIWebSocketTransport(app=test_app) as transport: + request = httpx.Request("GET", url, headers=headers) + response = await transport.handle_async_request(request) + assert response.status_code == 101 + + assert isinstance( + response.extensions["network_stream"], ASGIWebSocketAsyncNetworkStream + ) From 1c704db93cb5dcb220f8fb68e0c3b4f623037d68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 23 Nov 2022 14:30:59 +0100 Subject: [PATCH 011/108] Try to use Unix socket to serve apps during testing --- tests/httpx2/websockets/conftest.py | 21 ++-- tests/httpx2/websockets/test_api.py | 150 ++++++++++++++++------------ 2 files changed, 100 insertions(+), 71 deletions(-) diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index 3211f266..c03a3acf 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -1,6 +1,8 @@ import asyncio import contextlib +import pathlib import queue +import tempfile from typing import Callable, ContextManager from unittest.mock import MagicMock @@ -14,8 +16,9 @@ @pytest.fixture(scope="session") def event_loop(): """Force the pytest-asyncio loop to be the main one.""" - loop = asyncio.get_event_loop() + loop = asyncio.new_event_loop() yield loop + loop.close() @pytest.fixture @@ -23,7 +26,7 @@ def on_receive_message(): return MagicMock() -ServerFactoryFixture = Callable[[Callable], ContextManager[None]] +ServerFactoryFixture = Callable[[Callable], ContextManager[str]] @pytest.fixture @@ -32,7 +35,7 @@ def server_factory() -> ServerFactoryFixture: def _server_factory(endpoint: Callable): startup_queue: queue.Queue[bool] = queue.Queue() - async def start_uvicorn(): + async def start_uvicorn(socket: str): routes = [ WebSocketRoute("/ws", endpoint=endpoint), ] @@ -41,14 +44,16 @@ async def on_startup(): startup_queue.put(True) app = Starlette(routes=routes, on_startup=[on_startup]) - config = uvicorn.Config(app, port=8000) + config = uvicorn.Config(app, uds=socket) server = uvicorn.Server(config) await server.serve() with start_blocking_portal(backend="asyncio") as portal: - future = portal.start_task_soon(start_uvicorn) - startup_queue.get(True) - yield - future.cancel() + with tempfile.TemporaryDirectory() as socket_directory: + socket_path = str(pathlib.Path(socket_directory) / "socket.sock") + future = portal.start_task_soon(start_uvicorn, socket_path) + startup_queue.get(True) + yield socket_path + future.cancel() return _server_factory diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 1021dde6..9bf16e8b 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -25,20 +25,20 @@ def handler(request): base_url="http://localhost:8000", transport=httpx.MockTransport(handler) ) as client: with pytest.raises(WebSocketUpgradeError): - with connect_ws(client, "/ws"): + with connect_ws(client, "http://socket/ws"): pass async with httpx.AsyncClient( base_url="http://localhost:8000", transport=httpx.MockTransport(handler) ) as client: with pytest.raises(WebSocketUpgradeError): - async with aconnect_ws(client, "/ws"): + async with aconnect_ws(client, "http://socket/ws"): pass @pytest.mark.asyncio class TestSend: - async def test_send( + async def test_send_foo( self, server_factory: ServerFactoryFixture, on_receive_message: MagicMock, @@ -51,13 +51,15 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - with server_factory(websocket_endpoint): - with httpx.Client(base_url="http://localhost:8000") as client: - with connect_ws(client, "/ws") as ws: + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with connect_ws(client, "http://socket/ws") as ws: ws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) - async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: - async with aconnect_ws(aclient, "/ws") as aws: + async with httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(uds=socket) + ) as aclient: + async with aconnect_ws(aclient, "http://socket/ws") as aws: await aws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) on_receive_message.assert_has_calls( @@ -77,13 +79,15 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - with server_factory(websocket_endpoint): - with httpx.Client(base_url="http://localhost:8000") as client: - with connect_ws(client, "/ws") as ws: + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with connect_ws(client, "http://socket/ws") as ws: ws.send_text("CLIENT_MESSAGE") - async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: - async with aconnect_ws(aclient, "/ws") as aws: + async with httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(uds=socket) + ) as aclient: + async with aconnect_ws(aclient, "http://socket/ws") as aws: await aws.send_text("CLIENT_MESSAGE") on_receive_message.assert_has_calls( @@ -103,13 +107,15 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - with server_factory(websocket_endpoint): - with httpx.Client(base_url="http://localhost:8000") as client: - with connect_ws(client, "/ws") as ws: + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with connect_ws(client, "http://socket/ws") as ws: ws.send_bytes(b"CLIENT_MESSAGE") - async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: - async with aconnect_ws(aclient, "/ws") as aws: + async with httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(uds=socket) + ) as aclient: + async with aconnect_ws(aclient, "http://socket/ws") as aws: await aws.send_bytes(b"CLIENT_MESSAGE") on_receive_message.assert_has_calls( @@ -131,13 +137,15 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - with server_factory(websocket_endpoint): - with httpx.Client(base_url="http://localhost:8000") as client: - with connect_ws(client, "/ws") as ws: + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with connect_ws(client, "http://socket/ws") as ws: ws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) - async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: - async with aconnect_ws(aclient, "/ws") as aws: + async with httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(uds=socket) + ) as aclient: + async with aconnect_ws(aclient, "http://socket/ws") as aws: await aws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) on_receive_message.assert_has_calls( @@ -155,15 +163,17 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - with server_factory(websocket_endpoint): - with httpx.Client(base_url="http://localhost:8000") as client: - with connect_ws(client, "/ws") as ws: + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with connect_ws(client, "http://socket/ws") as ws: event = ws.receive() assert isinstance(event, wsproto.events.TextMessage) assert event.data == "SERVER_MESSAGE" - async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: - async with aconnect_ws(aclient, "/ws") as aws: + async with httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(uds=socket) + ) as aclient: + async with aconnect_ws(aclient, "http://socket/ws") as aws: event = await aws.receive() assert isinstance(event, wsproto.events.TextMessage) assert event.data == "SERVER_MESSAGE" @@ -176,14 +186,16 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - with server_factory(websocket_endpoint): - with httpx.Client(base_url="http://localhost:8000") as client: - with connect_ws(client, "/ws") as ws: + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with connect_ws(client, "http://socket/ws") as ws: data = ws.receive_text() assert data == "SERVER_MESSAGE" - async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: - async with aconnect_ws(aclient, "/ws") as aws: + async with httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(uds=socket) + ) as aclient: + async with aconnect_ws(aclient, "http://socket/ws") as aws: data = await aws.receive_text() assert data == "SERVER_MESSAGE" @@ -197,15 +209,17 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - with server_factory(websocket_endpoint): - with httpx.Client(base_url="http://localhost:8000") as client: + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: with pytest.raises(WebSocketInvalidTypeReceived): - with connect_ws(client, "/ws") as ws: + with connect_ws(client, "http://socket/ws") as ws: ws.receive_text() - async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: + async with httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(uds=socket) + ) as aclient: with pytest.raises(WebSocketInvalidTypeReceived): - async with aconnect_ws(aclient, "/ws") as aws: + async with aconnect_ws(aclient, "http://socket/ws") as aws: await aws.receive_text() async def test_receive_bytes(self, server_factory: ServerFactoryFixture): @@ -216,14 +230,16 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - with server_factory(websocket_endpoint): - with httpx.Client(base_url="http://localhost:8000") as client: - with connect_ws(client, "/ws") as ws: + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with connect_ws(client, "http://socket/ws") as ws: data = ws.receive_bytes() assert data == b"SERVER_MESSAGE" - async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: - async with aconnect_ws(aclient, "/ws") as aws: + async with httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(uds=socket) + ) as aclient: + async with aconnect_ws(aclient, "http://socket/ws") as aws: data = await aws.receive_bytes() assert data == b"SERVER_MESSAGE" @@ -237,15 +253,17 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - with server_factory(websocket_endpoint): - with httpx.Client(base_url="http://localhost:8000") as client: + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: with pytest.raises(WebSocketInvalidTypeReceived): - with connect_ws(client, "/ws") as ws: + with connect_ws(client, "http://socket/ws") as ws: ws.receive_bytes() - async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: + async with httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(uds=socket) + ) as aclient: with pytest.raises(WebSocketInvalidTypeReceived): - async with aconnect_ws(aclient, "/ws") as aws: + async with aconnect_ws(aclient, "http://socket/ws") as aws: await aws.receive_bytes() @pytest.mark.parametrize("mode", ["text", "binary"]) @@ -259,14 +277,16 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close() - with server_factory(websocket_endpoint): - with httpx.Client(base_url="http://localhost:8000") as client: - with connect_ws(client, "/ws") as ws: + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with connect_ws(client, "http://socket/ws") as ws: data = ws.receive_json(mode=mode) assert data == {"message": "SERVER_MESSAGE"} - async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: - async with aconnect_ws(aclient, "/ws") as aws: + async with httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(uds=socket) + ) as aclient: + async with aconnect_ws(aclient, "http://socket/ws") as aws: data = await aws.receive_json(mode=mode) assert data == {"message": "SERVER_MESSAGE"} @@ -277,13 +297,15 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.accept() await websocket.receive_text() - with server_factory(websocket_endpoint): - with httpx.Client(base_url="http://localhost:8000") as client: - with connect_ws(client, "/ws"): + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with connect_ws(client, "http://socket/ws"): pass - async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: - async with aconnect_ws(aclient, "/ws"): + async with httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(uds=socket) + ) as aclient: + async with aconnect_ws(aclient, "http://socket/ws"): pass @@ -293,13 +315,15 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.accept() await websocket.close() - with server_factory(websocket_endpoint): - with httpx.Client(base_url="http://localhost:8000") as client: + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: with pytest.raises(WebSocketDisconnect): - with connect_ws(client, "/ws") as ws: + with connect_ws(client, "http://socket/ws") as ws: ws.receive() - async with httpx.AsyncClient(base_url="http://localhost:8000") as aclient: + async with httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(uds=socket) + ) as aclient: with pytest.raises(WebSocketDisconnect): - async with aconnect_ws(aclient, "/ws") as aws: + async with aconnect_ws(aclient, "http://socket/ws") as aws: await aws.receive() From fd98e0c72a8e10981a1d64959c4b390a8b9edb6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 23 Nov 2022 15:15:23 +0100 Subject: [PATCH 012/108] Add max_bytes parameter to all receive methods --- src/httpx2/httpx2/_websockets/_api.py | 46 ++++++++++++++++----------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 8e2c5adc..54724710 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -15,6 +15,8 @@ JSONMode = Literal["text", "binary"] +DEFAULT_RECEIVE_MAX_BYTES = 65_536 + class HTTPXWSException(Exception): pass @@ -60,8 +62,10 @@ def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: else: self.send_bytes(serialized_data.encode("utf-8")) - def receive(self) -> wsproto.events.Event: - data = self.stream.read(max_bytes=4096) + def receive( + self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES + ) -> wsproto.events.Event: + data = self.stream.read(max_bytes=max_bytes) self.connection.receive_data(data) for event in self.connection.events(): if isinstance(event, wsproto.events.CloseConnection): @@ -69,25 +73,27 @@ def receive(self) -> wsproto.events.Event: return event raise HTTPXWSException() # pragma: no cover - def receive_text(self) -> str: - event = self.receive() + def receive_text(self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES) -> str: + event = self.receive(max_bytes) if isinstance(event, wsproto.events.TextMessage): return event.data raise WebSocketInvalidTypeReceived(event) - def receive_bytes(self) -> bytes: - event = self.receive() + def receive_bytes(self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES) -> bytes: + event = self.receive(max_bytes) if isinstance(event, wsproto.events.BytesMessage): return event.data raise WebSocketInvalidTypeReceived(event) - def receive_json(self, mode: JSONMode = "text") -> typing.Any: + def receive_json( + self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES, mode: JSONMode = "text" + ) -> typing.Any: assert mode in ["text", "binary"] data: typing.Union[str, bytes] if mode == "text": - data = self.receive_text() + data = self.receive_text(max_bytes) elif mode == "binary": - data = self.receive_bytes() + data = self.receive_bytes(max_bytes) return json.loads(data) def close(self, code: int = 1000, reason: typing.Optional[str] = None): @@ -123,8 +129,10 @@ async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: else: await self.send_bytes(serialized_data.encode("utf-8")) - async def receive(self) -> wsproto.events.Event: - data = await self.stream.read(max_bytes=4096) + async def receive( + self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES + ) -> wsproto.events.Event: + data = await self.stream.read(max_bytes=max_bytes) self.connection.receive_data(data) for event in self.connection.events(): if isinstance(event, wsproto.events.CloseConnection): @@ -132,25 +140,27 @@ async def receive(self) -> wsproto.events.Event: return event raise HTTPXWSException() # pragma: no cover - async def receive_text(self) -> str: - event = await self.receive() + async def receive_text(self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES) -> str: + event = await self.receive(max_bytes) if isinstance(event, wsproto.events.TextMessage): return event.data raise WebSocketInvalidTypeReceived(event) - async def receive_bytes(self) -> bytes: - event = await self.receive() + async def receive_bytes(self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES) -> bytes: + event = await self.receive(max_bytes) if isinstance(event, wsproto.events.BytesMessage): return event.data raise WebSocketInvalidTypeReceived(event) - async def receive_json(self, mode: JSONMode = "text") -> typing.Any: + async def receive_json( + self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES, mode: JSONMode = "text" + ) -> typing.Any: assert mode in ["text", "binary"] data: typing.Union[str, bytes] if mode == "text": - data = await self.receive_text() + data = await self.receive_text(max_bytes) elif mode == "binary": - data = await self.receive_bytes() + data = await self.receive_bytes(max_bytes) return json.loads(data) async def close(self, code: int = 1000, reason: typing.Optional[str] = None): From 5e11fc9f80ac42c05168d005ba6842dfe067ed1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 23 Nov 2022 15:30:23 +0100 Subject: [PATCH 013/108] Swallow WebSocketDisconnect during unit tests --- tests/httpx2/websockets/test_api.py | 152 +++++++++++++++++++--------- 1 file changed, 104 insertions(+), 48 deletions(-) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 9bf16e8b..5bab013a 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -53,14 +53,22 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: - with connect_ws(client, "http://socket/ws") as ws: - ws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) + try: + with connect_ws(client, "http://socket/ws") as ws: + ws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) + except WebSocketDisconnect: + pass async with httpx.AsyncClient( transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: - async with aconnect_ws(aclient, "http://socket/ws") as aws: - await aws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) + try: + async with aconnect_ws(aclient, "http://socket/ws") as aws: + await aws.send( + wsproto.events.TextMessage(data="CLIENT_MESSAGE") + ) + except WebSocketDisconnect: + pass on_receive_message.assert_has_calls( [call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")] @@ -81,14 +89,20 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: - with connect_ws(client, "http://socket/ws") as ws: - ws.send_text("CLIENT_MESSAGE") + try: + with connect_ws(client, "http://socket/ws") as ws: + ws.send_text("CLIENT_MESSAGE") + except WebSocketDisconnect: + pass async with httpx.AsyncClient( transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: - async with aconnect_ws(aclient, "http://socket/ws") as aws: - await aws.send_text("CLIENT_MESSAGE") + try: + async with aconnect_ws(aclient, "http://socket/ws") as aws: + await aws.send_text("CLIENT_MESSAGE") + except WebSocketDisconnect: + pass on_receive_message.assert_has_calls( [call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")] @@ -109,14 +123,20 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: - with connect_ws(client, "http://socket/ws") as ws: - ws.send_bytes(b"CLIENT_MESSAGE") + try: + with connect_ws(client, "http://socket/ws") as ws: + ws.send_bytes(b"CLIENT_MESSAGE") + except WebSocketDisconnect: + pass async with httpx.AsyncClient( transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: - async with aconnect_ws(aclient, "http://socket/ws") as aws: - await aws.send_bytes(b"CLIENT_MESSAGE") + try: + async with aconnect_ws(aclient, "http://socket/ws") as aws: + await aws.send_bytes(b"CLIENT_MESSAGE") + except WebSocketDisconnect: + pass on_receive_message.assert_has_calls( [call(b"CLIENT_MESSAGE"), call(b"CLIENT_MESSAGE")] @@ -139,14 +159,20 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: - with connect_ws(client, "http://socket/ws") as ws: - ws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) + try: + with connect_ws(client, "http://socket/ws") as ws: + ws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) + except WebSocketDisconnect: + pass async with httpx.AsyncClient( transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: - async with aconnect_ws(aclient, "http://socket/ws") as aws: - await aws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) + try: + async with aconnect_ws(aclient, "http://socket/ws") as aws: + await aws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) + except WebSocketDisconnect: + pass on_receive_message.assert_has_calls( [call({"message": "CLIENT_MESSAGE"}), call({"message": "CLIENT_MESSAGE"})] @@ -165,18 +191,24 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: - with connect_ws(client, "http://socket/ws") as ws: - event = ws.receive() - assert isinstance(event, wsproto.events.TextMessage) - assert event.data == "SERVER_MESSAGE" + try: + with connect_ws(client, "http://socket/ws") as ws: + event = ws.receive() + assert isinstance(event, wsproto.events.TextMessage) + assert event.data == "SERVER_MESSAGE" + except WebSocketDisconnect: + pass async with httpx.AsyncClient( transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: - async with aconnect_ws(aclient, "http://socket/ws") as aws: - event = await aws.receive() - assert isinstance(event, wsproto.events.TextMessage) - assert event.data == "SERVER_MESSAGE" + try: + async with aconnect_ws(aclient, "http://socket/ws") as aws: + event = await aws.receive() + assert isinstance(event, wsproto.events.TextMessage) + assert event.data == "SERVER_MESSAGE" + except WebSocketDisconnect: + pass async def test_receive_text(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): @@ -188,16 +220,22 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: - with connect_ws(client, "http://socket/ws") as ws: - data = ws.receive_text() - assert data == "SERVER_MESSAGE" + try: + with connect_ws(client, "http://socket/ws") as ws: + data = ws.receive_text() + assert data == "SERVER_MESSAGE" + except WebSocketDisconnect: + pass async with httpx.AsyncClient( transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: - async with aconnect_ws(aclient, "http://socket/ws") as aws: - data = await aws.receive_text() - assert data == "SERVER_MESSAGE" + try: + async with aconnect_ws(aclient, "http://socket/ws") as aws: + data = await aws.receive_text() + assert data == "SERVER_MESSAGE" + except WebSocketDisconnect: + pass async def test_receive_text_invalid_type( self, server_factory: ServerFactoryFixture @@ -211,16 +249,22 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: - with pytest.raises(WebSocketInvalidTypeReceived): - with connect_ws(client, "http://socket/ws") as ws: - ws.receive_text() + try: + with pytest.raises(WebSocketInvalidTypeReceived): + with connect_ws(client, "http://socket/ws") as ws: + ws.receive_text() + except WebSocketDisconnect: + pass async with httpx.AsyncClient( transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: - with pytest.raises(WebSocketInvalidTypeReceived): - async with aconnect_ws(aclient, "http://socket/ws") as aws: - await aws.receive_text() + try: + with pytest.raises(WebSocketInvalidTypeReceived): + async with aconnect_ws(aclient, "http://socket/ws") as aws: + await aws.receive_text() + except WebSocketDisconnect: + pass async def test_receive_bytes(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): @@ -232,16 +276,22 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: - with connect_ws(client, "http://socket/ws") as ws: - data = ws.receive_bytes() - assert data == b"SERVER_MESSAGE" + try: + with connect_ws(client, "http://socket/ws") as ws: + data = ws.receive_bytes() + assert data == b"SERVER_MESSAGE" + except WebSocketDisconnect: + pass async with httpx.AsyncClient( transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: - async with aconnect_ws(aclient, "http://socket/ws") as aws: - data = await aws.receive_bytes() - assert data == b"SERVER_MESSAGE" + try: + async with aconnect_ws(aclient, "http://socket/ws") as aws: + data = await aws.receive_bytes() + assert data == b"SERVER_MESSAGE" + except WebSocketDisconnect: + pass async def test_receive_bytes_invalid_type( self, server_factory: ServerFactoryFixture @@ -279,16 +329,22 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: - with connect_ws(client, "http://socket/ws") as ws: - data = ws.receive_json(mode=mode) - assert data == {"message": "SERVER_MESSAGE"} + try: + with connect_ws(client, "http://socket/ws") as ws: + data = ws.receive_json(mode=mode) + assert data == {"message": "SERVER_MESSAGE"} + except WebSocketDisconnect: + pass async with httpx.AsyncClient( transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: - async with aconnect_ws(aclient, "http://socket/ws") as aws: - data = await aws.receive_json(mode=mode) - assert data == {"message": "SERVER_MESSAGE"} + try: + async with aconnect_ws(aclient, "http://socket/ws") as aws: + data = await aws.receive_json(mode=mode) + assert data == {"message": "SERVER_MESSAGE"} + except WebSocketDisconnect: + pass @pytest.mark.asyncio From dd6b986d4ff7944823842ec36a621850ed34386d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 23 Nov 2022 15:33:31 +0000 Subject: [PATCH 014/108] Improve test Uvicorn servers so they shutdown gracefully --- tests/httpx2/websockets/conftest.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index c03a3acf..1833c6c0 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -35,7 +35,7 @@ def server_factory() -> ServerFactoryFixture: def _server_factory(endpoint: Callable): startup_queue: queue.Queue[bool] = queue.Queue() - async def start_uvicorn(socket: str): + def create_app() -> Starlette: routes = [ WebSocketRoute("/ws", endpoint=endpoint), ] @@ -43,17 +43,20 @@ async def start_uvicorn(socket: str): async def on_startup(): startup_queue.put(True) - app = Starlette(routes=routes, on_startup=[on_startup]) + return Starlette(routes=routes, on_startup=[on_startup]) + + def create_server(app: Starlette, socket: str): config = uvicorn.Config(app, uds=socket) - server = uvicorn.Server(config) - await server.serve() + return uvicorn.Server(config) with start_blocking_portal(backend="asyncio") as portal: with tempfile.TemporaryDirectory() as socket_directory: - socket_path = str(pathlib.Path(socket_directory) / "socket.sock") - future = portal.start_task_soon(start_uvicorn, socket_path) + socket = str(pathlib.Path(socket_directory) / "socket.sock") + app = create_app() + server = create_server(app, socket) + portal.start_task_soon(server.serve) startup_queue.get(True) - yield socket_path - future.cancel() + yield socket + server.should_exit = True return _server_factory From 748189a32b60fa17f39b9b0f9308ec14f52a5290 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 24 Nov 2022 10:56:37 +0100 Subject: [PATCH 015/108] Tweak conftest --- tests/httpx2/websockets/conftest.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index 1833c6c0..820ab1f6 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -3,8 +3,16 @@ import pathlib import queue import tempfile +import sys from typing import Callable, ContextManager from unittest.mock import MagicMock +import time + + +if sys.version_info < (3, 8): + from typing_extensions import Literal # pragma: no cover +else: + from typing import Literal # pragma: no cover import pytest import uvicorn @@ -26,6 +34,11 @@ def on_receive_message(): return MagicMock() +@pytest.fixture(params=("wsproto", "websockets")) +def websocket_implementation(request) -> Literal["wsproto", "websockets"]: + return request.param + + ServerFactoryFixture = Callable[[Callable], ContextManager[str]] @@ -34,6 +47,7 @@ def server_factory() -> ServerFactoryFixture: @contextlib.contextmanager def _server_factory(endpoint: Callable): startup_queue: queue.Queue[bool] = queue.Queue() + shutdown_queue: queue.Queue[bool] = queue.Queue() def create_app() -> Starlette: routes = [ @@ -46,17 +60,24 @@ async def on_startup(): return Starlette(routes=routes, on_startup=[on_startup]) def create_server(app: Starlette, socket: str): - config = uvicorn.Config(app, uds=socket) + config = uvicorn.Config(app, uds=socket, ws="wsproto") return uvicorn.Server(config) + def on_server_stopped(_task): + shutdown_queue.put(True) + with start_blocking_portal(backend="asyncio") as portal: with tempfile.TemporaryDirectory() as socket_directory: socket = str(pathlib.Path(socket_directory) / "socket.sock") app = create_app() server = create_server(app, socket) - portal.start_task_soon(server.serve) + task = portal.start_task_soon(server.serve) + task.add_done_callback(on_server_stopped) startup_queue.get(True) + time.sleep(1) yield socket + time.sleep(1) server.should_exit = True + shutdown_queue.get(True) return _server_factory From 0cad9bb7554766fd33204374db1efb8d3bd98b59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 24 Nov 2022 11:05:39 +0100 Subject: [PATCH 016/108] Tweak close methods --- src/httpx2/httpx2/_websockets/_api.py | 22 ++++++++++++++++------ tests/httpx2/websockets/conftest.py | 6 +----- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 54724710..9da8dfb3 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -10,8 +10,10 @@ else: from typing import Literal # pragma: no cover +import httpcore import httpx import wsproto +from httpcore.backends.base import AsyncNetworkStream, NetworkStream JSONMode = Literal["text", "binary"] @@ -40,7 +42,7 @@ def __init__(self, event: wsproto.events.Event) -> None: class WebSocketSession: def __init__(self, response: httpx.Response) -> None: - self.stream = response.extensions["network_stream"] + self.stream: NetworkStream = response.extensions["network_stream"] self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) def send(self, event: wsproto.events.Event) -> None: @@ -97,8 +99,12 @@ def receive_json( return json.loads(data) def close(self, code: int = 1000, reason: typing.Optional[str] = None): - event = wsproto.events.CloseConnection(code, reason) - self._send_event(event) + if self.connection.state != wsproto.connection.ConnectionState.CLOSED: + event = wsproto.events.CloseConnection(code, reason) + try: + self._send_event(event) + except httpcore.WriteError: + pass def _send_event(self, event: wsproto.events.Event): data = self.connection.send(event) @@ -107,7 +113,7 @@ def _send_event(self, event: wsproto.events.Event): class AsyncWebSocketSession: def __init__(self, response: httpx.Response) -> None: - self.stream = response.extensions["network_stream"] + self.stream: AsyncNetworkStream = response.extensions["network_stream"] self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) async def send(self, event: wsproto.events.Event) -> None: @@ -164,8 +170,12 @@ async def receive_json( return json.loads(data) async def close(self, code: int = 1000, reason: typing.Optional[str] = None): - event = wsproto.events.CloseConnection(code, reason) - await self._send_event(event) + if self.connection.state != wsproto.connection.ConnectionState.CLOSED: + event = wsproto.events.CloseConnection(code, reason) + try: + await self._send_event(event) + except httpcore.WriteError: + pass async def _send_event(self, event: wsproto.events.Event): data = self.connection.send(event) diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index 820ab1f6..11edb63b 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -2,12 +2,10 @@ import contextlib import pathlib import queue -import tempfile import sys +import tempfile from typing import Callable, ContextManager from unittest.mock import MagicMock -import time - if sys.version_info < (3, 8): from typing_extensions import Literal # pragma: no cover @@ -74,9 +72,7 @@ def on_server_stopped(_task): task = portal.start_task_soon(server.serve) task.add_done_callback(on_server_stopped) startup_queue.get(True) - time.sleep(1) yield socket - time.sleep(1) server.should_exit = True shutdown_queue.get(True) From 0c10cfc6ce2d2bdde5e51fdf2033c8a4fba1cfa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 24 Nov 2022 11:06:12 +0100 Subject: [PATCH 017/108] Test against uvicorn websockets impl --- tests/httpx2/websockets/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index 11edb63b..f842e9e4 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -58,7 +58,7 @@ async def on_startup(): return Starlette(routes=routes, on_startup=[on_startup]) def create_server(app: Starlette, socket: str): - config = uvicorn.Config(app, uds=socket, ws="wsproto") + config = uvicorn.Config(app, uds=socket, ws="websockets") return uvicorn.Server(config) def on_server_stopped(_task): From 46955b6e62861d048d69abdc3de7d9fbbb34b105 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 24 Nov 2022 12:35:45 +0100 Subject: [PATCH 018/108] Tweak events handling --- src/httpx2/httpx2/_websockets/_api.py | 33 +++++++++++++++------------ 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 9da8dfb3..3e28747c 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -1,4 +1,5 @@ import base64 +import collections import contextlib import json import secrets @@ -44,6 +45,7 @@ class WebSocketSession: def __init__(self, response: httpx.Response) -> None: self.stream: NetworkStream = response.extensions["network_stream"] self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) + self._events: typing.Deque[wsproto.events.Event] = collections.deque() def send(self, event: wsproto.events.Event) -> None: self._send_event(event) @@ -67,13 +69,14 @@ def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: def receive( self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES ) -> wsproto.events.Event: - data = self.stream.read(max_bytes=max_bytes) - self.connection.receive_data(data) - for event in self.connection.events(): - if isinstance(event, wsproto.events.CloseConnection): - raise WebSocketDisconnect(event.code, event.reason) - return event - raise HTTPXWSException() # pragma: no cover + while len(self._events) == 0: + data = self.stream.read(max_bytes=max_bytes) + self.connection.receive_data(data) + self._events.extend(self.connection.events()) + event = self._events.popleft() + if isinstance(event, wsproto.events.CloseConnection): + raise WebSocketDisconnect(event.code, event.reason) + return event def receive_text(self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES) -> str: event = self.receive(max_bytes) @@ -115,6 +118,7 @@ class AsyncWebSocketSession: def __init__(self, response: httpx.Response) -> None: self.stream: AsyncNetworkStream = response.extensions["network_stream"] self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) + self._events: typing.Deque[wsproto.events.Event] = collections.deque() async def send(self, event: wsproto.events.Event) -> None: await self._send_event(event) @@ -138,13 +142,14 @@ async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: async def receive( self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES ) -> wsproto.events.Event: - data = await self.stream.read(max_bytes=max_bytes) - self.connection.receive_data(data) - for event in self.connection.events(): - if isinstance(event, wsproto.events.CloseConnection): - raise WebSocketDisconnect(event.code, event.reason) - return event - raise HTTPXWSException() # pragma: no cover + while len(self._events) == 0: + data = await self.stream.read(max_bytes=max_bytes) + self.connection.receive_data(data) + self._events.extend(self.connection.events()) + event = self._events.popleft() + if isinstance(event, wsproto.events.CloseConnection): + raise WebSocketDisconnect(event.code, event.reason) + return event async def receive_text(self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES) -> str: event = await self.receive(max_bytes) From 6c240a4534613416d3f62d4f8d5765b38aa83723 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 24 Nov 2022 11:47:19 +0000 Subject: [PATCH 019/108] Fix unit tests by adding sleeps --- tests/httpx2/websockets/conftest.py | 4 ++-- tests/httpx2/websockets/test_api.py | 14 +++++++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index f842e9e4..2771c6d9 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -41,7 +41,7 @@ def websocket_implementation(request) -> Literal["wsproto", "websockets"]: @pytest.fixture -def server_factory() -> ServerFactoryFixture: +def server_factory(websocket_implementation: str) -> ServerFactoryFixture: @contextlib.contextmanager def _server_factory(endpoint: Callable): startup_queue: queue.Queue[bool] = queue.Queue() @@ -58,7 +58,7 @@ async def on_startup(): return Starlette(routes=routes, on_startup=[on_startup]) def create_server(app: Starlette, socket: str): - config = uvicorn.Config(app, uds=socket, ws="websockets") + config = uvicorn.Config(app, uds=socket, ws=websocket_implementation) return uvicorn.Server(config) def on_server_stopped(_task): diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 5bab013a..08bbc7a4 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -1,9 +1,11 @@ +import asyncio from unittest.mock import MagicMock, call import httpx import pytest import wsproto from starlette.websockets import WebSocket +from starlette.websockets import WebSocketDisconnect as StarletteWebSocketDisconnect from httpx_ws import ( JSONMode, @@ -184,6 +186,7 @@ class TestReceive: async def test_receive(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() + await asyncio.sleep(0.1) # FIXME: see #7 await websocket.send_text("SERVER_MESSAGE") @@ -213,6 +216,7 @@ async def websocket_endpoint(websocket: WebSocket): async def test_receive_text(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() + await asyncio.sleep(0.1) # FIXME: see #7 await websocket.send_text("SERVER_MESSAGE") @@ -242,6 +246,7 @@ async def test_receive_text_invalid_type( ): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() + await asyncio.sleep(0.1) # FIXME: see #7 await websocket.send_bytes(b"SERVER_MESSAGE") @@ -269,6 +274,7 @@ async def websocket_endpoint(websocket: WebSocket): async def test_receive_bytes(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() + await asyncio.sleep(0.1) # FIXME: see #7 await websocket.send_bytes(b"SERVER_MESSAGE") @@ -298,6 +304,7 @@ async def test_receive_bytes_invalid_type( ): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() + await asyncio.sleep(0.1) # FIXME: see #7 await websocket.send_text("SERVER_MESSAGE") @@ -322,6 +329,7 @@ async def test_receive_json( ): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() + await asyncio.sleep(0.1) # FIXME: see #7 await websocket.send_json({"message": "SERVER_MESSAGE"}, mode=mode) @@ -351,7 +359,10 @@ async def websocket_endpoint(websocket: WebSocket): async def test_send_close(server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await websocket.receive_text() + try: + await websocket.receive_text() + except StarletteWebSocketDisconnect: + await websocket.close() with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: @@ -369,6 +380,7 @@ async def websocket_endpoint(websocket: WebSocket): async def test_receive_close(server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() + await asyncio.sleep(0.1) # FIXME: see #7 await websocket.close() with server_factory(websocket_endpoint) as socket: From f55940773f39d2e8dec007a1319249d90f62f3fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 24 Nov 2022 13:04:38 +0100 Subject: [PATCH 020/108] Improve API --- src/httpx2/httpx2/_websockets/_api.py | 10 +++--- tests/httpx2/websockets/test_api.py | 52 +++++++++++++-------------- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 3e28747c..bc044464 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -198,12 +198,13 @@ def _get_headers() -> typing.Dict[str, typing.Any]: @contextlib.contextmanager def connect_ws( - client: httpx.Client, url: str, **kwargs: typing.Any + url: str, client: typing.Optional[httpx.Client] = None, **kwargs: typing.Any ) -> typing.Generator[WebSocketSession, None, None]: + client = httpx.Client() if client is None else client headers = kwargs.pop("headers", {}) headers.update(_get_headers()) - with client.stream("GET", url, headers=headers) as response: + with client.stream("GET", url, headers=headers, **kwargs) as response: if response.status_code != 101: raise WebSocketUpgradeError(response) @@ -214,12 +215,13 @@ def connect_ws( @contextlib.asynccontextmanager async def aconnect_ws( - client: httpx.AsyncClient, url: str, **kwargs: typing.Any + url: str, client: typing.Optional[httpx.AsyncClient] = None, **kwargs: typing.Any ) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: + client = httpx.AsyncClient() if client is None else client headers = kwargs.pop("headers", {}) headers.update(_get_headers()) - async with client.stream("GET", url, headers=headers) as response: + async with client.stream("GET", url, headers=headers, **kwargs) as response: if response.status_code != 101: raise WebSocketUpgradeError(response) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 08bbc7a4..24a5f53c 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -27,14 +27,14 @@ def handler(request): base_url="http://localhost:8000", transport=httpx.MockTransport(handler) ) as client: with pytest.raises(WebSocketUpgradeError): - with connect_ws(client, "http://socket/ws"): + with connect_ws("http://socket/ws", client): pass async with httpx.AsyncClient( base_url="http://localhost:8000", transport=httpx.MockTransport(handler) ) as client: with pytest.raises(WebSocketUpgradeError): - async with aconnect_ws(client, "http://socket/ws"): + async with aconnect_ws("http://socket/ws", client): pass @@ -56,7 +56,7 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: try: - with connect_ws(client, "http://socket/ws") as ws: + with connect_ws("http://socket/ws", client) as ws: ws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) except WebSocketDisconnect: pass @@ -65,7 +65,7 @@ async def websocket_endpoint(websocket: WebSocket): transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: try: - async with aconnect_ws(aclient, "http://socket/ws") as aws: + async with aconnect_ws("http://socket/ws", aclient) as aws: await aws.send( wsproto.events.TextMessage(data="CLIENT_MESSAGE") ) @@ -92,7 +92,7 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: try: - with connect_ws(client, "http://socket/ws") as ws: + with connect_ws("http://socket/ws", client) as ws: ws.send_text("CLIENT_MESSAGE") except WebSocketDisconnect: pass @@ -101,7 +101,7 @@ async def websocket_endpoint(websocket: WebSocket): transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: try: - async with aconnect_ws(aclient, "http://socket/ws") as aws: + async with aconnect_ws("http://socket/ws", aclient) as aws: await aws.send_text("CLIENT_MESSAGE") except WebSocketDisconnect: pass @@ -126,7 +126,7 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: try: - with connect_ws(client, "http://socket/ws") as ws: + with connect_ws("http://socket/ws", client) as ws: ws.send_bytes(b"CLIENT_MESSAGE") except WebSocketDisconnect: pass @@ -135,7 +135,7 @@ async def websocket_endpoint(websocket: WebSocket): transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: try: - async with aconnect_ws(aclient, "http://socket/ws") as aws: + async with aconnect_ws("http://socket/ws", aclient) as aws: await aws.send_bytes(b"CLIENT_MESSAGE") except WebSocketDisconnect: pass @@ -162,7 +162,7 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: try: - with connect_ws(client, "http://socket/ws") as ws: + with connect_ws("http://socket/ws", client) as ws: ws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) except WebSocketDisconnect: pass @@ -171,7 +171,7 @@ async def websocket_endpoint(websocket: WebSocket): transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: try: - async with aconnect_ws(aclient, "http://socket/ws") as aws: + async with aconnect_ws("http://socket/ws", aclient) as aws: await aws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) except WebSocketDisconnect: pass @@ -195,7 +195,7 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: try: - with connect_ws(client, "http://socket/ws") as ws: + with connect_ws("http://socket/ws", client) as ws: event = ws.receive() assert isinstance(event, wsproto.events.TextMessage) assert event.data == "SERVER_MESSAGE" @@ -206,7 +206,7 @@ async def websocket_endpoint(websocket: WebSocket): transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: try: - async with aconnect_ws(aclient, "http://socket/ws") as aws: + async with aconnect_ws("http://socket/ws", aclient) as aws: event = await aws.receive() assert isinstance(event, wsproto.events.TextMessage) assert event.data == "SERVER_MESSAGE" @@ -225,7 +225,7 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: try: - with connect_ws(client, "http://socket/ws") as ws: + with connect_ws("http://socket/ws", client) as ws: data = ws.receive_text() assert data == "SERVER_MESSAGE" except WebSocketDisconnect: @@ -235,7 +235,7 @@ async def websocket_endpoint(websocket: WebSocket): transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: try: - async with aconnect_ws(aclient, "http://socket/ws") as aws: + async with aconnect_ws("http://socket/ws", aclient) as aws: data = await aws.receive_text() assert data == "SERVER_MESSAGE" except WebSocketDisconnect: @@ -256,7 +256,7 @@ async def websocket_endpoint(websocket: WebSocket): with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: try: with pytest.raises(WebSocketInvalidTypeReceived): - with connect_ws(client, "http://socket/ws") as ws: + with connect_ws("http://socket/ws", client) as ws: ws.receive_text() except WebSocketDisconnect: pass @@ -266,7 +266,7 @@ async def websocket_endpoint(websocket: WebSocket): ) as aclient: try: with pytest.raises(WebSocketInvalidTypeReceived): - async with aconnect_ws(aclient, "http://socket/ws") as aws: + async with aconnect_ws("http://socket/ws", aclient) as aws: await aws.receive_text() except WebSocketDisconnect: pass @@ -283,7 +283,7 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: try: - with connect_ws(client, "http://socket/ws") as ws: + with connect_ws("http://socket/ws", client) as ws: data = ws.receive_bytes() assert data == b"SERVER_MESSAGE" except WebSocketDisconnect: @@ -293,7 +293,7 @@ async def websocket_endpoint(websocket: WebSocket): transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: try: - async with aconnect_ws(aclient, "http://socket/ws") as aws: + async with aconnect_ws("http://socket/ws", aclient) as aws: data = await aws.receive_bytes() assert data == b"SERVER_MESSAGE" except WebSocketDisconnect: @@ -313,14 +313,14 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: with pytest.raises(WebSocketInvalidTypeReceived): - with connect_ws(client, "http://socket/ws") as ws: + with connect_ws("http://socket/ws", client) as ws: ws.receive_bytes() async with httpx.AsyncClient( transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: with pytest.raises(WebSocketInvalidTypeReceived): - async with aconnect_ws(aclient, "http://socket/ws") as aws: + async with aconnect_ws("http://socket/ws", aclient) as aws: await aws.receive_bytes() @pytest.mark.parametrize("mode", ["text", "binary"]) @@ -338,7 +338,7 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: try: - with connect_ws(client, "http://socket/ws") as ws: + with connect_ws("http://socket/ws", client) as ws: data = ws.receive_json(mode=mode) assert data == {"message": "SERVER_MESSAGE"} except WebSocketDisconnect: @@ -348,7 +348,7 @@ async def websocket_endpoint(websocket: WebSocket): transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: try: - async with aconnect_ws(aclient, "http://socket/ws") as aws: + async with aconnect_ws("http://socket/ws", aclient) as aws: data = await aws.receive_json(mode=mode) assert data == {"message": "SERVER_MESSAGE"} except WebSocketDisconnect: @@ -366,13 +366,13 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: - with connect_ws(client, "http://socket/ws"): + with connect_ws("http://socket/ws", client): pass async with httpx.AsyncClient( transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: - async with aconnect_ws(aclient, "http://socket/ws"): + async with aconnect_ws("http://socket/ws", aclient): pass @@ -386,12 +386,12 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: with pytest.raises(WebSocketDisconnect): - with connect_ws(client, "http://socket/ws") as ws: + with connect_ws("http://socket/ws", client) as ws: ws.receive() async with httpx.AsyncClient( transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: with pytest.raises(WebSocketDisconnect): - async with aconnect_ws(aclient, "http://socket/ws") as aws: + async with aconnect_ws("http://socket/ws", aclient) as aws: await aws.receive() From 09084244bb6b903d9164f7960ac091cf3eca9d53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 24 Nov 2022 13:15:42 +0100 Subject: [PATCH 021/108] =?UTF-8?q?Bump=20version=200.0.0=20=E2=86=92=200.?= =?UTF-8?q?1.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit > ⚠️ This is a very young project. Expect bugs 🐛 Features -------- * `connect_ws` helper to talk to WebSockets synchronously. * `aconnect_ws` helper to talk to WebSockets asynchronously. * `ASGIWebSocketTransport` to test WebSockets in ASGI apps directly. --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 1c54e396..6abd7bdc 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.0" +__version__ = "0.1.0" from httpx_ws._api import ( AsyncWebSocketSession, From d948b564f7a4400444a66421a88cb3daaf954724 Mon Sep 17 00:00:00 2001 From: Kousik Mitra Date: Thu, 24 Nov 2022 21:54:12 +0530 Subject: [PATCH 022/108] #2 Add ping method (#8) * Add ping method (#2) * add optional arguement paylod to ping --- src/httpx2/httpx2/_websockets/_api.py | 8 +++++++ tests/httpx2/websockets/test_api.py | 31 +++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index bc044464..c6334b4f 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -47,6 +47,10 @@ def __init__(self, response: httpx.Response) -> None: self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) self._events: typing.Deque[wsproto.events.Event] = collections.deque() + def ping(self, payload: bytes = b"") -> None: + event = wsproto.events.Ping(payload) + self._send_event(event) + def send(self, event: wsproto.events.Event) -> None: self._send_event(event) @@ -120,6 +124,10 @@ def __init__(self, response: httpx.Response) -> None: self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) self._events: typing.Deque[wsproto.events.Event] = collections.deque() + async def ping(self, payload: bytes = b"") -> None: + event = wsproto.events.Ping(payload) + await self._send_event(event) + async def send(self, event: wsproto.events.Event) -> None: await self._send_event(event) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 24a5f53c..d2fad3ca 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -40,6 +40,37 @@ def handler(request): @pytest.mark.asyncio class TestSend: + async def test_ping(self, server_factory: ServerFactoryFixture): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + await asyncio.sleep(0.1) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + try: + with connect_ws("http://socket/ws", client) as ws: + ws.ping() + event = ws.receive() + assert isinstance(event, wsproto.events.Pong) + assert event.payload == b"" + except WebSocketDisconnect: + pass + + async with httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(uds=socket) + ) as aclient: + try: + async with aconnect_ws("http://socket/ws", aclient) as aws: + await aws.ping() + event = await aws.receive() + assert isinstance(event, wsproto.events.Pong) + assert event.payload == b"" + except WebSocketDisconnect: + pass + async def test_send_foo( self, server_factory: ServerFactoryFixture, From 56b010b490b0cae0f226b658cb8f192ebdcec8c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 25 Nov 2022 09:30:56 +0100 Subject: [PATCH 023/108] Fix test name --- tests/httpx2/websockets/test_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index d2fad3ca..c49c30b7 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -71,7 +71,7 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass - async def test_send_foo( + async def test_send( self, server_factory: ServerFactoryFixture, on_receive_message: MagicMock, From 96ef50d18cba9ea726e4a0109fb73d9876f0bad9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 25 Nov 2022 10:46:53 +0100 Subject: [PATCH 024/108] =?UTF-8?q?Bump=20version=200.1.0=20=E2=86=92=200.?= =?UTF-8?q?1.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Features -------- * Add a `.ping` method. Thanks @kousikmitra 🎉 Improvements ------------ * Pin lower bound version of `httpx` and `httpcore` dependencies * `httpx>=0.23.1` * `httpcore>=0.16.1` --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 6abd7bdc..9a75dc92 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.0" +__version__ = "0.1.1" from httpx_ws._api import ( AsyncWebSocketSession, From 3785e2a67aee88ad61f0d118366e0eb44425e78d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 25 Nov 2022 15:32:18 +0100 Subject: [PATCH 025/108] Implement auto ping respond --- src/httpx2/httpx2/_websockets/_api.py | 18 ++++--- tests/httpx2/websockets/conftest.py | 12 +++-- tests/httpx2/websockets/test_api.py | 69 +++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 10 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index c6334b4f..135a98c4 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -42,8 +42,8 @@ def __init__(self, event: wsproto.events.Event) -> None: class WebSocketSession: - def __init__(self, response: httpx.Response) -> None: - self.stream: NetworkStream = response.extensions["network_stream"] + def __init__(self, stream: NetworkStream) -> None: + self.stream = stream self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) self._events: typing.Deque[wsproto.events.Event] = collections.deque() @@ -78,6 +78,9 @@ def receive( self.connection.receive_data(data) self._events.extend(self.connection.events()) event = self._events.popleft() + if isinstance(event, wsproto.events.Ping): + self.send(event.response()) + return self.receive(max_bytes) if isinstance(event, wsproto.events.CloseConnection): raise WebSocketDisconnect(event.code, event.reason) return event @@ -119,8 +122,8 @@ def _send_event(self, event: wsproto.events.Event): class AsyncWebSocketSession: - def __init__(self, response: httpx.Response) -> None: - self.stream: AsyncNetworkStream = response.extensions["network_stream"] + def __init__(self, stream: AsyncNetworkStream) -> None: + self.stream = stream self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) self._events: typing.Deque[wsproto.events.Event] = collections.deque() @@ -155,6 +158,9 @@ async def receive( self.connection.receive_data(data) self._events.extend(self.connection.events()) event = self._events.popleft() + if isinstance(event, wsproto.events.Ping): + await self.send(event.response()) + return await self.receive(max_bytes) if isinstance(event, wsproto.events.CloseConnection): raise WebSocketDisconnect(event.code, event.reason) return event @@ -216,7 +222,7 @@ def connect_ws( if response.status_code != 101: raise WebSocketUpgradeError(response) - session = WebSocketSession(response) + session = WebSocketSession(response.extensions["network_stream"]) yield session session.close() @@ -233,6 +239,6 @@ async def aconnect_ws( if response.status_code != 101: raise WebSocketUpgradeError(response) - session = AsyncWebSocketSession(response) + session = AsyncWebSocketSession(response.extensions["network_stream"]) yield session await session.close() diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index 2771c6d9..dd385ec2 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -8,9 +8,9 @@ from unittest.mock import MagicMock if sys.version_info < (3, 8): - from typing_extensions import Literal # pragma: no cover + from typing_extensions import Literal, Protocol # pragma: no cover else: - from typing import Literal # pragma: no cover + from typing import Literal, Protocol # pragma: no cover import pytest import uvicorn @@ -37,11 +37,15 @@ def websocket_implementation(request) -> Literal["wsproto", "websockets"]: return request.param -ServerFactoryFixture = Callable[[Callable], ContextManager[str]] +class ServerFactoryFixture(Protocol): + def __call__(self, endpoint: Callable) -> ContextManager[str]: + ... @pytest.fixture -def server_factory(websocket_implementation: str) -> ServerFactoryFixture: +def server_factory( + websocket_implementation: Literal["wsproto", "websockets"] +) -> ServerFactoryFixture: @contextlib.contextmanager def _server_factory(endpoint: Callable): startup_queue: queue.Queue[bool] = queue.Queue() diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index c49c30b7..86f5b154 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -1,16 +1,20 @@ import asyncio +import typing from unittest.mock import MagicMock, call import httpx import pytest import wsproto +from httpcore.backends.base import AsyncNetworkStream, NetworkStream from starlette.websockets import WebSocket from starlette.websockets import WebSocketDisconnect as StarletteWebSocketDisconnect from httpx_ws import ( + AsyncWebSocketSession, JSONMode, WebSocketDisconnect, WebSocketInvalidTypeReceived, + WebSocketSession, WebSocketUpgradeError, aconnect_ws, connect_ws, @@ -386,6 +390,71 @@ async def websocket_endpoint(websocket: WebSocket): pass +class TestReceivePing: + def test_receive_ping(self): + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection( + wsproto.connection.ConnectionType.SERVER + ) + self.ping_sent = False + + def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + event: wsproto.events.Event + if not self.ping_sent: + event = wsproto.events.Ping(b"SERVER_PING") + self.ping_sent = True + else: + event = wsproto.events.TextMessage("SERVER_MESSAGE") + return self.connection.send(event) + + def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + self.connection.receive_data(buffer) + + stream = MockNetworkStream() + websocket_session = WebSocketSession(stream) + websocket_session.receive() + + received_events = list(stream.connection.events()) + assert received_events == [wsproto.events.Pong(b"SERVER_PING")] + + @pytest.mark.asyncio + async def test_async_receive_ping(self): + class MockAsyncNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection( + wsproto.connection.ConnectionType.SERVER + ) + self.ping_sent = False + + async def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + event: wsproto.events.Event + if not self.ping_sent: + event = wsproto.events.Ping(b"SERVER_PING") + self.ping_sent = True + else: + event = wsproto.events.TextMessage("SERVER_MESSAGE") + return self.connection.send(event) + + async def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + self.connection.receive_data(buffer) + + stream = MockAsyncNetworkStream() + websocket_session = AsyncWebSocketSession(stream) + await websocket_session.receive() + + received_events = list(stream.connection.events()) + assert received_events == [wsproto.events.Pong(b"SERVER_PING")] + + @pytest.mark.asyncio async def test_send_close(server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): From 0b573128c2778a74e15bd758f6eaae42872853d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 25 Nov 2022 16:12:47 +0100 Subject: [PATCH 026/108] Revamp receive implementation with a background thread --- src/httpx2/httpx2/_websockets/_api.py | 101 ++++++++++++++++---------- 1 file changed, 64 insertions(+), 37 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 135a98c4..f3406738 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -1,9 +1,11 @@ +import asyncio import base64 -import collections import contextlib import json +import queue import secrets import sys +import threading import typing if sys.version_info < (3, 8): @@ -45,7 +47,12 @@ class WebSocketSession: def __init__(self, stream: NetworkStream) -> None: self.stream = stream self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) - self._events: typing.Deque[wsproto.events.Event] = collections.deque() + self._events: queue.Queue[wsproto.events.Event] = queue.Queue() + self._should_close = False + self._background_receive_task = threading.Thread( + target=self._background_receive + ) + self._background_receive_task.start() def ping(self, payload: bytes = b"") -> None: event = wsproto.events.Ping(payload) @@ -70,45 +77,37 @@ def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: else: self.send_bytes(serialized_data.encode("utf-8")) - def receive( - self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES - ) -> wsproto.events.Event: - while len(self._events) == 0: - data = self.stream.read(max_bytes=max_bytes) - self.connection.receive_data(data) - self._events.extend(self.connection.events()) - event = self._events.popleft() - if isinstance(event, wsproto.events.Ping): - self.send(event.response()) - return self.receive(max_bytes) + def receive(self, timeout: typing.Optional[float] = None) -> wsproto.events.Event: + event = self._events.get(block=True, timeout=timeout) if isinstance(event, wsproto.events.CloseConnection): raise WebSocketDisconnect(event.code, event.reason) return event - def receive_text(self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES) -> str: - event = self.receive(max_bytes) + def receive_text(self, timeout: typing.Optional[float] = None) -> str: + event = self.receive(timeout) if isinstance(event, wsproto.events.TextMessage): return event.data raise WebSocketInvalidTypeReceived(event) - def receive_bytes(self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES) -> bytes: - event = self.receive(max_bytes) + def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: + event = self.receive(timeout) if isinstance(event, wsproto.events.BytesMessage): return event.data raise WebSocketInvalidTypeReceived(event) def receive_json( - self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES, mode: JSONMode = "text" + self, timeout: typing.Optional[float] = None, mode: JSONMode = "text" ) -> typing.Any: assert mode in ["text", "binary"] data: typing.Union[str, bytes] if mode == "text": - data = self.receive_text(max_bytes) + data = self.receive_text(timeout) elif mode == "binary": - data = self.receive_bytes(max_bytes) + data = self.receive_bytes(timeout) return json.loads(data) def close(self, code: int = 1000, reason: typing.Optional[str] = None): + self._should_close = True if self.connection.state != wsproto.connection.ConnectionState.CLOSED: event = wsproto.events.CloseConnection(code, reason) try: @@ -116,6 +115,21 @@ def close(self, code: int = 1000, reason: typing.Optional[str] = None): except httpcore.WriteError: pass + def _background_receive(self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES) -> None: + try: + while not self._should_close: + data = self.stream.read(max_bytes=max_bytes) + self.connection.receive_data(data) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Ping): + self.send(event.response()) + continue + if isinstance(event, wsproto.events.CloseConnection): + self._should_close = True + self._events.put(event) + except httpcore.ReadError: + pass + def _send_event(self, event: wsproto.events.Event): data = self.connection.send(event) self.stream.write(data) @@ -125,7 +139,9 @@ class AsyncWebSocketSession: def __init__(self, stream: AsyncNetworkStream) -> None: self.stream = stream self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) - self._events: typing.Deque[wsproto.events.Event] = collections.deque() + self._events: asyncio.Queue[wsproto.events.Event] = asyncio.Queue() + self._should_close = False + self._background_receive_task = asyncio.create_task(self._background_receive()) async def ping(self, payload: bytes = b"") -> None: event = wsproto.events.Ping(payload) @@ -151,44 +167,38 @@ async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: await self.send_bytes(serialized_data.encode("utf-8")) async def receive( - self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES + self, timeout: typing.Optional[float] = None ) -> wsproto.events.Event: - while len(self._events) == 0: - data = await self.stream.read(max_bytes=max_bytes) - self.connection.receive_data(data) - self._events.extend(self.connection.events()) - event = self._events.popleft() - if isinstance(event, wsproto.events.Ping): - await self.send(event.response()) - return await self.receive(max_bytes) + event = await self._events.get() if isinstance(event, wsproto.events.CloseConnection): raise WebSocketDisconnect(event.code, event.reason) return event - async def receive_text(self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES) -> str: - event = await self.receive(max_bytes) + async def receive_text(self, timeout: typing.Optional[float] = None) -> str: + event = await self.receive(timeout) if isinstance(event, wsproto.events.TextMessage): return event.data raise WebSocketInvalidTypeReceived(event) - async def receive_bytes(self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES) -> bytes: - event = await self.receive(max_bytes) + async def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: + event = await self.receive(timeout) if isinstance(event, wsproto.events.BytesMessage): return event.data raise WebSocketInvalidTypeReceived(event) async def receive_json( - self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES, mode: JSONMode = "text" + self, timeout: typing.Optional[float] = None, mode: JSONMode = "text" ) -> typing.Any: assert mode in ["text", "binary"] data: typing.Union[str, bytes] if mode == "text": - data = await self.receive_text(max_bytes) + data = await self.receive_text(timeout) elif mode == "binary": - data = await self.receive_bytes(max_bytes) + data = await self.receive_bytes(timeout) return json.loads(data) async def close(self, code: int = 1000, reason: typing.Optional[str] = None): + self._should_close = True if self.connection.state != wsproto.connection.ConnectionState.CLOSED: event = wsproto.events.CloseConnection(code, reason) try: @@ -196,6 +206,23 @@ async def close(self, code: int = 1000, reason: typing.Optional[str] = None): except httpcore.WriteError: pass + async def _background_receive( + self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES + ) -> None: + try: + while not self._should_close: + data = await self.stream.read(max_bytes=max_bytes) + self.connection.receive_data(data) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Ping): + await self.send(event.response()) + continue + if isinstance(event, wsproto.events.CloseConnection): + self._should_close = True + await self._events.put(event) + except httpcore.ReadError: + pass + async def _send_event(self, event: wsproto.events.Event): data = self.connection.send(event) await self.stream.write(data) From 8189b6def1ad4eef7b133ae8cae4901046204840 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 25 Nov 2022 16:51:56 +0100 Subject: [PATCH 027/108] Implement ping-pong handling --- src/httpx2/httpx2/_websockets/_api.py | 26 +++++++-- src/httpx2/httpx2/_websockets/_ping.py | 43 +++++++++++++++ tests/httpx2/websockets/test_api.py | 76 ++++++++++++++++++-------- 3 files changed, 118 insertions(+), 27 deletions(-) create mode 100644 src/httpx2/httpx2/_websockets/_ping.py diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index f3406738..287d18b1 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -18,6 +18,8 @@ import wsproto from httpcore.backends.base import AsyncNetworkStream, NetworkStream +from httpx_ws._ping import AsyncPingManager, PingManager + JSONMode = Literal["text", "binary"] DEFAULT_RECEIVE_MAX_BYTES = 65_536 @@ -48,15 +50,20 @@ def __init__(self, stream: NetworkStream) -> None: self.stream = stream self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) self._events: queue.Queue[wsproto.events.Event] = queue.Queue() + + self._ping_manager = PingManager() + self._should_close = False self._background_receive_task = threading.Thread( target=self._background_receive ) self._background_receive_task.start() - def ping(self, payload: bytes = b"") -> None: - event = wsproto.events.Ping(payload) + def ping(self, payload: bytes = b"") -> threading.Event: + ping_id, callback = self._ping_manager.create(payload) + event = wsproto.events.Ping(ping_id) self._send_event(event) + return callback def send(self, event: wsproto.events.Event) -> None: self._send_event(event) @@ -124,6 +131,9 @@ def _background_receive(self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES) -> Non if isinstance(event, wsproto.events.Ping): self.send(event.response()) continue + if isinstance(event, wsproto.events.Pong): + self._ping_manager.ack(event.payload) + continue if isinstance(event, wsproto.events.CloseConnection): self._should_close = True self._events.put(event) @@ -140,12 +150,17 @@ def __init__(self, stream: AsyncNetworkStream) -> None: self.stream = stream self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) self._events: asyncio.Queue[wsproto.events.Event] = asyncio.Queue() + + self._ping_manager = AsyncPingManager() + self._should_close = False self._background_receive_task = asyncio.create_task(self._background_receive()) - async def ping(self, payload: bytes = b"") -> None: - event = wsproto.events.Ping(payload) + async def ping(self, payload: bytes = b"") -> asyncio.Event: + ping_id, callback = self._ping_manager.create(payload) + event = wsproto.events.Ping(ping_id) await self._send_event(event) + return callback async def send(self, event: wsproto.events.Event) -> None: await self._send_event(event) @@ -217,6 +232,9 @@ async def _background_receive( if isinstance(event, wsproto.events.Ping): await self.send(event.response()) continue + if isinstance(event, wsproto.events.Pong): + self._ping_manager.ack(event.payload) + continue if isinstance(event, wsproto.events.CloseConnection): self._should_close = True await self._events.put(event) diff --git a/src/httpx2/httpx2/_websockets/_ping.py b/src/httpx2/httpx2/_websockets/_ping.py new file mode 100644 index 00000000..c7ff025e --- /dev/null +++ b/src/httpx2/httpx2/_websockets/_ping.py @@ -0,0 +1,43 @@ +import asyncio +import secrets +import threading +import typing + + +class PingManagerBase: + def _generate_id(self) -> bytes: + return secrets.token_bytes() + + +class PingManager(PingManagerBase): + def __init__(self) -> None: + self._pings: typing.Dict[bytes, threading.Event] = {} + + def create( + self, ping_id: typing.Optional[bytes] = None + ) -> typing.Tuple[bytes, threading.Event]: + ping_id = self._generate_id() if not ping_id else ping_id + event = threading.Event() + self._pings[ping_id] = event + return ping_id, event + + def ack(self, ping_id: typing.Union[bytes, bytearray]): + event = self._pings.pop(bytes(ping_id)) + event.set() + + +class AsyncPingManager(PingManagerBase): + def __init__(self) -> None: + self._pings: typing.Dict[bytes, asyncio.Event] = {} + + def create( + self, ping_id: typing.Optional[bytes] = None + ) -> typing.Tuple[bytes, asyncio.Event]: + ping_id = self._generate_id() if not ping_id else ping_id + event = asyncio.Event() + self._pings[ping_id] = event + return ping_id, event + + def ack(self, ping_id: typing.Union[bytes, bytearray]): + event = self._pings.pop(bytes(ping_id)) + event.set() diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 86f5b154..7b6f14d7 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -2,6 +2,7 @@ import typing from unittest.mock import MagicMock, call +import httpcore import httpx import pytest import wsproto @@ -390,25 +391,24 @@ async def websocket_endpoint(websocket: WebSocket): pass +@pytest.mark.asyncio class TestReceivePing: - def test_receive_ping(self): + async def test_receive_ping(self): class MockNetworkStream(NetworkStream): def __init__(self) -> None: self.connection = wsproto.connection.Connection( wsproto.connection.ConnectionType.SERVER ) - self.ping_sent = False + self.events_to_send = [wsproto.events.Ping(b"SERVER_PING")] def read( self, max_bytes: int, timeout: typing.Optional[float] = None ) -> bytes: - event: wsproto.events.Event - if not self.ping_sent: - event = wsproto.events.Ping(b"SERVER_PING") - self.ping_sent = True - else: - event = wsproto.events.TextMessage("SERVER_MESSAGE") - return self.connection.send(event) + try: + event = self.events_to_send.pop(0) + return self.connection.send(event) + except IndexError: + raise httpcore.ReadError() def write( self, buffer: bytes, timeout: typing.Optional[float] = None @@ -417,30 +417,31 @@ def write( stream = MockNetworkStream() websocket_session = WebSocketSession(stream) - websocket_session.receive() + await asyncio.sleep(0.1) + websocket_session.close() received_events = list(stream.connection.events()) - assert received_events == [wsproto.events.Pong(b"SERVER_PING")] + assert received_events == [ + wsproto.events.Pong(b"SERVER_PING"), + wsproto.events.CloseConnection(1000, ""), + ] - @pytest.mark.asyncio async def test_async_receive_ping(self): class MockAsyncNetworkStream(AsyncNetworkStream): def __init__(self) -> None: self.connection = wsproto.connection.Connection( wsproto.connection.ConnectionType.SERVER ) - self.ping_sent = False + self.events_to_send = [wsproto.events.Ping(b"SERVER_PING")] async def read( self, max_bytes: int, timeout: typing.Optional[float] = None ) -> bytes: - event: wsproto.events.Event - if not self.ping_sent: - event = wsproto.events.Ping(b"SERVER_PING") - self.ping_sent = True - else: - event = wsproto.events.TextMessage("SERVER_MESSAGE") - return self.connection.send(event) + try: + event = self.events_to_send.pop(0) + return self.connection.send(event) + except IndexError: + raise httpcore.ReadError() async def write( self, buffer: bytes, timeout: typing.Optional[float] = None @@ -449,10 +450,39 @@ async def write( stream = MockAsyncNetworkStream() websocket_session = AsyncWebSocketSession(stream) - await websocket_session.receive() + await asyncio.sleep(0.1) + await websocket_session.close() received_events = list(stream.connection.events()) - assert received_events == [wsproto.events.Pong(b"SERVER_PING")] + assert received_events == [ + wsproto.events.Pong(b"SERVER_PING"), + wsproto.events.CloseConnection(1000, ""), + ] + + +@pytest.mark.asyncio +async def test_ping_pong(server_factory: ServerFactoryFixture): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + try: + await websocket.receive_text() + except StarletteWebSocketDisconnect: + pass + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with connect_ws("http://socket/ws", client) as ws: + ping_callback = ws.ping() + result = ping_callback.wait() + assert result is True + + async with httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(uds=socket) + ) as aclient: + async with aconnect_ws("http://socket/ws", aclient) as aws: + aping_callback = await aws.ping() + aresult = await aping_callback.wait() + assert aresult is True @pytest.mark.asyncio @@ -462,7 +492,7 @@ async def websocket_endpoint(websocket: WebSocket): try: await websocket.receive_text() except StarletteWebSocketDisconnect: - await websocket.close() + pass with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: From b49462fa30d12444fc1dba4e0abbbf2e7e6387f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 25 Nov 2022 17:12:24 +0100 Subject: [PATCH 028/108] Add timeout handling for async receive --- src/httpx2/httpx2/_websockets/_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 287d18b1..53bc0dde 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -184,7 +184,7 @@ async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: async def receive( self, timeout: typing.Optional[float] = None ) -> wsproto.events.Event: - event = await self._events.get() + event = await asyncio.wait_for(self._events.get(), timeout) if isinstance(event, wsproto.events.CloseConnection): raise WebSocketDisconnect(event.code, event.reason) return event From 765bd2cfc517ddd682a9a79f5d358757bee01549 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 25 Nov 2022 17:32:07 +0100 Subject: [PATCH 029/108] Add parameters to control message and queue size --- src/httpx2/httpx2/_websockets/_api.py | 61 ++++++++++++++++++++------- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 53bc0dde..2d01f0d0 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -22,7 +22,8 @@ JSONMode = Literal["text", "binary"] -DEFAULT_RECEIVE_MAX_BYTES = 65_536 +DEFAULT_MAX_MESSAGE_SIZE_BYTES = 65_536 +DEFAULT_QUEUE_SIZE = 512 class HTTPXWSException(Exception): @@ -46,16 +47,22 @@ def __init__(self, event: wsproto.events.Event) -> None: class WebSocketSession: - def __init__(self, stream: NetworkStream) -> None: + def __init__( + self, + stream: NetworkStream, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + ) -> None: self.stream = stream self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) - self._events: queue.Queue[wsproto.events.Event] = queue.Queue() + self._events: queue.Queue[wsproto.events.Event] = queue.Queue(queue_size) self._ping_manager = PingManager() self._should_close = False self._background_receive_task = threading.Thread( - target=self._background_receive + target=self._background_receive, args=(max_message_size_bytes,) ) self._background_receive_task.start() @@ -122,7 +129,7 @@ def close(self, code: int = 1000, reason: typing.Optional[str] = None): except httpcore.WriteError: pass - def _background_receive(self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES) -> None: + def _background_receive(self, max_bytes: int) -> None: try: while not self._should_close: data = self.stream.read(max_bytes=max_bytes) @@ -146,15 +153,23 @@ def _send_event(self, event: wsproto.events.Event): class AsyncWebSocketSession: - def __init__(self, stream: AsyncNetworkStream) -> None: + def __init__( + self, + stream: AsyncNetworkStream, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + ) -> None: self.stream = stream self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) - self._events: asyncio.Queue[wsproto.events.Event] = asyncio.Queue() + self._events: asyncio.Queue[wsproto.events.Event] = asyncio.Queue(queue_size) self._ping_manager = AsyncPingManager() self._should_close = False - self._background_receive_task = asyncio.create_task(self._background_receive()) + self._background_receive_task = asyncio.create_task( + self._background_receive(max_message_size_bytes) + ) async def ping(self, payload: bytes = b"") -> asyncio.Event: ping_id, callback = self._ping_manager.create(payload) @@ -221,9 +236,7 @@ async def close(self, code: int = 1000, reason: typing.Optional[str] = None): except httpcore.WriteError: pass - async def _background_receive( - self, max_bytes: int = DEFAULT_RECEIVE_MAX_BYTES - ) -> None: + async def _background_receive(self, max_bytes: int) -> None: try: while not self._should_close: data = await self.stream.read(max_bytes=max_bytes) @@ -257,7 +270,12 @@ def _get_headers() -> typing.Dict[str, typing.Any]: @contextlib.contextmanager def connect_ws( - url: str, client: typing.Optional[httpx.Client] = None, **kwargs: typing.Any + url: str, + client: typing.Optional[httpx.Client] = None, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + **kwargs: typing.Any, ) -> typing.Generator[WebSocketSession, None, None]: client = httpx.Client() if client is None else client headers = kwargs.pop("headers", {}) @@ -267,14 +285,23 @@ def connect_ws( if response.status_code != 101: raise WebSocketUpgradeError(response) - session = WebSocketSession(response.extensions["network_stream"]) + session = WebSocketSession( + response.extensions["network_stream"], + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + ) yield session session.close() @contextlib.asynccontextmanager async def aconnect_ws( - url: str, client: typing.Optional[httpx.AsyncClient] = None, **kwargs: typing.Any + url: str, + client: typing.Optional[httpx.AsyncClient] = None, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + **kwargs: typing.Any, ) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: client = httpx.AsyncClient() if client is None else client headers = kwargs.pop("headers", {}) @@ -284,6 +311,10 @@ async def aconnect_ws( if response.status_code != 101: raise WebSocketUpgradeError(response) - session = AsyncWebSocketSession(response.extensions["network_stream"]) + session = AsyncWebSocketSession( + response.extensions["network_stream"], + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + ) yield session await session.close() From 5c00b25e5349007ce820de65528cc40edb4d61fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 25 Nov 2022 18:46:17 +0100 Subject: [PATCH 030/108] Add a sleep in transport receive to let it yield to the event loop --- src/httpx2/httpx2/_websockets/transport.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index 69d582cc..a4b5ebd7 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -101,6 +101,8 @@ async def send(self, message: Message) -> None: self._receive_queue.put(message) async def receive(self, timeout: typing.Optional[float] = None) -> Message: + while self._send_queue.empty(): + await anyio.sleep(0) return self._send_queue.get(timeout=timeout) async def _run(self) -> None: From 1c632948a2735265fa3aa2bb5a12f9c063c5df4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Sun, 27 Nov 2022 11:21:21 +0100 Subject: [PATCH 031/108] Add docstrings for session classes --- src/httpx2/httpx2/_websockets/_api.py | 542 +++++++++++++++++++++++++- tests/httpx2/websockets/test_api.py | 31 -- 2 files changed, 528 insertions(+), 45 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 2d01f0d0..94570631 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -27,26 +27,52 @@ class HTTPXWSException(Exception): + """ + Base exception class for HTTPX WS. + """ + pass class WebSocketUpgradeError(HTTPXWSException): + """ + Raised when the initial connection didn't correctly upgrade to a WebSocket session. + """ + def __init__(self, response: httpx.Response) -> None: self.response = response class WebSocketDisconnect(HTTPXWSException): + """ + Raised when the server closed the WebSocket session. + + Args: + code: + The integer close code to indicate why the connection has closed. + reason: + Additional reasoning for why the connection has closed. + """ + def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: self.code = code self.reason = reason or "" class WebSocketInvalidTypeReceived(HTTPXWSException): + """ + Raised when a event is not of the expected type. + """ + def __init__(self, event: wsproto.events.Event) -> None: self.event = event class WebSocketSession: + """ + Sync helper representing an opened WebSocket session. + """ + def __init__( self, stream: NetworkStream, @@ -67,23 +93,98 @@ def __init__( self._background_receive_task.start() def ping(self, payload: bytes = b"") -> threading.Event: + """ + Send a Ping message. + + Args: + payload: + Payload to attach to the Ping event. + Internally, it's used to track this specific event. + If left empty, a random one will be generated. + + Returns: + An event that can be used to wait for the corresponding Pong response. + + Examples: + Send a Ping and wait for the Pong + + pong_callback = ws.ping() + # Will block until the corresponding Pong is received. + pong_callback.wait() + """ ping_id, callback = self._ping_manager.create(payload) event = wsproto.events.Ping(ping_id) - self._send_event(event) + self.send(event) return callback def send(self, event: wsproto.events.Event) -> None: - self._send_event(event) + """ + Send an Event message. + + Mainly useful to send events that are not supported by the library. + Most of the time, [ping()][httpx_ws.WebSocketSession.ping], + [send_text()][httpx_ws.WebSocketSession.send_text], + [send_bytes()][httpx_ws.WebSocketSession.send_bytes] + and [send_json()][httpx_ws.WebSocketSession.send_json] are preferred. + + Args: + event: The event to send. + + Examples: + Send an event. + + event = wsproto.events.Message(b"Hello!") + ws.send(event) + """ + data = self.connection.send(event) + self.stream.write(data) def send_text(self, data: str) -> None: + """ + Send a text message. + + Args: + data: The text to send. + + Examples: + Send a text message. + + ws.send_text("Hello!") + """ event = wsproto.events.TextMessage(data=data) self.send(event) def send_bytes(self, data: bytes) -> None: + """ + Send a bytes message. + + Args: + data: The data to send. + + Examples: + Send a bytes message. + + ws.send_bytes(b"Hello!") + """ event = wsproto.events.BytesMessage(data=data) self.send(event) def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: + """ + Send JSON data. + + Args: + data: + The data to send. Must be serializable by [json.dumps][json.dumps]. + mode: + The sending mode. Should either be `'text'` or `'bytes'`. + + Examples: + Send JSON data. + + data = {"message": "Hello!"} + ws.send_json(data) + """ assert mode in ["text", "binary"] serialized_data = json.dumps(data) if mode == "text": @@ -92,18 +193,121 @@ def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: self.send_bytes(serialized_data.encode("utf-8")) def receive(self, timeout: typing.Optional[float] = None) -> wsproto.events.Event: + """ + Receive an event from the server. + + Mainly useful to receive raw [wsproto.events.Event][wsproto.events.Event]. + Most of the time, [receive_text()][httpx_ws.WebSocketSession.receive_text], + [receive_bytes()][httpx_ws.WebSocketSession.receive_bytes], + and [receive_json()][httpx_ws.WebSocketSession.receive_json] are preferred. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + A raw [wsproto.events.Event][wsproto.events.Event]. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + + Examples: + Wait for an event until one is available. + + try: + event = ws.receive() + except WebSocketDisconnect: + print("Connection closed") + + Wait for an event for 2 seconds. + + try: + event = ws.receive(timeout=2.) + except queue.Empty: + print("No event received.") + except WebSocketDisconnect: + print("Connection closed") + """ event = self._events.get(block=True, timeout=timeout) if isinstance(event, wsproto.events.CloseConnection): raise WebSocketDisconnect(event.code, event.reason) return event def receive_text(self, timeout: typing.Optional[float] = None) -> str: + """ + Receive text from the server. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + Text data. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketInvalidTypeReceived: The received event was not a text message. + + Examples: + Wait for text until available. + + try: + text = ws.receive_text() + except WebSocketDisconnect: + print("Connection closed") + + Wait for text for 2 seconds. + + try: + event = ws.receive_text(timeout=2.) + except queue.Empty: + print("No text received.") + except WebSocketDisconnect: + print("Connection closed") + """ event = self.receive(timeout) if isinstance(event, wsproto.events.TextMessage): return event.data raise WebSocketInvalidTypeReceived(event) def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: + """ + Receive bytes from the server. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + Bytes data. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketInvalidTypeReceived: The received event was not a bytes message. + + Examples: + Wait for bytes until available. + + try: + data = ws.receive_bytes() + except WebSocketDisconnect: + print("Connection closed") + + Wait for bytes for 2 seconds. + + try: + data = ws.receive_bytes(timeout=2.) + except queue.Empty: + print("No data received.") + except WebSocketDisconnect: + print("Connection closed") + """ event = self.receive(timeout) if isinstance(event, wsproto.events.BytesMessage): return event.data @@ -112,6 +316,44 @@ def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: def receive_json( self, timeout: typing.Optional[float] = None, mode: JSONMode = "text" ) -> typing.Any: + """ + Receive JSON data from the server. + + The received data should be parseable by [json.loads][json.loads]. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + mode: + Receive mode. Should either be `'text'` or `'bytes'`. + + Returns: + Parsed JSON data. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketInvalidTypeReceived: + The received event didn't correspond to the specified mode. + + Examples: + Wait for data until available. + + try: + data = ws.receive_json() + except WebSocketDisconnect: + print("Connection closed") + + Wait for data for 2 seconds. + + try: + data = ws.receive_json(timeout=2.) + except queue.Empty: + print("No data received.") + except WebSocketDisconnect: + print("Connection closed") + """ assert mode in ["text", "binary"] data: typing.Union[str, bytes] if mode == "text": @@ -121,15 +363,45 @@ def receive_json( return json.loads(data) def close(self, code: int = 1000, reason: typing.Optional[str] = None): + """ + Close the WebSocket session. + + Internally, it'll send the + [CloseConnection][wsproto.events.CloseConnection] event. + + Args: + code: + The integer close code to indicate why the connection has closed. + reason: + Additional reasoning for why the connection has closed. + + Examples: + Close the WebSocket session. + + ws.close() + """ self._should_close = True if self.connection.state != wsproto.connection.ConnectionState.CLOSED: event = wsproto.events.CloseConnection(code, reason) try: - self._send_event(event) + self.send(event) except httpcore.WriteError: pass def _background_receive(self, max_bytes: int) -> None: + """ + Background thread listening for data from the server. + + Internally, it'll: + + * Answer to Ping events. + * Acknowledge Pong events. + * Put other events in the [_events][_events] + queue that'll eventually be consumed by the user. + + Args: + max_bytes: The maximum chunk size to read at each iteration. + """ try: while not self._should_close: data = self.stream.read(max_bytes=max_bytes) @@ -147,12 +419,12 @@ def _background_receive(self, max_bytes: int) -> None: except httpcore.ReadError: pass - def _send_event(self, event: wsproto.events.Event): - data = self.connection.send(event) - self.stream.write(data) - class AsyncWebSocketSession: + """ + Async helper representing an opened WebSocket session. + """ + def __init__( self, stream: AsyncNetworkStream, @@ -172,23 +444,98 @@ def __init__( ) async def ping(self, payload: bytes = b"") -> asyncio.Event: + """ + Send a Ping message. + + Args: + payload: + Payload to attach to the Ping event. + Internally, it's used to track this specific event. + If left empty, a random one will be generated. + + Returns: + An event that can be used to wait for the corresponding Pong response. + + Examples: + Send a Ping and wait for the Pong + + pong_callback = await ws.ping() + # Will block until the corresponding Pong is received. + await pong_callback.wait() + """ ping_id, callback = self._ping_manager.create(payload) event = wsproto.events.Ping(ping_id) - await self._send_event(event) + await self.send(event) return callback async def send(self, event: wsproto.events.Event) -> None: - await self._send_event(event) + """ + Send an Event message. + + Mainly useful to send events that are not supported by the library. + Most of the time, [ping()][httpx_ws.AsyncWebSocketSession.ping], + [send_text()][httpx_ws.AsyncWebSocketSession.send_text], + [send_bytes()][httpx_ws.AsyncWebSocketSession.send_bytes] + and [send_json()][httpx_ws.AsyncWebSocketSession.send_json] are preferred. + + Args: + event: The event to send. + + Examples: + Send an event. + + event = await wsproto.events.Message(b"Hello!") + ws.send(event) + """ + data = self.connection.send(event) + await self.stream.write(data) async def send_text(self, data: str) -> None: + """ + Send a text message. + + Args: + data: The text to send. + + Examples: + Send a text message. + + await ws.send_text("Hello!") + """ event = wsproto.events.TextMessage(data=data) await self.send(event) async def send_bytes(self, data: bytes) -> None: + """ + Send a bytes message. + + Args: + data: The data to send. + + Examples: + Send a bytes message. + + await ws.send_bytes(b"Hello!") + """ event = wsproto.events.BytesMessage(data=data) await self.send(event) async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: + """ + Send JSON data. + + Args: + data: + The data to send. Must be serializable by [json.dumps][json.dumps]. + mode: + The sending mode. Should either be `'text'` or `'bytes'`. + + Examples: + Send JSON data. + + data = {"message": "Hello!"} + await ws.send_json(data) + """ assert mode in ["text", "binary"] serialized_data = json.dumps(data) if mode == "text": @@ -199,18 +546,121 @@ async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: async def receive( self, timeout: typing.Optional[float] = None ) -> wsproto.events.Event: + """ + Receive an event from the server. + + Mainly useful to receive raw [wsproto.events.Event][wsproto.events.Event]. + Most of the time, [receive_text()][httpx_ws.AsyncWebSocketSession.receive_text], + [receive_bytes()][httpx_ws.AsyncWebSocketSession.receive_bytes], + and [receive_json()][httpx_ws.AsyncWebSocketSession.receive_json] are preferred. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + A raw [wsproto.events.Event][wsproto.events.Event]. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + + Examples: + Wait for an event until one is available. + + try: + event = await ws.receive() + except WebSocketDisconnect: + print("Connection closed") + + Wait for an event for 2 seconds. + + try: + event = await ws.receive(timeout=2.) + except queue.Empty: + print("No event received.") + except WebSocketDisconnect: + print("Connection closed") + """ event = await asyncio.wait_for(self._events.get(), timeout) if isinstance(event, wsproto.events.CloseConnection): raise WebSocketDisconnect(event.code, event.reason) return event async def receive_text(self, timeout: typing.Optional[float] = None) -> str: + """ + Receive text from the server. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + Text data. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketInvalidTypeReceived: The received event was not a text message. + + Examples: + Wait for text until available. + + try: + text = await ws.receive_text() + except WebSocketDisconnect: + print("Connection closed") + + Wait for text for 2 seconds. + + try: + event = await ws.receive_text(timeout=2.) + except queue.Empty: + print("No text received.") + except WebSocketDisconnect: + print("Connection closed") + """ event = await self.receive(timeout) if isinstance(event, wsproto.events.TextMessage): return event.data raise WebSocketInvalidTypeReceived(event) async def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: + """ + Receive bytes from the server. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + Bytes data. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketInvalidTypeReceived: The received event was not a bytes message. + + Examples: + Wait for bytes until available. + + try: + data = await ws.receive_bytes() + except WebSocketDisconnect: + print("Connection closed") + + Wait for bytes for 2 seconds. + + try: + data = await ws.receive_bytes(timeout=2.) + except queue.Empty: + print("No data received.") + except WebSocketDisconnect: + print("Connection closed") + """ event = await self.receive(timeout) if isinstance(event, wsproto.events.BytesMessage): return event.data @@ -219,6 +669,44 @@ async def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: async def receive_json( self, timeout: typing.Optional[float] = None, mode: JSONMode = "text" ) -> typing.Any: + """ + Receive JSON data from the server. + + The received data should be parseable by [json.loads][json.loads]. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + mode: + Receive mode. Should either be `'text'` or `'bytes'`. + + Returns: + Parsed JSON data. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketInvalidTypeReceived: + The received event didn't correspond to the specified mode. + + Examples: + Wait for data until available. + + try: + data = await ws.receive_json() + except WebSocketDisconnect: + print("Connection closed") + + Wait for data for 2 seconds. + + try: + data = await ws.receive_json(timeout=2.) + except queue.Empty: + print("No data received.") + except WebSocketDisconnect: + print("Connection closed") + """ assert mode in ["text", "binary"] data: typing.Union[str, bytes] if mode == "text": @@ -228,15 +716,45 @@ async def receive_json( return json.loads(data) async def close(self, code: int = 1000, reason: typing.Optional[str] = None): + """ + Close the WebSocket session. + + Internally, it'll send the + [CloseConnection][wsproto.events.CloseConnection] event. + + Args: + code: + The integer close code to indicate why the connection has closed. + reason: + Additional reasoning for why the connection has closed. + + Examples: + Close the WebSocket session. + + await ws.close() + """ self._should_close = True if self.connection.state != wsproto.connection.ConnectionState.CLOSED: event = wsproto.events.CloseConnection(code, reason) try: - await self._send_event(event) + await self.send(event) except httpcore.WriteError: pass async def _background_receive(self, max_bytes: int) -> None: + """ + Background task listening for data from the server. + + Internally, it'll: + + * Answer to Ping events. + * Acknowledge Pong events. + * Put other events in the [_events][_events] + queue that'll eventually be consumed by the user. + + Args: + max_bytes: The maximum chunk size to read at each iteration. + """ try: while not self._should_close: data = await self.stream.read(max_bytes=max_bytes) @@ -254,10 +772,6 @@ async def _background_receive(self, max_bytes: int) -> None: except httpcore.ReadError: pass - async def _send_event(self, event: wsproto.events.Event): - data = self.connection.send(event) - await self.stream.write(data) - def _get_headers() -> typing.Dict[str, typing.Any]: return { diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 7b6f14d7..b709a345 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -45,37 +45,6 @@ def handler(request): @pytest.mark.asyncio class TestSend: - async def test_ping(self, server_factory: ServerFactoryFixture): - async def websocket_endpoint(websocket: WebSocket): - await websocket.accept() - - await asyncio.sleep(0.1) - - await websocket.close() - - with server_factory(websocket_endpoint) as socket: - with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: - try: - with connect_ws("http://socket/ws", client) as ws: - ws.ping() - event = ws.receive() - assert isinstance(event, wsproto.events.Pong) - assert event.payload == b"" - except WebSocketDisconnect: - pass - - async with httpx.AsyncClient( - transport=httpx.AsyncHTTPTransport(uds=socket) - ) as aclient: - try: - async with aconnect_ws("http://socket/ws", aclient) as aws: - await aws.ping() - event = await aws.receive() - assert isinstance(event, wsproto.events.Pong) - assert event.payload == b"" - except WebSocketDisconnect: - pass - async def test_send( self, server_factory: ServerFactoryFixture, From d72fef541126a331c551fb409b1c6d6c7d46b33e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Sun, 27 Nov 2022 11:53:19 +0100 Subject: [PATCH 032/108] Complete docs --- src/httpx2/httpx2/_websockets/_api.py | 96 +++++++++++++++++++++++++-- 1 file changed, 92 insertions(+), 4 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 94570631..5750a5f7 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -334,8 +334,8 @@ def receive_json( Raises: queue.Empty: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. - WebSocketInvalidTypeReceived: - The received event didn't correspond to the specified mode. + WebSocketInvalidTypeReceived: The received event + didn't correspond to the specified mode. Examples: Wait for data until available. @@ -687,8 +687,8 @@ async def receive_json( Raises: queue.Empty: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. - WebSocketInvalidTypeReceived: - The received event didn't correspond to the specified mode. + WebSocketInvalidTypeReceived: The received event + didn't correspond to the specified mode. Examples: Wait for data until available. @@ -791,6 +791,50 @@ def connect_ws( queue_size: int = DEFAULT_QUEUE_SIZE, **kwargs: typing.Any, ) -> typing.Generator[WebSocketSession, None, None]: + """ + Start a sync WebSocket session. + + It returns a context manager that'll automatically + call [close()][httpx_ws.WebSocketSession.close] when exiting. + + Args: + url: The WebSocket URL. + client: + HTTPX client to use. + If not provided, a default one will be initialized. + max_message_size_bytes: + Message size in bytes to receive from the server. + Defaults to 65 KiB. + queue_size: + Size of the queue where the received messages will be held + until they are consumed. + If the queue is full, the client will stop receive messages + from the server until the queue has room available. + Defaults to 512. + **kwargs: + Additional keyword arguments that will be passed to + the [HTTPX stream()](https://www.python-httpx.org/api/#request) method. + + Returns: + A [context manager][contextlib.AbstractContextManager] + for [WebSocketSession][httpx_ws.WebSocketSession]. + + Examples: + Without explicit HTTPX client. + + with connect_ws("http://localhost:8000/ws") as ws: + message = ws.receive_text() + print(message) + ws.send_text("Hello!") + + With explicit HTTPX client. + + with httpx.Client() as client: + with connect_ws("http://localhost:8000/ws", client) as ws: + message = ws.receive_text() + print(message) + ws.send_text("Hello!") + """ client = httpx.Client() if client is None else client headers = kwargs.pop("headers", {}) headers.update(_get_headers()) @@ -817,6 +861,50 @@ async def aconnect_ws( queue_size: int = DEFAULT_QUEUE_SIZE, **kwargs: typing.Any, ) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: + """ + Start an async WebSocket session. + + It returns an async context manager that'll automatically + call [close()][httpx_ws.AsyncWebSocketSession.close] when exiting. + + Args: + url: The WebSocket URL. + client: + HTTPX client to use. + If not provided, a default one will be initialized. + max_message_size_bytes: + Message size in bytes to receive from the server. + Defaults to 65 KiB. + queue_size: + Size of the queue where the received messages will be held + until they are consumed. + If the queue is full, the client will stop receive messages + from the server until the queue has room available. + Defaults to 512. + **kwargs: + Additional keyword arguments that will be passed to + the [HTTPX stream()](https://www.python-httpx.org/api/#request) method. + + Returns: + An [async context manager][contextlib.AbstractAsyncContextManager] + for [AsyncWebSocketSession][httpx_ws.AsyncWebSocketSession]. + + Examples: + Without explicit HTTPX client. + + async with aconnect_ws("http://localhost:8000/ws") as ws: + message = await ws.receive_text() + print(message) + await ws.send_text("Hello!") + + With explicit HTTPX client. + + async with httpx.AsyncClient() as client: + async with aconnect_ws("http://localhost:8000/ws", client) as ws: + message = await ws.receive_text() + print(message) + await ws.send_text("Hello!") + """ client = httpx.AsyncClient() if client is None else client headers = kwargs.pop("headers", {}) headers.update(_get_headers()) From a02695dd95a102078608474f510f1e7d77c3f9de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Sun, 27 Nov 2022 12:02:28 +0100 Subject: [PATCH 033/108] =?UTF-8?q?Bump=20version=200.1.1=20=E2=86=92=200.?= =?UTF-8?q?2.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Improvements ------------ * Documentation is live ➡️ https://frankie567.github.io/httpx-ws/ * Revamp implementation with a background thread/task receiving messages * Automatic management of ping/pong events --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 9a75dc92..2441d433 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.1" +__version__ = "0.2.0" from httpx_ws._api import ( AsyncWebSocketSession, From 08e58487f73281d3296bd65026396f4d1b4c43be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Sun, 27 Nov 2022 15:27:18 +0100 Subject: [PATCH 034/108] Ensure to close stream when calling close --- src/httpx2/httpx2/_websockets/_api.py | 2 ++ tests/httpx2/websockets/test_api.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 5750a5f7..df735af0 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -387,6 +387,7 @@ def close(self, code: int = 1000, reason: typing.Optional[str] = None): self.send(event) except httpcore.WriteError: pass + self.stream.close() def _background_receive(self, max_bytes: int) -> None: """ @@ -740,6 +741,7 @@ async def close(self, code: int = 1000, reason: typing.Optional[str] = None): await self.send(event) except httpcore.WriteError: pass + await self.stream.aclose() async def _background_receive(self, max_bytes: int) -> None: """ diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index b709a345..b92a1062 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -384,6 +384,9 @@ def write( ) -> None: self.connection.receive_data(buffer) + def close(self) -> None: + pass + stream = MockNetworkStream() websocket_session = WebSocketSession(stream) await asyncio.sleep(0.1) @@ -417,6 +420,9 @@ async def write( ) -> None: self.connection.receive_data(buffer) + async def aclose(self) -> None: + pass + stream = MockAsyncNetworkStream() websocket_session = AsyncWebSocketSession(stream) await asyncio.sleep(0.1) From 0677904761369cb79a0c878ec3e3c0542bbe3e98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Sun, 27 Nov 2022 15:57:12 +0100 Subject: [PATCH 035/108] Make sure some unit tests correctly close the websocket session --- tests/httpx2/websockets/test_api.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index b92a1062..a671823d 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -260,8 +260,8 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: try: - with pytest.raises(WebSocketInvalidTypeReceived): - with connect_ws("http://socket/ws", client) as ws: + with connect_ws("http://socket/ws", client) as ws: + with pytest.raises(WebSocketInvalidTypeReceived): ws.receive_text() except WebSocketDisconnect: pass @@ -270,8 +270,8 @@ async def websocket_endpoint(websocket: WebSocket): transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: try: - with pytest.raises(WebSocketInvalidTypeReceived): - async with aconnect_ws("http://socket/ws", aclient) as aws: + async with aconnect_ws("http://socket/ws", aclient) as aws: + with pytest.raises(WebSocketInvalidTypeReceived): await aws.receive_text() except WebSocketDisconnect: pass @@ -317,15 +317,15 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: - with pytest.raises(WebSocketInvalidTypeReceived): - with connect_ws("http://socket/ws", client) as ws: + with connect_ws("http://socket/ws", client) as ws: + with pytest.raises(WebSocketInvalidTypeReceived): ws.receive_bytes() async with httpx.AsyncClient( transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: - with pytest.raises(WebSocketInvalidTypeReceived): - async with aconnect_ws("http://socket/ws", aclient) as aws: + async with aconnect_ws("http://socket/ws", aclient) as aws: + with pytest.raises(WebSocketInvalidTypeReceived): await aws.receive_bytes() @pytest.mark.parametrize("mode", ["text", "binary"]) @@ -490,13 +490,13 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: - with pytest.raises(WebSocketDisconnect): - with connect_ws("http://socket/ws", client) as ws: + with connect_ws("http://socket/ws", client) as ws: + with pytest.raises(WebSocketDisconnect): ws.receive() async with httpx.AsyncClient( transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: - with pytest.raises(WebSocketDisconnect): - async with aconnect_ws("http://socket/ws", aclient) as aws: + async with aconnect_ws("http://socket/ws", aclient) as aws: + with pytest.raises(WebSocketDisconnect): await aws.receive() From 9d7653d53924008b691853c731368b146a0a1434 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Sun, 27 Nov 2022 18:13:04 +0100 Subject: [PATCH 036/108] Greatly improve network errors handling --- src/httpx2/httpx2/_websockets/__init__.py | 4 +- src/httpx2/httpx2/_websockets/_api.py | 120 +++++++++++++++++----- tests/httpx2/websockets/test_api.py | 10 +- 3 files changed, 107 insertions(+), 27 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 2441d433..9b15e713 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -6,6 +6,7 @@ JSONMode, WebSocketDisconnect, WebSocketInvalidTypeReceived, + WebSocketNetworkError, WebSocketSession, WebSocketUpgradeError, aconnect_ws, @@ -13,11 +14,12 @@ ) __all__ = [ + "AsyncWebSocketSession", "HTTPXWSException", "JSONMode", "WebSocketDisconnect", "WebSocketInvalidTypeReceived", - "AsyncWebSocketSession", + "WebSocketNetworkError", "WebSocketSession", "WebSocketUpgradeError", "aconnect_ws", diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index df735af0..d67fcc9e 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -17,6 +17,7 @@ import httpx import wsproto from httpcore.backends.base import AsyncNetworkStream, NetworkStream +from wsproto.connection import CloseReason from httpx_ws._ping import AsyncPingManager, PingManager @@ -68,6 +69,15 @@ def __init__(self, event: wsproto.events.Event) -> None: self.event = event +class WebSocketNetworkError(HTTPXWSException): + """ + Raised when a network error occured, + typically if the underlying stream has closed or timeout. + """ + + pass + + class WebSocketSession: """ Sync helper representing an opened WebSocket session. @@ -82,11 +92,14 @@ def __init__( ) -> None: self.stream = stream self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) - self._events: queue.Queue[wsproto.events.Event] = queue.Queue(queue_size) + self._events: queue.Queue[ + typing.Union[wsproto.events.Event, HTTPXWSException] + ] = queue.Queue(queue_size) self._ping_manager = PingManager() - self._should_close = False + self._should_close = threading.Event() + self._background_receive_task = threading.Thread( target=self._background_receive, args=(max_message_size_bytes,) ) @@ -130,14 +143,21 @@ def send(self, event: wsproto.events.Event) -> None: Args: event: The event to send. + Raises: + WebSocketNetworkError: A network error occured. + Examples: Send an event. event = wsproto.events.Message(b"Hello!") ws.send(event) """ - data = self.connection.send(event) - self.stream.write(data) + try: + data = self.connection.send(event) + self.stream.write(data) + except httpcore.WriteError as e: + self.close(CloseReason.INTERNAL_ERROR, "Stream write error") + raise WebSocketNetworkError() from e def send_text(self, data: str) -> None: """ @@ -146,6 +166,9 @@ def send_text(self, data: str) -> None: Args: data: The text to send. + Raises: + WebSocketNetworkError: A network error occured. + Examples: Send a text message. @@ -161,6 +184,9 @@ def send_bytes(self, data: bytes) -> None: Args: data: The data to send. + Raises: + WebSocketNetworkError: A network error occured. + Examples: Send a bytes message. @@ -179,6 +205,9 @@ def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: mode: The sending mode. Should either be `'text'` or `'bytes'`. + Raises: + WebSocketNetworkError: A network error occured. + Examples: Send JSON data. @@ -212,6 +241,7 @@ def receive(self, timeout: typing.Optional[float] = None) -> wsproto.events.Even Raises: queue.Empty: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. Examples: Wait for an event until one is available. @@ -231,6 +261,8 @@ def receive(self, timeout: typing.Optional[float] = None) -> wsproto.events.Even print("Connection closed") """ event = self._events.get(block=True, timeout=timeout) + if isinstance(event, HTTPXWSException): + raise event if isinstance(event, wsproto.events.CloseConnection): raise WebSocketDisconnect(event.code, event.reason) return event @@ -250,6 +282,7 @@ def receive_text(self, timeout: typing.Optional[float] = None) -> str: Raises: queue.Empty: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. WebSocketInvalidTypeReceived: The received event was not a text message. Examples: @@ -289,6 +322,7 @@ def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: Raises: queue.Empty: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. WebSocketInvalidTypeReceived: The received event was not a bytes message. Examples: @@ -334,6 +368,7 @@ def receive_json( Raises: queue.Empty: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. WebSocketInvalidTypeReceived: The received event didn't correspond to the specified mode. @@ -380,11 +415,15 @@ def close(self, code: int = 1000, reason: typing.Optional[str] = None): ws.close() """ - self._should_close = True - if self.connection.state != wsproto.connection.ConnectionState.CLOSED: + self._should_close.set() + if self.connection.state not in { + wsproto.connection.ConnectionState.LOCAL_CLOSING, + wsproto.connection.ConnectionState.CLOSED, + }: event = wsproto.events.CloseConnection(code, reason) + data = self.connection.send(event) try: - self.send(event) + self.stream.write(data) except httpcore.WriteError: pass self.stream.close() @@ -404,21 +443,23 @@ def _background_receive(self, max_bytes: int) -> None: max_bytes: The maximum chunk size to read at each iteration. """ try: - while not self._should_close: + while not self._should_close.is_set(): data = self.stream.read(max_bytes=max_bytes) self.connection.receive_data(data) for event in self.connection.events(): if isinstance(event, wsproto.events.Ping): - self.send(event.response()) + data = self.connection.send(event.response()) + self.stream.write(data) continue if isinstance(event, wsproto.events.Pong): self._ping_manager.ack(event.payload) continue if isinstance(event, wsproto.events.CloseConnection): - self._should_close = True + self._should_close.set() self._events.put(event) - except httpcore.ReadError: - pass + except (httpcore.ReadError, httpcore.WriteError): + self.close(CloseReason.INTERNAL_ERROR, "Stream error") + self._events.put(WebSocketNetworkError()) class AsyncWebSocketSession: @@ -435,11 +476,14 @@ def __init__( ) -> None: self.stream = stream self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) - self._events: asyncio.Queue[wsproto.events.Event] = asyncio.Queue(queue_size) + self._events: asyncio.Queue[ + typing.Union[wsproto.events.Event, HTTPXWSException] + ] = asyncio.Queue(queue_size) self._ping_manager = AsyncPingManager() - self._should_close = False + self._should_close = asyncio.Event() + self._background_receive_task = asyncio.create_task( self._background_receive(max_message_size_bytes) ) @@ -482,14 +526,21 @@ async def send(self, event: wsproto.events.Event) -> None: Args: event: The event to send. + Raises: + WebSocketNetworkError: A network error occured. + Examples: Send an event. event = await wsproto.events.Message(b"Hello!") ws.send(event) """ - data = self.connection.send(event) - await self.stream.write(data) + try: + data = self.connection.send(event) + await self.stream.write(data) + except httpcore.WriteError as e: + await self.close(CloseReason.INTERNAL_ERROR, "Stream write error") + raise WebSocketNetworkError() from e async def send_text(self, data: str) -> None: """ @@ -498,6 +549,9 @@ async def send_text(self, data: str) -> None: Args: data: The text to send. + Raises: + WebSocketNetworkError: A network error occured. + Examples: Send a text message. @@ -513,6 +567,9 @@ async def send_bytes(self, data: bytes) -> None: Args: data: The data to send. + Raises: + WebSocketNetworkError: A network error occured. + Examples: Send a bytes message. @@ -531,6 +588,9 @@ async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: mode: The sending mode. Should either be `'text'` or `'bytes'`. + Raises: + WebSocketNetworkError: A network error occured. + Examples: Send JSON data. @@ -566,6 +626,7 @@ async def receive( Raises: queue.Empty: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. Examples: Wait for an event until one is available. @@ -585,6 +646,8 @@ async def receive( print("Connection closed") """ event = await asyncio.wait_for(self._events.get(), timeout) + if isinstance(event, HTTPXWSException): + raise event if isinstance(event, wsproto.events.CloseConnection): raise WebSocketDisconnect(event.code, event.reason) return event @@ -604,6 +667,7 @@ async def receive_text(self, timeout: typing.Optional[float] = None) -> str: Raises: queue.Empty: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. WebSocketInvalidTypeReceived: The received event was not a text message. Examples: @@ -643,6 +707,7 @@ async def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: Raises: queue.Empty: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. WebSocketInvalidTypeReceived: The received event was not a bytes message. Examples: @@ -688,6 +753,7 @@ async def receive_json( Raises: queue.Empty: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. WebSocketInvalidTypeReceived: The received event didn't correspond to the specified mode. @@ -734,11 +800,15 @@ async def close(self, code: int = 1000, reason: typing.Optional[str] = None): await ws.close() """ - self._should_close = True - if self.connection.state != wsproto.connection.ConnectionState.CLOSED: + self._should_close.set() + if self.connection.state not in { + wsproto.connection.ConnectionState.LOCAL_CLOSING, + wsproto.connection.ConnectionState.CLOSED, + }: event = wsproto.events.CloseConnection(code, reason) + data = self.connection.send(event) try: - await self.send(event) + await self.stream.write(data) except httpcore.WriteError: pass await self.stream.aclose() @@ -758,21 +828,23 @@ async def _background_receive(self, max_bytes: int) -> None: max_bytes: The maximum chunk size to read at each iteration. """ try: - while not self._should_close: + while not self._should_close.is_set(): data = await self.stream.read(max_bytes=max_bytes) self.connection.receive_data(data) for event in self.connection.events(): if isinstance(event, wsproto.events.Ping): - await self.send(event.response()) + data = self.connection.send(event.response()) + await self.stream.write(data) continue if isinstance(event, wsproto.events.Pong): self._ping_manager.ack(event.payload) continue if isinstance(event, wsproto.events.CloseConnection): - self._should_close = True + self._should_close.set() await self._events.put(event) - except httpcore.ReadError: - pass + except (httpcore.ReadError, httpcore.WriteError): + await self.close(CloseReason.INTERNAL_ERROR, "Stream error") + await self._events.put(WebSocketNetworkError()) def _get_headers() -> typing.Dict[str, typing.Any]: diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index a671823d..af4e4e55 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -368,7 +368,10 @@ def __init__(self) -> None: self.connection = wsproto.connection.Connection( wsproto.connection.ConnectionType.SERVER ) - self.events_to_send = [wsproto.events.Ping(b"SERVER_PING")] + self.events_to_send = [ + wsproto.events.Ping(b"SERVER_PING"), + wsproto.events.CloseConnection(1000), + ] def read( self, max_bytes: int, timeout: typing.Optional[float] = None @@ -404,7 +407,10 @@ def __init__(self) -> None: self.connection = wsproto.connection.Connection( wsproto.connection.ConnectionType.SERVER ) - self.events_to_send = [wsproto.events.Ping(b"SERVER_PING")] + self.events_to_send = [ + wsproto.events.Ping(b"SERVER_PING"), + wsproto.events.CloseConnection(1000), + ] async def read( self, max_bytes: int, timeout: typing.Optional[float] = None From 6f666fabb8ac82396c2237138a7423f46bd7d7a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Sun, 27 Nov 2022 18:16:26 +0100 Subject: [PATCH 037/108] Add missing unit tests --- tests/httpx2/websockets/test_api.py | 112 ++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index af4e4e55..517de7ef 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -1,4 +1,5 @@ import asyncio +import time import typing from unittest.mock import MagicMock, call @@ -15,6 +16,7 @@ JSONMode, WebSocketDisconnect, WebSocketInvalidTypeReceived, + WebSocketNetworkError, WebSocketSession, WebSocketUpgradeError, aconnect_ws, @@ -45,6 +47,64 @@ def handler(request): @pytest.mark.asyncio class TestSend: + async def test_send_error(self): + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection( + wsproto.connection.ConnectionType.SERVER + ) + self._should_close = False + + def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + while not self._should_close: + time.sleep(0.1) + raise httpcore.ReadError() + + def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + raise httpcore.WriteError() + + def close(self) -> None: + self._should_close = True + + stream = MockNetworkStream() + websocket_session = WebSocketSession(stream) + with pytest.raises(WebSocketNetworkError): + websocket_session.send(wsproto.events.Ping()) + websocket_session.close() + + async def test_async_send_error(self): + class AsyncMockNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection( + wsproto.connection.ConnectionType.SERVER + ) + self._should_close = False + + async def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + while not self._should_close: + await asyncio.sleep(0.1) + raise httpcore.ReadError() + + async def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + raise httpcore.WriteError() + + async def aclose(self) -> None: + self._should_close = True + + stream = AsyncMockNetworkStream() + websocket_session = AsyncWebSocketSession(stream) + with pytest.raises(WebSocketNetworkError): + await websocket_session.send(wsproto.events.Ping()) + await websocket_session.close() + async def test_send( self, server_factory: ServerFactoryFixture, @@ -188,6 +248,58 @@ async def websocket_endpoint(websocket: WebSocket): @pytest.mark.asyncio class TestReceive: + async def test_receive_error(self): + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection( + wsproto.connection.ConnectionType.SERVER + ) + + def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + raise httpcore.ReadError() + + def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + pass + + def close(self) -> None: + pass + + stream = MockNetworkStream() + websocket_session = WebSocketSession(stream) + with pytest.raises(WebSocketNetworkError): + websocket_session.receive() + websocket_session.close() + + async def test_async_receive_error(self): + class AsyncMockNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection( + wsproto.connection.ConnectionType.SERVER + ) + + async def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + raise httpcore.ReadError() + + async def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + pass + + async def aclose(self) -> None: + pass + + stream = AsyncMockNetworkStream() + websocket_session = AsyncWebSocketSession(stream) + with pytest.raises(WebSocketNetworkError): + await websocket_session.receive() + await websocket_session.close() + async def test_receive(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() From 722bcc2930b68f8d21e8ec623484200fafb8a5c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 28 Nov 2022 09:56:16 +0100 Subject: [PATCH 038/108] Handle immediate close in ASGI transport --- src/httpx2/httpx2/_websockets/transport.py | 4 ++++ tests/httpx2/websockets/test_transport.py | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index a4b5ebd7..0febb5b1 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -8,6 +8,8 @@ from httpcore.backends.base import AsyncNetworkStream from httpx import ASGITransport, AsyncByteStream, Request, Response +from httpx_ws._api import WebSocketDisconnect + Scope = typing.MutableMapping[str, typing.Any] Message = typing.MutableMapping[str, typing.Any] Receive = typing.Callable[[], typing.Awaitable[Message]] @@ -45,6 +47,8 @@ async def __aenter__(self) -> "ASGIWebSocketAsyncNetworkStream": _: "Future[None]" = self.portal.start_task_soon(self._run) await self.send({"type": "websocket.connect"}) message = await self.receive() + if message["type"] == "websocket.close": + raise WebSocketDisconnect(message["code"], message.get("reason")) assert message["type"] == "websocket.accept" return self diff --git a/tests/httpx2/websockets/test_transport.py b/tests/httpx2/websockets/test_transport.py index 70b3c68d..523d9f18 100644 --- a/tests/httpx2/websockets/test_transport.py +++ b/tests/httpx2/websockets/test_transport.py @@ -8,6 +8,7 @@ from starlette.routing import Route, WebSocketRoute from starlette.websockets import WebSocket +from httpx_ws import WebSocketDisconnect from httpx_ws.transport import ( ASGIWebSocketAsyncNetworkStream, ASGIWebSocketTransport, @@ -88,6 +89,14 @@ async def app(scope, receive, send): with pytest.raises(UnhandledASGIMessageType): await stream.read(4096) + async def test_close_immediately(self): + async def app(scope, receive, send): + await send({"type": "websocket.close", "code": 1000, "reason": ""}) + + with pytest.raises(WebSocketDisconnect): + async with ASGIWebSocketAsyncNetworkStream(app, {}): + pass + @pytest.fixture def test_app() -> Starlette: From 8f77e6a7d617dbd439266ce2c3eb53c2cbd5c346 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 28 Nov 2022 09:58:54 +0100 Subject: [PATCH 039/108] =?UTF-8?q?Bump=20version=200.2.0=20=E2=86=92=200.?= =?UTF-8?q?2.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Improvements ------------ * Handle immediate close from the server in ASGI transport * Improve network errors handling --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 9b15e713..d344ced6 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.2.0" +__version__ = "0.2.1" from httpx_ws._api import ( AsyncWebSocketSession, From 0e015ef3a29f384760e8fd683153e54d599121aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 28 Nov 2022 14:28:32 +0100 Subject: [PATCH 040/108] Add automatic keepalive ping mechanism --- src/httpx2/httpx2/_websockets/_api.py | 116 ++++++++++++++++-- tests/httpx2/websockets/test_api.py | 164 ++++++++++++++++++++++++++ 2 files changed, 272 insertions(+), 8 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index d67fcc9e..bfbd0cbe 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -25,6 +25,8 @@ DEFAULT_MAX_MESSAGE_SIZE_BYTES = 65_536 DEFAULT_QUEUE_SIZE = 512 +DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS = 20.0 +DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS = 20.0 class HTTPXWSException(Exception): @@ -89,6 +91,12 @@ def __init__( *, max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, ) -> None: self.stream = stream self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) @@ -105,6 +113,14 @@ def __init__( ) self._background_receive_task.start() + self._background_keepalive_ping_task: typing.Optional[threading.Thread] = None + if keepalive_ping_interval_seconds is not None: + self._background_keepalive_ping_task = threading.Thread( + target=self._background_keepalive_ping, + args=(keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds), + ) + self._background_keepalive_ping_task.start() + def ping(self, payload: bytes = b"") -> threading.Event: """ Send a Ping message. @@ -461,6 +477,21 @@ def _background_receive(self, max_bytes: int) -> None: self.close(CloseReason.INTERNAL_ERROR, "Stream error") self._events.put(WebSocketNetworkError()) + def _background_keepalive_ping( + self, interval_seconds: float, timeout_seconds: typing.Optional[float] = None + ) -> None: + while True: + should_close = self._should_close.wait(interval_seconds) + if should_close: + break + + pong_callback = self.ping() + if timeout_seconds is not None: + acknowledged = pong_callback.wait(timeout_seconds) + if not acknowledged: + self.close(CloseReason.INTERNAL_ERROR, "Keepalive ping timeout") + self._events.put(WebSocketNetworkError()) + class AsyncWebSocketSession: """ @@ -473,6 +504,12 @@ def __init__( *, max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, ) -> None: self.stream = stream self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) @@ -488,6 +525,14 @@ def __init__( self._background_receive(max_message_size_bytes) ) + self._background_keepalive_ping_task: typing.Optional[asyncio.Task] = None + if keepalive_ping_interval_seconds is not None: + self._background_keepalive_ping_task = asyncio.create_task( + self._background_keepalive_ping( + keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds + ) + ) + async def ping(self, payload: bytes = b"") -> asyncio.Event: """ Send a Ping message. @@ -624,7 +669,7 @@ async def receive( A raw [wsproto.events.Event][wsproto.events.Event]. Raises: - queue.Empty: No event was received before the timeout delay. + asyncio.TimeoutError: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. WebSocketNetworkError: A network error occured. @@ -640,7 +685,7 @@ async def receive( try: event = await ws.receive(timeout=2.) - except queue.Empty: + except asyncio.TimeoutError: print("No event received.") except WebSocketDisconnect: print("Connection closed") @@ -665,7 +710,7 @@ async def receive_text(self, timeout: typing.Optional[float] = None) -> str: Text data. Raises: - queue.Empty: No event was received before the timeout delay. + asyncio.TimeoutError: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. WebSocketNetworkError: A network error occured. WebSocketInvalidTypeReceived: The received event was not a text message. @@ -682,7 +727,7 @@ async def receive_text(self, timeout: typing.Optional[float] = None) -> str: try: event = await ws.receive_text(timeout=2.) - except queue.Empty: + except asyncio.TimeoutError: print("No text received.") except WebSocketDisconnect: print("Connection closed") @@ -705,7 +750,7 @@ async def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: Bytes data. Raises: - queue.Empty: No event was received before the timeout delay. + asyncio.TimeoutError: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. WebSocketNetworkError: A network error occured. WebSocketInvalidTypeReceived: The received event was not a bytes message. @@ -722,7 +767,7 @@ async def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: try: data = await ws.receive_bytes(timeout=2.) - except queue.Empty: + except asyncio.TimeoutError: print("No data received.") except WebSocketDisconnect: print("Connection closed") @@ -751,7 +796,7 @@ async def receive_json( Parsed JSON data. Raises: - queue.Empty: No event was received before the timeout delay. + asyncio.TimeoutError: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. WebSocketNetworkError: A network error occured. WebSocketInvalidTypeReceived: The received event @@ -769,7 +814,7 @@ async def receive_json( try: data = await ws.receive_json(timeout=2.) - except queue.Empty: + except asyncio.TimeoutError: print("No data received.") except WebSocketDisconnect: print("Connection closed") @@ -846,6 +891,25 @@ async def _background_receive(self, max_bytes: int) -> None: await self.close(CloseReason.INTERNAL_ERROR, "Stream error") await self._events.put(WebSocketNetworkError()) + async def _background_keepalive_ping( + self, interval_seconds: float, timeout_seconds: typing.Optional[float] = None + ) -> None: + while True: + try: + await asyncio.wait_for(self._should_close.wait(), interval_seconds) + except asyncio.TimeoutError: + pass + + pong_callback = await self.ping() + if timeout_seconds is not None: + try: + await asyncio.wait_for(pong_callback.wait(), timeout_seconds) + except asyncio.TimeoutError: + await self.close( + CloseReason.INTERNAL_ERROR, "Keepalive ping timeout" + ) + await self._events.put(WebSocketNetworkError()) + def _get_headers() -> typing.Dict[str, typing.Any]: return { @@ -863,6 +927,12 @@ def connect_ws( *, max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, **kwargs: typing.Any, ) -> typing.Generator[WebSocketSession, None, None]: """ @@ -885,6 +955,16 @@ def connect_ws( If the queue is full, the client will stop receive messages from the server until the queue has room available. Defaults to 512. + keepalive_ping_interval_seconds: + Interval at which the client will automatically send a Ping event + to keep the connection alive. Set it to `None` to disable this mechanism. + Defaults to 20 seconds. + keepalive_ping_timeout_seconds: + Maximum delay the client will wait for an answer to its Ping event. + If the delay is exceeded, + [WebSocketNetworkError][httpx_ws.WebSocketNetworkError] + will be raised and the connection closed. + Defaults to 20 seconds. **kwargs: Additional keyword arguments that will be passed to the [HTTPX stream()](https://www.python-httpx.org/api/#request) method. @@ -921,6 +1001,8 @@ def connect_ws( response.extensions["network_stream"], max_message_size_bytes=max_message_size_bytes, queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, ) yield session session.close() @@ -933,6 +1015,12 @@ async def aconnect_ws( *, max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, **kwargs: typing.Any, ) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: """ @@ -955,6 +1043,16 @@ async def aconnect_ws( If the queue is full, the client will stop receive messages from the server until the queue has room available. Defaults to 512. + keepalive_ping_interval_seconds: + Interval at which the client will automatically send a Ping event + to keep the connection alive. Set it to `None` to disable this mechanism. + Defaults to 20 seconds. + keepalive_ping_timeout_seconds: + Maximum delay the client will wait for an answer to its Ping event. + If the delay is exceeded, + [WebSocketNetworkError][httpx_ws.WebSocketNetworkError] + will be raised and the connection closed. + Defaults to 20 seconds. **kwargs: Additional keyword arguments that will be passed to the [HTTPX stream()](https://www.python-httpx.org/api/#request) method. @@ -991,6 +1089,8 @@ async def aconnect_ws( response.extensions["network_stream"], max_message_size_bytes=max_message_size_bytes, queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, ) yield session await session.close() diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 517de7ef..10aa4f0c 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -1,4 +1,5 @@ import asyncio +import queue import time import typing from unittest.mock import MagicMock, call @@ -553,6 +554,169 @@ async def aclose(self) -> None: ] +@pytest.mark.asyncio +class TestKeepalivePing: + async def test_keepalive_ping(self): + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection( + wsproto.connection.ConnectionType.SERVER + ) + self._should_close = False + self.ping_received = 0 + self.ping_answered = 0 + self.events_to_send: queue.Queue[wsproto.events.Event] = queue.Queue() + + def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + while not self._should_close: + try: + event = self.events_to_send.get_nowait() + self.ping_answered += 1 + return self.connection.send(event) + except queue.Empty: + pass + raise httpcore.ReadError() + + def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + self.connection.receive_data(buffer) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Ping): + self.ping_received += 1 + self.events_to_send.put(event.response()) + + def close(self) -> None: + self._should_close = True + + stream = MockNetworkStream() + websocket_session = WebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ) + await asyncio.sleep(0.2) + websocket_session.close() + + assert stream.ping_received >= 1 + assert stream.ping_answered >= 1 + + async def test_keepalive_ping_timeout(self): + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection( + wsproto.connection.ConnectionType.SERVER + ) + self._should_close = False + + def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + while not self._should_close: + time.sleep(0.1) + raise httpcore.ReadError() + + def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + pass + + def close(self) -> None: + self._should_close = True + + stream = MockNetworkStream() + with pytest.raises(WebSocketNetworkError): + websocket_session = WebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ) + websocket_session.receive() + + async def test_async_keepalive_ping(self): + class MockAsyncNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection( + wsproto.connection.ConnectionType.SERVER + ) + self._should_close = False + self.ping_received = 0 + self.ping_answered = 0 + self.events_to_send: asyncio.Queue[ + wsproto.events.Event + ] = asyncio.Queue() + + async def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + while not self._should_close: + try: + event = self.events_to_send.get_nowait() + self.ping_answered += 1 + return self.connection.send(event) + except asyncio.QueueEmpty: + await asyncio.sleep(0.1) + raise httpcore.ReadError() + + async def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + self.connection.receive_data(buffer) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Ping): + self.ping_received += 1 + await self.events_to_send.put(event.response()) + + async def aclose(self) -> None: + self._should_close = True + + stream = MockAsyncNetworkStream() + websocket_session = AsyncWebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ) + await asyncio.sleep(0.3) + await websocket_session.close() + + assert stream.ping_received >= 1 + assert stream.ping_answered >= 1 + + async def test_async_keepalive_ping_timeout(self): + class MockAsyncNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.connection = wsproto.connection.Connection( + wsproto.connection.ConnectionType.SERVER + ) + self._should_close = False + + async def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + while not self._should_close: + await asyncio.sleep(0.1) + raise httpcore.ReadError() + + async def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + pass + + async def aclose(self) -> None: + self._should_close = True + + stream = MockAsyncNetworkStream() + with pytest.raises(WebSocketNetworkError): + websocket_session = AsyncWebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ) + await websocket_session.receive() + + @pytest.mark.asyncio async def test_ping_pong(server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): From ab338e13cf0980c89b100e82f2732aa947b285c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 28 Nov 2022 15:22:36 +0100 Subject: [PATCH 041/108] =?UTF-8?q?Bump=20version=200.2.1=20=E2=86=92=200.?= =?UTF-8?q?2.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Improvements ------------ * Automatic keepalive ping mechanism --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index d344ced6..c86cfdcc 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.2.1" +__version__ = "0.2.2" from httpx_ws._api import ( AsyncWebSocketSession, From 82577fbe204ab7ef5333faee208f60febbb5c55a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 29 Nov 2022 16:27:46 +0100 Subject: [PATCH 042/108] Add mechanism to stop thread early when they're stuck in a waiting operation --- src/httpx2/httpx2/_websockets/_api.py | 107 +++++++++++++++++++------- 1 file changed, 79 insertions(+), 28 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index bfbd0cbe..49fb440f 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -1,5 +1,6 @@ import asyncio import base64 +import concurrent.futures import contextlib import json import queue @@ -22,6 +23,8 @@ from httpx_ws._ping import AsyncPingManager, PingManager JSONMode = Literal["text", "binary"] +TaskFunction = typing.TypeVar("TaskFunction") +TaskResult = typing.TypeVar("TaskResult") DEFAULT_MAX_MESSAGE_SIZE_BYTES = 65_536 DEFAULT_QUEUE_SIZE = 512 @@ -80,6 +83,10 @@ class WebSocketNetworkError(HTTPXWSException): pass +class ShouldClose(Exception): + pass + + class WebSocketSession: """ Sync helper representing an opened WebSocket session. @@ -460,7 +467,7 @@ def _background_receive(self, max_bytes: int) -> None: """ try: while not self._should_close.is_set(): - data = self.stream.read(max_bytes=max_bytes) + data = self._wait_until_closed(self.stream.read, max_bytes) self.connection.receive_data(data) for event in self.connection.events(): if isinstance(event, wsproto.events.Ping): @@ -476,21 +483,46 @@ def _background_receive(self, max_bytes: int) -> None: except (httpcore.ReadError, httpcore.WriteError): self.close(CloseReason.INTERNAL_ERROR, "Stream error") self._events.put(WebSocketNetworkError()) + except ShouldClose: + pass def _background_keepalive_ping( self, interval_seconds: float, timeout_seconds: typing.Optional[float] = None ) -> None: - while True: - should_close = self._should_close.wait(interval_seconds) - if should_close: - break - - pong_callback = self.ping() - if timeout_seconds is not None: - acknowledged = pong_callback.wait(timeout_seconds) - if not acknowledged: - self.close(CloseReason.INTERNAL_ERROR, "Keepalive ping timeout") - self._events.put(WebSocketNetworkError()) + try: + while not self._should_close.is_set(): + should_close = self._wait_until_closed( + self._should_close.wait, interval_seconds + ) + if should_close: + raise ShouldClose() + pong_callback = self.ping() + if timeout_seconds is not None: + acknowledged = self._wait_until_closed( + pong_callback.wait, timeout_seconds + ) + if not acknowledged: + self.close(CloseReason.INTERNAL_ERROR, "Keepalive ping timeout") + self._events.put(WebSocketNetworkError()) + except ShouldClose: + pass + + def _wait_until_closed( + self, callable: typing.Callable[..., TaskResult], *args, **kwargs + ) -> TaskResult: + executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) + wait_close_task = executor.submit(self._should_close.wait) + todo_task = executor.submit(callable, *args, **kwargs) + done, pending = concurrent.futures.wait( # type: ignore + (todo_task, wait_close_task), return_when=concurrent.futures.FIRST_COMPLETED + ) + for task in pending: + task.cancel() + if wait_close_task in done: + raise ShouldClose() + result = todo_task.result() + executor.shutdown(False) + return result class AsyncWebSocketSession: @@ -874,7 +906,9 @@ async def _background_receive(self, max_bytes: int) -> None: """ try: while not self._should_close.is_set(): - data = await self.stream.read(max_bytes=max_bytes) + data = await self._wait_until_closed( + self.stream.read(max_bytes=max_bytes) + ) self.connection.receive_data(data) for event in self.connection.events(): if isinstance(event, wsproto.events.Ping): @@ -890,25 +924,42 @@ async def _background_receive(self, max_bytes: int) -> None: except (httpcore.ReadError, httpcore.WriteError): await self.close(CloseReason.INTERNAL_ERROR, "Stream error") await self._events.put(WebSocketNetworkError()) + except ShouldClose: + pass async def _background_keepalive_ping( self, interval_seconds: float, timeout_seconds: typing.Optional[float] = None ) -> None: - while True: - try: - await asyncio.wait_for(self._should_close.wait(), interval_seconds) - except asyncio.TimeoutError: - pass - - pong_callback = await self.ping() - if timeout_seconds is not None: - try: - await asyncio.wait_for(pong_callback.wait(), timeout_seconds) - except asyncio.TimeoutError: - await self.close( - CloseReason.INTERNAL_ERROR, "Keepalive ping timeout" - ) - await self._events.put(WebSocketNetworkError()) + try: + while not self._should_close.is_set(): + await self._wait_until_closed(asyncio.sleep(interval_seconds)) + pong_callback = await self.ping() + if timeout_seconds is not None: + try: + await self._wait_until_closed( + asyncio.wait_for(pong_callback.wait(), timeout_seconds) + ) + except asyncio.TimeoutError: + await self.close( + CloseReason.INTERNAL_ERROR, "Keepalive ping timeout" + ) + await self._events.put(WebSocketNetworkError()) + except ShouldClose: + pass + + async def _wait_until_closed( + self, coro: typing.Coroutine[typing.Any, typing.Any, TaskResult] + ) -> TaskResult: + wait_close_task = asyncio.create_task(self._should_close.wait()) + todo_task = asyncio.create_task(coro) + done, pending = await asyncio.wait( + {todo_task, wait_close_task}, return_when=asyncio.FIRST_COMPLETED + ) + for task in pending: + task.cancel() + if wait_close_task in done: + raise ShouldClose() + return todo_task.result() def _get_headers() -> typing.Dict[str, typing.Any]: From dd95a860542c124863f0423bd0fdb718b4b41455 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 29 Nov 2022 16:31:20 +0100 Subject: [PATCH 043/108] =?UTF-8?q?Bump=20version=200.2.2=20=E2=86=92=200.?= =?UTF-8?q?2.3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Improvements ------------ * When closing, stop background threads early when they're stuck in a waiting operation --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index c86cfdcc..b5c06950 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.2.2" +__version__ = "0.2.3" from httpx_ws._api import ( AsyncWebSocketSession, From 28811b91ad922a9a1d969ad9de0244fae86aca23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 27 Dec 2022 10:21:51 +0100 Subject: [PATCH 044/108] Fix #15: make sure default HTTPX client is closed (#16) --- src/httpx2/httpx2/_websockets/_api.py | 112 ++++++++++++++++++++++---- tests/httpx2/websockets/test_api.py | 28 ++++++- 2 files changed, 122 insertions(+), 18 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 49fb440f..03bacf0a 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -971,6 +971,39 @@ def _get_headers() -> typing.Dict[str, typing.Any]: } +@contextlib.contextmanager +def _connect_ws( + url: str, + client: httpx.Client, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + **kwargs: typing.Any, +) -> typing.Generator[WebSocketSession, None, None]: + headers = kwargs.pop("headers", {}) + headers.update(_get_headers()) + + with client.stream("GET", url, headers=headers, **kwargs) as response: + if response.status_code != 101: + raise WebSocketUpgradeError(response) + + session = WebSocketSession( + response.extensions["network_stream"], + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + ) + yield session + session.close() + + @contextlib.contextmanager def connect_ws( url: str, @@ -1040,15 +1073,54 @@ def connect_ws( print(message) ws.send_text("Hello!") """ - client = httpx.Client() if client is None else client + if client is None: + with httpx.Client() as client: + with _connect_ws( + url, + client=client, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + **kwargs, + ) as websocket: + yield websocket + else: + with _connect_ws( + url, + client=client, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + **kwargs, + ) as websocket: + yield websocket + + +@contextlib.asynccontextmanager +async def _aconnect_ws( + url: str, + client: httpx.AsyncClient, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + **kwargs: typing.Any, +) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: headers = kwargs.pop("headers", {}) headers.update(_get_headers()) - with client.stream("GET", url, headers=headers, **kwargs) as response: + async with client.stream("GET", url, headers=headers, **kwargs) as response: if response.status_code != 101: raise WebSocketUpgradeError(response) - session = WebSocketSession( + session = AsyncWebSocketSession( response.extensions["network_stream"], max_message_size_bytes=max_message_size_bytes, queue_size=queue_size, @@ -1056,7 +1128,7 @@ def connect_ws( keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, ) yield session - session.close() + await session.close() @contextlib.asynccontextmanager @@ -1128,20 +1200,26 @@ async def aconnect_ws( print(message) await ws.send_text("Hello!") """ - client = httpx.AsyncClient() if client is None else client - headers = kwargs.pop("headers", {}) - headers.update(_get_headers()) - - async with client.stream("GET", url, headers=headers, **kwargs) as response: - if response.status_code != 101: - raise WebSocketUpgradeError(response) - - session = AsyncWebSocketSession( - response.extensions["network_stream"], + if client is None: + async with httpx.AsyncClient() as client: + async with _aconnect_ws( + url, + client=client, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + **kwargs, + ) as websocket: + yield websocket + else: + async with _aconnect_ws( + url, + client=client, max_message_size_bytes=max_message_size_bytes, queue_size=queue_size, keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, - ) - yield session - await session.close() + **kwargs, + ) as websocket: + yield websocket diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 10aa4f0c..2f2e17c6 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -1,8 +1,9 @@ import asyncio +import contextlib import queue import time import typing -from unittest.mock import MagicMock, call +from unittest.mock import MagicMock, call, patch import httpcore import httpx @@ -782,3 +783,28 @@ async def websocket_endpoint(websocket: WebSocket): async with aconnect_ws("http://socket/ws", aclient) as aws: with pytest.raises(WebSocketDisconnect): await aws.receive() + + +@pytest.mark.asyncio +async def test_default_httpx_client(): + mock_context = contextlib.ExitStack() + with patch( + "httpx_ws._api._connect_ws", return_value=mock_context + ) as mock_connect_ws: + with connect_ws("http://socket/ws"): + pass + mock_connect_ws.assert_called_once() + httpx_client = mock_connect_ws.call_args[1]["client"] + assert isinstance(httpx_client, httpx.Client) + assert httpx_client.is_closed + + mock_async_context = contextlib.AsyncExitStack() + with patch( + "httpx_ws._api._aconnect_ws", return_value=mock_async_context + ) as mock_aconnect_ws: + async with aconnect_ws("http://socket/ws"): + pass + mock_aconnect_ws.assert_called_once() + httpx_client = mock_aconnect_ws.call_args[1]["client"] + assert isinstance(httpx_client, httpx.AsyncClient) + assert httpx_client.is_closed From 66445dc147950808ce12d7367e6c3e656273e4db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 27 Dec 2022 10:23:09 +0100 Subject: [PATCH 045/108] =?UTF-8?q?Bump=20version=200.2.3=20=E2=86=92=200.?= =?UTF-8?q?2.4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes --------- * Fix #15: make sure default HTTPX client is closed. --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index b5c06950..f6113dc9 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.2.3" +__version__ = "0.2.4" from httpx_ws._api import ( AsyncWebSocketSession, From d1d7121494c2c89a4f046a779b893f86e9e4faa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 2 Jan 2023 15:20:11 +0100 Subject: [PATCH 046/108] Fix #19: when both todo task and wait task are done in _wait_until_closed, make sure to always return the todo task result --- src/httpx2/httpx2/_websockets/_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 03bacf0a..de41a428 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -518,7 +518,7 @@ def _wait_until_closed( ) for task in pending: task.cancel() - if wait_close_task in done: + if wait_close_task in done and todo_task not in done: raise ShouldClose() result = todo_task.result() executor.shutdown(False) @@ -957,7 +957,7 @@ async def _wait_until_closed( ) for task in pending: task.cancel() - if wait_close_task in done: + if wait_close_task in done and todo_task not in done: raise ShouldClose() return todo_task.result() From e5a53a2cfb321f7765051db2e956f465bf2d66ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 2 Jan 2023 15:27:59 +0100 Subject: [PATCH 047/108] Fix typings --- src/httpx2/httpx2/_websockets/transport.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index 0febb5b1..745318d3 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -10,11 +10,11 @@ from httpx_ws._api import WebSocketDisconnect -Scope = typing.MutableMapping[str, typing.Any] -Message = typing.MutableMapping[str, typing.Any] +Scope = typing.Dict[str, typing.Any] +Message = typing.Dict[str, typing.Any] Receive = typing.Callable[[], typing.Awaitable[Message]] -Send = typing.Callable[[Message], typing.Awaitable[None]] -ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] +Send = typing.Callable[[Scope], typing.Coroutine[None, None, None]] +ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Coroutine[None, None, None]] class ASGIWebSocketTransportError(Exception): From aed282e05a2630f869edc944fa2961c3feaa1f37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 2 Jan 2023 15:38:13 +0100 Subject: [PATCH 048/108] =?UTF-8?q?Bump=20version=200.2.4=20=E2=86=92=200.?= =?UTF-8?q?2.5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes --------- * Fix #19: when both todo task and wait task are done in `_wait_until_closed`, make sure to always return the todo task result --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index f6113dc9..d172e420 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.2.4" +__version__ = "0.2.5" from httpx_ws._api import ( AsyncWebSocketSession, From 1fc1ed03c47961f742cfd46ab0bb9b7f62aed263 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 2 Jan 2023 18:12:43 +0100 Subject: [PATCH 049/108] Fix case where we try to schedule tasks in new threadpool during client close See https://github.com/frankie567/httpx-ws/issues/19#issuecomment-1368995731 --- src/httpx2/httpx2/_websockets/_api.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index de41a428..d4f6036f 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -510,9 +510,12 @@ def _background_keepalive_ping( def _wait_until_closed( self, callable: typing.Callable[..., TaskResult], *args, **kwargs ) -> TaskResult: - executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) - wait_close_task = executor.submit(self._should_close.wait) - todo_task = executor.submit(callable, *args, **kwargs) + try: + executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) + wait_close_task = executor.submit(self._should_close.wait) + todo_task = executor.submit(callable, *args, **kwargs) + except RuntimeError as e: + raise ShouldClose() from e done, pending = concurrent.futures.wait( # type: ignore (todo_task, wait_close_task), return_when=concurrent.futures.FIRST_COMPLETED ) From f86992fe0b85f87df88d76dc46f5c602157ccf5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 3 Jan 2023 09:17:28 +0100 Subject: [PATCH 050/108] =?UTF-8?q?Bump=20version=200.2.5=20=E2=86=92=200.?= =?UTF-8?q?2.6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes --------- * Fix case where we try to schedule tasks in new threadpool during client close --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index d172e420..a165f5bc 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.2.5" +__version__ = "0.2.6" from httpx_ws._api import ( AsyncWebSocketSession, From 44b9e0c2df225a38e2c624f89d963b142711b09c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 15 Feb 2023 11:54:51 +0100 Subject: [PATCH 051/108] Add support for subprotocols (#24) --- src/httpx2/httpx2/_websockets/_api.py | 49 +++++++++++++++++++++++-- tests/httpx2/websockets/test_api.py | 53 +++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 4 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index d4f6036f..27b009fb 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -90,8 +90,14 @@ class ShouldClose(Exception): class WebSocketSession: """ Sync helper representing an opened WebSocket session. + + Attributes: + subprotocol (typing.Optional[str]): + Optional protocol that has been accepted by the server. """ + subprotocol: typing.Optional[str] + def __init__( self, stream: NetworkStream, @@ -104,9 +110,12 @@ def __init__( keepalive_ping_timeout_seconds: typing.Optional[ float ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocol: typing.Optional[str] = None, ) -> None: self.stream = stream self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) + self.subprotocol = subprotocol + self._events: queue.Queue[ typing.Union[wsproto.events.Event, HTTPXWSException] ] = queue.Queue(queue_size) @@ -531,8 +540,14 @@ def _wait_until_closed( class AsyncWebSocketSession: """ Async helper representing an opened WebSocket session. + + Attributes: + subprotocol (typing.Optional[str]): + Optional protocol that has been accepted by the server. """ + subprotocol: typing.Optional[str] + def __init__( self, stream: AsyncNetworkStream, @@ -545,9 +560,12 @@ def __init__( keepalive_ping_timeout_seconds: typing.Optional[ float ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocol: typing.Optional[str] = None, ) -> None: self.stream = stream self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) + self.subprotocol = subprotocol + self._events: asyncio.Queue[ typing.Union[wsproto.events.Event, HTTPXWSException] ] = asyncio.Queue(queue_size) @@ -965,13 +983,18 @@ async def _wait_until_closed( return todo_task.result() -def _get_headers() -> typing.Dict[str, typing.Any]: - return { +def _get_headers( + subprotocols: typing.Optional[typing.List[str]], +) -> typing.Dict[str, typing.Any]: + headers = { "connection": "upgrade", "upgrade": "websocket", "sec-websocket-key": base64.b64encode(secrets.token_bytes(16)).decode("utf-8"), "sec-websocket-version": "13", } + if subprotocols is not None: + headers["sec-websocket-protocol"] = ", ".join(subprotocols) + return headers @contextlib.contextmanager @@ -987,21 +1010,25 @@ def _connect_ws( keepalive_ping_timeout_seconds: typing.Optional[ float ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: typing.Optional[typing.List[str]] = None, **kwargs: typing.Any, ) -> typing.Generator[WebSocketSession, None, None]: headers = kwargs.pop("headers", {}) - headers.update(_get_headers()) + headers.update(_get_headers(subprotocols)) with client.stream("GET", url, headers=headers, **kwargs) as response: if response.status_code != 101: raise WebSocketUpgradeError(response) + subprotocol = response.headers.get("sec-websocket-protocol") + session = WebSocketSession( response.extensions["network_stream"], max_message_size_bytes=max_message_size_bytes, queue_size=queue_size, keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocol=subprotocol, ) yield session session.close() @@ -1020,6 +1047,7 @@ def connect_ws( keepalive_ping_timeout_seconds: typing.Optional[ float ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: typing.Optional[typing.List[str]] = None, **kwargs: typing.Any, ) -> typing.Generator[WebSocketSession, None, None]: """ @@ -1052,6 +1080,8 @@ def connect_ws( [WebSocketNetworkError][httpx_ws.WebSocketNetworkError] will be raised and the connection closed. Defaults to 20 seconds. + subprotocols: + Optional list of suprotocols to negotiate with the server. **kwargs: Additional keyword arguments that will be passed to the [HTTPX stream()](https://www.python-httpx.org/api/#request) method. @@ -1085,6 +1115,7 @@ def connect_ws( queue_size=queue_size, keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocols=subprotocols, **kwargs, ) as websocket: yield websocket @@ -1096,6 +1127,7 @@ def connect_ws( queue_size=queue_size, keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocols=subprotocols, **kwargs, ) as websocket: yield websocket @@ -1114,21 +1146,25 @@ async def _aconnect_ws( keepalive_ping_timeout_seconds: typing.Optional[ float ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: typing.Optional[typing.List[str]] = None, **kwargs: typing.Any, ) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: headers = kwargs.pop("headers", {}) - headers.update(_get_headers()) + headers.update(_get_headers(subprotocols)) async with client.stream("GET", url, headers=headers, **kwargs) as response: if response.status_code != 101: raise WebSocketUpgradeError(response) + subprotocol = response.headers.get("sec-websocket-protocol") + session = AsyncWebSocketSession( response.extensions["network_stream"], max_message_size_bytes=max_message_size_bytes, queue_size=queue_size, keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocol=subprotocol, ) yield session await session.close() @@ -1147,6 +1183,7 @@ async def aconnect_ws( keepalive_ping_timeout_seconds: typing.Optional[ float ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: typing.Optional[typing.List[str]] = None, **kwargs: typing.Any, ) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: """ @@ -1179,6 +1216,8 @@ async def aconnect_ws( [WebSocketNetworkError][httpx_ws.WebSocketNetworkError] will be raised and the connection closed. Defaults to 20 seconds. + subprotocols: + Optional list of suprotocols to negotiate with the server. **kwargs: Additional keyword arguments that will be passed to the [HTTPX stream()](https://www.python-httpx.org/api/#request) method. @@ -1212,6 +1251,7 @@ async def aconnect_ws( queue_size=queue_size, keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocols=subprotocols, **kwargs, ) as websocket: yield websocket @@ -1223,6 +1263,7 @@ async def aconnect_ws( queue_size=queue_size, keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocols=subprotocols, **kwargs, ) as websocket: yield websocket diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 2f2e17c6..e1ef09bb 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -808,3 +808,56 @@ async def test_default_httpx_client(): httpx_client = mock_aconnect_ws.call_args[1]["client"] assert isinstance(httpx_client, httpx.AsyncClient) assert httpx_client.is_closed + + +@pytest.mark.asyncio +async def test_subprotocol(): + def handler(request): + assert ( + request.headers["sec-websocket-protocol"] + == "custom_protocol, unsupported_protocol" + ) + + return httpx.Response( + 101, + headers={"sec-websocket-protocol": "custom_protocol"}, + extensions={"network_stream": MagicMock()}, + ) + + def async_handler(request): + assert ( + request.headers["sec-websocket-protocol"] + == "custom_protocol, unsupported_protocol" + ) + + network_stream = MagicMock() + async_method_return_value = asyncio.Future() + async_method_return_value.set_result(MagicMock()) + network_stream.write.return_value = async_method_return_value + network_stream.aclose.return_value = async_method_return_value + + return httpx.Response( + 101, + headers={"sec-websocket-protocol": "custom_protocol"}, + extensions={"network_stream": network_stream}, + ) + + with httpx.Client( + base_url="http://localhost:8000", transport=httpx.MockTransport(handler) + ) as client: + with connect_ws( + "http://socket/ws", + client, + subprotocols=["custom_protocol", "unsupported_protocol"], + ) as ws: + assert ws.subprotocol == "custom_protocol" + + async with httpx.AsyncClient( + base_url="http://localhost:8000", transport=httpx.MockTransport(async_handler) + ) as client: + async with aconnect_ws( + "http://socket/ws", + client, + subprotocols=["custom_protocol", "unsupported_protocol"], + ) as aws: + assert aws.subprotocol == "custom_protocol" From 1093bba83c37de6a436a3bd3224d53786d8e1fe4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 15 Feb 2023 12:03:09 +0100 Subject: [PATCH 052/108] =?UTF-8?q?Bump=20version=200.2.6=20=E2=86=92=200.?= =?UTF-8?q?3.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New features ------------ * Add support for subprotocols [[Doccumentation](https://frankie567.github.io/httpx-ws/usage/subprotocols/)] * Thanks @davidbrochart for the idea and feedback 👍 --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index a165f5bc..a3002b9e 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.2.6" +__version__ = "0.3.0" from httpx_ws._api import ( AsyncWebSocketSession, From 99fc8ee6583d86c476ceb8e462ff68ae3676c2a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Sat, 15 Apr 2023 13:45:45 +0200 Subject: [PATCH 053/108] Fix #30: handle server error in ASGI transport --- src/httpx2/httpx2/_websockets/transport.py | 12 +++++++++++- tests/httpx2/websockets/test_transport.py | 10 ++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index 745318d3..432db8ec 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -7,6 +7,7 @@ import wsproto from httpcore.backends.base import AsyncNetworkStream from httpx import ASGITransport, AsyncByteStream, Request, Response +from wsproto.connection import CloseReason from httpx_ws._api import WebSocketDisconnect @@ -48,6 +49,7 @@ async def __aenter__(self) -> "ASGIWebSocketAsyncNetworkStream": await self.send({"type": "websocket.connect"}) message = await self.receive() if message["type"] == "websocket.close": + await self.aclose() raise WebSocketDisconnect(message["code"], message.get("reason")) assert message["type"] == "websocket.accept" return self @@ -116,7 +118,15 @@ async def _run(self) -> None: scope = self.scope receive = self._asgi_receive send = self._asgi_send - await self.app(scope, receive, send) + try: + await self.app(scope, receive, send) + except Exception as e: + message = { + "type": "websocket.close", + "code": CloseReason.INTERNAL_ERROR, + "reason": str(e), + } + await self._asgi_send(message) async def _asgi_receive(self) -> Message: while self._receive_queue.empty(): diff --git a/tests/httpx2/websockets/test_transport.py b/tests/httpx2/websockets/test_transport.py index 523d9f18..ebaed25d 100644 --- a/tests/httpx2/websockets/test_transport.py +++ b/tests/httpx2/websockets/test_transport.py @@ -97,6 +97,16 @@ async def app(scope, receive, send): async with ASGIWebSocketAsyncNetworkStream(app, {}): pass + async def test_exception(self): + async def app(scope, receive, send): + raise Exception("Error") + + with pytest.raises(WebSocketDisconnect) as excinfo: + async with ASGIWebSocketAsyncNetworkStream(app, {}): + pass + assert excinfo.value.code == 1011 + assert excinfo.value.reason == "Error" + @pytest.fixture def test_app() -> Starlette: From 4975b7686487fce42b9e7483805ce40f1484d50f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Sat, 15 Apr 2023 13:46:46 +0200 Subject: [PATCH 054/108] Remove useless docstring on event_loop fixture --- tests/httpx2/websockets/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index dd385ec2..b7699bad 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -21,7 +21,6 @@ @pytest.fixture(scope="session") def event_loop(): - """Force the pytest-asyncio loop to be the main one.""" loop = asyncio.new_event_loop() yield loop loop.close() From e62c81db1c63bf06eb6ea827c2e3669f2938ae46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Sat, 15 Apr 2023 13:48:29 +0200 Subject: [PATCH 055/108] =?UTF-8?q?Bump=20version=200.3.0=20=E2=86=92=200.?= =?UTF-8?q?3.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes --------- * Fix #30: ASGI transport now correctly handles server errors and closes the WebSocket instead of hanging. Thanks @ysmu 🎉 --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index a3002b9e..5858a2ee 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.3.0" +__version__ = "0.3.1" from httpx_ws._api import ( AsyncWebSocketSession, From f895896fb4210e75a8742811bc63971f1e10a1b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 27 Jun 2023 15:10:49 +0200 Subject: [PATCH 056/108] Drop Python 3.7 support --- src/httpx2/httpx2/_websockets/_api.py | 8 +------- tests/httpx2/websockets/conftest.py | 8 +------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 27b009fb..471540af 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -5,15 +5,9 @@ import json import queue import secrets -import sys import threading import typing -if sys.version_info < (3, 8): - from typing_extensions import Literal # pragma: no cover -else: - from typing import Literal # pragma: no cover - import httpcore import httpx import wsproto @@ -22,7 +16,7 @@ from httpx_ws._ping import AsyncPingManager, PingManager -JSONMode = Literal["text", "binary"] +JSONMode = typing.Literal["text", "binary"] TaskFunction = typing.TypeVar("TaskFunction") TaskResult = typing.TypeVar("TaskResult") diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index b7699bad..ea8510df 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -2,16 +2,10 @@ import contextlib import pathlib import queue -import sys import tempfile -from typing import Callable, ContextManager +from typing import Callable, ContextManager, Literal, Protocol from unittest.mock import MagicMock -if sys.version_info < (3, 8): - from typing_extensions import Literal, Protocol # pragma: no cover -else: - from typing import Literal, Protocol # pragma: no cover - import pytest import uvicorn from anyio.from_thread import start_blocking_portal From 77e83561efce12466606d947d0b185d22b774c32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 27 Jun 2023 15:12:44 +0200 Subject: [PATCH 057/108] Use Starlette lifespan instead of on_startup --- tests/httpx2/websockets/conftest.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index ea8510df..dd5127a3 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -49,10 +49,12 @@ def create_app() -> Starlette: WebSocketRoute("/ws", endpoint=endpoint), ] - async def on_startup(): + @contextlib.asynccontextmanager + async def lifespan(app: Starlette): startup_queue.put(True) + yield - return Starlette(routes=routes, on_startup=[on_startup]) + return Starlette(routes=routes, lifespan=lifespan) def create_server(app: Starlette, socket: str): config = uvicorn.Config(app, uds=socket, ws=websocket_implementation) From 231736a20187d976e53992752fb3b28b7473a017 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 27 Jun 2023 15:15:31 +0200 Subject: [PATCH 058/108] =?UTF-8?q?Bump=20version=200.3.1=20=E2=86=92=200.?= =?UTF-8?q?4.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Breaking changes ---------------- * Drop Python 3.7 support --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 5858a2ee..72771f63 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.3.1" +__version__ = "0.4.0" from httpx_ws._api import ( AsyncWebSocketSession, From 7fee7d39b5b2161a851a2437f06b14ba86e4c044 Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Thu, 6 Jul 2023 08:48:49 -0500 Subject: [PATCH 059/108] Fixes `httpcore` import in `httpx_ws/_api.py` (#36) * Fixes `httpcore` import in `httpx_ws/_api.py` * Replace `httpcore.backends` import in `httpx_ws/_api.py` * Replace `httpcore.backends` import in `httpx_ws/transport.py` * Bump `httpcore` in `pyproject.toml` * Remove `NetworkStream` import from `transport.py` * Fix `httpcore` import in `tests/test_api.py` --- src/httpx2/httpx2/_websockets/_api.py | 2 +- src/httpx2/httpx2/_websockets/transport.py | 2 +- tests/httpx2/websockets/test_api.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 471540af..cd6a838d 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -11,7 +11,7 @@ import httpcore import httpx import wsproto -from httpcore.backends.base import AsyncNetworkStream, NetworkStream +from httpcore import AsyncNetworkStream, NetworkStream from wsproto.connection import CloseReason from httpx_ws._ping import AsyncPingManager, PingManager diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index 432db8ec..5d687282 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -5,7 +5,7 @@ import anyio import wsproto -from httpcore.backends.base import AsyncNetworkStream +from httpcore import AsyncNetworkStream from httpx import ASGITransport, AsyncByteStream, Request, Response from wsproto.connection import CloseReason diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index e1ef09bb..9ff7fb67 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -9,7 +9,7 @@ import httpx import pytest import wsproto -from httpcore.backends.base import AsyncNetworkStream, NetworkStream +from httpcore import AsyncNetworkStream, NetworkStream from starlette.websockets import WebSocket from starlette.websockets import WebSocketDisconnect as StarletteWebSocketDisconnect From 5c166069da24c80e93425931ebcdc1641ceb9e70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 6 Jul 2023 15:52:16 +0200 Subject: [PATCH 060/108] =?UTF-8?q?Bump=20version=200.4.0=20=E2=86=92=200.?= =?UTF-8?q?4.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes --------- * Fix import issue with `httpcore>=0.17.3`. Thanks @saforem2 🚀 --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 72771f63..99ff81a5 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.0" +__version__ = "0.4.1" from httpx_ws._api import ( AsyncWebSocketSession, From b3d2f080cce3762985df51c5f593fa320ad08ab6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 8 Aug 2023 11:53:13 +0200 Subject: [PATCH 061/108] Fix #40: handle large message buffering (#44) --- src/httpx2/httpx2/_websockets/_api.py | 40 ++++++++++++++++++++++++ tests/httpx2/websockets/test_api.py | 45 +++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index cd6a838d..5abfebe1 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -468,6 +468,7 @@ def _background_receive(self, max_bytes: int) -> None: Args: max_bytes: The maximum chunk size to read at each iteration. """ + partial_message_buffer: typing.Union[str, bytes, None] = None try: while not self._should_close.is_set(): data = self._wait_until_closed(self.stream.read, max_bytes) @@ -482,6 +483,25 @@ def _background_receive(self, max_bytes: int) -> None: continue if isinstance(event, wsproto.events.CloseConnection): self._should_close.set() + if isinstance(event, wsproto.events.Message): + # Unfinished message: bufferize + if not event.message_finished: + if partial_message_buffer is None: + partial_message_buffer = event.data + else: + partial_message_buffer += event.data + # Finished message but no buffer: just emit the event + elif partial_message_buffer is None: + self._events.put(event) + # Finished message with buffer: emit the full event + else: + event_type = type(event) + full_message_event = event_type( + partial_message_buffer + event.data + ) + partial_message_buffer = None + self._events.put(full_message_event) + continue self._events.put(event) except (httpcore.ReadError, httpcore.WriteError): self.close(CloseReason.INTERNAL_ERROR, "Stream error") @@ -919,6 +939,7 @@ async def _background_receive(self, max_bytes: int) -> None: Args: max_bytes: The maximum chunk size to read at each iteration. """ + partial_message_buffer: typing.Union[str, bytes, None] = None try: while not self._should_close.is_set(): data = await self._wait_until_closed( @@ -935,6 +956,25 @@ async def _background_receive(self, max_bytes: int) -> None: continue if isinstance(event, wsproto.events.CloseConnection): self._should_close.set() + if isinstance(event, wsproto.events.Message): + # Unfinished message: bufferize + if not event.message_finished: + if partial_message_buffer is None: + partial_message_buffer = event.data + else: + partial_message_buffer += event.data + # Finished message but no buffer: just emit the event + elif partial_message_buffer is None: + await self._events.put(event) + # Finished message with buffer: emit the full event + else: + event_type = type(event) + full_message_event = event_type( + partial_message_buffer + event.data + ) + partial_message_buffer = None + await self._events.put(full_message_event) + continue await self._events.put(event) except (httpcore.ReadError, httpcore.WriteError): await self.close(CloseReason.INTERNAL_ERROR, "Stream error") diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 9ff7fb67..c304f276 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -332,6 +332,51 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass + @pytest.mark.parametrize( + "full_message,send_method", + [ + (b"A" * 1024 * 1024, "send_bytes"), + ("A" * 1024 * 1024, "send_text"), + ], + ) + async def test_receive_oversized_message( + self, + full_message: typing.Union[str, bytes], + send_method: str, + server_factory: ServerFactoryFixture, + ): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + await asyncio.sleep(0.1) # FIXME: see #7 + + method = getattr(websocket, send_method) + await method(full_message) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + try: + with connect_ws( + "http://socket/ws", client, max_message_size_bytes=1024 + ) as ws: + event = ws.receive() + assert isinstance(event, wsproto.events.Message) + assert event.data == full_message + except WebSocketDisconnect: + pass + + async with httpx.AsyncClient( + transport=httpx.AsyncHTTPTransport(uds=socket) + ) as aclient: + try: + async with aconnect_ws("http://socket/ws", aclient) as aws: + event = await aws.receive() + assert isinstance(event, wsproto.events.Message) + assert event.data == full_message + except WebSocketDisconnect: + pass + async def test_receive_text(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() From 45a66175569f0a604c6d41e1c1d92b3066c10160 Mon Sep 17 00:00:00 2001 From: Marc-Antoine Parent Date: Thu, 7 Sep 2023 02:12:28 -0400 Subject: [PATCH 062/108] start_blocking_portal was moved in anyio 4 (#48) --- src/httpx2/httpx2/_websockets/transport.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index 5d687282..309928ef 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -43,7 +43,7 @@ def __init__(self, app: ASGIApp, scope: Scope) -> None: async def __aenter__(self) -> "ASGIWebSocketAsyncNetworkStream": self.exit_stack = contextlib.ExitStack() self.portal = self.exit_stack.enter_context( - anyio.start_blocking_portal("asyncio") + anyio.from_thread.start_blocking_portal("asyncio") ) _: "Future[None]" = self.portal.start_task_soon(self._run) await self.send({"type": "websocket.connect"}) From b3f7885d0703584d06613d92cb2c69d1f24933a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 7 Sep 2023 08:35:42 +0200 Subject: [PATCH 063/108] Fix some wsproto imports --- src/httpx2/httpx2/_websockets/_api.py | 6 +++--- src/httpx2/httpx2/_websockets/transport.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 5abfebe1..ded6b87b 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -12,7 +12,7 @@ import httpx import wsproto from httpcore import AsyncNetworkStream, NetworkStream -from wsproto.connection import CloseReason +from wsproto.frame_protocol import CloseReason from httpx_ws._ping import AsyncPingManager, PingManager @@ -107,7 +107,7 @@ def __init__( subprotocol: typing.Optional[str] = None, ) -> None: self.stream = stream - self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) + self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) self.subprotocol = subprotocol self._events: queue.Queue[ @@ -577,7 +577,7 @@ def __init__( subprotocol: typing.Optional[str] = None, ) -> None: self.stream = stream - self.connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) + self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) self.subprotocol = subprotocol self._events: asyncio.Queue[ diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index 309928ef..6c08d33f 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -7,7 +7,7 @@ import wsproto from httpcore import AsyncNetworkStream from httpx import ASGITransport, AsyncByteStream, Request, Response -from wsproto.connection import CloseReason +from wsproto.frame_protocol import CloseReason from httpx_ws._api import WebSocketDisconnect From ba606bfe4bda7736db5213e1b35bf392eeec7fac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 27 Sep 2023 08:47:00 +0200 Subject: [PATCH 064/108] Fix #34: handle subprotocols corrrectly in `ASGIWebSocketTransport` --- src/httpx2/httpx2/_websockets/transport.py | 7 +++++++ tests/httpx2/websockets/test_transport.py | 23 +++++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index 6c08d33f..9335cd01 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -147,6 +147,12 @@ async def handle_async_request(self, request: Request) -> Response: headers = request.headers if scheme in {"ws", "wss"} or headers.get("upgrade") == "websocket": + subprotocols: typing.List[str] = [] + if ( + subprotocols_header := headers.get("sec-websocket-protocol") + ) is not None: + subprotocols = subprotocols_header.split(",") + scope = { "type": "websocket", "path": request.url.path, @@ -157,6 +163,7 @@ async def handle_async_request(self, request: Request) -> Response: "headers": [(k.lower(), v) for (k, v) in request.headers.raw], "client": self.client, "server": (request.url.host, request.url.port), + "subprotocols": subprotocols, } return await self._handle_ws_request(request, scope) diff --git a/tests/httpx2/websockets/test_transport.py b/tests/httpx2/websockets/test_transport.py index ebaed25d..0cd832ea 100644 --- a/tests/httpx2/websockets/test_transport.py +++ b/tests/httpx2/websockets/test_transport.py @@ -8,7 +8,7 @@ from starlette.routing import Route, WebSocketRoute from starlette.websockets import WebSocket -from httpx_ws import WebSocketDisconnect +from httpx_ws import WebSocketDisconnect, aconnect_ws from httpx_ws.transport import ( ASGIWebSocketAsyncNetworkStream, ASGIWebSocketTransport, @@ -153,3 +153,24 @@ async def test_websocket( assert isinstance( response.extensions["network_stream"], ASGIWebSocketAsyncNetworkStream ) + + +@pytest.mark.asyncio +async def test_subprotocol_support(): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + assert websocket.scope.get("subprotocols") == ["custom_protocol"] + await websocket.send_text("SERVER_MESSAGE") + await websocket.close() + + app = Starlette( + routes=[ + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + ) + + async with httpx.AsyncClient(transport=ASGIWebSocketTransport(app)) as client: + async with aconnect_ws( + "ws://localhost:8000/ws", client, subprotocols=["custom_protocol"] + ) as ws: + await ws.receive_text() From 03c17d5114068ad41e8f2f64da63dce8b3a3d780 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 27 Sep 2023 08:50:38 +0200 Subject: [PATCH 065/108] =?UTF-8?q?Bump=20version=200.4.1=20=E2=86=92=200.?= =?UTF-8?q?4.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes --------- * Fix anyio `start_blocking_portal` import. Thanks @maparent 🎉 * Fix #40: handle large message buffering * Fix #34: handle subprotocols corrrectly in `ASGIWebSocketTransport` --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 99ff81a5..bd1383ea 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.1" +__version__ = "0.4.2" from httpx_ws._api import ( AsyncWebSocketSession, From 142c205c84fdbdb0a1bfb39e4bd668d6590843b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 16 Nov 2023 10:11:52 +0100 Subject: [PATCH 066/108] Fix #56: return proper accept response headers with ASGIWebSocketTransport --- src/httpx2/httpx2/_websockets/transport.py | 41 +++++++++++-- tests/httpx2/websockets/test_transport.py | 67 +++++++++++++++++----- 2 files changed, 87 insertions(+), 21 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index 9335cd01..92c38081 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -38,21 +38,27 @@ def __init__(self, app: ASGIApp, scope: Scope) -> None: self.scope = scope self._receive_queue: queue.Queue[Message] = queue.Queue() self._send_queue: queue.Queue[Message] = queue.Queue() - self.connection = wsproto.connection.Connection(wsproto.connection.SERVER) + self.connection = wsproto.WSConnection(wsproto.ConnectionType.SERVER) + self.connection.initiate_upgrade_connection(scope["headers"], scope["path"]) - async def __aenter__(self) -> "ASGIWebSocketAsyncNetworkStream": + async def __aenter__( + self + ) -> typing.Tuple["ASGIWebSocketAsyncNetworkStream", bytes]: self.exit_stack = contextlib.ExitStack() self.portal = self.exit_stack.enter_context( anyio.from_thread.start_blocking_portal("asyncio") ) _: "Future[None]" = self.portal.start_task_soon(self._run) + await self.send({"type": "websocket.connect"}) message = await self.receive() + if message["type"] == "websocket.close": await self.aclose() raise WebSocketDisconnect(message["code"], message.get("reason")) + assert message["type"] == "websocket.accept" - return self + return self, self._build_accept_response(message) async def __aexit__(self, *args: typing.Any) -> None: await self.aclose() @@ -84,7 +90,9 @@ async def write( ) -> None: self.connection.receive_data(buffer) for event in self.connection.events(): - if isinstance(event, wsproto.events.CloseConnection): + if isinstance(event, wsproto.events.Request): + pass + elif isinstance(event, wsproto.events.CloseConnection): await self.send( { "type": "websocket.close", @@ -136,6 +144,16 @@ async def _asgi_receive(self) -> Message: async def _asgi_send(self, message: Message) -> None: self._send_queue.put(message) + def _build_accept_response(self, message: Message) -> bytes: + subprotocol = message.get("subprotocol", None) + headers = message.get("headers", []) + return self.connection.send( + wsproto.events.AcceptConnection( + subprotocol=subprotocol, + extra_headers=headers, + ) + ) + class ASGIWebSocketTransport(ASGITransport): def __init__(self, *args, **kwargs) -> None: @@ -178,11 +196,22 @@ async def _handle_ws_request( self.scope = scope self.exit_stack = contextlib.AsyncExitStack() - stream = await self.exit_stack.enter_async_context( + stream, accept_response = await self.exit_stack.enter_async_context( ASGIWebSocketAsyncNetworkStream(self.app, self.scope) ) - return Response(101, extensions={"network_stream": stream}) + accept_response_lines = accept_response.decode("utf-8").splitlines() + headers = [ + typing.cast(typing.Tuple[str, str], line.split(": ", 1)) + for line in accept_response_lines[1:] + if line.strip() != "" + ] + + return Response( + status_code=101, + headers=headers, + extensions={"network_stream": stream}, + ) async def aclose(self) -> None: if self.exit_stack: diff --git a/tests/httpx2/websockets/test_transport.py b/tests/httpx2/websockets/test_transport.py index 0cd832ea..be850729 100644 --- a/tests/httpx2/websockets/test_transport.py +++ b/tests/httpx2/websockets/test_transport.py @@ -1,3 +1,5 @@ +import base64 +import secrets from typing import Any, Dict import httpx @@ -12,14 +14,42 @@ from httpx_ws.transport import ( ASGIWebSocketAsyncNetworkStream, ASGIWebSocketTransport, + Scope, UnhandledASGIMessageType, UnhandledWebSocketEvent, ) +@pytest.fixture +def websocket_request_headers() -> Dict[str, str]: + return { + "connection": "upgrade", + "upgrade": "websocket", + "sec-websocket-key": base64.b64encode(secrets.token_bytes(16)).decode("utf-8"), + "sec-websocket-version": "13", + } + + +@pytest.fixture +def scope(websocket_request_headers: Dict[str, str]) -> Scope: + return { + "type": "websocket", + "path": "/ws", + "raw_path": "/ws", + "root_path": "/", + "scheme": "ws", + "headers": [ + ("host", "localhost"), + *websocket_request_headers.items(), + ], + "subprotocols": [], + "server": ("localhost", 8000), + } + + @pytest.mark.asyncio class TestASGIWebSocketAsyncNetworkStream: - async def test_write(self): + async def test_write(self, scope: Scope): received_messages = [] async def app(scope, receive, send): @@ -31,7 +61,7 @@ async def app(scope, receive, send): received_messages.append(message) connection = wsproto.connection.Connection(wsproto.connection.CLIENT) - async with ASGIWebSocketAsyncNetworkStream(app, {}) as stream: + async with ASGIWebSocketAsyncNetworkStream(app, scope) as (stream, _): text_event = wsproto.events.TextMessage("CLIENT_MESSAGE") await stream.write(connection.send(text_event)) @@ -48,18 +78,18 @@ async def app(scope, receive, send): {"type": "websocket.close", "code": 1000, "reason": ""}, ] - async def test_write_unhandled_event(self): + async def test_write_unhandled_event(self, scope: Scope): async def app(scope, receive, send): await send({"type": "websocket.accept"}) await receive() connection = wsproto.connection.Connection(wsproto.connection.CLIENT) - async with ASGIWebSocketAsyncNetworkStream(app, {}) as stream: + async with ASGIWebSocketAsyncNetworkStream(app, scope) as (stream, _): with pytest.raises(UnhandledWebSocketEvent): ping_event = wsproto.events.Ping(b"PING") await stream.write(connection.send(ping_event)) - async def test_read(self): + async def test_read(self, scope): async def app(scope, receive, send): await send({"type": "websocket.accept"}) await send({"type": "websocket.send", "text": "SERVER_MESSAGE"}) @@ -68,7 +98,7 @@ async def app(scope, receive, send): connection = wsproto.connection.Connection(wsproto.connection.CLIENT) events = [] - async with ASGIWebSocketAsyncNetworkStream(app, {}) as stream: + async with ASGIWebSocketAsyncNetworkStream(app, scope) as (stream, _): for _ in range(3): data = await stream.read(4096) connection.receive_data(data) @@ -80,29 +110,29 @@ async def app(scope, receive, send): wsproto.events.CloseConnection(1000, ""), ] - async def test_read_unhandled_asgi_message(self): + async def test_read_unhandled_asgi_message(self, scope): async def app(scope, receive, send): await send({"type": "websocket.accept"}) await send({"type": "websocket.foo"}) - async with ASGIWebSocketAsyncNetworkStream(app, {}) as stream: + async with ASGIWebSocketAsyncNetworkStream(app, scope) as (stream, _): with pytest.raises(UnhandledASGIMessageType): await stream.read(4096) - async def test_close_immediately(self): + async def test_close_immediately(self, scope): async def app(scope, receive, send): await send({"type": "websocket.close", "code": 1000, "reason": ""}) with pytest.raises(WebSocketDisconnect): - async with ASGIWebSocketAsyncNetworkStream(app, {}): + async with ASGIWebSocketAsyncNetworkStream(app, scope): pass - async def test_exception(self): + async def test_exception(self, scope): async def app(scope, receive, send): raise Exception("Error") with pytest.raises(WebSocketDisconnect) as excinfo: - async with ASGIWebSocketAsyncNetworkStream(app, {}): + async with ASGIWebSocketAsyncNetworkStream(app, scope): pass assert excinfo.value.code == 1011 assert excinfo.value.reason == "Error" @@ -143,10 +173,16 @@ async def test_http(self, test_app: Starlette): ], ) async def test_websocket( - self, url: str, headers: Dict[str, Any], test_app: Starlette + self, + url: str, + headers: Dict[str, Any], + test_app: Starlette, + websocket_request_headers: Dict[str, str], ): async with ASGIWebSocketTransport(app=test_app) as transport: - request = httpx.Request("GET", url, headers=headers) + request = httpx.Request( + "GET", url, headers={**websocket_request_headers, **headers} + ) response = await transport.handle_async_request(request) assert response.status_code == 101 @@ -158,7 +194,7 @@ async def test_websocket( @pytest.mark.asyncio async def test_subprotocol_support(): async def websocket_endpoint(websocket: WebSocket): - await websocket.accept() + await websocket.accept("custom_protocol") assert websocket.scope.get("subprotocols") == ["custom_protocol"] await websocket.send_text("SERVER_MESSAGE") await websocket.close() @@ -174,3 +210,4 @@ async def websocket_endpoint(websocket: WebSocket): "ws://localhost:8000/ws", client, subprotocols=["custom_protocol"] ) as ws: await ws.receive_text() + assert ws.subprotocol == "custom_protocol" From c4e1b6504d299803f61e653a454605a0b51a7f45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 4 Dec 2023 09:39:31 +0100 Subject: [PATCH 067/108] Fix pytest-asyncio warning --- tests/httpx2/websockets/conftest.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index dd5127a3..eb4197c6 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -1,4 +1,3 @@ -import asyncio import contextlib import pathlib import queue @@ -13,13 +12,6 @@ from starlette.routing import WebSocketRoute -@pytest.fixture(scope="session") -def event_loop(): - loop = asyncio.new_event_loop() - yield loop - loop.close() - - @pytest.fixture def on_receive_message(): return MagicMock() From 08123464bf20e0fafcbb1cb4e368339f4ad34074 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 4 Dec 2023 09:40:55 +0100 Subject: [PATCH 068/108] =?UTF-8?q?Bump=20version=200.4.2=20=E2=86=92=200.?= =?UTF-8?q?4.3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes & improvements ------------------------ * Fix #57: compatibility with `httpx>=0.25.2` --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index bd1383ea..cf301e3d 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.2" +__version__ = "0.4.3" from httpx_ws._api import ( AsyncWebSocketSession, From c4f053d9d5f9f162da38bd4725d751a547438876 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 8 Feb 2024 17:20:43 +0100 Subject: [PATCH 069/108] Implement AsyncWebSocketSession with AnyIO --- src/httpx2/httpx2/_websockets/_api.py | 109 ++++++++++----------- src/httpx2/httpx2/_websockets/_ping.py | 9 +- src/httpx2/httpx2/_websockets/transport.py | 2 +- tests/httpx2/websockets/conftest.py | 2 +- tests/httpx2/websockets/test_api.py | 109 ++++++++++----------- tests/httpx2/websockets/test_transport.py | 6 +- 6 files changed, 115 insertions(+), 122 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index ded6b87b..fb9a4601 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -1,4 +1,3 @@ -import asyncio import base64 import concurrent.futures import contextlib @@ -8,6 +7,7 @@ import threading import typing +import anyio import httpcore import httpx import wsproto @@ -580,27 +580,41 @@ def __init__( self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) self.subprotocol = subprotocol - self._events: asyncio.Queue[ + self._send_event, self._receive_event = anyio.create_memory_object_stream[ typing.Union[wsproto.events.Event, HTTPXWSException] - ] = asyncio.Queue(queue_size) + ]() self._ping_manager = AsyncPingManager() + self._should_close = anyio.Event() - self._should_close = asyncio.Event() + self._max_message_size_bytes = max_message_size_bytes + self._queue_size = queue_size + self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds + self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds - self._background_receive_task = asyncio.create_task( - self._background_receive(max_message_size_bytes) - ) + async def __aenter__(self): + self._exit_stack = contextlib.AsyncExitStack() + self._background_task_group = anyio.create_task_group() + await self._exit_stack.enter_async_context(self._background_task_group) - self._background_keepalive_ping_task: typing.Optional[asyncio.Task] = None - if keepalive_ping_interval_seconds is not None: - self._background_keepalive_ping_task = asyncio.create_task( - self._background_keepalive_ping( - keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds - ) + self._background_task_group.start_soon( + self._background_receive, self._max_message_size_bytes + ) + if self._keepalive_ping_interval_seconds is not None: + self._background_task_group.start_soon( + self._background_keepalive_ping, + self._keepalive_ping_interval_seconds, + self._keepalive_ping_timeout_seconds, ) - async def ping(self, payload: bytes = b"") -> asyncio.Event: + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.close() + self._background_task_group.cancel_scope.cancel() + await self._exit_stack.aclose() + + async def ping(self, payload: bytes = b"") -> anyio.Event: """ Send a Ping message. @@ -736,7 +750,7 @@ async def receive( A raw [wsproto.events.Event][wsproto.events.Event]. Raises: - asyncio.TimeoutError: No event was received before the timeout delay. + TimeoutError: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. WebSocketNetworkError: A network error occured. @@ -752,12 +766,13 @@ async def receive( try: event = await ws.receive(timeout=2.) - except asyncio.TimeoutError: + except TimeoutError: print("No event received.") except WebSocketDisconnect: print("Connection closed") """ - event = await asyncio.wait_for(self._events.get(), timeout) + with anyio.fail_after(timeout): + event = await self._receive_event.receive() if isinstance(event, HTTPXWSException): raise event if isinstance(event, wsproto.events.CloseConnection): @@ -777,7 +792,7 @@ async def receive_text(self, timeout: typing.Optional[float] = None) -> str: Text data. Raises: - asyncio.TimeoutError: No event was received before the timeout delay. + TimeoutError: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. WebSocketNetworkError: A network error occured. WebSocketInvalidTypeReceived: The received event was not a text message. @@ -794,7 +809,7 @@ async def receive_text(self, timeout: typing.Optional[float] = None) -> str: try: event = await ws.receive_text(timeout=2.) - except asyncio.TimeoutError: + except TimeoutError: print("No text received.") except WebSocketDisconnect: print("Connection closed") @@ -817,7 +832,7 @@ async def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: Bytes data. Raises: - asyncio.TimeoutError: No event was received before the timeout delay. + TimeoutError: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. WebSocketNetworkError: A network error occured. WebSocketInvalidTypeReceived: The received event was not a bytes message. @@ -834,7 +849,7 @@ async def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: try: data = await ws.receive_bytes(timeout=2.) - except asyncio.TimeoutError: + except TimeoutError: print("No data received.") except WebSocketDisconnect: print("Connection closed") @@ -863,7 +878,7 @@ async def receive_json( Parsed JSON data. Raises: - asyncio.TimeoutError: No event was received before the timeout delay. + TimeoutError: No event was received before the timeout delay. WebSocketDisconnect: The server closed the websocket. WebSocketNetworkError: A network error occured. WebSocketInvalidTypeReceived: The received event @@ -881,7 +896,7 @@ async def receive_json( try: data = await ws.receive_json(timeout=2.) - except asyncio.TimeoutError: + except TimeoutError: print("No data received.") except WebSocketDisconnect: print("Connection closed") @@ -942,9 +957,7 @@ async def _background_receive(self, max_bytes: int) -> None: partial_message_buffer: typing.Union[str, bytes, None] = None try: while not self._should_close.is_set(): - data = await self._wait_until_closed( - self.stream.read(max_bytes=max_bytes) - ) + data = await self.stream.read(max_bytes=max_bytes) self.connection.receive_data(data) for event in self.connection.events(): if isinstance(event, wsproto.events.Ping): @@ -965,7 +978,7 @@ async def _background_receive(self, max_bytes: int) -> None: partial_message_buffer += event.data # Finished message but no buffer: just emit the event elif partial_message_buffer is None: - await self._events.put(event) + await self._send_event.send(event) # Finished message with buffer: emit the full event else: event_type = type(event) @@ -973,13 +986,13 @@ async def _background_receive(self, max_bytes: int) -> None: partial_message_buffer + event.data ) partial_message_buffer = None - await self._events.put(full_message_event) + await self._send_event.send(full_message_event) continue - await self._events.put(event) + await self._send_event.send(event) except (httpcore.ReadError, httpcore.WriteError): await self.close(CloseReason.INTERNAL_ERROR, "Stream error") - await self._events.put(WebSocketNetworkError()) - except ShouldClose: + await self._send_event.send(WebSocketNetworkError()) + except anyio.get_cancelled_exc_class(): pass async def _background_keepalive_ping( @@ -987,35 +1000,20 @@ async def _background_keepalive_ping( ) -> None: try: while not self._should_close.is_set(): - await self._wait_until_closed(asyncio.sleep(interval_seconds)) + await anyio.sleep(interval_seconds) pong_callback = await self.ping() if timeout_seconds is not None: try: - await self._wait_until_closed( - asyncio.wait_for(pong_callback.wait(), timeout_seconds) - ) - except asyncio.TimeoutError: + with anyio.fail_after(timeout_seconds): + await pong_callback.wait() + except TimeoutError: await self.close( CloseReason.INTERNAL_ERROR, "Keepalive ping timeout" ) - await self._events.put(WebSocketNetworkError()) - except ShouldClose: + await self._send_event.send(WebSocketNetworkError()) + except anyio.get_cancelled_exc_class(): pass - async def _wait_until_closed( - self, coro: typing.Coroutine[typing.Any, typing.Any, TaskResult] - ) -> TaskResult: - wait_close_task = asyncio.create_task(self._should_close.wait()) - todo_task = asyncio.create_task(coro) - done, pending = await asyncio.wait( - {todo_task, wait_close_task}, return_when=asyncio.FIRST_COMPLETED - ) - for task in pending: - task.cancel() - if wait_close_task in done and todo_task not in done: - raise ShouldClose() - return todo_task.result() - def _get_headers( subprotocols: typing.Optional[typing.List[str]], @@ -1192,16 +1190,15 @@ async def _aconnect_ws( subprotocol = response.headers.get("sec-websocket-protocol") - session = AsyncWebSocketSession( + async with AsyncWebSocketSession( response.extensions["network_stream"], max_message_size_bytes=max_message_size_bytes, queue_size=queue_size, keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, subprotocol=subprotocol, - ) - yield session - await session.close() + ) as session: + yield session @contextlib.asynccontextmanager diff --git a/src/httpx2/httpx2/_websockets/_ping.py b/src/httpx2/httpx2/_websockets/_ping.py index c7ff025e..bd3c0a07 100644 --- a/src/httpx2/httpx2/_websockets/_ping.py +++ b/src/httpx2/httpx2/_websockets/_ping.py @@ -1,8 +1,9 @@ -import asyncio import secrets import threading import typing +import anyio + class PingManagerBase: def _generate_id(self) -> bytes: @@ -28,13 +29,13 @@ def ack(self, ping_id: typing.Union[bytes, bytearray]): class AsyncPingManager(PingManagerBase): def __init__(self) -> None: - self._pings: typing.Dict[bytes, asyncio.Event] = {} + self._pings: typing.Dict[bytes, anyio.Event] = {} def create( self, ping_id: typing.Optional[bytes] = None - ) -> typing.Tuple[bytes, asyncio.Event]: + ) -> typing.Tuple[bytes, anyio.Event]: ping_id = self._generate_id() if not ping_id else ping_id - event = asyncio.Event() + event = anyio.Event() self._pings[ping_id] = event return ping_id, event diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index 92c38081..60653eee 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -42,7 +42,7 @@ def __init__(self, app: ASGIApp, scope: Scope) -> None: self.connection.initiate_upgrade_connection(scope["headers"], scope["path"]) async def __aenter__( - self + self, ) -> typing.Tuple["ASGIWebSocketAsyncNetworkStream", bytes]: self.exit_stack = contextlib.ExitStack() self.portal = self.exit_stack.enter_context( diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index eb4197c6..6337d68d 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -29,7 +29,7 @@ def __call__(self, endpoint: Callable) -> ContextManager[str]: @pytest.fixture def server_factory( - websocket_implementation: Literal["wsproto", "websockets"] + websocket_implementation: Literal["wsproto", "websockets"], ) -> ServerFactoryFixture: @contextlib.contextmanager def _server_factory(endpoint: Callable): diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index c304f276..38359527 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -1,10 +1,10 @@ -import asyncio import contextlib import queue import time import typing from unittest.mock import MagicMock, call, patch +import anyio import httpcore import httpx import pytest @@ -27,7 +27,7 @@ from tests.conftest import ServerFactoryFixture -@pytest.mark.asyncio +@pytest.mark.anyio async def test_upgrade_error(): def handler(request): return httpx.Response(400) @@ -47,7 +47,7 @@ def handler(request): pass -@pytest.mark.asyncio +@pytest.mark.anyio class TestSend: async def test_send_error(self): class MockNetworkStream(NetworkStream): @@ -90,7 +90,7 @@ async def read( self, max_bytes: int, timeout: typing.Optional[float] = None ) -> bytes: while not self._should_close: - await asyncio.sleep(0.1) + await anyio.sleep(0.1) raise httpcore.ReadError() async def write( @@ -102,10 +102,9 @@ async def aclose(self) -> None: self._should_close = True stream = AsyncMockNetworkStream() - websocket_session = AsyncWebSocketSession(stream) with pytest.raises(WebSocketNetworkError): - await websocket_session.send(wsproto.events.Ping()) - await websocket_session.close() + async with AsyncWebSocketSession(stream) as websocket_session: + await websocket_session.send(wsproto.events.Ping()) async def test_send( self, @@ -248,7 +247,7 @@ async def websocket_endpoint(websocket: WebSocket): ) -@pytest.mark.asyncio +@pytest.mark.anyio class TestReceive: async def test_receive_error(self): class MockNetworkStream(NetworkStream): @@ -297,15 +296,14 @@ async def aclose(self) -> None: pass stream = AsyncMockNetworkStream() - websocket_session = AsyncWebSocketSession(stream) with pytest.raises(WebSocketNetworkError): - await websocket_session.receive() - await websocket_session.close() + async with AsyncWebSocketSession(stream) as websocket_session: + await websocket_session.receive() async def test_receive(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await asyncio.sleep(0.1) # FIXME: see #7 + await anyio.sleep(0.1) # FIXME: see #7 await websocket.send_text("SERVER_MESSAGE") @@ -347,7 +345,7 @@ async def test_receive_oversized_message( ): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await asyncio.sleep(0.1) # FIXME: see #7 + await anyio.sleep(0.1) # FIXME: see #7 method = getattr(websocket, send_method) await method(full_message) @@ -370,7 +368,11 @@ async def websocket_endpoint(websocket: WebSocket): transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: try: - async with aconnect_ws("http://socket/ws", aclient) as aws: + async with aconnect_ws( + "http://socket/ws", + aclient, + keepalive_ping_interval_seconds=None, + ) as aws: event = await aws.receive() assert isinstance(event, wsproto.events.Message) assert event.data == full_message @@ -380,7 +382,7 @@ async def websocket_endpoint(websocket: WebSocket): async def test_receive_text(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await asyncio.sleep(0.1) # FIXME: see #7 + await anyio.sleep(0.1) # FIXME: see #7 await websocket.send_text("SERVER_MESSAGE") @@ -410,7 +412,7 @@ async def test_receive_text_invalid_type( ): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await asyncio.sleep(0.1) # FIXME: see #7 + await anyio.sleep(0.1) # FIXME: see #7 await websocket.send_bytes(b"SERVER_MESSAGE") @@ -438,7 +440,7 @@ async def websocket_endpoint(websocket: WebSocket): async def test_receive_bytes(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await asyncio.sleep(0.1) # FIXME: see #7 + await anyio.sleep(0.1) # FIXME: see #7 await websocket.send_bytes(b"SERVER_MESSAGE") @@ -468,7 +470,7 @@ async def test_receive_bytes_invalid_type( ): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await asyncio.sleep(0.1) # FIXME: see #7 + await anyio.sleep(0.1) # FIXME: see #7 await websocket.send_text("SERVER_MESSAGE") @@ -493,7 +495,7 @@ async def test_receive_json( ): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await asyncio.sleep(0.1) # FIXME: see #7 + await anyio.sleep(0.1) # FIXME: see #7 await websocket.send_json({"message": "SERVER_MESSAGE"}, mode=mode) @@ -519,7 +521,7 @@ async def websocket_endpoint(websocket: WebSocket): pass -@pytest.mark.asyncio +@pytest.mark.anyio class TestReceivePing: async def test_receive_ping(self): class MockNetworkStream(NetworkStream): @@ -551,7 +553,7 @@ def close(self) -> None: stream = MockNetworkStream() websocket_session = WebSocketSession(stream) - await asyncio.sleep(0.1) + await anyio.sleep(0.1) websocket_session.close() received_events = list(stream.connection.events()) @@ -589,9 +591,8 @@ async def aclose(self) -> None: pass stream = MockAsyncNetworkStream() - websocket_session = AsyncWebSocketSession(stream) - await asyncio.sleep(0.1) - await websocket_session.close() + async with AsyncWebSocketSession(stream): + await anyio.sleep(0.1) received_events = list(stream.connection.events()) assert received_events == [ @@ -600,7 +601,7 @@ async def aclose(self) -> None: ] -@pytest.mark.asyncio +@pytest.mark.anyio class TestKeepalivePing: async def test_keepalive_ping(self): class MockNetworkStream(NetworkStream): @@ -643,7 +644,7 @@ def close(self) -> None: keepalive_ping_interval_seconds=0.1, keepalive_ping_timeout_seconds=0.1, ) - await asyncio.sleep(0.2) + await anyio.sleep(0.2) websocket_session.close() assert stream.ping_received >= 1 @@ -690,20 +691,21 @@ def __init__(self) -> None: self._should_close = False self.ping_received = 0 self.ping_answered = 0 - self.events_to_send: asyncio.Queue[ - wsproto.events.Event - ] = asyncio.Queue() + ( + self.send_events, + self.receive_events, + ) = anyio.create_memory_object_stream[wsproto.events.Event]() async def read( self, max_bytes: int, timeout: typing.Optional[float] = None ) -> bytes: while not self._should_close: try: - event = self.events_to_send.get_nowait() + event = self.receive_events.receive_nowait() self.ping_answered += 1 return self.connection.send(event) - except asyncio.QueueEmpty: - await asyncio.sleep(0.1) + except anyio.WouldBlock: + await anyio.sleep(0.1) raise httpcore.ReadError() async def write( @@ -713,19 +715,18 @@ async def write( for event in self.connection.events(): if isinstance(event, wsproto.events.Ping): self.ping_received += 1 - await self.events_to_send.put(event.response()) + await self.send_events.send(event.response()) async def aclose(self) -> None: self._should_close = True stream = MockAsyncNetworkStream() - websocket_session = AsyncWebSocketSession( + async with AsyncWebSocketSession( stream, keepalive_ping_interval_seconds=0.1, keepalive_ping_timeout_seconds=0.1, - ) - await asyncio.sleep(0.3) - await websocket_session.close() + ): + await anyio.sleep(0.3) assert stream.ping_received >= 1 assert stream.ping_answered >= 1 @@ -742,7 +743,7 @@ async def read( self, max_bytes: int, timeout: typing.Optional[float] = None ) -> bytes: while not self._should_close: - await asyncio.sleep(0.1) + await anyio.sleep(0.1) raise httpcore.ReadError() async def write( @@ -755,15 +756,15 @@ async def aclose(self) -> None: stream = MockAsyncNetworkStream() with pytest.raises(WebSocketNetworkError): - websocket_session = AsyncWebSocketSession( + async with AsyncWebSocketSession( stream, keepalive_ping_interval_seconds=0.1, keepalive_ping_timeout_seconds=0.1, - ) - await websocket_session.receive() + ) as websocket_session: + await websocket_session.receive() -@pytest.mark.asyncio +@pytest.mark.anyio async def test_ping_pong(server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() @@ -784,11 +785,11 @@ async def websocket_endpoint(websocket: WebSocket): ) as aclient: async with aconnect_ws("http://socket/ws", aclient) as aws: aping_callback = await aws.ping() - aresult = await aping_callback.wait() - assert aresult is True + await aping_callback.wait() + assert aping_callback.is_set() -@pytest.mark.asyncio +@pytest.mark.anyio async def test_send_close(server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() @@ -809,11 +810,11 @@ async def websocket_endpoint(websocket: WebSocket): pass -@pytest.mark.asyncio +@pytest.mark.anyio async def test_receive_close(server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await asyncio.sleep(0.1) # FIXME: see #7 + await anyio.sleep(0.1) # FIXME: see #7 await websocket.close() with server_factory(websocket_endpoint) as socket: @@ -830,7 +831,7 @@ async def websocket_endpoint(websocket: WebSocket): await aws.receive() -@pytest.mark.asyncio +@pytest.mark.anyio async def test_default_httpx_client(): mock_context = contextlib.ExitStack() with patch( @@ -855,7 +856,7 @@ async def test_default_httpx_client(): assert httpx_client.is_closed -@pytest.mark.asyncio +@pytest.mark.anyio async def test_subprotocol(): def handler(request): assert ( @@ -866,7 +867,7 @@ def handler(request): return httpx.Response( 101, headers={"sec-websocket-protocol": "custom_protocol"}, - extensions={"network_stream": MagicMock()}, + extensions={"network_stream": MagicMock(spec=NetworkStream)}, ) def async_handler(request): @@ -875,16 +876,10 @@ def async_handler(request): == "custom_protocol, unsupported_protocol" ) - network_stream = MagicMock() - async_method_return_value = asyncio.Future() - async_method_return_value.set_result(MagicMock()) - network_stream.write.return_value = async_method_return_value - network_stream.aclose.return_value = async_method_return_value - return httpx.Response( 101, headers={"sec-websocket-protocol": "custom_protocol"}, - extensions={"network_stream": network_stream}, + extensions={"network_stream": MagicMock(spec=AsyncNetworkStream)}, ) with httpx.Client( diff --git a/tests/httpx2/websockets/test_transport.py b/tests/httpx2/websockets/test_transport.py index be850729..22822ab4 100644 --- a/tests/httpx2/websockets/test_transport.py +++ b/tests/httpx2/websockets/test_transport.py @@ -47,7 +47,7 @@ def scope(websocket_request_headers: Dict[str, str]) -> Scope: } -@pytest.mark.asyncio +@pytest.mark.anyio class TestASGIWebSocketAsyncNetworkStream: async def test_write(self, scope: Scope): received_messages = [] @@ -156,7 +156,7 @@ async def websocket_endpoint(websocket: WebSocket): return Starlette(routes=routes) -@pytest.mark.asyncio +@pytest.mark.anyio class TestASGIWebSocketTransport: async def test_http(self, test_app: Starlette): async with ASGIWebSocketTransport(app=test_app) as transport: @@ -191,7 +191,7 @@ async def test_websocket( ) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_subprotocol_support(): async def websocket_endpoint(websocket: WebSocket): await websocket.accept("custom_protocol") From 35fd6424d7cffc9c1d8ac93958a7503cc4569bf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 8 Feb 2024 17:31:36 +0100 Subject: [PATCH 070/108] Make WebSocketSession a context manager to mirror the async implementation --- src/httpx2/httpx2/_websockets/_api.py | 31 ++++++++++++++++++++------- tests/httpx2/websockets/test_api.py | 28 +++++++++++------------- 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index fb9a4601..d42e764f 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -115,22 +115,38 @@ def __init__( ] = queue.Queue(queue_size) self._ping_manager = PingManager() - self._should_close = threading.Event() + self._max_message_size_bytes = max_message_size_bytes + self._queue_size = queue_size + self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds + self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds + + def __enter__(self) -> "WebSocketSession": self._background_receive_task = threading.Thread( - target=self._background_receive, args=(max_message_size_bytes,) + target=self._background_receive, args=(self._max_message_size_bytes,) ) self._background_receive_task.start() self._background_keepalive_ping_task: typing.Optional[threading.Thread] = None - if keepalive_ping_interval_seconds is not None: + if self._keepalive_ping_interval_seconds is not None: self._background_keepalive_ping_task = threading.Thread( target=self._background_keepalive_ping, - args=(keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds), + args=( + self._keepalive_ping_interval_seconds, + self._keepalive_ping_timeout_seconds, + ), ) self._background_keepalive_ping_task.start() + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + self._background_receive_task.join() + if self._background_keepalive_ping_task is not None: + self._background_keepalive_ping_task.join() + def ping(self, payload: bytes = b"") -> threading.Event: """ Send a Ping message. @@ -1054,16 +1070,15 @@ def _connect_ws( subprotocol = response.headers.get("sec-websocket-protocol") - session = WebSocketSession( + with WebSocketSession( response.extensions["network_stream"], max_message_size_bytes=max_message_size_bytes, queue_size=queue_size, keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, subprotocol=subprotocol, - ) - yield session - session.close() + ) as session: + yield session @contextlib.contextmanager diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 38359527..d0068ba0 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -73,10 +73,9 @@ def close(self) -> None: self._should_close = True stream = MockNetworkStream() - websocket_session = WebSocketSession(stream) with pytest.raises(WebSocketNetworkError): - websocket_session.send(wsproto.events.Ping()) - websocket_session.close() + with WebSocketSession(stream) as websocket_session: + websocket_session.send(wsproto.events.Ping()) async def test_async_send_error(self): class AsyncMockNetworkStream(AsyncNetworkStream): @@ -270,10 +269,9 @@ def close(self) -> None: pass stream = MockNetworkStream() - websocket_session = WebSocketSession(stream) with pytest.raises(WebSocketNetworkError): - websocket_session.receive() - websocket_session.close() + with WebSocketSession(stream) as websocket_session: + websocket_session.receive() async def test_async_receive_error(self): class AsyncMockNetworkStream(AsyncNetworkStream): @@ -552,9 +550,8 @@ def close(self) -> None: pass stream = MockNetworkStream() - websocket_session = WebSocketSession(stream) - await anyio.sleep(0.1) - websocket_session.close() + with WebSocketSession(stream): + await anyio.sleep(0.1) received_events = list(stream.connection.events()) assert received_events == [ @@ -639,13 +636,12 @@ def close(self) -> None: self._should_close = True stream = MockNetworkStream() - websocket_session = WebSocketSession( + with WebSocketSession( stream, keepalive_ping_interval_seconds=0.1, keepalive_ping_timeout_seconds=0.1, - ) - await anyio.sleep(0.2) - websocket_session.close() + ): + await anyio.sleep(0.2) assert stream.ping_received >= 1 assert stream.ping_answered >= 1 @@ -675,12 +671,12 @@ def close(self) -> None: stream = MockNetworkStream() with pytest.raises(WebSocketNetworkError): - websocket_session = WebSocketSession( + with WebSocketSession( stream, keepalive_ping_interval_seconds=0.1, keepalive_ping_timeout_seconds=0.1, - ) - websocket_session.receive() + ) as websocket_session: + websocket_session.receive() async def test_async_keepalive_ping(self): class MockAsyncNetworkStream(AsyncNetworkStream): From 4592cf823d32b41777a9570bb918d52b5de499c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 9 Feb 2024 09:43:39 +0100 Subject: [PATCH 071/108] Add flaky marker on test_async_keepalive_ping --- tests/httpx2/websockets/test_api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index d0068ba0..fcb5936a 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -678,6 +678,7 @@ def close(self) -> None: ) as websocket_session: websocket_session.receive() + @pytest.mark.flaky async def test_async_keepalive_ping(self): class MockAsyncNetworkStream(AsyncNetworkStream): def __init__(self) -> None: From 45f521bdf581bc5f7d8411d3f12668973e7f816d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 9 Feb 2024 17:06:56 +0100 Subject: [PATCH 072/108] Update docs --- src/httpx2/httpx2/_websockets/_api.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index d42e764f..10993d17 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -83,7 +83,7 @@ class ShouldClose(Exception): class WebSocketSession: """ - Sync helper representing an opened WebSocket session. + Sync context manager representing an opened WebSocket session. Attributes: subprotocol (typing.Optional[str]): @@ -446,6 +446,8 @@ def close(self, code: int = 1000, reason: typing.Optional[str] = None): Internally, it'll send the [CloseConnection][wsproto.events.CloseConnection] event. + *This method is automatically called when exiting the context manager.* + Args: code: The integer close code to indicate why the connection has closed. @@ -569,7 +571,7 @@ def _wait_until_closed( class AsyncWebSocketSession: """ - Async helper representing an opened WebSocket session. + Async context manager representing an opened WebSocket session. Attributes: subprotocol (typing.Optional[str]): @@ -932,6 +934,8 @@ async def close(self, code: int = 1000, reason: typing.Optional[str] = None): Internally, it'll send the [CloseConnection][wsproto.events.CloseConnection] event. + *This method is automatically called when exiting the context manager.* + Args: code: The integer close code to indicate why the connection has closed. From 8452893200d1be217ac1032a3360781dad3d7716 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 9 Feb 2024 17:07:04 +0100 Subject: [PATCH 073/108] =?UTF-8?q?Bump=20version=200.4.3=20=E2=86=92=200.?= =?UTF-8?q?5.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New features ------------ * `asyncio` and [Trio](https://trio.readthedocs.io/) support through [AnyIO](https://anyio.readthedocs.io/) Breaking changes ---------------- * `WebSocketSession` and `AsyncWebSocketSession` are now context managers. If you were using them directly instead of relying on `connect_ws` and `aconnect_ws`, you'll have to adapt your code accordingly: ```py with WebSocketSession(...) as session: ... async with AsyncWebSocketSession(...) as session: ... ``` * `AsyncWebSocketSession.receive_*` methods may now raise `TimeoutError` instead of `asyncio.TimeoutError`: **Before** ```py try: event = await ws.receive(timeout=2.) except asyncio.TimeoutError: print("No event received.") except WebSocketDisconnect: print("Connection closed") ``` **After** ```py try: event = await ws.receive(timeout=2.) except TimeoutError: print("No event received.") except WebSocketDisconnect: print("Connection closed") ``` --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index cf301e3d..c3565c20 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.3" +__version__ = "0.5.0" from httpx_ws._api import ( AsyncWebSocketSession, From 4306a2072063b2939456fb32a7cabd0bf435de0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 22 Feb 2024 09:06:13 +0100 Subject: [PATCH 074/108] Upgrade httpcore>=1.0.4 and remove connection hang workaround in tests Fix #7 --- tests/httpx2/websockets/test_api.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index fcb5936a..ea3bb4f2 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -301,7 +301,6 @@ async def aclose(self) -> None: async def test_receive(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await anyio.sleep(0.1) # FIXME: see #7 await websocket.send_text("SERVER_MESSAGE") @@ -343,7 +342,6 @@ async def test_receive_oversized_message( ): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await anyio.sleep(0.1) # FIXME: see #7 method = getattr(websocket, send_method) await method(full_message) @@ -380,7 +378,6 @@ async def websocket_endpoint(websocket: WebSocket): async def test_receive_text(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await anyio.sleep(0.1) # FIXME: see #7 await websocket.send_text("SERVER_MESSAGE") @@ -410,7 +407,6 @@ async def test_receive_text_invalid_type( ): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await anyio.sleep(0.1) # FIXME: see #7 await websocket.send_bytes(b"SERVER_MESSAGE") @@ -438,7 +434,6 @@ async def websocket_endpoint(websocket: WebSocket): async def test_receive_bytes(self, server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await anyio.sleep(0.1) # FIXME: see #7 await websocket.send_bytes(b"SERVER_MESSAGE") @@ -468,7 +463,6 @@ async def test_receive_bytes_invalid_type( ): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await anyio.sleep(0.1) # FIXME: see #7 await websocket.send_text("SERVER_MESSAGE") @@ -493,7 +487,6 @@ async def test_receive_json( ): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await anyio.sleep(0.1) # FIXME: see #7 await websocket.send_json({"message": "SERVER_MESSAGE"}, mode=mode) @@ -811,7 +804,6 @@ async def websocket_endpoint(websocket: WebSocket): async def test_receive_close(server_factory: ServerFactoryFixture): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() - await anyio.sleep(0.1) # FIXME: see #7 await websocket.close() with server_factory(websocket_endpoint) as socket: From d2720705557d615033eaa4f44372833a0998fe7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 22 Feb 2024 10:56:59 +0100 Subject: [PATCH 075/108] Increase flakiness of test_async_keepalive_ping --- tests/httpx2/websockets/test_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index ea3bb4f2..ff963ad1 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -671,7 +671,7 @@ def close(self) -> None: ) as websocket_session: websocket_session.receive() - @pytest.mark.flaky + @pytest.mark.flaky(max_runs=5, min_passes=1) async def test_async_keepalive_ping(self): class MockAsyncNetworkStream(AsyncNetworkStream): def __init__(self) -> None: From 3c03a6754a8f429dd4eddbbf73e5cf4aa51771d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 22 Feb 2024 15:10:03 +0100 Subject: [PATCH 076/108] Disable automatic keepalive ping when using ASGI transport --- src/httpx2/httpx2/_websockets/__init__.py | 12 ++-- src/httpx2/httpx2/_websockets/_api.py | 71 +++++--------------- src/httpx2/httpx2/_websockets/_exceptions.py | 55 +++++++++++++++ src/httpx2/httpx2/_websockets/transport.py | 2 +- tests/httpx2/websockets/test_transport.py | 18 +++++ 5 files changed, 98 insertions(+), 60 deletions(-) create mode 100644 src/httpx2/httpx2/_websockets/_exceptions.py diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index c3565c20..64a6d5d8 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,16 +1,18 @@ __version__ = "0.5.0" -from httpx_ws._api import ( +from ._api import ( AsyncWebSocketSession, - HTTPXWSException, JSONMode, + WebSocketSession, + aconnect_ws, + connect_ws, +) +from ._exceptions import ( + HTTPXWSException, WebSocketDisconnect, WebSocketInvalidTypeReceived, WebSocketNetworkError, - WebSocketSession, WebSocketUpgradeError, - aconnect_ws, - connect_ws, ) __all__ = [ diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 10993d17..f1c34986 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -14,7 +14,15 @@ from httpcore import AsyncNetworkStream, NetworkStream from wsproto.frame_protocol import CloseReason -from httpx_ws._ping import AsyncPingManager, PingManager +from ._exceptions import ( + HTTPXWSException, + WebSocketDisconnect, + WebSocketInvalidTypeReceived, + WebSocketNetworkError, + WebSocketUpgradeError, +) +from ._ping import AsyncPingManager, PingManager +from .transport import ASGIWebSocketAsyncNetworkStream JSONMode = typing.Literal["text", "binary"] TaskFunction = typing.TypeVar("TaskFunction") @@ -26,57 +34,6 @@ DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS = 20.0 -class HTTPXWSException(Exception): - """ - Base exception class for HTTPX WS. - """ - - pass - - -class WebSocketUpgradeError(HTTPXWSException): - """ - Raised when the initial connection didn't correctly upgrade to a WebSocket session. - """ - - def __init__(self, response: httpx.Response) -> None: - self.response = response - - -class WebSocketDisconnect(HTTPXWSException): - """ - Raised when the server closed the WebSocket session. - - Args: - code: - The integer close code to indicate why the connection has closed. - reason: - Additional reasoning for why the connection has closed. - """ - - def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: - self.code = code - self.reason = reason or "" - - -class WebSocketInvalidTypeReceived(HTTPXWSException): - """ - Raised when a event is not of the expected type. - """ - - def __init__(self, event: wsproto.events.Event) -> None: - self.event = event - - -class WebSocketNetworkError(HTTPXWSException): - """ - Raised when a network error occured, - typically if the underlying stream has closed or timeout. - """ - - pass - - class ShouldClose(Exception): pass @@ -607,8 +564,14 @@ def __init__( self._max_message_size_bytes = max_message_size_bytes self._queue_size = queue_size - self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds - self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds + + # Always disable keepalive ping when emulating ASGI + if isinstance(stream, ASGIWebSocketAsyncNetworkStream): + self._keepalive_ping_interval_seconds = None + self._keepalive_ping_timeout_seconds = None + else: + self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds + self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds async def __aenter__(self): self._exit_stack = contextlib.AsyncExitStack() diff --git a/src/httpx2/httpx2/_websockets/_exceptions.py b/src/httpx2/httpx2/_websockets/_exceptions.py new file mode 100644 index 00000000..0facbf82 --- /dev/null +++ b/src/httpx2/httpx2/_websockets/_exceptions.py @@ -0,0 +1,55 @@ +import typing + +import httpx +import wsproto + + +class HTTPXWSException(Exception): + """ + Base exception class for HTTPX WS. + """ + + pass + + +class WebSocketUpgradeError(HTTPXWSException): + """ + Raised when the initial connection didn't correctly upgrade to a WebSocket session. + """ + + def __init__(self, response: httpx.Response) -> None: + self.response = response + + +class WebSocketDisconnect(HTTPXWSException): + """ + Raised when the server closed the WebSocket session. + + Args: + code: + The integer close code to indicate why the connection has closed. + reason: + Additional reasoning for why the connection has closed. + """ + + def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: + self.code = code + self.reason = reason or "" + + +class WebSocketInvalidTypeReceived(HTTPXWSException): + """ + Raised when a event is not of the expected type. + """ + + def __init__(self, event: wsproto.events.Event) -> None: + self.event = event + + +class WebSocketNetworkError(HTTPXWSException): + """ + Raised when a network error occured, + typically if the underlying stream has closed or timeout. + """ + + pass diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index 60653eee..3ec3291f 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -9,7 +9,7 @@ from httpx import ASGITransport, AsyncByteStream, Request, Response from wsproto.frame_protocol import CloseReason -from httpx_ws._api import WebSocketDisconnect +from ._exceptions import WebSocketDisconnect Scope = typing.Dict[str, typing.Any] Message = typing.Dict[str, typing.Any] diff --git a/tests/httpx2/websockets/test_transport.py b/tests/httpx2/websockets/test_transport.py index 22822ab4..032b8402 100644 --- a/tests/httpx2/websockets/test_transport.py +++ b/tests/httpx2/websockets/test_transport.py @@ -211,3 +211,21 @@ async def websocket_endpoint(websocket: WebSocket): ) as ws: await ws.receive_text() assert ws.subprotocol == "custom_protocol" + + +@pytest.mark.anyio +async def test_keepalive_ping_disabled(): + async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.receive_text() + await websocket.close() + + app = Starlette( + routes=[ + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + ) + + async with httpx.AsyncClient(transport=ASGIWebSocketTransport(app)) as client: + async with aconnect_ws("ws://localhost:8000/ws", client) as ws: + assert ws._keepalive_ping_interval_seconds is None From 7aa4a373dc251c673956f321c47015334fc82f73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 22 Feb 2024 16:45:47 +0100 Subject: [PATCH 077/108] =?UTF-8?q?Bump=20version=200.5.0=20=E2=86=92=200.?= =?UTF-8?q?5.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes and improvements -------------------------- * Always disable automatic keepalive ping when using ASGI transport. Thanks @dmontagu and @Kludex 🎉 * Bump dependencies: * `httpcore>=1.0.4` - Solves #7, thanks to @tomchristie and @MtkN1 🎉 --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 64a6d5d8..ea85b4e9 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.5.0" +__version__ = "0.5.1" from ._api import ( AsyncWebSocketSession, From 1a5f10b8f876bbaa84f6fbd0096699cec7260906 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 22 Feb 2024 17:02:38 +0100 Subject: [PATCH 078/108] Fix test_receive_oversized_message unit test --- tests/httpx2/websockets/test_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index ff963ad1..3132bf55 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -330,8 +330,8 @@ async def websocket_endpoint(websocket: WebSocket): @pytest.mark.parametrize( "full_message,send_method", [ - (b"A" * 1024 * 1024, "send_bytes"), - ("A" * 1024 * 1024, "send_text"), + (b"A" * 1024 * 4, "send_bytes"), + ("A" * 1024 * 4, "send_text"), ], ) async def test_receive_oversized_message( From 87d25f4e148022cc2885cd24be7479afcfe5af1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 19 Mar 2024 09:14:54 +0100 Subject: [PATCH 079/108] Upgrade and fix Ruff linting --- tests/httpx2/websockets/conftest.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index 6337d68d..88bfe824 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -23,8 +23,7 @@ def websocket_implementation(request) -> Literal["wsproto", "websockets"]: class ServerFactoryFixture(Protocol): - def __call__(self, endpoint: Callable) -> ContextManager[str]: - ... + def __call__(self, endpoint: Callable) -> ContextManager[str]: ... @pytest.fixture From 2d99bf96c103c79f5f789f03ea2e1b149669a744 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 19 Mar 2024 09:18:12 +0100 Subject: [PATCH 080/108] =?UTF-8?q?Bump=20version=200.5.1=20=E2=86=92=200.?= =?UTF-8?q?5.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes --------- * Set `anyio` dependency lower bound version to `>4` --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index ea85b4e9..faee150f 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.5.1" +__version__ = "0.5.2" from ._api import ( AsyncWebSocketSession, From 8b7d0e7bd81bfed4574af68f53e193ad38937077 Mon Sep 17 00:00:00 2001 From: WSH032 <614337162@qq.com> Date: Mon, 30 Oct 2023 13:46:48 +0800 Subject: [PATCH 081/108] Add `response` attribute for `WebSocketSession` --- src/httpx2/httpx2/_websockets/_api.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index f1c34986..94b26bd9 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -45,9 +45,12 @@ class WebSocketSession: Attributes: subprotocol (typing.Optional[str]): Optional protocol that has been accepted by the server. + response (typing.Optional[httpx.Response]): + The response received after completing the WebSocket handshake. """ subprotocol: typing.Optional[str] + response: typing.Optional[httpx.Response] def __init__( self, @@ -62,10 +65,12 @@ def __init__( float ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, subprotocol: typing.Optional[str] = None, + response: typing.Optional[httpx.Response] = None, ) -> None: self.stream = stream self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) self.subprotocol = subprotocol + self.response = response self._events: queue.Queue[ typing.Union[wsproto.events.Event, HTTPXWSException] @@ -533,6 +538,8 @@ class AsyncWebSocketSession: Attributes: subprotocol (typing.Optional[str]): Optional protocol that has been accepted by the server. + response (typing.Optional[httpx.Response]): + The response received after completing the WebSocket handshake. """ subprotocol: typing.Optional[str] @@ -550,10 +557,12 @@ def __init__( float ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, subprotocol: typing.Optional[str] = None, + response: typing.Optional[httpx.Response] = None, ) -> None: self.stream = stream self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) self.subprotocol = subprotocol + self.response = response self._send_event, self._receive_event = anyio.create_memory_object_stream[ typing.Union[wsproto.events.Event, HTTPXWSException] @@ -1044,6 +1053,7 @@ def _connect_ws( keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, subprotocol=subprotocol, + response=response, ) as session: yield session @@ -1179,6 +1189,7 @@ async def _aconnect_ws( keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, subprotocol=subprotocol, + response=response, ) as session: yield session From afb0802d7e181ae96550fae1c6c30cd00f7357c6 Mon Sep 17 00:00:00 2001 From: WSH032 <614337162@qq.com> Date: Mon, 30 Oct 2023 15:41:42 +0800 Subject: [PATCH 082/108] Add test for attribute `response` - Add test for attribute `response` - Add missing type annotations --- src/httpx2/httpx2/_websockets/_api.py | 1 + tests/httpx2/websockets/test_api.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 94b26bd9..4486b012 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -543,6 +543,7 @@ class AsyncWebSocketSession: """ subprotocol: typing.Optional[str] + response: typing.Optional[httpx.Response] def __init__( self, diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 3132bf55..0bc6951e 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -846,7 +846,7 @@ async def test_default_httpx_client(): @pytest.mark.anyio -async def test_subprotocol(): +async def test_subprotocol_and_response(): def handler(request): assert ( request.headers["sec-websocket-protocol"] @@ -879,7 +879,9 @@ def async_handler(request): client, subprotocols=["custom_protocol", "unsupported_protocol"], ) as ws: + assert isinstance(ws.response, httpx.Response) assert ws.subprotocol == "custom_protocol" + assert ws.response.headers["sec-websocket-protocol"] == ws.subprotocol async with httpx.AsyncClient( base_url="http://localhost:8000", transport=httpx.MockTransport(async_handler) @@ -889,4 +891,6 @@ def async_handler(request): client, subprotocols=["custom_protocol", "unsupported_protocol"], ) as aws: + assert isinstance(aws.response, httpx.Response) assert aws.subprotocol == "custom_protocol" + assert aws.response.headers["sec-websocket-protocol"] == aws.subprotocol From 3b25adff7cba43fe9cf74e31d94c53a59a157fed Mon Sep 17 00:00:00 2001 From: WSH032 <614337162@qq.com> Date: Wed, 3 Apr 2024 18:31:37 +0800 Subject: [PATCH 083/108] parse subprotocol in `WebSocketSession`, rather than in `connect_ws` BREAKING CHANGE: the `subprotocol` parameter of `WebSocketSession` has been removed. --- src/httpx2/httpx2/_websockets/_api.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 4486b012..2970531e 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -46,7 +46,7 @@ class WebSocketSession: subprotocol (typing.Optional[str]): Optional protocol that has been accepted by the server. response (typing.Optional[httpx.Response]): - The response received after completing the WebSocket handshake. + The webSocket handshake response. """ subprotocol: typing.Optional[str] @@ -64,13 +64,15 @@ def __init__( keepalive_ping_timeout_seconds: typing.Optional[ float ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - subprotocol: typing.Optional[str] = None, response: typing.Optional[httpx.Response] = None, ) -> None: self.stream = stream self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) - self.subprotocol = subprotocol self.response = response + if self.response is not None: + self.subprotocol = self.response.headers.get("sec-websocket-protocol") + else: + self.subprotocol = None self._events: queue.Queue[ typing.Union[wsproto.events.Event, HTTPXWSException] @@ -539,7 +541,7 @@ class AsyncWebSocketSession: subprotocol (typing.Optional[str]): Optional protocol that has been accepted by the server. response (typing.Optional[httpx.Response]): - The response received after completing the WebSocket handshake. + The webSocket handshake response. """ subprotocol: typing.Optional[str] @@ -557,13 +559,15 @@ def __init__( keepalive_ping_timeout_seconds: typing.Optional[ float ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - subprotocol: typing.Optional[str] = None, response: typing.Optional[httpx.Response] = None, ) -> None: self.stream = stream self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) - self.subprotocol = subprotocol self.response = response + if self.response is not None: + self.subprotocol = self.response.headers.get("sec-websocket-protocol") + else: + self.subprotocol = None self._send_event, self._receive_event = anyio.create_memory_object_stream[ typing.Union[wsproto.events.Event, HTTPXWSException] @@ -1045,15 +1049,12 @@ def _connect_ws( if response.status_code != 101: raise WebSocketUpgradeError(response) - subprotocol = response.headers.get("sec-websocket-protocol") - with WebSocketSession( response.extensions["network_stream"], max_message_size_bytes=max_message_size_bytes, queue_size=queue_size, keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, - subprotocol=subprotocol, response=response, ) as session: yield session @@ -1181,15 +1182,12 @@ async def _aconnect_ws( if response.status_code != 101: raise WebSocketUpgradeError(response) - subprotocol = response.headers.get("sec-websocket-protocol") - async with AsyncWebSocketSession( response.extensions["network_stream"], max_message_size_bytes=max_message_size_bytes, queue_size=queue_size, keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, - subprotocol=subprotocol, response=response, ) as session: yield session From 6fa05a45947dc0a23d3b76e2233cd61a09f8f0b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 5 Apr 2024 08:36:46 +0200 Subject: [PATCH 084/108] =?UTF-8?q?Bump=20version=200.5.2=20=E2=86=92=200.?= =?UTF-8?q?6.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Breaking changes ---------------- * [`AsyncWebSocketSession`](https://frankie567.github.io/httpx-ws/reference/httpx_ws/#httpx_ws.AsyncWebSocketSession) and [`WebSocketSession`](https://frankie567.github.io/httpx-ws/reference/httpx_ws/#httpx_ws.WebSocketSession) no longer accept the `subprotocol` parameter. It's automatically set from the `response` headers (see below). > [!NOTE] > If you only use the `connect_ws` and `aconnect_ws` functions, you don't need to change anything. Improvements ------------ * [`AsyncWebSocketSession`](https://frankie567.github.io/httpx-ws/reference/httpx_ws/#httpx_ws.AsyncWebSocketSession) and [`WebSocketSession`](https://frankie567.github.io/httpx-ws/reference/httpx_ws/#httpx_ws.WebSocketSession) now accepts the original HTTPX handshake response in parameter. Thanks @WSH032 🎉 --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index faee150f..0d7742d2 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.5.2" +__version__ = "0.6.0" from ._api import ( AsyncWebSocketSession, From 41ce91a2bc3e686f365b083e31bad87898b78c73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 2 May 2024 16:37:28 +0200 Subject: [PATCH 085/108] Fix typing --- src/httpx2/httpx2/_websockets/_api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 2970531e..d6b4fa2a 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -522,7 +522,8 @@ def _wait_until_closed( except RuntimeError as e: raise ShouldClose() from e done, pending = concurrent.futures.wait( # type: ignore - (todo_task, wait_close_task), return_when=concurrent.futures.FIRST_COMPLETED + (todo_task, wait_close_task), # type: ignore + return_when=concurrent.futures.FIRST_COMPLETED, ) for task in pending: task.cancel() From fae27f61d500786955321952b6b8fa49666f260d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Mon, 15 Jul 2024 01:11:07 +0300 Subject: [PATCH 086/108] Don't catch cancellation exceptions Doing so without re-raising them is harmful from the asyncio uncancellation PoV, and is completely unnecessary with task groups anyway. --- src/httpx2/httpx2/_websockets/_api.py | 29 +++++++++++---------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index d6b4fa2a..ac3ec762 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -990,27 +990,22 @@ async def _background_receive(self, max_bytes: int) -> None: except (httpcore.ReadError, httpcore.WriteError): await self.close(CloseReason.INTERNAL_ERROR, "Stream error") await self._send_event.send(WebSocketNetworkError()) - except anyio.get_cancelled_exc_class(): - pass async def _background_keepalive_ping( self, interval_seconds: float, timeout_seconds: typing.Optional[float] = None ) -> None: - try: - while not self._should_close.is_set(): - await anyio.sleep(interval_seconds) - pong_callback = await self.ping() - if timeout_seconds is not None: - try: - with anyio.fail_after(timeout_seconds): - await pong_callback.wait() - except TimeoutError: - await self.close( - CloseReason.INTERNAL_ERROR, "Keepalive ping timeout" - ) - await self._send_event.send(WebSocketNetworkError()) - except anyio.get_cancelled_exc_class(): - pass + while not self._should_close.is_set(): + await anyio.sleep(interval_seconds) + pong_callback = await self.ping() + if timeout_seconds is not None: + try: + with anyio.fail_after(timeout_seconds): + await pong_callback.wait() + except TimeoutError: + await self.close( + CloseReason.INTERNAL_ERROR, "Keepalive ping timeout" + ) + await self._send_event.send(WebSocketNetworkError()) def _get_headers( From 4ac6154230b8dd41d5c02bd435659bc9b1dee0de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 15 Jul 2024 09:45:24 +0200 Subject: [PATCH 087/108] Fix linting --- src/httpx2/httpx2/_websockets/transport.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index 3ec3291f..e40a4347 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -48,7 +48,7 @@ async def __aenter__( self.portal = self.exit_stack.enter_context( anyio.from_thread.start_blocking_portal("asyncio") ) - _: "Future[None]" = self.portal.start_task_soon(self._run) + _: Future[None] = self.portal.start_task_soon(self._run) await self.send({"type": "websocket.connect"}) message = await self.receive() From 8f1e355ca1fad9932cc9262d3230f35f2093f6b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Mon, 15 Jul 2024 01:50:06 +0300 Subject: [PATCH 088/108] Close memory object streams at exit This avoids the ResourceWarning about unclosed memory object streams, introduced in AnyIO 4.4.0. --- src/httpx2/httpx2/_websockets/_api.py | 53 ++++++++++++++++++--------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index ac3ec762..dc31cd23 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -6,11 +6,13 @@ import secrets import threading import typing +from types import TracebackType import anyio import httpcore import httpx import wsproto +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpcore import AsyncNetworkStream, NetworkStream from wsproto.frame_protocol import CloseReason @@ -547,6 +549,12 @@ class AsyncWebSocketSession: subprotocol: typing.Optional[str] response: typing.Optional[httpx.Response] + _send_event: MemoryObjectSendStream[ + typing.Union[wsproto.events.Event, HTTPXWSException] + ] + _receive_event: MemoryObjectReceiveStream[ + typing.Union[wsproto.events.Event, HTTPXWSException] + ] def __init__( self, @@ -570,10 +578,6 @@ def __init__( else: self.subprotocol = None - self._send_event, self._receive_event = anyio.create_memory_object_stream[ - typing.Union[wsproto.events.Event, HTTPXWSException] - ]() - self._ping_manager = AsyncPingManager() self._should_close = anyio.Event() @@ -588,26 +592,39 @@ def __init__( self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds - async def __aenter__(self): - self._exit_stack = contextlib.AsyncExitStack() - self._background_task_group = anyio.create_task_group() - await self._exit_stack.enter_async_context(self._background_task_group) + async def __aenter__(self) -> "AsyncWebSocketSession": + async with contextlib.AsyncExitStack() as exit_stack: + self._send_event, self._receive_event = anyio.create_memory_object_stream[ + typing.Union[wsproto.events.Event, HTTPXWSException] + ]() + exit_stack.enter_context(self._send_event) + exit_stack.enter_context(self._receive_event) + + self._background_task_group = anyio.create_task_group() + await exit_stack.enter_async_context(self._background_task_group) - self._background_task_group.start_soon( - self._background_receive, self._max_message_size_bytes - ) - if self._keepalive_ping_interval_seconds is not None: self._background_task_group.start_soon( - self._background_keepalive_ping, - self._keepalive_ping_interval_seconds, - self._keepalive_ping_timeout_seconds, + self._background_receive, self._max_message_size_bytes ) + if self._keepalive_ping_interval_seconds is not None: + self._background_task_group.start_soon( + self._background_keepalive_ping, + self._keepalive_ping_interval_seconds, + self._keepalive_ping_timeout_seconds, + ) + + exit_stack.callback(self._background_task_group.cancel_scope.cancel) + exit_stack.push_async_callback(self.close) + self._exit_stack = exit_stack.pop_all() return self - async def __aexit__(self, exc_type, exc, tb): - await self.close() - self._background_task_group.cancel_scope.cancel() + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]], + exc: typing.Optional[BaseException], + tb: typing.Optional[TracebackType], + ) -> None: await self._exit_stack.aclose() async def ping(self, payload: bytes = b"") -> anyio.Event: From 6d3d4cd07d015ad1cd80d4d013205689351af44a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 8 Apr 2024 13:16:01 +0200 Subject: [PATCH 089/108] Add unit test to reproduce #70 --- tests/httpx2/websockets/test_api.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 0bc6951e..fed43267 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -780,24 +780,31 @@ async def websocket_endpoint(websocket: WebSocket): @pytest.mark.anyio -async def test_send_close(server_factory: ServerFactoryFixture): +async def test_send_close( + server_factory: ServerFactoryFixture, on_receive_message: MagicMock +): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() try: await websocket.receive_text() - except StarletteWebSocketDisconnect: - pass + except StarletteWebSocketDisconnect as e: + print(e) + on_receive_message(e.code, e.reason) with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: - with connect_ws("http://socket/ws", client): - pass + with connect_ws("http://socket/ws", client) as ws: + ws.close(code=1001, reason="CLOSE_REASON") async with httpx.AsyncClient( transport=httpx.AsyncHTTPTransport(uds=socket) ) as aclient: - async with aconnect_ws("http://socket/ws", aclient): - pass + async with aconnect_ws("http://socket/ws", aclient) as aws: + await aws.close(code=1001, reason="CLOSE_REASON") + + on_receive_message.assert_has_calls( + [call(1001, "CLOSE_REASON"), call(1001, "CLOSE_REASON")] + ) @pytest.mark.anyio From d0f2c8e3bcb6bd2fc8af64e6c42675e15cf01fdf Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sun, 22 Sep 2024 23:17:04 -0300 Subject: [PATCH 090/108] Prevent threads to hang and pile up in WebSocketSession Previously the threads responsible for watching the stop event set signal were not being properly terminated, leading then to pile up and becoming a memory leak when using `connect_ws`. This commit fix this issue (#76). --- src/httpx2/httpx2/_websockets/_api.py | 35 ++++++++++++++-------- src/httpx2/httpx2/_websockets/transport.py | 2 +- tests/httpx2/websockets/test_api.py | 34 +++++++++++++++++++-- 3 files changed, 56 insertions(+), 15 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index dc31cd23..40c2e1ad 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -5,6 +5,7 @@ import queue import secrets import threading +import time import typing from types import TracebackType @@ -517,22 +518,32 @@ def _background_keepalive_ping( def _wait_until_closed( self, callable: typing.Callable[..., TaskResult], *args, **kwargs ) -> TaskResult: + exit_await = threading.Event() + + def wait_close() -> None: + while not exit_await.is_set() and not self._should_close.is_set(): + time.sleep(0.05) + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) try: - executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) - wait_close_task = executor.submit(self._should_close.wait) + wait_close_task = executor.submit(wait_close) todo_task = executor.submit(callable, *args, **kwargs) except RuntimeError as e: raise ShouldClose() from e - done, pending = concurrent.futures.wait( # type: ignore - (todo_task, wait_close_task), # type: ignore - return_when=concurrent.futures.FIRST_COMPLETED, - ) - for task in pending: - task.cancel() - if wait_close_task in done and todo_task not in done: - raise ShouldClose() - result = todo_task.result() - executor.shutdown(False) + else: + done, _ = concurrent.futures.wait( + (todo_task, wait_close_task), # type: ignore[misc] + return_when=concurrent.futures.FIRST_COMPLETED, + ) + if wait_close_task in done: + raise ShouldClose() + assert todo_task in done + if not wait_close_task.cancel(): + exit_await.set() + wait_close_task.result() + result = todo_task.result() + finally: + executor.shutdown(False) return result diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index e40a4347..5b0bafb8 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -197,7 +197,7 @@ async def _handle_ws_request( self.scope = scope self.exit_stack = contextlib.AsyncExitStack() stream, accept_response = await self.exit_stack.enter_async_context( - ASGIWebSocketAsyncNetworkStream(self.app, self.scope) + ASGIWebSocketAsyncNetworkStream(self.app, self.scope) # type: ignore[arg-type] ) accept_response_lines = accept_response.decode("utf-8").splitlines() diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index fed43267..7edbb7e2 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -1,5 +1,6 @@ import contextlib import queue +import threading import time import typing from unittest.mock import MagicMock, call, patch @@ -330,8 +331,8 @@ async def websocket_endpoint(websocket: WebSocket): @pytest.mark.parametrize( "full_message,send_method", [ - (b"A" * 1024 * 4, "send_bytes"), - ("A" * 1024 * 4, "send_text"), + pytest.param(b"A" * 1024 * 4, "send_bytes", id="bytes"), + pytest.param("A" * 1024 * 4, "send_text", id="text"), ], ) async def test_receive_oversized_message( @@ -901,3 +902,32 @@ def async_handler(request): assert isinstance(aws.response, httpx.Response) assert aws.subprotocol == "custom_protocol" assert aws.response.headers["sec-websocket-protocol"] == aws.subprotocol + + +@pytest.mark.anyio +async def test_threads_wont_hang(server_factory: ServerFactoryFixture) -> None: + """ + Check that all threads spawned in WebSocketSession are properly terminated during + a series of messages exchange. This used to be the cause of a memory leak in the + connect_ws client, see https://github.com/frankie567/httpx-ws/issues/76. + """ + + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + for _ in range(50): + await websocket.send_text("SERVER_MESSAGE") + await websocket.receive_text() + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + initial_threads_count = threading.active_count() + with connect_ws( + "http://socket/ws", client, keepalive_ping_interval_seconds=None + ) as ws: + for _ in range(50): + ws.receive() + ws.send_text("CLIENT_MESSAGE") + time.sleep(0.5) # Let the websocket endpoint finish its handling. + final_threads_count = threading.active_count() + assert initial_threads_count == final_threads_count From 7a9e7f9b6fcc903273edf60aeaa8100addf963cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Sat, 5 Oct 2024 09:22:05 +0200 Subject: [PATCH 091/108] =?UTF-8?q?Bump=20version=200.6.0=20=E2=86=92=200.?= =?UTF-8?q?6.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes --------- * Fix (#73) anyio misusages. Thanks @agronholm 🎉 * Fix (#74) unclosed anyio streams. Thanks @agronholm 🎉 * Fix (#76) memory leak with non-async WebSocketSession. Thanks @ro-oliveira95 🎉 --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 0d7742d2..95edfcdc 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.6.0" +__version__ = "0.6.1" from ._api import ( AsyncWebSocketSession, From 1ca2bedbc0bc364e51d97ef31d6f13269ba12bc8 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Mon, 7 Oct 2024 06:45:03 +0200 Subject: [PATCH 092/108] Improve _wait_until_closed() (#81) * Improve _wait_until_closed() * Fix type annotation for Python<3.9 --- src/httpx2/httpx2/_websockets/_api.py | 34 ++++++++++++++------------- tests/httpx2/websockets/test_api.py | 9 ++++--- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 40c2e1ad..cb2d1b99 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -5,7 +5,6 @@ import queue import secrets import threading -import time import typing from types import TracebackType @@ -83,12 +82,25 @@ def __init__( self._ping_manager = PingManager() self._should_close = threading.Event() + self._should_close_task: typing.Optional[concurrent.futures.Future[bool]] = None + self._executor: typing.Optional[concurrent.futures.ThreadPoolExecutor] = None self._max_message_size_bytes = max_message_size_bytes self._queue_size = queue_size self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds + def _get_executor_should_close_task( + self, + ) -> typing.Tuple[ + concurrent.futures.ThreadPoolExecutor, "concurrent.futures.Future[bool]" + ]: + if self._should_close_task is None: + self._executor = concurrent.futures.ThreadPoolExecutor() + self._should_close_task = self._executor.submit(self._should_close.wait) + assert self._executor is not None + return self._executor, self._should_close_task + def __enter__(self) -> "WebSocketSession": self._background_receive_task = threading.Thread( target=self._background_receive, args=(self._max_message_size_bytes,) @@ -427,6 +439,8 @@ def close(self, code: int = 1000, reason: typing.Optional[str] = None): ws.close() """ self._should_close.set() + if self._executor is not None: + self._executor.shutdown(False) if self.connection.state not in { wsproto.connection.ConnectionState.LOCAL_CLOSING, wsproto.connection.ConnectionState.CLOSED, @@ -518,32 +532,20 @@ def _background_keepalive_ping( def _wait_until_closed( self, callable: typing.Callable[..., TaskResult], *args, **kwargs ) -> TaskResult: - exit_await = threading.Event() - - def wait_close() -> None: - while not exit_await.is_set() and not self._should_close.is_set(): - time.sleep(0.05) - - executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) try: - wait_close_task = executor.submit(wait_close) + executor, should_close_task = self._get_executor_should_close_task() todo_task = executor.submit(callable, *args, **kwargs) except RuntimeError as e: raise ShouldClose() from e else: done, _ = concurrent.futures.wait( - (todo_task, wait_close_task), # type: ignore[misc] + (todo_task, should_close_task), # type: ignore[misc] return_when=concurrent.futures.FIRST_COMPLETED, ) - if wait_close_task in done: + if should_close_task in done: raise ShouldClose() assert todo_task in done - if not wait_close_task.cancel(): - exit_await.set() - wait_close_task.result() result = todo_task.result() - finally: - executor.shutdown(False) return result diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 7edbb7e2..a9472e79 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -928,6 +928,9 @@ async def websocket_endpoint(websocket: WebSocket) -> None: for _ in range(50): ws.receive() ws.send_text("CLIENT_MESSAGE") - time.sleep(0.5) # Let the websocket endpoint finish its handling. - final_threads_count = threading.active_count() - assert initial_threads_count == final_threads_count + time.sleep(0.1) # Let the websocket endpoint finish its handling. + threads_count = threading.active_count() + assert initial_threads_count + 2 == threads_count + time.sleep(0.1) + final_threads_count = threading.active_count() + assert initial_threads_count == final_threads_count From e5bd84dfe291d29f6d35a550d40e28c8d609838e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 7 Oct 2024 09:29:26 +0200 Subject: [PATCH 093/108] =?UTF-8?q?Bump=20version=200.6.1=20=E2=86=92=200.?= =?UTF-8?q?6.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes and improvements -------------------------- * Improve efficiency of `WebSocketSession` by reusing a single thread pool when waiting for messages. Thank you @davidbrochart 🎉 --- src/httpx2/httpx2/_websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 95edfcdc..2ae6b1b8 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.6.1" +__version__ = "0.6.2" from ._api import ( AsyncWebSocketSession, From c3dcf210dbaf46337ec4f55199ed3ca1a5ee00cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 9 Oct 2024 08:15:05 +0200 Subject: [PATCH 094/108] Drop Python 3.8 support --- src/httpx2/httpx2/_websockets/_api.py | 16 ++++++++-------- src/httpx2/httpx2/_websockets/_ping.py | 8 ++++---- src/httpx2/httpx2/_websockets/transport.py | 10 +++++----- tests/httpx2/websockets/conftest.py | 6 ++++-- tests/httpx2/websockets/test_transport.py | 10 +++++----- 5 files changed, 26 insertions(+), 24 deletions(-) diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index cb2d1b99..92a098fb 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -92,7 +92,7 @@ def __init__( def _get_executor_should_close_task( self, - ) -> typing.Tuple[ + ) -> tuple[ concurrent.futures.ThreadPoolExecutor, "concurrent.futures.Future[bool]" ]: if self._should_close_task is None: @@ -634,7 +634,7 @@ async def __aenter__(self) -> "AsyncWebSocketSession": async def __aexit__( self, - exc_type: typing.Optional[typing.Type[BaseException]], + exc_type: typing.Optional[type[BaseException]], exc: typing.Optional[BaseException], tb: typing.Optional[TracebackType], ) -> None: @@ -1039,8 +1039,8 @@ async def _background_keepalive_ping( def _get_headers( - subprotocols: typing.Optional[typing.List[str]], -) -> typing.Dict[str, typing.Any]: + subprotocols: typing.Optional[list[str]], +) -> dict[str, typing.Any]: headers = { "connection": "upgrade", "upgrade": "websocket", @@ -1065,7 +1065,7 @@ def _connect_ws( keepalive_ping_timeout_seconds: typing.Optional[ float ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - subprotocols: typing.Optional[typing.List[str]] = None, + subprotocols: typing.Optional[list[str]] = None, **kwargs: typing.Any, ) -> typing.Generator[WebSocketSession, None, None]: headers = kwargs.pop("headers", {}) @@ -1099,7 +1099,7 @@ def connect_ws( keepalive_ping_timeout_seconds: typing.Optional[ float ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - subprotocols: typing.Optional[typing.List[str]] = None, + subprotocols: typing.Optional[list[str]] = None, **kwargs: typing.Any, ) -> typing.Generator[WebSocketSession, None, None]: """ @@ -1198,7 +1198,7 @@ async def _aconnect_ws( keepalive_ping_timeout_seconds: typing.Optional[ float ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - subprotocols: typing.Optional[typing.List[str]] = None, + subprotocols: typing.Optional[list[str]] = None, **kwargs: typing.Any, ) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: headers = kwargs.pop("headers", {}) @@ -1232,7 +1232,7 @@ async def aconnect_ws( keepalive_ping_timeout_seconds: typing.Optional[ float ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - subprotocols: typing.Optional[typing.List[str]] = None, + subprotocols: typing.Optional[list[str]] = None, **kwargs: typing.Any, ) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: """ diff --git a/src/httpx2/httpx2/_websockets/_ping.py b/src/httpx2/httpx2/_websockets/_ping.py index bd3c0a07..5920eafe 100644 --- a/src/httpx2/httpx2/_websockets/_ping.py +++ b/src/httpx2/httpx2/_websockets/_ping.py @@ -12,11 +12,11 @@ def _generate_id(self) -> bytes: class PingManager(PingManagerBase): def __init__(self) -> None: - self._pings: typing.Dict[bytes, threading.Event] = {} + self._pings: dict[bytes, threading.Event] = {} def create( self, ping_id: typing.Optional[bytes] = None - ) -> typing.Tuple[bytes, threading.Event]: + ) -> tuple[bytes, threading.Event]: ping_id = self._generate_id() if not ping_id else ping_id event = threading.Event() self._pings[ping_id] = event @@ -29,11 +29,11 @@ def ack(self, ping_id: typing.Union[bytes, bytearray]): class AsyncPingManager(PingManagerBase): def __init__(self) -> None: - self._pings: typing.Dict[bytes, anyio.Event] = {} + self._pings: dict[bytes, anyio.Event] = {} def create( self, ping_id: typing.Optional[bytes] = None - ) -> typing.Tuple[bytes, anyio.Event]: + ) -> tuple[bytes, anyio.Event]: ping_id = self._generate_id() if not ping_id else ping_id event = anyio.Event() self._pings[ping_id] = event diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/transport.py index 5b0bafb8..63118d3c 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -11,8 +11,8 @@ from ._exceptions import WebSocketDisconnect -Scope = typing.Dict[str, typing.Any] -Message = typing.Dict[str, typing.Any] +Scope = dict[str, typing.Any] +Message = dict[str, typing.Any] Receive = typing.Callable[[], typing.Awaitable[Message]] Send = typing.Callable[[Scope], typing.Coroutine[None, None, None]] ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Coroutine[None, None, None]] @@ -43,7 +43,7 @@ def __init__(self, app: ASGIApp, scope: Scope) -> None: async def __aenter__( self, - ) -> typing.Tuple["ASGIWebSocketAsyncNetworkStream", bytes]: + ) -> tuple["ASGIWebSocketAsyncNetworkStream", bytes]: self.exit_stack = contextlib.ExitStack() self.portal = self.exit_stack.enter_context( anyio.from_thread.start_blocking_portal("asyncio") @@ -165,7 +165,7 @@ async def handle_async_request(self, request: Request) -> Response: headers = request.headers if scheme in {"ws", "wss"} or headers.get("upgrade") == "websocket": - subprotocols: typing.List[str] = [] + subprotocols: list[str] = [] if ( subprotocols_header := headers.get("sec-websocket-protocol") ) is not None: @@ -202,7 +202,7 @@ async def _handle_ws_request( accept_response_lines = accept_response.decode("utf-8").splitlines() headers = [ - typing.cast(typing.Tuple[str, str], line.split(": ", 1)) + typing.cast(tuple[str, str], line.split(": ", 1)) for line in accept_response_lines[1:] if line.strip() != "" ] diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index 88bfe824..58451e7b 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -2,7 +2,7 @@ import pathlib import queue import tempfile -from typing import Callable, ContextManager, Literal, Protocol +from typing import Callable, Literal, Protocol from unittest.mock import MagicMock import pytest @@ -23,7 +23,9 @@ def websocket_implementation(request) -> Literal["wsproto", "websockets"]: class ServerFactoryFixture(Protocol): - def __call__(self, endpoint: Callable) -> ContextManager[str]: ... + def __call__( + self, endpoint: Callable + ) -> contextlib.AbstractContextManager[str]: ... @pytest.fixture diff --git a/tests/httpx2/websockets/test_transport.py b/tests/httpx2/websockets/test_transport.py index 032b8402..3e0c8ca2 100644 --- a/tests/httpx2/websockets/test_transport.py +++ b/tests/httpx2/websockets/test_transport.py @@ -1,6 +1,6 @@ import base64 import secrets -from typing import Any, Dict +from typing import Any import httpx import pytest @@ -21,7 +21,7 @@ @pytest.fixture -def websocket_request_headers() -> Dict[str, str]: +def websocket_request_headers() -> dict[str, str]: return { "connection": "upgrade", "upgrade": "websocket", @@ -31,7 +31,7 @@ def websocket_request_headers() -> Dict[str, str]: @pytest.fixture -def scope(websocket_request_headers: Dict[str, str]) -> Scope: +def scope(websocket_request_headers: dict[str, str]) -> Scope: return { "type": "websocket", "path": "/ws", @@ -175,9 +175,9 @@ async def test_http(self, test_app: Starlette): async def test_websocket( self, url: str, - headers: Dict[str, Any], + headers: dict[str, Any], test_app: Starlette, - websocket_request_headers: Dict[str, str], + websocket_request_headers: dict[str, str], ): async with ASGIWebSocketTransport(app=test_app) as transport: request = httpx.Request( From d5ed6a5ca1ee361bcdc6ddf0144961c146ec5cb2 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 24 Jun 2026 12:57:24 +0200 Subject: [PATCH 095/108] Adapt vendored httpx-ws package to httpx2 Rewrite the vendored `_websockets` package to httpx2's namespaces and conventions while keeping httpcore2 lazily imported: - Import the public types from httpx2's own modules (`.._models`, `.._transports.asgi`, `.._types`) instead of `httpx`. - Defer `import httpcore2` into the methods that catch its errors, and drop the `AsyncNetworkStream` base class from the ASGI stream, so importing the package no longer eagerly pulls in httpcore2. - Rename `transport.py` to `_transport.py` to match httpx2's private-module layout. - Move the numeric defaults into a wsproto-free `_defaults` module and add a `require_wsproto()` guard so the package can be imported without `wsproto` installed; `httpx2[ws]` is now an optional extra. - Apply httpx2's typing standards (`from __future__ import annotations`, modern unions, full annotations) so the code passes ruff and mypy --strict. --- src/httpx2/httpx2/_websockets/__init__.py | 56 ++-- src/httpx2/httpx2/_websockets/_api.py | 255 ++++++++---------- src/httpx2/httpx2/_websockets/_defaults.py | 15 ++ src/httpx2/httpx2/_websockets/_exceptions.py | 16 +- src/httpx2/httpx2/_websockets/_ping.py | 15 +- .../{transport.py => _transport.py} | 45 ++-- 6 files changed, 207 insertions(+), 195 deletions(-) create mode 100644 src/httpx2/httpx2/_websockets/_defaults.py rename src/httpx2/httpx2/_websockets/{transport.py => _transport.py} (85%) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 2ae6b1b8..856cd0cb 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,22 +1,17 @@ -__version__ = "0.6.2" - -from ._api import ( - AsyncWebSocketSession, - JSONMode, - WebSocketSession, - aconnect_ws, - connect_ws, -) -from ._exceptions import ( - HTTPXWSException, - WebSocketDisconnect, - WebSocketInvalidTypeReceived, - WebSocketNetworkError, - WebSocketUpgradeError, +from ._defaults import ( + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + DEFAULT_MAX_MESSAGE_SIZE_BYTES, + DEFAULT_QUEUE_SIZE, ) __all__ = [ + "ASGIWebSocketTransport", "AsyncWebSocketSession", + "DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS", + "DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS", + "DEFAULT_MAX_MESSAGE_SIZE_BYTES", + "DEFAULT_QUEUE_SIZE", "HTTPXWSException", "JSONMode", "WebSocketDisconnect", @@ -27,3 +22,34 @@ "aconnect_ws", "connect_ws", ] + +_API_NAMES = { + "AsyncWebSocketSession", + "JSONMode", + "WebSocketSession", + "aconnect_ws", + "connect_ws", +} +_EXCEPTION_NAMES = { + "HTTPXWSException", + "WebSocketDisconnect", + "WebSocketInvalidTypeReceived", + "WebSocketNetworkError", + "WebSocketUpgradeError", +} + + +def __getattr__(name: str) -> object: + if name in _API_NAMES: + from . import _api + + return getattr(_api, name) + if name in _EXCEPTION_NAMES: + from . import _exceptions + + return getattr(_exceptions, name) + if name == "ASGIWebSocketTransport": + from ._transport import ASGIWebSocketTransport + + return ASGIWebSocketTransport + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/_api.py index 92a098fb..ac9f07c6 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/_api.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import concurrent.futures import contextlib @@ -9,13 +11,16 @@ from types import TracebackType import anyio -import httpcore -import httpx import wsproto from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from httpcore import AsyncNetworkStream, NetworkStream from wsproto.frame_protocol import CloseReason +from ._defaults import ( + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + DEFAULT_MAX_MESSAGE_SIZE_BYTES, + DEFAULT_QUEUE_SIZE, +) from ._exceptions import ( HTTPXWSException, WebSocketDisconnect, @@ -24,17 +29,18 @@ WebSocketUpgradeError, ) from ._ping import AsyncPingManager, PingManager -from .transport import ASGIWebSocketAsyncNetworkStream +from ._transport import ASGIWebSocketAsyncNetworkStream + +if typing.TYPE_CHECKING: + from httpcore2 import AsyncNetworkStream, NetworkStream + + from .._client import AsyncClient, Client + from .._models import Response JSONMode = typing.Literal["text", "binary"] TaskFunction = typing.TypeVar("TaskFunction") TaskResult = typing.TypeVar("TaskResult") -DEFAULT_MAX_MESSAGE_SIZE_BYTES = 65_536 -DEFAULT_QUEUE_SIZE = 512 -DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS = 20.0 -DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS = 20.0 - class ShouldClose(Exception): pass @@ -47,12 +53,12 @@ class WebSocketSession: Attributes: subprotocol (typing.Optional[str]): Optional protocol that has been accepted by the server. - response (typing.Optional[httpx.Response]): + response (Response | None): The webSocket handshake response. """ - subprotocol: typing.Optional[str] - response: typing.Optional[httpx.Response] + subprotocol: str | None + response: Response | None def __init__( self, @@ -60,13 +66,9 @@ def __init__( *, max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, queue_size: int = DEFAULT_QUEUE_SIZE, - keepalive_ping_interval_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, - keepalive_ping_timeout_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - response: typing.Optional[httpx.Response] = None, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + response: Response | None = None, ) -> None: self.stream = stream self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) @@ -76,14 +78,12 @@ def __init__( else: self.subprotocol = None - self._events: queue.Queue[ - typing.Union[wsproto.events.Event, HTTPXWSException] - ] = queue.Queue(queue_size) + self._events: queue.Queue[wsproto.events.Event | HTTPXWSException] = queue.Queue(queue_size) self._ping_manager = PingManager() self._should_close = threading.Event() - self._should_close_task: typing.Optional[concurrent.futures.Future[bool]] = None - self._executor: typing.Optional[concurrent.futures.ThreadPoolExecutor] = None + self._should_close_task: concurrent.futures.Future[bool] | None = None + self._executor: concurrent.futures.ThreadPoolExecutor | None = None self._max_message_size_bytes = max_message_size_bytes self._queue_size = queue_size @@ -92,22 +92,20 @@ def __init__( def _get_executor_should_close_task( self, - ) -> tuple[ - concurrent.futures.ThreadPoolExecutor, "concurrent.futures.Future[bool]" - ]: + ) -> tuple[concurrent.futures.ThreadPoolExecutor, concurrent.futures.Future[bool]]: if self._should_close_task is None: self._executor = concurrent.futures.ThreadPoolExecutor() self._should_close_task = self._executor.submit(self._should_close.wait) assert self._executor is not None return self._executor, self._should_close_task - def __enter__(self) -> "WebSocketSession": + def __enter__(self) -> WebSocketSession: self._background_receive_task = threading.Thread( target=self._background_receive, args=(self._max_message_size_bytes,) ) self._background_receive_task.start() - self._background_keepalive_ping_task: typing.Optional[threading.Thread] = None + self._background_keepalive_ping_task: threading.Thread | None = None if self._keepalive_ping_interval_seconds is not None: self._background_keepalive_ping_task = threading.Thread( target=self._background_keepalive_ping, @@ -120,7 +118,12 @@ def __enter__(self) -> "WebSocketSession": return self - def __exit__(self, exc_type, exc, tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: self.close() self._background_receive_task.join() if self._background_keepalive_ping_task is not None: @@ -173,10 +176,12 @@ def send(self, event: wsproto.events.Event) -> None: event = wsproto.events.Message(b"Hello!") ws.send(event) """ + import httpcore2 + try: data = self.connection.send(event) self.stream.write(data) - except httpcore.WriteError as e: + except httpcore2.WriteError as e: self.close(CloseReason.INTERNAL_ERROR, "Stream write error") raise WebSocketNetworkError() from e @@ -242,7 +247,7 @@ def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: else: self.send_bytes(serialized_data.encode("utf-8")) - def receive(self, timeout: typing.Optional[float] = None) -> wsproto.events.Event: + def receive(self, timeout: float | None = None) -> wsproto.events.Event: """ Receive an event from the server. @@ -288,7 +293,7 @@ def receive(self, timeout: typing.Optional[float] = None) -> wsproto.events.Even raise WebSocketDisconnect(event.code, event.reason) return event - def receive_text(self, timeout: typing.Optional[float] = None) -> str: + def receive_text(self, timeout: float | None = None) -> str: """ Receive text from the server. @@ -328,7 +333,7 @@ def receive_text(self, timeout: typing.Optional[float] = None) -> str: return event.data raise WebSocketInvalidTypeReceived(event) - def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: + def receive_bytes(self, timeout: float | None = None) -> bytes: """ Receive bytes from the server. @@ -365,12 +370,10 @@ def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: """ event = self.receive(timeout) if isinstance(event, wsproto.events.BytesMessage): - return event.data + return bytes(event.data) raise WebSocketInvalidTypeReceived(event) - def receive_json( - self, timeout: typing.Optional[float] = None, mode: JSONMode = "text" - ) -> typing.Any: + def receive_json(self, timeout: float | None = None, mode: JSONMode = "text") -> typing.Any: """ Receive JSON data from the server. @@ -411,14 +414,14 @@ def receive_json( print("Connection closed") """ assert mode in ["text", "binary"] - data: typing.Union[str, bytes] + data: str | bytes if mode == "text": data = self.receive_text(timeout) elif mode == "binary": data = self.receive_bytes(timeout) return json.loads(data) - def close(self, code: int = 1000, reason: typing.Optional[str] = None): + def close(self, code: int = 1000, reason: str | None = None) -> None: """ Close the WebSocket session. @@ -438,6 +441,8 @@ def close(self, code: int = 1000, reason: typing.Optional[str] = None): ws.close() """ + import httpcore2 + self._should_close.set() if self._executor is not None: self._executor.shutdown(False) @@ -449,7 +454,7 @@ def close(self, code: int = 1000, reason: typing.Optional[str] = None): data = self.connection.send(event) try: self.stream.write(data) - except httpcore.WriteError: + except httpcore2.WriteError: pass self.stream.close() @@ -467,7 +472,9 @@ def _background_receive(self, max_bytes: int) -> None: Args: max_bytes: The maximum chunk size to read at each iteration. """ - partial_message_buffer: typing.Union[str, bytes, None] = None + import httpcore2 + + partial_message_buffer: str | bytes | None = None try: while not self._should_close.is_set(): data = self._wait_until_closed(self.stream.read, max_bytes) @@ -495,34 +502,26 @@ def _background_receive(self, max_bytes: int) -> None: # Finished message with buffer: emit the full event else: event_type = type(event) - full_message_event = event_type( - partial_message_buffer + event.data - ) + full_message_event = event_type(partial_message_buffer + event.data) partial_message_buffer = None self._events.put(full_message_event) continue self._events.put(event) - except (httpcore.ReadError, httpcore.WriteError): + except (httpcore2.ReadError, httpcore2.WriteError): self.close(CloseReason.INTERNAL_ERROR, "Stream error") self._events.put(WebSocketNetworkError()) except ShouldClose: pass - def _background_keepalive_ping( - self, interval_seconds: float, timeout_seconds: typing.Optional[float] = None - ) -> None: + def _background_keepalive_ping(self, interval_seconds: float, timeout_seconds: float | None = None) -> None: try: while not self._should_close.is_set(): - should_close = self._wait_until_closed( - self._should_close.wait, interval_seconds - ) + should_close = self._wait_until_closed(self._should_close.wait, interval_seconds) if should_close: raise ShouldClose() pong_callback = self.ping() if timeout_seconds is not None: - acknowledged = self._wait_until_closed( - pong_callback.wait, timeout_seconds - ) + acknowledged = self._wait_until_closed(pong_callback.wait, timeout_seconds) if not acknowledged: self.close(CloseReason.INTERNAL_ERROR, "Keepalive ping timeout") self._events.put(WebSocketNetworkError()) @@ -530,7 +529,7 @@ def _background_keepalive_ping( pass def _wait_until_closed( - self, callable: typing.Callable[..., TaskResult], *args, **kwargs + self, callable: typing.Callable[..., TaskResult], *args: typing.Any, **kwargs: typing.Any ) -> TaskResult: try: executor, should_close_task = self._get_executor_should_close_task() @@ -556,18 +555,14 @@ class AsyncWebSocketSession: Attributes: subprotocol (typing.Optional[str]): Optional protocol that has been accepted by the server. - response (typing.Optional[httpx.Response]): + response (Response | None): The webSocket handshake response. """ - subprotocol: typing.Optional[str] - response: typing.Optional[httpx.Response] - _send_event: MemoryObjectSendStream[ - typing.Union[wsproto.events.Event, HTTPXWSException] - ] - _receive_event: MemoryObjectReceiveStream[ - typing.Union[wsproto.events.Event, HTTPXWSException] - ] + subprotocol: str | None + response: Response | None + _send_event: MemoryObjectSendStream[wsproto.events.Event | HTTPXWSException] + _receive_event: MemoryObjectReceiveStream[wsproto.events.Event | HTTPXWSException] def __init__( self, @@ -575,13 +570,9 @@ def __init__( *, max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, queue_size: int = DEFAULT_QUEUE_SIZE, - keepalive_ping_interval_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, - keepalive_ping_timeout_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - response: typing.Optional[httpx.Response] = None, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + response: Response | None = None, ) -> None: self.stream = stream self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) @@ -605,10 +596,10 @@ def __init__( self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds - async def __aenter__(self) -> "AsyncWebSocketSession": + async def __aenter__(self) -> AsyncWebSocketSession: async with contextlib.AsyncExitStack() as exit_stack: self._send_event, self._receive_event = anyio.create_memory_object_stream[ - typing.Union[wsproto.events.Event, HTTPXWSException] + wsproto.events.Event | HTTPXWSException ]() exit_stack.enter_context(self._send_event) exit_stack.enter_context(self._receive_event) @@ -616,9 +607,7 @@ async def __aenter__(self) -> "AsyncWebSocketSession": self._background_task_group = anyio.create_task_group() await exit_stack.enter_async_context(self._background_task_group) - self._background_task_group.start_soon( - self._background_receive, self._max_message_size_bytes - ) + self._background_task_group.start_soon(self._background_receive, self._max_message_size_bytes) if self._keepalive_ping_interval_seconds is not None: self._background_task_group.start_soon( self._background_keepalive_ping, @@ -634,9 +623,9 @@ async def __aenter__(self) -> "AsyncWebSocketSession": async def __aexit__( self, - exc_type: typing.Optional[type[BaseException]], - exc: typing.Optional[BaseException], - tb: typing.Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ) -> None: await self._exit_stack.aclose() @@ -687,10 +676,12 @@ async def send(self, event: wsproto.events.Event) -> None: event = await wsproto.events.Message(b"Hello!") ws.send(event) """ + import httpcore2 + try: data = self.connection.send(event) await self.stream.write(data) - except httpcore.WriteError as e: + except httpcore2.WriteError as e: await self.close(CloseReason.INTERNAL_ERROR, "Stream write error") raise WebSocketNetworkError() from e @@ -756,9 +747,7 @@ async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: else: await self.send_bytes(serialized_data.encode("utf-8")) - async def receive( - self, timeout: typing.Optional[float] = None - ) -> wsproto.events.Event: + async def receive(self, timeout: float | None = None) -> wsproto.events.Event: """ Receive an event from the server. @@ -805,7 +794,7 @@ async def receive( raise WebSocketDisconnect(event.code, event.reason) return event - async def receive_text(self, timeout: typing.Optional[float] = None) -> str: + async def receive_text(self, timeout: float | None = None) -> str: """ Receive text from the server. @@ -845,7 +834,7 @@ async def receive_text(self, timeout: typing.Optional[float] = None) -> str: return event.data raise WebSocketInvalidTypeReceived(event) - async def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: + async def receive_bytes(self, timeout: float | None = None) -> bytes: """ Receive bytes from the server. @@ -882,12 +871,10 @@ async def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: """ event = await self.receive(timeout) if isinstance(event, wsproto.events.BytesMessage): - return event.data + return bytes(event.data) raise WebSocketInvalidTypeReceived(event) - async def receive_json( - self, timeout: typing.Optional[float] = None, mode: JSONMode = "text" - ) -> typing.Any: + async def receive_json(self, timeout: float | None = None, mode: JSONMode = "text") -> typing.Any: """ Receive JSON data from the server. @@ -928,14 +915,14 @@ async def receive_json( print("Connection closed") """ assert mode in ["text", "binary"] - data: typing.Union[str, bytes] + data: str | bytes if mode == "text": data = await self.receive_text(timeout) elif mode == "binary": data = await self.receive_bytes(timeout) return json.loads(data) - async def close(self, code: int = 1000, reason: typing.Optional[str] = None): + async def close(self, code: int = 1000, reason: str | None = None) -> None: """ Close the WebSocket session. @@ -955,6 +942,8 @@ async def close(self, code: int = 1000, reason: typing.Optional[str] = None): await ws.close() """ + import httpcore2 + self._should_close.set() if self.connection.state not in { wsproto.connection.ConnectionState.LOCAL_CLOSING, @@ -964,7 +953,7 @@ async def close(self, code: int = 1000, reason: typing.Optional[str] = None): data = self.connection.send(event) try: await self.stream.write(data) - except httpcore.WriteError: + except httpcore2.WriteError: pass await self.stream.aclose() @@ -982,7 +971,9 @@ async def _background_receive(self, max_bytes: int) -> None: Args: max_bytes: The maximum chunk size to read at each iteration. """ - partial_message_buffer: typing.Union[str, bytes, None] = None + import httpcore2 + + partial_message_buffer: str | bytes | None = None try: while not self._should_close.is_set(): data = await self.stream.read(max_bytes=max_bytes) @@ -1010,20 +1001,16 @@ async def _background_receive(self, max_bytes: int) -> None: # Finished message with buffer: emit the full event else: event_type = type(event) - full_message_event = event_type( - partial_message_buffer + event.data - ) + full_message_event = event_type(partial_message_buffer + event.data) partial_message_buffer = None await self._send_event.send(full_message_event) continue await self._send_event.send(event) - except (httpcore.ReadError, httpcore.WriteError): + except (httpcore2.ReadError, httpcore2.WriteError): await self.close(CloseReason.INTERNAL_ERROR, "Stream error") await self._send_event.send(WebSocketNetworkError()) - async def _background_keepalive_ping( - self, interval_seconds: float, timeout_seconds: typing.Optional[float] = None - ) -> None: + async def _background_keepalive_ping(self, interval_seconds: float, timeout_seconds: float | None = None) -> None: while not self._should_close.is_set(): await anyio.sleep(interval_seconds) pong_callback = await self.ping() @@ -1032,14 +1019,12 @@ async def _background_keepalive_ping( with anyio.fail_after(timeout_seconds): await pong_callback.wait() except TimeoutError: - await self.close( - CloseReason.INTERNAL_ERROR, "Keepalive ping timeout" - ) + await self.close(CloseReason.INTERNAL_ERROR, "Keepalive ping timeout") await self._send_event.send(WebSocketNetworkError()) def _get_headers( - subprotocols: typing.Optional[list[str]], + subprotocols: list[str] | None, ) -> dict[str, typing.Any]: headers = { "connection": "upgrade", @@ -1055,17 +1040,13 @@ def _get_headers( @contextlib.contextmanager def _connect_ws( url: str, - client: httpx.Client, + client: Client, *, max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, queue_size: int = DEFAULT_QUEUE_SIZE, - keepalive_ping_interval_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, - keepalive_ping_timeout_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - subprotocols: typing.Optional[list[str]] = None, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: list[str] | None = None, **kwargs: typing.Any, ) -> typing.Generator[WebSocketSession, None, None]: headers = kwargs.pop("headers", {}) @@ -1089,17 +1070,13 @@ def _connect_ws( @contextlib.contextmanager def connect_ws( url: str, - client: typing.Optional[httpx.Client] = None, + client: Client | None = None, *, max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, queue_size: int = DEFAULT_QUEUE_SIZE, - keepalive_ping_interval_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, - keepalive_ping_timeout_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - subprotocols: typing.Optional[list[str]] = None, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: list[str] | None = None, **kwargs: typing.Any, ) -> typing.Generator[WebSocketSession, None, None]: """ @@ -1152,14 +1129,16 @@ def connect_ws( With explicit HTTPX client. - with httpx.Client() as client: + with httpx2.Client() as client: with connect_ws("http://localhost:8000/ws", client) as ws: message = ws.receive_text() print(message) ws.send_text("Hello!") """ if client is None: - with httpx.Client() as client: + from .._client import Client + + with Client() as client: with _connect_ws( url, client=client, @@ -1188,17 +1167,13 @@ def connect_ws( @contextlib.asynccontextmanager async def _aconnect_ws( url: str, - client: httpx.AsyncClient, + client: AsyncClient, *, max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, queue_size: int = DEFAULT_QUEUE_SIZE, - keepalive_ping_interval_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, - keepalive_ping_timeout_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - subprotocols: typing.Optional[list[str]] = None, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: list[str] | None = None, **kwargs: typing.Any, ) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: headers = kwargs.pop("headers", {}) @@ -1222,17 +1197,13 @@ async def _aconnect_ws( @contextlib.asynccontextmanager async def aconnect_ws( url: str, - client: typing.Optional[httpx.AsyncClient] = None, + client: AsyncClient | None = None, *, max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, queue_size: int = DEFAULT_QUEUE_SIZE, - keepalive_ping_interval_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, - keepalive_ping_timeout_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - subprotocols: typing.Optional[list[str]] = None, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: list[str] | None = None, **kwargs: typing.Any, ) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: """ @@ -1285,14 +1256,16 @@ async def aconnect_ws( With explicit HTTPX client. - async with httpx.AsyncClient() as client: + async with httpx2.AsyncClient() as client: async with aconnect_ws("http://localhost:8000/ws", client) as ws: message = await ws.receive_text() print(message) await ws.send_text("Hello!") """ if client is None: - async with httpx.AsyncClient() as client: + from .._client import AsyncClient + + async with AsyncClient() as client: async with _aconnect_ws( url, client=client, diff --git a/src/httpx2/httpx2/_websockets/_defaults.py b/src/httpx2/httpx2/_websockets/_defaults.py new file mode 100644 index 00000000..e7c5a30b --- /dev/null +++ b/src/httpx2/httpx2/_websockets/_defaults.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +DEFAULT_MAX_MESSAGE_SIZE_BYTES = 65_536 +DEFAULT_QUEUE_SIZE = 512 +DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS = 20.0 +DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS = 20.0 + +WS_EXTRA_INSTALL_MESSAGE = "WebSocket support requires the `wsproto` package. Install it with `pip install httpx2[ws]`." + + +def require_wsproto() -> None: + try: + import wsproto # noqa: F401 + except ImportError as exc: # pragma: no cover + raise ImportError(WS_EXTRA_INSTALL_MESSAGE) from exc diff --git a/src/httpx2/httpx2/_websockets/_exceptions.py b/src/httpx2/httpx2/_websockets/_exceptions.py index 0facbf82..762643aa 100644 --- a/src/httpx2/httpx2/_websockets/_exceptions.py +++ b/src/httpx2/httpx2/_websockets/_exceptions.py @@ -1,7 +1,11 @@ +from __future__ import annotations + import typing -import httpx -import wsproto +if typing.TYPE_CHECKING: + import wsproto + + from .._models import Response class HTTPXWSException(Exception): @@ -9,15 +13,13 @@ class HTTPXWSException(Exception): Base exception class for HTTPX WS. """ - pass - class WebSocketUpgradeError(HTTPXWSException): """ Raised when the initial connection didn't correctly upgrade to a WebSocket session. """ - def __init__(self, response: httpx.Response) -> None: + def __init__(self, response: Response) -> None: self.response = response @@ -32,7 +34,7 @@ class WebSocketDisconnect(HTTPXWSException): Additional reasoning for why the connection has closed. """ - def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: + def __init__(self, code: int = 1000, reason: str | None = None) -> None: self.code = code self.reason = reason or "" @@ -51,5 +53,3 @@ class WebSocketNetworkError(HTTPXWSException): Raised when a network error occured, typically if the underlying stream has closed or timeout. """ - - pass diff --git a/src/httpx2/httpx2/_websockets/_ping.py b/src/httpx2/httpx2/_websockets/_ping.py index 5920eafe..b9116cb0 100644 --- a/src/httpx2/httpx2/_websockets/_ping.py +++ b/src/httpx2/httpx2/_websockets/_ping.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import secrets import threading -import typing import anyio @@ -14,15 +15,13 @@ class PingManager(PingManagerBase): def __init__(self) -> None: self._pings: dict[bytes, threading.Event] = {} - def create( - self, ping_id: typing.Optional[bytes] = None - ) -> tuple[bytes, threading.Event]: + def create(self, ping_id: bytes | None = None) -> tuple[bytes, threading.Event]: ping_id = self._generate_id() if not ping_id else ping_id event = threading.Event() self._pings[ping_id] = event return ping_id, event - def ack(self, ping_id: typing.Union[bytes, bytearray]): + def ack(self, ping_id: bytes | bytearray) -> None: event = self._pings.pop(bytes(ping_id)) event.set() @@ -31,14 +30,12 @@ class AsyncPingManager(PingManagerBase): def __init__(self) -> None: self._pings: dict[bytes, anyio.Event] = {} - def create( - self, ping_id: typing.Optional[bytes] = None - ) -> tuple[bytes, anyio.Event]: + def create(self, ping_id: bytes | None = None) -> tuple[bytes, anyio.Event]: ping_id = self._generate_id() if not ping_id else ping_id event = anyio.Event() self._pings[ping_id] = event return ping_id, event - def ack(self, ping_id: typing.Union[bytes, bytearray]): + def ack(self, ping_id: bytes | bytearray) -> None: event = self._pings.pop(bytes(ping_id)) event.set() diff --git a/src/httpx2/httpx2/_websockets/transport.py b/src/httpx2/httpx2/_websockets/_transport.py similarity index 85% rename from src/httpx2/httpx2/_websockets/transport.py rename to src/httpx2/httpx2/_websockets/_transport.py index 63118d3c..761fb52d 100644 --- a/src/httpx2/httpx2/_websockets/transport.py +++ b/src/httpx2/httpx2/_websockets/_transport.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import queue import typing @@ -5,10 +7,11 @@ import anyio import wsproto -from httpcore import AsyncNetworkStream -from httpx import ASGITransport, AsyncByteStream, Request, Response from wsproto.frame_protocol import CloseReason +from .._models import Request, Response +from .._transports.asgi import ASGITransport, _ASGIApp +from .._types import AsyncByteStream from ._exceptions import WebSocketDisconnect Scope = dict[str, typing.Any] @@ -32,7 +35,7 @@ def __init__(self, event: wsproto.events.Event) -> None: self.event = event -class ASGIWebSocketAsyncNetworkStream(AsyncNetworkStream): +class ASGIWebSocketAsyncNetworkStream: def __init__(self, app: ASGIApp, scope: Scope) -> None: self.app = app self.scope = scope @@ -43,11 +46,9 @@ def __init__(self, app: ASGIApp, scope: Scope) -> None: async def __aenter__( self, - ) -> tuple["ASGIWebSocketAsyncNetworkStream", bytes]: + ) -> tuple[ASGIWebSocketAsyncNetworkStream, bytes]: self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context( - anyio.from_thread.start_blocking_portal("asyncio") - ) + self.portal = self.exit_stack.enter_context(anyio.from_thread.start_blocking_portal("asyncio")) _: Future[None] = self.portal.start_task_soon(self._run) await self.send({"type": "websocket.connect"}) @@ -63,9 +64,7 @@ async def __aenter__( async def __aexit__(self, *args: typing.Any) -> None: await self.aclose() - async def read( - self, max_bytes: int, timeout: typing.Optional[float] = None - ) -> bytes: + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: message: Message = await self.receive(timeout=timeout) type = message["type"] @@ -74,10 +73,10 @@ async def read( event: wsproto.events.Event if type == "websocket.send": - data_str: typing.Optional[str] = message.get("text") + data_str: str | None = message.get("text") if data_str is not None: event = wsproto.events.TextMessage(data_str) - data_bytes: typing.Optional[bytes] = message.get("bytes") + data_bytes: bytes | None = message.get("bytes") if data_bytes is not None: event = wsproto.events.BytesMessage(data_bytes) elif type == "websocket.close": @@ -85,9 +84,7 @@ async def read( return self.connection.send(event) - async def write( - self, buffer: bytes, timeout: typing.Optional[float] = None - ) -> None: + async def write(self, buffer: bytes, timeout: float | None = None) -> None: self.connection.receive_data(buffer) for event in self.connection.events(): if isinstance(event, wsproto.events.Request): @@ -114,7 +111,7 @@ async def aclose(self) -> None: async def send(self, message: Message) -> None: self._receive_queue.put(message) - async def receive(self, timeout: typing.Optional[float] = None) -> Message: + async def receive(self, timeout: float | None = None) -> Message: while self._send_queue.empty(): await anyio.sleep(0) return self._send_queue.get(timeout=timeout) @@ -156,9 +153,15 @@ def _build_accept_response(self, message: Message) -> bytes: class ASGIWebSocketTransport(ASGITransport): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.exit_stack: typing.Optional[contextlib.AsyncExitStack] = None + def __init__( + self, + app: _ASGIApp, + raise_app_exceptions: bool = True, + root_path: str = "", + client: tuple[str, int] = ("127.0.0.1", 123), + ) -> None: + super().__init__(app, raise_app_exceptions, root_path, client) + self.exit_stack: contextlib.AsyncExitStack | None = None async def handle_async_request(self, request: Request) -> Response: scheme = request.url.scheme @@ -166,9 +169,7 @@ async def handle_async_request(self, request: Request) -> Response: if scheme in {"ws", "wss"} or headers.get("upgrade") == "websocket": subprotocols: list[str] = [] - if ( - subprotocols_header := headers.get("sec-websocket-protocol") - ) is not None: + if (subprotocols_header := headers.get("sec-websocket-protocol")) is not None: subprotocols = subprotocols_header.split(",") scope = { From 6c7b81d2242d1a1c87728ead4f7435476fa17685 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 24 Jun 2026 12:57:33 +0200 Subject: [PATCH 096/108] Expose WebSocket support in the httpx2 public API Add `httpx2.websocket()`, `Client.websocket()` and `AsyncClient.websocket()` context managers, and re-export the WebSocket session classes, the `ASGIWebSocketTransport` and the exception hierarchy from the top-level `httpx2` namespace. These names resolve lazily through `__getattr__` so `import httpx2` keeps working without `wsproto`; a missing dependency raises a clear error pointing to the `httpx2[ws]` extra. A `TYPE_CHECKING` block re-imports them from the typed submodules so static type checkers still see the real types. --- src/httpx2/httpx2/__init__.py | 54 ++++++++++++++++++-- src/httpx2/httpx2/_api.py | 63 ++++++++++++++++++++++++ src/httpx2/httpx2/_client.py | 92 +++++++++++++++++++++++++++++++++++ src/httpx2/pyproject.toml | 3 ++ 4 files changed, 209 insertions(+), 3 deletions(-) diff --git a/src/httpx2/httpx2/__init__.py b/src/httpx2/httpx2/__init__.py index 068e0a25..5c4ec95f 100644 --- a/src/httpx2/httpx2/__init__.py +++ b/src/httpx2/httpx2/__init__.py @@ -1,3 +1,5 @@ +import typing as _typing + from .__version__ import __description__, __title__, __version__ from ._api import * from ._auth import * @@ -11,15 +13,28 @@ from ._types import * from ._urls import * +if _typing.TYPE_CHECKING: + from ._websockets._api import AsyncWebSocketSession, WebSocketSession + from ._websockets._exceptions import ( + HTTPXWSException, + WebSocketDisconnect, + WebSocketInvalidTypeReceived, + WebSocketNetworkError, + WebSocketUpgradeError, + ) + from ._websockets._transport import ASGIWebSocketTransport + __all__ = [ "__description__", "__title__", "__version__", "ASGITransport", + "ASGIWebSocketTransport", "AsyncBaseTransport", "AsyncByteStream", "AsyncClient", "AsyncHTTPTransport", + "AsyncWebSocketSession", "Auth", "BaseTransport", "BasicAuth", @@ -42,6 +57,7 @@ "HTTPError", "HTTPStatusError", "HTTPTransport", + "HTTPXWSException", "InvalidURL", "Limits", "LocalProtocolError", @@ -78,20 +94,38 @@ "UnsupportedProtocol", "URL", "USE_CLIENT_DEFAULT", + "websocket", + "WebSocketDisconnect", + "WebSocketInvalidTypeReceived", + "WebSocketNetworkError", + "WebSocketSession", + "WebSocketUpgradeError", "WriteError", "WriteTimeout", "WSGITransport", ] +_WEBSOCKET_NAMES = { + "ASGIWebSocketTransport", + "AsyncWebSocketSession", + "HTTPXWSException", + "WebSocketDisconnect", + "WebSocketInvalidTypeReceived", + "WebSocketNetworkError", + "WebSocketSession", + "WebSocketUpgradeError", + "websocket", +} + __locals = locals() for __name in __all__: - if not __name.startswith("__"): + if not __name.startswith("__") and __name not in _WEBSOCKET_NAMES: setattr(__locals[__name], "__module__", "httpx2") # noqa -def __getattr__(name: str) -> object: # pragma: no cover - if name == "main": +def __getattr__(name: str) -> object: + if name == "main": # pragma: no cover import warnings warnings.warn( @@ -104,4 +138,18 @@ def __getattr__(name: str) -> object: # pragma: no cover return main + if name in _WEBSOCKET_NAMES: + from ._websockets._defaults import require_wsproto + + require_wsproto() + + if name == "websocket": + from ._api import websocket + + return websocket + + from . import _websockets + + return getattr(_websockets, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/httpx2/httpx2/_api.py b/src/httpx2/httpx2/_api.py index 25171cbc..60cfddcf 100644 --- a/src/httpx2/httpx2/_api.py +++ b/src/httpx2/httpx2/_api.py @@ -19,10 +19,18 @@ TimeoutTypes, ) from ._urls import URL +from ._websockets._defaults import ( + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + DEFAULT_MAX_MESSAGE_SIZE_BYTES, + DEFAULT_QUEUE_SIZE, +) if typing.TYPE_CHECKING: import ssl # pragma: no cover + from ._websockets._api import WebSocketSession + __all__ = [ "delete", @@ -34,6 +42,7 @@ "put", "request", "stream", + "websocket", ] @@ -424,3 +433,57 @@ def delete( timeout=timeout, trust_env=trust_env, ) + + +@contextmanager +def websocket( + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | None = None, + proxy: ProxyTypes | None = None, + follow_redirects: bool = False, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + verify: ssl.SSLContext | str | bool = True, + trust_env: bool = True, + subprotocols: list[str] | None = None, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, +) -> Generator[WebSocketSession]: + """ + Open a WebSocket session. + + The session is closed automatically when exiting the context manager. + + ```python + with httpx2.websocket("ws://localhost:8000/ws") as ws: + ws.send_text("Hello!") + message = ws.receive_text() + ``` + + **Parameters**: See `httpx2.request` and `httpx2.Client.websocket`. + """ + with Client( + cookies=cookies, + proxy=proxy, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) as client: + with client.websocket( + url, + params=params, + headers=headers, + auth=auth, + follow_redirects=follow_redirects, + subprotocols=subprotocols, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + ) as session: + yield session diff --git a/src/httpx2/httpx2/_client.py b/src/httpx2/httpx2/_client.py index 18720ee6..b7a0f00d 100644 --- a/src/httpx2/httpx2/_client.py +++ b/src/httpx2/httpx2/_client.py @@ -48,10 +48,18 @@ ) from ._urls import URL, QueryParams from ._utils import URLPattern, get_environment_proxies +from ._websockets._defaults import ( + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + DEFAULT_MAX_MESSAGE_SIZE_BYTES, + DEFAULT_QUEUE_SIZE, +) if typing.TYPE_CHECKING: import ssl # pragma: no cover + from ._websockets._api import AsyncWebSocketSession, WebSocketSession + __all__ = ["USE_CLIENT_DEFAULT", "AsyncClient", "Client"] # The type annotation for @classmethod and context managers here follows PEP 484 @@ -845,6 +853,48 @@ def stream( finally: response.close() + @contextmanager + def websocket( + self, + url: URL | str, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: list[str] | None = None, + **kwargs: typing.Any, + ) -> Generator[WebSocketSession]: + """ + Open a WebSocket session, using this client's configuration. + + The session is closed automatically when exiting the context manager. + + ```python + with httpx2.Client() as client: + with client.websocket("ws://localhost:8000/ws") as ws: + ws.send_text("Hello!") + message = ws.receive_text() + ``` + """ + from ._websockets._defaults import require_wsproto + + require_wsproto() + + from ._websockets._api import connect_ws + + with connect_ws( + str(url), + self, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocols=subprotocols, + **kwargs, + ) as session: + yield session + def send( self, request: Request, @@ -1548,6 +1598,48 @@ async def stream( finally: await response.aclose() + @asynccontextmanager + async def websocket( + self, + url: URL | str, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: list[str] | None = None, + **kwargs: typing.Any, + ) -> AsyncGenerator[AsyncWebSocketSession]: + """ + Open a WebSocket session, using this client's configuration. + + The session is closed automatically when exiting the context manager. + + ```python + async with httpx2.AsyncClient() as client: + async with client.websocket("ws://localhost:8000/ws") as ws: + await ws.send_text("Hello!") + message = await ws.receive_text() + ``` + """ + from ._websockets._defaults import require_wsproto + + require_wsproto() + + from ._websockets._api import aconnect_ws + + async with aconnect_ws( + str(url), + self, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocols=subprotocols, + **kwargs, + ) as session: + yield session + async def send( self, request: Request, diff --git a/src/httpx2/pyproject.toml b/src/httpx2/pyproject.toml index dc194f7f..2c3c54c0 100644 --- a/src/httpx2/pyproject.toml +++ b/src/httpx2/pyproject.toml @@ -67,6 +67,9 @@ http2 = [ socks = [ "socksio==1.*", ] +ws = [ + "wsproto>=1.2", +] # TODO(Marcelo): Remove when Python 3.13 reaches EOL. zstd = [ "zstandard>=0.18.0; python_version <= '3.13'", From 29150edd36df0d0a015f6a8953442f5c11faa9f0 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 24 Jun 2026 12:57:44 +0200 Subject: [PATCH 097/108] Adapt httpx-ws tests to httpx2 and wire up test dependencies Rewrite the vendored WebSocket tests to httpx2's namespaces, point the server fixture at uvicorn's `wsproto`/`websockets-sansio` implementations (the legacy implementation is incompatible with `filterwarnings=error`), and close the mock memory streams so the trio backend doesn't trip an unraisable ResourceWarning. Add `starlette`, `websockets` and `flaky` to the dev group, the `ws` extra to the dev `httpx2[...]` install, and update `test_exported_members` to account for the lazily-exported WebSocket names. --- pyproject.toml | 5 +- tests/httpx2/test_exported_members.py | 10 +- tests/httpx2/websockets/conftest.py | 52 ++- tests/httpx2/websockets/test_api.py | 368 +++++++--------------- tests/httpx2/websockets/test_transport.py | 64 ++-- uv.lock | 113 ++++++- 6 files changed, 299 insertions(+), 313 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 08859dbc..d31db5ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,20 +14,23 @@ httpx2 = { workspace = true } [dependency-groups] dev = [ - "httpx2[brotli,cli,http2,socks,zstd]", + "httpx2[brotli,cli,http2,socks,ws,zstd]", "httpcore2[asyncio,trio,http2,socks]", # Tests "chardet==6.0.0.post1", "coverage[toml]==7.10.6", "cryptography==46.0.7", + "flaky>=3.8", "pytest>=9.0.3", "pytest-codspeed>=4.1.1", "pytest-httpbin==2.0.0", "pytest-trio==0.8.0", + "starlette>=0.49", "trio==0.31.0", "trio-typing==0.10.0", "trustme==1.2.1", "uvicorn>=0.35", + "websockets>=15", "werkzeug>=3.1.6", # Linting "mypy==1.17.1", diff --git a/tests/httpx2/test_exported_members.py b/tests/httpx2/test_exported_members.py index afa4d8e0..cd711074 100644 --- a/tests/httpx2/test_exported_members.py +++ b/tests/httpx2/test_exported_members.py @@ -3,7 +3,9 @@ def test_all_imports_are_exported() -> None: included_private_members = ["__description__", "__title__", "__version__"] - assert httpx2.__all__ == sorted( - (member for member in vars(httpx2).keys() if not member.startswith("_") or member in included_private_members), - key=str.casefold, - ) + # WebSocket members are exported lazily through `__getattr__` so they only + # appear in `vars(httpx2)` once accessed; force them in for the comparison. + lazy_members = httpx2._WEBSOCKET_NAMES + exported = {member for member in vars(httpx2) if not member.startswith("_") or member in included_private_members} + assert set(httpx2.__all__) == exported | lazy_members + assert httpx2.__all__ == sorted(httpx2.__all__, key=str.casefold) diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py index 58451e7b..3699b550 100644 --- a/tests/httpx2/websockets/conftest.py +++ b/tests/httpx2/websockets/conftest.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import contextlib import pathlib import queue import tempfile -from typing import Callable, Literal, Protocol +import time +import typing from unittest.mock import MagicMock import pytest @@ -10,50 +13,40 @@ from anyio.from_thread import start_blocking_portal from starlette.applications import Starlette from starlette.routing import WebSocketRoute +from starlette.websockets import WebSocket + +WebSocketEndpoint = typing.Callable[[WebSocket], typing.Awaitable[None]] @pytest.fixture -def on_receive_message(): +def on_receive_message() -> MagicMock: return MagicMock() -@pytest.fixture(params=("wsproto", "websockets")) -def websocket_implementation(request) -> Literal["wsproto", "websockets"]: - return request.param +@pytest.fixture(params=("wsproto", "websockets-sansio")) +def websocket_implementation(request: pytest.FixtureRequest) -> typing.Literal["wsproto", "websockets-sansio"]: + return request.param # type: ignore[no-any-return] -class ServerFactoryFixture(Protocol): - def __call__( - self, endpoint: Callable - ) -> contextlib.AbstractContextManager[str]: ... +class ServerFactoryFixture(typing.Protocol): + def __call__(self, endpoint: WebSocketEndpoint) -> contextlib.AbstractContextManager[str]: ... @pytest.fixture -def server_factory( - websocket_implementation: Literal["wsproto", "websockets"], -) -> ServerFactoryFixture: +def server_factory(websocket_implementation: typing.Literal["wsproto", "websockets-sansio"]) -> ServerFactoryFixture: @contextlib.contextmanager - def _server_factory(endpoint: Callable): - startup_queue: queue.Queue[bool] = queue.Queue() + def _server_factory(endpoint: WebSocketEndpoint) -> typing.Iterator[str]: shutdown_queue: queue.Queue[bool] = queue.Queue() def create_app() -> Starlette: - routes = [ - WebSocketRoute("/ws", endpoint=endpoint), - ] - - @contextlib.asynccontextmanager - async def lifespan(app: Starlette): - startup_queue.put(True) - yield - - return Starlette(routes=routes, lifespan=lifespan) + routes = [WebSocketRoute("/ws", endpoint=endpoint)] + return Starlette(routes=routes) - def create_server(app: Starlette, socket: str): - config = uvicorn.Config(app, uds=socket, ws=websocket_implementation) + def create_server(app: Starlette, socket: str) -> uvicorn.Server: + config = uvicorn.Config(app, uds=socket, ws=websocket_implementation, lifespan="off") return uvicorn.Server(config) - def on_server_stopped(_task): + def on_server_stopped(_task: object) -> None: shutdown_queue.put(True) with start_blocking_portal(backend="asyncio") as portal: @@ -63,7 +56,10 @@ def on_server_stopped(_task): server = create_server(app, socket) task = portal.start_task_soon(server.serve) task.add_done_callback(on_server_stopped) - startup_queue.get(True) + while not server.started and not task.done(): + time.sleep(0.01) + if task.done() and task.exception() is not None: # pragma: no cover + raise typing.cast(BaseException, task.exception()) yield socket server.should_exit = True shutdown_queue.get(True) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index a9472e79..71b5659e 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -1,73 +1,65 @@ +from __future__ import annotations + import contextlib import queue import threading import time -import typing from unittest.mock import MagicMock, call, patch import anyio -import httpcore -import httpx import pytest import wsproto -from httpcore import AsyncNetworkStream, NetworkStream -from starlette.websockets import WebSocket -from starlette.websockets import WebSocketDisconnect as StarletteWebSocketDisconnect +from starlette.websockets import WebSocket, WebSocketDisconnect as StarletteWebSocketDisconnect -from httpx_ws import ( +import httpcore2 as httpcore +import httpx2 as httpx +from httpcore2 import AsyncNetworkStream, NetworkStream +from httpx2._websockets._api import ( AsyncWebSocketSession, JSONMode, + WebSocketSession, + aconnect_ws, + connect_ws, +) +from httpx2._websockets._exceptions import ( WebSocketDisconnect, WebSocketInvalidTypeReceived, WebSocketNetworkError, - WebSocketSession, WebSocketUpgradeError, - aconnect_ws, - connect_ws, ) -from tests.conftest import ServerFactoryFixture +from tests.httpx2.websockets.conftest import ServerFactoryFixture @pytest.mark.anyio -async def test_upgrade_error(): - def handler(request): +async def test_upgrade_error() -> None: + def handler(request: httpx.Request) -> httpx.Response: return httpx.Response(400) - with httpx.Client( - base_url="http://localhost:8000", transport=httpx.MockTransport(handler) - ) as client: + with httpx.Client(base_url="http://localhost:8000", transport=httpx.MockTransport(handler)) as client: with pytest.raises(WebSocketUpgradeError): with connect_ws("http://socket/ws", client): pass - async with httpx.AsyncClient( - base_url="http://localhost:8000", transport=httpx.MockTransport(handler) - ) as client: + async with httpx.AsyncClient(base_url="http://localhost:8000", transport=httpx.MockTransport(handler)) as aclient: with pytest.raises(WebSocketUpgradeError): - async with aconnect_ws("http://socket/ws", client): + async with aconnect_ws("http://socket/ws", aclient): pass @pytest.mark.anyio class TestSend: - async def test_send_error(self): + async def test_send_error(self) -> None: class MockNetworkStream(NetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection( - wsproto.connection.ConnectionType.SERVER - ) + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) self._should_close = False - def read( - self, max_bytes: int, timeout: typing.Optional[float] = None - ) -> bytes: + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: while not self._should_close: time.sleep(0.1) raise httpcore.ReadError() - def write( - self, buffer: bytes, timeout: typing.Optional[float] = None - ) -> None: + def write(self, buffer: bytes, timeout: float | None = None) -> None: raise httpcore.WriteError() def close(self) -> None: @@ -78,24 +70,18 @@ def close(self) -> None: with WebSocketSession(stream) as websocket_session: websocket_session.send(wsproto.events.Ping()) - async def test_async_send_error(self): + async def test_async_send_error(self) -> None: class AsyncMockNetworkStream(AsyncNetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection( - wsproto.connection.ConnectionType.SERVER - ) + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) self._should_close = False - async def read( - self, max_bytes: int, timeout: typing.Optional[float] = None - ) -> bytes: + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: while not self._should_close: await anyio.sleep(0.1) raise httpcore.ReadError() - async def write( - self, buffer: bytes, timeout: typing.Optional[float] = None - ) -> None: + async def write(self, buffer: bytes, timeout: float | None = None) -> None: raise httpcore.WriteError() async def aclose(self) -> None: @@ -110,8 +96,8 @@ async def test_send( self, server_factory: ServerFactoryFixture, on_receive_message: MagicMock, - ): - async def websocket_endpoint(websocket: WebSocket): + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() message = await websocket.receive_text() @@ -127,27 +113,21 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass - async with httpx.AsyncClient( - transport=httpx.AsyncHTTPTransport(uds=socket) - ) as aclient: + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: try: async with aconnect_ws("http://socket/ws", aclient) as aws: - await aws.send( - wsproto.events.TextMessage(data="CLIENT_MESSAGE") - ) + await aws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) except WebSocketDisconnect: pass - on_receive_message.assert_has_calls( - [call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")] - ) + on_receive_message.assert_has_calls([call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")]) async def test_send_text( self, server_factory: ServerFactoryFixture, on_receive_message: MagicMock, - ): - async def websocket_endpoint(websocket: WebSocket): + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() message = await websocket.receive_text() @@ -163,25 +143,21 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass - async with httpx.AsyncClient( - transport=httpx.AsyncHTTPTransport(uds=socket) - ) as aclient: + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: try: async with aconnect_ws("http://socket/ws", aclient) as aws: await aws.send_text("CLIENT_MESSAGE") except WebSocketDisconnect: pass - on_receive_message.assert_has_calls( - [call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")] - ) + on_receive_message.assert_has_calls([call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")]) async def test_send_bytes( self, server_factory: ServerFactoryFixture, on_receive_message: MagicMock, - ): - async def websocket_endpoint(websocket: WebSocket): + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() message = await websocket.receive_bytes() @@ -197,18 +173,14 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass - async with httpx.AsyncClient( - transport=httpx.AsyncHTTPTransport(uds=socket) - ) as aclient: + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: try: async with aconnect_ws("http://socket/ws", aclient) as aws: await aws.send_bytes(b"CLIENT_MESSAGE") except WebSocketDisconnect: pass - on_receive_message.assert_has_calls( - [call(b"CLIENT_MESSAGE"), call(b"CLIENT_MESSAGE")] - ) + on_receive_message.assert_has_calls([call(b"CLIENT_MESSAGE"), call(b"CLIENT_MESSAGE")]) @pytest.mark.parametrize("mode", ["text", "binary"]) async def test_send_json( @@ -216,8 +188,8 @@ async def test_send_json( mode: JSONMode, server_factory: ServerFactoryFixture, on_receive_message: MagicMock, - ): - async def websocket_endpoint(websocket: WebSocket): + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() message = await websocket.receive_json(mode=mode) @@ -233,37 +205,27 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass - async with httpx.AsyncClient( - transport=httpx.AsyncHTTPTransport(uds=socket) - ) as aclient: + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: try: async with aconnect_ws("http://socket/ws", aclient) as aws: await aws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) except WebSocketDisconnect: pass - on_receive_message.assert_has_calls( - [call({"message": "CLIENT_MESSAGE"}), call({"message": "CLIENT_MESSAGE"})] - ) + on_receive_message.assert_has_calls([call({"message": "CLIENT_MESSAGE"}), call({"message": "CLIENT_MESSAGE"})]) @pytest.mark.anyio class TestReceive: - async def test_receive_error(self): + async def test_receive_error(self) -> None: class MockNetworkStream(NetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection( - wsproto.connection.ConnectionType.SERVER - ) + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) - def read( - self, max_bytes: int, timeout: typing.Optional[float] = None - ) -> bytes: + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: raise httpcore.ReadError() - def write( - self, buffer: bytes, timeout: typing.Optional[float] = None - ) -> None: + def write(self, buffer: bytes, timeout: float | None = None) -> None: pass def close(self) -> None: @@ -274,21 +236,15 @@ def close(self) -> None: with WebSocketSession(stream) as websocket_session: websocket_session.receive() - async def test_async_receive_error(self): + async def test_async_receive_error(self) -> None: class AsyncMockNetworkStream(AsyncNetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection( - wsproto.connection.ConnectionType.SERVER - ) + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) - async def read( - self, max_bytes: int, timeout: typing.Optional[float] = None - ) -> bytes: + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: raise httpcore.ReadError() - async def write( - self, buffer: bytes, timeout: typing.Optional[float] = None - ) -> None: + async def write(self, buffer: bytes, timeout: float | None = None) -> None: pass async def aclose(self) -> None: @@ -299,8 +255,8 @@ async def aclose(self) -> None: async with AsyncWebSocketSession(stream) as websocket_session: await websocket_session.receive() - async def test_receive(self, server_factory: ServerFactoryFixture): - async def websocket_endpoint(websocket: WebSocket): + async def test_receive(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.send_text("SERVER_MESSAGE") @@ -317,9 +273,7 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass - async with httpx.AsyncClient( - transport=httpx.AsyncHTTPTransport(uds=socket) - ) as aclient: + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: try: async with aconnect_ws("http://socket/ws", aclient) as aws: event = await aws.receive() @@ -337,11 +291,11 @@ async def websocket_endpoint(websocket: WebSocket): ) async def test_receive_oversized_message( self, - full_message: typing.Union[str, bytes], + full_message: str | bytes, send_method: str, server_factory: ServerFactoryFixture, - ): - async def websocket_endpoint(websocket: WebSocket): + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() method = getattr(websocket, send_method) @@ -352,18 +306,14 @@ async def websocket_endpoint(websocket: WebSocket): with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: try: - with connect_ws( - "http://socket/ws", client, max_message_size_bytes=1024 - ) as ws: + with connect_ws("http://socket/ws", client, max_message_size_bytes=1024) as ws: event = ws.receive() assert isinstance(event, wsproto.events.Message) assert event.data == full_message except WebSocketDisconnect: pass - async with httpx.AsyncClient( - transport=httpx.AsyncHTTPTransport(uds=socket) - ) as aclient: + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: try: async with aconnect_ws( "http://socket/ws", @@ -376,8 +326,8 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass - async def test_receive_text(self, server_factory: ServerFactoryFixture): - async def websocket_endpoint(websocket: WebSocket): + async def test_receive_text(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.send_text("SERVER_MESSAGE") @@ -393,9 +343,7 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass - async with httpx.AsyncClient( - transport=httpx.AsyncHTTPTransport(uds=socket) - ) as aclient: + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: try: async with aconnect_ws("http://socket/ws", aclient) as aws: data = await aws.receive_text() @@ -403,10 +351,8 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass - async def test_receive_text_invalid_type( - self, server_factory: ServerFactoryFixture - ): - async def websocket_endpoint(websocket: WebSocket): + async def test_receive_text_invalid_type(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.send_bytes(b"SERVER_MESSAGE") @@ -422,9 +368,7 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass - async with httpx.AsyncClient( - transport=httpx.AsyncHTTPTransport(uds=socket) - ) as aclient: + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: try: async with aconnect_ws("http://socket/ws", aclient) as aws: with pytest.raises(WebSocketInvalidTypeReceived): @@ -432,8 +376,8 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass - async def test_receive_bytes(self, server_factory: ServerFactoryFixture): - async def websocket_endpoint(websocket: WebSocket): + async def test_receive_bytes(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.send_bytes(b"SERVER_MESSAGE") @@ -449,9 +393,7 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass - async with httpx.AsyncClient( - transport=httpx.AsyncHTTPTransport(uds=socket) - ) as aclient: + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: try: async with aconnect_ws("http://socket/ws", aclient) as aws: data = await aws.receive_bytes() @@ -459,10 +401,8 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass - async def test_receive_bytes_invalid_type( - self, server_factory: ServerFactoryFixture - ): - async def websocket_endpoint(websocket: WebSocket): + async def test_receive_bytes_invalid_type(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.send_text("SERVER_MESSAGE") @@ -475,18 +415,14 @@ async def websocket_endpoint(websocket: WebSocket): with pytest.raises(WebSocketInvalidTypeReceived): ws.receive_bytes() - async with httpx.AsyncClient( - transport=httpx.AsyncHTTPTransport(uds=socket) - ) as aclient: + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: async with aconnect_ws("http://socket/ws", aclient) as aws: with pytest.raises(WebSocketInvalidTypeReceived): await aws.receive_bytes() @pytest.mark.parametrize("mode", ["text", "binary"]) - async def test_receive_json( - self, mode: JSONMode, server_factory: ServerFactoryFixture - ): - async def websocket_endpoint(websocket: WebSocket): + async def test_receive_json(self, mode: JSONMode, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.send_json({"message": "SERVER_MESSAGE"}, mode=mode) @@ -502,9 +438,7 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass - async with httpx.AsyncClient( - transport=httpx.AsyncHTTPTransport(uds=socket) - ) as aclient: + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: try: async with aconnect_ws("http://socket/ws", aclient) as aws: data = await aws.receive_json(mode=mode) @@ -515,29 +449,23 @@ async def websocket_endpoint(websocket: WebSocket): @pytest.mark.anyio class TestReceivePing: - async def test_receive_ping(self): + async def test_receive_ping(self) -> None: class MockNetworkStream(NetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection( - wsproto.connection.ConnectionType.SERVER - ) + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) self.events_to_send = [ wsproto.events.Ping(b"SERVER_PING"), wsproto.events.CloseConnection(1000), ] - def read( - self, max_bytes: int, timeout: typing.Optional[float] = None - ) -> bytes: + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: try: event = self.events_to_send.pop(0) return self.connection.send(event) except IndexError: raise httpcore.ReadError() - def write( - self, buffer: bytes, timeout: typing.Optional[float] = None - ) -> None: + def write(self, buffer: bytes, timeout: float | None = None) -> None: self.connection.receive_data(buffer) def close(self) -> None: @@ -553,29 +481,23 @@ def close(self) -> None: wsproto.events.CloseConnection(1000, ""), ] - async def test_async_receive_ping(self): + async def test_async_receive_ping(self) -> None: class MockAsyncNetworkStream(AsyncNetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection( - wsproto.connection.ConnectionType.SERVER - ) + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) self.events_to_send = [ wsproto.events.Ping(b"SERVER_PING"), wsproto.events.CloseConnection(1000), ] - async def read( - self, max_bytes: int, timeout: typing.Optional[float] = None - ) -> bytes: + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: try: event = self.events_to_send.pop(0) return self.connection.send(event) except IndexError: raise httpcore.ReadError() - async def write( - self, buffer: bytes, timeout: typing.Optional[float] = None - ) -> None: + async def write(self, buffer: bytes, timeout: float | None = None) -> None: self.connection.receive_data(buffer) async def aclose(self) -> None: @@ -594,20 +516,16 @@ async def aclose(self) -> None: @pytest.mark.anyio class TestKeepalivePing: - async def test_keepalive_ping(self): + async def test_keepalive_ping(self) -> None: class MockNetworkStream(NetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection( - wsproto.connection.ConnectionType.SERVER - ) + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) self._should_close = False self.ping_received = 0 self.ping_answered = 0 self.events_to_send: queue.Queue[wsproto.events.Event] = queue.Queue() - def read( - self, max_bytes: int, timeout: typing.Optional[float] = None - ) -> bytes: + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: while not self._should_close: try: event = self.events_to_send.get_nowait() @@ -617,9 +535,7 @@ def read( pass raise httpcore.ReadError() - def write( - self, buffer: bytes, timeout: typing.Optional[float] = None - ) -> None: + def write(self, buffer: bytes, timeout: float | None = None) -> None: self.connection.receive_data(buffer) for event in self.connection.events(): if isinstance(event, wsproto.events.Ping): @@ -640,24 +556,18 @@ def close(self) -> None: assert stream.ping_received >= 1 assert stream.ping_answered >= 1 - async def test_keepalive_ping_timeout(self): + async def test_keepalive_ping_timeout(self) -> None: class MockNetworkStream(NetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection( - wsproto.connection.ConnectionType.SERVER - ) + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) self._should_close = False - def read( - self, max_bytes: int, timeout: typing.Optional[float] = None - ) -> bytes: + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: while not self._should_close: time.sleep(0.1) raise httpcore.ReadError() - def write( - self, buffer: bytes, timeout: typing.Optional[float] = None - ) -> None: + def write(self, buffer: bytes, timeout: float | None = None) -> None: pass def close(self) -> None: @@ -673,12 +583,10 @@ def close(self) -> None: websocket_session.receive() @pytest.mark.flaky(max_runs=5, min_passes=1) - async def test_async_keepalive_ping(self): + async def test_async_keepalive_ping(self) -> None: class MockAsyncNetworkStream(AsyncNetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection( - wsproto.connection.ConnectionType.SERVER - ) + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) self._should_close = False self.ping_received = 0 self.ping_answered = 0 @@ -687,9 +595,7 @@ def __init__(self) -> None: self.receive_events, ) = anyio.create_memory_object_stream[wsproto.events.Event]() - async def read( - self, max_bytes: int, timeout: typing.Optional[float] = None - ) -> bytes: + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: while not self._should_close: try: event = self.receive_events.receive_nowait() @@ -699,9 +605,7 @@ async def read( await anyio.sleep(0.1) raise httpcore.ReadError() - async def write( - self, buffer: bytes, timeout: typing.Optional[float] = None - ) -> None: + async def write(self, buffer: bytes, timeout: float | None = None) -> None: self.connection.receive_data(buffer) for event in self.connection.events(): if isinstance(event, wsproto.events.Ping): @@ -710,6 +614,8 @@ async def write( async def aclose(self) -> None: self._should_close = True + await self.send_events.aclose() + await self.receive_events.aclose() stream = MockAsyncNetworkStream() async with AsyncWebSocketSession( @@ -722,24 +628,18 @@ async def aclose(self) -> None: assert stream.ping_received >= 1 assert stream.ping_answered >= 1 - async def test_async_keepalive_ping_timeout(self): + async def test_async_keepalive_ping_timeout(self) -> None: class MockAsyncNetworkStream(AsyncNetworkStream): def __init__(self) -> None: - self.connection = wsproto.connection.Connection( - wsproto.connection.ConnectionType.SERVER - ) + self.connection = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) self._should_close = False - async def read( - self, max_bytes: int, timeout: typing.Optional[float] = None - ) -> bytes: + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: while not self._should_close: await anyio.sleep(0.1) raise httpcore.ReadError() - async def write( - self, buffer: bytes, timeout: typing.Optional[float] = None - ) -> None: + async def write(self, buffer: bytes, timeout: float | None = None) -> None: pass async def aclose(self) -> None: @@ -756,8 +656,8 @@ async def aclose(self) -> None: @pytest.mark.anyio -async def test_ping_pong(server_factory: ServerFactoryFixture): - async def websocket_endpoint(websocket: WebSocket): +async def test_ping_pong(server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() try: await websocket.receive_text() @@ -771,9 +671,7 @@ async def websocket_endpoint(websocket: WebSocket): result = ping_callback.wait() assert result is True - async with httpx.AsyncClient( - transport=httpx.AsyncHTTPTransport(uds=socket) - ) as aclient: + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: async with aconnect_ws("http://socket/ws", aclient) as aws: aping_callback = await aws.ping() await aping_callback.wait() @@ -781,10 +679,8 @@ async def websocket_endpoint(websocket: WebSocket): @pytest.mark.anyio -async def test_send_close( - server_factory: ServerFactoryFixture, on_receive_message: MagicMock -): - async def websocket_endpoint(websocket: WebSocket): +async def test_send_close(server_factory: ServerFactoryFixture, on_receive_message: MagicMock) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() try: await websocket.receive_text() @@ -797,20 +693,16 @@ async def websocket_endpoint(websocket: WebSocket): with connect_ws("http://socket/ws", client) as ws: ws.close(code=1001, reason="CLOSE_REASON") - async with httpx.AsyncClient( - transport=httpx.AsyncHTTPTransport(uds=socket) - ) as aclient: + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: async with aconnect_ws("http://socket/ws", aclient) as aws: await aws.close(code=1001, reason="CLOSE_REASON") - on_receive_message.assert_has_calls( - [call(1001, "CLOSE_REASON"), call(1001, "CLOSE_REASON")] - ) + on_receive_message.assert_has_calls([call(1001, "CLOSE_REASON"), call(1001, "CLOSE_REASON")]) @pytest.mark.anyio -async def test_receive_close(server_factory: ServerFactoryFixture): - async def websocket_endpoint(websocket: WebSocket): +async def test_receive_close(server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.close() @@ -820,20 +712,16 @@ async def websocket_endpoint(websocket: WebSocket): with pytest.raises(WebSocketDisconnect): ws.receive() - async with httpx.AsyncClient( - transport=httpx.AsyncHTTPTransport(uds=socket) - ) as aclient: + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: async with aconnect_ws("http://socket/ws", aclient) as aws: with pytest.raises(WebSocketDisconnect): await aws.receive() @pytest.mark.anyio -async def test_default_httpx_client(): +async def test_default_httpx_client() -> None: mock_context = contextlib.ExitStack() - with patch( - "httpx_ws._api._connect_ws", return_value=mock_context - ) as mock_connect_ws: + with patch("httpx2._websockets._api._connect_ws", return_value=mock_context) as mock_connect_ws: with connect_ws("http://socket/ws"): pass mock_connect_ws.assert_called_once() @@ -842,9 +730,7 @@ async def test_default_httpx_client(): assert httpx_client.is_closed mock_async_context = contextlib.AsyncExitStack() - with patch( - "httpx_ws._api._aconnect_ws", return_value=mock_async_context - ) as mock_aconnect_ws: + with patch("httpx2._websockets._api._aconnect_ws", return_value=mock_async_context) as mock_aconnect_ws: async with aconnect_ws("http://socket/ws"): pass mock_aconnect_ws.assert_called_once() @@ -854,12 +740,9 @@ async def test_default_httpx_client(): @pytest.mark.anyio -async def test_subprotocol_and_response(): - def handler(request): - assert ( - request.headers["sec-websocket-protocol"] - == "custom_protocol, unsupported_protocol" - ) +async def test_subprotocol_and_response() -> None: + def handler(request: httpx.Request) -> httpx.Response: + assert request.headers["sec-websocket-protocol"] == "custom_protocol, unsupported_protocol" return httpx.Response( 101, @@ -867,11 +750,8 @@ def handler(request): extensions={"network_stream": MagicMock(spec=NetworkStream)}, ) - def async_handler(request): - assert ( - request.headers["sec-websocket-protocol"] - == "custom_protocol, unsupported_protocol" - ) + def async_handler(request: httpx.Request) -> httpx.Response: + assert request.headers["sec-websocket-protocol"] == "custom_protocol, unsupported_protocol" return httpx.Response( 101, @@ -879,9 +759,7 @@ def async_handler(request): extensions={"network_stream": MagicMock(spec=AsyncNetworkStream)}, ) - with httpx.Client( - base_url="http://localhost:8000", transport=httpx.MockTransport(handler) - ) as client: + with httpx.Client(base_url="http://localhost:8000", transport=httpx.MockTransport(handler)) as client: with connect_ws( "http://socket/ws", client, @@ -893,10 +771,10 @@ def async_handler(request): async with httpx.AsyncClient( base_url="http://localhost:8000", transport=httpx.MockTransport(async_handler) - ) as client: + ) as aclient: async with aconnect_ws( "http://socket/ws", - client, + aclient, subprotocols=["custom_protocol", "unsupported_protocol"], ) as aws: assert isinstance(aws.response, httpx.Response) @@ -922,9 +800,7 @@ async def websocket_endpoint(websocket: WebSocket) -> None: with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: initial_threads_count = threading.active_count() - with connect_ws( - "http://socket/ws", client, keepalive_ping_interval_seconds=None - ) as ws: + with connect_ws("http://socket/ws", client, keepalive_ping_interval_seconds=None) as ws: for _ in range(50): ws.receive() ws.send_text("CLIENT_MESSAGE") diff --git a/tests/httpx2/websockets/test_transport.py b/tests/httpx2/websockets/test_transport.py index 3e0c8ca2..736f70b2 100644 --- a/tests/httpx2/websockets/test_transport.py +++ b/tests/httpx2/websockets/test_transport.py @@ -1,20 +1,26 @@ +from __future__ import annotations + import base64 import secrets from typing import Any -import httpx import pytest import wsproto from starlette.applications import Starlette +from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route, WebSocketRoute from starlette.websockets import WebSocket -from httpx_ws import WebSocketDisconnect, aconnect_ws -from httpx_ws.transport import ( +import httpx2 as httpx +from httpx2._websockets._api import aconnect_ws +from httpx2._websockets._exceptions import WebSocketDisconnect +from httpx2._websockets._transport import ( ASGIWebSocketAsyncNetworkStream, ASGIWebSocketTransport, + Receive, Scope, + Send, UnhandledASGIMessageType, UnhandledWebSocketEvent, ) @@ -49,10 +55,10 @@ def scope(websocket_request_headers: dict[str, str]) -> Scope: @pytest.mark.anyio class TestASGIWebSocketAsyncNetworkStream: - async def test_write(self, scope: Scope): + async def test_write(self, scope: Scope) -> None: received_messages = [] - async def app(scope, receive, send): + async def app(scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "websocket.accept"}) message = await receive() received_messages.append(message) @@ -78,8 +84,8 @@ async def app(scope, receive, send): {"type": "websocket.close", "code": 1000, "reason": ""}, ] - async def test_write_unhandled_event(self, scope: Scope): - async def app(scope, receive, send): + async def test_write_unhandled_event(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "websocket.accept"}) await receive() @@ -89,8 +95,8 @@ async def app(scope, receive, send): ping_event = wsproto.events.Ping(b"PING") await stream.write(connection.send(ping_event)) - async def test_read(self, scope): - async def app(scope, receive, send): + async def test_read(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "websocket.accept"}) await send({"type": "websocket.send", "text": "SERVER_MESSAGE"}) await send({"type": "websocket.send", "bytes": b"SERVER_MESSAGE"}) @@ -110,8 +116,8 @@ async def app(scope, receive, send): wsproto.events.CloseConnection(1000, ""), ] - async def test_read_unhandled_asgi_message(self, scope): - async def app(scope, receive, send): + async def test_read_unhandled_asgi_message(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "websocket.accept"}) await send({"type": "websocket.foo"}) @@ -119,16 +125,16 @@ async def app(scope, receive, send): with pytest.raises(UnhandledASGIMessageType): await stream.read(4096) - async def test_close_immediately(self, scope): - async def app(scope, receive, send): + async def test_close_immediately(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "websocket.close", "code": 1000, "reason": ""}) with pytest.raises(WebSocketDisconnect): async with ASGIWebSocketAsyncNetworkStream(app, scope): pass - async def test_exception(self, scope): - async def app(scope, receive, send): + async def test_exception(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: raise Exception("Error") with pytest.raises(WebSocketDisconnect) as excinfo: @@ -140,10 +146,10 @@ async def app(scope, receive, send): @pytest.fixture def test_app() -> Starlette: - async def http_endpoint(request): + async def http_endpoint(request: Request) -> PlainTextResponse: return PlainTextResponse("Hello, world!") - async def websocket_endpoint(websocket: WebSocket): + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.receive_text() await websocket.close() @@ -158,7 +164,7 @@ async def websocket_endpoint(websocket: WebSocket): @pytest.mark.anyio class TestASGIWebSocketTransport: - async def test_http(self, test_app: Starlette): + async def test_http(self, test_app: Starlette) -> None: async with ASGIWebSocketTransport(app=test_app) as transport: request = httpx.Request("GET", "http://localhost:8000/http") response = await transport.handle_async_request(request) @@ -178,22 +184,18 @@ async def test_websocket( headers: dict[str, Any], test_app: Starlette, websocket_request_headers: dict[str, str], - ): + ) -> None: async with ASGIWebSocketTransport(app=test_app) as transport: - request = httpx.Request( - "GET", url, headers={**websocket_request_headers, **headers} - ) + request = httpx.Request("GET", url, headers={**websocket_request_headers, **headers}) response = await transport.handle_async_request(request) assert response.status_code == 101 - assert isinstance( - response.extensions["network_stream"], ASGIWebSocketAsyncNetworkStream - ) + assert isinstance(response.extensions["network_stream"], ASGIWebSocketAsyncNetworkStream) @pytest.mark.anyio -async def test_subprotocol_support(): - async def websocket_endpoint(websocket: WebSocket): +async def test_subprotocol_support() -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept("custom_protocol") assert websocket.scope.get("subprotocols") == ["custom_protocol"] await websocket.send_text("SERVER_MESSAGE") @@ -206,16 +208,14 @@ async def websocket_endpoint(websocket: WebSocket): ) async with httpx.AsyncClient(transport=ASGIWebSocketTransport(app)) as client: - async with aconnect_ws( - "ws://localhost:8000/ws", client, subprotocols=["custom_protocol"] - ) as ws: + async with aconnect_ws("ws://localhost:8000/ws", client, subprotocols=["custom_protocol"]) as ws: await ws.receive_text() assert ws.subprotocol == "custom_protocol" @pytest.mark.anyio -async def test_keepalive_ping_disabled(): - async def websocket_endpoint(websocket: WebSocket): +async def test_keepalive_ping_disabled() -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.receive_text() await websocket.close() diff --git a/uv.lock b/uv.lock index ef708e99..de30c2c4 100644 --- a/uv.lock +++ b/uv.lock @@ -33,19 +33,22 @@ dev = [ { name = "chardet", specifier = "==6.0.0.post1" }, { name = "coverage", extras = ["toml"], specifier = "==7.10.6" }, { name = "cryptography", specifier = "==46.0.7" }, + { name = "flaky", specifier = ">=3.8" }, { name = "httpcore2", extras = ["asyncio", "trio", "http2", "socks"], editable = "src/httpcore2" }, - { name = "httpx2", extras = ["brotli", "cli", "http2", "socks", "zstd"], editable = "src/httpx2" }, + { name = "httpx2", extras = ["brotli", "cli", "http2", "socks", "ws", "zstd"], editable = "src/httpx2" }, { name = "mypy", specifier = "==1.17.1" }, { name = "pytest", specifier = ">=9.0.3" }, { name = "pytest-codspeed", specifier = ">=4.1.1" }, { name = "pytest-httpbin", specifier = "==2.0.0" }, { name = "pytest-trio", specifier = "==0.8.0" }, { name = "ruff", specifier = "==0.15.13" }, + { name = "starlette", specifier = ">=0.49" }, { name = "trio", specifier = "==0.31.0" }, { name = "trio-typing", specifier = "==0.10.0" }, { name = "trustme", specifier = "==1.2.1" }, { name = "twine", specifier = "==6.1.0" }, { name = "uvicorn", specifier = ">=0.35" }, + { name = "websockets", specifier = ">=15" }, { name = "werkzeug", specifier = ">=3.1.6" }, ] docs = [ @@ -934,6 +937,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740, upload-time = "2025-11-21T23:01:53.443Z" }, ] +[[package]] +name = "flaky" +version = "3.8.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5b/c5/ef69119a01427204ff2db5fc8f98001087bcce719bbb94749dcd7b191365/flaky-3.8.1.tar.gz", hash = "sha256:47204a81ec905f3d5acfbd61daeabcada8f9d4031616d9bcb0618461729699f5", size = 25248, upload-time = "2024-03-12T22:17:59.265Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7f/b8/b830fc43663246c3f3dd1ae7dca4847b96ed992537e85311e27fa41ac40e/flaky-3.8.1-py2.py3-none-any.whl", hash = "sha256:194ccf4f0d3a22b2de7130f4b62e45e977ac1b5ccad74d4d48f3005dcc38815e", size = 19139, upload-time = "2024-03-12T22:17:51.59Z" }, +] + [[package]] name = "flasgger" version = "0.9.7.1" @@ -1372,6 +1384,9 @@ http2 = [ socks = [ { name = "socksio" }, ] +ws = [ + { name = "wsproto" }, +] zstd = [ { name = "zstandard", marker = "python_full_version < '3.14'" }, ] @@ -1390,9 +1405,10 @@ requires-dist = [ { name = "socksio", marker = "extra == 'socks'", specifier = "==1.*" }, { name = "truststore", specifier = ">=0.10" }, { name = "typing-extensions", marker = "python_full_version < '3.13'", specifier = ">=4.5.0" }, + { name = "wsproto", marker = "extra == 'ws'", specifier = ">=1.2" }, { name = "zstandard", marker = "python_full_version < '3.14' and extra == 'zstd'", specifier = ">=0.18.0" }, ] -provides-extras = ["brotli", "cli", "http2", "socks", "zstd"] +provides-extras = ["brotli", "cli", "http2", "socks", "ws", "zstd"] [[package]] name = "hyperframe" @@ -3175,6 +3191,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" }, ] +[[package]] +name = "starlette" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/e3/7c1dc7381d9f8ab7d854328ebfa884e62cb3f3d8549ddfd37c7814f42afa/starlette-1.3.1.tar.gz", hash = "sha256:05d0213193f2fbaae60e2ecb593b4add4262ad4e46536b54abe36f11a71724e0", size = 2703240, upload-time = "2026-06-12T09:23:11.602Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/bb/2799cc2ede3ed41131f8975621e7213dfc7ef4acbbaadfa440f32500c370/starlette-1.3.1-py3-none-any.whl", hash = "sha256:c7372aae11c3c3f26a42df7bd626cec2f47d03483d261d369516a615a53714c6", size = 73632, upload-time = "2026-06-12T09:23:10.017Z" }, +] + [[package]] name = "tomli" version = "2.4.1" @@ -3370,6 +3399,74 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, ] +[[package]] +name = "websockets" +version = "16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/04/24/4b2031d72e840ce4c1ccb255f693b15c334757fc50023e4db9537080b8c4/websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5", size = 179346, upload-time = "2026-01-10T09:23:47.181Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/74/221f58decd852f4b59cc3354cccaf87e8ef695fede361d03dc9a7396573b/websockets-16.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04cdd5d2d1dacbad0a7bf36ccbcd3ccd5a30ee188f2560b7a62a30d14107b31a", size = 177343, upload-time = "2026-01-10T09:22:21.28Z" }, + { url = "https://files.pythonhosted.org/packages/19/0f/22ef6107ee52ab7f0b710d55d36f5a5d3ef19e8a205541a6d7ffa7994e5a/websockets-16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8ff32bb86522a9e5e31439a58addbb0166f0204d64066fb955265c4e214160f0", size = 175021, upload-time = "2026-01-10T09:22:22.696Z" }, + { url = "https://files.pythonhosted.org/packages/10/40/904a4cb30d9b61c0e278899bf36342e9b0208eb3c470324a9ecbaac2a30f/websockets-16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:583b7c42688636f930688d712885cf1531326ee05effd982028212ccc13e5957", size = 175320, upload-time = "2026-01-10T09:22:23.94Z" }, + { url = "https://files.pythonhosted.org/packages/9d/2f/4b3ca7e106bc608744b1cdae041e005e446124bebb037b18799c2d356864/websockets-16.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7d837379b647c0c4c2355c2499723f82f1635fd2c26510e1f587d89bc2199e72", size = 183815, upload-time = "2026-01-10T09:22:25.469Z" }, + { url = "https://files.pythonhosted.org/packages/86/26/d40eaa2a46d4302becec8d15b0fc5e45bdde05191e7628405a19cf491ccd/websockets-16.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df57afc692e517a85e65b72e165356ed1df12386ecb879ad5693be08fac65dde", size = 185054, upload-time = "2026-01-10T09:22:27.101Z" }, + { url = "https://files.pythonhosted.org/packages/b0/ba/6500a0efc94f7373ee8fefa8c271acdfd4dca8bd49a90d4be7ccabfc397e/websockets-16.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2b9f1e0d69bc60a4a87349d50c09a037a2607918746f07de04df9e43252c77a3", size = 184565, upload-time = "2026-01-10T09:22:28.293Z" }, + { url = "https://files.pythonhosted.org/packages/04/b4/96bf2cee7c8d8102389374a2616200574f5f01128d1082f44102140344cc/websockets-16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:335c23addf3d5e6a8633f9f8eda77efad001671e80b95c491dd0924587ece0b3", size = 183848, upload-time = "2026-01-10T09:22:30.394Z" }, + { url = "https://files.pythonhosted.org/packages/02/8e/81f40fb00fd125357814e8c3025738fc4ffc3da4b6b4a4472a82ba304b41/websockets-16.0-cp310-cp310-win32.whl", hash = "sha256:37b31c1623c6605e4c00d466c9d633f9b812ea430c11c8a278774a1fde1acfa9", size = 178249, upload-time = "2026-01-10T09:22:32.083Z" }, + { url = "https://files.pythonhosted.org/packages/b4/5f/7e40efe8df57db9b91c88a43690ac66f7b7aa73a11aa6a66b927e44f26fa/websockets-16.0-cp310-cp310-win_amd64.whl", hash = "sha256:8e1dab317b6e77424356e11e99a432b7cb2f3ec8c5ab4dabbcee6add48f72b35", size = 178685, upload-time = "2026-01-10T09:22:33.345Z" }, + { url = "https://files.pythonhosted.org/packages/f2/db/de907251b4ff46ae804ad0409809504153b3f30984daf82a1d84a9875830/websockets-16.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:31a52addea25187bde0797a97d6fc3d2f92b6f72a9370792d65a6e84615ac8a8", size = 177340, upload-time = "2026-01-10T09:22:34.539Z" }, + { url = "https://files.pythonhosted.org/packages/f3/fa/abe89019d8d8815c8781e90d697dec52523fb8ebe308bf11664e8de1877e/websockets-16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:417b28978cdccab24f46400586d128366313e8a96312e4b9362a4af504f3bbad", size = 175022, upload-time = "2026-01-10T09:22:36.332Z" }, + { url = "https://files.pythonhosted.org/packages/58/5d/88ea17ed1ded2079358b40d31d48abe90a73c9e5819dbcde1606e991e2ad/websockets-16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af80d74d4edfa3cb9ed973a0a5ba2b2a549371f8a741e0800cb07becdd20f23d", size = 175319, upload-time = "2026-01-10T09:22:37.602Z" }, + { url = "https://files.pythonhosted.org/packages/d2/ae/0ee92b33087a33632f37a635e11e1d99d429d3d323329675a6022312aac2/websockets-16.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:08d7af67b64d29823fed316505a89b86705f2b7981c07848fb5e3ea3020c1abe", size = 184631, upload-time = "2026-01-10T09:22:38.789Z" }, + { url = "https://files.pythonhosted.org/packages/c8/c5/27178df583b6c5b31b29f526ba2da5e2f864ecc79c99dae630a85d68c304/websockets-16.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7be95cfb0a4dae143eaed2bcba8ac23f4892d8971311f1b06f3c6b78952ee70b", size = 185870, upload-time = "2026-01-10T09:22:39.893Z" }, + { url = "https://files.pythonhosted.org/packages/87/05/536652aa84ddc1c018dbb7e2c4cbcd0db884580bf8e95aece7593fde526f/websockets-16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d6297ce39ce5c2e6feb13c1a996a2ded3b6832155fcfc920265c76f24c7cceb5", size = 185361, upload-time = "2026-01-10T09:22:41.016Z" }, + { url = "https://files.pythonhosted.org/packages/6d/e2/d5332c90da12b1e01f06fb1b85c50cfc489783076547415bf9f0a659ec19/websockets-16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1c1b30e4f497b0b354057f3467f56244c603a79c0d1dafce1d16c283c25f6e64", size = 184615, upload-time = "2026-01-10T09:22:42.442Z" }, + { url = "https://files.pythonhosted.org/packages/77/fb/d3f9576691cae9253b51555f841bc6600bf0a983a461c79500ace5a5b364/websockets-16.0-cp311-cp311-win32.whl", hash = "sha256:5f451484aeb5cafee1ccf789b1b66f535409d038c56966d6101740c1614b86c6", size = 178246, upload-time = "2026-01-10T09:22:43.654Z" }, + { url = "https://files.pythonhosted.org/packages/54/67/eaff76b3dbaf18dcddabc3b8c1dba50b483761cccff67793897945b37408/websockets-16.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d7f0659570eefb578dacde98e24fb60af35350193e4f56e11190787bee77dac", size = 178684, upload-time = "2026-01-10T09:22:44.941Z" }, + { url = "https://files.pythonhosted.org/packages/84/7b/bac442e6b96c9d25092695578dda82403c77936104b5682307bd4deb1ad4/websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00", size = 177365, upload-time = "2026-01-10T09:22:46.787Z" }, + { url = "https://files.pythonhosted.org/packages/b0/fe/136ccece61bd690d9c1f715baaeefd953bb2360134de73519d5df19d29ca/websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79", size = 175038, upload-time = "2026-01-10T09:22:47.999Z" }, + { url = "https://files.pythonhosted.org/packages/40/1e/9771421ac2286eaab95b8575b0cb701ae3663abf8b5e1f64f1fd90d0a673/websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39", size = 175328, upload-time = "2026-01-10T09:22:49.809Z" }, + { url = "https://files.pythonhosted.org/packages/18/29/71729b4671f21e1eaa5d6573031ab810ad2936c8175f03f97f3ff164c802/websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c", size = 184915, upload-time = "2026-01-10T09:22:51.071Z" }, + { url = "https://files.pythonhosted.org/packages/97/bb/21c36b7dbbafc85d2d480cd65df02a1dc93bf76d97147605a8e27ff9409d/websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f", size = 186152, upload-time = "2026-01-10T09:22:52.224Z" }, + { url = "https://files.pythonhosted.org/packages/4a/34/9bf8df0c0cf88fa7bfe36678dc7b02970c9a7d5e065a3099292db87b1be2/websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1", size = 185583, upload-time = "2026-01-10T09:22:53.443Z" }, + { url = "https://files.pythonhosted.org/packages/47/88/4dd516068e1a3d6ab3c7c183288404cd424a9a02d585efbac226cb61ff2d/websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2", size = 184880, upload-time = "2026-01-10T09:22:55.033Z" }, + { url = "https://files.pythonhosted.org/packages/91/d6/7d4553ad4bf1c0421e1ebd4b18de5d9098383b5caa1d937b63df8d04b565/websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89", size = 178261, upload-time = "2026-01-10T09:22:56.251Z" }, + { url = "https://files.pythonhosted.org/packages/c3/f0/f3a17365441ed1c27f850a80b2bc680a0fa9505d733fe152fdf5e98c1c0b/websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea", size = 178693, upload-time = "2026-01-10T09:22:57.478Z" }, + { url = "https://files.pythonhosted.org/packages/cc/9c/baa8456050d1c1b08dd0ec7346026668cbc6f145ab4e314d707bb845bf0d/websockets-16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878b336ac47938b474c8f982ac2f7266a540adc3fa4ad74ae96fea9823a02cc9", size = 177364, upload-time = "2026-01-10T09:22:59.333Z" }, + { url = "https://files.pythonhosted.org/packages/7e/0c/8811fc53e9bcff68fe7de2bcbe75116a8d959ac699a3200f4847a8925210/websockets-16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52a0fec0e6c8d9a784c2c78276a48a2bdf099e4ccc2a4cad53b27718dbfd0230", size = 175039, upload-time = "2026-01-10T09:23:01.171Z" }, + { url = "https://files.pythonhosted.org/packages/aa/82/39a5f910cb99ec0b59e482971238c845af9220d3ab9fa76dd9162cda9d62/websockets-16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e6578ed5b6981005df1860a56e3617f14a6c307e6a71b4fff8c48fdc50f3ed2c", size = 175323, upload-time = "2026-01-10T09:23:02.341Z" }, + { url = "https://files.pythonhosted.org/packages/bd/28/0a25ee5342eb5d5f297d992a77e56892ecb65e7854c7898fb7d35e9b33bd/websockets-16.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:95724e638f0f9c350bb1c2b0a7ad0e83d9cc0c9259f3ea94e40d7b02a2179ae5", size = 184975, upload-time = "2026-01-10T09:23:03.756Z" }, + { url = "https://files.pythonhosted.org/packages/f9/66/27ea52741752f5107c2e41fda05e8395a682a1e11c4e592a809a90c6a506/websockets-16.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0204dc62a89dc9d50d682412c10b3542d748260d743500a85c13cd1ee4bde82", size = 186203, upload-time = "2026-01-10T09:23:05.01Z" }, + { url = "https://files.pythonhosted.org/packages/37/e5/8e32857371406a757816a2b471939d51c463509be73fa538216ea52b792a/websockets-16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:52ac480f44d32970d66763115edea932f1c5b1312de36df06d6b219f6741eed8", size = 185653, upload-time = "2026-01-10T09:23:06.301Z" }, + { url = "https://files.pythonhosted.org/packages/9b/67/f926bac29882894669368dc73f4da900fcdf47955d0a0185d60103df5737/websockets-16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6e5a82b677f8f6f59e8dfc34ec06ca6b5b48bc4fcda346acd093694cc2c24d8f", size = 184920, upload-time = "2026-01-10T09:23:07.492Z" }, + { url = "https://files.pythonhosted.org/packages/3c/a1/3d6ccdcd125b0a42a311bcd15a7f705d688f73b2a22d8cf1c0875d35d34a/websockets-16.0-cp313-cp313-win32.whl", hash = "sha256:abf050a199613f64c886ea10f38b47770a65154dc37181bfaff70c160f45315a", size = 178255, upload-time = "2026-01-10T09:23:09.245Z" }, + { url = "https://files.pythonhosted.org/packages/6b/ae/90366304d7c2ce80f9b826096a9e9048b4bb760e44d3b873bb272cba696b/websockets-16.0-cp313-cp313-win_amd64.whl", hash = "sha256:3425ac5cf448801335d6fdc7ae1eb22072055417a96cc6b31b3861f455fbc156", size = 178689, upload-time = "2026-01-10T09:23:10.483Z" }, + { url = "https://files.pythonhosted.org/packages/f3/1d/e88022630271f5bd349ed82417136281931e558d628dd52c4d8621b4a0b2/websockets-16.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8cc451a50f2aee53042ac52d2d053d08bf89bcb31ae799cb4487587661c038a0", size = 177406, upload-time = "2026-01-10T09:23:12.178Z" }, + { url = "https://files.pythonhosted.org/packages/f2/78/e63be1bf0724eeb4616efb1ae1c9044f7c3953b7957799abb5915bffd38e/websockets-16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:daa3b6ff70a9241cf6c7fc9e949d41232d9d7d26fd3522b1ad2b4d62487e9904", size = 175085, upload-time = "2026-01-10T09:23:13.511Z" }, + { url = "https://files.pythonhosted.org/packages/bb/f4/d3c9220d818ee955ae390cf319a7c7a467beceb24f05ee7aaaa2414345ba/websockets-16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:fd3cb4adb94a2a6e2b7c0d8d05cb94e6f1c81a0cf9dc2694fb65c7e8d94c42e4", size = 175328, upload-time = "2026-01-10T09:23:14.727Z" }, + { url = "https://files.pythonhosted.org/packages/63/bc/d3e208028de777087e6fb2b122051a6ff7bbcca0d6df9d9c2bf1dd869ae9/websockets-16.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:781caf5e8eee67f663126490c2f96f40906594cb86b408a703630f95550a8c3e", size = 185044, upload-time = "2026-01-10T09:23:15.939Z" }, + { url = "https://files.pythonhosted.org/packages/ad/6e/9a0927ac24bd33a0a9af834d89e0abc7cfd8e13bed17a86407a66773cc0e/websockets-16.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:caab51a72c51973ca21fa8a18bd8165e1a0183f1ac7066a182ff27107b71e1a4", size = 186279, upload-time = "2026-01-10T09:23:17.148Z" }, + { url = "https://files.pythonhosted.org/packages/b9/ca/bf1c68440d7a868180e11be653c85959502efd3a709323230314fda6e0b3/websockets-16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19c4dc84098e523fd63711e563077d39e90ec6702aff4b5d9e344a60cb3c0cb1", size = 185711, upload-time = "2026-01-10T09:23:18.372Z" }, + { url = "https://files.pythonhosted.org/packages/c4/f8/fdc34643a989561f217bb477cbc47a3a07212cbda91c0e4389c43c296ebf/websockets-16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a5e18a238a2b2249c9a9235466b90e96ae4795672598a58772dd806edc7ac6d3", size = 184982, upload-time = "2026-01-10T09:23:19.652Z" }, + { url = "https://files.pythonhosted.org/packages/dd/d1/574fa27e233764dbac9c52730d63fcf2823b16f0856b3329fc6268d6ae4f/websockets-16.0-cp314-cp314-win32.whl", hash = "sha256:a069d734c4a043182729edd3e9f247c3b2a4035415a9172fd0f1b71658a320a8", size = 177915, upload-time = "2026-01-10T09:23:21.458Z" }, + { url = "https://files.pythonhosted.org/packages/8a/f1/ae6b937bf3126b5134ce1f482365fde31a357c784ac51852978768b5eff4/websockets-16.0-cp314-cp314-win_amd64.whl", hash = "sha256:c0ee0e63f23914732c6d7e0cce24915c48f3f1512ec1d079ed01fc629dab269d", size = 178381, upload-time = "2026-01-10T09:23:22.715Z" }, + { url = "https://files.pythonhosted.org/packages/06/9b/f791d1db48403e1f0a27577a6beb37afae94254a8c6f08be4a23e4930bc0/websockets-16.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:a35539cacc3febb22b8f4d4a99cc79b104226a756aa7400adc722e83b0d03244", size = 177737, upload-time = "2026-01-10T09:23:24.523Z" }, + { url = "https://files.pythonhosted.org/packages/bd/40/53ad02341fa33b3ce489023f635367a4ac98b73570102ad2cdd770dacc9a/websockets-16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b784ca5de850f4ce93ec85d3269d24d4c82f22b7212023c974c401d4980ebc5e", size = 175268, upload-time = "2026-01-10T09:23:25.781Z" }, + { url = "https://files.pythonhosted.org/packages/74/9b/6158d4e459b984f949dcbbb0c5d270154c7618e11c01029b9bbd1bb4c4f9/websockets-16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:569d01a4e7fba956c5ae4fc988f0d4e187900f5497ce46339c996dbf24f17641", size = 175486, upload-time = "2026-01-10T09:23:27.033Z" }, + { url = "https://files.pythonhosted.org/packages/e5/2d/7583b30208b639c8090206f95073646c2c9ffd66f44df967981a64f849ad/websockets-16.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:50f23cdd8343b984957e4077839841146f67a3d31ab0d00e6b824e74c5b2f6e8", size = 185331, upload-time = "2026-01-10T09:23:28.259Z" }, + { url = "https://files.pythonhosted.org/packages/45/b0/cce3784eb519b7b5ad680d14b9673a31ab8dcb7aad8b64d81709d2430aa8/websockets-16.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:152284a83a00c59b759697b7f9e9cddf4e3c7861dd0d964b472b70f78f89e80e", size = 186501, upload-time = "2026-01-10T09:23:29.449Z" }, + { url = "https://files.pythonhosted.org/packages/19/60/b8ebe4c7e89fb5f6cdf080623c9d92789a53636950f7abacfc33fe2b3135/websockets-16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bc59589ab64b0022385f429b94697348a6a234e8ce22544e3681b2e9331b5944", size = 186062, upload-time = "2026-01-10T09:23:31.368Z" }, + { url = "https://files.pythonhosted.org/packages/88/a8/a080593f89b0138b6cba1b28f8df5673b5506f72879322288b031337c0b8/websockets-16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:32da954ffa2814258030e5a57bc73a3635463238e797c7375dc8091327434206", size = 185356, upload-time = "2026-01-10T09:23:32.627Z" }, + { url = "https://files.pythonhosted.org/packages/c2/b6/b9afed2afadddaf5ebb2afa801abf4b0868f42f8539bfe4b071b5266c9fe/websockets-16.0-cp314-cp314t-win32.whl", hash = "sha256:5a4b4cc550cb665dd8a47f868c8d04c8230f857363ad3c9caf7a0c3bf8c61ca6", size = 178085, upload-time = "2026-01-10T09:23:33.816Z" }, + { url = "https://files.pythonhosted.org/packages/9f/3e/28135a24e384493fa804216b79a6a6759a38cc4ff59118787b9fb693df93/websockets-16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b14dc141ed6d2dde437cddb216004bcac6a1df0935d79656387bd41632ba0bbd", size = 178531, upload-time = "2026-01-10T09:23:35.016Z" }, + { url = "https://files.pythonhosted.org/packages/72/07/c98a68571dcf256e74f1f816b8cc5eae6eb2d3d5cfa44d37f801619d9166/websockets-16.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:349f83cd6c9a415428ee1005cadb5c2c56f4389bc06a9af16103c3bc3dcc8b7d", size = 174947, upload-time = "2026-01-10T09:23:36.166Z" }, + { url = "https://files.pythonhosted.org/packages/7e/52/93e166a81e0305b33fe416338be92ae863563fe7bce446b0f687b9df5aea/websockets-16.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:4a1aba3340a8dca8db6eb5a7986157f52eb9e436b74813764241981ca4888f03", size = 175260, upload-time = "2026-01-10T09:23:37.409Z" }, + { url = "https://files.pythonhosted.org/packages/56/0c/2dbf513bafd24889d33de2ff0368190a0e69f37bcfa19009ef819fe4d507/websockets-16.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f4a32d1bd841d4bcbffdcb3d2ce50c09c3909fbead375ab28d0181af89fd04da", size = 176071, upload-time = "2026-01-10T09:23:39.158Z" }, + { url = "https://files.pythonhosted.org/packages/a5/8f/aea9c71cc92bf9b6cc0f7f70df8f0b420636b6c96ef4feee1e16f80f75dd/websockets-16.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0298d07ee155e2e9fda5be8a9042200dd2e3bb0b8a38482156576f863a9d457c", size = 176968, upload-time = "2026-01-10T09:23:41.031Z" }, + { url = "https://files.pythonhosted.org/packages/9a/3f/f70e03f40ffc9a30d817eef7da1be72ee4956ba8d7255c399a01b135902a/websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767", size = 178735, upload-time = "2026-01-10T09:23:42.259Z" }, + { url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598, upload-time = "2026-01-10T09:23:45.395Z" }, +] + [[package]] name = "werkzeug" version = "3.1.8" @@ -3382,6 +3479,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/93/8c/2e650f2afeb7ee576912636c23ddb621c91ac6a98e66dc8d29c3c69446e1/werkzeug-3.1.8-py3-none-any.whl", hash = "sha256:63a77fb8892bf28ebc3178683445222aa500e48ebad5ec77b0ad80f8726b1f50", size = 226459, upload-time = "2026-04-02T18:49:12.72Z" }, ] +[[package]] +name = "wsproto" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/79/12135bdf8b9c9367b8701c2c19a14c913c120b882d50b014ca0d38083c2c/wsproto-1.3.2.tar.gz", hash = "sha256:b86885dcf294e15204919950f666e06ffc6c7c114ca900b060d6e16293528294", size = 50116, upload-time = "2025-11-20T18:18:01.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/f5/10b68b7b1544245097b2a1b8238f66f2fc6dcaeb24ba5d917f52bd2eed4f/wsproto-1.3.2-py3-none-any.whl", hash = "sha256:61eea322cdf56e8cc904bd3ad7573359a242ba65688716b0710a5eb12beab584", size = 24405, upload-time = "2025-11-20T18:18:00.454Z" }, +] + [[package]] name = "yarl" version = "1.23.0" From 6022321916b1136328bfa6fb5d62be66d4d43f01 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 24 Jun 2026 13:40:01 +0200 Subject: [PATCH 098/108] Guard WebSocket entrypoints with an inline wsproto import check Replace the `require_wsproto()` helper with a direct `try`/`import` guard in each WebSocket entrypoint (`Client.websocket`, `AsyncClient.websocket` and the top-level `__getattr__` resolution). When `wsproto` is missing the entrypoints raise a clear ImportError pointing to the `httpx2[ws]` extra, instead of surfacing a raw `ModuleNotFoundError`. --- src/httpx2/httpx2/__init__.py | 16 +++++----------- src/httpx2/httpx2/_client.py | 19 +++++++++---------- src/httpx2/httpx2/_websockets/_defaults.py | 7 ------- 3 files changed, 14 insertions(+), 28 deletions(-) diff --git a/src/httpx2/httpx2/__init__.py b/src/httpx2/httpx2/__init__.py index 5c4ec95f..df0f1057 100644 --- a/src/httpx2/httpx2/__init__.py +++ b/src/httpx2/httpx2/__init__.py @@ -115,7 +115,6 @@ "WebSocketNetworkError", "WebSocketSession", "WebSocketUpgradeError", - "websocket", } __locals = locals() @@ -139,17 +138,12 @@ def __getattr__(name: str) -> object: return main if name in _WEBSOCKET_NAMES: - from ._websockets._defaults import require_wsproto - - require_wsproto() - - if name == "websocket": - from ._api import websocket - - return websocket - from . import _websockets + from ._websockets._defaults import WS_EXTRA_INSTALL_MESSAGE - return getattr(_websockets, name) + try: + return getattr(_websockets, name) + except ImportError: + raise ImportError(WS_EXTRA_INSTALL_MESSAGE) raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/httpx2/httpx2/_client.py b/src/httpx2/httpx2/_client.py index b7a0f00d..88cee5b1 100644 --- a/src/httpx2/httpx2/_client.py +++ b/src/httpx2/httpx2/_client.py @@ -53,6 +53,7 @@ DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, DEFAULT_MAX_MESSAGE_SIZE_BYTES, DEFAULT_QUEUE_SIZE, + WS_EXTRA_INSTALL_MESSAGE, ) if typing.TYPE_CHECKING: @@ -877,11 +878,10 @@ def websocket( message = ws.receive_text() ``` """ - from ._websockets._defaults import require_wsproto - - require_wsproto() - - from ._websockets._api import connect_ws + try: + from ._websockets._api import connect_ws + except ImportError: + raise ImportError(WS_EXTRA_INSTALL_MESSAGE) with connect_ws( str(url), @@ -1622,11 +1622,10 @@ async def websocket( message = await ws.receive_text() ``` """ - from ._websockets._defaults import require_wsproto - - require_wsproto() - - from ._websockets._api import aconnect_ws + try: + from ._websockets._api import aconnect_ws + except ImportError: + raise ImportError(WS_EXTRA_INSTALL_MESSAGE) async with aconnect_ws( str(url), diff --git a/src/httpx2/httpx2/_websockets/_defaults.py b/src/httpx2/httpx2/_websockets/_defaults.py index e7c5a30b..6adc202e 100644 --- a/src/httpx2/httpx2/_websockets/_defaults.py +++ b/src/httpx2/httpx2/_websockets/_defaults.py @@ -6,10 +6,3 @@ DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS = 20.0 WS_EXTRA_INSTALL_MESSAGE = "WebSocket support requires the `wsproto` package. Install it with `pip install httpx2[ws]`." - - -def require_wsproto() -> None: - try: - import wsproto # noqa: F401 - except ImportError as exc: # pragma: no cover - raise ImportError(WS_EXTRA_INSTALL_MESSAGE) from exc From 3ec1f634104311b0d85570f01e25091903cbc52a Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 29 Jun 2026 15:59:48 +0200 Subject: [PATCH 099/108] Credit httpx-ws and its MIT license in _websockets/__init__.py --- src/httpx2/httpx2/_websockets/LICENSE | 21 --------------------- src/httpx2/httpx2/_websockets/__init__.py | 6 ++++++ 2 files changed, 6 insertions(+), 21 deletions(-) delete mode 100644 src/httpx2/httpx2/_websockets/LICENSE diff --git a/src/httpx2/httpx2/_websockets/LICENSE b/src/httpx2/httpx2/_websockets/LICENSE deleted file mode 100644 index 3ea977c7..00000000 --- a/src/httpx2/httpx2/_websockets/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 François Voron - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index 856cd0cb..a21e1838 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -1,3 +1,9 @@ +""" +WebSocket support, derived from httpx-ws (https://github.com/frankie567/httpx-ws). + +Copyright (c) 2021 François Voron, MIT License (https://github.com/frankie567/httpx-ws/blob/main/LICENSE). +""" + from ._defaults import ( DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, From 4437d9d7941104d7eb6c6e82fb7f327707866f0e Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 29 Jun 2026 23:28:13 +0200 Subject: [PATCH 100/108] Drop redundant underscore prefix from _websockets submodules --- src/httpx2/httpx2/__init__.py | 8 ++++---- src/httpx2/httpx2/_api.py | 4 ++-- src/httpx2/httpx2/_client.py | 8 ++++---- src/httpx2/httpx2/_websockets/__init__.py | 12 ++++++------ src/httpx2/httpx2/_websockets/{_api.py => api.py} | 8 ++++---- .../httpx2/_websockets/{_defaults.py => defaults.py} | 0 .../_websockets/{_exceptions.py => exceptions.py} | 0 src/httpx2/httpx2/_websockets/{_ping.py => ping.py} | 0 .../_websockets/{_transport.py => transport.py} | 2 +- tests/httpx2/websockets/test_api.py | 8 ++++---- tests/httpx2/websockets/test_transport.py | 6 +++--- 11 files changed, 28 insertions(+), 28 deletions(-) rename src/httpx2/httpx2/_websockets/{_api.py => api.py} (99%) rename src/httpx2/httpx2/_websockets/{_defaults.py => defaults.py} (100%) rename src/httpx2/httpx2/_websockets/{_exceptions.py => exceptions.py} (100%) rename src/httpx2/httpx2/_websockets/{_ping.py => ping.py} (100%) rename src/httpx2/httpx2/_websockets/{_transport.py => transport.py} (99%) diff --git a/src/httpx2/httpx2/__init__.py b/src/httpx2/httpx2/__init__.py index 23163777..79cd6a45 100644 --- a/src/httpx2/httpx2/__init__.py +++ b/src/httpx2/httpx2/__init__.py @@ -15,15 +15,15 @@ from ._urls import * if _typing.TYPE_CHECKING: - from ._websockets._api import AsyncWebSocketSession, WebSocketSession - from ._websockets._exceptions import ( + from ._websockets.api import AsyncWebSocketSession, WebSocketSession + from ._websockets.exceptions import ( HTTPXWSException, WebSocketDisconnect, WebSocketInvalidTypeReceived, WebSocketNetworkError, WebSocketUpgradeError, ) - from ._websockets._transport import ASGIWebSocketTransport + from ._websockets.transport import ASGIWebSocketTransport __all__ = [ "__description__", @@ -143,7 +143,7 @@ def __getattr__(name: str) -> object: if name in _WEBSOCKET_NAMES: from . import _websockets - from ._websockets._defaults import WS_EXTRA_INSTALL_MESSAGE + from ._websockets.defaults import WS_EXTRA_INSTALL_MESSAGE try: return getattr(_websockets, name) diff --git a/src/httpx2/httpx2/_api.py b/src/httpx2/httpx2/_api.py index 60cfddcf..08fa5bdd 100644 --- a/src/httpx2/httpx2/_api.py +++ b/src/httpx2/httpx2/_api.py @@ -19,7 +19,7 @@ TimeoutTypes, ) from ._urls import URL -from ._websockets._defaults import ( +from ._websockets.defaults import ( DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, DEFAULT_MAX_MESSAGE_SIZE_BYTES, @@ -29,7 +29,7 @@ if typing.TYPE_CHECKING: import ssl # pragma: no cover - from ._websockets._api import WebSocketSession + from ._websockets.api import WebSocketSession __all__ = [ diff --git a/src/httpx2/httpx2/_client.py b/src/httpx2/httpx2/_client.py index 595e3ede..5eebd636 100644 --- a/src/httpx2/httpx2/_client.py +++ b/src/httpx2/httpx2/_client.py @@ -49,7 +49,7 @@ ) from ._urls import URL, QueryParams from ._utils import URLPattern, get_environment_proxies -from ._websockets._defaults import ( +from ._websockets.defaults import ( DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, DEFAULT_MAX_MESSAGE_SIZE_BYTES, @@ -60,7 +60,7 @@ if typing.TYPE_CHECKING: import ssl # pragma: no cover - from ._websockets._api import AsyncWebSocketSession, WebSocketSession + from ._websockets.api import AsyncWebSocketSession, WebSocketSession __all__ = ["USE_CLIENT_DEFAULT", "AsyncClient", "Client"] @@ -922,7 +922,7 @@ def websocket( ``` """ try: - from ._websockets._api import connect_ws + from ._websockets.api import connect_ws except ImportError: raise ImportError(WS_EXTRA_INSTALL_MESSAGE) @@ -1708,7 +1708,7 @@ async def websocket( ``` """ try: - from ._websockets._api import aconnect_ws + from ._websockets.api import aconnect_ws except ImportError: raise ImportError(WS_EXTRA_INSTALL_MESSAGE) diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index a21e1838..a6d5d5bd 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -4,7 +4,7 @@ Copyright (c) 2021 François Voron, MIT License (https://github.com/frankie567/httpx-ws/blob/main/LICENSE). """ -from ._defaults import ( +from .defaults import ( DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, DEFAULT_MAX_MESSAGE_SIZE_BYTES, @@ -47,15 +47,15 @@ def __getattr__(name: str) -> object: if name in _API_NAMES: - from . import _api + from . import api - return getattr(_api, name) + return getattr(api, name) if name in _EXCEPTION_NAMES: - from . import _exceptions + from . import exceptions - return getattr(_exceptions, name) + return getattr(exceptions, name) if name == "ASGIWebSocketTransport": - from ._transport import ASGIWebSocketTransport + from .transport import ASGIWebSocketTransport return ASGIWebSocketTransport raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/httpx2/httpx2/_websockets/_api.py b/src/httpx2/httpx2/_websockets/api.py similarity index 99% rename from src/httpx2/httpx2/_websockets/_api.py rename to src/httpx2/httpx2/_websockets/api.py index ac9f07c6..84f8f91f 100644 --- a/src/httpx2/httpx2/_websockets/_api.py +++ b/src/httpx2/httpx2/_websockets/api.py @@ -15,21 +15,21 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from wsproto.frame_protocol import CloseReason -from ._defaults import ( +from .defaults import ( DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, DEFAULT_MAX_MESSAGE_SIZE_BYTES, DEFAULT_QUEUE_SIZE, ) -from ._exceptions import ( +from .exceptions import ( HTTPXWSException, WebSocketDisconnect, WebSocketInvalidTypeReceived, WebSocketNetworkError, WebSocketUpgradeError, ) -from ._ping import AsyncPingManager, PingManager -from ._transport import ASGIWebSocketAsyncNetworkStream +from .ping import AsyncPingManager, PingManager +from .transport import ASGIWebSocketAsyncNetworkStream if typing.TYPE_CHECKING: from httpcore2 import AsyncNetworkStream, NetworkStream diff --git a/src/httpx2/httpx2/_websockets/_defaults.py b/src/httpx2/httpx2/_websockets/defaults.py similarity index 100% rename from src/httpx2/httpx2/_websockets/_defaults.py rename to src/httpx2/httpx2/_websockets/defaults.py diff --git a/src/httpx2/httpx2/_websockets/_exceptions.py b/src/httpx2/httpx2/_websockets/exceptions.py similarity index 100% rename from src/httpx2/httpx2/_websockets/_exceptions.py rename to src/httpx2/httpx2/_websockets/exceptions.py diff --git a/src/httpx2/httpx2/_websockets/_ping.py b/src/httpx2/httpx2/_websockets/ping.py similarity index 100% rename from src/httpx2/httpx2/_websockets/_ping.py rename to src/httpx2/httpx2/_websockets/ping.py diff --git a/src/httpx2/httpx2/_websockets/_transport.py b/src/httpx2/httpx2/_websockets/transport.py similarity index 99% rename from src/httpx2/httpx2/_websockets/_transport.py rename to src/httpx2/httpx2/_websockets/transport.py index 761fb52d..621bd55c 100644 --- a/src/httpx2/httpx2/_websockets/_transport.py +++ b/src/httpx2/httpx2/_websockets/transport.py @@ -12,7 +12,7 @@ from .._models import Request, Response from .._transports.asgi import ASGITransport, _ASGIApp from .._types import AsyncByteStream -from ._exceptions import WebSocketDisconnect +from .exceptions import WebSocketDisconnect Scope = dict[str, typing.Any] Message = dict[str, typing.Any] diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 71b5659e..efc66a0d 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -14,14 +14,14 @@ import httpcore2 as httpcore import httpx2 as httpx from httpcore2 import AsyncNetworkStream, NetworkStream -from httpx2._websockets._api import ( +from httpx2._websockets.api import ( AsyncWebSocketSession, JSONMode, WebSocketSession, aconnect_ws, connect_ws, ) -from httpx2._websockets._exceptions import ( +from httpx2._websockets.exceptions import ( WebSocketDisconnect, WebSocketInvalidTypeReceived, WebSocketNetworkError, @@ -721,7 +721,7 @@ async def websocket_endpoint(websocket: WebSocket) -> None: @pytest.mark.anyio async def test_default_httpx_client() -> None: mock_context = contextlib.ExitStack() - with patch("httpx2._websockets._api._connect_ws", return_value=mock_context) as mock_connect_ws: + with patch("httpx2._websockets.api._connect_ws", return_value=mock_context) as mock_connect_ws: with connect_ws("http://socket/ws"): pass mock_connect_ws.assert_called_once() @@ -730,7 +730,7 @@ async def test_default_httpx_client() -> None: assert httpx_client.is_closed mock_async_context = contextlib.AsyncExitStack() - with patch("httpx2._websockets._api._aconnect_ws", return_value=mock_async_context) as mock_aconnect_ws: + with patch("httpx2._websockets.api._aconnect_ws", return_value=mock_async_context) as mock_aconnect_ws: async with aconnect_ws("http://socket/ws"): pass mock_aconnect_ws.assert_called_once() diff --git a/tests/httpx2/websockets/test_transport.py b/tests/httpx2/websockets/test_transport.py index 736f70b2..259789f6 100644 --- a/tests/httpx2/websockets/test_transport.py +++ b/tests/httpx2/websockets/test_transport.py @@ -13,9 +13,9 @@ from starlette.websockets import WebSocket import httpx2 as httpx -from httpx2._websockets._api import aconnect_ws -from httpx2._websockets._exceptions import WebSocketDisconnect -from httpx2._websockets._transport import ( +from httpx2._websockets.api import aconnect_ws +from httpx2._websockets.exceptions import WebSocketDisconnect +from httpx2._websockets.transport import ( ASGIWebSocketAsyncNetworkStream, ASGIWebSocketTransport, Receive, From bd3a16962b5ed4a009214b9125b1bb5c782e2a17 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 30 Jun 2026 09:41:31 +0200 Subject: [PATCH 101/108] Pin the test server to ws="none" so installing websockets does not hang it The session-scoped `server` fixture serves a raw ASGI app and never needs WebSocket support, but its `uvicorn.Config` left `ws` at the default `"auto"`. Once `websockets`/`wsproto` are installed (as this branch requires), uvicorn's auto-detection deadlocks server startup in the fixture thread, so `started` is never set and `serve_in_thread` busy-waits forever. --- tests/httpx2/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/httpx2/conftest.py b/tests/httpx2/conftest.py index 156c5e3a..5db7aa68 100644 --- a/tests/httpx2/conftest.py +++ b/tests/httpx2/conftest.py @@ -274,6 +274,6 @@ def serve_in_thread(server: TestServer) -> typing.Iterator[TestServer]: @pytest.fixture(scope="session") def server(free_tcp_port_factory: typing.Callable[[], int]) -> typing.Iterator[TestServer]: - config = Config(app=app, lifespan="off", loop="asyncio", port=free_tcp_port_factory()) + config = Config(app=app, lifespan="off", loop="asyncio", ws="none", port=free_tcp_port_factory()) server = TestServer(config=config) yield from serve_in_thread(server) From edeb01f1547e6a8887de0bd82436890628d5ad0c Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 30 Jun 2026 14:21:23 +0200 Subject: [PATCH 102/108] Restore 100% coverage for the WebSocket code Add high-level tests covering `httpx2.websocket()`, `Client.websocket()`, top-level lazy name access, and fragmented-message reassembly. Exclude the aliased `_typing.TYPE_CHECKING` block from coverage, and mark the defensive `wsproto`-not-installed import guards and racy test-only branches as no cover. --- pyproject.toml | 2 +- src/httpx2/httpx2/__init__.py | 2 +- src/httpx2/httpx2/_client.py | 4 +- src/httpx2/httpx2/_websockets/__init__.py | 2 +- tests/httpx2/websockets/test_api.py | 58 ++++++++-------- tests/httpx2/websockets/test_high_level.py | 81 ++++++++++++++++++++++ tests/httpx2/websockets/test_transport.py | 8 +-- 7 files changed, 119 insertions(+), 38 deletions(-) create mode 100644 tests/httpx2/websockets/test_high_level.py diff --git a/pyproject.toml b/pyproject.toml index d31db5ee..f248d8c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ omit = ["src/httpcore2/httpcore2/_sync/*", "tests/test_benchmark.py"] [tool.coverage.report] exclude_also = [ "if TYPE_CHECKING:", - "if typing.TYPE_CHECKING:", + "if (typing|_typing).TYPE_CHECKING:", "raise NotImplementedError", "@(typing\\.)?overload", ] diff --git a/src/httpx2/httpx2/__init__.py b/src/httpx2/httpx2/__init__.py index 79cd6a45..b87e3b92 100644 --- a/src/httpx2/httpx2/__init__.py +++ b/src/httpx2/httpx2/__init__.py @@ -147,7 +147,7 @@ def __getattr__(name: str) -> object: try: return getattr(_websockets, name) - except ImportError: + except ImportError: # pragma: no cover raise ImportError(WS_EXTRA_INSTALL_MESSAGE) raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/httpx2/httpx2/_client.py b/src/httpx2/httpx2/_client.py index 5eebd636..8f203b52 100644 --- a/src/httpx2/httpx2/_client.py +++ b/src/httpx2/httpx2/_client.py @@ -923,7 +923,7 @@ def websocket( """ try: from ._websockets.api import connect_ws - except ImportError: + except ImportError: # pragma: no cover raise ImportError(WS_EXTRA_INSTALL_MESSAGE) with connect_ws( @@ -1709,7 +1709,7 @@ async def websocket( """ try: from ._websockets.api import aconnect_ws - except ImportError: + except ImportError: # pragma: no cover raise ImportError(WS_EXTRA_INSTALL_MESSAGE) async with aconnect_ws( diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py index a6d5d5bd..7ff2cf31 100644 --- a/src/httpx2/httpx2/_websockets/__init__.py +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -58,4 +58,4 @@ def __getattr__(name: str) -> object: from .transport import ASGIWebSocketTransport return ASGIWebSocketTransport - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") # pragma: no cover diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index efc66a0d..36c63588 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -38,12 +38,12 @@ def handler(request: httpx.Request) -> httpx.Response: with httpx.Client(base_url="http://localhost:8000", transport=httpx.MockTransport(handler)) as client: with pytest.raises(WebSocketUpgradeError): with connect_ws("http://socket/ws", client): - pass + pass # pragma: no cover async with httpx.AsyncClient(base_url="http://localhost:8000", transport=httpx.MockTransport(handler)) as aclient: with pytest.raises(WebSocketUpgradeError): async with aconnect_ws("http://socket/ws", aclient): - pass + pass # pragma: no cover @pytest.mark.anyio @@ -77,9 +77,9 @@ def __init__(self) -> None: self._should_close = False async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: - while not self._should_close: + while not self._should_close: # pragma: no cover await anyio.sleep(0.1) - raise httpcore.ReadError() + raise httpcore.ReadError() # pragma: no cover async def write(self, buffer: bytes, timeout: float | None = None) -> None: raise httpcore.WriteError() @@ -110,14 +110,14 @@ async def websocket_endpoint(websocket: WebSocket) -> None: try: with connect_ws("http://socket/ws", client) as ws: ws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: try: async with aconnect_ws("http://socket/ws", aclient) as aws: await aws.send(wsproto.events.TextMessage(data="CLIENT_MESSAGE")) - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass on_receive_message.assert_has_calls([call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")]) @@ -140,14 +140,14 @@ async def websocket_endpoint(websocket: WebSocket) -> None: try: with connect_ws("http://socket/ws", client) as ws: ws.send_text("CLIENT_MESSAGE") - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: try: async with aconnect_ws("http://socket/ws", aclient) as aws: await aws.send_text("CLIENT_MESSAGE") - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass on_receive_message.assert_has_calls([call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")]) @@ -170,14 +170,14 @@ async def websocket_endpoint(websocket: WebSocket) -> None: try: with connect_ws("http://socket/ws", client) as ws: ws.send_bytes(b"CLIENT_MESSAGE") - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: try: async with aconnect_ws("http://socket/ws", aclient) as aws: await aws.send_bytes(b"CLIENT_MESSAGE") - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass on_receive_message.assert_has_calls([call(b"CLIENT_MESSAGE"), call(b"CLIENT_MESSAGE")]) @@ -202,14 +202,14 @@ async def websocket_endpoint(websocket: WebSocket) -> None: try: with connect_ws("http://socket/ws", client) as ws: ws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: try: async with aconnect_ws("http://socket/ws", aclient) as aws: await aws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass on_receive_message.assert_has_calls([call({"message": "CLIENT_MESSAGE"}), call({"message": "CLIENT_MESSAGE"})]) @@ -270,7 +270,7 @@ async def websocket_endpoint(websocket: WebSocket) -> None: event = ws.receive() assert isinstance(event, wsproto.events.TextMessage) assert event.data == "SERVER_MESSAGE" - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: @@ -279,7 +279,7 @@ async def websocket_endpoint(websocket: WebSocket) -> None: event = await aws.receive() assert isinstance(event, wsproto.events.TextMessage) assert event.data == "SERVER_MESSAGE" - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass @pytest.mark.parametrize( @@ -310,7 +310,7 @@ async def websocket_endpoint(websocket: WebSocket) -> None: event = ws.receive() assert isinstance(event, wsproto.events.Message) assert event.data == full_message - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: @@ -323,7 +323,7 @@ async def websocket_endpoint(websocket: WebSocket) -> None: event = await aws.receive() assert isinstance(event, wsproto.events.Message) assert event.data == full_message - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass async def test_receive_text(self, server_factory: ServerFactoryFixture) -> None: @@ -340,7 +340,7 @@ async def websocket_endpoint(websocket: WebSocket) -> None: with connect_ws("http://socket/ws", client) as ws: data = ws.receive_text() assert data == "SERVER_MESSAGE" - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: @@ -348,7 +348,7 @@ async def websocket_endpoint(websocket: WebSocket) -> None: async with aconnect_ws("http://socket/ws", aclient) as aws: data = await aws.receive_text() assert data == "SERVER_MESSAGE" - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass async def test_receive_text_invalid_type(self, server_factory: ServerFactoryFixture) -> None: @@ -365,7 +365,7 @@ async def websocket_endpoint(websocket: WebSocket) -> None: with connect_ws("http://socket/ws", client) as ws: with pytest.raises(WebSocketInvalidTypeReceived): ws.receive_text() - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: @@ -373,7 +373,7 @@ async def websocket_endpoint(websocket: WebSocket) -> None: async with aconnect_ws("http://socket/ws", aclient) as aws: with pytest.raises(WebSocketInvalidTypeReceived): await aws.receive_text() - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass async def test_receive_bytes(self, server_factory: ServerFactoryFixture) -> None: @@ -390,7 +390,7 @@ async def websocket_endpoint(websocket: WebSocket) -> None: with connect_ws("http://socket/ws", client) as ws: data = ws.receive_bytes() assert data == b"SERVER_MESSAGE" - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: @@ -398,7 +398,7 @@ async def websocket_endpoint(websocket: WebSocket) -> None: async with aconnect_ws("http://socket/ws", aclient) as aws: data = await aws.receive_bytes() assert data == b"SERVER_MESSAGE" - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass async def test_receive_bytes_invalid_type(self, server_factory: ServerFactoryFixture) -> None: @@ -435,7 +435,7 @@ async def websocket_endpoint(websocket: WebSocket) -> None: with connect_ws("http://socket/ws", client) as ws: data = ws.receive_json(mode=mode) assert data == {"message": "SERVER_MESSAGE"} - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: @@ -443,7 +443,7 @@ async def websocket_endpoint(websocket: WebSocket) -> None: async with aconnect_ws("http://socket/ws", aclient) as aws: data = await aws.receive_json(mode=mode) assert data == {"message": "SERVER_MESSAGE"} - except WebSocketDisconnect: + except WebSocketDisconnect: # pragma: no cover pass @@ -462,7 +462,7 @@ def read(self, max_bytes: int, timeout: float | None = None) -> bytes: try: event = self.events_to_send.pop(0) return self.connection.send(event) - except IndexError: + except IndexError: # pragma: no cover raise httpcore.ReadError() def write(self, buffer: bytes, timeout: float | None = None) -> None: @@ -494,7 +494,7 @@ async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: try: event = self.events_to_send.pop(0) return self.connection.send(event) - except IndexError: + except IndexError: # pragma: no cover raise httpcore.ReadError() async def write(self, buffer: bytes, timeout: float | None = None) -> None: @@ -603,7 +603,7 @@ async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: return self.connection.send(event) except anyio.WouldBlock: await anyio.sleep(0.1) - raise httpcore.ReadError() + raise httpcore.ReadError() # pragma: no cover async def write(self, buffer: bytes, timeout: float | None = None) -> None: self.connection.receive_data(buffer) @@ -635,9 +635,9 @@ def __init__(self) -> None: self._should_close = False async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: - while not self._should_close: + while not self._should_close: # pragma: no cover await anyio.sleep(0.1) - raise httpcore.ReadError() + raise httpcore.ReadError() # pragma: no cover async def write(self, buffer: bytes, timeout: float | None = None) -> None: pass diff --git a/tests/httpx2/websockets/test_high_level.py b/tests/httpx2/websockets/test_high_level.py new file mode 100644 index 00000000..7423963d --- /dev/null +++ b/tests/httpx2/websockets/test_high_level.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import contextlib +from unittest.mock import patch + +import anyio +import pytest +import wsproto +from starlette.websockets import WebSocket + +import httpx2 as httpx +from httpcore2 import AsyncNetworkStream +from httpx2._websockets.api import AsyncWebSocketSession, WebSocketSession +from httpx2._websockets.transport import ASGIWebSocketTransport +from tests.httpx2.websockets.conftest import ServerFactoryFixture + + +@pytest.mark.parametrize( + ("name", "expected"), + [ + ("WebSocketSession", WebSocketSession), + ("AsyncWebSocketSession", AsyncWebSocketSession), + ("ASGIWebSocketTransport", ASGIWebSocketTransport), + ("WebSocketDisconnect", httpx.WebSocketDisconnect), + ], +) +def test_top_level_names_are_lazily_exported(name: str, expected: object) -> None: + assert getattr(httpx, name) is expected + + +def test_top_level_websocket_uses_a_dedicated_client() -> None: + mock_context = contextlib.ExitStack() + with patch("httpx2._websockets.api._connect_ws", return_value=mock_context) as mock_connect_ws: + with httpx.websocket("http://socket/ws"): + pass + mock_connect_ws.assert_called_once() + client = mock_connect_ws.call_args[1]["client"] + assert isinstance(client, httpx.Client) + assert client.is_closed + + +@pytest.mark.anyio +async def test_client_websocket(server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.send_text("SERVER_MESSAGE") + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + assert ws.receive_text() == "SERVER_MESSAGE" + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + assert await aws.receive_text() == "SERVER_MESSAGE" + + +@pytest.mark.anyio +async def test_async_receive_reassembles_fragmented_message() -> None: + server = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER) + fragments = server.send(wsproto.events.TextMessage("FRAG", message_finished=False)) + fragments += server.send(wsproto.events.TextMessage("MEN", message_finished=False)) + fragments += server.send(wsproto.events.TextMessage("TED", message_finished=True)) + + class AsyncMockNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self._sent = False + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + if self._sent: + await anyio.sleep_forever() + self._sent = True + return fragments + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: ... + + async def aclose(self) -> None: ... + + async with AsyncWebSocketSession(AsyncMockNetworkStream()) as ws: + assert await ws.receive_text() == "FRAGMENTED" diff --git a/tests/httpx2/websockets/test_transport.py b/tests/httpx2/websockets/test_transport.py index 259789f6..e14365f3 100644 --- a/tests/httpx2/websockets/test_transport.py +++ b/tests/httpx2/websockets/test_transport.py @@ -131,7 +131,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: with pytest.raises(WebSocketDisconnect): async with ASGIWebSocketAsyncNetworkStream(app, scope): - pass + pass # pragma: no cover async def test_exception(self, scope: Scope) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: @@ -139,7 +139,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: with pytest.raises(WebSocketDisconnect) as excinfo: async with ASGIWebSocketAsyncNetworkStream(app, scope): - pass + pass # pragma: no cover assert excinfo.value.code == 1011 assert excinfo.value.reason == "Error" @@ -152,7 +152,7 @@ async def http_endpoint(request: Request) -> PlainTextResponse: async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.receive_text() - await websocket.close() + await websocket.close() # pragma: no cover routes = [ Route("/http", endpoint=http_endpoint), @@ -218,7 +218,7 @@ async def test_keepalive_ping_disabled() -> None: async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.receive_text() - await websocket.close() + await websocket.close() # pragma: no cover app = Starlette( routes=[ From 6a48d21f7554056c28b956d5cf93803568f80318 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 30 Jun 2026 14:44:01 +0200 Subject: [PATCH 103/108] Make the WebSocket suite robust under full-suite ordering on 3.10 Resolve `httpx2._websockets` via `importlib` so `mock.patch` targets keep working after `tests/httpx2/test_api.py` drops `httpx2` from `sys.modules`. Track session threads by identity instead of asserting absolute counts, so `test_threads_wont_hang` no longer flakes on threads owned by other tests. --- src/httpx2/httpx2/__init__.py | 5 +++++ tests/httpx2/websockets/test_api.py | 20 +++++++++++++------- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/httpx2/httpx2/__init__.py b/src/httpx2/httpx2/__init__.py index b87e3b92..7abf8097 100644 --- a/src/httpx2/httpx2/__init__.py +++ b/src/httpx2/httpx2/__init__.py @@ -150,4 +150,9 @@ def __getattr__(name: str) -> object: except ImportError: # pragma: no cover raise ImportError(WS_EXTRA_INSTALL_MESSAGE) + if name == "_websockets": + import importlib + + return importlib.import_module(f"{__name__}._websockets") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 36c63588..a8595f92 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -797,16 +797,22 @@ async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.receive_text() await websocket.close() + def session_threads() -> set[threading.Thread]: + return set(threading.enumerate()) - threads_before + + def wait_for_session_threads(expected: int) -> None: + for _ in range(100): + if len(session_threads()) == expected: + return + time.sleep(0.05) + assert len(session_threads()) == expected # pragma: no cover + with server_factory(websocket_endpoint) as socket: with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: - initial_threads_count = threading.active_count() + threads_before = set(threading.enumerate()) with connect_ws("http://socket/ws", client, keepalive_ping_interval_seconds=None) as ws: for _ in range(50): ws.receive() ws.send_text("CLIENT_MESSAGE") - time.sleep(0.1) # Let the websocket endpoint finish its handling. - threads_count = threading.active_count() - assert initial_threads_count + 2 == threads_count - time.sleep(0.1) - final_threads_count = threading.active_count() - assert initial_threads_count == final_threads_count + wait_for_session_threads(2) + wait_for_session_threads(0) From f2a0211487239f59f0f488ae267199a59132d5e1 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 30 Jun 2026 15:00:27 +0200 Subject: [PATCH 104/108] Skip the async keepalive ping once the session starts closing The async keepalive loop slept for the interval and then pinged unconditionally, so a close racing in during the sleep left it sending a Ping on a connection already in `LOCAL_CLOSING`, raising `LocalProtocolError`. Re-check `_should_close` after the sleep, mirroring the synchronous keepalive loop. --- src/httpx2/httpx2/_websockets/api.py | 2 ++ tests/httpx2/websockets/test_api.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/httpx2/httpx2/_websockets/api.py b/src/httpx2/httpx2/_websockets/api.py index 84f8f91f..e0e48cf0 100644 --- a/src/httpx2/httpx2/_websockets/api.py +++ b/src/httpx2/httpx2/_websockets/api.py @@ -1013,6 +1013,8 @@ async def _background_receive(self, max_bytes: int) -> None: async def _background_keepalive_ping(self, interval_seconds: float, timeout_seconds: float | None = None) -> None: while not self._should_close.is_set(): await anyio.sleep(interval_seconds) + if self._should_close.is_set(): + return pong_callback = await self.ping() if timeout_seconds is not None: try: diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index a8595f92..a6ae8918 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -628,6 +628,27 @@ async def aclose(self) -> None: assert stream.ping_received >= 1 assert stream.ping_answered >= 1 + async def test_async_keepalive_ping_skips_when_closing(self) -> None: + writes = 0 + + class MockAsyncNetworkStream(AsyncNetworkStream): + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + return b"" # pragma: no cover + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + nonlocal writes + writes += 1 # pragma: no cover + + async def aclose(self) -> None: ... # pragma: no cover + + session = AsyncWebSocketSession(MockAsyncNetworkStream(), keepalive_ping_interval_seconds=0.2) + async with anyio.create_task_group() as tg: + tg.start_soon(session._background_keepalive_ping, 0.2) + await anyio.sleep(0.05) # Let the loop enter its sleep before closing. + session._should_close.set() + + assert writes == 0 + async def test_async_keepalive_ping_timeout(self) -> None: class MockAsyncNetworkStream(AsyncNetworkStream): def __init__(self) -> None: From 9b612eb3e394a4a3a533086adfa947f563ac0fbd Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 30 Jun 2026 15:16:00 +0200 Subject: [PATCH 105/108] Exclude blocking mock reads from coverage to keep it stable under load These mock `read()` loops run in the session's background thread and only exit once the test closes the connection, so whether the loop body and the trailing `raise` execute before teardown depends on scheduling. Under full suite load that made coverage dip below 100% on some Python versions. --- tests/httpx2/websockets/test_api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index a6ae8918..94143e84 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -55,9 +55,9 @@ def __init__(self) -> None: self._should_close = False def read(self, max_bytes: int, timeout: float | None = None) -> bytes: - while not self._should_close: + while not self._should_close: # pragma: no cover time.sleep(0.1) - raise httpcore.ReadError() + raise httpcore.ReadError() # pragma: no cover def write(self, buffer: bytes, timeout: float | None = None) -> None: raise httpcore.WriteError() @@ -563,9 +563,9 @@ def __init__(self) -> None: self._should_close = False def read(self, max_bytes: int, timeout: float | None = None) -> bytes: - while not self._should_close: + while not self._should_close: # pragma: no cover time.sleep(0.1) - raise httpcore.ReadError() + raise httpcore.ReadError() # pragma: no cover def write(self, buffer: bytes, timeout: float | None = None) -> None: pass From fa810a12163b736bf072d40937bc11d3b9306769 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 30 Jun 2026 15:32:49 +0200 Subject: [PATCH 106/108] Measure background-thread coverage with coverage's thread concurrency The WebSocket session runs its receive and keepalive loops in background threads, whose lines coverage only records with `concurrency = ["thread"]`. Without it, coverage dipped below 100% on whichever Python version happened to tear a thread down early under load. Enabling it lets the existing tests cover those lines deterministically, so the ad-hoc `no cover` pragmas on the blocking mock reads are no longer needed. --- pyproject.toml | 1 + tests/httpx2/websockets/test_api.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f248d8c7..663231ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ markers = [ ] [tool.coverage.run] +concurrency = ["thread"] source_pkgs = ["httpx2", "httpcore2", "tests"] omit = ["src/httpcore2/httpcore2/_sync/*", "tests/test_benchmark.py"] diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index 94143e84..a6ae8918 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -55,9 +55,9 @@ def __init__(self) -> None: self._should_close = False def read(self, max_bytes: int, timeout: float | None = None) -> bytes: - while not self._should_close: # pragma: no cover + while not self._should_close: time.sleep(0.1) - raise httpcore.ReadError() # pragma: no cover + raise httpcore.ReadError() def write(self, buffer: bytes, timeout: float | None = None) -> None: raise httpcore.WriteError() @@ -563,9 +563,9 @@ def __init__(self) -> None: self._should_close = False def read(self, max_bytes: int, timeout: float | None = None) -> bytes: - while not self._should_close: # pragma: no cover + while not self._should_close: time.sleep(0.1) - raise httpcore.ReadError() # pragma: no cover + raise httpcore.ReadError() def write(self, buffer: bytes, timeout: float | None = None) -> None: pass From b77a190dc59ff8e8df0020ec39633b5c604fc092 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 30 Jun 2026 15:59:54 +0200 Subject: [PATCH 107/108] Exclude race-dependent background-thread lines from coverage The sync keepalive close-guard and the blocking mock reads only run when a background thread happens to reach them before the session tears down, so their coverage swings with scheduling and fails the gate on whichever Python version runs slowest under load. Mark them `no cover` and drop the thread concurrency setting, which only masked the flakiness via incidental timing. --- pyproject.toml | 1 - src/httpx2/httpx2/_websockets/api.py | 2 +- tests/httpx2/websockets/test_api.py | 8 ++++---- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 663231ac..f248d8c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,6 @@ markers = [ ] [tool.coverage.run] -concurrency = ["thread"] source_pkgs = ["httpx2", "httpcore2", "tests"] omit = ["src/httpcore2/httpcore2/_sync/*", "tests/test_benchmark.py"] diff --git a/src/httpx2/httpx2/_websockets/api.py b/src/httpx2/httpx2/_websockets/api.py index e0e48cf0..b81d7e77 100644 --- a/src/httpx2/httpx2/_websockets/api.py +++ b/src/httpx2/httpx2/_websockets/api.py @@ -517,7 +517,7 @@ def _background_keepalive_ping(self, interval_seconds: float, timeout_seconds: f try: while not self._should_close.is_set(): should_close = self._wait_until_closed(self._should_close.wait, interval_seconds) - if should_close: + if should_close: # pragma: no cover raise ShouldClose() pong_callback = self.ping() if timeout_seconds is not None: diff --git a/tests/httpx2/websockets/test_api.py b/tests/httpx2/websockets/test_api.py index a6ae8918..94143e84 100644 --- a/tests/httpx2/websockets/test_api.py +++ b/tests/httpx2/websockets/test_api.py @@ -55,9 +55,9 @@ def __init__(self) -> None: self._should_close = False def read(self, max_bytes: int, timeout: float | None = None) -> bytes: - while not self._should_close: + while not self._should_close: # pragma: no cover time.sleep(0.1) - raise httpcore.ReadError() + raise httpcore.ReadError() # pragma: no cover def write(self, buffer: bytes, timeout: float | None = None) -> None: raise httpcore.WriteError() @@ -563,9 +563,9 @@ def __init__(self) -> None: self._should_close = False def read(self, max_bytes: int, timeout: float | None = None) -> bytes: - while not self._should_close: + while not self._should_close: # pragma: no cover time.sleep(0.1) - raise httpcore.ReadError() + raise httpcore.ReadError() # pragma: no cover def write(self, buffer: bytes, timeout: float | None = None) -> None: pass From 79289f306c24771b7c26d473e5d4b6bf6334887b Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 1 Jul 2026 09:14:59 +0200 Subject: [PATCH 108/108] Replace WebSocket **kwargs with explicit handshake request parameters `Client.websocket()`, `AsyncClient.websocket()` and the vendored `connect_ws`/`aconnect_ws` forwarded arbitrary `**kwargs` to the handshake `stream()` call. Spell out the request parameters instead - `params`, `headers`, `cookies`, `auth`, `follow_redirects`, `timeout`, `extensions` - so the surface is typed and matches `sse()`, and merge the mandatory upgrade headers over any caller-supplied ones with the `Headers` union operator. --- src/httpx2/httpx2/_client.py | 32 ++++- src/httpx2/httpx2/_websockets/api.py | 156 +++++++++++++++------ tests/httpx2/websockets/test_high_level.py | 17 +++ 3 files changed, 156 insertions(+), 49 deletions(-) diff --git a/src/httpx2/httpx2/_client.py b/src/httpx2/httpx2/_client.py index 8f203b52..0b6f4b24 100644 --- a/src/httpx2/httpx2/_client.py +++ b/src/httpx2/httpx2/_client.py @@ -907,7 +907,13 @@ def websocket( keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, subprotocols: list[str] | None = None, - **kwargs: typing.Any, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, ) -> Generator[WebSocketSession]: """ Open a WebSocket session, using this client's configuration. @@ -934,7 +940,13 @@ def websocket( keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, subprotocols=subprotocols, - **kwargs, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, ) as session: yield session @@ -1693,7 +1705,13 @@ async def websocket( keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, subprotocols: list[str] | None = None, - **kwargs: typing.Any, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, ) -> AsyncGenerator[AsyncWebSocketSession]: """ Open a WebSocket session, using this client's configuration. @@ -1720,7 +1738,13 @@ async def websocket( keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, subprotocols=subprotocols, - **kwargs, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, ) as session: yield session diff --git a/src/httpx2/httpx2/_websockets/api.py b/src/httpx2/httpx2/_websockets/api.py index b81d7e77..004a7482 100644 --- a/src/httpx2/httpx2/_websockets/api.py +++ b/src/httpx2/httpx2/_websockets/api.py @@ -15,6 +15,8 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from wsproto.frame_protocol import CloseReason +from .._client import USE_CLIENT_DEFAULT +from .._models import Headers from .defaults import ( DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, @@ -34,8 +36,16 @@ if typing.TYPE_CHECKING: from httpcore2 import AsyncNetworkStream, NetworkStream - from .._client import AsyncClient, Client + from .._client import AsyncClient, Client, UseClientDefault from .._models import Response + from .._types import ( + AuthTypes, + CookieTypes, + HeaderTypes, + QueryParamTypes, + RequestExtensions, + TimeoutTypes, + ) JSONMode = typing.Literal["text", "binary"] TaskFunction = typing.TypeVar("TaskFunction") @@ -1049,12 +1059,25 @@ def _connect_ws( keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, subprotocols: list[str] | None = None, - **kwargs: typing.Any, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, ) -> typing.Generator[WebSocketSession, None, None]: - headers = kwargs.pop("headers", {}) - headers.update(_get_headers(subprotocols)) - - with client.stream("GET", url, headers=headers, **kwargs) as response: + with client.stream( + "GET", + url, + params=params, + headers=Headers(headers) | _get_headers(subprotocols), + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) as response: if response.status_code != 101: raise WebSocketUpgradeError(response) @@ -1079,7 +1102,13 @@ def connect_ws( keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, subprotocols: list[str] | None = None, - **kwargs: typing.Any, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, ) -> typing.Generator[WebSocketSession, None, None]: """ Start a sync WebSocket session. @@ -1113,9 +1142,20 @@ def connect_ws( Defaults to 20 seconds. subprotocols: Optional list of suprotocols to negotiate with the server. - **kwargs: - Additional keyword arguments that will be passed to - the [HTTPX stream()](https://www.python-httpx.org/api/#request) method. + params: + Query parameters to include in the handshake request. + headers: + Headers to include in the handshake request. + cookies: + Cookies to include in the handshake request. + auth: + Authentication to use for the handshake request. + follow_redirects: + Whether to follow redirects on the handshake request. + timeout: + Timeout configuration for the handshake request. + extensions: + Request extensions for the handshake request. Returns: A [context manager][contextlib.AbstractContextManager] @@ -1140,19 +1180,11 @@ def connect_ws( if client is None: from .._client import Client - with Client() as client: - with _connect_ws( - url, - client=client, - max_message_size_bytes=max_message_size_bytes, - queue_size=queue_size, - keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, - keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, - subprotocols=subprotocols, - **kwargs, - ) as websocket: - yield websocket + owned_client: contextlib.AbstractContextManager[Client] = Client() else: + owned_client = contextlib.nullcontext(client) + + with owned_client as client: with _connect_ws( url, client=client, @@ -1161,7 +1193,13 @@ def connect_ws( keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, subprotocols=subprotocols, - **kwargs, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, ) as websocket: yield websocket @@ -1176,12 +1214,25 @@ async def _aconnect_ws( keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, subprotocols: list[str] | None = None, - **kwargs: typing.Any, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, ) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: - headers = kwargs.pop("headers", {}) - headers.update(_get_headers(subprotocols)) - - async with client.stream("GET", url, headers=headers, **kwargs) as response: + async with client.stream( + "GET", + url, + params=params, + headers=Headers(headers) | _get_headers(subprotocols), + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) as response: if response.status_code != 101: raise WebSocketUpgradeError(response) @@ -1206,7 +1257,13 @@ async def aconnect_ws( keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, subprotocols: list[str] | None = None, - **kwargs: typing.Any, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, ) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: """ Start an async WebSocket session. @@ -1240,9 +1297,20 @@ async def aconnect_ws( Defaults to 20 seconds. subprotocols: Optional list of suprotocols to negotiate with the server. - **kwargs: - Additional keyword arguments that will be passed to - the [HTTPX stream()](https://www.python-httpx.org/api/#request) method. + params: + Query parameters to include in the handshake request. + headers: + Headers to include in the handshake request. + cookies: + Cookies to include in the handshake request. + auth: + Authentication to use for the handshake request. + follow_redirects: + Whether to follow redirects on the handshake request. + timeout: + Timeout configuration for the handshake request. + extensions: + Request extensions for the handshake request. Returns: An [async context manager][contextlib.AbstractAsyncContextManager] @@ -1267,19 +1335,11 @@ async def aconnect_ws( if client is None: from .._client import AsyncClient - async with AsyncClient() as client: - async with _aconnect_ws( - url, - client=client, - max_message_size_bytes=max_message_size_bytes, - queue_size=queue_size, - keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, - keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, - subprotocols=subprotocols, - **kwargs, - ) as websocket: - yield websocket + owned_client: contextlib.AbstractAsyncContextManager[AsyncClient] = AsyncClient() else: + owned_client = contextlib.nullcontext(client) + + async with owned_client as client: async with _aconnect_ws( url, client=client, @@ -1288,6 +1348,12 @@ async def aconnect_ws( keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, subprotocols=subprotocols, - **kwargs, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, ) as websocket: yield websocket diff --git a/tests/httpx2/websockets/test_high_level.py b/tests/httpx2/websockets/test_high_level.py index 7423963d..ece658bf 100644 --- a/tests/httpx2/websockets/test_high_level.py +++ b/tests/httpx2/websockets/test_high_level.py @@ -56,6 +56,23 @@ async def websocket_endpoint(websocket: WebSocket) -> None: assert await aws.receive_text() == "SERVER_MESSAGE" +@pytest.mark.anyio +async def test_client_websocket_forwards_request_params(server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.send_text(websocket.headers.get("x-token", "")) + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx.Client(transport=httpx.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws", headers={"x-token": "secret"}) as ws: + assert ws.receive_text() == "secret" + + async with httpx.AsyncClient(transport=httpx.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws", headers={"x-token": "secret"}) as aws: + assert await aws.receive_text() == "secret" + + @pytest.mark.anyio async def test_async_receive_reassembles_fragmented_message() -> None: server = wsproto.connection.Connection(wsproto.connection.ConnectionType.SERVER)