Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/blacki/adk_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def build_session_db_kwargs(env: ServerEnv) -> dict[str, Any]:
Note: Pool settings are only relevant for PostgreSQL. SQLite uses
a single connection and ignores pool settings.
"""
session_uri = build_session_service_uri(env)
is_sqlite = session_uri is None or session_uri.startswith("sqlite")

if is_sqlite:
return {"connect_args": {"timeout": 15}}
return {}


Expand All @@ -102,7 +107,21 @@ def create_session_service(
)

if session_service_uri.startswith(("postgresql+asyncpg://", "sqlite+aiosqlite://")):
return DatabaseSessionService(session_service_uri, **session_db_kwargs)
service = DatabaseSessionService(session_service_uri, **session_db_kwargs)

if session_service_uri.startswith("sqlite"):
from sqlalchemy import event

@event.listens_for(service.db_engine.sync_engine, "connect")
def set_sqlite_pragma(
dbapi_connection: Any, connection_record: Any
) -> None:
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.close()
Comment on lines +119 to +122
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To ensure proper resource cleanup even if an exception occurs during the execution of the PRAGMA statements, wrap the cursor operations in a try...finally block. Additionally, consider enabling PRAGMA foreign_keys=ON to enforce referential integrity constraints, keeping it consistent with how SQLite is configured in src/blacki/storage/sqlite.py.

Suggested change
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.close()
cursor = dbapi_connection.cursor()
try:
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA foreign_keys=ON")
finally:
cursor.close()


return service

msg = (
"Shared ADK runtime does not support the configured session URI: "
Expand Down
51 changes: 48 additions & 3 deletions tests/test_adk_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from collections.abc import AsyncIterator
from pathlib import Path
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from google.adk.events import Event
Expand Down Expand Up @@ -70,10 +70,17 @@ def test_build_session_service_uri_converts_postgresql_to_asyncpg() -> None:
)


def test_build_session_db_kwargs_returns_empty_dict() -> None:
"""Test that session DB kwargs returns empty dict for SQLite."""
def test_build_session_db_kwargs_returns_timeout_for_sqlite() -> None:
"""Test that session DB kwargs returns timeout config for SQLite."""
env = _build_server_env()

assert build_session_db_kwargs(env) == {"connect_args": {"timeout": 15}}


def test_build_session_db_kwargs_returns_empty_for_postgres() -> None:
"""Test that session DB kwargs returns empty dict for PostgreSQL."""
env = _build_server_env(agent_engine="postgresql://localhost:5432/db")

assert build_session_db_kwargs(env) == {}


Expand All @@ -95,6 +102,44 @@ def test_create_session_service_with_sqlite_uri(tmp_path: Path) -> None:
assert isinstance(session_service, DatabaseSessionService)


def test_create_session_service_sqlite_event_listener(tmp_path: Path) -> None:
"""Test that SQLite session services execute PRAGMAs on connect."""
db_path = tmp_path / "sessions.db"

with patch("sqlalchemy.event.listens_for") as mock_listens_for:
create_session_service(
f"sqlite+aiosqlite:///{db_path}",
{},
)

# Get the decorator returned by event.listens_for and the function it wrapped
decorator = mock_listens_for.return_value
set_sqlite_pragma = decorator.call_args[0][0]

# Call the wrapped function with a mock connection
mock_conn = MagicMock()
set_sqlite_pragma(mock_conn, None)

# Verify it executed the pragmas
cursor = mock_conn.cursor.return_value
assert cursor.execute.call_count == 2
cursor.execute.assert_any_call("PRAGMA journal_mode=WAL")
cursor.execute.assert_any_call("PRAGMA synchronous=NORMAL")
cursor.close.assert_called_once()


@patch("blacki.adk_runtime.DatabaseSessionService")
def test_create_session_service_with_postgres_uri(mock_db_service: AsyncMock) -> None:
"""Test that PostgreSQL session services skip SQLite PRAGMAs."""
session_service = create_session_service(
"postgresql+asyncpg://user:pass@localhost/db",
{},
)

# Since we mocked it, it returns the mock object
assert session_service == mock_db_service.return_value


def test_create_session_service_rejects_unsupported_uri() -> None:
"""Test that unsupported session URIs fail fast."""
with pytest.raises(ValueError, match="does not support"):
Expand Down
Loading