diff --git a/.env.example b/.env.example index c7b0ea6..f7e6f60 100644 --- a/.env.example +++ b/.env.example @@ -23,13 +23,9 @@ GOOGLE_API_KEY=your-google-ai-studio-key-here # BROWSER_USE_API_KEY=bu_your_key_here # Database Configuration -# asyncpg requires ssl=require instead of sslmode=require -DATABASE_URL=postgresql://user:pass@host:port/dbname?ssl=require -DB_POOL_PRE_PING=true -DB_POOL_RECYCLE=1800 -DB_POOL_SIZE=5 -DB_MAX_OVERFLOW=10 -DB_POOL_TIMEOUT=30 +# SQLite database for reminders, calories, workouts, and preferences +# Default: {AGENT_DIR}/.adk/tools.db (inside container: /app/src/.adk/tools.db) +# SQLITE_PATH=/data/tools.db # Server Configuration LOG_LEVEL=INFO diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index 0521691..5629115 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -98,7 +98,6 @@ jobs: TELEGRAM_ENABLED: ${{ secrets.TELEGRAM_ENABLED }} TELEGRAM_BOT_TOKEN: ${{ secrets.TELEGRAM_BOT_TOKEN }} TELEGRAM_TOOL_NOTIFICATIONS: ${{ secrets.TELEGRAM_TOOL_NOTIFICATIONS }} - DATABASE_URL: ${{ secrets.DATABASE_URL }} BRAVE_SEARCH_API_KEY: ${{ secrets.BRAVE_SEARCH_API_KEY }} BROWSER_USE_API_KEY: ${{ secrets.BROWSER_USE_API_KEY }} SANDBOX_ENABLED: ${{ secrets.SANDBOX_ENABLED }} @@ -131,7 +130,6 @@ jobs: TELEGRAM_ENABLED="${TELEGRAM_ENABLED}" TELEGRAM_BOT_TOKEN="${TELEGRAM_BOT_TOKEN}" TELEGRAM_TOOL_NOTIFICATIONS="${TELEGRAM_TOOL_NOTIFICATIONS}" - DATABASE_URL="${DATABASE_URL}" BRAVE_SEARCH_API_KEY="${BRAVE_SEARCH_API_KEY}" BROWSER_USE_API_KEY="${BROWSER_USE_API_KEY}" SANDBOX_ENABLED="${SANDBOX_ENABLED}" @@ -210,7 +208,6 @@ jobs: TELEGRAM_ENABLED="${TELEGRAM_ENABLED}" TELEGRAM_BOT_TOKEN="${TELEGRAM_BOT_TOKEN}" TELEGRAM_TOOL_NOTIFICATIONS="${TELEGRAM_TOOL_NOTIFICATIONS}" - DATABASE_URL="${DATABASE_URL}" BRAVE_SEARCH_API_KEY="${BRAVE_SEARCH_API_KEY}" BROWSER_USE_API_KEY="${BROWSER_USE_API_KEY}" HOST_PORT="${HOST_PORT}" diff --git a/Dockerfile b/Dockerfile index a5beba8..8c3a700 100644 --- a/Dockerfile +++ b/Dockerfile @@ -36,12 +36,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # ============================================================================ FROM python:3.13-slim AS runtime -# Install system dependencies -# - netcat-openbsd: for checking DB readiness (used in entrypoint.sh) -RUN apt-get update && apt-get install -y --no-install-recommends \ - netcat-openbsd \ - && rm -rf /var/lib/apt/lists/* - # Create non-root user for security (matching common host UID 1000) RUN groupadd -g 1000 app && \ useradd -u 1000 -g app -s /bin/sh -m app diff --git a/entrypoint.sh b/entrypoint.sh index ccfb4a7..8959ebc 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,25 +1,4 @@ #!/bin/sh set -e -# Wait for the database if DATABASE_URL is set and looks like a postgres url -if echo "$DATABASE_URL" | grep -q "postgresql://"; then - echo "Waiting for database..." - # Extract host and port from DATABASE_URL - # Assumes format postgresql://user:pass@host:port/dbname - # This is a basic extraction and might need adjustment for complex URLs - DB_HOST=$(echo $DATABASE_URL | sed -e 's|^.*@||' -e 's|/.*$||' -e 's|:.*$||') - DB_PORT=$(echo $DATABASE_URL | sed -e 's|^.*@||' -e 's|/.*$||' -e 's|^.*:||') - - # Default port if not specified - if [ "$DB_HOST" = "$DB_PORT" ]; then - DB_PORT=5432 - fi - - # Loop until the database is ready - while ! nc -z $DB_HOST $DB_PORT; do - sleep 1 - done - echo "Database started" -fi - exec "$@" diff --git a/pyproject.toml b/pyproject.toml index e6dfcf5..57112d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "pydantic>=2.11.0,<3.0.0", "python-dotenv>=1.0.0,<2.0.0", "litellm>=1.60.0", - "asyncpg>=0.30.0", + "aiosqlite>=0.20.0", "greenlet>=3.0.0", "openinference-instrumentation-google-adk>=0.1.8", "httpx>=0.27.0,<1.0.0", diff --git a/scripts/speech_client.py b/scripts/speech_client.py index 061bfdb..42c942a 100644 --- a/scripts/speech_client.py +++ b/scripts/speech_client.py @@ -4,6 +4,7 @@ import os import re import sys +from pathlib import Path from typing import Any import httpx @@ -259,11 +260,11 @@ async def main() -> None: # Initialize Blacki ADK Runtime env = initialize_environment(ServerEnv) - # Initialize global container so tools that depend on Postgres can function + # Initialize global container so tools that depend on SQLite can function container = None - if env.database_url: - container = await init_container(env.database_url) - await container.initialize_all_storages() + sqlite_path = env.sqlite_path or str(Path(env.agent_dir) / ".adk" / "tools.db") + container = await init_container(sqlite_path) + await container.initialize_all_storages() runtime = create_adk_runtime(env) diff --git a/src/blacki/adk_runtime.py b/src/blacki/adk_runtime.py index 4cc75b5..3ada299 100644 --- a/src/blacki/adk_runtime.py +++ b/src/blacki/adk_runtime.py @@ -78,14 +78,12 @@ def build_session_service_uri(env: ServerEnv) -> str | None: def build_session_db_kwargs(env: ServerEnv) -> dict[str, Any]: - """Build shared SQLAlchemy kwargs for database-backed ADK sessions.""" - return { - "pool_pre_ping": env.db_pool_pre_ping, - "pool_recycle": env.db_pool_recycle, - "pool_size": env.db_pool_size, - "max_overflow": env.db_max_overflow, - "pool_timeout": env.db_pool_timeout, - } + """Build shared SQLAlchemy kwargs for database-backed ADK sessions. + + Note: Pool settings are only relevant for PostgreSQL. SQLite uses + a single connection and ignores pool settings. + """ + return {} def create_session_service( diff --git a/src/blacki/calories/__init__.py b/src/blacki/calories/__init__.py index f43c099..fea43fd 100644 --- a/src/blacki/calories/__init__.py +++ b/src/blacki/calories/__init__.py @@ -1,7 +1,3 @@ -from .storage import ( - close_calorie_storage, - init_calorie_storage, -) from .tools import ( delete_meal, edit_meal, @@ -11,8 +7,6 @@ ) __all__ = [ - "close_calorie_storage", - "init_calorie_storage", "delete_meal", "edit_meal", "get_calorie_summary", diff --git a/src/blacki/calories/storage.py b/src/blacki/calories/storage.py index 6142129..ddca075 100644 --- a/src/blacki/calories/storage.py +++ b/src/blacki/calories/storage.py @@ -1,14 +1,18 @@ +"""Persistent storage for calorie tracking backed by SQLite.""" + +from __future__ import annotations + import logging -from collections.abc import Mapping from typing import TYPE_CHECKING, Any -import asyncpg # type: ignore[import-untyped] from pydantic import BaseModel -from blacki.storage.base import PostgresStorage +from blacki.storage.base import SqlStorage if TYPE_CHECKING: - pass + import asyncio + + import aiosqlite _ALLOWED_UPDATE_COLUMNS = frozenset( { @@ -36,96 +40,83 @@ class CalorieEntry(BaseModel): protein_g: float | None = None carbs_g: float | None = None fat_g: float | None = None - meal_type: str | None = None # breakfast/lunch/dinner/snack - logged_at: str # UTC ISO - logged_date: str # YYYY-MM-DD local + meal_type: str | None = None + logged_at: str + logged_date: str class DailySummary(BaseModel): """Summary of calorie intake for a specific date.""" - date: str # YYYY-MM-DD + date: str total_calories: int = 0 total_protein_g: float | None = None total_carbs_g: float | None = None total_fat_g: float | None = None entry_count: int = 0 - entries: list[CalorieEntry] = [] # populated only in single-day queries + entries: list[CalorieEntry] = [] -class PostgresCalorieStorage(PostgresStorage): - """Storage for calorie tracking using Postgres via asyncpg.""" +class SqliteCalorieStorage(SqlStorage): + """Storage for calorie tracking using SQLite via aiosqlite.""" - def __init__(self, pool: asyncpg.Pool) -> None: - super().__init__(pool) + def __init__(self, conn: aiosqlite.Connection, lock: asyncio.Lock) -> None: + super().__init__(conn, lock) - async def _create_tables(self, conn: asyncpg.Connection) -> None: - await conn.execute(""" + async def _create_tables(self) -> None: + await self._conn.execute(""" CREATE TABLE IF NOT EXISTS calorie_logs ( - id BIGSERIAL PRIMARY KEY, - user_id TEXT NOT NULL, - description TEXT NOT NULL, - calories INTEGER NOT NULL, - protein_g REAL, - carbs_g REAL, - fat_g REAL, - meal_type TEXT, - logged_at TIMESTAMPTZ NOT NULL, - logged_date DATE NOT NULL + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, + description TEXT NOT NULL, + calories INTEGER NOT NULL, + protein_g REAL, + carbs_g REAL, + fat_g REAL, + meal_type TEXT, + logged_at TEXT NOT NULL, + logged_date TEXT NOT NULL ) """) - column_type = await conn.fetchval(""" - SELECT data_type - FROM information_schema.columns - WHERE table_name = 'calorie_logs' AND column_name = 'protein_g' - """) - if column_type == "integer": - await conn.execute(""" - ALTER TABLE calorie_logs - ALTER COLUMN protein_g TYPE REAL USING protein_g::REAL, - ALTER COLUMN carbs_g TYPE REAL USING carbs_g::REAL, - ALTER COLUMN fat_g TYPE REAL USING fat_g::REAL; - """) - await conn.execute(""" + await self._conn.execute(""" CREATE INDEX IF NOT EXISTS idx_calorie_logs_user_date ON calorie_logs (user_id, logged_date) """) async def add_entry(self, entry: CalorieEntry) -> int: """Insert a calorie entry and return its new row ID.""" - rid = await self._pool.fetchval( + rid = await self._execute( """ INSERT INTO calorie_logs ( user_id, description, calories, protein_g, carbs_g, fat_g, meal_type, logged_at, logged_date ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - RETURNING id - + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, - entry.user_id, - entry.description, - entry.calories, - entry.protein_g, - entry.carbs_g, - entry.fat_g, - entry.meal_type, - entry.logged_at, - entry.logged_date, + ( + entry.user_id, + entry.description, + entry.calories, + entry.protein_g, + entry.carbs_g, + entry.fat_g, + entry.meal_type, + entry.logged_at, + entry.logged_date, + ), ) - return int(rid) + return rid async def get_daily_summary(self, user_id: str, date_str: str) -> DailySummary: """Get summary and up to 50 entries for a specific day.""" - rows = await self._pool.fetch( + rows = await self._fetch_all( """ SELECT * FROM calorie_logs - WHERE user_id = $1 AND logged_date = $2 + WHERE user_id = ? AND logged_date = ? ORDER BY logged_at ASC """, - user_id, - date_str, + (user_id, date_str), ) entries = [self._row_to_entry(r) for r in rows] @@ -152,11 +143,11 @@ async def get_daily_summary(self, user_id: str, date_str: str) -> DailySummary: has_fat = True summary.total_fat_g += e.fat_g - if not has_protein: # pragma: no cover + if not has_protein: summary.total_protein_g = None - if not has_carbs: # pragma: no cover + if not has_carbs: summary.total_carbs_g = None - if not has_fat: # pragma: no cover + if not has_fat: summary.total_fat_g = None return summary @@ -165,7 +156,7 @@ async def get_date_range_summary( self, user_id: str, start_date: str, end_date: str ) -> list[DailySummary]: """Get summaries for a date range, capped at 30 days (no individual entries).""" - rows = await self._pool.fetch( + rows = await self._fetch_all( """ SELECT logged_date, @@ -175,14 +166,12 @@ async def get_date_range_summary( SUM(carbs_g) as total_carbs_g, SUM(fat_g) as total_fat_g FROM calorie_logs - WHERE user_id = $1 AND logged_date >= $2 AND logged_date <= $3 + WHERE user_id = ? AND logged_date >= ? AND logged_date <= ? GROUP BY logged_date ORDER BY logged_date DESC LIMIT 30 """, - user_id, - start_date, - end_date, + (user_id, start_date, end_date), ) summaries = [] @@ -210,36 +199,38 @@ async def get_date_range_summary( async def update_entry(self, entry_id: int, user_id: str, **fields: Any) -> bool: """Update a specific calorie entry.""" - if not fields: # pragma: no cover + if not fields: return False set_clauses = [] - values: list[Any] = [entry_id, user_id] + values: list[Any] = [] - for i, (key, value) in enumerate(fields.items(), start=3): + for key, value in fields.items(): if key not in _ALLOWED_UPDATE_COLUMNS: raise ValueError( f"Column '{key}' is not allowed in calorie_logs UPDATE" ) - set_clauses.append(f"{key} = ${i}") + set_clauses.append(f"{key} = ?") values.append(value) + values.extend([entry_id, user_id]) updates_str = ", ".join(set_clauses) - query = f"UPDATE calorie_logs SET {updates_str} WHERE id = $1 AND user_id = $2" # noqa: S608 + query = f"UPDATE calorie_logs SET {updates_str} WHERE id = ? AND user_id = ?" # noqa: S608 - result = await self._pool.execute(query, *values) - return bool(result == "UPDATE 1") + async with self._lock: + cursor = await self._conn.execute(query, values) + return cursor.rowcount > 0 async def delete_entry(self, entry_id: int, user_id: str) -> bool: """Delete a calorie entry.""" - result = await self._pool.execute( - "DELETE FROM calorie_logs WHERE id = $1 AND user_id = $2", - entry_id, - user_id, - ) - return bool(result == "DELETE 1") + async with self._lock: + cursor = await self._conn.execute( + "DELETE FROM calorie_logs WHERE id = ? AND user_id = ?", + (entry_id, user_id), + ) + return cursor.rowcount > 0 - def _row_to_entry(self, row: Mapping[str, Any]) -> CalorieEntry: + def _row_to_entry(self, row: dict[str, Any]) -> CalorieEntry: return CalorieEntry( id=int(row["id"]), user_id=row["user_id"], @@ -249,20 +240,16 @@ def _row_to_entry(self, row: Mapping[str, Any]) -> CalorieEntry: carbs_g=float(row["carbs_g"]) if row["carbs_g"] is not None else None, fat_g=float(row["fat_g"]) if row["fat_g"] is not None else None, meal_type=row["meal_type"], - logged_at=( - row["logged_at"].isoformat() - if hasattr(row["logged_at"], "isoformat") - else str(row["logged_at"]) - ), + logged_at=row["logged_at"], logged_date=str(row["logged_date"]), ) -_storage: PostgresCalorieStorage | None = None +_storage: SqliteCalorieStorage | None = None -def get_storage() -> PostgresCalorieStorage: - """Return the process-wide singleton PostgresCalorieStorage instance. +def get_storage() -> SqliteCalorieStorage: + """Return the process-wide singleton SqliteCalorieStorage instance. Uses the AppContainer for dependency injection. """ @@ -272,51 +259,6 @@ def get_storage() -> PostgresCalorieStorage: storage = container.calorie_storage if not storage.is_initialized: raise RuntimeError( - "Calorie storage not initialized. Call init_calorie_storage() first." + "Calorie storage not initialized. Call storage.initialize() first." ) return storage - - -async def init_calorie_storage(pool: asyncpg.Pool) -> PostgresCalorieStorage: - """Initialize the calorie storage with a Postgres pool. - - Note: This function is provided for backward compatibility. - Prefer using AppContainer directly for new code. - """ - global _storage - import blacki.container as container_module - - if container_module._container is None: # pragma: no cover - container_module.set_container_from_pool(pool) - - if _storage is not None: - await _storage.close() - _storage = None - - container = container_module._container - if container is None: # pragma: no cover - raise RuntimeError("Container not initialized") - if container._calorie_storage is not None: # pragma: no cover - await container._calorie_storage.close() - - storage = container.calorie_storage - await storage.initialize() - _storage = storage - return storage - - -async def close_calorie_storage() -> None: - """Close the singleton calorie storage. - - Note: This function is provided for backward compatibility. - Prefer using AppContainer.close() for new code. - """ - global _storage - import blacki.container as container_module - - if container_module._container is not None: # pragma: no cover - container = container_module._container - if container._calorie_storage is not None: - await container._calorie_storage.close() - container._calorie_storage = None - _storage = None diff --git a/src/blacki/container.py b/src/blacki/container.py index 28c6d21..2c72070 100644 --- a/src/blacki/container.py +++ b/src/blacki/container.py @@ -4,7 +4,7 @@ replacing the global singleton pattern with explicit dependency injection. Usage: - container = await AppContainer.create(database_url) + container = await AppContainer.create(sqlite_path) await container.initialize_all_storages() set_container(container) try: @@ -17,17 +17,19 @@ from __future__ import annotations +import asyncio import logging from dataclasses import dataclass, field +from pathlib import Path from typing import TYPE_CHECKING, Self if TYPE_CHECKING: - import asyncpg # type: ignore[import-untyped] + import aiosqlite - from blacki.calories.storage import PostgresCalorieStorage - from blacki.reminders.storage import PostgresReminderStorage - from blacki.utils.preferences import PostgresPreferencesStorage - from blacki.workouts.storage import PostgresWorkoutStorage + from blacki.calories.storage import SqliteCalorieStorage + from blacki.reminders.storage import SqliteReminderStorage + from blacki.utils.preferences import SqlitePreferencesStorage + from blacki.workouts.storage import SqliteWorkoutStorage logger = logging.getLogger(__name__) @@ -51,17 +53,16 @@ def set_container(container: AppContainer | None) -> None: _container = container -async def init_container(database_url: str, pool_size: int = 5) -> AppContainer: +async def init_container(sqlite_path: str | Path) -> AppContainer: """Create and set the global container. Args: - database_url: Postgres connection string. - pool_size: Maximum number of connections (default: 5). + sqlite_path: Path to the SQLite database file. Returns: - Initialized container with database pool. + Initialized container with database connection. """ - container = await AppContainer.create(database_url, pool_size) + container = await AppContainer.create(sqlite_path) set_container(container) return container @@ -84,19 +85,26 @@ def reset_container_for_tests() -> None: _container = None -def set_container_from_pool(pool: asyncpg.Pool) -> AppContainer: - """Create and set a container from an existing pool. +def set_container_from_connection( + conn: aiosqlite.Connection, + lock: asyncio.Lock | None = None, +) -> AppContainer: + """Create and set a container from an existing connection. - Useful for tests that create their own mock pool. + Useful for tests that create their own mock connection. Args: - pool: An existing asyncpg pool. + conn: An existing aiosqlite connection. + lock: Optional lock for write operations. If None, creates a new one. Returns: - Container instance using the provided pool. + Container instance using the provided connection. """ global _container - _container = AppContainer(pool=pool) + _container = AppContainer( + conn=conn, + _lock=lock or asyncio.Lock(), + ) return _container @@ -104,53 +112,48 @@ def set_container_from_pool(pool: asyncpg.Pool) -> AppContainer: class AppContainer: """Container for managing application-wide resources. - Manages the lifecycle of the database pool and storage singletons. + Manages the lifecycle of the database connection and storage singletons. All storages are lazily instantiated on first access. Attributes: - pool: The asyncpg connection pool. + conn: The aiosqlite connection. + _lock: Shared lock for write operations. """ - pool: asyncpg.Pool - _reminder_storage: PostgresReminderStorage | None = field( + conn: aiosqlite.Connection + _lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False) + _reminder_storage: SqliteReminderStorage | None = field( default=None, init=False, repr=False ) - _calorie_storage: PostgresCalorieStorage | None = field( + _calorie_storage: SqliteCalorieStorage | None = field( default=None, init=False, repr=False ) - _workout_storage: PostgresWorkoutStorage | None = field( + _workout_storage: SqliteWorkoutStorage | None = field( default=None, init=False, repr=False ) - _preferences_storage: PostgresPreferencesStorage | None = field( + _preferences_storage: SqlitePreferencesStorage | None = field( default=None, init=False, repr=False ) @classmethod - async def create( - cls, database_url: str, pool_size: int = 5 - ) -> Self: # pragma: no cover - """Create and initialize the container with a database pool. + async def create(cls, sqlite_path: str | Path) -> Self: + """Create and initialize the container with a SQLite database. Args: - database_url: Postgres connection string. - pool_size: Maximum number of connections (default: 5). + sqlite_path: Path to the SQLite database file. Returns: - Initialized container with database pool. + Initialized container with database connection. """ - import asyncpg + from blacki.storage.sqlite import create_connection - pool = await asyncpg.create_pool( - database_url, - min_size=1, - max_size=pool_size, - ) - return cls(pool=pool) + conn = await create_connection(sqlite_path) + return cls(conn=conn) async def close(self) -> None: - """Close all storage instances and the pool.""" + """Close all storage instances and the connection.""" await self._close_storages() - await self.pool.close() + await self.conn.close() logger.info("AppContainer closed") async def _close_storages(self) -> None: @@ -183,37 +186,42 @@ async def initialize_all_storages(self) -> None: await self.preferences_storage.initialize() @property - def reminder_storage(self) -> PostgresReminderStorage: + def lock(self) -> asyncio.Lock: + """Get the shared write lock.""" + return self._lock + + @property + def reminder_storage(self) -> SqliteReminderStorage: """Get or create the reminder storage instance.""" if self._reminder_storage is None: - from blacki.reminders.storage import PostgresReminderStorage + from blacki.reminders.storage import SqliteReminderStorage - self._reminder_storage = PostgresReminderStorage(self.pool) + self._reminder_storage = SqliteReminderStorage(self.conn, self._lock) return self._reminder_storage @property - def calorie_storage(self) -> PostgresCalorieStorage: + def calorie_storage(self) -> SqliteCalorieStorage: """Get or create the calorie storage instance.""" if self._calorie_storage is None: - from blacki.calories.storage import PostgresCalorieStorage + from blacki.calories.storage import SqliteCalorieStorage - self._calorie_storage = PostgresCalorieStorage(self.pool) + self._calorie_storage = SqliteCalorieStorage(self.conn, self._lock) return self._calorie_storage @property - def workout_storage(self) -> PostgresWorkoutStorage: + def workout_storage(self) -> SqliteWorkoutStorage: """Get or create the workout storage instance.""" if self._workout_storage is None: - from blacki.workouts.storage import PostgresWorkoutStorage + from blacki.workouts.storage import SqliteWorkoutStorage - self._workout_storage = PostgresWorkoutStorage(self.pool) + self._workout_storage = SqliteWorkoutStorage(self.conn, self._lock) return self._workout_storage @property - def preferences_storage(self) -> PostgresPreferencesStorage: + def preferences_storage(self) -> SqlitePreferencesStorage: """Get or create the preferences storage instance.""" if self._preferences_storage is None: - from blacki.utils.preferences import PostgresPreferencesStorage + from blacki.utils.preferences import SqlitePreferencesStorage - self._preferences_storage = PostgresPreferencesStorage(self.pool) + self._preferences_storage = SqlitePreferencesStorage(self.conn, self._lock) return self._preferences_storage diff --git a/src/blacki/registry.py b/src/blacki/registry.py index 284e906..1c9a4d0 100644 --- a/src/blacki/registry.py +++ b/src/blacki/registry.py @@ -23,13 +23,13 @@ class ToolConfig: Attributes: brave_search_api_key: API key for Brave Search. - database_url: Postgres connection string for storage-backed tools. + sqlite_path: Path to SQLite database for storage-backed tools. sandbox_enabled: Whether to enable sandbox tools. skills_dir: Directory containing skill definitions. """ brave_search_api_key: str | None = None - database_url: str | None = None + sqlite_path: str | None = None sandbox_enabled: bool = False skills_dir: Path | None = None weather_enabled: bool = True @@ -50,7 +50,7 @@ def build_tools(config: ToolConfig) -> list[Any]: tools.extend(_build_brave_search_tools()) logger.info("Brave Search tool enabled") - if config.database_url: + if config.sqlite_path: tools.extend(_build_reminder_tools()) tools.extend(_build_calorie_tools()) tools.extend(_build_workout_tools()) @@ -235,7 +235,7 @@ def build_tool_config_from_env() -> ToolConfig: return ToolConfig( brave_search_api_key=os.getenv("BRAVE_SEARCH_API_KEY", "").strip() or None, - database_url=os.getenv("DATABASE_URL", "").strip() or None, + sqlite_path=os.getenv("SQLITE_PATH", "").strip() or None, sandbox_enabled=os.getenv("SANDBOX_ENABLED", "false").strip().lower() in ("true", "1", "yes"), skills_dir=skills_dir, diff --git a/src/blacki/reminders/__init__.py b/src/blacki/reminders/__init__.py index 1efde41..c593aea 100644 --- a/src/blacki/reminders/__init__.py +++ b/src/blacki/reminders/__init__.py @@ -5,7 +5,7 @@ """ from .scheduler import ReminderScheduler, get_scheduler -from .storage import Reminder, get_storage, init_reminder_storage +from .storage import Reminder, get_storage from .tools import ( SUPPORTED_RECURRENCE_MESSAGE, cancel_reminder, @@ -20,7 +20,6 @@ "cancel_reminder", "get_scheduler", "get_storage", - "init_reminder_storage", "list_reminders", "schedule_reminder", ] diff --git a/src/blacki/reminders/storage.py b/src/blacki/reminders/storage.py index e9ce243..6deeea1 100644 --- a/src/blacki/reminders/storage.py +++ b/src/blacki/reminders/storage.py @@ -1,22 +1,23 @@ -"""Persistent storage for scheduled reminders backed by Postgres via asyncpg. +"""Persistent storage for scheduled reminders backed by SQLite. -A single asyncpg pool is used for all database operations. The pool is shared -with the ADK session service when DATABASE_URL is configured. +A single aiosqlite connection is used for all database operations. """ +from __future__ import annotations + import abc import logging -from collections.abc import Mapping from typing import TYPE_CHECKING, Any -import asyncpg # type: ignore[import-untyped] from pydantic import BaseModel -from blacki.storage.base import PostgresStorage +from blacki.storage.base import SqlStorage from blacki.utils.timezone import now_utc if TYPE_CHECKING: - pass + import asyncio + + import aiosqlite logger = logging.getLogger(__name__) @@ -95,39 +96,37 @@ async def delete_reminder(self, reminder_id: int, user_id: str) -> bool: """Delete a reminder if it belongs to the given user.""" -class PostgresReminderStorage(PostgresStorage): - """Storage for reminders using Postgres via asyncpg.""" +class SqliteReminderStorage(SqlStorage): + """Storage for reminders using SQLite via aiosqlite.""" - def __init__(self, pool: asyncpg.Pool) -> None: - super().__init__(pool) + def __init__(self, conn: aiosqlite.Connection, lock: asyncio.Lock) -> None: + super().__init__(conn, lock) - async def _create_tables(self, conn: asyncpg.Connection) -> None: - await conn.execute(""" + async def _create_tables(self) -> None: + await self._conn.execute(""" CREATE TABLE IF NOT EXISTS reminders ( - id BIGSERIAL PRIMARY KEY, - user_id TEXT NOT NULL, - message TEXT NOT NULL, - trigger_time TEXT NOT NULL, - is_sent BOOLEAN NOT NULL DEFAULT FALSE, + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, + message TEXT NOT NULL, + trigger_time TEXT NOT NULL, + is_sent INTEGER NOT NULL DEFAULT 0, recurrence_rule TEXT, recurrence_text TEXT, timezone_name TEXT, - created_at TEXT NOT NULL + created_at TEXT NOT NULL ) """) - await conn.execute("DROP INDEX IF EXISTS idx_reminders_trigger_time_sent") - await conn.execute(""" - CREATE INDEX IF NOT EXISTS idx_reminders_due_reminders - ON reminders (trigger_time) - WHERE is_sent = FALSE + await self._conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_reminders_due + ON reminders (is_sent, trigger_time) """) - await conn.execute(""" + await self._conn.execute(""" CREATE INDEX IF NOT EXISTS idx_reminders_user_id ON reminders (user_id) """) async def add_reminder(self, reminder: Reminder) -> int: - rid = await self._pool.fetchval( + rid = await self._execute( """ INSERT INTO reminders ( @@ -140,30 +139,31 @@ async def add_reminder(self, reminder: Reminder) -> int: timezone_name, created_at ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - RETURNING id + VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, - reminder.user_id, - reminder.message, - reminder.trigger_time, - reminder.is_sent, - reminder.recurrence_rule, - reminder.recurrence_text, - reminder.timezone_name, - reminder.created_at, + ( + reminder.user_id, + reminder.message, + reminder.trigger_time, + int(reminder.is_sent), + reminder.recurrence_rule, + reminder.recurrence_text, + reminder.timezone_name, + reminder.created_at, + ), ) logger.info( - "Added reminder %s for user %s: '%s...' at %s (Postgres)", + "Added reminder %s for user %s: '%s...' at %s (SQLite)", rid, reminder.user_id, reminder.message[:30], reminder.trigger_time, ) - return int(rid) + return rid async def get_due_reminders(self) -> list[Reminder]: now = now_utc().isoformat(timespec="seconds") - rows = await self._pool.fetch( + rows = await self._fetch_all( """ SELECT id, @@ -176,35 +176,36 @@ async def get_due_reminders(self) -> list[Reminder]: timezone_name, created_at FROM reminders - WHERE trigger_time <= $1 AND is_sent = FALSE + WHERE trigger_time <= ? AND is_sent = 0 ORDER BY trigger_time ASC - LIMIT $2 + LIMIT ? """, - now, - DUE_REMINDERS_FETCH_LIMIT, + (now, DUE_REMINDERS_FETCH_LIMIT), ) return [self._row_to_reminder(r) for r in rows] async def mark_sent(self, reminder_id: int) -> None: - await self._pool.execute( - "UPDATE reminders SET is_sent = TRUE WHERE id = $1", reminder_id - ) - logger.info("Marked reminder %s as sent (Postgres)", reminder_id) + async with self._lock: + await self._conn.execute( + "UPDATE reminders SET is_sent = 1 WHERE id = ?", + (reminder_id,), + ) + logger.info("Marked reminder %s as sent (SQLite)", reminder_id) async def reschedule_reminder( self, reminder_id: int, next_trigger_time: str ) -> None: - await self._pool.execute( - """ - UPDATE reminders - SET trigger_time = $1, is_sent = FALSE - WHERE id = $2 - """, - next_trigger_time, - reminder_id, - ) + async with self._lock: + await self._conn.execute( + """ + UPDATE reminders + SET trigger_time = ?, is_sent = 0 + WHERE id = ? + """, + (next_trigger_time, reminder_id), + ) logger.info( - "Rescheduled recurring reminder %s for %s (Postgres)", + "Rescheduled recurring reminder %s for %s (SQLite)", reminder_id, next_trigger_time, ) @@ -214,36 +215,29 @@ async def get_user_reminders( ) -> list[Reminder]: query = """ SELECT - id, - user_id, - message, - trigger_time, - is_sent, - recurrence_rule, - recurrence_text, - timezone_name, - created_at - FROM reminders WHERE user_id = $1 + id, user_id, message, trigger_time, is_sent, + recurrence_rule, recurrence_text, timezone_name, created_at + FROM reminders WHERE user_id = ? """ + params: list[Any] = [user_id] if not include_sent: - query += " AND is_sent = FALSE" + query += " AND is_sent = 0" query += " ORDER BY trigger_time ASC" - - rows = await self._pool.fetch(query, user_id) + rows = await self._fetch_all(query, tuple(params)) return [self._row_to_reminder(r) for r in rows] async def delete_reminder(self, reminder_id: int, user_id: str) -> bool: - result = await self._pool.execute( - "DELETE FROM reminders WHERE id = $1 AND user_id = $2", - reminder_id, - user_id, - ) - deleted = bool(result == "DELETE 1") + async with self._lock: + cursor = await self._conn.execute( + "DELETE FROM reminders WHERE id = ? AND user_id = ?", + (reminder_id, user_id), + ) + deleted = cursor.rowcount > 0 if deleted: - logger.info("Deleted reminder %s (Postgres)", reminder_id) + logger.info("Deleted reminder %s (SQLite)", reminder_id) return deleted - def _row_to_reminder(self, row: Mapping[str, Any]) -> Reminder: + def _row_to_reminder(self, row: dict[str, Any]) -> Reminder: return Reminder( id=int(row["id"]), user_id=row["user_id"], @@ -257,10 +251,10 @@ def _row_to_reminder(self, row: Mapping[str, Any]) -> Reminder: ) -_storage: PostgresReminderStorage | None = None +_storage: SqliteReminderStorage | None = None -def get_storage() -> PostgresReminderStorage: +def get_storage() -> SqliteReminderStorage: """Return the process-wide singleton ReminderStorage instance. Uses the AppContainer for dependency injection. @@ -271,51 +265,6 @@ def get_storage() -> PostgresReminderStorage: storage = container.reminder_storage if not storage.is_initialized: raise RuntimeError( - "Reminder storage not initialized. Call init_reminder_storage() first." + "Reminder storage not initialized. Call storage.initialize() first." ) return storage - - -async def init_reminder_storage(pool: asyncpg.Pool) -> PostgresReminderStorage: - """Initialize the reminder storage with a Postgres pool. - - Note: This function is provided for backward compatibility. - Prefer using AppContainer directly for new code. - """ - global _storage - import blacki.container as container_module - - if container_module._container is None: # pragma: no cover - container_module.set_container_from_pool(pool) - - if _storage is not None: - await _storage.close() - _storage = None - - container = container_module._container - if container is None: # pragma: no cover - raise RuntimeError("Container not initialized") - if container._reminder_storage is not None: # pragma: no cover - await container._reminder_storage.close() - - storage = container.reminder_storage - await storage.initialize() - _storage = storage - return storage - - -async def close_reminder_storage() -> None: - """Close the singleton reminder storage. - - Note: This function is provided for backward compatibility. - Prefer using AppContainer.close() for new code. - """ - global _storage - import blacki.container as container_module - - if container_module._container is not None: # pragma: no cover - container = container_module._container - if container._reminder_storage is not None: - await container._reminder_storage.close() - container._reminder_storage = None - _storage = None diff --git a/src/blacki/server.py b/src/blacki/server.py index 44fc0ea..e5f714e 100644 --- a/src/blacki/server.py +++ b/src/blacki/server.py @@ -17,11 +17,7 @@ from google.adk.cli.fast_api import get_fast_api_app from openinference.instrumentation.google_adk import GoogleADKInstrumentor -from .adk_runtime import ( - build_session_db_kwargs, - build_session_service_uri, - create_adk_runtime, -) +from .adk_runtime import create_adk_runtime from .container import AppContainer, close_container, init_container from .utils import ( ConfigurationError, @@ -131,13 +127,11 @@ async def _stop_reminder_scheduler() -> None: AGENT_DIR = os.getenv("AGENT_DIR", str(Path(__file__).resolve().parent.parent)) -session_uri = build_session_service_uri(env) -session_db_kwargs = build_session_db_kwargs(env) +DEFAULT_SQLITE_PATH = str(Path(AGENT_DIR) / ".adk" / "tools.db") app: FastAPI = get_fast_api_app( agents_dir=AGENT_DIR, - session_service_uri=session_uri, - session_db_kwargs=session_db_kwargs, + session_service_uri=None, artifact_service_uri=None, memory_service_uri="mem0://", allow_origins=env.allow_origins_list, @@ -156,9 +150,9 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]: """ global _container - if env.database_url: - _container = await init_container(env.database_url) - await _container.initialize_all_storages() + sqlite_path = env.sqlite_path or DEFAULT_SQLITE_PATH + _container = await init_container(sqlite_path) + await _container.initialize_all_storages() logger.info("Validating configuration...") try: @@ -208,7 +202,8 @@ async def health() -> dict[str, Any]: if _container is not None: try: - await _container.pool.fetchval("SELECT 1") + async with _container.conn.execute("SELECT 1") as cursor: + await cursor.fetchone() checks["database"] = "healthy" except Exception: checks["database"] = "unhealthy" @@ -246,7 +241,7 @@ def main() -> None: SERVE_WEB_INTERFACE: Whether to serve the web interface (true/false) RELOAD_AGENTS: Whether to reload agents on file changes (true/false) AGENT_ENGINE: Agent Engine instance for session and memory - DATABASE_URL: Postgres URL for session and memory + SQLITE_PATH: Path to SQLite database (default: {AGENT_DIR}/.adk/tools.db) OPENROUTER_API_KEY: Key for LiteLLM/OpenRouter ALLOW_ORIGINS: JSON array string of allowed CORS origins HOST: Server host (default: 127.0.0.1, set to 0.0.0.0 for containers) diff --git a/src/blacki/storage/__init__.py b/src/blacki/storage/__init__.py index 09c9012..b10e786 100644 --- a/src/blacki/storage/__init__.py +++ b/src/blacki/storage/__init__.py @@ -1,5 +1,5 @@ """Storage module with base class and implementations.""" -from .base import PostgresStorage +from .base import SqlStorage -__all__ = ["PostgresStorage"] +__all__ = ["SqlStorage"] diff --git a/src/blacki/storage/base.py b/src/blacki/storage/base.py index 8e332c9..5834a26 100644 --- a/src/blacki/storage/base.py +++ b/src/blacki/storage/base.py @@ -1,38 +1,39 @@ -"""Base class for Postgres-backed storage implementations.""" +"""Base class for SQLite-backed storage implementations.""" from __future__ import annotations import asyncio import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - import asyncpg # type: ignore[import-untyped] + import aiosqlite logger = logging.getLogger(__name__) -class PostgresStorage(ABC): - """Abstract base class for Postgres-backed storage. +class SqlStorage(ABC): + """Abstract base class for SQLite-backed storage. - Provides common initialization pattern with thread-safe schema creation. - Subclasses must implement _create_tables(). + Provides common initialization pattern with thread-safe schema creation + and unified query helpers that abstract away SQLite-specific patterns. Attributes: - _pool: The asyncpg connection pool. - _lock: Async lock for thread-safe initialization. + _conn: The aiosqlite connection. + _lock: Async lock for thread-safe operations (shared across storages). _schema_ready: Whether schema has been created. """ - def __init__(self, pool: asyncpg.Pool) -> None: - """Initialize storage with a Postgres pool. + def __init__(self, conn: aiosqlite.Connection, lock: asyncio.Lock) -> None: + """Initialize storage with a SQLite connection. Args: - pool: asyncpg connection pool. + conn: aiosqlite connection. + lock: Shared lock for write operations. """ - self._pool = pool - self._lock = asyncio.Lock() + self._conn = conn + self._lock = lock self._schema_ready = False async def initialize(self) -> None: @@ -44,15 +45,14 @@ async def initialize(self) -> None: async with self._lock: if self._schema_ready: return - async with self._pool.acquire() as conn: - await self._create_tables(conn) + await self._create_tables() self._schema_ready = True - logger.info("%s schema ready (Postgres)", self.__class__.__name__) + logger.info("%s schema ready (SQLite)", self.__class__.__name__) async def close(self) -> None: """Mark storage as uninitialized. - Note: Pool lifecycle is managed externally (by AppContainer). + Note: Connection lifecycle is managed externally (by AppContainer). """ async with self._lock: self._schema_ready = False @@ -63,9 +63,122 @@ def is_initialized(self) -> bool: return self._schema_ready @abstractmethod - async def _create_tables(self, conn: asyncpg.Connection) -> None: + async def _create_tables(self) -> None: """Create tables and indexes. + Override in subclasses to define schema. + """ + + async def _execute( + self, + query: str, + params: tuple[Any, ...] = (), + *, + use_lock: bool = True, + ) -> int: + """Execute a write query and return lastrowid. + + Args: + query: SQL query with ? placeholders. + params: Query parameters. + use_lock: Whether to acquire the write lock. + + Returns: + The last inserted row ID. + + Raises: + RuntimeError: If lastrowid is None after insert. + """ + if use_lock: + async with self._lock: # noqa: SIM117 + async with self._conn.execute(query, params) as cursor: + if cursor.lastrowid is None: + raise RuntimeError("Failed to get lastrowid after insert") + return cursor.lastrowid + else: + async with self._conn.execute(query, params) as cursor: + if cursor.lastrowid is None: + raise RuntimeError("Failed to get lastrowid after insert") + return cursor.lastrowid + + async def _execute_many( + self, + query: str, + params_list: list[tuple[Any, ...]], + *, + use_lock: bool = True, + ) -> None: + """Execute a write query multiple times with different params. + + Args: + query: SQL query with ? placeholders. + params_list: List of parameter tuples. + use_lock: Whether to acquire the write lock. + """ + if use_lock: + async with self._lock: + await self._conn.executemany(query, params_list) + else: + await self._conn.executemany(query, params_list) + + async def _fetch_one( + self, + query: str, + params: tuple[Any, ...] = (), + ) -> dict[str, Any] | None: + """Execute a query and return a single row as dict. + + Args: + query: SQL query with ? placeholders. + params: Query parameters. + + Returns: + A dict representing the row, or None if no result. + """ + async with self._conn.execute(query, params) as cursor: + row = await cursor.fetchone() + return dict(row) if row else None + + async def _fetch_all( + self, + query: str, + params: tuple[Any, ...] = (), + ) -> list[dict[str, Any]]: + """Execute a query and return all rows as dicts. + Args: - conn: Database connection to use for DDL operations. + query: SQL query with ? placeholders. + params: Query parameters. + + Returns: + A list of dicts representing the rows. + """ + async with self._conn.execute(query, params) as cursor: + rows = await cursor.fetchall() + return [dict(r) for r in rows] + + async def _fetch_val( + self, + query: str, + params: tuple[Any, ...] = (), + ) -> Any: + """Execute a query and return a single value. + + Args: + query: SQL query with ? placeholders. + params: Query parameters. + + Returns: + The first column of the first row, or None. + """ + async with self._conn.execute(query, params) as cursor: + row = await cursor.fetchone() + return row[0] if row else None + + @property + def conn(self) -> aiosqlite.Connection: + """Get the underlying connection for advanced operations. + + Use with caution - bypasses the write lock. """ + return self._conn diff --git a/src/blacki/storage/sqlite.py b/src/blacki/storage/sqlite.py new file mode 100644 index 0000000..97e456f --- /dev/null +++ b/src/blacki/storage/sqlite.py @@ -0,0 +1,54 @@ +"""SQLite connection management for tools.db.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +import aiosqlite + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + +DEFAULT_BUSY_TIMEOUT_MS = 5000 + + +async def create_connection( + db_path: str | Path, + *, + busy_timeout_ms: int = DEFAULT_BUSY_TIMEOUT_MS, +) -> aiosqlite.Connection: + """Create a SQLite connection with WAL mode and optimal settings. + + Args: + db_path: Path to the SQLite database file. + busy_timeout_ms: Milliseconds to wait when database is locked. + + Returns: + Configured aiosqlite connection. + """ + path = Path(db_path) + path.parent.mkdir(parents=True, exist_ok=True) + + conn = await aiosqlite.connect(path) + conn.row_factory = aiosqlite.Row + + await conn.execute("PRAGMA journal_mode=WAL") + await conn.execute(f"PRAGMA busy_timeout={busy_timeout_ms}") + await conn.execute("PRAGMA foreign_keys=ON") + + logger.info("SQLite connection opened: %s (WAL mode)", path) + return conn + + +async def close_connection(conn: aiosqlite.Connection) -> None: + """Close a SQLite connection. + + Args: + conn: The connection to close. + """ + await conn.close() + logger.info("SQLite connection closed") diff --git a/src/blacki/utils/config.py b/src/blacki/utils/config.py index 5635471..302ee61 100644 --- a/src/blacki/utils/config.py +++ b/src/blacki/utils/config.py @@ -125,40 +125,10 @@ class ServerEnv(BaseModel): description="Agent Engine instance ID for session and memory persistence", ) - database_url: str | None = Field( + sqlite_path: str | None = Field( default=None, - alias="DATABASE_URL", - description="Database URL for session storage (e.g., postgresql://...)", - ) - - db_pool_pre_ping: bool = Field( - default=True, - alias="DB_POOL_PRE_PING", - description="Validate DB connections before use", - ) - - db_pool_recycle: int = Field( - default=1800, - alias="DB_POOL_RECYCLE", - description="Recycle connections after this many seconds", - ) - - db_pool_size: int = Field( - default=5, - alias="DB_POOL_SIZE", - description="Number of connections to keep open inside the connection pool", - ) - - db_max_overflow: int = Field( - default=10, - alias="DB_MAX_OVERFLOW", - description="Number of connections to allow beyond pool_size", - ) - - db_pool_timeout: int = Field( - default=30, - alias="DB_POOL_TIMEOUT", - description="Seconds to wait before giving up on getting a connection", + alias="SQLITE_PATH", + description="Path to SQLite database file (default: {AGENT_DIR}/.adk/tools.db)", ) openrouter_api_key: str | None = Field( @@ -219,13 +189,7 @@ def print_config(self) -> None: print(f"SERVE_WEB_INTERFACE: {self.serve_web_interface}") print(f"RELOAD_AGENTS: {self.reload_agents}") print(f"AGENT_ENGINE: {self.agent_engine}") - print(f"DATABASE_URL: {self.database_url}") - if self.database_url: - print(f"DB_POOL_PRE_PING: {self.db_pool_pre_ping}") - print(f"DB_POOL_RECYCLE: {self.db_pool_recycle}") - print(f"DB_POOL_SIZE: {self.db_pool_size}") - print(f"DB_MAX_OVERFLOW: {self.db_max_overflow}") - print(f"DB_POOL_TIMEOUT: {self.db_pool_timeout}") + print(f"SQLITE_PATH: {self.sqlite_path}") masked_key = "********" if self.openrouter_api_key else "[not set]" print(f"OPENROUTER_KEY: {masked_key}") print(f"HOST: {self.host}") @@ -246,7 +210,7 @@ def session_uri(self) -> str | None: """Session service URI (Agent Engine or in-memory). Returns agent_engine_uri for ADK sessions, defaulting to None for - in-memory sessions. DATABASE_URL is reserved for the Reminders system. + in-memory sessions. """ return self.agent_engine_uri diff --git a/src/blacki/utils/preferences.py b/src/blacki/utils/preferences.py index 9a29b52..2b26007 100644 --- a/src/blacki/utils/preferences.py +++ b/src/blacki/utils/preferences.py @@ -1,80 +1,80 @@ +"""Persistent storage for user preferences backed by SQLite.""" + +from __future__ import annotations + import json import logging from typing import TYPE_CHECKING, Any -import asyncpg # type: ignore[import-untyped] - -from blacki.storage.base import PostgresStorage +from blacki.storage.base import SqlStorage from blacki.utils.timezone import now_utc if TYPE_CHECKING: - pass + import asyncio + + import aiosqlite logger = logging.getLogger(__name__) -class PostgresPreferencesStorage(PostgresStorage): - """Storage for user preferences using Postgres via asyncpg.""" +class SqlitePreferencesStorage(SqlStorage): + """Storage for user preferences using SQLite via aiosqlite.""" - def __init__(self, pool: asyncpg.Pool) -> None: - super().__init__(pool) + def __init__(self, conn: aiosqlite.Connection, lock: asyncio.Lock) -> None: + super().__init__(conn, lock) - async def _create_tables(self, conn: asyncpg.Connection) -> None: - await conn.execute(""" + async def _create_tables(self) -> None: + await self._conn.execute(""" CREATE TABLE IF NOT EXISTS user_preferences ( - user_id TEXT NOT NULL, - key TEXT NOT NULL, - value JSONB NOT NULL, - updated_at TEXT NOT NULL, + user_id TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + updated_at TEXT NOT NULL, PRIMARY KEY (user_id, key) ) """) async def get(self, user_id: str, key: str, default: Any = None) -> Any: """Get a preference value.""" - row = await self._pool.fetchrow( - "SELECT value FROM user_preferences WHERE user_id = $1 AND key = $2", - user_id, - key, + row = await self._fetch_one( + "SELECT value FROM user_preferences WHERE user_id = ? AND key = ?", + (user_id, key), ) if row is None: return default - value = row["value"] - return json.loads(value) if isinstance(value, str) else value + return json.loads(row["value"]) async def set(self, user_id: str, key: str, value: Any) -> None: """Set a preference value.""" now = now_utc().isoformat(timespec="seconds") value_json = json.dumps(value) - await self._pool.execute( - """ - INSERT INTO user_preferences (user_id, key, value, updated_at) - VALUES ($1, $2, $3::jsonb, $4) - ON CONFLICT (user_id, key) DO UPDATE - SET value = EXCLUDED.value, updated_at = EXCLUDED.updated_at - """, - user_id, - key, - value_json, - now, - ) + async with self._lock: + await self._conn.execute( + """ + INSERT INTO user_preferences (user_id, key, value, updated_at) + VALUES (?, ?, ?, ?) + ON CONFLICT (user_id, key) DO UPDATE + SET value = excluded.value, updated_at = excluded.updated_at + """, + (user_id, key, value_json, now), + ) logger.info("Updated preference %s for user %s", key, user_id) async def delete(self, user_id: str, key: str) -> bool: """Delete a preference.""" - result = await self._pool.execute( - "DELETE FROM user_preferences WHERE user_id = $1 AND key = $2", - user_id, - key, - ) - return bool(result == "DELETE 1") + async with self._lock: + cursor = await self._conn.execute( + "DELETE FROM user_preferences WHERE user_id = ? AND key = ?", + (user_id, key), + ) + return cursor.rowcount > 0 -_storage: PostgresPreferencesStorage | None = None +_storage: SqlitePreferencesStorage | None = None -def get_preferences_storage() -> PostgresPreferencesStorage: - """Return the process-wide singleton PostgresPreferencesStorage instance. +def get_preferences_storage() -> SqlitePreferencesStorage: + """Return the process-wide singleton SqlitePreferencesStorage instance. Uses the AppContainer for dependency injection. """ @@ -84,52 +84,6 @@ def get_preferences_storage() -> PostgresPreferencesStorage: storage = container.preferences_storage if not storage.is_initialized: raise RuntimeError( - "Preferences storage not initialized. " - "Call init_preferences_storage() first." + "Preferences storage not initialized. Call storage.initialize() first." ) return storage - - -async def init_preferences_storage(pool: asyncpg.Pool) -> PostgresPreferencesStorage: - """Initialize the preferences storage with a Postgres pool. - - Note: This function is provided for backward compatibility. - Prefer using AppContainer directly for new code. - """ - global _storage - import blacki.container as container_module - - if container_module._container is None: # pragma: no cover - container_module.set_container_from_pool(pool) - - if _storage is not None: - await _storage.close() - _storage = None - - container = container_module._container - if container is None: # pragma: no cover - raise RuntimeError("Container not initialized") - if container._preferences_storage is not None: # pragma: no cover - await container._preferences_storage.close() - - storage = container.preferences_storage - await storage.initialize() - _storage = storage - return storage - - -async def close_preferences_storage() -> None: - """Close the singleton preferences storage. - - Note: This function is provided for backward compatibility. - Prefer using AppContainer.close() for new code. - """ - global _storage - import blacki.container as container_module - - if container_module._container is not None: # pragma: no cover - container = container_module._container - if container._preferences_storage is not None: - await container._preferences_storage.close() - container._preferences_storage = None - _storage = None diff --git a/src/blacki/workouts/__init__.py b/src/blacki/workouts/__init__.py index 4f4813b..d758335 100644 --- a/src/blacki/workouts/__init__.py +++ b/src/blacki/workouts/__init__.py @@ -1,7 +1,3 @@ -from .storage import ( - close_workout_storage, - init_workout_storage, -) from .tools import ( delete_workout, get_exercise_progress, @@ -13,8 +9,6 @@ ) __all__ = [ - "close_workout_storage", - "init_workout_storage", "delete_workout", "get_exercise_progress", "get_last_workout", diff --git a/src/blacki/workouts/storage.py b/src/blacki/workouts/storage.py index 64a9f76..3dc2948 100644 --- a/src/blacki/workouts/storage.py +++ b/src/blacki/workouts/storage.py @@ -1,15 +1,19 @@ +"""Persistent storage for workout tracking backed by SQLite.""" + +from __future__ import annotations + import json import logging -from collections.abc import Mapping from typing import TYPE_CHECKING, Any -import asyncpg # type: ignore[import-untyped] from pydantic import BaseModel -from blacki.storage.base import PostgresStorage +from blacki.storage.base import SqlStorage if TYPE_CHECKING: - pass + import asyncio + + import aiosqlite logger = logging.getLogger(__name__) @@ -28,7 +32,7 @@ class WorkoutExercise(BaseModel): id: int | None = None session_id: int | None = None - exercise_name: str # always lowercase + exercise_name: str sets: list[SetDetail] exercise_order: int = 0 notes: str | None = None @@ -39,10 +43,10 @@ class WorkoutSession(BaseModel): id: int | None = None user_id: str - workout_date: str # YYYY-MM-DD local - split_name: str # Push / Pull / Legs / etc. + workout_date: str + split_name: str notes: str | None = None - created_at: str # ISO UTC + created_at: str exercises: list[WorkoutExercise] = [] @@ -63,101 +67,115 @@ class ExerciseHistoryEntry(BaseModel): sets: list[SetDetail] best_set_weight_kg: float best_set_reps: int - total_volume_kg: float # sum(weight * reps) across all working sets + total_volume_kg: float -class PostgresWorkoutStorage(PostgresStorage): - """Storage for workout tracking using Postgres via asyncpg.""" +class SqliteWorkoutStorage(SqlStorage): + """Storage for workout tracking using SQLite via aiosqlite.""" - def __init__(self, pool: asyncpg.Pool) -> None: - super().__init__(pool) + def __init__(self, conn: aiosqlite.Connection, lock: asyncio.Lock) -> None: + super().__init__(conn, lock) - async def _create_tables(self, conn: asyncpg.Connection) -> None: - await conn.execute(""" + async def _create_tables(self) -> None: + await self._conn.execute(""" CREATE TABLE IF NOT EXISTS workout_sessions ( - id BIGSERIAL PRIMARY KEY, - user_id TEXT NOT NULL, - workout_date DATE NOT NULL, - split_name TEXT NOT NULL, - notes TEXT, - created_at TIMESTAMPTZ NOT NULL + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, + workout_date TEXT NOT NULL, + split_name TEXT NOT NULL, + notes TEXT, + created_at TEXT NOT NULL ) """) - await conn.execute(""" + await self._conn.execute(""" CREATE INDEX IF NOT EXISTS idx_workout_sessions_user_date ON workout_sessions (user_id, workout_date DESC) """) - await conn.execute(""" + await self._conn.execute(""" CREATE INDEX IF NOT EXISTS idx_workout_sessions_user_split ON workout_sessions (user_id, split_name) """) - await conn.execute(""" + await self._conn.execute(""" CREATE TABLE IF NOT EXISTS workout_exercises ( - id BIGSERIAL PRIMARY KEY, - session_id BIGINT NOT NULL REFERENCES workout_sessions(id) - ON DELETE CASCADE, - exercise_name TEXT NOT NULL, - sets JSONB NOT NULL, - exercise_order INTEGER NOT NULL DEFAULT 0, - notes TEXT + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id INTEGER NOT NULL, + exercise_name TEXT NOT NULL, + sets TEXT NOT NULL, + exercise_order INTEGER NOT NULL DEFAULT 0, + notes TEXT, + FOREIGN KEY (session_id) + REFERENCES workout_sessions(id) ON DELETE CASCADE ) """) - await conn.execute(""" + await self._conn.execute(""" CREATE INDEX IF NOT EXISTS idx_workout_exercises_session ON workout_exercises (session_id) """) async def create_session(self, session: WorkoutSession) -> int: """Create session row + all exercises atomically.""" - async with self._pool.acquire() as conn, conn.transaction(): - sid = await conn.fetchval( - """ + async with self._lock: + await self._conn.execute("BEGIN") + try: + cursor = await self._conn.execute( + """ INSERT INTO workout_sessions (user_id, workout_date, split_name, notes, created_at) - VALUES ($1, $2, $3, $4, $5) - RETURNING id + VALUES (?, ?, ?, ?, ?) """, - session.user_id, - session.workout_date, - session.split_name, - session.notes, - session.created_at, - ) - - for exercise in session.exercises: - sets_json = json.dumps([s.model_dump() for s in exercise.sets]) - await conn.execute( - """ + ( + session.user_id, + session.workout_date, + session.split_name, + session.notes, + session.created_at, + ), + ) + sid = cursor.lastrowid + if sid is None: + raise RuntimeError("Failed to get lastrowid after session insert") + + for exercise in session.exercises: + sets_json = json.dumps([s.model_dump() for s in exercise.sets]) + await self._conn.execute( + """ INSERT INTO workout_exercises (session_id, exercise_name, sets, exercise_order, notes) - VALUES ($1, $2, $3::jsonb, $4, $5) + VALUES (?, ?, ?, ?, ?) """, - sid, - exercise.exercise_name, - sets_json, - exercise.exercise_order, - exercise.notes, - ) - return int(sid) + ( + sid, + exercise.exercise_name, + sets_json, + exercise.exercise_order, + exercise.notes, + ), + ) + await self._conn.commit() + return sid + except Exception: + await self._conn.rollback() + raise async def add_exercise(self, session_id: int, exercise: WorkoutExercise) -> int: """Add one exercise to an existing session.""" sets_json = json.dumps([s.model_dump() for s in exercise.sets]) - eid = await self._pool.fetchval( + eid = await self._execute( """ INSERT INTO workout_exercises (session_id, exercise_name, sets, exercise_order, notes) - VALUES ($1, $2, $3::jsonb, $4, $5) - RETURNING id + VALUES (?, ?, ?, ?, ?) """, - session_id, - exercise.exercise_name, - sets_json, - exercise.exercise_order, - exercise.notes, + ( + session_id, + exercise.exercise_name, + sets_json, + exercise.exercise_order, + exercise.notes, + ), ) - return int(eid) + return eid async def update_exercise( self, @@ -167,72 +185,73 @@ async def update_exercise( notes: str | None = None, ) -> bool: """Update sets/notes for an exercise. Needs user_id for authorization.""" - owner = await self._pool.fetchval( + owner_row = await self._fetch_one( """ SELECT s.user_id FROM workout_sessions s JOIN workout_exercises e ON s.id = e.session_id - WHERE e.id = $1 + WHERE e.id = ? """, - exercise_id, + (exercise_id,), ) - if owner != user_id: # pragma: no cover + if owner_row is None or owner_row["user_id"] != user_id: return False updates: list[str] = [] values: list[Any] = [] - if sets is not None: # pragma: no cover - updates.append(f"sets = ${len(values) + 1}::jsonb") + if sets is not None: + updates.append("sets = ?") values.append(json.dumps([s.model_dump() for s in sets])) if notes is not None: - updates.append(f"notes = ${len(values) + 1}") + updates.append("notes = ?") values.append(notes) - if not updates: # pragma: no cover + if not updates: return False updates_str = ", ".join(updates) - query = ( - f"UPDATE workout_exercises SET {updates_str} WHERE id = ${len(values) + 1}" # noqa: S608 - ) values.append(exercise_id) + query = f"UPDATE workout_exercises SET {updates_str} WHERE id = ?" # noqa: S608 - result = await self._pool.execute(query, *values) - return bool(result == "UPDATE 1") + async with self._lock: + cursor = await self._conn.execute(query, values) + return cursor.rowcount > 0 async def delete_exercise(self, exercise_id: int, user_id: str) -> bool: """Remove one exercise from a session.""" - owner = await self._pool.fetchval( + owner_row = await self._fetch_one( """ SELECT s.user_id FROM workout_sessions s JOIN workout_exercises e ON s.id = e.session_id - WHERE e.id = $1 + WHERE e.id = ? """, - exercise_id, + (exercise_id,), ) - if owner != user_id: # pragma: no cover + if owner_row is None or owner_row["user_id"] != user_id: return False - result = await self._pool.execute( - "DELETE FROM workout_exercises WHERE id = $1", exercise_id - ) - return bool(result == "DELETE 1") + async with self._lock: + cursor = await self._conn.execute( + "DELETE FROM workout_exercises WHERE id = ?", (exercise_id,) + ) + return cursor.rowcount > 0 async def get_session(self, session_id: int, user_id: str) -> WorkoutSession | None: """Get full session with exercises.""" - row = await self._pool.fetchrow( - "SELECT * FROM workout_sessions WHERE id = $1 AND user_id = $2", - session_id, - user_id, + row = await self._fetch_one( + "SELECT * FROM workout_sessions WHERE id = ? AND user_id = ?", + (session_id, user_id), ) - if not row: # pragma: no cover + if not row: return None session = self._row_to_session(row) - ex_rows = await self._pool.fetch( - "SELECT * FROM workout_exercises WHERE session_id = $1 " - "ORDER BY exercise_order ASC, id ASC", - session_id, + ex_rows = await self._fetch_all( + """ + SELECT * FROM workout_exercises WHERE session_id = ? + ORDER BY exercise_order ASC, id ASC + """, + (session_id,), ) session.exercises = [self._row_to_exercise(r) for r in ex_rows] return session @@ -241,17 +260,16 @@ async def get_latest_split_session( self, user_id: str, split_name: str ) -> WorkoutSession | None: """Returns the most recent session for a given split.""" - row = await self._pool.fetchrow( + row = await self._fetch_one( """ SELECT * FROM workout_sessions - WHERE user_id = $1 AND split_name = $2 + WHERE user_id = ? AND split_name = ? ORDER BY workout_date DESC, created_at DESC LIMIT 1 """, - user_id, - split_name, + (user_id, split_name), ) - if not row: # pragma: no cover + if not row: return None return await self.get_session(row["id"], user_id) @@ -260,19 +278,18 @@ async def get_recent_sessions( self, user_id: str, limit: int = 10 ) -> list[WorkoutSessionSummary]: """Returns lightweight view of recent sessions.""" - limit = min(limit, 20) # Capped at 20 - rows = await self._pool.fetch( + limit = min(limit, 20) + rows = await self._fetch_all( """ SELECT s.id, s.workout_date, s.split_name, COUNT(e.id) as exercise_count FROM workout_sessions s LEFT JOIN workout_exercises e ON s.id = e.session_id - WHERE s.user_id = $1 + WHERE s.user_id = ? GROUP BY s.id ORDER BY s.workout_date DESC, s.created_at DESC - LIMIT $2 + LIMIT ? """, - user_id, - limit, + (user_id, limit), ) return [ WorkoutSessionSummary( @@ -288,19 +305,17 @@ async def get_exercise_history( self, user_id: str, exercise_name: str, limit: int = 8 ) -> list[ExerciseHistoryEntry]: """Returns the last N instances of a specific exercise.""" - limit = min(limit, 8) # Capped at 8 - rows = await self._pool.fetch( + limit = min(limit, 8) + rows = await self._fetch_all( """ SELECT s.workout_date, s.split_name, e.sets FROM workout_exercises e JOIN workout_sessions s ON e.session_id = s.id - WHERE s.user_id = $1 AND e.exercise_name = $2 + WHERE s.user_id = ? AND e.exercise_name = ? ORDER BY s.workout_date DESC, s.created_at DESC - LIMIT $3 + LIMIT ? """, - user_id, - exercise_name.lower(), - limit, + (user_id, exercise_name.lower(), limit), ) history = [] @@ -315,9 +330,9 @@ async def get_exercise_history( volume = 0.0 for s in sets: - if not s.is_warmup: # pragma: no cover + if not s.is_warmup: volume += s.weight_kg * s.reps - if s.weight_kg > best_weight or ( # pragma: no cover + if s.weight_kg > best_weight or ( s.weight_kg == best_weight and s.reps > best_reps ): best_weight = s.weight_kg @@ -338,29 +353,25 @@ async def get_exercise_history( async def delete_session(self, session_id: int, user_id: str) -> bool: """Cascades to exercises.""" - result = await self._pool.execute( - "DELETE FROM workout_sessions WHERE id = $1 AND user_id = $2", - session_id, - user_id, - ) - return bool(result == "DELETE 1") + async with self._lock: + cursor = await self._conn.execute( + "DELETE FROM workout_sessions WHERE id = ? AND user_id = ?", + (session_id, user_id), + ) + return cursor.rowcount > 0 - def _row_to_session(self, row: Mapping[str, Any]) -> WorkoutSession: + def _row_to_session(self, row: dict[str, Any]) -> WorkoutSession: return WorkoutSession( id=int(row["id"]), user_id=row["user_id"], workout_date=str(row["workout_date"]), split_name=row["split_name"], notes=row["notes"], - created_at=( - row["created_at"].isoformat() - if hasattr(row["created_at"], "isoformat") - else str(row["created_at"]) - ), + created_at=row["created_at"], exercises=[], ) - def _row_to_exercise(self, row: Mapping[str, Any]) -> WorkoutExercise: + def _row_to_exercise(self, row: dict[str, Any]) -> WorkoutExercise: sets_data = ( json.loads(row["sets"]) if isinstance(row["sets"], str) else row["sets"] ) @@ -374,11 +385,11 @@ def _row_to_exercise(self, row: Mapping[str, Any]) -> WorkoutExercise: ) -_storage: PostgresWorkoutStorage | None = None +_storage: SqliteWorkoutStorage | None = None -def get_storage() -> PostgresWorkoutStorage: - """Return the process-wide singleton PostgresWorkoutStorage instance. +def get_storage() -> SqliteWorkoutStorage: + """Return the process-wide singleton SqliteWorkoutStorage instance. Uses the AppContainer for dependency injection. """ @@ -388,51 +399,6 @@ def get_storage() -> PostgresWorkoutStorage: storage = container.workout_storage if not storage.is_initialized: raise RuntimeError( - "Workout storage not initialized. Call init_workout_storage() first." + "Workout storage not initialized. Call storage.initialize() first." ) return storage - - -async def init_workout_storage(pool: asyncpg.Pool) -> PostgresWorkoutStorage: - """Initialize the workout storage with a Postgres pool. - - Note: This function is provided for backward compatibility. - Prefer using AppContainer directly for new code. - """ - global _storage - import blacki.container as container_module - - if container_module._container is None: # pragma: no cover - container_module.set_container_from_pool(pool) - - if _storage is not None: - await _storage.close() - _storage = None - - container = container_module._container - if container is None: # pragma: no cover - raise RuntimeError("Container not initialized") - if container._workout_storage is not None: # pragma: no cover - await container._workout_storage.close() - - storage = container.workout_storage - await storage.initialize() - _storage = storage - return storage - - -async def close_workout_storage() -> None: - """Close the singleton workout storage. - - Note: This function is provided for backward compatibility. - Prefer using AppContainer.close() for new code. - """ - global _storage - import blacki.container as container_module - - if container_module._container is not None: # pragma: no cover - container = container_module._container - if container._workout_storage is not None: - await container._workout_storage.close() - container._workout_storage = None - _storage = None diff --git a/tests/calories/test_storage.py b/tests/calories/test_storage.py index 2adb55c..1a5820b 100644 --- a/tests/calories/test_storage.py +++ b/tests/calories/test_storage.py @@ -1,212 +1,341 @@ # mypy: disable-error-code="no-untyped-def" -from unittest.mock import AsyncMock, MagicMock +"""Unit tests for calorie storage.""" +import asyncio + +import aiosqlite import pytest from blacki.calories.storage import ( CalorieEntry, - PostgresCalorieStorage, - close_calorie_storage, - get_storage, - init_calorie_storage, + SqliteCalorieStorage, ) @pytest.fixture -def mock_pool(): - pool = MagicMock() - conn = AsyncMock() - conn.execute = AsyncMock() - pool.acquire.return_value.__aenter__.return_value = conn - pool.execute = AsyncMock() - pool.fetch = AsyncMock() - pool.fetchval = AsyncMock() - return pool +async def conn(): + """Create an in-memory SQLite connection for testing.""" + conn = await aiosqlite.connect(":memory:") + conn.row_factory = aiosqlite.Row + yield conn + await conn.close() @pytest.fixture -async def calorie_storage(mock_pool): - storage = PostgresCalorieStorage(mock_pool) - await storage.initialize() - yield storage - await storage.close() +def lock(): + """Create a lock for write operations.""" + return asyncio.Lock() -@pytest.mark.asyncio -async def test_initialize_creates_tables(mock_pool) -> None: - conn = mock_pool.acquire.return_value.__aenter__.return_value - conn.fetchval.return_value = "integer" - - storage = PostgresCalorieStorage(mock_pool) +@pytest.fixture +async def storage(conn, lock): + """Create a storage instance with the test connection.""" + storage = SqliteCalorieStorage(conn, lock) await storage.initialize() + yield storage + await storage.close() - assert conn.execute.call_count == 3 - assert conn.fetchval.call_count == 1 - assert storage._schema_ready is True - - -@pytest.mark.asyncio -async def test_add_entry(calorie_storage, mock_pool) -> None: - mock_pool.fetchval.return_value = 123 - entry = CalorieEntry( - user_id="user1", - description="apple", - calories=95, - logged_at="2026-04-26T10:00:00", - logged_date="2026-04-26", - ) - - entry_id = await calorie_storage.add_entry(entry) - - assert entry_id == 123 - mock_pool.fetchval.assert_called_once() - args = mock_pool.fetchval.call_args[0] - assert args[1] == "user1" - assert args[2] == "apple" - assert args[3] == 95 - - -@pytest.mark.asyncio -async def test_get_daily_summary(calorie_storage, mock_pool) -> None: - mock_pool.fetch.return_value = [ - { - "id": 1, - "user_id": "user1", - "description": "apple", - "calories": 100, - "protein_g": None, - "carbs_g": 25, - "fat_g": None, - "meal_type": "snack", - "logged_at": "2026-04-26T10:00:00", - "logged_date": "2026-04-26", - }, - { - "id": 2, - "user_id": "user1", - "description": "egg", - "calories": 70, - "protein_g": 6, - "carbs_g": None, - "fat_g": 5, - "meal_type": "breakfast", - "logged_at": "2026-04-26T11:00:00", - "logged_date": "2026-04-26", - }, - ] - - summary = await calorie_storage.get_daily_summary("user1", "2026-04-26") - - assert summary.date == "2026-04-26" - assert summary.entry_count == 2 - assert summary.total_calories == 170 - assert summary.total_protein_g == 6 - assert summary.total_carbs_g == 25 - assert summary.total_fat_g == 5 - assert len(summary.entries) == 2 - - -@pytest.mark.asyncio -async def test_get_date_range_summary(calorie_storage, mock_pool) -> None: - mock_pool.fetch.return_value = [ - { - "logged_date": "2026-04-26", - "entry_count": 2, - "total_calories": 500, - "total_protein_g": 20, - "total_carbs_g": 50, - "total_fat_g": 10, - }, - { - "logged_date": "2026-04-25", - "entry_count": 3, - "total_calories": 2000, - "total_protein_g": 100, - "total_carbs_g": 200, - "total_fat_g": 50, - }, - ] - - summaries = await calorie_storage.get_date_range_summary( - "user1", "2026-04-20", "2026-04-26" - ) - - assert len(summaries) == 2 - assert summaries[0].date == "2026-04-26" - assert summaries[0].total_calories == 500 - assert summaries[1].date == "2026-04-25" - assert summaries[1].total_calories == 2000 - - -@pytest.mark.asyncio -async def test_update_entry(calorie_storage, mock_pool) -> None: - mock_pool.execute.return_value = "UPDATE 1" - - result = await calorie_storage.update_entry( - 1, "user1", calories=200, meal_type="lunch" - ) - - assert result is True - mock_pool.execute.assert_called_once() - args = mock_pool.execute.call_args[0] - assert "calories = $3" in args[0] - assert "meal_type = $4" in args[0] - assert args[1] == 1 - assert args[2] == "user1" - assert args[3] == 200 - assert args[4] == "lunch" - - -@pytest.mark.asyncio -async def test_delete_entry(calorie_storage, mock_pool) -> None: - mock_pool.execute.return_value = "DELETE 1" - - result = await calorie_storage.delete_entry(1, "user1") - - assert result is True - mock_pool.execute.assert_called_once_with( - "DELETE FROM calorie_logs WHERE id = $1 AND user_id = $2", - 1, - "user1", - ) - - -@pytest.mark.asyncio -async def test_singleton(mock_pool) -> None: - import blacki.calories.storage as storage - - storage._storage = None - - with pytest.raises(RuntimeError): - get_storage() - - instance = await init_calorie_storage(mock_pool) - assert get_storage() is instance - - await close_calorie_storage() - with pytest.raises(RuntimeError): - get_storage() - - -@pytest.mark.asyncio -async def test_update_entry_invalid_column(calorie_storage, mock_pool) -> None: - """update_entry raises ValueError for columns not in whitelist.""" - with pytest.raises(ValueError, match="Column 'bogus' is not allowed"): - await calorie_storage.update_entry(1, "user1", bogus="value") - - -@pytest.mark.asyncio -async def test_reinit_calorie_storage_closes_existing(mock_pool) -> None: - """init_calorie_storage closes existing storage before replacing.""" - import blacki.calories.storage as storage - - existing = PostgresCalorieStorage(mock_pool) - existing.close = AsyncMock() # type: ignore[method-assign] - storage._storage = existing - new = await init_calorie_storage(mock_pool) - - existing.close.assert_awaited_once() - assert storage._storage is new - - storage._storage = None +class TestSqliteCalorieStorage: + """Tests for SqliteCalorieStorage.""" + + @pytest.mark.asyncio + async def test_initialize_creates_tables(self, conn, lock) -> None: + """Should create tables on initialization.""" + storage = SqliteCalorieStorage(conn, lock) + await storage.initialize() + + assert storage.is_initialized is True + + cursor = await conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='calorie_logs'" + ) + row = await cursor.fetchone() + assert row is not None + + @pytest.mark.asyncio + async def test_add_entry(self, storage) -> None: + """Should add an entry and return its ID.""" + entry = CalorieEntry( + user_id="user1", + description="apple", + calories=95, + logged_at="2026-04-26T10:00:00", + logged_date="2026-04-26", + ) + + entry_id = await storage.add_entry(entry) + + assert entry_id == 1 + + @pytest.mark.asyncio + async def test_add_entry_with_macros(self, storage) -> None: + """Should add an entry with macro nutrients.""" + entry = CalorieEntry( + user_id="user1", + description="protein shake", + calories=200, + protein_g=30.0, + carbs_g=10.0, + fat_g=5.0, + meal_type="snack", + logged_at="2026-04-26T10:00:00", + logged_date="2026-04-26", + ) + + entry_id = await storage.add_entry(entry) + + assert entry_id == 1 + + summary = await storage.get_daily_summary("user1", "2026-04-26") + assert summary.total_protein_g == 30.0 + assert summary.total_carbs_g == 10.0 + assert summary.total_fat_g == 5.0 + + @pytest.mark.asyncio + async def test_get_daily_summary(self, storage) -> None: + """Should get daily summary with entries.""" + entry1 = CalorieEntry( + user_id="user1", + description="apple", + calories=100, + carbs_g=25.0, + meal_type="snack", + logged_at="2026-04-26T10:00:00", + logged_date="2026-04-26", + ) + entry2 = CalorieEntry( + user_id="user1", + description="egg", + calories=70, + protein_g=6.0, + fat_g=5.0, + meal_type="breakfast", + logged_at="2026-04-26T11:00:00", + logged_date="2026-04-26", + ) + await storage.add_entry(entry1) + await storage.add_entry(entry2) + + summary = await storage.get_daily_summary("user1", "2026-04-26") + + assert summary.date == "2026-04-26" + assert summary.entry_count == 2 + assert summary.total_calories == 170 + assert summary.total_protein_g == 6.0 + assert summary.total_carbs_g == 25.0 + assert summary.total_fat_g == 5.0 + assert len(summary.entries) == 2 + + @pytest.mark.asyncio + async def test_get_daily_summary_empty(self, storage) -> None: + """Should return empty summary for date with no entries.""" + summary = await storage.get_daily_summary("user1", "2026-04-26") + + assert summary.date == "2026-04-26" + assert summary.entry_count == 0 + assert summary.total_calories == 0 + assert len(summary.entries) == 0 + + @pytest.mark.asyncio + async def test_get_date_range_summary(self, storage) -> None: + """Should get summaries for date range.""" + for day in range(20, 27): + entry = CalorieEntry( + user_id="user1", + description=f"food day {day}", + calories=500 + day, + protein_g=20.0 + day, + logged_at=f"2026-04-{day:02d}T10:00:00", + logged_date=f"2026-04-{day:02d}", + ) + await storage.add_entry(entry) + + summaries = await storage.get_date_range_summary( + "user1", "2026-04-20", "2026-04-26" + ) + + assert len(summaries) == 7 + assert summaries[0].date == "2026-04-26" + assert summaries[0].total_calories == 526 + + @pytest.mark.asyncio + async def test_update_entry(self, storage) -> None: + """Should update an entry.""" + entry = CalorieEntry( + user_id="user1", + description="apple", + calories=100, + logged_at="2026-04-26T10:00:00", + logged_date="2026-04-26", + ) + entry_id = await storage.add_entry(entry) + + result = await storage.update_entry( + entry_id, "user1", calories=200, meal_type="lunch" + ) + + assert result is True + summary = await storage.get_daily_summary("user1", "2026-04-26") + assert summary.entries[0].calories == 200 + assert summary.entries[0].meal_type == "lunch" + + @pytest.mark.asyncio + async def test_update_entry_wrong_user(self, storage) -> None: + """Should not update entry belonging to different user.""" + entry = CalorieEntry( + user_id="user1", + description="apple", + calories=100, + logged_at="2026-04-26T10:00:00", + logged_date="2026-04-26", + ) + entry_id = await storage.add_entry(entry) + + result = await storage.update_entry(entry_id, "user2", calories=200) + + assert result is False + summary = await storage.get_daily_summary("user1", "2026-04-26") + assert summary.entries[0].calories == 100 + + @pytest.mark.asyncio + async def test_update_entry_invalid_column(self, storage) -> None: + """Should raise ValueError for invalid column.""" + entry = CalorieEntry( + user_id="user1", + description="apple", + calories=100, + logged_at="2026-04-26T10:00:00", + logged_date="2026-04-26", + ) + entry_id = await storage.add_entry(entry) + + with pytest.raises(ValueError, match="Column 'bogus' is not allowed"): + await storage.update_entry(entry_id, "user1", bogus="value") + + @pytest.mark.asyncio + async def test_delete_entry(self, storage) -> None: + """Should delete an entry.""" + entry = CalorieEntry( + user_id="user1", + description="apple", + calories=100, + logged_at="2026-04-26T10:00:00", + logged_date="2026-04-26", + ) + entry_id = await storage.add_entry(entry) + + result = await storage.delete_entry(entry_id, "user1") + + assert result is True + summary = await storage.get_daily_summary("user1", "2026-04-26") + assert summary.entry_count == 0 + + @pytest.mark.asyncio + async def test_delete_entry_wrong_user(self, storage) -> None: + """Should not delete entry belonging to different user.""" + entry = CalorieEntry( + user_id="user1", + description="apple", + calories=100, + logged_at="2026-04-26T10:00:00", + logged_date="2026-04-26", + ) + entry_id = await storage.add_entry(entry) + + result = await storage.delete_entry(entry_id, "user2") + + assert result is False + summary = await storage.get_daily_summary("user1", "2026-04-26") + assert summary.entry_count == 1 + + @pytest.mark.asyncio + async def test_delete_entry_not_found(self, storage) -> None: + """Should return False for non-existent entry.""" + result = await storage.delete_entry(999, "user1") + + assert result is False + + @pytest.mark.asyncio + async def test_multiple_users_isolated(self, storage) -> None: + """Should isolate data between users.""" + entry1 = CalorieEntry( + user_id="user1", + description="apple", + calories=100, + logged_at="2026-04-26T10:00:00", + logged_date="2026-04-26", + ) + entry2 = CalorieEntry( + user_id="user2", + description="banana", + calories=150, + logged_at="2026-04-26T11:00:00", + logged_date="2026-04-26", + ) + await storage.add_entry(entry1) + await storage.add_entry(entry2) + + summary1 = await storage.get_daily_summary("user1", "2026-04-26") + summary2 = await storage.get_daily_summary("user2", "2026-04-26") + + assert summary1.entry_count == 1 + assert summary1.total_calories == 100 + assert summary2.entry_count == 1 + assert summary2.total_calories == 150 + + @pytest.mark.asyncio + async def test_update_entry_no_fields_returns_false(self, storage) -> None: + """Should return False when no fields provided for update.""" + entry = CalorieEntry( + user_id="user1", + description="apple", + calories=100, + logged_at="2026-04-26T10:00:00", + logged_date="2026-04-26", + ) + entry_id = await storage.add_entry(entry) + + result = await storage.update_entry(entry_id, "user1") + + assert result is False + + +class TestGetStorage: + """Tests for get_storage function.""" + + @pytest.mark.asyncio + async def test_get_storage_raises_when_not_initialized(self, conn, lock) -> None: + """Should raise RuntimeError when storage is not initialized.""" + from blacki.calories.storage import get_storage + from blacki.container import ( + reset_container_for_tests, + set_container_from_connection, + ) + + set_container_from_connection(conn, lock) + + with pytest.raises(RuntimeError, match="Calorie storage not initialized"): + get_storage() + + reset_container_for_tests() + + @pytest.mark.asyncio + async def test_get_storage_returns_storage_when_initialized( + self, conn, lock + ) -> None: + """Should return storage when initialized.""" + from blacki.calories.storage import get_storage + from blacki.container import ( + reset_container_for_tests, + set_container_from_connection, + ) + + container = set_container_from_connection(conn, lock) + await container.calorie_storage.initialize() + + result = get_storage() + + assert result is container.calorie_storage + + reset_container_for_tests() diff --git a/tests/reminders/test_storage.py b/tests/reminders/test_storage.py index 353f1e7..14936d3 100644 --- a/tests/reminders/test_storage.py +++ b/tests/reminders/test_storage.py @@ -1,20 +1,42 @@ +# mypy: disable-error-code="no-untyped-def" """Unit tests for reminder storage.""" -from unittest.mock import AsyncMock, MagicMock +import asyncio -import asyncpg # type: ignore[import-untyped] +import aiosqlite import pytest from blacki.reminders.storage import ( DUE_REMINDERS_FETCH_LIMIT, - PostgresReminderStorage, Reminder, - close_reminder_storage, - get_storage, - init_reminder_storage, + SqliteReminderStorage, ) +@pytest.fixture +async def conn(): + """Create an in-memory SQLite connection for testing.""" + conn = await aiosqlite.connect(":memory:") + conn.row_factory = aiosqlite.Row + yield conn + await conn.close() + + +@pytest.fixture +def lock(): + """Create a lock for write operations.""" + return asyncio.Lock() + + +@pytest.fixture +async def storage(conn, lock): + """Create a storage instance with the test connection.""" + storage = SqliteReminderStorage(conn, lock) + await storage.initialize() + yield storage + await storage.close() + + class TestReminder: """Tests for Reminder model.""" @@ -61,49 +83,26 @@ def test_is_recurring_true_for_recurring(self) -> None: assert reminder.is_recurring is True -class TestPostgresReminderStorage: - """Tests for PostgresReminderStorage.""" - - @pytest.fixture - def mock_pool(self) -> MagicMock: - """Create a mock asyncpg Pool.""" - pool = MagicMock(spec=asyncpg.Pool) - pool.acquire = MagicMock() - pool.fetch = AsyncMock() - pool.fetchval = AsyncMock() - pool.execute = AsyncMock() - return pool - - @pytest.fixture - def mock_connection(self) -> MagicMock: - """Create a mock asyncpg Connection.""" - conn = MagicMock(spec=asyncpg.Connection) - conn.execute = AsyncMock() - return conn +class TestSqliteReminderStorage: + """Tests for SqliteReminderStorage.""" @pytest.mark.asyncio - async def test_initialize_creates_tables( - self, mock_pool: MagicMock, mock_connection: MagicMock - ) -> None: + async def test_initialize_creates_tables(self, conn, lock) -> None: """Should create tables on initialization.""" - mock_pool.acquire.return_value.__aenter__ = AsyncMock( - return_value=mock_connection - ) - mock_pool.acquire.return_value.__aexit__ = AsyncMock() - - storage = PostgresReminderStorage(mock_pool) + storage = SqliteReminderStorage(conn, lock) await storage.initialize() - assert mock_connection.execute.call_count >= 1 + assert storage.is_initialized is True + + cursor = await conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='reminders'" + ) + row = await cursor.fetchone() + assert row is not None @pytest.mark.asyncio - async def test_add_reminder(self, mock_pool: MagicMock) -> None: + async def test_add_reminder(self, storage) -> None: """Should add a reminder and return its ID.""" - mock_pool.fetchval.return_value = 42 - - storage = PostgresReminderStorage(mock_pool) - storage._schema_ready = True - reminder = Reminder( user_id="user1", message="Test reminder", @@ -113,90 +112,103 @@ async def test_add_reminder(self, mock_pool: MagicMock) -> None: result = await storage.add_reminder(reminder) - assert result == 42 - mock_pool.fetchval.assert_called_once() + assert result == 1 @pytest.mark.asyncio - async def test_get_due_reminders(self, mock_pool: MagicMock) -> None: + async def test_get_due_reminders(self, storage) -> None: """Should fetch due reminders.""" - mock_pool.fetch.return_value = [ - { - "id": 1, - "user_id": "user1", - "message": "Test", - "trigger_time": "2026-04-18T12:00:00+00:00", - "is_sent": False, - "recurrence_rule": None, - "recurrence_text": None, - "timezone_name": None, - "created_at": "2026-04-18T10:00:00+00:00", - } - ] - - storage = PostgresReminderStorage(mock_pool) - storage._schema_ready = True + reminder = Reminder( + user_id="user1", + message="Test", + trigger_time="2020-01-01T00:00:00+00:00", + created_at="2026-04-18T10:00:00+00:00", + ) + await storage.add_reminder(reminder) result = await storage.get_due_reminders() assert len(result) == 1 - assert result[0].id == 1 assert result[0].message == "Test" - fetch_sql = mock_pool.fetch.call_args[0][0] - assert "LIMIT $2" in fetch_sql - assert mock_pool.fetch.call_args[0][2] == DUE_REMINDERS_FETCH_LIMIT @pytest.mark.asyncio - async def test_mark_sent(self, mock_pool: MagicMock) -> None: + async def test_get_due_reminders_respects_limit(self, conn, lock) -> None: + """Should limit the number of due reminders fetched.""" + storage = SqliteReminderStorage(conn, lock) + await storage.initialize() + + for i in range(DUE_REMINDERS_FETCH_LIMIT + 10): + reminder = Reminder( + user_id="user1", + message=f"Test {i}", + trigger_time="2020-01-01T00:00:00+00:00", + created_at="2026-04-18T10:00:00+00:00", + ) + await storage.add_reminder(reminder) + + result = await storage.get_due_reminders() + + assert len(result) == DUE_REMINDERS_FETCH_LIMIT + + @pytest.mark.asyncio + async def test_mark_sent(self, storage) -> None: """Should mark a reminder as sent.""" - storage = PostgresReminderStorage(mock_pool) - storage._schema_ready = True + reminder = Reminder( + user_id="user1", + message="Test", + trigger_time="2026-04-18T12:00:00+00:00", + created_at="2026-04-18T10:00:00+00:00", + ) + rid = await storage.add_reminder(reminder) - await storage.mark_sent(42) + await storage.mark_sent(rid) - mock_pool.execute.assert_called_once() + rows = await storage.get_user_reminders("user1", include_sent=True) + assert rows[0].is_sent is True @pytest.mark.asyncio - async def test_reschedule_reminder(self, mock_pool: MagicMock) -> None: + async def test_reschedule_reminder(self, storage) -> None: """Should reschedule a recurring reminder.""" - storage = PostgresReminderStorage(mock_pool) - storage._schema_ready = True + reminder = Reminder( + user_id="user1", + message="Test", + trigger_time="2026-04-18T12:00:00+00:00", + created_at="2026-04-18T10:00:00+00:00", + ) + rid = await storage.add_reminder(reminder) - await storage.reschedule_reminder(42, "2026-04-19T12:00:00+00:00") + await storage.reschedule_reminder(rid, "2026-04-19T12:00:00+00:00") - mock_pool.execute.assert_called_once() + rows = await storage.get_user_reminders("user1") + assert rows[0].trigger_time == "2026-04-19T12:00:00+00:00" + assert rows[0].is_sent is False @pytest.mark.asyncio - async def test_get_user_reminders(self, mock_pool: MagicMock) -> None: + async def test_get_user_reminders(self, storage) -> None: """Should get reminders for a user.""" - mock_pool.fetch.return_value = [] - - storage = PostgresReminderStorage(mock_pool) - storage._schema_ready = True + reminder = Reminder( + user_id="user1", + message="Test", + trigger_time="2026-04-18T12:00:00+00:00", + created_at="2026-04-18T10:00:00+00:00", + ) + await storage.add_reminder(reminder) result = await storage.get_user_reminders("user1") - assert result == [] - mock_pool.fetch.assert_called_once() + assert len(result) == 1 + assert result[0].message == "Test" @pytest.mark.asyncio - async def test_get_user_reminders_include_sent(self, mock_pool: MagicMock) -> None: + async def test_get_user_reminders_include_sent(self, storage) -> None: """Should include sent reminders when requested.""" - mock_pool.fetch.return_value = [ - { - "id": 1, - "user_id": "user1", - "message": "Test", - "trigger_time": "2026-04-18T12:00:00+00:00", - "is_sent": True, - "recurrence_rule": None, - "recurrence_text": None, - "timezone_name": None, - "created_at": "2026-04-18T10:00:00+00:00", - } - ] - - storage = PostgresReminderStorage(mock_pool) - storage._schema_ready = True + reminder = Reminder( + user_id="user1", + message="Test", + trigger_time="2026-04-18T12:00:00+00:00", + created_at="2026-04-18T10:00:00+00:00", + ) + rid = await storage.add_reminder(reminder) + await storage.mark_sent(rid) result = await storage.get_user_reminders("user1", include_sent=True) @@ -204,135 +216,117 @@ async def test_get_user_reminders_include_sent(self, mock_pool: MagicMock) -> No assert result[0].is_sent is True @pytest.mark.asyncio - async def test_delete_reminder_found(self, mock_pool: MagicMock) -> None: - """Should delete a reminder and return True.""" - mock_pool.execute.return_value = "DELETE 1" + async def test_get_user_reminders_excludes_sent_by_default(self, storage) -> None: + """Should exclude sent reminders by default.""" + reminder = Reminder( + user_id="user1", + message="Test", + trigger_time="2026-04-18T12:00:00+00:00", + created_at="2026-04-18T10:00:00+00:00", + ) + rid = await storage.add_reminder(reminder) + await storage.mark_sent(rid) + + result = await storage.get_user_reminders("user1") + + assert len(result) == 0 - storage = PostgresReminderStorage(mock_pool) - storage._schema_ready = True + @pytest.mark.asyncio + async def test_delete_reminder_found(self, storage) -> None: + """Should delete a reminder and return True.""" + reminder = Reminder( + user_id="user1", + message="Test", + trigger_time="2026-04-18T12:00:00+00:00", + created_at="2026-04-18T10:00:00+00:00", + ) + rid = await storage.add_reminder(reminder) - result = await storage.delete_reminder(42, "user1") + result = await storage.delete_reminder(rid, "user1") assert result is True + rows = await storage.get_user_reminders("user1") + assert len(rows) == 0 @pytest.mark.asyncio - async def test_delete_reminder_not_found(self, mock_pool: MagicMock) -> None: + async def test_delete_reminder_not_found(self, storage) -> None: """Should return False if reminder not found.""" - mock_pool.execute.return_value = "DELETE 0" + result = await storage.delete_reminder(999, "user1") - storage = PostgresReminderStorage(mock_pool) - storage._schema_ready = True + assert result is False + + @pytest.mark.asyncio + async def test_delete_reminder_wrong_user(self, storage) -> None: + """Should return False if reminder belongs to different user.""" + reminder = Reminder( + user_id="user1", + message="Test", + trigger_time="2026-04-18T12:00:00+00:00", + created_at="2026-04-18T10:00:00+00:00", + ) + rid = await storage.add_reminder(reminder) - result = await storage.delete_reminder(42, "user1") + result = await storage.delete_reminder(rid, "user2") assert result is False + rows = await storage.get_user_reminders("user1") + assert len(rows) == 1 @pytest.mark.asyncio - async def test_initialize_returns_early_if_schema_ready( - self, mock_pool: MagicMock - ) -> None: + async def test_initialize_returns_early_if_schema_ready(self, conn, lock) -> None: """Should return early if schema already ready.""" - storage = PostgresReminderStorage(mock_pool) - storage._schema_ready = True + storage = SqliteReminderStorage(conn, lock) + await storage.initialize() await storage.initialize() - mock_pool.acquire.assert_not_called() + assert storage.is_initialized is True @pytest.mark.asyncio - async def test_close_resets_schema_ready(self, mock_pool: MagicMock) -> None: + async def test_close_resets_schema_ready(self, storage) -> None: """Should reset schema ready flag on close.""" - storage = PostgresReminderStorage(mock_pool) - storage._schema_ready = True + assert storage.is_initialized is True await storage.close() - assert storage._schema_ready is False + assert storage.is_initialized is False -class TestStorageSingleton: - """Tests for storage singleton management.""" +class TestGetStorage: + """Tests for get_storage function.""" @pytest.mark.asyncio - async def test_get_storage_raises_if_not_initialized(self) -> None: - """Should raise RuntimeError if storage not initialized.""" - import blacki.reminders.storage as storage_module + async def test_get_storage_raises_when_not_initialized(self, conn, lock) -> None: + """Should raise RuntimeError when storage is not initialized.""" + from blacki.container import ( + reset_container_for_tests, + set_container_from_connection, + ) + from blacki.reminders.storage import get_storage - storage_module._storage = None + set_container_from_connection(conn, lock) - with pytest.raises(RuntimeError, match="not initialized"): + with pytest.raises(RuntimeError, match="Reminder storage not initialized"): get_storage() - @pytest.mark.asyncio - async def test_init_and_get_storage(self) -> None: - """Should initialize and return storage singleton.""" - import blacki.reminders.storage as storage_module - - storage_module._storage = None - - mock_pool = MagicMock(spec=asyncpg.Pool) - mock_pool.acquire = MagicMock() - mock_conn = MagicMock() - mock_conn.execute = AsyncMock() - mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) - mock_pool.acquire.return_value.__aexit__ = AsyncMock() - - storage = await init_reminder_storage(mock_pool) - - assert storage is not None - assert get_storage() is storage - - storage_module._storage = None - - @pytest.mark.asyncio - async def test_close_reminder_storage(self) -> None: - """Should close and reset storage singleton.""" - import blacki.reminders.storage as storage_module - - mock_pool = MagicMock(spec=asyncpg.Pool) - mock_pool.acquire = MagicMock() - mock_conn = MagicMock() - mock_conn.execute = AsyncMock() - mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) - mock_pool.acquire.return_value.__aexit__ = AsyncMock() - - storage = await init_reminder_storage(mock_pool) - assert storage is not None - - await close_reminder_storage() - - assert storage_module._storage is None + reset_container_for_tests() @pytest.mark.asyncio - async def test_reinit_reminder_storage_closes_existing(self) -> None: - """init_reminder_storage closes existing storage before replacing.""" - import blacki.reminders.storage as storage_module - - mock_pool = MagicMock(spec=asyncpg.Pool) - mock_pool.acquire = MagicMock() - mock_conn = MagicMock() - mock_conn.execute = AsyncMock() - mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) - mock_pool.acquire.return_value.__aexit__ = AsyncMock() - - existing = PostgresReminderStorage(mock_pool) - existing.close = AsyncMock() # type: ignore[method-assign] - storage_module._storage = existing - - new = await init_reminder_storage(mock_pool) - - existing.close.assert_awaited_once() - assert storage_module._storage is new - - storage_module._storage = None + async def test_get_storage_returns_storage_when_initialized( + self, conn, lock + ) -> None: + """Should return storage when initialized.""" + from blacki.container import ( + reset_container_for_tests, + set_container_from_connection, + ) + from blacki.reminders.storage import get_storage - @pytest.mark.asyncio - async def test_close_reminder_storage_when_none(self) -> None: - """Should do nothing if storage is already None.""" - import blacki.reminders.storage as storage_module + container = set_container_from_connection(conn, lock) + await container.reminder_storage.initialize() - storage_module._storage = None + result = get_storage() - await close_reminder_storage() + assert result is container.reminder_storage - assert storage_module._storage is None + reset_container_for_tests() diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py new file mode 100644 index 0000000..5f39473 --- /dev/null +++ b/tests/storage/test_base.py @@ -0,0 +1,187 @@ +# mypy: disable-error-code="no-untyped-def" +"""Unit tests for SqlStorage base class.""" + +import asyncio +from unittest.mock import AsyncMock + +import aiosqlite +import pytest + +from blacki.storage.base import SqlStorage + + +class ConcreteStorage(SqlStorage): + """Concrete implementation for testing abstract base class.""" + + async def _create_tables(self) -> None: + await self._conn.execute( + "CREATE TABLE IF NOT EXISTS test_table (id INTEGER PRIMARY KEY, value TEXT)" + ) + + +@pytest.fixture +async def conn(): + """Create an in-memory SQLite connection for testing.""" + conn = await aiosqlite.connect(":memory:") + conn.row_factory = aiosqlite.Row + yield conn + await conn.close() + + +@pytest.fixture +def lock(): + """Create a lock for write operations.""" + return asyncio.Lock() + + +@pytest.fixture +async def storage(conn, lock): + """Create a storage instance with the test connection.""" + storage = ConcreteStorage(conn, lock) + await storage.initialize() + yield storage + await storage.close() + + +class TestSqlStorageExecute: + """Tests for _execute method.""" + + @pytest.mark.asyncio + async def test_execute_with_lock(self, storage) -> None: + """Should execute query with lock by default.""" + rid = await storage._execute( + "INSERT INTO test_table (value) VALUES (?)", ("test_value",) + ) + + assert rid == 1 + + row = await storage._fetch_one("SELECT * FROM test_table WHERE id = ?", (rid,)) + assert row is not None + assert row["value"] == "test_value" + + @pytest.mark.asyncio + async def test_execute_without_lock(self, storage) -> None: + """Should execute query without lock when use_lock=False.""" + rid = await storage._execute( + "INSERT INTO test_table (value) VALUES (?)", + ("test_value",), + use_lock=False, + ) + + assert rid == 1 + + row = await storage._fetch_one("SELECT * FROM test_table WHERE id = ?", (rid,)) + assert row is not None + assert row["value"] == "test_value" + + @pytest.mark.asyncio + async def test_execute_raises_runtime_error_when_lastrowid_none( + self, conn, lock + ) -> None: + """Should raise RuntimeError when lastrowid is None after insert.""" + from unittest.mock import MagicMock + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.lastrowid = None + + async_cm = MagicMock() + async_cm.__aenter__ = AsyncMock(return_value=mock_cursor) + async_cm.__aexit__ = AsyncMock(return_value=None) + mock_conn.execute.return_value = async_cm + + storage = ConcreteStorage(mock_conn, lock) + + with pytest.raises(RuntimeError, match="Failed to get lastrowid after insert"): + await storage._execute( + "INSERT INTO test_table (value) VALUES (?)", ("test",) + ) + + @pytest.mark.asyncio + async def test_execute_without_lock_raises_runtime_error_when_lastrowid_none( + self, conn, lock + ) -> None: + """Should raise RuntimeError when lastrowid is None with use_lock=False.""" + from unittest.mock import MagicMock + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.lastrowid = None + + async_cm = MagicMock() + async_cm.__aenter__ = AsyncMock(return_value=mock_cursor) + async_cm.__aexit__ = AsyncMock(return_value=None) + mock_conn.execute.return_value = async_cm + + storage = ConcreteStorage(mock_conn, lock) + + with pytest.raises(RuntimeError, match="Failed to get lastrowid after insert"): + await storage._execute( + "INSERT INTO test_table (value) VALUES (?)", + ("test",), + use_lock=False, + ) + + +class TestSqlStorageExecuteMany: + """Tests for _execute_many method.""" + + @pytest.mark.asyncio + async def test_execute_many_with_lock(self, storage) -> None: + """Should execute many with lock by default.""" + params_list = [("value1",), ("value2",), ("value3",)] + await storage._execute_many( + "INSERT INTO test_table (value) VALUES (?)", params_list + ) + + rows = await storage._fetch_all("SELECT * FROM test_table ORDER BY id") + assert len(rows) == 3 + assert rows[0]["value"] == "value1" + assert rows[1]["value"] == "value2" + assert rows[2]["value"] == "value3" + + @pytest.mark.asyncio + async def test_execute_many_without_lock(self, storage) -> None: + """Should execute many without lock when use_lock=False.""" + params_list = [("value1",), ("value2",)] + await storage._execute_many( + "INSERT INTO test_table (value) VALUES (?)", + params_list, + use_lock=False, + ) + + rows = await storage._fetch_all("SELECT * FROM test_table ORDER BY id") + assert len(rows) == 2 + + +class TestSqlStorageFetchVal: + """Tests for _fetch_val method.""" + + @pytest.mark.asyncio + async def test_fetch_val_returns_value(self, storage) -> None: + """Should return single value.""" + await storage._execute( + "INSERT INTO test_table (value) VALUES (?)", ("test_value",) + ) + + result = await storage._fetch_val("SELECT value FROM test_table WHERE id = 1") + + assert result == "test_value" + + @pytest.mark.asyncio + async def test_fetch_val_returns_none_when_no_row(self, storage) -> None: + """Should return None when no row found.""" + result = await storage._fetch_val("SELECT value FROM test_table WHERE id = 999") + + assert result is None + + +class TestSqlStorageConn: + """Tests for conn property.""" + + @pytest.mark.asyncio + async def test_conn_returns_underlying_connection(self, storage, conn) -> None: + """Should return the underlying connection.""" + result = storage.conn + + assert result is conn diff --git a/tests/storage/test_sqlite.py b/tests/storage/test_sqlite.py new file mode 100644 index 0000000..271f588 --- /dev/null +++ b/tests/storage/test_sqlite.py @@ -0,0 +1,81 @@ +# mypy: disable-error-code="no-untyped-def" +"""Unit tests for SQLite connection management.""" + +from pathlib import Path + +import pytest + +from blacki.storage.sqlite import close_connection, create_connection + + +class TestCreateConnection: + """Tests for create_connection function.""" + + @pytest.mark.asyncio + async def test_create_connection_creates_file(self, tmp_path: Path) -> None: + """Should create database file and parent directories.""" + db_path = tmp_path / "subdir" / "test.db" + + conn = await create_connection(db_path) + + assert db_path.exists() + assert db_path.parent.exists() + + await conn.close() + + @pytest.mark.asyncio + async def test_create_connection_sets_row_factory(self, tmp_path: Path) -> None: + """Should set row_factory to aiosqlite.Row.""" + db_path = tmp_path / "test.db" + + conn = await create_connection(db_path) + + assert conn.row_factory is not None + + await conn.close() + + @pytest.mark.asyncio + async def test_create_connection_configures_pragmas(self, tmp_path: Path) -> None: + """Should configure WAL mode and other pragmas.""" + db_path = tmp_path / "test.db" + + conn = await create_connection(db_path) + + async with conn.execute("PRAGMA journal_mode") as cursor: + row = await cursor.fetchone() + assert row is not None + assert row[0].lower() == "wal" + + async with conn.execute("PRAGMA foreign_keys") as cursor: + row = await cursor.fetchone() + assert row is not None + assert row[0] == 1 + + await conn.close() + + +class TestCloseConnection: + """Tests for close_connection function.""" + + @pytest.mark.asyncio + async def test_close_connection_closes_connection(self, tmp_path: Path) -> None: + """Should close the SQLite connection.""" + db_path = tmp_path / "test.db" + conn = await create_connection(db_path) + + await close_connection(conn) + + with pytest.raises(ValueError, match="no active connection"): + await conn.execute("SELECT 1") + + @pytest.mark.asyncio + async def test_close_connection_with_memory_db(self) -> None: + """Should close an in-memory connection.""" + import aiosqlite + + conn = await aiosqlite.connect(":memory:") + + await close_connection(conn) + + with pytest.raises(ValueError, match="no active connection"): + await conn.execute("SELECT 1") diff --git a/tests/test_adk_runtime.py b/tests/test_adk_runtime.py index fca27ed..8708485 100644 --- a/tests/test_adk_runtime.py +++ b/tests/test_adk_runtime.py @@ -55,23 +55,26 @@ def test_build_session_service_uri_keeps_agentengine_scheme() -> None: assert build_session_service_uri(env) == "agentengine://test-engine-id" -def test_build_session_db_kwargs_uses_env_values() -> None: - """Test that session DB kwargs are derived from ServerEnv.""" - env = _build_server_env( - DB_POOL_PRE_PING="false", - DB_POOL_RECYCLE="99", - DB_POOL_SIZE="7", - DB_MAX_OVERFLOW="8", - DB_POOL_TIMEOUT="9", - ) +def test_build_session_service_uri_converts_postgresql_to_asyncpg() -> None: + """Test that postgresql:// URIs are converted to postgresql+asyncpg://.""" + env = _build_server_env() - assert build_session_db_kwargs(env) == { - "pool_pre_ping": False, - "pool_recycle": 99, - "pool_size": 7, - "max_overflow": 8, - "pool_timeout": 9, - } + with patch.object( + type(env), + "session_uri", + property(lambda self: "postgresql://user:pass@localhost/db"), + ): + assert ( + build_session_service_uri(env) + == "postgresql+asyncpg://user:pass@localhost/db" + ) + + +def test_build_session_db_kwargs_returns_empty_dict() -> None: + """Test that session DB kwargs returns empty dict for SQLite.""" + env = _build_server_env() + + assert build_session_db_kwargs(env) == {} def test_create_session_service_without_uri_uses_sqlite(tmp_path: Path) -> None: @@ -81,10 +84,11 @@ def test_create_session_service_without_uri_uses_sqlite(tmp_path: Path) -> None: assert isinstance(session_service, DatabaseSessionService) -def test_create_session_service_with_postgres_uri() -> None: - """Test that Postgres session services use DatabaseSessionService.""" +def test_create_session_service_with_sqlite_uri(tmp_path: Path) -> None: + """Test that SQLite session services use DatabaseSessionService.""" + db_path = tmp_path / "sessions.db" session_service = create_session_service( - "postgresql+asyncpg://user:pass@localhost/db", + f"sqlite+aiosqlite:///{db_path}", {}, ) diff --git a/tests/test_config.py b/tests/test_config.py index 8fc4a3f..d49b79a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -45,7 +45,7 @@ def test_server_env_optional_fields_use_defaults( assert env.serve_web_interface is False assert env.reload_agents is False assert env.agent_engine is None - assert env.database_url is None + assert env.sqlite_path is None assert env.openrouter_api_key is None assert env.allow_origins == '["http://127.0.0.1", "http://127.0.0.1:8080"]' assert env.host == "127.0.0.1" @@ -108,7 +108,7 @@ def test_server_env_optional_fields_with_values( "SERVE_WEB_INTERFACE": "true", "RELOAD_AGENTS": "true", "AGENT_ENGINE": "test-engine-id", - "DATABASE_URL": "postgresql://user:pass@localhost/db", + "SQLITE_PATH": "/tmp/blacki.db", "OPENROUTER_API_KEY": "sk-or-v1-test", "ALLOW_ORIGINS": '["http://localhost:3000"]', "HOST": "0.0.0.0", # noqa: S104 @@ -122,7 +122,7 @@ def test_server_env_optional_fields_with_values( assert env.serve_web_interface is True assert env.reload_agents is True assert env.agent_engine == "test-engine-id" - assert env.database_url == "postgresql://user:pass@localhost/db" + assert env.sqlite_path == "/tmp/blacki.db" assert env.openrouter_api_key == "sk-or-v1-test" assert env.allow_origins == '["http://localhost:3000"]' assert env.host == "0.0.0.0" # noqa: S104 @@ -142,9 +142,9 @@ def test_agent_engine_uri_property(self, valid_server_env: dict[str, str]) -> No def test_session_uri_property(self, valid_server_env: dict[str, str]) -> None: """Test that session_uri property uses agent_engine_uri only. - DATABASE_URL is ignored for sessions (reserved for Reminders system). + SQLITE_PATH is used for SQLite storage, not for ADK sessions. """ - # Case 1: Neither database_url nor agent_engine -> in-memory (None) + # Case 1: Neither sqlite_path nor agent_engine -> in-memory (None) env = ServerEnv.model_validate(valid_server_env) assert env.session_uri is None @@ -153,16 +153,15 @@ def test_session_uri_property(self, valid_server_env: dict[str, str]) -> None: env = ServerEnv.model_validate(data) assert env.session_uri == "agentengine://test-engine-id" - # Case 3: Only database_url -> ignored, returns None (in-memory) - db_url = "postgresql://user:pass@localhost/db?sslmode=require" - data = {**valid_server_env, "DATABASE_URL": db_url} + # Case 3: Only sqlite_path -> ignored for sessions, returns None + data = {**valid_server_env, "SQLITE_PATH": "/tmp/blacki.db"} env = ServerEnv.model_validate(data) assert env.session_uri is None - # Case 4: Both database_url and agent_engine -> agent_engine wins + # Case 4: Both sqlite_path and agent_engine -> agent_engine wins data = { **valid_server_env, - "DATABASE_URL": "postgresql://user:pass@localhost/db", + "SQLITE_PATH": "/tmp/blacki.db", "AGENT_ENGINE": "test-engine-id", } env = ServerEnv.model_validate(data) @@ -226,13 +225,13 @@ def test_server_env_print_config( assert "AGENT_NAME" in output assert "LOG_LEVEL" in output - def test_server_env_print_config_with_db( + def test_server_env_print_config_with_sqlite( self, valid_server_env: dict[str, str], capsys: pytest.CaptureFixture[str] ) -> None: - """Test print_config outputs DB pool settings when DATABASE_URL is set.""" + """Test print_config outputs SQLite path when SQLITE_PATH is set.""" data = { **valid_server_env, - "DATABASE_URL": "postgresql://user:pass@localhost/db", + "SQLITE_PATH": "/tmp/blacki.db", } env = ServerEnv.model_validate(data) env.print_config() @@ -240,11 +239,7 @@ def test_server_env_print_config_with_db( captured = capsys.readouterr() output = captured.out - assert "DB_POOL_PRE_PING" in output - assert "DB_POOL_RECYCLE" in output - assert "DB_POOL_SIZE" in output - assert "DB_MAX_OVERFLOW" in output - assert "DB_POOL_TIMEOUT" in output + assert "SQLITE_PATH" in output def test_server_env_ignores_extra_fields( self, valid_server_env: dict[str, str] @@ -425,3 +420,10 @@ def test_port_field_parsing(self, valid_server_env: dict[str, str]) -> None: env = ServerEnv.model_validate(data) assert env.port == 9000 assert isinstance(env.port, int) + + def test_sqlite_path_can_be_set(self, valid_server_env: dict[str, str]) -> None: + """Test that SQLITE_PATH can be configured.""" + data = {**valid_server_env, "SQLITE_PATH": "/var/data/blacki.db"} + + env = ServerEnv.model_validate(data) + assert env.sqlite_path == "/var/data/blacki.db" diff --git a/tests/test_container.py b/tests/test_container.py index 7cb0f45..bc13fc2 100644 --- a/tests/test_container.py +++ b/tests/test_container.py @@ -1,8 +1,10 @@ +# mypy: disable-error-code="no-untyped-def,method-assign" """Tests for the dependency injection container.""" +import asyncio from unittest.mock import AsyncMock, MagicMock, patch -import asyncpg # type: ignore[import-untyped] +import aiosqlite import pytest from blacki.container import ( @@ -12,7 +14,7 @@ init_container, reset_container_for_tests, set_container, - set_container_from_pool, + set_container_from_connection, ) @@ -56,20 +58,36 @@ def test_reset_container_for_tests(self) -> None: get_container() -class TestSetContainerFromPool: - """Tests for set_container_from_pool function.""" +class TestSetContainerFromConnection: + """Tests for set_container_from_connection function.""" def teardown_method(self) -> None: """Reset container after each test.""" reset_container_for_tests() - def test_creates_container_from_pool(self) -> None: - """Should create and set container from existing pool.""" - mock_pool = MagicMock(spec=asyncpg.Pool) - container = set_container_from_pool(mock_pool) + @pytest.mark.asyncio + async def test_creates_container_from_connection(self) -> None: + """Should create and set container from existing connection.""" + conn = await aiosqlite.connect(":memory:") + try: + container = set_container_from_connection(conn) + + assert container.conn is conn + assert get_container() is container + finally: + await conn.close() + + @pytest.mark.asyncio + async def test_creates_container_with_custom_lock(self) -> None: + """Should use provided lock.""" + conn = await aiosqlite.connect(":memory:") + try: + custom_lock = asyncio.Lock() + container = set_container_from_connection(conn, lock=custom_lock) - assert container.pool is mock_pool - assert get_container() is container + assert container.lock is custom_lock + finally: + await conn.close() class TestCloseContainer: @@ -115,25 +133,12 @@ async def test_init_container_creates_and_sets(self) -> None: mock_container = MagicMock(spec=AppContainer) mock_create.return_value = mock_container - result = await init_container("postgres://localhost/test") + result = await init_container("/tmp/test.db") - mock_create.assert_called_once_with("postgres://localhost/test", 5) + mock_create.assert_called_once_with("/tmp/test.db") assert result is mock_container assert get_container() is mock_container - @pytest.mark.asyncio - async def test_init_container_with_custom_pool_size(self) -> None: - """Should pass custom pool size to create.""" - with patch.object( - AppContainer, "create", new_callable=AsyncMock - ) as mock_create: - mock_container = MagicMock(spec=AppContainer) - mock_create.return_value = mock_container - - await init_container("postgres://localhost/test", pool_size=10) - - mock_create.assert_called_once_with("postgres://localhost/test", 10) - class TestAppContainer: """Tests for AppContainer class.""" @@ -143,61 +148,69 @@ def teardown_method(self) -> None: reset_container_for_tests() @pytest.fixture - def mock_pool(self) -> MagicMock: - """Create a mock asyncpg Pool.""" - pool = MagicMock(spec=asyncpg.Pool) - pool.close = AsyncMock() - return pool + async def conn(self): + """Create an in-memory SQLite connection for testing.""" + conn = await aiosqlite.connect(":memory:") + yield conn + await conn.close() - def test_container_properties_lazy_instantiate(self, mock_pool: MagicMock) -> None: + @pytest.fixture + def lock(self) -> asyncio.Lock: + """Create a lock for write operations.""" + return asyncio.Lock() + + @pytest.mark.asyncio + async def test_container_properties_lazy_instantiate(self, conn, lock) -> None: """Should lazily instantiate storage on first access.""" - container = AppContainer(pool=mock_pool) + container = AppContainer(conn=conn, _lock=lock) assert container._reminder_storage is None storage = container.reminder_storage assert container._reminder_storage is storage - def test_calorie_storage_property(self, mock_pool: MagicMock) -> None: + @pytest.mark.asyncio + async def test_calorie_storage_property(self, conn, lock) -> None: """Should lazily instantiate calorie storage.""" - container = AppContainer(pool=mock_pool) + container = AppContainer(conn=conn, _lock=lock) storage = container.calorie_storage assert storage is not None assert container._calorie_storage is storage - def test_workout_storage_property(self, mock_pool: MagicMock) -> None: + @pytest.mark.asyncio + async def test_workout_storage_property(self, conn, lock) -> None: """Should lazily instantiate workout storage.""" - container = AppContainer(pool=mock_pool) + container = AppContainer(conn=conn, _lock=lock) storage = container.workout_storage assert storage is not None assert container._workout_storage is storage - def test_preferences_storage_property(self, mock_pool: MagicMock) -> None: + @pytest.mark.asyncio + async def test_preferences_storage_property(self, conn, lock) -> None: """Should lazily instantiate preferences storage.""" - container = AppContainer(pool=mock_pool) + container = AppContainer(conn=conn, _lock=lock) storage = container.preferences_storage assert storage is not None assert container._preferences_storage is storage @pytest.mark.asyncio - async def test_close_closes_pool_and_storages(self, mock_pool: MagicMock) -> None: - """Should close pool and all storage instances.""" - container = AppContainer(pool=mock_pool) + async def test_close_closes_connection_and_storages(self, conn, lock) -> None: + """Should close connection and all storage instances.""" + container = AppContainer(conn=conn, _lock=lock) reminder = container.reminder_storage - reminder.close = AsyncMock() # type: ignore[method-assign] + reminder.close = AsyncMock() await container.close() reminder.close.assert_called_once() - mock_pool.close.assert_called_once() @pytest.mark.asyncio - async def test_close_storages_resets_references(self, mock_pool: MagicMock) -> None: + async def test_close_storages_resets_references(self, conn, lock) -> None: """Should reset storage references after close.""" - container = AppContainer(pool=mock_pool) + container = AppContainer(conn=conn, _lock=lock) _ = container.reminder_storage _ = container.calorie_storage @@ -212,9 +225,9 @@ async def test_close_storages_resets_references(self, mock_pool: MagicMock) -> N assert container._preferences_storage is None @pytest.mark.asyncio - async def test_close_storages_partial(self, mock_pool: MagicMock) -> None: + async def test_close_storages_partial(self, conn, lock) -> None: """Should handle partial storage initialization.""" - container = AppContainer(pool=mock_pool) + container = AppContainer(conn=conn, _lock=lock) _ = container.calorie_storage _ = container.workout_storage @@ -228,19 +241,19 @@ async def test_close_storages_partial(self, mock_pool: MagicMock) -> None: assert container._preferences_storage is None @pytest.mark.asyncio - async def test_initialize_all_storages(self, mock_pool: MagicMock) -> None: + async def test_initialize_all_storages(self, conn, lock) -> None: """Should initialize all storage instances.""" - container = AppContainer(pool=mock_pool) + container = AppContainer(conn=conn, _lock=lock) reminder = container.reminder_storage calorie = container.calorie_storage workout = container.workout_storage preferences = container.preferences_storage - reminder.initialize = AsyncMock() # type: ignore[method-assign] - calorie.initialize = AsyncMock() # type: ignore[method-assign] - workout.initialize = AsyncMock() # type: ignore[method-assign] - preferences.initialize = AsyncMock() # type: ignore[method-assign] + reminder.initialize = AsyncMock() + calorie.initialize = AsyncMock() + workout.initialize = AsyncMock() + preferences.initialize = AsyncMock() await container.initialize_all_storages() @@ -248,3 +261,19 @@ async def test_initialize_all_storages(self, mock_pool: MagicMock) -> None: calorie.initialize.assert_called_once() workout.initialize.assert_called_once() preferences.initialize.assert_called_once() + + @pytest.mark.asyncio + async def test_create_creates_container_with_connection(self) -> None: + """Should create container with SQLite connection.""" + import tempfile + from pathlib import Path + + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.db" + + container = await AppContainer.create(db_path) + + assert container.conn is not None + assert container.lock is not None + + await container.close() diff --git a/tests/test_preferences.py b/tests/test_preferences.py index f009009..b601f34 100644 --- a/tests/test_preferences.py +++ b/tests/test_preferences.py @@ -1,141 +1,180 @@ # mypy: disable-error-code="no-untyped-def" -from unittest.mock import AsyncMock, MagicMock, patch +"""Unit tests for preferences storage.""" +import asyncio + +import aiosqlite import pytest -from blacki.utils.preferences import ( - PostgresPreferencesStorage, - close_preferences_storage, - get_preferences_storage, - init_preferences_storage, -) +from blacki.utils.preferences import SqlitePreferencesStorage @pytest.fixture -def mock_pool(): - pool = MagicMock() - - conn = AsyncMock() - conn.execute = AsyncMock() +async def conn(): + """Create an in-memory SQLite connection for testing.""" + conn = await aiosqlite.connect(":memory:") + conn.row_factory = aiosqlite.Row + yield conn + await conn.close() - pool.acquire.return_value.__aenter__.return_value = conn - pool.execute = AsyncMock() - pool.fetchrow = AsyncMock() - return pool +@pytest.fixture +def lock(): + """Create a lock for write operations.""" + return asyncio.Lock() @pytest.fixture -async def preferences_storage(mock_pool): - storage = PostgresPreferencesStorage(mock_pool) +async def storage(conn, lock): + """Create a storage instance with the test connection.""" + storage = SqlitePreferencesStorage(conn, lock) await storage.initialize() yield storage await storage.close() -@pytest.mark.asyncio -async def test_initialize_creates_tables(mock_pool) -> None: - storage = PostgresPreferencesStorage(mock_pool) - await storage.initialize() - - conn = mock_pool.acquire.return_value.__aenter__.return_value - conn.execute.assert_called() - assert storage._schema_ready is True - - -@pytest.mark.asyncio -async def test_get_existing(preferences_storage, mock_pool) -> None: - mock_pool.fetchrow.return_value = {"value": '{"monday": "push"}'} - - result = await preferences_storage.get("user1", "workout_split") - - assert result == {"monday": "push"} - mock_pool.fetchrow.assert_called_once_with( - "SELECT value FROM user_preferences WHERE user_id = $1 AND key = $2", - "user1", - "workout_split", - ) - - -@pytest.mark.asyncio -async def test_get_not_found_returns_default(preferences_storage, mock_pool) -> None: - mock_pool.fetchrow.return_value = None - - result = await preferences_storage.get("user1", "calorie_goal", 2000) - - assert result == 2000 - - -@pytest.mark.asyncio -async def test_set(preferences_storage, mock_pool) -> None: - mock_pool.execute.return_value = "INSERT 0 1" - - with patch("blacki.utils.preferences.now_utc") as mock_now: - mock_now.return_value.isoformat.return_value = "2026-04-26T12:00:00" - await preferences_storage.set("user1", "calorie_goal", 2500) - - mock_pool.execute.assert_called_once() - args = mock_pool.execute.call_args[0] - assert args[1] == "user1" - assert args[2] == "calorie_goal" - assert args[3] == "2500" - assert args[4] == "2026-04-26T12:00:00" - +class TestSqlitePreferencesStorage: + """Tests for SqlitePreferencesStorage.""" + + @pytest.mark.asyncio + async def test_initialize_creates_tables(self, conn, lock) -> None: + """Should create tables on initialization.""" + storage = SqlitePreferencesStorage(conn, lock) + await storage.initialize() -@pytest.mark.asyncio -async def test_delete_success(preferences_storage, mock_pool) -> None: - mock_pool.execute.return_value = "DELETE 1" + assert storage.is_initialized is True - result = await preferences_storage.delete("user1", "calorie_goal") + cursor = await conn.execute( + """ + SELECT name FROM sqlite_master + WHERE type='table' AND name='user_preferences' + """ + ) + row = await cursor.fetchone() + assert row is not None + + @pytest.mark.asyncio + async def test_get_existing(self, storage) -> None: + """Should get an existing preference.""" + await storage.set("user1", "workout_split", {"monday": "push"}) - assert result is True - mock_pool.execute.assert_called_once_with( - "DELETE FROM user_preferences WHERE user_id = $1 AND key = $2", - "user1", - "calorie_goal", - ) + result = await storage.get("user1", "workout_split") + assert result == {"monday": "push"} -@pytest.mark.asyncio -async def test_delete_not_found(preferences_storage, mock_pool) -> None: - mock_pool.execute.return_value = "DELETE 0" + @pytest.mark.asyncio + async def test_get_not_found_returns_default(self, storage) -> None: + """Should return default when preference not found.""" + result = await storage.get("user1", "calorie_goal", 2000) + + assert result == 2000 + + @pytest.mark.asyncio + async def test_set(self, storage) -> None: + """Should set a preference.""" + await storage.set("user1", "calorie_goal", 2500) + + result = await storage.get("user1", "calorie_goal") + assert result == 2500 - result = await preferences_storage.delete("user1", "calorie_goal") + @pytest.mark.asyncio + async def test_set_updates_existing(self, storage) -> None: + """Should update an existing preference.""" + await storage.set("user1", "calorie_goal", 2000) + await storage.set("user1", "calorie_goal", 2500) - assert result is False + result = await storage.get("user1", "calorie_goal") + assert result == 2500 + @pytest.mark.asyncio + async def test_delete_success(self, storage) -> None: + """Should delete a preference.""" + await storage.set("user1", "calorie_goal", 2500) -@pytest.mark.asyncio -async def test_singleton(mock_pool) -> None: - # Ensure it raises before init - # Need to clear global first since other tests might have run - import blacki.utils.preferences as prefs + result = await storage.delete("user1", "calorie_goal") - prefs._storage = None + assert result is True + value = await storage.get("user1", "calorie_goal") + assert value is None - with pytest.raises(RuntimeError): - get_preferences_storage() + @pytest.mark.asyncio + async def test_delete_not_found(self, storage) -> None: + """Should return False when deleting non-existent preference.""" + result = await storage.delete("user1", "calorie_goal") - storage = await init_preferences_storage(mock_pool) - assert get_preferences_storage() is storage + assert result is False - await close_preferences_storage() - with pytest.raises(RuntimeError): - get_preferences_storage() + @pytest.mark.asyncio + async def test_multiple_keys_per_user(self, storage) -> None: + """Should handle multiple keys per user.""" + await storage.set("user1", "calorie_goal", 2500) + await storage.set("user1", "workout_split", {"monday": "push"}) + await storage.set("user1", "timezone", "America/New_York") + assert await storage.get("user1", "calorie_goal") == 2500 + assert await storage.get("user1", "workout_split") == {"monday": "push"} + assert await storage.get("user1", "timezone") == "America/New_York" -@pytest.mark.asyncio -async def test_reinit_preferences_storage_closes_existing(mock_pool) -> None: - """init_preferences_storage closes existing storage before replacing.""" - import blacki.utils.preferences as prefs + @pytest.mark.asyncio + async def test_multiple_users_isolated(self, storage) -> None: + """Should isolate preferences between users.""" + await storage.set("user1", "calorie_goal", 2000) + await storage.set("user2", "calorie_goal", 2500) - existing = PostgresPreferencesStorage(mock_pool) - existing.close = AsyncMock() # type: ignore[method-assign] - prefs._storage = existing + assert await storage.get("user1", "calorie_goal") == 2000 + assert await storage.get("user2", "calorie_goal") == 2500 - new = await init_preferences_storage(mock_pool) + @pytest.mark.asyncio + async def test_complex_value(self, storage) -> None: + """Should store and retrieve complex values.""" + complex_value = { + "split_name": "push", + "exercises": ["bench press", "shoulder press"], + "rest_days": ["wednesday", "sunday"], + } + await storage.set("user1", "workout_split", complex_value) + + result = await storage.get("user1", "workout_split") + assert result == complex_value + + +class TestGetPreferencesStorage: + """Tests for get_preferences_storage function.""" + + @pytest.mark.asyncio + async def test_get_preferences_storage_raises_when_not_initialized( + self, conn, lock + ) -> None: + """Should raise RuntimeError when storage is not initialized.""" + from blacki.container import ( + reset_container_for_tests, + set_container_from_connection, + ) + from blacki.utils.preferences import get_preferences_storage - existing.close.assert_awaited_once() - assert prefs._storage is new + set_container_from_connection(conn, lock) + + with pytest.raises(RuntimeError, match="Preferences storage not initialized"): + get_preferences_storage() - prefs._storage = None + reset_container_for_tests() + + @pytest.mark.asyncio + async def test_get_preferences_storage_returns_storage_when_initialized( + self, conn, lock + ) -> None: + """Should return storage when initialized.""" + from blacki.container import ( + reset_container_for_tests, + set_container_from_connection, + ) + from blacki.utils.preferences import get_preferences_storage + + container = set_container_from_connection(conn, lock) + await container.preferences_storage.initialize() + + result = get_preferences_storage() + + assert result is container.preferences_storage + + reset_container_for_tests() diff --git a/tests/test_registry.py b/tests/test_registry.py index 60c7681..b2fe365 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -14,7 +14,7 @@ def test_default_values(self) -> None: config = ToolConfig() assert config.brave_search_api_key is None - assert config.database_url is None + assert config.sqlite_path is None assert config.sandbox_enabled is False assert config.skills_dir is None @@ -23,13 +23,13 @@ def test_custom_values(self) -> None: skills_path = Path("/tmp/skills") config = ToolConfig( brave_search_api_key="test-key", - database_url="postgres://localhost/test", + sqlite_path="/tmp/blacki.db", sandbox_enabled=True, skills_dir=skills_path, ) assert config.brave_search_api_key == "test-key" - assert config.database_url == "postgres://localhost/test" + assert config.sqlite_path == "/tmp/blacki.db" assert config.sandbox_enabled is True assert config.skills_dir == skills_path @@ -53,8 +53,8 @@ def test_brave_search_tools_added(self) -> None: assert len(tools) == 9 def test_database_tools_added(self) -> None: - """Should add database-backed tools when database URL provided.""" - config = ToolConfig(database_url="postgres://localhost/test") + """Should add database-backed tools when sqlite path provided.""" + config = ToolConfig(sqlite_path="/tmp/blacki.db") tools = build_tools(config) @@ -80,7 +80,7 @@ def test_all_tools_with_full_config(self) -> None: """Should include all tools with full configuration.""" config = ToolConfig( brave_search_api_key="test-key", - database_url="postgres://localhost/test", + sqlite_path="/tmp/blacki.db", sandbox_enabled=True, skills_dir=Path(__file__).parent.parent / "src" / "blacki" / "skills", ) @@ -110,7 +110,7 @@ def test_empty_env(self) -> None: config = build_tool_config_from_env() assert config.brave_search_api_key is None - assert config.database_url is None + assert config.sqlite_path is None assert config.sandbox_enabled is False assert config.skills_dir is not None @@ -139,14 +139,12 @@ def test_brave_search_api_key_empty_string_becomes_none(self) -> None: assert config.brave_search_api_key is None - def test_database_url_from_env(self) -> None: - """Should read DATABASE_URL from env.""" - with patch.dict( - "os.environ", {"DATABASE_URL": "postgres://localhost/test"}, clear=False - ): + def test_sqlite_path_from_env(self) -> None: + """Should read SQLITE_PATH from env.""" + with patch.dict("os.environ", {"SQLITE_PATH": "/tmp/blacki.db"}, clear=False): config = build_tool_config_from_env() - assert config.database_url == "postgres://localhost/test" + assert config.sqlite_path == "/tmp/blacki.db" def test_sandbox_enabled_from_env_true(self) -> None: """Should enable sandbox when SANDBOX_ENABLED is true.""" diff --git a/tests/test_server_config.py b/tests/test_server_config.py index 23e9df7..b7155cc 100644 --- a/tests/test_server_config.py +++ b/tests/test_server_config.py @@ -18,49 +18,30 @@ def mock_dependencies() -> Generator[MagicMock]: patch("openinference.instrumentation.google_adk.GoogleADKInstrumentor"), patch("blacki.utils.setup_logging"), ): - # Setup basic env mock mock_env = MagicMock() - mock_env.session_uri = "postgresql://user:pass@localhost/db" + mock_env.session_uri = None mock_env.allow_origins_list = ["*"] mock_env.serve_web_interface = True mock_env.reload_agents = False + mock_env.sqlite_path = None + mock_env.agent_dir = "src" - # Helper to support .host and .port access if needed mock_env.host = "127.0.0.1" mock_env.port = 8080 - # DB pool settings - mock_env.db_pool_pre_ping = True - mock_env.db_pool_recycle = 1800 - mock_env.db_pool_size = 5 - mock_env.db_max_overflow = 10 - mock_env.db_pool_timeout = 30 - mock_init_env.return_value = mock_env yield mock_get_app -def test_server_session_db_kwargs_configuration(mock_dependencies: MagicMock) -> None: - """Verify session_db_kwargs is configured and passed to get_fast_api_app.""" - # Ensure blacki.server is reloaded if it was already imported +def test_server_session_service_uri_is_none(mock_dependencies: MagicMock) -> None: + """Verify session_service_uri is None for default SQLite sessions.""" if "blacki.server" in sys.modules: del sys.modules["blacki.server"] import blacki.server # noqa: F401 - # expected kwargs - expected_db_kwargs = { - "pool_pre_ping": True, - "pool_recycle": 1800, - "pool_size": 5, - "max_overflow": 10, - "pool_timeout": 30, - } - - # Verify the call mock_dependencies.assert_called_once() call_kwargs = mock_dependencies.call_args[1] - assert "session_db_kwargs" in call_kwargs - assert call_kwargs["session_db_kwargs"] == expected_db_kwargs + assert call_kwargs["session_service_uri"] is None diff --git a/tests/workouts/test_storage.py b/tests/workouts/test_storage.py index c30149d..b449ac9 100644 --- a/tests/workouts/test_storage.py +++ b/tests/workouts/test_storage.py @@ -1,246 +1,638 @@ # mypy: disable-error-code="no-untyped-def" -from unittest.mock import AsyncMock, MagicMock +"""Unit tests for workout storage.""" +import asyncio + +import aiosqlite import pytest from blacki.workouts.storage import ( - PostgresWorkoutStorage, SetDetail, + SqliteWorkoutStorage, WorkoutExercise, WorkoutSession, - close_workout_storage, - get_storage, - init_workout_storage, ) @pytest.fixture -def mock_pool(): - pool = MagicMock() - conn = AsyncMock() - conn.execute = AsyncMock() - tx = MagicMock() - tx.__aenter__ = AsyncMock(return_value=tx) - tx.__aexit__ = AsyncMock(return_value=None) - conn.transaction = MagicMock(return_value=tx) - pool.acquire.return_value.__aenter__.return_value = conn - pool.execute = AsyncMock() - pool.fetch = AsyncMock() - pool.fetchval = AsyncMock() - pool.fetchrow = AsyncMock() - return pool +async def conn(): + """Create an in-memory SQLite connection for testing.""" + conn = await aiosqlite.connect(":memory:") + conn.row_factory = aiosqlite.Row + yield conn + await conn.close() + + +@pytest.fixture +def lock(): + """Create a lock for write operations.""" + return asyncio.Lock() @pytest.fixture -async def workout_storage(mock_pool): - storage = PostgresWorkoutStorage(mock_pool) +async def storage(conn, lock): + """Create a storage instance with the test connection.""" + storage = SqliteWorkoutStorage(conn, lock) await storage.initialize() yield storage await storage.close() -@pytest.mark.asyncio -async def test_initialize_creates_tables(mock_pool) -> None: - storage = PostgresWorkoutStorage(mock_pool) - await storage.initialize() - - conn = mock_pool.acquire.return_value.__aenter__.return_value - assert conn.execute.call_count == 5 - assert storage._schema_ready is True - - -@pytest.mark.asyncio -async def test_create_session(workout_storage, mock_pool) -> None: - conn = mock_pool.acquire.return_value.__aenter__.return_value - conn.fetchval.return_value = 123 - - session = WorkoutSession( - user_id="user1", - workout_date="2026-04-26", - split_name="push", - created_at="2026-04-26T10:00:00", - exercises=[ - WorkoutExercise( - exercise_name="bench press", - sets=[SetDetail(set_num=1, weight_kg=100, reps=10)], +class TestSqliteWorkoutStorage: + """Tests for SqliteWorkoutStorage.""" + + @pytest.mark.asyncio + async def test_initialize_creates_tables(self, conn, lock) -> None: + """Should create tables on initialization.""" + storage = SqliteWorkoutStorage(conn, lock) + await storage.initialize() + + assert storage.is_initialized is True + + cursor = await conn.execute( + """ + SELECT name FROM sqlite_master + WHERE type='table' AND name='workout_sessions' + """ + ) + row = await cursor.fetchone() + assert row is not None + + cursor = await conn.execute( + """ + SELECT name FROM sqlite_master + WHERE type='table' AND name='workout_exercises' + """ + ) + row = await cursor.fetchone() + assert row is not None + + @pytest.mark.asyncio + async def test_create_session(self, storage) -> None: + """Should create a session with exercises.""" + session = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[ + WorkoutExercise( + exercise_name="bench press", + sets=[SetDetail(set_num=1, weight_kg=100, reps=10)], + ) + ], + ) + + session_id = await storage.create_session(session) + + assert session_id == 1 + + saved = await storage.get_session(session_id, "user1") + assert saved is not None + assert saved.split_name == "push" + assert len(saved.exercises) == 1 + assert saved.exercises[0].exercise_name == "bench press" + assert saved.exercises[0].sets[0].weight_kg == 100 + + @pytest.mark.asyncio + async def test_create_session_multiple_exercises(self, storage) -> None: + """Should create a session with multiple exercises.""" + session = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[ + WorkoutExercise( + exercise_name="bench press", + sets=[ + SetDetail(set_num=1, weight_kg=100, reps=10), + SetDetail(set_num=2, weight_kg=105, reps=8), + ], + exercise_order=0, + ), + WorkoutExercise( + exercise_name="shoulder press", + sets=[SetDetail(set_num=1, weight_kg=60, reps=12)], + exercise_order=1, + ), + ], + ) + + session_id = await storage.create_session(session) + + saved = await storage.get_session(session_id, "user1") + assert len(saved.exercises) == 2 + assert saved.exercises[0].sets[0].weight_kg == 100 + assert saved.exercises[1].exercise_name == "shoulder press" + + @pytest.mark.asyncio + async def test_add_exercise(self, storage) -> None: + """Should add an exercise to an existing session.""" + session = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[], + ) + session_id = await storage.create_session(session) + + exercise = WorkoutExercise( + exercise_name="squat", + sets=[SetDetail(set_num=1, weight_kg=120, reps=8)], + ) + + exercise_id = await storage.add_exercise(session_id, exercise) + + assert exercise_id == 1 + + saved = await storage.get_session(session_id, "user1") + assert len(saved.exercises) == 1 + assert saved.exercises[0].exercise_name == "squat" + + @pytest.mark.asyncio + async def test_update_exercise(self, storage) -> None: + """Should update an exercise's sets and notes.""" + session = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[ + WorkoutExercise( + exercise_name="bench press", + sets=[SetDetail(set_num=1, weight_kg=100, reps=10)], + ) + ], + ) + session_id = await storage.create_session(session) + saved = await storage.get_session(session_id, "user1") + exercise_id = saved.exercises[0].id + + result = await storage.update_exercise( + exercise_id, + "user1", + sets=[SetDetail(set_num=1, weight_kg=110, reps=8)], + notes="felt strong", + ) + + assert result is True + + updated = await storage.get_session(session_id, "user1") + assert updated.exercises[0].sets[0].weight_kg == 110 + assert updated.exercises[0].notes == "felt strong" + + @pytest.mark.asyncio + async def test_update_exercise_wrong_user(self, storage) -> None: + """Should not update exercise belonging to different user.""" + session = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[ + WorkoutExercise( + exercise_name="bench press", + sets=[SetDetail(set_num=1, weight_kg=100, reps=10)], + ) + ], + ) + session_id = await storage.create_session(session) + saved = await storage.get_session(session_id, "user1") + exercise_id = saved.exercises[0].id + + result = await storage.update_exercise( + exercise_id, + "user2", + notes="should not work", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_delete_exercise(self, storage) -> None: + """Should delete an exercise.""" + session = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[ + WorkoutExercise( + exercise_name="bench press", + sets=[SetDetail(set_num=1, weight_kg=100, reps=10)], + ) + ], + ) + session_id = await storage.create_session(session) + saved = await storage.get_session(session_id, "user1") + exercise_id = saved.exercises[0].id + + result = await storage.delete_exercise(exercise_id, "user1") + + assert result is True + + updated = await storage.get_session(session_id, "user1") + assert len(updated.exercises) == 0 + + @pytest.mark.asyncio + async def test_delete_exercise_wrong_user(self, storage) -> None: + """Should not delete exercise belonging to different user.""" + session = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[ + WorkoutExercise( + exercise_name="bench press", + sets=[SetDetail(set_num=1, weight_kg=100, reps=10)], + ) + ], + ) + session_id = await storage.create_session(session) + saved = await storage.get_session(session_id, "user1") + exercise_id = saved.exercises[0].id + + result = await storage.delete_exercise(exercise_id, "user2") + + assert result is False + + @pytest.mark.asyncio + async def test_get_session_not_found(self, storage) -> None: + """Should return None for non-existent session.""" + result = await storage.get_session(999, "user1") + + assert result is None + + @pytest.mark.asyncio + async def test_get_session_wrong_user(self, storage) -> None: + """Should return None for session belonging to different user.""" + session = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[], + ) + session_id = await storage.create_session(session) + + result = await storage.get_session(session_id, "user2") + + assert result is None + + @pytest.mark.asyncio + async def test_get_recent_sessions(self, storage) -> None: + """Should get recent sessions with exercise counts.""" + for i in range(3): + session = WorkoutSession( + user_id="user1", + workout_date=f"2026-04-{26 - i:02d}", + split_name="push" if i % 2 == 0 else "pull", + created_at=f"2026-04-{26 - i:02d}T10:00:00", + exercises=[ + WorkoutExercise( + exercise_name=f"exercise {j}", + sets=[SetDetail(set_num=1, weight_kg=100, reps=10)], + ) + for j in range(i + 1) + ], ) - ], - ) - - conn.execute.reset_mock() - session_id = await workout_storage.create_session(session) - - assert session_id == 123 - conn.fetchval.assert_called_once() - assert conn.execute.call_count == 1 - args = conn.execute.call_args[0] - assert args[1] == 123 - assert args[2] == "bench press" - - -@pytest.mark.asyncio -async def test_add_exercise(workout_storage, mock_pool) -> None: - mock_pool.fetchval.return_value = 456 - - exercise = WorkoutExercise( - exercise_name="squat", sets=[SetDetail(set_num=1, weight_kg=120, reps=8)] - ) - - exercise_id = await workout_storage.add_exercise(123, exercise) - - assert exercise_id == 456 - mock_pool.fetchval.assert_called_once() - - -@pytest.mark.asyncio -async def test_update_exercise(workout_storage, mock_pool) -> None: - mock_pool.fetchval.return_value = "user1" # ownership check - mock_pool.execute.return_value = "UPDATE 1" - - result = await workout_storage.update_exercise(456, "user1", notes="form felt good") - - assert result is True - mock_pool.execute.assert_called_once() - - -@pytest.mark.asyncio -async def test_delete_exercise(workout_storage, mock_pool) -> None: - mock_pool.fetchval.return_value = "user1" - mock_pool.execute.return_value = "DELETE 1" - - result = await workout_storage.delete_exercise(456, "user1") - - assert result is True - mock_pool.execute.assert_called_once() - - -@pytest.mark.asyncio -async def test_get_session(workout_storage, mock_pool) -> None: - mock_pool.fetchrow.return_value = { - "id": 1, - "user_id": "user1", - "workout_date": "2026-04-26", - "split_name": "push", - "notes": None, - "created_at": "2026-04-26T10:00:00", - } - - mock_pool.fetch.return_value = [ - { - "id": 10, - "session_id": 1, - "exercise_name": "bench press", - "sets": ( - '[{"set_num": 1, "weight_kg": 100, "reps": 10, "is_warmup": false}]' - ), - "exercise_order": 0, - "notes": None, - } - ] - - session = await workout_storage.get_session(1, "user1") - - assert session is not None - assert session.id == 1 - assert session.split_name == "push" - assert len(session.exercises) == 1 - assert session.exercises[0].exercise_name == "bench press" - assert session.exercises[0].sets[0].weight_kg == 100 - - -@pytest.mark.asyncio -async def test_get_recent_sessions(workout_storage, mock_pool) -> None: - mock_pool.fetch.return_value = [ - { - "id": 1, - "workout_date": "2026-04-26", - "split_name": "push", - "exercise_count": 5, - } - ] - - sessions = await workout_storage.get_recent_sessions("user1") - - assert len(sessions) == 1 - assert sessions[0].id == 1 - assert sessions[0].exercise_count == 5 - - -@pytest.mark.asyncio -async def test_get_exercise_history(workout_storage, mock_pool) -> None: - mock_pool.fetch.return_value = [ - { - "workout_date": "2026-04-26", - "split_name": "push", - "sets": ( - '[{"set_num": 1, "weight_kg": 100, "reps": 10, "is_warmup": false}]' - ), - } - ] - - history = await workout_storage.get_exercise_history("user1", "bench press") - - assert len(history) == 1 - assert history[0].best_set_weight_kg == 100 - assert history[0].best_set_reps == 10 - assert history[0].total_volume_kg == 1000 - - -@pytest.mark.asyncio -async def test_delete_session(workout_storage, mock_pool) -> None: - mock_pool.execute.return_value = "DELETE 1" - - result = await workout_storage.delete_session(1, "user1") - - assert result is True - - -@pytest.mark.asyncio -async def test_singleton(mock_pool) -> None: - import blacki.workouts.storage as storage - - storage._storage = None - - with pytest.raises(RuntimeError): - get_storage() - - instance = await init_workout_storage(mock_pool) - assert get_storage() is instance - - await close_workout_storage() - with pytest.raises(RuntimeError): - get_storage() - - -@pytest.mark.asyncio -async def test_reinit_workout_storage_closes_existing(mock_pool) -> None: - """init_workout_storage closes existing storage before replacing.""" - import blacki.workouts.storage as storage - - existing = PostgresWorkoutStorage(mock_pool) - existing.close = AsyncMock() # type: ignore[method-assign] - storage._storage = existing - - new = await init_workout_storage(mock_pool) - - existing.close.assert_awaited_once() - assert storage._storage is new + await storage.create_session(session) + + sessions = await storage.get_recent_sessions("user1") + + assert len(sessions) == 3 + assert sessions[0].workout_date == "2026-04-26" + assert sessions[0].exercise_count == 1 + assert sessions[1].exercise_count == 2 + assert sessions[2].exercise_count == 3 + + @pytest.mark.asyncio + async def test_get_recent_sessions_respects_limit(self, storage) -> None: + """Should limit the number of sessions returned.""" + for i in range(15): + session = WorkoutSession( + user_id="user1", + workout_date=f"2026-04-{26 - i:02d}", + split_name="push", + created_at=f"2026-04-{26 - i:02d}T10:00:00", + exercises=[], + ) + await storage.create_session(session) + + sessions = await storage.get_recent_sessions("user1", limit=5) + + assert len(sessions) == 5 + + @pytest.mark.asyncio + async def test_get_exercise_history(self, storage) -> None: + """Should get exercise history with best sets.""" + for i in range(3): + session = WorkoutSession( + user_id="user1", + workout_date=f"2026-04-{26 - i:02d}", + split_name="push", + created_at=f"2026-04-{26 - i:02d}T10:00:00", + exercises=[ + WorkoutExercise( + exercise_name="bench press", + sets=[ + SetDetail( + set_num=1, + weight_kg=100.0 + (2 - i) * 5, + reps=10 - (2 - i), + ), + SetDetail( + set_num=2, + weight_kg=95.0 + (2 - i) * 5, + reps=12 - (2 - i), + ), + ], + ) + ], + ) + await storage.create_session(session) + + history = await storage.get_exercise_history("user1", "bench press") + + assert len(history) == 3 + assert history[0].best_set_weight_kg == 110.0 + assert history[0].best_set_reps == 8 + + @pytest.mark.asyncio + async def test_get_exercise_history_excludes_warmup(self, storage) -> None: + """Should exclude warmup sets from best set calculation.""" + session = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[ + WorkoutExercise( + exercise_name="bench press", + sets=[ + SetDetail(set_num=1, weight_kg=60.0, reps=15, is_warmup=True), + SetDetail(set_num=2, weight_kg=100.0, reps=10, is_warmup=False), + ], + ) + ], + ) + await storage.create_session(session) + + history = await storage.get_exercise_history("user1", "bench press") + + assert history[0].best_set_weight_kg == 100.0 + assert history[0].best_set_reps == 10 + + @pytest.mark.asyncio + async def test_delete_session(self, storage) -> None: + """Should delete a session and cascade to exercises.""" + session = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[ + WorkoutExercise( + exercise_name="bench press", + sets=[SetDetail(set_num=1, weight_kg=100, reps=10)], + ) + ], + ) + session_id = await storage.create_session(session) + + result = await storage.delete_session(session_id, "user1") + + assert result is True + + saved = await storage.get_session(session_id, "user1") + assert saved is None + + @pytest.mark.asyncio + async def test_delete_session_wrong_user(self, storage) -> None: + """Should not delete session belonging to different user.""" + session = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[], + ) + session_id = await storage.create_session(session) + + result = await storage.delete_session(session_id, "user2") + + assert result is False + + @pytest.mark.asyncio + async def test_get_latest_split_session(self, storage) -> None: + """Should get the most recent session for a split.""" + for i in range(3): + session = WorkoutSession( + user_id="user1", + workout_date=f"2026-04-{20 + i:02d}", + split_name="push" if i < 2 else "pull", + created_at=f"2026-04-{20 + i:02d}T10:00:00", + exercises=[ + WorkoutExercise( + exercise_name=f"exercise {i}", + sets=[SetDetail(set_num=1, weight_kg=100, reps=10)], + ) + ], + ) + await storage.create_session(session) + + result = await storage.get_latest_split_session("user1", "push") + + assert result is not None + assert result.workout_date == "2026-04-21" + assert result.split_name == "push" + + @pytest.mark.asyncio + async def test_get_latest_split_session_not_found(self, storage) -> None: + """Should return None if no session for split.""" + session = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[], + ) + await storage.create_session(session) + + result = await storage.get_latest_split_session("user1", "legs") + + assert result is None + + @pytest.mark.asyncio + async def test_multiple_users_isolated(self, storage) -> None: + """Should isolate data between users.""" + session1 = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[ + WorkoutExercise( + exercise_name="bench press", + sets=[SetDetail(set_num=1, weight_kg=100, reps=10)], + ) + ], + ) + session2 = WorkoutSession( + user_id="user2", + workout_date="2026-04-26", + split_name="pull", + created_at="2026-04-26T11:00:00", + exercises=[ + WorkoutExercise( + exercise_name="deadlift", + sets=[SetDetail(set_num=1, weight_kg=200, reps=5)], + ) + ], + ) + id1 = await storage.create_session(session1) + id2 = await storage.create_session(session2) + + s1 = await storage.get_session(id1, "user1") + s2 = await storage.get_session(id2, "user2") + + assert s1.split_name == "push" + assert s2.split_name == "pull" + + sessions1 = await storage.get_recent_sessions("user1") + sessions2 = await storage.get_recent_sessions("user2") + + assert len(sessions1) == 1 + assert len(sessions2) == 1 + + @pytest.mark.asyncio + async def test_update_exercise_no_updates_returns_false(self, storage) -> None: + """Should return False when no updates provided.""" + session = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[ + WorkoutExercise( + exercise_name="bench press", + sets=[SetDetail(set_num=1, weight_kg=100, reps=10)], + ) + ], + ) + session_id = await storage.create_session(session) + saved = await storage.get_session(session_id, "user1") + exercise_id = saved.exercises[0].id + + result = await storage.update_exercise(exercise_id, "user1") + + assert result is False + + +class TestGetStorage: + """Tests for get_storage function.""" + + @pytest.mark.asyncio + async def test_get_storage_raises_when_not_initialized(self, conn, lock) -> None: + """Should raise RuntimeError when storage is not initialized.""" + from blacki.container import ( + reset_container_for_tests, + set_container_from_connection, + ) + from blacki.workouts.storage import get_storage + + set_container_from_connection(conn, lock) + + with pytest.raises(RuntimeError, match="Workout storage not initialized"): + get_storage() + + reset_container_for_tests() + + @pytest.mark.asyncio + async def test_get_storage_returns_storage_when_initialized( + self, conn, lock + ) -> None: + """Should return storage when initialized.""" + from blacki.container import ( + reset_container_for_tests, + set_container_from_connection, + ) + from blacki.workouts.storage import get_storage + + container = set_container_from_connection(conn, lock) + await container.workout_storage.initialize() + + result = get_storage() + + assert result is container.workout_storage + + reset_container_for_tests() + + +class TestCreateSessionEdgeCases: + """Tests for create_session edge cases.""" + + @pytest.mark.asyncio + async def test_create_session_raises_when_lastrowid_none(self, conn, lock) -> None: + """Should raise RuntimeError when lastrowid is None after session insert.""" + from unittest.mock import AsyncMock + + import aiosqlite + + mock_conn = AsyncMock(spec=aiosqlite.Connection) + mock_cursor = AsyncMock() + mock_cursor.lastrowid = None + mock_conn.execute.return_value = mock_cursor + mock_conn.commit = AsyncMock() + mock_conn.rollback = AsyncMock() + + storage = SqliteWorkoutStorage(mock_conn, lock) + storage._schema_ready = True + + session = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[], + ) + + with pytest.raises( + RuntimeError, match="Failed to get lastrowid after session insert" + ): + await storage.create_session(session) + + mock_conn.rollback.assert_called_once() + + @pytest.mark.asyncio + async def test_create_session_rollback_on_exception(self, conn, lock) -> None: + """Should rollback transaction on exception during session creation.""" + from unittest.mock import AsyncMock + + import aiosqlite + + mock_conn = AsyncMock(spec=aiosqlite.Connection) + mock_cursor = AsyncMock() + mock_cursor.lastrowid = 1 + mock_conn.execute.return_value = mock_cursor + mock_conn.commit = AsyncMock(side_effect=Exception("commit failed")) + mock_conn.rollback = AsyncMock() - storage._storage = None + storage = SqliteWorkoutStorage(mock_conn, lock) + storage._schema_ready = True + session = WorkoutSession( + user_id="user1", + workout_date="2026-04-26", + split_name="push", + created_at="2026-04-26T10:00:00", + exercises=[], + ) -@pytest.mark.asyncio -async def test_get_latest_split_session(workout_storage, mock_pool) -> None: - mock_pool.fetchrow.return_value = { - "id": 1, - "user_id": "user1", - "workout_date": "2026-04-26", - "split_name": "push", - "notes": None, - "created_at": "2026-04-26T10:00:00", - } - mock_pool.fetch.return_value = [] + with pytest.raises(Exception, match="commit failed"): + await storage.create_session(session) - session = await workout_storage.get_latest_split_session("user1", "push") - assert session is not None - assert session.id == 1 + mock_conn.rollback.assert_called_once() diff --git a/uv.lock b/uv.lock index e626fb9..3dd8467 100644 --- a/uv.lock +++ b/uv.lock @@ -181,38 +181,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl", hash = "sha256:15a3ebc0f43c2d0a50eeafea25e19046c68398e487b9f1f5b517f7c0f40f976a", size = 27047, upload-time = "2025-11-15T16:43:16.109Z" }, ] -[[package]] -name = "asyncpg" -version = "0.31.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fe/cc/d18065ce2380d80b1bcce927c24a2642efd38918e33fd724bc4bca904877/asyncpg-0.31.0.tar.gz", hash = "sha256:c989386c83940bfbd787180f2b1519415e2d3d6277a70d9d0f0145ac73500735", size = 993667, upload-time = "2025-11-24T23:27:00.812Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/08/17/cc02bc49bc350623d050fa139e34ea512cd6e020562f2a7312a7bcae4bc9/asyncpg-0.31.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:eee690960e8ab85063ba93af2ce128c0f52fd655fdff9fdb1a28df01329f031d", size = 643159, upload-time = "2025-11-24T23:25:36.443Z" }, - { url = "https://files.pythonhosted.org/packages/a4/62/4ded7d400a7b651adf06f49ea8f73100cca07c6df012119594d1e3447aa6/asyncpg-0.31.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2657204552b75f8288de08ca60faf4a99a65deef3a71d1467454123205a88fab", size = 638157, upload-time = "2025-11-24T23:25:37.89Z" }, - { url = "https://files.pythonhosted.org/packages/d6/5b/4179538a9a72166a0bf60ad783b1ef16efb7960e4d7b9afe9f77a5551680/asyncpg-0.31.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a429e842a3a4b4ea240ea52d7fe3f82d5149853249306f7ff166cb9948faa46c", size = 2918051, upload-time = "2025-11-24T23:25:39.461Z" }, - { url = "https://files.pythonhosted.org/packages/e6/35/c27719ae0536c5b6e61e4701391ffe435ef59539e9360959240d6e47c8c8/asyncpg-0.31.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c0807be46c32c963ae40d329b3a686356e417f674c976c07fa49f1b30303f109", size = 2972640, upload-time = "2025-11-24T23:25:41.512Z" }, - { url = "https://files.pythonhosted.org/packages/43/f4/01ebb9207f29e645a64699b9ce0eefeff8e7a33494e1d29bb53736f7766b/asyncpg-0.31.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e5d5098f63beeae93512ee513d4c0c53dc12e9aa2b7a1af5a81cddf93fe4e4da", size = 2851050, upload-time = "2025-11-24T23:25:43.153Z" }, - { url = "https://files.pythonhosted.org/packages/3e/f4/03ff1426acc87be0f4e8d40fa2bff5c3952bef0080062af9efc2212e3be8/asyncpg-0.31.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:37fc6c00a814e18eef51833545d1891cac9aa69140598bb076b4cd29b3e010b9", size = 2962574, upload-time = "2025-11-24T23:25:44.942Z" }, - { url = "https://files.pythonhosted.org/packages/c7/39/cc788dfca3d4060f9d93e67be396ceec458dfc429e26139059e58c2c244d/asyncpg-0.31.0-cp311-cp311-win32.whl", hash = "sha256:5a4af56edf82a701aece93190cc4e094d2df7d33f6e915c222fb09efbb5afc24", size = 521076, upload-time = "2025-11-24T23:25:46.486Z" }, - { url = "https://files.pythonhosted.org/packages/28/fc/735af5384c029eb7f1ca60ccb8fa95521dbdaeef788edf4cecfc604c3cab/asyncpg-0.31.0-cp311-cp311-win_amd64.whl", hash = "sha256:480c4befbdf079c14c9ca43c8c5e1fe8b6296c96f1f927158d4f1e750aacc047", size = 584980, upload-time = "2025-11-24T23:25:47.938Z" }, - { url = "https://files.pythonhosted.org/packages/2a/a6/59d0a146e61d20e18db7396583242e32e0f120693b67a8de43f1557033e2/asyncpg-0.31.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b44c31e1efc1c15188ef183f287c728e2046abb1d26af4d20858215d50d91fad", size = 662042, upload-time = "2025-11-24T23:25:49.578Z" }, - { url = "https://files.pythonhosted.org/packages/36/01/ffaa189dcb63a2471720615e60185c3f6327716fdc0fc04334436fbb7c65/asyncpg-0.31.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0c89ccf741c067614c9b5fc7f1fc6f3b61ab05ae4aaa966e6fd6b93097c7d20d", size = 638504, upload-time = "2025-11-24T23:25:51.501Z" }, - { url = "https://files.pythonhosted.org/packages/9f/62/3f699ba45d8bd24c5d65392190d19656d74ff0185f42e19d0bbd973bb371/asyncpg-0.31.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:12b3b2e39dc5470abd5e98c8d3373e4b1d1234d9fbdedf538798b2c13c64460a", size = 3426241, upload-time = "2025-11-24T23:25:53.278Z" }, - { url = "https://files.pythonhosted.org/packages/8c/d1/a867c2150f9c6e7af6462637f613ba67f78a314b00db220cd26ff559d532/asyncpg-0.31.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:aad7a33913fb8bcb5454313377cc330fbb19a0cd5faa7272407d8a0c4257b671", size = 3520321, upload-time = "2025-11-24T23:25:54.982Z" }, - { url = "https://files.pythonhosted.org/packages/7a/1a/cce4c3f246805ecd285a3591222a2611141f1669d002163abef999b60f98/asyncpg-0.31.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3df118d94f46d85b2e434fd62c84cb66d5834d5a890725fe625f498e72e4d5ec", size = 3316685, upload-time = "2025-11-24T23:25:57.43Z" }, - { url = "https://files.pythonhosted.org/packages/40/ae/0fc961179e78cc579e138fad6eb580448ecae64908f95b8cb8ee2f241f67/asyncpg-0.31.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bd5b6efff3c17c3202d4b37189969acf8927438a238c6257f66be3c426beba20", size = 3471858, upload-time = "2025-11-24T23:25:59.636Z" }, - { url = "https://files.pythonhosted.org/packages/52/b2/b20e09670be031afa4cbfabd645caece7f85ec62d69c312239de568e058e/asyncpg-0.31.0-cp312-cp312-win32.whl", hash = "sha256:027eaa61361ec735926566f995d959ade4796f6a49d3bde17e5134b9964f9ba8", size = 527852, upload-time = "2025-11-24T23:26:01.084Z" }, - { url = "https://files.pythonhosted.org/packages/b5/f0/f2ed1de154e15b107dc692262395b3c17fc34eafe2a78fc2115931561730/asyncpg-0.31.0-cp312-cp312-win_amd64.whl", hash = "sha256:72d6bdcbc93d608a1158f17932de2321f68b1a967a13e014998db87a72ed3186", size = 597175, upload-time = "2025-11-24T23:26:02.564Z" }, - { url = "https://files.pythonhosted.org/packages/95/11/97b5c2af72a5d0b9bc3fa30cd4b9ce22284a9a943a150fdc768763caf035/asyncpg-0.31.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c204fab1b91e08b0f47e90a75d1b3c62174dab21f670ad6c5d0f243a228f015b", size = 661111, upload-time = "2025-11-24T23:26:04.467Z" }, - { url = "https://files.pythonhosted.org/packages/1b/71/157d611c791a5e2d0423f09f027bd499935f0906e0c2a416ce712ba51ef3/asyncpg-0.31.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:54a64f91839ba59008eccf7aad2e93d6e3de688d796f35803235ea1c4898ae1e", size = 636928, upload-time = "2025-11-24T23:26:05.944Z" }, - { url = "https://files.pythonhosted.org/packages/2e/fc/9e3486fb2bbe69d4a867c0b76d68542650a7ff1574ca40e84c3111bb0c6e/asyncpg-0.31.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0e0822b1038dc7253b337b0f3f676cadc4ac31b126c5d42691c39691962e403", size = 3424067, upload-time = "2025-11-24T23:26:07.957Z" }, - { url = "https://files.pythonhosted.org/packages/12/c6/8c9d076f73f07f995013c791e018a1cd5f31823c2a3187fc8581706aa00f/asyncpg-0.31.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bef056aa502ee34204c161c72ca1f3c274917596877f825968368b2c33f585f4", size = 3518156, upload-time = "2025-11-24T23:26:09.591Z" }, - { url = "https://files.pythonhosted.org/packages/ae/3b/60683a0baf50fbc546499cfb53132cb6835b92b529a05f6a81471ab60d0c/asyncpg-0.31.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0bfbcc5b7ffcd9b75ab1558f00db2ae07db9c80637ad1b2469c43df79d7a5ae2", size = 3319636, upload-time = "2025-11-24T23:26:11.168Z" }, - { url = "https://files.pythonhosted.org/packages/50/dc/8487df0f69bd398a61e1792b3cba0e47477f214eff085ba0efa7eac9ce87/asyncpg-0.31.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:22bc525ebbdc24d1261ecbf6f504998244d4e3be1721784b5f64664d61fbe602", size = 3472079, upload-time = "2025-11-24T23:26:13.164Z" }, - { url = "https://files.pythonhosted.org/packages/13/a1/c5bbeeb8531c05c89135cb8b28575ac2fac618bcb60119ee9696c3faf71c/asyncpg-0.31.0-cp313-cp313-win32.whl", hash = "sha256:f890de5e1e4f7e14023619399a471ce4b71f5418cd67a51853b9910fdfa73696", size = 527606, upload-time = "2025-11-24T23:26:14.78Z" }, - { url = "https://files.pythonhosted.org/packages/91/66/b25ccb84a246b470eb943b0107c07edcae51804912b824054b3413995a10/asyncpg-0.31.0-cp313-cp313-win_amd64.whl", hash = "sha256:dc5f2fa9916f292e5c5c8b2ac2813763bcd7f58e130055b4ad8a0531314201ab", size = 596569, upload-time = "2025-11-24T23:26:16.189Z" }, -] - [[package]] name = "attrs" version = "25.4.0" @@ -248,8 +216,8 @@ name = "blacki" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "aiosqlite" }, { name = "apscheduler" }, - { name = "asyncpg" }, { name = "dateparser" }, { name = "google-adk" }, { name = "google-auth" }, @@ -284,8 +252,8 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aiosqlite", specifier = ">=0.20.0" }, { name = "apscheduler", specifier = ">=3.11.0,<4.0.0" }, - { name = "asyncpg", specifier = ">=0.30.0" }, { name = "dateparser", specifier = ">=1.2.0,<2.0.0" }, { name = "google-adk", specifier = "==1.25.1" }, { name = "google-auth", specifier = ">=2.40.3,<3.0.0" },