Skip to content

Commit 930bfa6

Browse files
DeanChensjcopybara-github
authored andcommitted
test: Fix resource leaks and brittle synchronization in streaming tests
- Wrap streaming runners in aclosing to prevent background task leaks and ensure clean test execution. - Implement a test-only _wait_for_queue_empty helper inside test_streaming.py to eliminate brittle asyncio.sleep(0.1) calls, resulting in robust and faster test execution. - Enhance custom test event loops to cancel and gather all pending tasks during loop teardown, preventing PytestUnraisableExceptionWarnings. Co-authored-by: Shangjie Chen <deanchen@google.com> PiperOrigin-RevId: 938235975
1 parent 844cd36 commit 930bfa6

1 file changed

Lines changed: 61 additions & 22 deletions

File tree

tests/unittests/streaming/test_streaming.py

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@
2727
from .. import testing_utils
2828

2929

30+
async def _wait_for_queue_empty(queue: LiveRequestQueue):
31+
"""Wait until the queue is empty and the background consumer has finished."""
32+
while not queue._queue.empty():
33+
await asyncio.sleep(0)
34+
# Give opportunity for _send_to_model to finish processing (e.g. append_event)
35+
for _ in range(10):
36+
await asyncio.sleep(0)
37+
38+
3039
class StreamingTestRunner(testing_utils.InMemoryRunner):
3140
"""A robust runner for streaming tests that avoids resource leaks."""
3241

@@ -47,6 +56,17 @@ def _run_with_loop(self, coro):
4756
except (asyncio.TimeoutError, asyncio.CancelledError):
4857
pass
4958
finally:
59+
# Cancel all pending tasks to prevent leaks and warnings
60+
pending = asyncio.all_tasks(loop)
61+
for task in pending:
62+
task.cancel()
63+
if pending:
64+
try:
65+
loop.run_until_complete(
66+
asyncio.gather(*pending, return_exceptions=True)
67+
)
68+
except Exception: # pylint: disable=broad-except
69+
pass
5070
loop.close()
5171
asyncio.set_event_loop(old_loop)
5272

@@ -68,7 +88,7 @@ async def consume_responses(session: testing_utils.Session):
6888
async for response in agen:
6989
collected_responses.append(response)
7090
if len(collected_responses) >= self.max_responses:
71-
await asyncio.sleep(0.1)
91+
await _wait_for_queue_empty(live_request_queue)
7292
return
7393

7494
self._run_with_loop(
@@ -887,6 +907,17 @@ def _run_with_loop(self, coro: Awaitable[Any]) -> None:
887907
except (asyncio.TimeoutError, asyncio.CancelledError):
888908
pass
889909
finally:
910+
# Cancel all pending tasks to prevent leaks and warnings
911+
pending = asyncio.all_tasks(loop)
912+
for task in pending:
913+
task.cancel()
914+
if pending:
915+
try:
916+
loop.run_until_complete(
917+
asyncio.gather(*pending, return_exceptions=True)
918+
)
919+
except Exception: # pylint: disable=broad-except
920+
pass
890921
loop.close()
891922
asyncio.set_event_loop(old_loop)
892923

@@ -899,13 +930,16 @@ def run_live(
899930
collected = []
900931

901932
async def consume(session: testing_utils.Session):
902-
async for response in self.runner.run_live(
933+
run_res = self.runner.run_live(
903934
session=session,
904935
live_request_queue=live_request_queue,
905-
):
906-
collected.append(response)
907-
if len(collected) >= max_responses:
908-
return
936+
)
937+
async with aclosing(run_res) as agen:
938+
async for response in agen:
939+
collected.append(response)
940+
if len(collected) >= max_responses:
941+
await _wait_for_queue_empty(live_request_queue)
942+
return
909943

910944
self._run_with_loop(asyncio.wait_for(consume(self.session), timeout=5.0))
911945
return collected
@@ -977,25 +1011,30 @@ def capturing_method(*args, **kwargs) -> Any:
9771011

9781012
async def consume(session: testing_utils.Session):
9791013
nonlocal not_registered_before_call
980-
async for response in runner.runner.run_live(
1014+
run_res = runner.runner.run_live(
9811015
session=session,
9821016
live_request_queue=live_request_queue,
983-
):
984-
collected.append(response)
985-
# On the first non-function-call event, verify the tool is not
986-
# yet registered (lazy registration).
987-
active = (
988-
captured_context.active_streaming_tools if captured_context else None
989-
)
990-
if (
991-
not_registered_before_call is None
992-
and not response.get_function_calls()
993-
):
994-
not_registered_before_call = (
995-
active is None or "monitor_video_stream" not in active
1017+
)
1018+
async with aclosing(run_res) as agen:
1019+
async for response in agen:
1020+
collected.append(response)
1021+
# On the first non-function-call event, verify the tool is not
1022+
# yet registered (lazy registration).
1023+
active = (
1024+
captured_context.active_streaming_tools
1025+
if captured_context
1026+
else None
9961027
)
997-
if len(collected) >= 4:
998-
return
1028+
if (
1029+
not_registered_before_call is None
1030+
and not response.get_function_calls()
1031+
):
1032+
not_registered_before_call = (
1033+
active is None or "monitor_video_stream" not in active
1034+
)
1035+
if len(collected) >= 4:
1036+
await _wait_for_queue_empty(live_request_queue)
1037+
return
9991038

10001039
runner._run_with_loop(asyncio.wait_for(consume(runner.session), timeout=5.0))
10011040

0 commit comments

Comments
 (0)