From 7f04dcd96cc75af12bd30472a9c37298caeb0fc1 Mon Sep 17 00:00:00 2001 From: QueryPlanner Date: Sun, 7 Jun 2026 07:26:36 +0530 Subject: [PATCH 1/2] feat: add periodized training program and tools - Add training program, day, state, and metric database tables - Implement 7 new training tools for logging, tracking, and cycles - Register training tools and integrate prompt workout guidance - Add comprehensive unit and integration tests with 100% coverage --- src/blacki/prompt.py | 25 +- src/blacki/registry.py | 14 + src/blacki/workouts/__init__.py | 14 + src/blacki/workouts/storage.py | 569 ++++++++++++++- src/blacki/workouts/tools.py | 510 ++++++++++++- tests/test_registry.py | 2 +- tests/workouts/test_training.py | 1177 +++++++++++++++++++++++++++++++ 7 files changed, 2261 insertions(+), 50 deletions(-) create mode 100644 tests/workouts/test_training.py diff --git a/src/blacki/prompt.py b/src/blacki/prompt.py index 4535d71..b390d51 100644 --- a/src/blacki/prompt.py +++ b/src/blacki/prompt.py @@ -44,14 +44,27 @@ def return_instruction_root() -> str: -- When logging a workout, ask for the split name and exercises - with sets/reps/weight (kg). -- After logging, compare with the previous session for the same split - and highlight improvements or regressions conversationally. -- When the user asks "what should I do today?", use get_todays_workout. +- Prefer the training-program tools for workouts. Use set_training_program + for rotating or periodized plans, get_todays_training for the current + cycle day, log_training for resistance, conditioning, recovery, and rest, + get_training_history for comparable sessions, and advance_training_cycle + only when the user wants the cycle pointer moved. +- Support multi-modal sessions: resistance sets/reps/weight, Zone 2 heart + rate and duration, VO2 intervals, rower protocols, rucks, mobility, + active recovery, and complete rest. +- Treat cycle_day, session_type, completion_status, and metrics as structured + fields. Use metrics for duration, distance, average/max heart rate, watts, + intervals, ruck load, incline, and lower_back_status. +- Use get_training_metrics and update_training_metrics for 1RMs, max heart + rate, and other macro training metrics. +- Respect stored deload and progression rules from the active training program. + Do not invent progression changes when rules are available. +- Do not silently advance the training cycle after logging. Only advance when + the user asks or log_training is explicitly called with advance_day=true. +- Legacy weekly split tools still exist for simple split-based workouts: + log_workout, get_last_workout, get_todays_workout, and set_workout_split. - Normalize exercise names to lowercase (e.g., "Bench Press" → "bench press") for consistent history tracking. -- If no workout split is configured, prompt the user to set one via set_workout_split. - When summarizing workouts, DO NOT use lists or bullet points. Speak it naturally. diff --git a/src/blacki/registry.py b/src/blacki/registry.py index b4307fb..8c116e8 100644 --- a/src/blacki/registry.py +++ b/src/blacki/registry.py @@ -115,16 +115,30 @@ def _build_workout_tools() -> list[Any]: """Build workout tracking tools.""" try: from blacki.workouts import ( + advance_training_cycle, delete_workout, get_exercise_progress, get_last_workout, + get_todays_training, get_todays_workout, + get_training_history, + get_training_metrics, list_recent_workouts, + log_training, log_workout, + set_training_program, set_workout_split, + update_training_metrics, ) return [ + set_training_program, + get_todays_training, + log_training, + advance_training_cycle, + get_training_history, + get_training_metrics, + update_training_metrics, log_workout, get_last_workout, get_exercise_progress, diff --git a/src/blacki/workouts/__init__.py b/src/blacki/workouts/__init__.py index d758335..724a292 100644 --- a/src/blacki/workouts/__init__.py +++ b/src/blacki/workouts/__init__.py @@ -1,19 +1,33 @@ from .tools import ( + advance_training_cycle, delete_workout, get_exercise_progress, get_last_workout, + get_todays_training, get_todays_workout, + get_training_history, + get_training_metrics, list_recent_workouts, + log_training, log_workout, + set_training_program, set_workout_split, + update_training_metrics, ) __all__ = [ + "advance_training_cycle", "delete_workout", "get_exercise_progress", "get_last_workout", + "get_todays_training", "get_todays_workout", + "get_training_history", + "get_training_metrics", "list_recent_workouts", "log_workout", + "log_training", + "set_training_program", "set_workout_split", + "update_training_metrics", ] diff --git a/src/blacki/workouts/storage.py b/src/blacki/workouts/storage.py index 1c42712..bc9dc8c 100644 --- a/src/blacki/workouts/storage.py +++ b/src/blacki/workouts/storage.py @@ -6,7 +6,7 @@ import logging from typing import TYPE_CHECKING, Any -from pydantic import BaseModel +from pydantic import BaseModel, Field from blacki.storage.base import SqlStorage @@ -17,6 +17,21 @@ logger = logging.getLogger(__name__) +SESSION_TYPES = frozenset( + { + "resistance", + "zone2", + "vo2", + "sugarcane", + "ruck", + "recovery", + "rest", + "mobility", + "other", + } +) +COMPLETION_STATUSES = frozenset({"planned", "completed", "partial", "skipped"}) + class SetDetail(BaseModel): """Details for a single workout set.""" @@ -33,7 +48,7 @@ class WorkoutExercise(BaseModel): id: int | None = None session_id: int | None = None exercise_name: str - sets: list[SetDetail] + sets: list[SetDetail] = Field(default_factory=list) exercise_order: int = 0 notes: str | None = None @@ -47,7 +62,13 @@ class WorkoutSession(BaseModel): split_name: str notes: str | None = None created_at: str - exercises: list[WorkoutExercise] = [] + exercises: list[WorkoutExercise] = Field(default_factory=list) + program_id: int | None = None + program_version: int | None = None + cycle_day: int | None = None + session_type: str = "resistance" + completion_status: str = "completed" + metrics: dict[str, Any] = Field(default_factory=dict) class WorkoutSessionSummary(BaseModel): @@ -57,6 +78,9 @@ class WorkoutSessionSummary(BaseModel): workout_date: str split_name: str exercise_count: int + cycle_day: int | None = None + session_type: str | None = None + completion_status: str = "completed" class ExerciseHistoryEntry(BaseModel): @@ -70,6 +94,64 @@ class ExerciseHistoryEntry(BaseModel): total_volume_kg: float +class TrainingProgramDay(BaseModel): + """One scheduled day in a rotating training program.""" + + id: int | None = None + program_id: int | None = None + cycle_day: int + focus: str + session_type: str + prescription: str | None = None + modality: str | None = None + target_zone: str | None = None + target_duration_min: int | None = None + exercises: list[dict[str, Any]] = Field(default_factory=list) + rules: dict[str, Any] = Field(default_factory=dict) + notes: str | None = None + + +class TrainingProgramState(BaseModel): + """Current pointer for a user's active training program.""" + + user_id: str + program_id: int + current_cycle_day: int + current_mesocycle_week: int + updated_at: str + + +class TrainingProgram(BaseModel): + """A rotating training program and its scheduled days.""" + + id: int | None = None + user_id: str + name: str + cycle_length_days: int = 14 + mesocycle_length_days: int = 28 + deload_week_interval: int = 5 + starts_on: str + version: int = 1 + is_active: bool = True + notes: str | None = None + created_at: str + updated_at: str + days: list[TrainingProgramDay] = Field(default_factory=list) + state: TrainingProgramState | None = None + + +class TrainingMetric(BaseModel): + """A user training metric measurement, stored as history.""" + + id: int | None = None + user_id: str + metric_name: str + value: float + unit: str + recorded_at: str + notes: str | None = None + + class SqliteWorkoutStorage(SqlStorage): """Storage for workout tracking using SQLite via aiosqlite.""" @@ -84,9 +166,26 @@ async def _create_tables(self) -> None: workout_date TEXT NOT NULL, split_name TEXT NOT NULL, notes TEXT, - created_at TEXT NOT NULL + created_at TEXT NOT NULL, + program_id INTEGER, + program_version INTEGER, + cycle_day INTEGER, + session_type TEXT NOT NULL DEFAULT 'resistance', + completion_status TEXT NOT NULL DEFAULT 'completed', + metrics TEXT NOT NULL DEFAULT '{}' ) """) + await self._ensure_columns( + "workout_sessions", + { + "program_id": "INTEGER", + "program_version": "INTEGER", + "cycle_day": "INTEGER", + "session_type": "TEXT NOT NULL DEFAULT 'resistance'", + "completion_status": "TEXT NOT NULL DEFAULT 'completed'", + "metrics": "TEXT NOT NULL DEFAULT '{}'", + }, + ) await self._conn.execute(""" CREATE INDEX IF NOT EXISTS idx_workout_sessions_user_date ON workout_sessions (user_id, workout_date DESC) @@ -95,6 +194,10 @@ async def _create_tables(self) -> None: CREATE INDEX IF NOT EXISTS idx_workout_sessions_user_split ON workout_sessions (user_id, split_name) """) + await self._conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_workout_sessions_user_cycle + ON workout_sessions (user_id, cycle_day, session_type) + """) await self._conn.execute(""" CREATE TABLE IF NOT EXISTS workout_exercises ( @@ -113,6 +216,89 @@ async def _create_tables(self) -> None: ON workout_exercises (session_id) """) + await self._conn.execute(""" + CREATE TABLE IF NOT EXISTS training_programs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, + name TEXT NOT NULL, + cycle_length_days INTEGER NOT NULL, + mesocycle_length_days INTEGER NOT NULL, + deload_week_interval INTEGER NOT NULL, + starts_on TEXT NOT NULL, + version INTEGER NOT NULL, + is_active INTEGER NOT NULL DEFAULT 1, + notes TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ) + """) + await self._conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_training_programs_user_active + ON training_programs (user_id, is_active, version DESC) + """) + + await self._conn.execute(""" + CREATE TABLE IF NOT EXISTS training_program_days ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + program_id INTEGER NOT NULL, + cycle_day INTEGER NOT NULL, + focus TEXT NOT NULL, + session_type TEXT NOT NULL, + prescription TEXT, + modality TEXT, + target_zone TEXT, + target_duration_min INTEGER, + exercises TEXT NOT NULL DEFAULT '[]', + rules TEXT NOT NULL DEFAULT '{}', + notes TEXT, + UNIQUE(program_id, cycle_day), + FOREIGN KEY (program_id) + REFERENCES training_programs(id) ON DELETE CASCADE + ) + """) + await self._conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_training_program_days_program_cycle + ON training_program_days (program_id, cycle_day) + """) + + await self._conn.execute(""" + CREATE TABLE IF NOT EXISTS training_program_state ( + user_id TEXT PRIMARY KEY, + program_id INTEGER NOT NULL, + current_cycle_day INTEGER NOT NULL, + current_mesocycle_week INTEGER NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY (program_id) + REFERENCES training_programs(id) ON DELETE CASCADE + ) + """) + + await self._conn.execute(""" + CREATE TABLE IF NOT EXISTS training_metrics ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, + metric_name TEXT NOT NULL, + value REAL NOT NULL, + unit TEXT NOT NULL, + recorded_at TEXT NOT NULL, + notes TEXT + ) + """) + await self._conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_training_metrics_user_name_time + ON training_metrics (user_id, metric_name, recorded_at DESC) + """) + + async def _ensure_columns(self, table_name: str, columns: dict[str, str]) -> None: + cursor = await self._conn.execute(f"PRAGMA table_info({table_name})") # noqa: S608 + rows = await cursor.fetchall() + existing = {row[1] for row in rows} + for column_name, column_sql in columns.items(): + if column_name not in existing: + await self._conn.execute( + f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_sql}" # noqa: S608 + ) + async def create_session(self, session: WorkoutSession) -> int: """Create session row + all exercises atomically.""" async with self._lock: @@ -121,8 +307,12 @@ async def create_session(self, session: WorkoutSession) -> int: cursor = await self._conn.execute( """ INSERT INTO workout_sessions - (user_id, workout_date, split_name, notes, created_at) - VALUES (?, ?, ?, ?, ?) + ( + user_id, workout_date, split_name, notes, created_at, + program_id, program_version, cycle_day, session_type, + completion_status, metrics + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( session.user_id, @@ -130,6 +320,12 @@ async def create_session(self, session: WorkoutSession) -> int: session.split_name, session.notes, session.created_at, + session.program_id, + session.program_version, + session.cycle_day, + session.session_type, + session.completion_status, + json.dumps(session.metrics), ), ) sid = cursor.lastrowid @@ -281,7 +477,14 @@ async def get_recent_sessions( limit = min(limit, 20) rows = await self._fetch_all( """ - SELECT s.id, s.workout_date, s.split_name, COUNT(e.id) as exercise_count + SELECT + s.id, + s.workout_date, + s.split_name, + s.cycle_day, + s.session_type, + s.completion_status, + COUNT(e.id) as exercise_count FROM workout_sessions s LEFT JOIN workout_exercises e ON s.id = e.session_id WHERE s.user_id = ? @@ -297,10 +500,290 @@ async def get_recent_sessions( workout_date=r["workout_date"], split_name=r["split_name"], exercise_count=r["exercise_count"], + cycle_day=r["cycle_day"], + session_type=r["session_type"], + completion_status=r["completion_status"], ) for r in rows ] + async def create_training_program( + self, + program: TrainingProgram, + current_cycle_day: int, + current_mesocycle_week: int, + ) -> int: + """Create a new active training program and its cycle pointer.""" + version_row = await self._fetch_one( + """ + SELECT COALESCE(MAX(version), 0) + 1 as next_version + FROM training_programs + WHERE user_id = ? + """, + (program.user_id,), + ) + version = int(version_row["next_version"]) if version_row else 1 + + async with self._lock: + await self._conn.execute("BEGIN") + try: + await self._conn.execute( + """ + UPDATE training_programs + SET is_active = 0, updated_at = ? + WHERE user_id = ? + """, + (program.updated_at, program.user_id), + ) + cursor = await self._conn.execute( + """ + INSERT INTO training_programs + ( + user_id, name, cycle_length_days, mesocycle_length_days, + deload_week_interval, starts_on, version, is_active, + notes, created_at, updated_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + program.user_id, + program.name, + program.cycle_length_days, + program.mesocycle_length_days, + program.deload_week_interval, + program.starts_on, + version, + 1, + program.notes, + program.created_at, + program.updated_at, + ), + ) + program_id = cursor.lastrowid + if program_id is None: + raise RuntimeError("Failed to get lastrowid after program insert") + + for day in program.days: + await self._conn.execute( + """ + INSERT INTO training_program_days + ( + program_id, cycle_day, focus, session_type, + prescription, modality, target_zone, + target_duration_min, exercises, rules, notes + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + program_id, + day.cycle_day, + day.focus, + day.session_type, + day.prescription, + day.modality, + day.target_zone, + day.target_duration_min, + json.dumps(day.exercises), + json.dumps(day.rules), + day.notes, + ), + ) + + await self._conn.execute( + """ + INSERT INTO training_program_state + ( + user_id, program_id, current_cycle_day, + current_mesocycle_week, updated_at + ) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(user_id) DO UPDATE SET + program_id = excluded.program_id, + current_cycle_day = excluded.current_cycle_day, + current_mesocycle_week = excluded.current_mesocycle_week, + updated_at = excluded.updated_at + """, + ( + program.user_id, + program_id, + current_cycle_day, + current_mesocycle_week, + program.updated_at, + ), + ) + await self._conn.execute("COMMIT") + return int(program_id) + except Exception: + await self._conn.execute("ROLLBACK") + raise + + async def get_active_training_program(self, user_id: str) -> TrainingProgram | None: + """Return the active training program with days and state.""" + row = await self._fetch_one( + """ + SELECT * FROM training_programs + WHERE user_id = ? AND is_active = 1 + ORDER BY version DESC, id DESC + LIMIT 1 + """, + (user_id,), + ) + if row is None: + return None + + program = self._row_to_training_program(row) + if program.id is None: # pragma: no cover + return program + + day_rows = await self._fetch_all( + """ + SELECT * FROM training_program_days + WHERE program_id = ? + ORDER BY cycle_day ASC + """, + (program.id,), + ) + program.days = [self._row_to_training_program_day(r) for r in day_rows] + program.state = await self.get_training_state(user_id) + return program + + async def get_training_state(self, user_id: str) -> TrainingProgramState | None: + """Return the current training cycle pointer for a user.""" + row = await self._fetch_one( + "SELECT * FROM training_program_state WHERE user_id = ?", + (user_id,), + ) + return self._row_to_training_state(row) if row else None + + async def advance_training_state( + self, user_id: str, days: int, updated_at: str + ) -> TrainingProgramState | None: + """Advance the active program pointer by a number of calendar days.""" + program = await self.get_active_training_program(user_id) + if program is None or program.state is None or program.id is None: + return None + + current_day = program.state.current_cycle_day + current_week = program.state.current_mesocycle_week + for _ in range(days): + if current_day % 7 == 0: + current_week += 1 + if current_week > program.deload_week_interval: + current_week = 1 + current_day = ( + current_day + 1 if current_day < program.cycle_length_days else 1 + ) + + async with self._lock: + await self._conn.execute( + """ + UPDATE training_program_state + SET current_cycle_day = ?, current_mesocycle_week = ?, updated_at = ? + WHERE user_id = ? + """, + (current_day, current_week, updated_at, user_id), + ) + + return TrainingProgramState( + user_id=user_id, + program_id=program.id, + current_cycle_day=current_day, + current_mesocycle_week=current_week, + updated_at=updated_at, + ) + + async def get_training_history( + self, + user_id: str, + cycle_day: int | None = None, + session_type: str | None = None, + exercise_name: str | None = None, + limit: int = 8, + ) -> list[WorkoutSession]: + """Return comparable training sessions by cycle day, type, or exercise.""" + limit = min(max(limit, 1), 20) + joins = "" + where = ["s.user_id = ?"] + values: list[Any] = [user_id] + if exercise_name: + joins = "JOIN workout_exercises e ON e.session_id = s.id" + where.append("e.exercise_name = ?") + values.append(exercise_name.lower()) + if cycle_day is not None: + where.append("s.cycle_day = ?") + values.append(cycle_day) + if session_type is not None: + where.append("s.session_type = ?") + values.append(session_type) + + values.append(limit) + query = f""" + SELECT DISTINCT s.* + FROM workout_sessions s + {joins} + WHERE {" AND ".join(where)} + ORDER BY s.workout_date DESC, s.created_at DESC + LIMIT ? + """ # noqa: S608 + rows = await self._fetch_all(query, tuple(values)) + sessions = [] + for row in rows: + session = await self.get_session(int(row["id"]), user_id) + if session is not None: + sessions.append(session) + return sessions + + async def add_training_metrics(self, metrics: list[TrainingMetric]) -> list[int]: + """Insert metric history rows and return their IDs.""" + ids = [] + async with self._lock: + for metric in metrics: + cursor = await self._conn.execute( + """ + INSERT INTO training_metrics + (user_id, metric_name, value, unit, recorded_at, notes) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + metric.user_id, + metric.metric_name, + metric.value, + metric.unit, + metric.recorded_at, + metric.notes, + ), + ) + if cursor.lastrowid is None: + raise RuntimeError("Failed to get lastrowid after metric insert") + ids.append(cursor.lastrowid) + return ids + + async def get_latest_training_metrics( + self, user_id: str, metric_names: list[str] | None = None + ) -> list[TrainingMetric]: + """Return the latest row for each requested metric name.""" + values: list[Any] = [user_id] + where = "WHERE user_id = ?" + if metric_names: + placeholders = ", ".join("?" for _ in metric_names) + where = f"{where} AND metric_name IN ({placeholders})" + values.extend(metric_names) + + rows = await self._fetch_all( + f""" + SELECT * FROM training_metrics + {where} + ORDER BY metric_name ASC, recorded_at DESC, id DESC + """, # noqa: S608 + tuple(values), + ) + latest_by_name: dict[str, TrainingMetric] = {} + for row in rows: + name = row["metric_name"] + if name not in latest_by_name: + latest_by_name[name] = self._row_to_training_metric(row) + return list(latest_by_name.values()) + async def get_exercise_history( self, user_id: str, exercise_name: str, limit: int = 8 ) -> list[ExerciseHistoryEntry]: @@ -361,6 +844,10 @@ async def delete_session(self, session_id: int, user_id: str) -> bool: return cursor.rowcount > 0 def _row_to_session(self, row: dict[str, Any]) -> WorkoutSession: + metrics_data = row.get("metrics") or "{}" + metrics = ( + json.loads(metrics_data) if isinstance(metrics_data, str) else metrics_data + ) return WorkoutSession( id=int(row["id"]), user_id=row["user_id"], @@ -369,6 +856,12 @@ def _row_to_session(self, row: dict[str, Any]) -> WorkoutSession: notes=row["notes"], created_at=row["created_at"], exercises=[], + program_id=row.get("program_id"), + program_version=row.get("program_version"), + cycle_day=row.get("cycle_day"), + session_type=row.get("session_type") or "resistance", + completion_status=row.get("completion_status") or "completed", + metrics=metrics, ) def _row_to_exercise(self, row: dict[str, Any]) -> WorkoutExercise: @@ -384,6 +877,68 @@ def _row_to_exercise(self, row: dict[str, Any]) -> WorkoutExercise: notes=row["notes"], ) + def _row_to_training_program(self, row: dict[str, Any]) -> TrainingProgram: + return TrainingProgram( + id=int(row["id"]), + user_id=row["user_id"], + name=row["name"], + cycle_length_days=int(row["cycle_length_days"]), + mesocycle_length_days=int(row["mesocycle_length_days"]), + deload_week_interval=int(row["deload_week_interval"]), + starts_on=str(row["starts_on"]), + version=int(row["version"]), + is_active=bool(row["is_active"]), + notes=row["notes"], + created_at=row["created_at"], + updated_at=row["updated_at"], + days=[], + state=None, + ) + + def _row_to_training_program_day(self, row: dict[str, Any]) -> TrainingProgramDay: + exercises_data = row.get("exercises") or "[]" + rules_data = row.get("rules") or "{}" + exercises = ( + json.loads(exercises_data) + if isinstance(exercises_data, str) + else exercises_data + ) + rules = json.loads(rules_data) if isinstance(rules_data, str) else rules_data + return TrainingProgramDay( + id=int(row["id"]), + program_id=int(row["program_id"]), + cycle_day=int(row["cycle_day"]), + focus=row["focus"], + session_type=row["session_type"], + prescription=row["prescription"], + modality=row["modality"], + target_zone=row["target_zone"], + target_duration_min=row["target_duration_min"], + exercises=exercises, + rules=rules, + notes=row["notes"], + ) + + def _row_to_training_state(self, row: dict[str, Any]) -> TrainingProgramState: + return TrainingProgramState( + user_id=row["user_id"], + program_id=int(row["program_id"]), + current_cycle_day=int(row["current_cycle_day"]), + current_mesocycle_week=int(row["current_mesocycle_week"]), + updated_at=row["updated_at"], + ) + + def _row_to_training_metric(self, row: dict[str, Any]) -> TrainingMetric: + return TrainingMetric( + id=int(row["id"]), + user_id=row["user_id"], + metric_name=row["metric_name"], + value=float(row["value"]), + unit=row["unit"], + recorded_at=row["recorded_at"], + notes=row["notes"], + ) + _storage: SqliteWorkoutStorage | None = None diff --git a/src/blacki/workouts/tools.py b/src/blacki/workouts/tools.py index 0772c22..5a3736a 100644 --- a/src/blacki/workouts/tools.py +++ b/src/blacki/workouts/tools.py @@ -7,64 +7,54 @@ from blacki.utils.preferences import get_preferences_storage from blacki.utils.timezone import get_app_timezone, now_utc -from .storage import SetDetail, WorkoutExercise, WorkoutSession, get_storage +from .storage import ( + COMPLETION_STATUSES, + SESSION_TYPES, + SetDetail, + TrainingMetric, + TrainingProgram, + TrainingProgramDay, + WorkoutExercise, + WorkoutSession, + get_storage, +) logger = logging.getLogger(__name__) +LOWER_BACK_SWAP_STATUSES = {"tight", "sore", "pain", "avoid_hinge", "strained"} -async def log_workout( - tool_context: ToolContext, - split_name: str, - exercises: list[dict[str, Any]], - workout_date: str | None = None, - notes: str | None = None, -) -> dict[str, Any]: - """Start or complete a full workout session.""" - user_id = tool_context.user_id - if not user_id: - return {"status": "error", "message": "Missing user_id in tool_context"} - parsed_date = parse_date(workout_date) +def _parse_workout_exercises( + exercises: list[dict[str, Any]] | None, +) -> tuple[list[WorkoutExercise], str | None]: + if not exercises: + return [], None - # Parse exercises parsed_exercises = [] for i, ex_dict in enumerate(exercises): if "name" not in ex_dict or "sets" not in ex_dict: - return { - "status": "error", - "message": "Each exercise must have 'name' and 'sets' keys", - } + return [], "Each exercise must have 'name' and 'sets' keys" sets_data = ex_dict["sets"] - sets_list: list[dict[str, Any]] = [] + sets_list: list[dict[str, Any]] if isinstance(sets_data, int): - # Shorthand: sets=3, reps=8, weight=100 reps = ex_dict.get("reps", 0) weight = ex_dict.get("weight_kg") or ex_dict.get("weight", 0) sets_list = [{"weight_kg": weight, "reps": reps} for _ in range(sets_data)] - elif isinstance(sets_data, dict): # pragma: no cover + elif isinstance(sets_data, dict): sets_list = [sets_data] elif isinstance(sets_data, list): sets_list = sets_data - else: # pragma: no cover - return { - "status": "error", - "message": "'sets' must be a list of dictionaries or an integer", - } + else: + return [], "'sets' must be a list of dictionaries or an integer" sets: list[SetDetail] = [] for set_dict in sets_list: if "weight_kg" not in set_dict and "weight" not in set_dict: - return { - "status": "error", - "message": "Each set must have 'weight_kg' (or 'weight')", - } - if "reps" not in set_dict: # pragma: no cover - return { - "status": "error", - "message": "Each set must have 'reps'", - } + return [], "Each set must have 'weight_kg' (or 'weight')" + if "reps" not in set_dict: + return [], "Each set must have 'reps'" weight_val = set_dict.get("weight_kg") or set_dict.get("weight", 0) sets.append( @@ -78,13 +68,99 @@ async def log_workout( parsed_exercises.append( WorkoutExercise( - exercise_name=ex_dict["name"].lower(), + exercise_name=str(ex_dict["name"]).lower(), sets=sets, exercise_order=i, notes=ex_dict.get("notes"), ) ) + return parsed_exercises, None + + +def _infer_metric_unit(metric_name: str) -> str: + if metric_name.endswith("_kg") or "1rm" in metric_name or "load" in metric_name: + return "kg" + if metric_name.endswith("_bpm") or "heart_rate" in metric_name: + return "bpm" + if metric_name.endswith("_km") or "distance" in metric_name: + return "km" + if metric_name.endswith("_min") or "duration" in metric_name: + return "min" + return "unitless" + + +def _parse_training_metrics( + user_id: str, metrics: dict[str, Any], recorded_at: str +) -> tuple[list[TrainingMetric], str | None]: + if not metrics: + return [], "metrics cannot be empty" + + parsed = [] + for raw_name, raw_value in metrics.items(): + metric_name = str(raw_name).strip().lower() + if not metric_name: + return [], "metric names cannot be empty" + + notes = None + metric_recorded_at = recorded_at + if isinstance(raw_value, dict): + if "value" not in raw_value: + return [], f"Metric '{metric_name}' must include a value" + value_obj = raw_value["value"] + unit = str(raw_value.get("unit") or _infer_metric_unit(metric_name)) + notes = raw_value.get("notes") + metric_recorded_at = str(raw_value.get("recorded_at") or recorded_at) + else: + value_obj = raw_value + unit = _infer_metric_unit(metric_name) + + try: + value = float(value_obj) + except (TypeError, ValueError): + return [], f"Metric '{metric_name}' value must be numeric" + + parsed.append( + TrainingMetric( + user_id=user_id, + metric_name=metric_name, + value=value, + unit=unit, + recorded_at=metric_recorded_at, + notes=notes, + ) + ) + + return parsed, None + + +def _matching_program_day( + program: TrainingProgram, cycle_day: int +) -> TrainingProgramDay | None: + for day in program.days: + if day.cycle_day == cycle_day: + return day + return None + + +async def log_workout( + tool_context: ToolContext, + split_name: str, + exercises: list[dict[str, Any]], + workout_date: str | None = None, + notes: str | None = None, +) -> dict[str, Any]: + """Start or complete a full workout session.""" + user_id = tool_context.user_id + if not user_id: + return {"status": "error", "message": "Missing user_id in tool_context"} + + parsed_date = parse_date(workout_date) + + parsed_exercises, parse_error = _parse_workout_exercises(exercises) + if parse_error: + return {"status": "error", "message": parse_error} + session = WorkoutSession( user_id=user_id, workout_date=parsed_date, @@ -92,6 +168,7 @@ async def log_workout( notes=notes, created_at=now_utc().isoformat(timespec="seconds"), exercises=parsed_exercises, + session_type="resistance", ) storage = get_storage() @@ -143,6 +220,367 @@ async def log_workout( return result +async def set_training_program( + tool_context: ToolContext, + program_config: dict[str, Any], + baseline_metrics: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Store a rotating, multi-modal training program and its cycle pointer.""" + user_id = tool_context.user_id + if not user_id: + return {"status": "error", "message": "Missing user_id in tool_context"} + + days_config = program_config.get("days") + if not isinstance(days_config, list) or not days_config: + return {"status": "error", "message": "program_config.days must be a list"} + + cycle_length_days = int(program_config.get("cycle_length_days") or len(days_config)) + current_cycle_day = int(program_config.get("current_cycle_day") or 1) + current_mesocycle_week = int(program_config.get("current_mesocycle_week") or 1) + if cycle_length_days < 1: + return {"status": "error", "message": "cycle_length_days must be positive"} + if not (1 <= current_cycle_day <= cycle_length_days): + return { + "status": "error", + "message": "current_cycle_day must fit within the cycle length", + } + if current_mesocycle_week < 1: + return {"status": "error", "message": "current_mesocycle_week must be positive"} + + program_days = [] + seen_days: set[int] = set() + for index, day_config in enumerate(days_config, start=1): + if not isinstance(day_config, dict): + return {"status": "error", "message": "Each program day must be an object"} + cycle_day = int(day_config.get("cycle_day") or index) + if not (1 <= cycle_day <= cycle_length_days): + return {"status": "error", "message": "cycle_day is outside cycle length"} + if cycle_day in seen_days: + return {"status": "error", "message": f"Duplicate cycle_day {cycle_day}"} + seen_days.add(cycle_day) + + session_type = str(day_config.get("session_type") or "other").lower() + if session_type not in SESSION_TYPES: + return { + "status": "error", + "message": f"session_type must be one of {sorted(SESSION_TYPES)}", + } + + exercises = day_config.get("exercises") or [] + if not isinstance(exercises, list): + return {"status": "error", "message": "day exercises must be a list"} + rules = day_config.get("rules") or {} + if not isinstance(rules, dict): + return {"status": "error", "message": "day rules must be an object"} + + target_duration = day_config.get("target_duration_min") + program_days.append( + TrainingProgramDay( + cycle_day=cycle_day, + focus=str(day_config.get("focus") or session_type), + session_type=session_type, + prescription=day_config.get("prescription"), + modality=day_config.get("modality"), + target_zone=day_config.get("target_zone"), + target_duration_min=int(target_duration) + if target_duration is not None + else None, + exercises=exercises, + rules=rules, + notes=day_config.get("notes"), + ) + ) + + now = now_utc().isoformat(timespec="seconds") + starts_on = parse_date(program_config.get("starts_on")) + program = TrainingProgram( + user_id=user_id, + name=str(program_config.get("name") or "Training Program"), + cycle_length_days=cycle_length_days, + mesocycle_length_days=int(program_config.get("mesocycle_length_days") or 28), + deload_week_interval=int(program_config.get("deload_week_interval") or 5), + starts_on=starts_on, + notes=program_config.get("notes"), + created_at=now, + updated_at=now, + days=program_days, + ) + + storage = get_storage() + program_id = await storage.create_training_program( + program, + current_cycle_day=current_cycle_day, + current_mesocycle_week=current_mesocycle_week, + ) + + metrics_config = baseline_metrics or program_config.get("baseline_metrics") + metric_ids: list[int] = [] + if isinstance(metrics_config, dict): + parsed_metrics, metric_error = _parse_training_metrics( + user_id, metrics_config, now + ) + if metric_error: + return {"status": "error", "message": metric_error} + metric_ids = await storage.add_training_metrics(parsed_metrics) + + return { + "status": "success", + "message": ( + f"Stored training program '{program.name}' with {len(program_days)} days." + ), + "program_id": program_id, + "cycle_length_days": cycle_length_days, + "current_cycle_day": current_cycle_day, + "current_mesocycle_week": current_mesocycle_week, + "metric_ids": metric_ids, + } + + +async def get_todays_training(tool_context: ToolContext) -> dict[str, Any]: + """Return today's training from the active rotating program pointer.""" + user_id = tool_context.user_id + if not user_id: + return {"status": "error", "message": "Missing user_id in tool_context"} + + storage = get_storage() + program = await storage.get_active_training_program(user_id) + if program is None or program.state is None: + return { + "status": "not_configured", + "message": "No active training program configured.", + } + + cycle_day = program.state.current_cycle_day + program_day = _matching_program_day(program, cycle_day) + if program_day is None: + return { + "status": "error", + "message": f"No program day found for cycle day {cycle_day}", + } + + last_sessions = await storage.get_training_history( + user_id, + cycle_day=cycle_day, + session_type=program_day.session_type, + limit=1, + ) + is_deload = program.state.current_mesocycle_week == program.deload_week_interval + + recommendations = [] + if cycle_day == 6: + day_four_sessions = await storage.get_training_history( + user_id, + cycle_day=4, + session_type="resistance", + limit=1, + ) + if day_four_sessions: + lower_back_status = str( + day_four_sessions[0].metrics.get("lower_back_status", "") + ).lower() + if lower_back_status in LOWER_BACK_SWAP_STATUSES: + recommendations.append( + { + "type": "conditional_swap", + "message": ( + "Lower back status after Day 4 suggests swapping " + "Day 6 and Day 7." + ), + "lower_back_status": lower_back_status, + } + ) + + return { + "status": "success", + "program": { + "id": program.id, + "name": program.name, + "version": program.version, + "cycle_length_days": program.cycle_length_days, + }, + "state": program.state.model_dump(), + "training_day": program_day.model_dump(), + "is_deload": is_deload, + "deload_message": "Use 50% sets and 50% load." if is_deload else None, + "last_comparable_session": last_sessions[0].model_dump() + if last_sessions + else None, + "recommendations": recommendations, + } + + +async def log_training( + tool_context: ToolContext, + session_type: str, + cycle_day: int | None = None, + workout_date: str | None = None, + exercises: list[dict[str, Any]] | None = None, + metrics: dict[str, Any] | None = None, + notes: str | None = None, + completion_status: str = "completed", + advance_day: bool = False, +) -> dict[str, Any]: + """Log a resistance, conditioning, recovery, or rest training session.""" + user_id = tool_context.user_id + if not user_id: + return {"status": "error", "message": "Missing user_id in tool_context"} + + normalized_session_type = session_type.lower() + if normalized_session_type not in SESSION_TYPES: + return { + "status": "error", + "message": f"session_type must be one of {sorted(SESSION_TYPES)}", + } + normalized_completion = completion_status.lower() + if normalized_completion not in COMPLETION_STATUSES: + allowed_statuses = sorted(COMPLETION_STATUSES) + return { + "status": "error", + "message": f"completion_status must be one of {allowed_statuses}", + } + + parsed_exercises, parse_error = _parse_workout_exercises(exercises) + if parse_error: + return {"status": "error", "message": parse_error} + + storage = get_storage() + program = await storage.get_active_training_program(user_id) + if cycle_day is None and program and program.state: + cycle_day = program.state.current_cycle_day + + program_day = ( + _matching_program_day(program, cycle_day) if program and cycle_day else None + ) + split_name = program_day.focus if program_day else normalized_session_type + previous_sessions = await storage.get_training_history( + user_id, + cycle_day=cycle_day, + session_type=normalized_session_type, + limit=1, + ) + now = now_utc().isoformat(timespec="seconds") + session = WorkoutSession( + user_id=user_id, + workout_date=parse_date(workout_date), + split_name=split_name, + notes=notes, + created_at=now, + exercises=parsed_exercises, + program_id=program.id if program else None, + program_version=program.version if program else None, + cycle_day=cycle_day, + session_type=normalized_session_type, + completion_status=normalized_completion, + metrics=metrics or {}, + ) + session_id = await storage.create_session(session) + + advanced_state = None + if advance_day: + advanced_state = await storage.advance_training_state(user_id, 1, now) + + return { + "status": "success", + "session_id": session_id, + "message": f"Logged {normalized_session_type} training session.", + "previous_comparable_session": previous_sessions[0].model_dump() + if previous_sessions + else None, + "advanced_state": advanced_state.model_dump() if advanced_state else None, + } + + +async def advance_training_cycle( + tool_context: ToolContext, + days: int = 1, +) -> dict[str, Any]: + """Explicitly advance the active training program pointer.""" + user_id = tool_context.user_id + if not user_id: + return {"status": "error", "message": "Missing user_id in tool_context"} + if not (1 <= days <= 28): + return {"status": "error", "message": "days must be between 1 and 28"} + + storage = get_storage() + updated_at = now_utc().isoformat(timespec="seconds") + state = await storage.advance_training_state(user_id, days, updated_at) + if state is None: + return { + "status": "not_configured", + "message": "No active training program configured.", + } + + return {"status": "success", "state": state.model_dump()} + + +async def get_training_history( + tool_context: ToolContext, + cycle_day: int | None = None, + session_type: str | None = None, + exercise_name: str | None = None, + limit: int = 8, +) -> dict[str, Any]: + """Get comparable training sessions by cycle day, type, or exercise.""" + user_id = tool_context.user_id + if not user_id: + return {"status": "error", "message": "Missing user_id in tool_context"} + + normalized_session_type = session_type.lower() if session_type else None + if normalized_session_type and normalized_session_type not in SESSION_TYPES: + return { + "status": "error", + "message": f"session_type must be one of {sorted(SESSION_TYPES)}", + } + + storage = get_storage() + sessions = await storage.get_training_history( + user_id, + cycle_day=cycle_day, + session_type=normalized_session_type, + exercise_name=exercise_name, + limit=limit, + ) + return {"status": "success", "sessions": [s.model_dump() for s in sessions]} + + +async def get_training_metrics( + tool_context: ToolContext, + metric_names: list[str] | None = None, +) -> dict[str, Any]: + """Get latest training metrics like 1RMs, max HR, or ruck load.""" + user_id = tool_context.user_id + if not user_id: + return {"status": "error", "message": "Missing user_id in tool_context"} + + normalized_names = [name.lower() for name in metric_names] if metric_names else None + storage = get_storage() + metrics = await storage.get_latest_training_metrics(user_id, normalized_names) + return {"status": "success", "metrics": [m.model_dump() for m in metrics]} + + +async def update_training_metrics( + tool_context: ToolContext, + metrics: dict[str, Any], +) -> dict[str, Any]: + """Record training metric history such as 1RMs or max heart rate.""" + user_id = tool_context.user_id + if not user_id: + return {"status": "error", "message": "Missing user_id in tool_context"} + + now = now_utc().isoformat(timespec="seconds") + parsed_metrics, metric_error = _parse_training_metrics(user_id, metrics, now) + if metric_error: + return {"status": "error", "message": metric_error} + + storage = get_storage() + metric_ids = await storage.add_training_metrics(parsed_metrics) + return { + "status": "success", + "message": f"Recorded {len(metric_ids)} training metrics.", + "metric_ids": metric_ids, + } + + async def get_last_workout( tool_context: ToolContext, split_name: str, diff --git a/tests/test_registry.py b/tests/test_registry.py index 3e5b0c5..ef604f9 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -227,7 +227,7 @@ def test_returns_tools_when_available(self) -> None: tools = _build_workout_tools() - assert len(tools) == 7 + assert len(tools) == 14 class TestBuildSandboxTools: diff --git a/tests/workouts/test_training.py b/tests/workouts/test_training.py new file mode 100644 index 0000000..c686d0a --- /dev/null +++ b/tests/workouts/test_training.py @@ -0,0 +1,1177 @@ +# mypy: disable-error-code="no-untyped-def" +# ruff: noqa: E501 +"""Unit tests for training-program storage, models, and tools.""" + +import asyncio +from unittest.mock import AsyncMock, create_autospec, patch + +import aiosqlite +import pytest +from google.adk.tools import ToolContext + +from blacki.workouts.storage import ( + SetDetail, + SqliteWorkoutStorage, + TrainingMetric, + TrainingProgram, + TrainingProgramDay, + TrainingProgramState, + WorkoutExercise, + WorkoutSession, +) +from blacki.workouts.tools import ( + advance_training_cycle, + get_todays_training, + get_training_history, + get_training_metrics, + log_training, + set_training_program, + update_training_metrics, +) + + +@pytest.fixture +async def conn(): + """Create an in-memory SQLite connection for testing.""" + conn = await aiosqlite.connect(":memory:", isolation_level=None) + conn.row_factory = aiosqlite.Row + yield conn + await conn.close() + + +@pytest.fixture +def lock(): + """Create a lock for write operations.""" + return asyncio.Lock() + + +@pytest.fixture +async def storage(conn, lock): + """Create a storage instance with the test connection.""" + storage = SqliteWorkoutStorage(conn, lock) + await storage.initialize() + yield storage + await storage.close() + + +@pytest.fixture +def mock_tool_context(): + mock_context = create_autospec(ToolContext, spec_set=True, instance=True) + mock_context.state = {} + mock_context.user_id = "user1" + return mock_context + + +class TestSqliteTrainingStorage: + """Tests for the SQLite training-program storage methods.""" + + @pytest.mark.asyncio + async def test_ensure_columns_idempotent(self, conn, lock) -> None: + """Should handle existing columns and ignore them during ALTER.""" + storage = SqliteWorkoutStorage(conn, lock) + await storage.initialize() + # Call it again to prove safety/idempotency + await storage.initialize() + assert storage.is_initialized is True + + @pytest.mark.asyncio + async def test_create_and_get_training_program(self, storage) -> None: + """Should save a program with days and retrieve it as active.""" + program = TrainingProgram( + user_id="user1", + name="Test Program", + cycle_length_days=14, + starts_on="2026-06-07", + created_at="2026-06-07T12:00:00", + updated_at="2026-06-07T12:00:00", + days=[ + TrainingProgramDay( + cycle_day=1, + focus="Legs", + session_type="resistance", + prescription="3x5 Squats", + exercises=[{"name": "squat", "sets": 3}], + ) + ], + ) + + program_id = await storage.create_training_program(program, 1, 1) + assert program_id == 1 + + active = await storage.get_active_training_program("user1") + assert active is not None + assert active.name == "Test Program" + assert len(active.days) == 1 + assert active.days[0].focus == "Legs" + assert active.days[0].exercises == [{"name": "squat", "sets": 3}] + assert active.state is not None + assert active.state.current_cycle_day == 1 + assert active.state.current_mesocycle_week == 1 + + @pytest.mark.asyncio + async def test_advance_training_state_calculations(self, storage) -> None: + """Should correctly advance cycle days and weeks, wrapping 14 to 1 and weeks to week 1.""" + program = TrainingProgram( + user_id="user1", + name="Rotating Program", + cycle_length_days=14, + mesocycle_length_days=28, + deload_week_interval=5, + starts_on="2026-06-07", + created_at="2026-06-07T12:00:00", + updated_at="2026-06-07T12:00:00", + days=[ + TrainingProgramDay(cycle_day=i, focus="Rest", session_type="rest") + for i in range(1, 15) + ], + ) + await storage.create_training_program( + program, current_cycle_day=1, current_mesocycle_week=1 + ) + + # Advance by 6 days (to day 7) + state = await storage.advance_training_state("user1", 6, "2026-06-13T12:00:00") + assert state is not None + assert state.current_cycle_day == 7 + assert state.current_mesocycle_week == 1 + + # Advance by 1 more day (to day 8). Day 7 was the end of week 1, so week advances to 2. + state = await storage.advance_training_state("user1", 1, "2026-06-14T12:00:00") + assert state.current_cycle_day == 8 + assert state.current_mesocycle_week == 2 + + # Advance by 7 days (to day 15 -> day 1). Day 14 was the end of week 2, so week advances to 3. + state = await storage.advance_training_state("user1", 7, "2026-06-21T12:00:00") + assert state.current_cycle_day == 1 + assert state.current_mesocycle_week == 3 + + # Advance state to deload week 5 (week 4 ends at day 28 which is cycle day 14 of second loop) + # Week 3 ends at cycle day 7 of third loop, week 4 ends at cycle day 14 of third loop. + # Let's verify by advancing to week 5 + # Current state is Day 1, Week 3. + # Advance by 13 days to get to Day 14, Week 4 + state = await storage.advance_training_state("user1", 13, "2026-07-04T12:00:00") + assert state.current_cycle_day == 14 + assert state.current_mesocycle_week == 4 + + # Advance 1 day to Day 1, Week 5 (Deload week) + state = await storage.advance_training_state("user1", 1, "2026-07-05T12:00:00") + assert state.current_cycle_day == 1 + assert state.current_mesocycle_week == 5 + + # Day 7 is end of week 5. Advance 6 days to Day 7, Week 5. + state = await storage.advance_training_state("user1", 6, "2026-07-11T12:00:00") + assert state.current_cycle_day == 7 + assert state.current_mesocycle_week == 5 + + # Advance 1 day. End of deload interval (5), week wraps back to 1. Cycle day becomes 8. + state = await storage.advance_training_state("user1", 1, "2026-07-12T12:00:00") + assert state.current_cycle_day == 8 + assert state.current_mesocycle_week == 1 + + @pytest.mark.asyncio + async def test_get_training_history_filters(self, storage) -> None: + """Should filter history by cycle day, session type, or exercise name.""" + now = "2026-06-07T12:00:00" + session1 = WorkoutSession( + user_id="user1", + workout_date="2026-06-07", + split_name="Legs", + created_at=now, + cycle_day=1, + session_type="resistance", + exercises=[ + WorkoutExercise( + exercise_name="squat", + sets=[SetDetail(set_num=1, weight_kg=150.0, reps=5)], + ) + ], + ) + session2 = WorkoutSession( + user_id="user1", + workout_date="2026-06-08", + split_name="Elliptical", + created_at=now, + cycle_day=2, + session_type="zone2", + metrics={"duration_min": 45, "avg_hr_bpm": 135}, + ) + await storage.create_session(session1) + await storage.create_session(session2) + + # Filter by cycle day + history_day1 = await storage.get_training_history("user1", cycle_day=1) + assert len(history_day1) == 1 + assert history_day1[0].split_name == "Legs" + + # Filter by type + history_zone2 = await storage.get_training_history( + "user1", session_type="zone2" + ) + assert len(history_zone2) == 1 + assert history_zone2[0].split_name == "Elliptical" + + # Filter by exercise + history_squat = await storage.get_training_history( + "user1", exercise_name="squat" + ) + assert len(history_squat) == 1 + assert history_squat[0].exercises[0].exercise_name == "squat" + + # Filter by missing exercise + history_bench = await storage.get_training_history( + "user1", exercise_name="bench" + ) + assert len(history_bench) == 0 + + @pytest.mark.asyncio + async def test_training_metrics_history(self, storage) -> None: + """Should store metrics history and return the single latest point for each metric.""" + metric1 = TrainingMetric( + user_id="user1", + metric_name="squat_1rm", + value=150.0, + unit="kg", + recorded_at="2026-06-07T12:00:00", + ) + metric2 = TrainingMetric( + user_id="user1", + metric_name="squat_1rm", + value=152.5, + unit="kg", + recorded_at="2026-06-08T12:00:00", + ) + metric3 = TrainingMetric( + user_id="user1", + metric_name="bench_1rm", + value=100.0, + unit="kg", + recorded_at="2026-06-07T12:00:00", + ) + + ids = await storage.add_training_metrics([metric1, metric2, metric3]) + assert len(ids) == 3 + + # Retrieve all latest metrics + latest = await storage.get_latest_training_metrics("user1") + assert len(latest) == 2 + + # Verify order and values (latest squat_1rm should be 152.5) + bench = next(m for m in latest if m.metric_name == "bench_1rm") + squat = next(m for m in latest if m.metric_name == "squat_1rm") + assert bench.value == 100.0 + assert squat.value == 152.5 + assert squat.recorded_at == "2026-06-08T12:00:00" + + # Retrieve filtered list of names + filtered = await storage.get_latest_training_metrics("user1", ["bench_1rm"]) + assert len(filtered) == 1 + assert filtered[0].metric_name == "bench_1rm" + + +class TestTrainingTools: + """Tests for the high-level training tools and workflows.""" + + @pytest.mark.asyncio + @patch("blacki.workouts.tools.get_storage") + async def test_set_training_program_success( + self, mock_get_storage, mock_tool_context + ) -> None: + """Should validate and create a program config and baseline metrics.""" + mock_storage = AsyncMock() + mock_get_storage.return_value = mock_storage + mock_storage.create_training_program.return_value = 1 + mock_storage.add_training_metrics.return_value = [1, 2] + + config = { + "name": "Rotating Plan", + "cycle_length_days": 14, + "starts_on": "today", + "days": [ + { + "cycle_day": 1, + "focus": "Legs Strength", + "session_type": "resistance", + "prescription": "3x5 Squats", + "exercises": [{"name": "squat", "sets": 3}], + }, + { + "cycle_day": 2, + "focus": "Cardio Base", + "session_type": "zone2", + "modality": "elliptical", + "target_duration_min": 45, + }, + ], + "baseline_metrics": { + "squat_1rm": 150.0, + "deadlift_1rm": 200.0, + }, + } + + result = await set_training_program(mock_tool_context, config) + assert result["status"] == "success" + assert result["program_id"] == 1 + assert len(result["metric_ids"]) == 2 + mock_storage.create_training_program.assert_called_once() + mock_storage.add_training_metrics.assert_called_once() + + @pytest.mark.asyncio + @patch("blacki.workouts.tools.get_storage") + async def test_set_training_program_validations( + self, mock_get_storage, mock_tool_context + ) -> None: + """Should enforce validation rules on keys, indexes, and session types.""" + mock_storage = AsyncMock() + mock_get_storage.return_value = mock_storage + + # Missing days + assert (await set_training_program(mock_tool_context, {}))["status"] == "error" + + # Cycle day outside bounds + config = { + "days": [{"cycle_day": 20, "session_type": "rest"}], + "cycle_length_days": 14, + } + assert (await set_training_program(mock_tool_context, config))[ + "status" + ] == "error" + + # Duplicate cycle day + config = { + "days": [ + {"cycle_day": 1, "session_type": "rest"}, + {"cycle_day": 1, "session_type": "rest"}, + ], + "cycle_length_days": 14, + } + assert (await set_training_program(mock_tool_context, config))[ + "status" + ] == "error" + + # Invalid session type + config = { + "days": [{"cycle_day": 1, "session_type": "invalid_type"}], + "cycle_length_days": 14, + } + assert (await set_training_program(mock_tool_context, config))[ + "status" + ] == "error" + + # Invalid day config format + config = { + "days": [None], + "cycle_length_days": 14, + } + assert (await set_training_program(mock_tool_context, config))[ + "status" + ] == "error" + + # Invalid metrics structure + config = { + "days": [{"cycle_day": 1, "session_type": "rest"}], + "cycle_length_days": 14, + "baseline_metrics": {"": 100}, + } + assert (await set_training_program(mock_tool_context, config))[ + "status" + ] == "error" + + @pytest.mark.asyncio + @patch("blacki.workouts.tools.get_storage") + async def test_get_todays_training_no_program( + self, mock_get_storage, mock_tool_context + ) -> None: + """Should return not_configured when no active program is set.""" + mock_storage = AsyncMock() + mock_get_storage.return_value = mock_storage + mock_storage.get_active_training_program.return_value = None + + result = await get_todays_training(mock_tool_context) + assert result["status"] == "not_configured" + + @pytest.mark.asyncio + @patch("blacki.workouts.tools.get_storage") + async def test_get_todays_training_day_six_swap( + self, mock_get_storage, mock_tool_context + ) -> None: + """Should recommend swapping Day 6 and Day 7 if Day 4 deadlifts logged back strain.""" + mock_storage = AsyncMock() + mock_get_storage.return_value = mock_storage + + day = TrainingProgramDay( + cycle_day=6, + focus="VO2 Max", + session_type="vo2", + prescription="4x4 Rower", + ) + program = TrainingProgram( + id=1, + user_id="user1", + name=" Rotating", + starts_on="2026-06-07", + created_at="2026-06-07T12:00:00", + updated_at="2026-06-07T12:00:00", + days=[day], + state=TrainingProgramState( + user_id="user1", + program_id=1, + current_cycle_day=6, + current_mesocycle_week=1, + updated_at="2026-06-12T12:00:00", + ), + ) + mock_storage.get_active_training_program.return_value = program + + # Mock Day 4 log with back strain + day_four_session = WorkoutSession( + user_id="user1", + workout_date="2026-06-10", + split_name="Pull", + created_at="2026-06-10T12:00:00", + cycle_day=4, + metrics={"lower_back_status": "pain"}, + ) + + async def history_side_effect( + user_id, cycle_day, session_type, limit=1, exercise_name=None + ): + if cycle_day == 6: + return [] + if cycle_day == 4: + return [day_four_session] + return [] + + mock_storage.get_training_history.side_effect = history_side_effect + + result = await get_todays_training(mock_tool_context) + assert result["status"] == "success" + assert len(result["recommendations"]) == 1 + assert result["recommendations"][0]["type"] == "conditional_swap" + assert "swapping Day 6 and Day 7" in result["recommendations"][0]["message"] + + @pytest.mark.asyncio + @patch("blacki.workouts.tools.get_storage") + async def test_log_training_success_and_advancement( + self, mock_get_storage, mock_tool_context + ) -> None: + """Should save a session with multi-modal metrics and advance state pointer if flag is set.""" + mock_storage = AsyncMock() + mock_get_storage.return_value = mock_storage + mock_storage.create_session.return_value = 123 + + program = TrainingProgram( + id=1, + user_id="user1", + name="Rotating Plan", + starts_on="2026-06-07", + created_at="2026-06-07T12:00:00", + updated_at="2026-06-07T12:00:00", + days=[ + TrainingProgramDay(cycle_day=1, focus="Legs", session_type="resistance") + ], + state=TrainingProgramState( + user_id="user1", + program_id=1, + current_cycle_day=1, + current_mesocycle_week=1, + updated_at="2026-06-07T12:00:00", + ), + ) + mock_storage.get_active_training_program.return_value = program + + advanced_state = TrainingProgramState( + user_id="user1", + program_id=1, + current_cycle_day=2, + current_mesocycle_week=1, + updated_at="2026-06-07T13:00:00", + ) + mock_storage.advance_training_state.return_value = advanced_state + mock_storage.get_training_history.return_value = [] + + result = await log_training( + mock_tool_context, + session_type="resistance", + cycle_day=1, + exercises=[{"name": "squat", "sets": [{"weight_kg": 150.0, "reps": 5}]}], + metrics={"lower_back_status": "ok"}, + advance_day=True, + ) + + assert result["status"] == "success" + assert result["session_id"] == 123 + assert result["advanced_state"] is not None + assert result["advanced_state"]["current_cycle_day"] == 2 + mock_storage.create_session.assert_called_once() + mock_storage.advance_training_state.assert_called_once() + + @pytest.mark.asyncio + @patch("blacki.workouts.tools.get_storage") + async def test_advance_training_cycle_errors( + self, mock_get_storage, mock_tool_context + ) -> None: + """Should validate advance steps and error when no program exists.""" + mock_storage = AsyncMock() + mock_get_storage.return_value = mock_storage + + # Invalid days input + result = await advance_training_cycle(mock_tool_context, days=0) + assert result["status"] == "error" + + # Active program missing + mock_storage.advance_training_state.return_value = None + result = await advance_training_cycle(mock_tool_context, days=1) + assert result["status"] == "not_configured" + + @pytest.mark.asyncio + @patch("blacki.workouts.tools.get_storage") + async def test_get_training_history_tool( + self, mock_get_storage, mock_tool_context + ) -> None: + """Should filter and validate parameters in history tools.""" + mock_storage = AsyncMock() + mock_get_storage.return_value = mock_storage + mock_storage.get_training_history.return_value = [] + + result = await get_training_history(mock_tool_context, session_type="invalid") + assert result["status"] == "error" + + result = await get_training_history(mock_tool_context, session_type="zone2") + assert result["status"] == "success" + + @pytest.mark.asyncio + @patch("blacki.workouts.tools.get_storage") + async def test_get_and_update_metrics_tools( + self, mock_get_storage, mock_tool_context + ) -> None: + """Should record metric history and retrieve filtered values.""" + mock_storage = AsyncMock() + mock_get_storage.return_value = mock_storage + mock_storage.add_training_metrics.return_value = [10] + mock_storage.get_latest_training_metrics.return_value = [ + TrainingMetric( + user_id="user1", + metric_name="squat_1rm", + value=150.0, + unit="kg", + recorded_at="2026-06-07T12:00:00", + ) + ] + + # Update metrics + result = await update_training_metrics(mock_tool_context, {"squat_1rm": 150.0}) + assert result["status"] == "success" + assert result["metric_ids"] == [10] + + # Get metrics + result = await get_training_metrics(mock_tool_context, ["squat_1rm"]) + assert result["status"] == "success" + assert len(result["metrics"]) == 1 + assert result["metrics"][0]["metric_name"] == "squat_1rm" + + @pytest.mark.asyncio + async def test_missing_user_id_error_paths(self) -> None: + """Should handle missing user_id in tool contexts gracefully across all new paths.""" + mock_context = create_autospec(ToolContext, spec_set=True, instance=True) + mock_context.user_id = None + + assert (await set_training_program(mock_context, {}))["status"] == "error" + assert (await get_todays_training(mock_context))["status"] == "error" + assert (await log_training(mock_context, "rest"))["status"] == "error" + assert (await advance_training_cycle(mock_context))["status"] == "error" + assert (await get_training_history(mock_context))["status"] == "error" + assert (await get_training_metrics(mock_context))["status"] == "error" + assert (await update_training_metrics(mock_context, {}))["status"] == "error" + + +class TestTrainingRegistryAndPrompt: + """Tests confirming tools are correctly integrated into the registry and exposed by prompt.""" + + def test_training_tools_built_by_registry(self) -> None: + """Registry build_tools must include all 7 new tools.""" + from blacki.registry import ToolConfig, build_tools + + config = ToolConfig(sqlite_path="/tmp/test_tools.db") + tools = build_tools(config) + tool_names = { + getattr(t, "name", None) or getattr(t, "__name__", "") for t in tools + } + + new_tool_names = { + "set_training_program", + "get_todays_training", + "log_training", + "advance_training_cycle", + "get_training_history", + "get_training_metrics", + "update_training_metrics", + } + for name in new_tool_names: + assert name in tool_names + + def test_prompt_guidance_for_training(self) -> None: + """The system prompt must contain references to training-program specs.""" + from blacki.prompt import return_instruction_root + + prompt = return_instruction_root() + assert "training-program" in prompt + assert "set_training_program" in prompt + assert "get_todays_training" in prompt + assert "log_training" in prompt + assert "advance_training_cycle" in prompt + assert "get_training_metrics" in prompt + assert "update_training_metrics" in prompt + + +class TestTrainingEdgeCasesAndCoverage: + """Additional tests to reach 100% test coverage across training storage and tools.""" + + @pytest.mark.asyncio + async def test_ensure_columns_idempotence_branches(self, conn, lock) -> None: + """Force table initialization when tables exist to cover False branch of column check.""" + storage = SqliteWorkoutStorage(conn, lock) + await storage.initialize() + # Reset schema_ready to check existing table columns path + storage._schema_ready = False + await storage.initialize() + assert storage.is_initialized is True + + @pytest.mark.asyncio + async def test_storage_create_program_insert_failure_rollback( + self, storage + ) -> None: + """Prove transactions roll back correctly if an error is thrown in create_training_program.""" + orig_execute = storage._conn.execute + + class MockAiosqliteHelper: + def __init__(self, ctx, force_fail=False): + self.ctx = ctx + self.force_fail = force_fail + + def __await__(self): + return self._await_impl().__await__() + + async def _await_impl(self): + if self.force_fail: + raise Exception("mock commit fail") + return await self.ctx + + async def __aenter__(self): + if self.force_fail: + raise Exception("mock commit fail") + return await self.ctx.__aenter__() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.ctx.__aexit__(exc_type, exc_val, exc_tb) + + def mock_execute(query, *args, **kwargs): + force_fail = "INSERT INTO training_programs" in query + ctx = orig_execute(query, *args, **kwargs) + return MockAiosqliteHelper(ctx, force_fail=force_fail) + + with patch.object(storage._conn, "execute", side_effect=mock_execute): + program = TrainingProgram( + user_id="user1", + name="Fail Program", + starts_on="2026-06-07", + created_at="2026-06-07T12:00:00", + updated_at="2026-06-07T12:00:00", + ) + + with pytest.raises(Exception, match="mock commit fail"): + await storage.create_training_program(program, 1, 1) + + @pytest.mark.asyncio + async def test_storage_create_program_missing_lastrowid(self, storage) -> None: + """RuntimeError should be raised if lastrowid is None after program insertion.""" + orig_execute = storage._conn.execute + + class MockCursor: + def __init__(self, orig_cursor): + self._orig = orig_cursor + + @property + def lastrowid(self): + return None + + def __getattr__(self, name): + return getattr(self._orig, name) + + class MockAiosqliteHelper: + def __init__(self, ctx, force_none_id=False): + self.ctx = ctx + self.force_none_id = force_none_id + + def __await__(self): + return self._await_impl().__await__() + + async def _await_impl(self): + cursor = await self.ctx + if self.force_none_id: + return MockCursor(cursor) + return cursor + + async def __aenter__(self): + cursor = await self.ctx.__aenter__() + if self.force_none_id: + return MockCursor(cursor) + return cursor + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.ctx.__aexit__(exc_type, exc_val, exc_tb) + + def mock_execute(query, *args, **kwargs): + force_none_id = "INSERT INTO training_programs" in query + ctx = orig_execute(query, *args, **kwargs) + return MockAiosqliteHelper(ctx, force_none_id=force_none_id) + + with patch.object(storage._conn, "execute", side_effect=mock_execute): + program = TrainingProgram( + user_id="user1", + name="Fail Program", + starts_on="2026-06-07", + created_at="2026-06-07T12:00:00", + updated_at="2026-06-07T12:00:00", + ) + + with pytest.raises( + RuntimeError, match="Failed to get lastrowid after program insert" + ): + await storage.create_training_program(program, 1, 1) + + @pytest.mark.asyncio + async def test_storage_create_metrics_missing_lastrowid(self, storage) -> None: + """RuntimeError should be raised if lastrowid is None after metric insertion.""" + orig_execute = storage._conn.execute + + class MockCursor: + def __init__(self, orig_cursor): + self._orig = orig_cursor + + @property + def lastrowid(self): + return None + + def __getattr__(self, name): + return getattr(self._orig, name) + + class MockAiosqliteHelper: + def __init__(self, ctx, force_none_id=False): + self.ctx = ctx + self.force_none_id = force_none_id + + def __await__(self): + return self._await_impl().__await__() + + async def _await_impl(self): + cursor = await self.ctx + if self.force_none_id: + return MockCursor(cursor) + return cursor + + async def __aenter__(self): + cursor = await self.ctx.__aenter__() + if self.force_none_id: + return MockCursor(cursor) + return cursor + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.ctx.__aexit__(exc_type, exc_val, exc_tb) + + def mock_execute(query, *args, **kwargs): + force_none_id = "INSERT INTO training_metrics" in query + ctx = orig_execute(query, *args, **kwargs) + return MockAiosqliteHelper(ctx, force_none_id=force_none_id) + + with patch.object(storage._conn, "execute", side_effect=mock_execute): + metric = TrainingMetric( + user_id="user1", + metric_name="bench", + value=100.0, + unit="kg", + recorded_at="2026-06-07T12:00:00", + ) + + with pytest.raises( + RuntimeError, match="Failed to get lastrowid after metric insert" + ): + await storage.add_training_metrics([metric]) + + @pytest.mark.asyncio + async def test_storage_get_active_program_not_found(self, storage) -> None: + """Querying active program when missing should return None.""" + program = await storage.get_active_training_program("nonexistent") + assert program is None + + @pytest.mark.asyncio + async def test_storage_advance_state_no_program(self, storage) -> None: + """Advancing training state when program is missing should return None.""" + state = await storage.advance_training_state("user1", 1, "2026-06-07T12:00:00") + assert state is None + + @pytest.mark.asyncio + async def test_storage_get_training_history_handles_deleted_sessions( + self, storage + ) -> None: + """Verify get_training_history ignores sessions that return None from get_session.""" + # Insert a raw session that we will fail to load to hit session is None check + await storage._conn.execute( + """ + INSERT INTO workout_sessions (user_id, workout_date, split_name, created_at) + VALUES ('user1', '2026-06-07', 'Legs', '2026-06-07T12:00:00') + """ + ) + # Session ID is 1. If we query user2, get_session returns None because of user isolation. + history = await storage.get_training_history("user2", limit=1) + assert len(history) == 0 + + @pytest.mark.asyncio + @patch("blacki.workouts.tools.get_storage") + async def test_tool_log_training_parsing_shorthands( + self, mock_get_storage, mock_tool_context + ) -> None: + """Test optional and shorthand exercise formats in log_training.""" + mock_storage = AsyncMock() + mock_get_storage.return_value = mock_storage + mock_storage.create_session.return_value = 1 + mock_storage.get_active_training_program.return_value = None + mock_storage.get_training_history.return_value = [] + + # Empty exercises + result = await log_training(mock_tool_context, "rest", exercises=None) + assert result["status"] == "success" + + # Shorthand sets list with single set dict + exercises = [{"name": "bench", "sets": {"weight_kg": 100, "reps": 10}}] + result = await log_training( + mock_tool_context, "resistance", exercises=exercises + ) + assert result["status"] == "success" + + # Invalid sets type + exercises_invalid = [{"name": "bench", "sets": "invalid"}] + result = await log_training( + mock_tool_context, "resistance", exercises=exercises_invalid + ) + assert result["status"] == "error" + + # Missing reps in sets + exercises_missing = [{"name": "bench", "sets": [{"weight_kg": 100}]}] + result = await log_training( + mock_tool_context, "resistance", exercises=exercises_missing + ) + assert result["status"] == "error" + + @pytest.mark.asyncio + @patch("blacki.workouts.tools.get_storage") + async def test_tool_log_training_metrics_parsing_branches( + self, mock_get_storage, mock_tool_context + ) -> None: + """Test unit inferring and parsing dictionaries with units/notes in log_training/metrics.""" + mock_storage = AsyncMock() + mock_get_storage.return_value = mock_storage + mock_storage.add_training_metrics.return_value = [1] + + # Metric units by name suffix + metrics = { + "cardio_bpm": 145.0, + "hike_km": 5.2, + "ruck_min": 90.0, + "squat_1rm": 150.0, + "random_metric": 42.0, + } + result = await update_training_metrics(mock_tool_context, metrics) + assert result["status"] == "success" + + # Dictionary format with notes and unit override + metrics_dict = { + "bench_1rm": {"value": 105.0, "unit": "lbs", "notes": "sore shoulder"} + } + result = await update_training_metrics(mock_tool_context, metrics_dict) + assert result["status"] == "success" + + # Metric parsing validations + # Missing value + assert (await update_training_metrics(mock_tool_context, {"bench_1rm": {}}))[ + "status" + ] == "error" + # Non-numeric + assert ( + await update_training_metrics(mock_tool_context, {"bench_1rm": "invalid"}) + )["status"] == "error" + # Empty metric dict + assert (await update_training_metrics(mock_tool_context, {}))[ + "status" + ] == "error" + # Empty metric name + assert (await update_training_metrics(mock_tool_context, {"": 10}))[ + "status" + ] == "error" + + @pytest.mark.asyncio + @patch("blacki.workouts.tools.get_storage") + async def test_get_todays_training_day_six_swap_branches( + self, mock_get_storage, mock_tool_context + ) -> None: + """Test day 6 branches where Day 4 log doesn't exist, or has different status, or day is missing.""" + mock_storage = AsyncMock() + mock_get_storage.return_value = mock_storage + + # Scenario A: Day 6 is current cycle day, but Day 4 has NO recorded session + day = TrainingProgramDay(cycle_day=6, focus="VO2 Max", session_type="vo2") + program = TrainingProgram( + id=1, + user_id="user1", + name=" Rotating", + starts_on="2026-06-07", + created_at="2026-06-07T12:00:00", + updated_at="2026-06-07T12:00:00", + days=[day], + state=TrainingProgramState( + user_id="user1", + program_id=1, + current_cycle_day=6, + current_mesocycle_week=1, + updated_at="2026-06-12T12:00:00", + ), + ) + mock_storage.get_active_training_program.return_value = program + mock_storage.get_training_history.return_value = [] # Day 4 returns empty list + + result = await get_todays_training(mock_tool_context) + assert result["status"] == "success" + assert len(result["recommendations"]) == 0 + + # Scenario B: Day 4 log exists but back status is OK (no swap recommended) + day_four_session = WorkoutSession( + user_id="user1", + workout_date="2026-06-10", + split_name="Pull", + created_at="2026-06-10T12:00:00", + cycle_day=4, + metrics={"lower_back_status": "ok"}, + ) + + async def history_side_effect( + user_id, cycle_day, session_type, limit=1, exercise_name=None + ): + if cycle_day == 4: + return [day_four_session] + return [] + + mock_storage.get_training_history.side_effect = history_side_effect + + result = await get_todays_training(mock_tool_context) + assert result["status"] == "success" + assert len(result["recommendations"]) == 0 + + # Scenario C: Active program exists but no matching day config for today's cycle day + assert program.state is not None + program.state.current_cycle_day = 10 + result = await get_todays_training(mock_tool_context) + assert result["status"] == "error" + + # Scenario D: Active program day config exists for day 1 (non-6 day config to cover cycle_day != 6 branch on line 368) + day_one = TrainingProgramDay( + cycle_day=1, focus="Legs", session_type="resistance" + ) + program.days.append(day_one) + assert program.state is not None + program.state.current_cycle_day = 1 + mock_storage.get_training_history.side_effect = None + mock_storage.get_training_history.return_value = [] + result = await get_todays_training(mock_tool_context) + assert result["status"] == "success" + assert result["training_day"]["cycle_day"] == 1 + + +class TestFullTestCoverageFillers: + """Explicitly tests missing branch coverage to reach 100% codebase wide.""" + + @pytest.mark.asyncio + async def test_storage_ensure_columns_actually_upgrades(self, conn, lock) -> None: + """Create a table without the new columns first, then ensure _ensure_columns adds them.""" + # Create workout_sessions table with ONLY the legacy columns + await conn.execute(""" + CREATE TABLE workout_sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, + workout_date TEXT NOT NULL, + split_name TEXT NOT NULL, + notes TEXT, + created_at TEXT NOT NULL + ) + """) + storage = SqliteWorkoutStorage(conn, lock) + # Manually invoke ensure_columns (normally called inside _create_tables which would do IF NOT EXISTS) + # This will meet the "not in existing" condition and trigger the ALTER TABLE statement (Line 300) + await storage._ensure_columns( + "workout_sessions", + { + "program_id": "INTEGER", + "session_type": "TEXT NOT NULL DEFAULT 'resistance'", + }, + ) + + # Verify the columns are indeed added + cursor = await conn.execute("PRAGMA table_info(workout_sessions)") + rows = await cursor.fetchall() + existing = {row[1] for row in rows} + assert "program_id" in existing + assert "session_type" in existing + + @pytest.mark.asyncio + async def test_storage_get_training_history_true_and_false_branches( + self, storage + ) -> None: + """Ensure get_training_history covers both returning matched sessions and omitting None values.""" + # Add session + session = WorkoutSession( + user_id="user1", + workout_date="2026-06-07", + split_name="Legs", + created_at="2026-06-07T12:00:00", + cycle_day=1, + session_type="resistance", + ) + await storage.create_session(session) + + # Query with user1 (hits True branch of if session is not None) + assert len(await storage.get_training_history("user1", limit=1)) == 1 + + # Query with user1 but mock get_session to return None (hits False branch of if session is not None) + with patch.object(storage, "get_session", return_value=None): + assert len(await storage.get_training_history("user1", limit=1)) == 0 + + @pytest.mark.asyncio + @patch("blacki.workouts.tools.get_storage") + async def test_set_training_program_additional_validations( + self, mock_get_storage, mock_tool_context + ) -> None: + """Cover additional validation branches in set_training_program.""" + mock_storage = AsyncMock() + mock_get_storage.return_value = mock_storage + + # cycle_length_days < 1 + config = { + "days": [{"cycle_day": 1, "session_type": "rest"}], + "cycle_length_days": -5, + } + assert (await set_training_program(mock_tool_context, config))[ + "status" + ] == "error" + + # current_cycle_day out of bounds + config = { + "days": [{"cycle_day": 1, "session_type": "rest"}], + "cycle_length_days": 1, + "current_cycle_day": 5, + } + assert (await set_training_program(mock_tool_context, config))[ + "status" + ] == "error" + + # current_mesocycle_week < 1 + config = { + "days": [{"cycle_day": 1, "session_type": "rest"}], + "cycle_length_days": 1, + "current_cycle_day": 1, + "current_mesocycle_week": -5, + } + assert (await set_training_program(mock_tool_context, config))[ + "status" + ] == "error" + + # exercises is not a list + config = { + "days": [{"cycle_day": 1, "session_type": "rest", "exercises": "not_list"}], + "cycle_length_days": 1, + } + assert (await set_training_program(mock_tool_context, config))[ + "status" + ] == "error" + + # rules is not a dict + config = { + "days": [{"cycle_day": 1, "session_type": "rest", "rules": "not_dict"}], + "cycle_length_days": 1, + } + assert (await set_training_program(mock_tool_context, config))[ + "status" + ] == "error" + + # metrics_config is None / empty dict (Line 318->326 False branch) + config = { + "days": [{"cycle_day": 1, "session_type": "rest"}], + "cycle_length_days": 1, + "baseline_metrics": None, + } + result = await set_training_program(mock_tool_context, config) + assert result["status"] == "success" + + @pytest.mark.asyncio + @patch("blacki.workouts.tools.get_storage") + async def test_log_training_additional_validations( + self, mock_get_storage, mock_tool_context + ) -> None: + """Cover validation and defaulting branches in log_training.""" + mock_storage = AsyncMock() + mock_get_storage.return_value = mock_storage + mock_storage.create_session.return_value = 1 + mock_storage.get_training_history.return_value = [] + + # session_type invalid + assert (await log_training(mock_tool_context, "invalid_type"))[ + "status" + ] == "error" + + # completion_status invalid + assert ( + await log_training(mock_tool_context, "rest", completion_status="invalid") + )["status"] == "error" + + # cycle_day is None, should default to active program state (Line 448) + program = TrainingProgram( + id=1, + user_id="user1", + name="Rotating Plan", + starts_on="2026-06-07", + created_at="2026-06-07T12:00:00", + updated_at="2026-06-07T12:00:00", + state=TrainingProgramState( + user_id="user1", + program_id=1, + current_cycle_day=3, + current_mesocycle_week=1, + updated_at="2026-06-07T12:00:00", + ), + ) + mock_storage.get_active_training_program.return_value = program + result = await log_training(mock_tool_context, "rest", cycle_day=None) + assert result["status"] == "success" + # Verify the logged session was created with the default cycle day 3 + logged_session = mock_storage.create_session.call_args[0][0] + assert logged_session.cycle_day == 3 + + @pytest.mark.asyncio + @patch("blacki.workouts.tools.get_storage") + async def test_advance_training_cycle_tool_success( + self, mock_get_storage, mock_tool_context + ) -> None: + """Verify successful explicit cycle advancement tool execution.""" + mock_storage = AsyncMock() + mock_get_storage.return_value = mock_storage + state = TrainingProgramState( + user_id="user1", + program_id=1, + current_cycle_day=5, + current_mesocycle_week=1, + updated_at="2026-06-07T12:00:00", + ) + mock_storage.advance_training_state.return_value = state + + result = await advance_training_cycle(mock_tool_context, days=2) + assert result["status"] == "success" + assert result["state"]["current_cycle_day"] == 5 + mock_storage.advance_training_state.assert_called_once() From d668a692cdd3fd3d023fe1252ffefd813b95c95f Mon Sep 17 00:00:00 2001 From: QueryPlanner Date: Sun, 7 Jun 2026 08:15:53 +0530 Subject: [PATCH 2/2] refactor: address code review and optimize queries - Acquire storage lock in advance_training_state to prevent races - Add defensive type checking for LLM-provided inputs in tools - Batch load exercises in get_training_history to prevent N+1 queries - Optimize get_latest_training_metrics using SQLite window functions - Update unit tests to achieve 100% statement and branch coverage --- src/blacki/workouts/storage.py | 99 +++++++++++++-------- src/blacki/workouts/tools.py | 24 +++++ tests/workouts/test_training.py | 151 ++++++++++++++++++++++++++++++-- 3 files changed, 230 insertions(+), 44 deletions(-) diff --git a/src/blacki/workouts/storage.py b/src/blacki/workouts/storage.py index bc9dc8c..5ecc6ac 100644 --- a/src/blacki/workouts/storage.py +++ b/src/blacki/workouts/storage.py @@ -659,22 +659,22 @@ async def advance_training_state( self, user_id: str, days: int, updated_at: str ) -> TrainingProgramState | None: """Advance the active program pointer by a number of calendar days.""" - program = await self.get_active_training_program(user_id) - if program is None or program.state is None or program.id is None: - return None - - current_day = program.state.current_cycle_day - current_week = program.state.current_mesocycle_week - for _ in range(days): - if current_day % 7 == 0: - current_week += 1 - if current_week > program.deload_week_interval: - current_week = 1 - current_day = ( - current_day + 1 if current_day < program.cycle_length_days else 1 - ) - async with self._lock: + program = await self.get_active_training_program(user_id) + if program is None or program.state is None or program.id is None: + return None + + current_day = program.state.current_cycle_day + current_week = program.state.current_mesocycle_week + for _ in range(days): + if current_day % 7 == 0: + current_week += 1 + if current_week > program.deload_week_interval: + current_week = 1 + current_day = ( + current_day + 1 if current_day < program.cycle_length_days else 1 + ) + await self._conn.execute( """ UPDATE training_program_state @@ -684,13 +684,13 @@ async def advance_training_state( (current_day, current_week, updated_at, user_id), ) - return TrainingProgramState( - user_id=user_id, - program_id=program.id, - current_cycle_day=current_day, - current_mesocycle_week=current_week, - updated_at=updated_at, - ) + return TrainingProgramState( + user_id=user_id, + program_id=program.id, + current_cycle_day=current_day, + current_mesocycle_week=current_week, + updated_at=updated_at, + ) async def get_training_history( self, @@ -726,11 +726,34 @@ async def get_training_history( LIMIT ? """ # noqa: S608 rows = await self._fetch_all(query, tuple(values)) - sessions = [] - for row in rows: - session = await self.get_session(int(row["id"]), user_id) - if session is not None: - sessions.append(session) + if not rows: + return [] + + sessions = [self._row_to_session(row) for row in rows] + session_ids = [s.id for s in sessions if s.id is not None] + + if session_ids: + placeholders = ", ".join("?" for _ in session_ids) + ex_rows = await self._fetch_all( + f""" + SELECT * FROM workout_exercises + WHERE session_id IN ({placeholders}) + ORDER BY exercise_order ASC, id ASC + """, # noqa: S608 + tuple(session_ids), + ) + exercises_by_session: dict[int, list[WorkoutExercise]] = { + sid: [] for sid in session_ids + } + for r in ex_rows: + sid = int(r["session_id"]) + if sid in exercises_by_session: + exercises_by_session[sid].append(self._row_to_exercise(r)) + + for session in sessions: + if session.id in exercises_by_session: + session.exercises = exercises_by_session[session.id] + return sessions async def add_training_metrics(self, metrics: list[TrainingMetric]) -> list[int]: @@ -771,18 +794,22 @@ async def get_latest_training_metrics( rows = await self._fetch_all( f""" - SELECT * FROM training_metrics - {where} - ORDER BY metric_name ASC, recorded_at DESC, id DESC + WITH ranked_metrics AS ( + SELECT *, + ROW_NUMBER() OVER ( + PARTITION BY metric_name + ORDER BY recorded_at DESC, id DESC + ) as rn + FROM training_metrics + {where} + ) + SELECT * FROM ranked_metrics + WHERE rn = 1 + ORDER BY metric_name ASC """, # noqa: S608 tuple(values), ) - latest_by_name: dict[str, TrainingMetric] = {} - for row in rows: - name = row["metric_name"] - if name not in latest_by_name: - latest_by_name[name] = self._row_to_training_metric(row) - return list(latest_by_name.values()) + return [self._row_to_training_metric(row) for row in rows] async def get_exercise_history( self, user_id: str, exercise_name: str, limit: int = 8 diff --git a/src/blacki/workouts/tools.py b/src/blacki/workouts/tools.py index 5a3736a..1bfd54c 100644 --- a/src/blacki/workouts/tools.py +++ b/src/blacki/workouts/tools.py @@ -30,8 +30,14 @@ def _parse_workout_exercises( if not exercises: return [], None + if not isinstance(exercises, list): + return [], "Exercises must be a list of dictionaries" # type: ignore[unreachable] + parsed_exercises = [] for i, ex_dict in enumerate(exercises): + if not isinstance(ex_dict, dict): + return [], "Each exercise item must be a dictionary" # type: ignore[unreachable] + if "name" not in ex_dict or "sets" not in ex_dict: return [], "Each exercise must have 'name' and 'sets' keys" @@ -45,6 +51,9 @@ def _parse_workout_exercises( elif isinstance(sets_data, dict): sets_list = [sets_data] elif isinstance(sets_data, list): + for s in sets_data: + if not isinstance(s, dict): + return [], "Each set item in sets list must be a dictionary" sets_list = sets_data else: return [], "'sets' must be a list of dictionaries or an integer" @@ -93,6 +102,9 @@ def _infer_metric_unit(metric_name: str) -> str: def _parse_training_metrics( user_id: str, metrics: dict[str, Any], recorded_at: str ) -> tuple[list[TrainingMetric], str | None]: + if not isinstance(metrics, dict): + return [], "metrics must be a dictionary" # type: ignore[unreachable] + if not metrics: return [], "metrics cannot be empty" @@ -230,6 +242,12 @@ async def set_training_program( if not user_id: return {"status": "error", "message": "Missing user_id in tool_context"} + if not isinstance(program_config, dict): + return {"status": "error", "message": "program_config must be a dictionary"} # type: ignore[unreachable] + + if baseline_metrics is not None and not isinstance(baseline_metrics, dict): + return {"status": "error", "message": "baseline_metrics must be a dictionary"} # type: ignore[unreachable] + days_config = program_config.get("days") if not isinstance(days_config, list) or not days_config: return {"status": "error", "message": "program_config.days must be a list"} @@ -439,6 +457,12 @@ async def log_training( "message": f"completion_status must be one of {allowed_statuses}", } + if metrics is not None and not isinstance(metrics, dict): + return { # type: ignore[unreachable] + "status": "error", + "message": "metrics must be a dictionary", + } + parsed_exercises, parse_error = _parse_workout_exercises(exercises) if parse_error: return {"status": "error", "message": parse_error} diff --git a/tests/workouts/test_training.py b/tests/workouts/test_training.py index c686d0a..5b8c490 100644 --- a/tests/workouts/test_training.py +++ b/tests/workouts/test_training.py @@ -1022,11 +1022,34 @@ async def test_storage_ensure_columns_actually_upgrades(self, conn, lock) -> Non assert "session_type" in existing @pytest.mark.asyncio - async def test_storage_get_training_history_true_and_false_branches( + async def test_storage_get_training_history_no_session_ids(self, storage) -> None: + """Ensure get_training_history handles sessions with None IDs correctly.""" + session = WorkoutSession( + user_id="user1", + workout_date="2026-06-07", + split_name="Legs", + created_at="2026-06-07T12:00:00", + cycle_day=1, + session_type="resistance", + ) + await storage.create_session(session) + original_row_to_session = storage._row_to_session + + def mock_row_to_session(row): + s = original_row_to_session(row) + s.id = None + return s + + with patch.object(storage, "_row_to_session", side_effect=mock_row_to_session): + sessions = await storage.get_training_history("user1", limit=1) + assert len(sessions) == 1 + assert sessions[0].id is None + + @pytest.mark.asyncio + async def test_storage_get_training_history_mismatched_exercise_session_id( self, storage ) -> None: - """Ensure get_training_history covers both returning matched sessions and omitting None values.""" - # Add session + """Ensure get_training_history filters out exercises with mismatched session_ids.""" session = WorkoutSession( user_id="user1", workout_date="2026-06-07", @@ -1036,13 +1059,125 @@ async def test_storage_get_training_history_true_and_false_branches( session_type="resistance", ) await storage.create_session(session) + original_fetch_all = storage._fetch_all + + async def mock_fetch_all(query, values): + if "workout_exercises" in query: + return [ + { + "id": 99, + "session_id": 9999, + "exercise_name": "bench press", + "exercise_order": 0, + "sets": "[]", + } + ] + return await original_fetch_all(query, values) + + with patch.object(storage, "_fetch_all", side_effect=mock_fetch_all): + sessions = await storage.get_training_history("user1", limit=1) + assert len(sessions) == 1 + assert len(sessions[0].exercises) == 0 - # Query with user1 (hits True branch of if session is not None) - assert len(await storage.get_training_history("user1", limit=1)) == 1 + @pytest.mark.asyncio + async def test_storage_get_training_history_session_id_not_in_dict( + self, storage + ) -> None: + """Cover session.id not in exercises_by_session dictionary.""" + session = WorkoutSession( + user_id="user1", + workout_date="2026-06-07", + split_name="Legs", + created_at="2026-06-07T12:00:00", + cycle_day=1, + session_type="resistance", + ) + await storage.create_session(session) + original_row_to_session = storage._row_to_session + sessions_ref = [] + + def mock_row_to_session(row): + s = original_row_to_session(row) + sessions_ref.append(s) + return s + + original_fetch_all = storage._fetch_all + + async def mock_fetch_all(query, values): + if "workout_exercises" in query: + for s in sessions_ref: + s.id = 9999 + return [] + return await original_fetch_all(query, values) + + with ( + patch.object(storage, "_row_to_session", side_effect=mock_row_to_session), + patch.object(storage, "_fetch_all", side_effect=mock_fetch_all), + ): + await storage.get_training_history("user1", limit=1) - # Query with user1 but mock get_session to return None (hits False branch of if session is not None) - with patch.object(storage, "get_session", return_value=None): - assert len(await storage.get_training_history("user1", limit=1)) == 0 + @pytest.mark.asyncio + @patch("blacki.workouts.tools.get_storage") + async def test_defensive_type_validations( + self, mock_get_storage, mock_tool_context + ) -> None: + """Test defensive type checking added for LLM-provided arguments in tools.""" + from blacki.workouts.tools import ( + log_training, + set_training_program, + update_training_metrics, + ) + + # 1. exercises is not a list in log_training + res = await log_training( + mock_tool_context, + "resistance", + exercises="not a list", # type: ignore[arg-type] + ) + assert res["status"] == "error" + assert "must be a list" in res["message"] + + # 2. exercise item is not a dict in log_training + res = await log_training( + mock_tool_context, + "resistance", + exercises=["not a dict"], # type: ignore[list-item] + ) + assert res["status"] == "error" + assert "must be a dictionary" in res["message"] + + # 3. set detail is not a dict in sets list in log_training + res = await log_training( + mock_tool_context, + "resistance", + exercises=[{"name": "bench press", "sets": ["not a dict"]}], + ) + assert res["status"] == "error" + assert "must be a dictionary" in res["message"] + + # 4. metrics in log_training is not a dict + res = await log_training(mock_tool_context, "resistance", metrics="not a dict") # type: ignore[arg-type] + assert res["status"] == "error" + assert "must be a dictionary" in res["message"] + + # 5. metrics in update_training_metrics is not a dict + res = await update_training_metrics(mock_tool_context, "not a dict") # type: ignore[arg-type] + assert res["status"] == "error" + assert "must be a dictionary" in res["message"] + + # 6. program_config in set_training_program is not a dict + res = await set_training_program(mock_tool_context, "not a dict") # type: ignore[arg-type] + assert res["status"] == "error" + assert "must be a dictionary" in res["message"] + + # 7. baseline_metrics in set_training_program is not a dict + res = await set_training_program( + mock_tool_context, + {}, + baseline_metrics="not a dict", # type: ignore[arg-type] + ) + assert res["status"] == "error" + assert "must be a dictionary" in res["message"] @pytest.mark.asyncio @patch("blacki.workouts.tools.get_storage")