Skip to content

Commit 2f2e03b

Browse files
DeanChensjcopybara-github
authored andcommitted
fix: Fix event loop closed and thread leak errors in unit tests
- Wrap run_live generator in aclosing/Aclosing in InMemoryRunner and test runners to ensure they are closed when tests exit early. - Refactor test_streaming.py to use a single, robust TestRunner helper, eliminating 10 duplicated CustomTestRunner definitions. - Implement close() and async context manager support directly in PerAgentDatabaseSessionService in local_storage.py to clean up SQLAlchemy engines and connection threads without changing BaseSessionService. - Wrap PerAgentDatabaseSessionService in async with blocks (or try-finally) in test_local_storage.py to ensure they are closed. Co-authored-by: Shangjie Chen <deanchen@google.com> PiperOrigin-RevId: 936950444
1 parent dbd4bb0 commit 2f2e03b

5 files changed

Lines changed: 143 additions & 539 deletions

File tree

src/google/adk/cli/utils/local_storage.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,21 @@ async def append_event(self, session: Session, event: Event) -> Event:
236236
service = await self._get_service(session.app_name)
237237
return await service.append_event(session, event)
238238

239+
async def close(self) -> None:
240+
"""Closes all underlying session services."""
241+
for service in self._services.values():
242+
if hasattr(service, "close"):
243+
await service.close()
244+
self._services.clear()
245+
246+
async def __aenter__(self) -> PerAgentDatabaseSessionService:
247+
"""Enters the async context manager."""
248+
return self
249+
250+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
251+
"""Exits the async context manager and closes the service."""
252+
await self.close()
253+
239254

240255
class PerAgentFileArtifactService(BaseArtifactService):
241256
"""Routes artifact storage to per-agent `.adk/artifacts` folders."""

tests/unittests/cli/utils/test_local_storage.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,20 @@ async def test_per_agent_session_service_creates_scoped_dot_adk(
3838
agent_a.mkdir()
3939
agent_b.mkdir()
4040

41-
service = PerAgentDatabaseSessionService(agents_root=tmp_path)
41+
async with PerAgentDatabaseSessionService(agents_root=tmp_path) as service:
42+
await service.create_session(app_name="agent_a", user_id="user_a")
43+
await service.create_session(app_name="agent_b", user_id="user_b")
4244

43-
await service.create_session(app_name="agent_a", user_id="user_a")
44-
await service.create_session(app_name="agent_b", user_id="user_b")
45+
assert (agent_a / ".adk" / "session.db").exists()
46+
assert (agent_b / ".adk" / "session.db").exists()
4547

46-
assert (agent_a / ".adk" / "session.db").exists()
47-
assert (agent_b / ".adk" / "session.db").exists()
48+
agent_a_sessions = await service.list_sessions(app_name="agent_a")
49+
agent_b_sessions = await service.list_sessions(app_name="agent_b")
4850

49-
agent_a_sessions = await service.list_sessions(app_name="agent_a")
50-
agent_b_sessions = await service.list_sessions(app_name="agent_b")
51-
52-
assert len(agent_a_sessions.sessions) == 1
53-
assert agent_a_sessions.sessions[0].app_name == "agent_a"
54-
assert len(agent_b_sessions.sessions) == 1
55-
assert agent_b_sessions.sessions[0].app_name == "agent_b"
51+
assert len(agent_a_sessions.sessions) == 1
52+
assert agent_a_sessions.sessions[0].app_name == "agent_a"
53+
assert len(agent_b_sessions.sessions) == 1
54+
assert agent_b_sessions.sessions[0].app_name == "agent_b"
5655

5756

5857
@pytest.mark.asyncio
@@ -68,26 +67,28 @@ async def test_per_agent_session_service_respects_app_name_alias(
6867
per_agent=True,
6968
app_name_to_dir={logical_name: folder_name},
7069
)
70+
try:
71+
session = await service.create_session(
72+
app_name=logical_name,
73+
user_id="user",
74+
)
7175

72-
session = await service.create_session(
73-
app_name=logical_name,
74-
user_id="user",
75-
)
76-
77-
assert session.app_name == logical_name
78-
assert (tmp_path / folder_name / ".adk" / "session.db").exists()
76+
assert session.app_name == logical_name
77+
assert (tmp_path / folder_name / ".adk" / "session.db").exists()
78+
finally:
79+
if isinstance(service, PerAgentDatabaseSessionService):
80+
await service.close()
7981

8082

8183
@pytest.mark.asyncio
8284
async def test_per_agent_session_service_routes_built_in_agents_to_root_dot_adk(
8385
tmp_path: Path,
8486
) -> None:
85-
service = PerAgentDatabaseSessionService(agents_root=tmp_path)
87+
async with PerAgentDatabaseSessionService(agents_root=tmp_path) as service:
88+
await service.create_session(app_name="__helper", user_id="user")
8689

87-
await service.create_session(app_name="__helper", user_id="user")
88-
89-
assert not (tmp_path / "__helper").exists()
90-
assert (tmp_path / ".adk" / "session.db").exists()
90+
assert not (tmp_path / "__helper").exists()
91+
assert (tmp_path / ".adk" / "session.db").exists()
9192

9293

9394
def test_create_local_database_session_service_returns_sqlite(
@@ -106,22 +107,25 @@ async def test_per_agent_session_service_get_user_state(tmp_path: Path) -> None:
106107
agent_a.mkdir()
107108
agent_b.mkdir()
108109

109-
service = PerAgentDatabaseSessionService(agents_root=tmp_path)
110-
111-
session_a = await service.create_session(app_name="agent_a", user_id="user_a")
112-
await service.append_event(
113-
session_a,
114-
Event(
115-
author="system",
116-
actions=EventActions(state_delta={"user:profile": {"name": "Alice"}}),
117-
),
118-
)
119-
120-
state_a = await service.get_user_state(app_name="agent_a", user_id="user_a")
121-
state_b = await service.get_user_state(app_name="agent_b", user_id="user_b")
122-
123-
assert state_a == {"profile": {"name": "Alice"}}
124-
assert not state_b
110+
async with PerAgentDatabaseSessionService(agents_root=tmp_path) as service:
111+
session_a = await service.create_session(
112+
app_name="agent_a", user_id="user_a"
113+
)
114+
await service.append_event(
115+
session_a,
116+
Event(
117+
author="system",
118+
actions=EventActions(
119+
state_delta={"user:profile": {"name": "Alice"}}
120+
),
121+
),
122+
)
123+
124+
state_a = await service.get_user_state(app_name="agent_a", user_id="user_a")
125+
state_b = await service.get_user_state(app_name="agent_b", user_id="user_b")
126+
127+
assert state_a == {"profile": {"name": "Alice"}}
128+
assert not state_b
125129

126130

127131
@pytest.mark.asyncio

tests/unittests/streaming/test_multi_agent_streaming.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,13 @@ async def consume_responses(session: testing_utils.Session):
9494
live_request_queue=live_request_queue,
9595
run_config=run_config or testing_utils.RunConfig(),
9696
)
97-
async for response in run_res:
98-
collected_responses.append(response)
99-
if len(collected_responses) >= 5:
100-
return
97+
from contextlib import aclosing
98+
99+
async with aclosing(run_res) as agen:
100+
async for response in agen:
101+
collected_responses.append(response)
102+
if len(collected_responses) >= 5:
103+
return
101104

102105
try:
103106
session = self.session

0 commit comments

Comments
 (0)