diff --git a/src/blacki/adk_runtime.py b/src/blacki/adk_runtime.py index 3ada299..11abf39 100644 --- a/src/blacki/adk_runtime.py +++ b/src/blacki/adk_runtime.py @@ -83,6 +83,11 @@ def build_session_db_kwargs(env: ServerEnv) -> dict[str, Any]: Note: Pool settings are only relevant for PostgreSQL. SQLite uses a single connection and ignores pool settings. """ + session_uri = build_session_service_uri(env) + is_sqlite = session_uri is None or session_uri.startswith("sqlite") + + if is_sqlite: + return {"connect_args": {"timeout": 15}} return {} @@ -102,7 +107,21 @@ def create_session_service( ) if session_service_uri.startswith(("postgresql+asyncpg://", "sqlite+aiosqlite://")): - return DatabaseSessionService(session_service_uri, **session_db_kwargs) + service = DatabaseSessionService(session_service_uri, **session_db_kwargs) + + if session_service_uri.startswith("sqlite"): + from sqlalchemy import event + + @event.listens_for(service.db_engine.sync_engine, "connect") + def set_sqlite_pragma( + dbapi_connection: Any, connection_record: Any + ) -> None: + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA journal_mode=WAL") + cursor.execute("PRAGMA synchronous=NORMAL") + cursor.close() + + return service msg = ( "Shared ADK runtime does not support the configured session URI: " diff --git a/tests/test_adk_runtime.py b/tests/test_adk_runtime.py index 8708485..2c3db0c 100644 --- a/tests/test_adk_runtime.py +++ b/tests/test_adk_runtime.py @@ -3,7 +3,7 @@ from collections.abc import AsyncIterator from pathlib import Path -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from google.adk.events import Event @@ -70,10 +70,17 @@ def test_build_session_service_uri_converts_postgresql_to_asyncpg() -> None: ) -def test_build_session_db_kwargs_returns_empty_dict() -> None: - """Test that session DB kwargs returns empty dict for SQLite.""" +def test_build_session_db_kwargs_returns_timeout_for_sqlite() -> None: + """Test that session DB kwargs returns timeout config for SQLite.""" env = _build_server_env() + assert build_session_db_kwargs(env) == {"connect_args": {"timeout": 15}} + + +def test_build_session_db_kwargs_returns_empty_for_postgres() -> None: + """Test that session DB kwargs returns empty dict for PostgreSQL.""" + env = _build_server_env(agent_engine="postgresql://localhost:5432/db") + assert build_session_db_kwargs(env) == {} @@ -95,6 +102,44 @@ def test_create_session_service_with_sqlite_uri(tmp_path: Path) -> None: assert isinstance(session_service, DatabaseSessionService) +def test_create_session_service_sqlite_event_listener(tmp_path: Path) -> None: + """Test that SQLite session services execute PRAGMAs on connect.""" + db_path = tmp_path / "sessions.db" + + with patch("sqlalchemy.event.listens_for") as mock_listens_for: + create_session_service( + f"sqlite+aiosqlite:///{db_path}", + {}, + ) + + # Get the decorator returned by event.listens_for and the function it wrapped + decorator = mock_listens_for.return_value + set_sqlite_pragma = decorator.call_args[0][0] + + # Call the wrapped function with a mock connection + mock_conn = MagicMock() + set_sqlite_pragma(mock_conn, None) + + # Verify it executed the pragmas + cursor = mock_conn.cursor.return_value + assert cursor.execute.call_count == 2 + cursor.execute.assert_any_call("PRAGMA journal_mode=WAL") + cursor.execute.assert_any_call("PRAGMA synchronous=NORMAL") + cursor.close.assert_called_once() + + +@patch("blacki.adk_runtime.DatabaseSessionService") +def test_create_session_service_with_postgres_uri(mock_db_service: AsyncMock) -> None: + """Test that PostgreSQL session services skip SQLite PRAGMAs.""" + session_service = create_session_service( + "postgresql+asyncpg://user:pass@localhost/db", + {}, + ) + + # Since we mocked it, it returns the mock object + assert session_service == mock_db_service.return_value + + def test_create_session_service_rejects_unsupported_uri() -> None: """Test that unsupported session URIs fail fast.""" with pytest.raises(ValueError, match="does not support"):