|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import asyncio |
| 16 | +import inspect |
16 | 17 | import math |
17 | 18 | import uuid |
18 | 19 | from asyncio.tasks import Task |
19 | | -from typing import Any, Callable, List, Tuple, Union |
| 20 | +from typing import Any, Callable, List, Optional, Tuple, Union |
20 | 21 |
|
21 | 22 | from pyee import EventEmitter |
22 | 23 |
|
@@ -71,9 +72,11 @@ def reject_on_event( |
71 | 72 | error: Union[Error, Callable[..., Error]], |
72 | 73 | predicate: Callable = None, |
73 | 74 | ) -> None: |
| 75 | + def on_match() -> None: |
| 76 | + self._reject(error() if callable(error) else error) |
| 77 | + |
74 | 78 | def listener(event_data: Any = None) -> None: |
75 | | - if not predicate or predicate(event_data): |
76 | | - self._reject(error() if callable(error) else error) |
| 79 | + self._evaluate_predicate(predicate, event_data, on_match) |
77 | 80 |
|
78 | 81 | emitter.on(event, listener) |
79 | 82 | self._registered_listeners.append((emitter, event, listener)) |
@@ -117,12 +120,43 @@ def wait_for_event( |
117 | 120 | predicate: Callable = None, |
118 | 121 | ) -> None: |
119 | 122 | def listener(event_data: Any = None) -> None: |
120 | | - if not predicate or predicate(event_data): |
121 | | - self._fulfill(event_data) |
| 123 | + self._evaluate_predicate( |
| 124 | + predicate, event_data, lambda: self._fulfill(event_data) |
| 125 | + ) |
122 | 126 |
|
123 | 127 | emitter.on(event, listener) |
124 | 128 | self._registered_listeners.append((emitter, event, listener)) |
125 | 129 |
|
| 130 | + def _evaluate_predicate( |
| 131 | + self, |
| 132 | + predicate: Optional[Callable], |
| 133 | + event_data: Any, |
| 134 | + on_match: Callable[[], None], |
| 135 | + ) -> None: |
| 136 | + if predicate is None: |
| 137 | + on_match() |
| 138 | + return |
| 139 | + try: |
| 140 | + result = predicate(event_data) |
| 141 | + except Exception as e: |
| 142 | + self._reject(e) |
| 143 | + return |
| 144 | + if inspect.iscoroutine(result): |
| 145 | + |
| 146 | + async def _await_predicate(coro: Any) -> None: |
| 147 | + try: |
| 148 | + matched = await coro |
| 149 | + except Exception as e: |
| 150 | + self._reject(e) |
| 151 | + return |
| 152 | + if matched and not self._result.done(): |
| 153 | + on_match() |
| 154 | + |
| 155 | + self._pending_tasks.append(self._loop.create_task(_await_predicate(result))) |
| 156 | + return |
| 157 | + if result: |
| 158 | + on_match() |
| 159 | + |
126 | 160 | def result(self) -> asyncio.Future: |
127 | 161 | return self._result |
128 | 162 |
|
|
0 commit comments