From a7574e246c28825635a82389794e861af3b29dc7 Mon Sep 17 00:00:00 2001 From: jakeross Date: Thu, 5 Feb 2026 23:14:32 +1100 Subject: [PATCH 1/4] feat: add data migration to move NMA location notes to Notes table --- AGENTS.MD | 6 +- cli/cli.py | 196 ++++++++++++++---- data_migrations/__init__.py | 1 + data_migrations/base.py | 29 +++ .../20260205_0001_move_nma_location_notes.py | 94 +++++++++ data_migrations/migrations/__init__.py | 1 + data_migrations/migrations/_template.py | 38 ++++ data_migrations/registry.py | 59 ++++++ data_migrations/runner.py | 189 +++++++++++++++++ pyproject.toml | 1 + tests/test_cli_commands.py | 8 +- tests/test_data_migrations.py | 107 ++++++++++ tests/test_data_migrations_cli.py | 93 +++++++++ tests/test_thing_transfer.py | 52 +++++ uv.lock | 62 +++++- 15 files changed, 891 insertions(+), 45 deletions(-) create mode 100644 data_migrations/__init__.py create mode 100644 data_migrations/base.py create mode 100644 data_migrations/migrations/20260205_0001_move_nma_location_notes.py create mode 100644 data_migrations/migrations/__init__.py create mode 100644 data_migrations/migrations/_template.py create mode 100644 data_migrations/registry.py create mode 100644 data_migrations/runner.py create mode 100644 tests/test_data_migrations.py create mode 100644 tests/test_data_migrations_cli.py create mode 100644 tests/test_thing_transfer.py diff --git a/AGENTS.MD b/AGENTS.MD index a25a6021..ae0bc08d 100644 --- a/AGENTS.MD +++ b/AGENTS.MD @@ -21,7 +21,11 @@ these transfers, keep the following rules in mind to avoid hour-long runs: right instance before running destructive suites. - When done, `deactivate` to exit the venv and avoid polluting other shells. +## 3. Data migrations must be idempotent +- Data migrations should be safe to re-run without creating duplicate rows or corrupting data. +- Use upserts or duplicate checks and update source fields only after successful inserts. + Following this playbook keeps ETL runs measured in seconds/minutes instead of hours. EOF ## Activate python venv -Always use `source .venv/bin/activate` to activate the venv running python \ No newline at end of file +Always use `source .venv/bin/activate` to activate the venv running python diff --git a/cli/cli.py b/cli/cli.py index 50625434..bad3b720 100644 --- a/cli/cli.py +++ b/cli/cli.py @@ -13,42 +13,52 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== -import click +from pathlib import Path + +import typer from dotenv import load_dotenv load_dotenv() - -@click.group() -def cli(): - """Command line interface for managing the application.""" - pass +cli = typer.Typer(help="Command line interface for managing the application.") +water_levels = typer.Typer(help="Water-level utilities") +data_migrations = typer.Typer(help="Data migration utilities") +cli.add_typer(water_levels, name="water-levels") +cli.add_typer(data_migrations, name="data-migrations") -@cli.command() +@cli.command("initialize-lexicon") def initialize_lexicon(): from core.initializers import init_lexicon init_lexicon() -@cli.command() -@click.argument( - "root_directory", - type=click.Path(exists=True, file_okay=False, dir_okay=True, readable=True), -) -def associate_assets_command(root_directory: str): +@cli.command("associate-assets") +def associate_assets_command( + root_directory: str = typer.Argument( + ..., + exists=True, + file_okay=False, + dir_okay=True, + readable=True, + ) +): from cli.service_adapter import associate_assets associate_assets(root_directory) -@cli.command() -@click.argument( - "file_path", - type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True), -) -def well_inventory_csv(file_path: str): +@cli.command("well-inventory-csv") +def well_inventory_csv( + file_path: str = typer.Argument( + ..., + exists=True, + file_okay=True, + dir_okay=False, + readable=True, + ) +): """ parse and upload a csv to database """ @@ -58,28 +68,24 @@ def well_inventory_csv(file_path: str): well_inventory_csv(file_path) -@cli.group() -def water_levels(): - """Water-level utilities""" - pass - - @water_levels.command("bulk-upload") -@click.option( - "--file", - "file_path", - type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True), - required=True, - help="Path to CSV file containing water level rows", -) -@click.option( - "--output", - "output_format", - type=click.Choice(["json"], case_sensitive=False), - default=None, - help="Optional output format", -) -def water_levels_bulk_upload(file_path: str, output_format: str | None): +def water_levels_bulk_upload( + file_path: str = typer.Option( + ..., + "--file", + exists=True, + file_okay=True, + dir_okay=False, + readable=True, + help="Path to CSV file containing water level rows", + ), + output_format: str | None = typer.Option( + None, + "--output", + case_sensitive=False, + help="Optional output format", + ), +): """ parse and upload a csv """ @@ -90,6 +96,116 @@ def water_levels_bulk_upload(file_path: str, output_format: str | None): water_levels_csv(file_path, pretty_json=pretty_json) +@data_migrations.command("list") +def data_migrations_list(): + from data_migrations.registry import list_migrations + + migrations = list_migrations() + if not migrations: + typer.echo("No data migrations registered.") + return + for migration in migrations: + repeatable = " (repeatable)" if migration.is_repeatable else "" + typer.echo(f"{migration.id}: {migration.name}{repeatable}") + + +@data_migrations.command("status") +def data_migrations_status(): + from db.engine import session_ctx + from data_migrations.runner import get_status + + with session_ctx() as session: + statuses = get_status(session) + if not statuses: + typer.echo("No data migrations registered.") + return + for status in statuses: + last_applied = ( + status.last_applied_at.isoformat() if status.last_applied_at else "never" + ) + typer.echo( + f"{status.id}: applied {status.applied_count} time(s), last={last_applied}" + ) + + +@data_migrations.command("run") +def data_migrations_run( + migration_id: str = typer.Argument(...), + force: bool = typer.Option( + False, "--force", help="Re-run even if already applied." + ), +): + from db.engine import session_ctx + from data_migrations.runner import run_migration_by_id + + with session_ctx() as session: + ran = run_migration_by_id(session, migration_id, force=force) + typer.echo("applied" if ran else "skipped") + + +@data_migrations.command("run-all") +def data_migrations_run_all( + include_repeatable: bool = typer.Option( + False, + "--include-repeatable/--exclude-repeatable", + help="Whether to include repeatable migrations.", + ), + force: bool = typer.Option( + False, "--force", help="Re-run non-repeatable migrations." + ), +): + from db.engine import session_ctx + from data_migrations.runner import run_all + + with session_ctx() as session: + ran = run_all(session, include_repeatable=include_repeatable, force=force) + typer.echo(f"applied {len(ran)} migration(s)") + + +@cli.command("alembic-upgrade-and-data") +def alembic_upgrade_and_data( + revision: str = typer.Argument("head"), + include_repeatable: bool = typer.Option( + False, + "--include-repeatable/--exclude-repeatable", + help="Whether to include repeatable migrations.", + ), + force: bool = typer.Option( + False, "--force", help="Re-run non-repeatable migrations." + ), +): + from alembic import command + from alembic.config import Config + from alembic.runtime.migration import MigrationContext + from alembic.script import ScriptDirectory + from db.engine import engine, session_ctx + from data_migrations.runner import run_all + + root = Path(__file__).resolve().parents[1] + cfg = Config(str(root / "alembic.ini")) + cfg.set_main_option("script_location", str(root / "alembic")) + + command.upgrade(cfg, revision) + + with engine.connect() as conn: + context = MigrationContext.configure(conn) + heads = context.get_current_heads() + script = ScriptDirectory.from_config(cfg) + applied_revisions: set[str] = set() + for head in heads: + for rev in script.iterate_revisions(head, "base"): + applied_revisions.add(rev.revision) + + with session_ctx() as session: + ran = run_all( + session, + include_repeatable=include_repeatable, + force=force, + allowed_alembic_revisions=applied_revisions, + ) + typer.echo(f"applied {len(ran)} migration(s)") + + if __name__ == "__main__": cli() diff --git a/data_migrations/__init__.py b/data_migrations/__init__.py new file mode 100644 index 00000000..2f8d062a --- /dev/null +++ b/data_migrations/__init__.py @@ -0,0 +1 @@ +# Data migrations package diff --git a/data_migrations/base.py b/data_migrations/base.py new file mode 100644 index 00000000..89cc24f3 --- /dev/null +++ b/data_migrations/base.py @@ -0,0 +1,29 @@ +# =============================================================================== +# Copyright 2026 ross +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== +from dataclasses import dataclass +from typing import Callable + +from sqlalchemy.orm import Session + + +@dataclass(frozen=True) +class DataMigration: + id: str + alembic_revision: str + name: str + description: str + run: Callable[[Session], None] + is_repeatable: bool = False diff --git a/data_migrations/migrations/20260205_0001_move_nma_location_notes.py b/data_migrations/migrations/20260205_0001_move_nma_location_notes.py new file mode 100644 index 00000000..6261ca12 --- /dev/null +++ b/data_migrations/migrations/20260205_0001_move_nma_location_notes.py @@ -0,0 +1,94 @@ +# =============================================================================== +# Copyright 2026 ross +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== +from sqlalchemy import insert, select, update +from sqlalchemy.orm import Session + +from data_migrations.base import DataMigration +from db.location import Location +from db.notes import Notes + +NOTE_TYPE = "General" +BATCH_SIZE = 1000 + + +def _iter_location_notes(session: Session): + stmt = select( + Location.id, + Location.nma_location_notes, + Location.release_status, + ).where(Location.nma_location_notes.isnot(None)) + for row in session.execute(stmt): + note = (row.nma_location_notes or "").strip() + if not note: + continue + yield row.id, note, row.release_status + + +def run(session: Session) -> None: + buffer: list[tuple[int, str, str]] = [] + for item in _iter_location_notes(session): + buffer.append(item) + if len(buffer) >= BATCH_SIZE: + _flush_batch(session, buffer) + buffer.clear() + if buffer: + _flush_batch(session, buffer) + + +def _flush_batch(session: Session, batch: list[tuple[int, str, str]]) -> None: + location_ids = [row[0] for row in batch] + existing = session.execute( + select(Notes.target_id, Notes.content).where( + Notes.target_table == "location", + Notes.note_type == NOTE_TYPE, + Notes.target_id.in_(location_ids), + ) + ).all() + existing_set = {(row.target_id, row.content) for row in existing} + + inserts = [] + for location_id, note, release_status in batch: + if (location_id, note) in existing_set: + continue + inserts.append( + { + "target_id": location_id, + "target_table": "location", + "note_type": NOTE_TYPE, + "content": note, + "release_status": release_status or "draft", + } + ) + + if inserts: + session.execute(insert(Notes), inserts) + + session.execute( + update(Location) + .where(Location.id.in_(location_ids)) + .values(nma_location_notes=None) + ) + session.commit() + + +MIGRATION = DataMigration( + id="20260205_0001_move_nma_location_notes", + alembic_revision="f0c9d8e7b6a5", + name="Move NMA location notes to Notes table", + description="Backfill polymorphic notes from Location.nma_location_notes.", + run=run, + is_repeatable=False, +) diff --git a/data_migrations/migrations/__init__.py b/data_migrations/migrations/__init__.py new file mode 100644 index 00000000..5c91fffc --- /dev/null +++ b/data_migrations/migrations/__init__.py @@ -0,0 +1 @@ +# Data migrations live here. diff --git a/data_migrations/migrations/_template.py b/data_migrations/migrations/_template.py new file mode 100644 index 00000000..bec1295d --- /dev/null +++ b/data_migrations/migrations/_template.py @@ -0,0 +1,38 @@ +# =============================================================================== +# Copyright 2026 ross +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== +from sqlalchemy.orm import Session + +from data_migrations.base import DataMigration + + +def run(session: Session) -> None: + """ + Implement migration logic here. + + Use SQLAlchemy core for large batches: + session.execute(insert(Model), rows) + """ + return None + + +MIGRATION = DataMigration( + id="YYYYMMDD_0000", + alembic_revision="REVISION_ID", + name="Short migration name", + description="Why this data migration exists.", + run=run, + is_repeatable=False, +) diff --git a/data_migrations/registry.py b/data_migrations/registry.py new file mode 100644 index 00000000..27dc4cc4 --- /dev/null +++ b/data_migrations/registry.py @@ -0,0 +1,59 @@ +# =============================================================================== +# Copyright 2026 ross +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== +from __future__ import annotations + +import importlib +import pkgutil + +from data_migrations.base import DataMigration + + +def _discover_migration_modules() -> list[str]: + base_pkg = __name__.rsplit(".", 1)[0] + migrations_pkg = f"{base_pkg}.migrations" + try: + package = importlib.import_module(migrations_pkg) + except ModuleNotFoundError: + return [] + package_paths = list(getattr(package, "__path__", [])) + modules: list[str] = [] + for module_info in pkgutil.iter_modules(package_paths): + if module_info.ispkg: + continue + if module_info.name.startswith("_"): + continue + modules.append(f"{migrations_pkg}.{module_info.name}") + return modules + + +def list_migrations() -> list[DataMigration]: + migrations: list[DataMigration] = [] + for module_path in _discover_migration_modules(): + module = importlib.import_module(module_path) + migration = getattr(module, "MIGRATION", None) + if migration is None: + continue + if not isinstance(migration, DataMigration): + raise TypeError(f"{module_path}.MIGRATION must be a DataMigration instance") + migrations.append(migration) + return migrations + + +def get_migration(migration_id: str) -> DataMigration | None: + for migration in list_migrations(): + if migration.id == migration_id: + return migration + return None diff --git a/data_migrations/runner.py b/data_migrations/runner.py new file mode 100644 index 00000000..effc1922 --- /dev/null +++ b/data_migrations/runner.py @@ -0,0 +1,189 @@ +# =============================================================================== +# Copyright 2026 ross +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone + +from sqlalchemy import ( + Boolean, + Column, + DateTime, + MetaData, + String, + Table, + func, + select, + text, +) +from sqlalchemy.orm import Session + +from data_migrations.base import DataMigration +from data_migrations.registry import get_migration, list_migrations +from transfers.logger import logger + +metadata = MetaData() +data_migration_history = Table( + "data_migration_history", + metadata, + Column("id", String(100), nullable=False), + Column("alembic_revision", String(100), nullable=False), + Column("name", String(255), nullable=False), + Column("is_repeatable", Boolean, nullable=False, default=False), + Column("applied_at", DateTime(timezone=True), nullable=False), + Column("checksum", String(64), nullable=True), +) + + +@dataclass(frozen=True) +class MigrationStatus: + id: str + alembic_revision: str + name: str + is_repeatable: bool + applied_count: int + last_applied_at: datetime | None + + +def ensure_history_table(session: Session) -> None: + metadata.create_all(bind=session.get_bind(), tables=[data_migration_history]) + + +def _applied_counts(session: Session) -> dict[str, int]: + stmt = select(data_migration_history.c.id, func.count().label("count")).group_by( + data_migration_history.c.id + ) + return {row.id: int(row.count) for row in session.execute(stmt).all()} + + +def _last_applied_map(session: Session) -> dict[str, datetime]: + stmt = select( + data_migration_history.c.id, + func.max(data_migration_history.c.applied_at).label("last_applied_at"), + ).group_by(data_migration_history.c.id) + return {row.id: row.last_applied_at for row in session.execute(stmt).all()} + + +def get_status(session: Session) -> list[MigrationStatus]: + ensure_history_table(session) + applied_counts = _applied_counts(session) + last_applied = _last_applied_map(session) + statuses = [] + for migration in list_migrations(): + statuses.append( + MigrationStatus( + id=migration.id, + alembic_revision=migration.alembic_revision, + name=migration.name, + is_repeatable=migration.is_repeatable, + applied_count=applied_counts.get(migration.id, 0), + last_applied_at=last_applied.get(migration.id), + ) + ) + return statuses + + +def _record_migration(session: Session, migration: DataMigration) -> None: + session.execute( + data_migration_history.insert().values( + id=migration.id, + alembic_revision=migration.alembic_revision, + name=migration.name, + is_repeatable=bool(migration.is_repeatable), + applied_at=datetime.now(tz=timezone.utc), + ) + ) + + +def _is_applied(session: Session, migration: DataMigration) -> bool: + stmt = ( + select(func.count()) + .select_from(data_migration_history) + .where(data_migration_history.c.id == migration.id) + ) + return session.execute(stmt).scalar_one() > 0 + + +def _ensure_alembic_applied(session: Session, migration: DataMigration) -> None: + count = session.execute( + text("SELECT COUNT(*) FROM alembic_version WHERE version_num = :rev"), + {"rev": migration.alembic_revision}, + ).scalar_one() + if count == 0: + raise ValueError( + f"Alembic revision {migration.alembic_revision} not applied for " + f"data migration {migration.id}" + ) + + +def run_migration( + session: Session, + migration: DataMigration, + *, + force: bool = False, +) -> bool: + ensure_history_table(session) + _ensure_alembic_applied(session, migration) + + if not migration.is_repeatable and not force and _is_applied(session, migration): + logger.info("Skipping data migration %s (already applied)", migration.id) + return False + + logger.info("Running data migration %s - %s", migration.id, migration.name) + migration.run(session) + _record_migration(session, migration) + session.commit() + return True + + +def run_migration_by_id( + session: Session, migration_id: str, *, force: bool = False +) -> bool: + migration = get_migration(migration_id) + if migration is None: + raise ValueError(f"Unknown data migration: {migration_id}") + return run_migration(session, migration, force=force) + + +def run_all( + session: Session, + *, + include_repeatable: bool = False, + force: bool = False, + allowed_alembic_revisions: set[str] | None = None, +) -> list[str]: + ran = [] + for migration in list_migrations(): + if ( + allowed_alembic_revisions is not None + and migration.alembic_revision not in allowed_alembic_revisions + ): + logger.info( + "Skipping data migration %s (alembic revision %s not applied)", + migration.id, + migration.alembic_revision, + ) + continue + _ensure_alembic_applied(session, migration) + if migration.is_repeatable and not include_repeatable: + logger.info( + "Skipping repeatable migration %s (include_repeatable=false)", + migration.id, + ) + continue + if run_migration(session, migration, force=force): + ran.append(migration.id) + return ran diff --git a/pyproject.toml b/pyproject.toml index 22539c00..0110f976 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ dependencies = [ "sqlalchemy-utils==0.42.0", "starlette==0.49.1", "starlette-admin[i18n]>=0.16.0", + "typer>=0.21.1", "typing-extensions==4.15.0", "typing-inspection==0.4.1", "tzdata==2025.2", diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index d31b0bea..ab4dfa9a 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -19,8 +19,8 @@ import uuid from pathlib import Path -from click.testing import CliRunner from sqlalchemy import select +from typer.testing import CliRunner from cli.cli import cli from db import FieldActivity, FieldEvent, Observation, Sample @@ -138,10 +138,12 @@ def test_water_levels_cli_persists_observations(tmp_path, water_well_thing): """ def _write_csv(path: Path, *, well_name: str, notes: str): - csv_text = textwrap.dedent(f"""\ + csv_text = textwrap.dedent( + f"""\ field_staff,well_name_point_id,field_event_date_time,measurement_date_time,sampler,sample_method,mp_height,level_status,depth_to_water_ft,data_quality,water_level_notes CLI Tester,{well_name},2025-02-15T08:00:00-07:00,2025-02-15T10:30:00-07:00,Groundwater Team,electric tape,1.5,stable,42.5,approved,{notes} - """) + """ + ) path.write_text(csv_text) unique_notes = f"pytest-{uuid.uuid4()}" diff --git a/tests/test_data_migrations.py b/tests/test_data_migrations.py new file mode 100644 index 00000000..3b0ce521 --- /dev/null +++ b/tests/test_data_migrations.py @@ -0,0 +1,107 @@ +# =============================================================================== +# Copyright 2026 ross +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== +import importlib + +from sqlalchemy import select + +move_notes = importlib.import_module( + "data_migrations.migrations.20260205_0001_move_nma_location_notes" +) +from db.location import Location +from db.notes import Notes +from db.engine import session_ctx + + +def test_move_nma_location_notes_creates_notes_and_clears_field(): + with session_ctx() as session: + location = Location( + point="POINT (10.2 10.2)", + elevation=0, + release_status="public", + nma_location_notes="Legacy location note", + ) + session.add(location) + session.commit() + session.refresh(location) + + move_notes.run(session) + + notes = ( + session.execute( + select(Notes).where( + Notes.target_table == "location", + Notes.target_id == location.id, + ) + ) + .scalars() + .all() + ) + assert len(notes) == 1 + assert notes[0].content == "Legacy location note" + assert notes[0].note_type == "General" + assert notes[0].release_status == "public" + + session.refresh(location) + assert location.nma_location_notes is None + + session.delete(notes[0]) + session.delete(location) + session.commit() + + +def test_move_nma_location_notes_skips_duplicates(): + with session_ctx() as session: + location = Location( + point="POINT (10.4 10.4)", + elevation=1.0, + release_status="draft", + nma_location_notes="Duplicate note", + ) + session.add(location) + session.commit() + session.refresh(location) + + existing = Notes( + target_id=location.id, + target_table="location", + note_type="General", + content="Duplicate note", + release_status="draft", + ) + session.add(existing) + session.commit() + + move_notes.run(session) + + notes = ( + session.execute( + select(Notes).where( + Notes.target_table == "location", + Notes.target_id == location.id, + Notes.note_type == "General", + ) + ) + .scalars() + .all() + ) + assert len(notes) == 1 + + session.refresh(location) + assert location.nma_location_notes is None + + session.delete(notes[0]) + session.delete(location) + session.commit() diff --git a/tests/test_data_migrations_cli.py b/tests/test_data_migrations_cli.py new file mode 100644 index 00000000..56a19c73 --- /dev/null +++ b/tests/test_data_migrations_cli.py @@ -0,0 +1,93 @@ +# =============================================================================== +# Copyright 2026 ross +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== +from __future__ import annotations + +from contextlib import contextmanager + +from typer.testing import CliRunner + +from cli.cli import cli +from data_migrations.base import DataMigration + + +@contextmanager +def _fake_session_ctx(): + yield object() + + +def test_data_migrations_list_empty(monkeypatch): + monkeypatch.setattr("data_migrations.registry.list_migrations", lambda: []) + runner = CliRunner() + result = runner.invoke(cli, ["data-migrations", "list"]) + assert result.exit_code == 0 + assert "No data migrations registered" in result.output + + +def test_data_migrations_list_non_empty(monkeypatch): + migrations = [ + DataMigration( + id="20260205_0001", + alembic_revision="000000000000", + name="Backfill Example", + description="Example", + run=lambda session: None, + ) + ] + monkeypatch.setattr("data_migrations.registry.list_migrations", lambda: migrations) + runner = CliRunner() + result = runner.invoke(cli, ["data-migrations", "list"]) + assert result.exit_code == 0 + assert "20260205_0001: Backfill Example" in result.output + + +def test_data_migrations_run_invokes_runner(monkeypatch): + monkeypatch.setattr("db.engine.session_ctx", _fake_session_ctx) + + called = {} + + def fake_run(session, migration_id, force=False): + called["migration_id"] = migration_id + called["force"] = force + return True + + monkeypatch.setattr("data_migrations.runner.run_migration_by_id", fake_run) + + runner = CliRunner() + result = runner.invoke(cli, ["data-migrations", "run", "20260205_0001"]) + + assert result.exit_code == 0 + assert called == {"migration_id": "20260205_0001", "force": False} + assert "applied" in result.output + + +def test_data_migrations_run_all_invokes_runner(monkeypatch): + monkeypatch.setattr("db.engine.session_ctx", _fake_session_ctx) + + called = {} + + def fake_run_all(session, include_repeatable=False, force=False): + called["include_repeatable"] = include_repeatable + called["force"] = force + return ["20260205_0001"] + + monkeypatch.setattr("data_migrations.runner.run_all", fake_run_all) + + runner = CliRunner() + result = runner.invoke(cli, ["data-migrations", "run-all", "--include-repeatable"]) + + assert result.exit_code == 0 + assert called == {"include_repeatable": True, "force": False} + assert "applied 1 migration(s)" in result.output diff --git a/tests/test_thing_transfer.py b/tests/test_thing_transfer.py new file mode 100644 index 00000000..7c5e39c2 --- /dev/null +++ b/tests/test_thing_transfer.py @@ -0,0 +1,52 @@ +import pytest + +from transfers import thing_transfer as tt + + +@pytest.mark.parametrize( + "func_name,site_code,thing_type", + [ + ("transfer_rock_sample_locations", "R", "Rock sample location"), + ( + "transfer_diversion_of_surface_water", + "D", + "Diversion of surface water, etc.", + ), + ("transfer_lake_pond_reservoir", "L", "Lake, pond or reservoir"), + ("transfer_soil_gas_sample_locations", "S", "Soil gas sample location"), + ("transfer_other_site_types", "OT", "Other"), + ( + "transfer_outfall_wastewater_return_flow", + "O", + "Outfall of wastewater or return flow", + ), + ], +) +def test_transfer_new_site_types_calls_transfer_thing( + monkeypatch, func_name, site_code, thing_type +): + calls = [] + + def fake_transfer_thing(session, site_type, make_payload, limit=None): + class Row: + PointID = "PT-1" + PublicRelease = False + + payload = make_payload(Row) + calls.append((site_type, payload, limit)) + + monkeypatch.setattr(tt, "transfer_thing", fake_transfer_thing) + + getattr(tt, func_name)(session=None, limit=7) + + assert calls == [ + ( + site_code, + { + "name": "PT-1", + "thing_type": thing_type, + "release_status": "private", + }, + 7, + ) + ] diff --git a/uv.lock b/uv.lock index 67ea6ae0..b1a47719 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.13" [[package]] @@ -874,6 +874,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, ] +[[package]] +name = "markdown-it-py" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" }, +] + [[package]] name = "markupsafe" version = "3.0.2" @@ -902,6 +914,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4f/65/6079a46068dfceaeabb5dcad6d674f5f5c61a6fa5673746f42a9f4c233b3/MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f", size = 15739, upload-time = "2024-10-18T15:21:42.784Z" }, ] +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + [[package]] name = "multidict" version = "6.6.3" @@ -1101,6 +1122,7 @@ dependencies = [ { name = "sqlalchemy-utils" }, { name = "starlette" }, { name = "starlette-admin", extra = ["i18n"] }, + { name = "typer" }, { name = "typing-extensions" }, { name = "typing-inspection" }, { name = "tzdata" }, @@ -1209,6 +1231,7 @@ requires-dist = [ { name = "sqlalchemy-utils", specifier = "==0.42.0" }, { name = "starlette", specifier = "==0.49.1" }, { name = "starlette-admin", extras = ["i18n"], specifier = ">=0.16.0" }, + { name = "typer", specifier = ">=0.21.1" }, { name = "typing-extensions", specifier = "==4.15.0" }, { name = "typing-inspection", specifier = "==0.4.1" }, { name = "tzdata", specifier = "==2025.2" }, @@ -1766,6 +1789,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, ] +[[package]] +name = "rich" +version = "14.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/74/99/a4cab2acbb884f80e558b0771e97e21e939c5dfb460f488d19df485e8298/rich-14.3.2.tar.gz", hash = "sha256:e712f11c1a562a11843306f5ed999475f09ac31ffb64281f73ab29ffdda8b3b8", size = 230143, upload-time = "2026-02-01T16:20:47.908Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/45/615f5babd880b4bd7d405cc0dc348234c5ffb6ed1ea33e152ede08b2072d/rich-14.3.2-py3-none-any.whl", hash = "sha256:08e67c3e90884651da3239ea668222d19bea7b589149d8014a21c633420dbb69", size = 309963, upload-time = "2026-02-01T16:20:46.078Z" }, +] + [[package]] name = "rsa" version = "4.9.1" @@ -1835,6 +1871,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ea/f1/5e9b3ba5c7aa7ebfaf269657e728067d16a7c99401c7973ddf5f0cf121bd/shapely-2.1.1-cp313-cp313t-win_amd64.whl", hash = "sha256:8cb8f17c377260452e9d7720eeaf59082c5f8ea48cf104524d953e5d36d4bdb7", size = 1723061, upload-time = "2025-05-19T11:04:40.082Z" }, ] +[[package]] +name = "shellingham" +version = "1.5.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, +] + [[package]] name = "six" version = "1.17.0" @@ -1943,6 +1988,21 @@ i18n = [ { name = "babel" }, ] +[[package]] +name = "typer" +version = "0.21.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "rich" }, + { name = "shellingham" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/36/bf/8825b5929afd84d0dabd606c67cd57b8388cb3ec385f7ef19c5cc2202069/typer-0.21.1.tar.gz", hash = "sha256:ea835607cd752343b6b2b7ce676893e5a0324082268b48f27aa058bdb7d2145d", size = 110371, upload-time = "2026-01-06T11:21:10.989Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/1d/d9257dd49ff2ca23ea5f132edf1281a0c4f9de8a762b9ae399b670a59235/typer-0.21.1-py3-none-any.whl", hash = "sha256:7985e89081c636b88d172c2ee0cfe33c253160994d47bdfdc302defd7d1f1d01", size = 47381, upload-time = "2026-01-06T11:21:09.824Z" }, +] + [[package]] name = "types-pytz" version = "2025.2.0.20250809" From 060e7d56e9cf3b685ef947c2b1831f11b4b493cf Mon Sep 17 00:00:00 2001 From: jirhiker Date: Thu, 5 Feb 2026 12:15:00 +0000 Subject: [PATCH 2/4] Formatting changes --- tests/test_cli_commands.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index ab4dfa9a..220535ae 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -138,12 +138,10 @@ def test_water_levels_cli_persists_observations(tmp_path, water_well_thing): """ def _write_csv(path: Path, *, well_name: str, notes: str): - csv_text = textwrap.dedent( - f"""\ + csv_text = textwrap.dedent(f"""\ field_staff,well_name_point_id,field_event_date_time,measurement_date_time,sampler,sample_method,mp_height,level_status,depth_to_water_ft,data_quality,water_level_notes CLI Tester,{well_name},2025-02-15T08:00:00-07:00,2025-02-15T10:30:00-07:00,Groundwater Team,electric tape,1.5,stable,42.5,approved,{notes} - """ - ) + """) path.write_text(csv_text) unique_notes = f"pytest-{uuid.uuid4()}" From 4a688c94c88f6a64005e03d11186ec5d4f07e671 Mon Sep 17 00:00:00 2001 From: Jake Ross Date: Thu, 5 Feb 2026 23:17:39 +1100 Subject: [PATCH 3/4] Update cli/cli.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- cli/cli.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cli/cli.py b/cli/cli.py index bad3b720..b46fedac 100644 --- a/cli/cli.py +++ b/cli/cli.py @@ -82,7 +82,6 @@ def water_levels_bulk_upload( output_format: str | None = typer.Option( None, "--output", - case_sensitive=False, help="Optional output format", ), ): From a829322854a25e83fd75dfc56e5c84f5dd89f096 Mon Sep 17 00:00:00 2001 From: jakeross Date: Thu, 5 Feb 2026 23:23:00 +1100 Subject: [PATCH 4/4] feat: enhance alembic migration handling and improve output format options --- cli/cli.py | 9 ++++++-- data_migrations/runner.py | 45 ++++++++++++++++++++++++++++-------- tests/test_thing_transfer.py | 12 +++++----- 3 files changed, 49 insertions(+), 17 deletions(-) diff --git a/cli/cli.py b/cli/cli.py index b46fedac..f003dae4 100644 --- a/cli/cli.py +++ b/cli/cli.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== +from enum import Enum from pathlib import Path import typer @@ -27,6 +28,10 @@ cli.add_typer(data_migrations, name="data-migrations") +class OutputFormat(str, Enum): + json = "json" + + @cli.command("initialize-lexicon") def initialize_lexicon(): from core.initializers import init_lexicon @@ -79,7 +84,7 @@ def water_levels_bulk_upload( readable=True, help="Path to CSV file containing water level rows", ), - output_format: str | None = typer.Option( + output_format: OutputFormat | None = typer.Option( None, "--output", help="Optional output format", @@ -91,7 +96,7 @@ def water_levels_bulk_upload( # TODO: use the same helper function used by api to parse and upload a WL csv from cli.service_adapter import water_levels_csv - pretty_json = (output_format or "").lower() == "json" + pretty_json = output_format == OutputFormat.json water_levels_csv(file_path, pretty_json=pretty_json) diff --git a/data_migrations/runner.py b/data_migrations/runner.py index effc1922..6869974d 100644 --- a/data_migrations/runner.py +++ b/data_migrations/runner.py @@ -17,7 +17,11 @@ from dataclasses import dataclass from datetime import datetime, timezone +from pathlib import Path +from alembic.config import Config +from alembic.runtime.migration import MigrationContext +from alembic.script import ScriptDirectory from sqlalchemy import ( Boolean, Column, @@ -27,7 +31,6 @@ Table, func, select, - text, ) from sqlalchemy.orm import Session @@ -117,12 +120,31 @@ def _is_applied(session: Session, migration: DataMigration) -> bool: return session.execute(stmt).scalar_one() > 0 -def _ensure_alembic_applied(session: Session, migration: DataMigration) -> None: - count = session.execute( - text("SELECT COUNT(*) FROM alembic_version WHERE version_num = :rev"), - {"rev": migration.alembic_revision}, - ).scalar_one() - if count == 0: +def _get_applied_alembic_revisions(session: Session) -> set[str]: + root = Path(__file__).resolve().parents[1] + cfg = Config(str(root / "alembic.ini")) + cfg.set_main_option("script_location", str(root / "alembic")) + + connection = session.connection() + context = MigrationContext.configure(connection) + heads = context.get_current_heads() + script = ScriptDirectory.from_config(cfg) + + applied: set[str] = set() + for head in heads: + for rev in script.iterate_revisions(head, "base"): + applied.add(rev.revision) + return applied + + +def _ensure_alembic_applied( + session: Session, + migration: DataMigration, + applied_revisions: set[str] | None = None, +) -> None: + if applied_revisions is None: + applied_revisions = _get_applied_alembic_revisions(session) + if migration.alembic_revision not in applied_revisions: raise ValueError( f"Alembic revision {migration.alembic_revision} not applied for " f"data migration {migration.id}" @@ -136,7 +158,8 @@ def run_migration( force: bool = False, ) -> bool: ensure_history_table(session) - _ensure_alembic_applied(session, migration) + applied_revisions = _get_applied_alembic_revisions(session) + _ensure_alembic_applied(session, migration, applied_revisions=applied_revisions) if not migration.is_repeatable and not force and _is_applied(session, migration): logger.info("Skipping data migration %s (already applied)", migration.id) @@ -165,6 +188,8 @@ def run_all( force: bool = False, allowed_alembic_revisions: set[str] | None = None, ) -> list[str]: + if allowed_alembic_revisions is None: + allowed_alembic_revisions = _get_applied_alembic_revisions(session) ran = [] for migration in list_migrations(): if ( @@ -177,7 +202,9 @@ def run_all( migration.alembic_revision, ) continue - _ensure_alembic_applied(session, migration) + _ensure_alembic_applied( + session, migration, applied_revisions=allowed_alembic_revisions + ) if migration.is_repeatable and not include_repeatable: logger.info( "Skipping repeatable migration %s (include_repeatable=false)", diff --git a/tests/test_thing_transfer.py b/tests/test_thing_transfer.py index 7c5e39c2..ea33baf7 100644 --- a/tests/test_thing_transfer.py +++ b/tests/test_thing_transfer.py @@ -6,19 +6,19 @@ @pytest.mark.parametrize( "func_name,site_code,thing_type", [ - ("transfer_rock_sample_locations", "R", "Rock sample location"), + ("transfer_rock_sample_locations", "R", "rock sample location"), ( "transfer_diversion_of_surface_water", "D", - "Diversion of surface water, etc.", + "diversion of surface water, etc.", ), - ("transfer_lake_pond_reservoir", "L", "Lake, pond or reservoir"), - ("transfer_soil_gas_sample_locations", "S", "Soil gas sample location"), - ("transfer_other_site_types", "OT", "Other"), + ("transfer_lake_pond_reservoir", "L", "lake, pond or reservoir"), + ("transfer_soil_gas_sample_locations", "S", "soil gas sample location"), + ("transfer_other_site_types", "OT", "other"), ( "transfer_outfall_wastewater_return_flow", "O", - "Outfall of wastewater or return flow", + "outfall of wastewater or return flow", ), ], )