Skip to content

Commit 797946e

Browse files
Skn0ttCopilot
andauthored
fix: support async predicates in page.expect_request/expect_response (#3055)
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 1f847dd commit 797946e

10 files changed

Lines changed: 137 additions & 14 deletions

File tree

CLAUDE.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ This is the recurring high-stakes task. Use the dedicated skill:
4848

4949
It documents the full process: the upstream commit-range diff over `docs/src/api/`, how to classify each commit (PORT / MISMATCH / N/A), how to handle the `langs:` filter, the recurring failure modes, and the tests/sync-mirroring conventions.
5050

51+
## Working on PRs
52+
53+
- Never post comments, replies, or reviews on GitHub PRs/issues under my account without my explicit approval. Draft the proposed text and wait for me to approve before sending.
54+
5155
## House style
5256

5357
- Don't hand-edit generated files.

playwright/_impl/_helper.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import (
2323
TYPE_CHECKING,
2424
Any,
25+
Awaitable,
2526
Callable,
2627
Dict,
2728
List,
@@ -54,8 +55,12 @@
5455
from playwright._impl._network import Request, Response, Route, WebSocketRoute
5556

5657
URLMatch = Union[str, Pattern[str], Callable[[str], bool]]
57-
URLMatchRequest = Union[str, Pattern[str], Callable[["Request"], bool]]
58-
URLMatchResponse = Union[str, Pattern[str], Callable[["Response"], bool]]
58+
URLMatchRequest = Union[
59+
str, Pattern[str], Callable[["Request"], Union[bool, Awaitable[bool]]]
60+
]
61+
URLMatchResponse = Union[
62+
str, Pattern[str], Callable[["Response"], Union[bool, Awaitable[bool]]]
63+
]
5964
RouteHandlerCallback = Union[
6065
Callable[["Route"], Any], Callable[["Route", "Request"], Any]
6166
]

playwright/_impl/_page.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import (
2323
TYPE_CHECKING,
2424
Any,
25+
Awaitable,
2526
Callable,
2627
Dict,
2728
List,
@@ -1278,7 +1279,7 @@ def expect_request(
12781279
urlOrPredicate: URLMatchRequest,
12791280
timeout: float = None,
12801281
) -> EventContextManagerImpl[Request]:
1281-
def my_predicate(request: Request) -> bool:
1282+
def my_predicate(request: Request) -> Union[bool, Awaitable[bool]]:
12821283
if not callable(urlOrPredicate):
12831284
return url_matches(
12841285
self._browser_context._base_url,
@@ -1310,7 +1311,7 @@ def expect_response(
13101311
urlOrPredicate: URLMatchResponse,
13111312
timeout: float = None,
13121313
) -> EventContextManagerImpl[Response]:
1313-
def my_predicate(request: Response) -> bool:
1314+
def my_predicate(request: Response) -> Union[bool, Awaitable[bool]]:
13141315
if not callable(urlOrPredicate):
13151316
return url_matches(
13161317
self._browser_context._base_url,

playwright/_impl/_waiter.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import inspect
1617
import math
1718
import uuid
1819
from asyncio.tasks import Task
19-
from typing import Any, Callable, List, Tuple, Union
20+
from typing import Any, Callable, List, Optional, Tuple, Union
2021

2122
from pyee import EventEmitter
2223

@@ -71,9 +72,11 @@ def reject_on_event(
7172
error: Union[Error, Callable[..., Error]],
7273
predicate: Callable = None,
7374
) -> None:
75+
def on_match() -> None:
76+
self._reject(error() if callable(error) else error)
77+
7478
def listener(event_data: Any = None) -> None:
75-
if not predicate or predicate(event_data):
76-
self._reject(error() if callable(error) else error)
79+
self._evaluate_predicate(predicate, event_data, on_match)
7780

7881
emitter.on(event, listener)
7982
self._registered_listeners.append((emitter, event, listener))
@@ -117,12 +120,43 @@ def wait_for_event(
117120
predicate: Callable = None,
118121
) -> None:
119122
def listener(event_data: Any = None) -> None:
120-
if not predicate or predicate(event_data):
121-
self._fulfill(event_data)
123+
self._evaluate_predicate(
124+
predicate, event_data, lambda: self._fulfill(event_data)
125+
)
122126

123127
emitter.on(event, listener)
124128
self._registered_listeners.append((emitter, event, listener))
125129

130+
def _evaluate_predicate(
131+
self,
132+
predicate: Optional[Callable],
133+
event_data: Any,
134+
on_match: Callable[[], None],
135+
) -> None:
136+
if predicate is None:
137+
on_match()
138+
return
139+
try:
140+
result = predicate(event_data)
141+
except Exception as e:
142+
self._reject(e)
143+
return
144+
if inspect.iscoroutine(result):
145+
146+
async def _await_predicate(coro: Any) -> None:
147+
try:
148+
matched = await coro
149+
except Exception as e:
150+
self._reject(e)
151+
return
152+
if matched and not self._result.done():
153+
on_match()
154+
155+
self._pending_tasks.append(self._loop.create_task(_await_predicate(result)))
156+
return
157+
if result:
158+
on_match()
159+
126160
def result(self) -> asyncio.Future:
127161
return self._result
128162

playwright/async_api/_generated.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12306,7 +12306,9 @@ def expect_popup(
1230612306
def expect_request(
1230712307
self,
1230812308
url_or_predicate: typing.Union[
12309-
str, typing.Pattern[str], typing.Callable[["Request"], bool]
12309+
str,
12310+
typing.Pattern[str],
12311+
typing.Callable[["Request"], typing.Union[bool, typing.Awaitable[bool]]],
1231012312
],
1231112313
*,
1231212314
timeout: typing.Optional[float] = None,
@@ -12331,7 +12333,7 @@ def expect_request(
1233112333

1233212334
Parameters
1233312335
----------
12334-
url_or_predicate : Union[Callable[[Request], bool], Pattern[str], str]
12336+
url_or_predicate : Union[Callable[[Request], Union[bool, typing.Awaitable[bool]]], Pattern[str], str]
1233512337
Request URL string, regex or predicate receiving `Request` object. When a `baseURL` via the context options was
1233612338
provided and the passed URL is a path, it gets merged via the
1233712339
[`new URL()`](https://developer.mozilla.org/en-US/docs/Web/API/URL/URL) constructor.
@@ -12384,7 +12386,9 @@ def expect_request_finished(
1238412386
def expect_response(
1238512387
self,
1238612388
url_or_predicate: typing.Union[
12387-
str, typing.Pattern[str], typing.Callable[["Response"], bool]
12389+
str,
12390+
typing.Pattern[str],
12391+
typing.Callable[["Response"], typing.Union[bool, typing.Awaitable[bool]]],
1238812392
],
1238912393
*,
1239012394
timeout: typing.Optional[float] = None,
@@ -12411,7 +12415,7 @@ def expect_response(
1241112415

1241212416
Parameters
1241312417
----------
12414-
url_or_predicate : Union[Callable[[Response], bool], Pattern[str], str]
12418+
url_or_predicate : Union[Callable[[Response], Union[bool, typing.Awaitable[bool]]], Pattern[str], str]
1241512419
Request URL string, regex or predicate receiving `Response` object. When a `baseURL` via the context options was
1241612420
provided and the passed URL is a path, it gets merged via the
1241712421
[`new URL()`](https://developer.mozilla.org/en-US/docs/Web/API/URL/URL) constructor.

scripts/documentation_provider.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,16 @@ def serialize_python_type(self, value: Any, direction: str) -> str:
408408
return f"{{{', '.join(signature)}}}"
409409
if origin == Union:
410410
args = get_args(value)
411+
if not self.is_async:
412+
# Sync API doesn't accept awaitable callbacks; drop the
413+
# Awaitable arm so docstring types match the sync signature.
414+
args = tuple(
415+
a
416+
for a in args
417+
if str(get_origin(a)) != "<class 'collections.abc.Awaitable'>"
418+
)
419+
if len(args) == 1:
420+
return self.serialize_python_type(args[0], direction)
411421
if len(args) == 2 and str(args[1]) == "<class 'NoneType'>":
412422
return self.make_optional(
413423
self.serialize_python_type(args[0], direction)

scripts/expected_api_mismatch.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@ Parameter type mismatch in BrowserContext.route_web_socket(handler=): documented
2020
Parameter type mismatch in Page.route_web_socket(handler=): documented as Callable[[WebSocketRoute], Union[Any, Any]], code has Callable[[WebSocketRoute], Any]
2121
Parameter type mismatch in WebSocketRoute.on_close(handler=): documented as Callable[[Union[int, undefined]], Union[Any, Any]], code has Callable[[Union[int, None], Union[str, None]], Any]
2222
Parameter type mismatch in WebSocketRoute.on_message(handler=): documented as Callable[[str], Union[Any, Any]], code has Callable[[Union[bytes, str]], Any]
23+
24+
# Async API additionally accepts an `async def` predicate.
25+
Parameter type mismatch in Page.expect_request(url_or_predicate=): documented as Union[Callable[[Request], bool], Pattern[str], str], code has Union[Callable[[Request], Union[bool, typing.Awaitable[bool]]], Pattern[str], str]
26+
Parameter type mismatch in Page.expect_response(url_or_predicate=): documented as Union[Callable[[Response], bool], Pattern[str], str], code has Union[Callable[[Response], Union[bool, typing.Awaitable[bool]]], Pattern[str], str]

scripts/generate_api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
from playwright._impl._video import Video
5757
from playwright._impl._web_error import WebError
5858

59+
SYNC_API = False
60+
5961

6062
def process_type(value: Any, param: bool = False) -> str:
6163
value = str(value)
@@ -65,6 +67,15 @@ def process_type(value: Any, param: bool = False) -> str:
6567
value = re.sub(r"playwright\._impl\._api_structures.([\w]+)", r"\1", value)
6668
value = re.sub(r"playwright\._impl\.[\w]+\.([\w]+)", r'"\1"', value)
6769
value = re.sub(r"typing.Literal", "Literal", value)
70+
if SYNC_API:
71+
# Sync API does not accept awaitable callbacks; collapse
72+
# Union[X, Awaitable[X]] (used for predicates the async API also
73+
# accepts as `async def`) down to just X.
74+
value = re.sub(
75+
r"typing\.Union\[([^\[\],]+),\s*typing\.Awaitable\[\1\]\]",
76+
r"\1",
77+
value,
78+
)
6879
if param:
6980
value = re.sub(r"^typing.Union\[([^,]+), None\]$", r"\1 = None", value)
7081
value = re.sub(

scripts/generate_sync_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from types import FunctionType
2020
from typing import Any
2121

22+
import generate_api
2223
from documentation_provider import DocumentationProvider
2324
from generate_api import (
2425
api_globals,
@@ -33,6 +34,8 @@
3334
signature,
3435
)
3536

37+
generate_api.SYNC_API = True
38+
3639
documentation_provider = DocumentationProvider(False)
3740

3841

tests/async/test_page.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
import re
1818
from pathlib import Path
19-
from typing import Dict, List, Optional
19+
from typing import Any, Dict, List, Optional
2020

2121
import pytest
2222

@@ -352,6 +352,53 @@ async def test_wait_for_response_should_work_with_predicate(
352352
assert response.url == server.PREFIX + "/digits/2.png"
353353

354354

355+
async def test_wait_for_response_should_work_with_async_predicate(
356+
page: Page, server: Server
357+
) -> None:
358+
await page.goto(server.EMPTY_PAGE)
359+
360+
async def predicate(response: Any) -> bool:
361+
await asyncio.sleep(0)
362+
return response.url == server.PREFIX + "/digits/2.png"
363+
364+
async with page.expect_response(predicate) as response_info:
365+
await page.evaluate(
366+
"""() => {
367+
fetch('/digits/1.png')
368+
fetch('/digits/2.png')
369+
fetch('/digits/3.png')
370+
}"""
371+
)
372+
response = await response_info.value
373+
assert response.url == server.PREFIX + "/digits/2.png"
374+
375+
376+
async def test_expect_response_should_reject_when_async_predicate_throws(
377+
page: Page, server: Server
378+
) -> None:
379+
await page.goto(server.EMPTY_PAGE)
380+
381+
async def predicate(response: Any) -> bool:
382+
raise Exception("Async oops!")
383+
384+
with pytest.raises(Exception, match="Async oops!"):
385+
async with page.expect_response(predicate):
386+
await page.evaluate("() => fetch('/digits/1.png')")
387+
388+
389+
async def test_expect_response_should_reject_when_sync_predicate_throws(
390+
page: Page, server: Server
391+
) -> None:
392+
await page.goto(server.EMPTY_PAGE)
393+
394+
def predicate(response: Any) -> bool:
395+
raise Exception("Sync oops!")
396+
397+
with pytest.raises(Exception, match="Sync oops!"):
398+
async with page.expect_response(predicate):
399+
await page.evaluate("() => fetch('/digits/1.png')")
400+
401+
355402
async def test_wait_for_response_should_work_with_no_timeout(
356403
page: Page, server: Server
357404
) -> None:

0 commit comments

Comments
 (0)