From 48e4ed6c8d0a27de70d7204ee08edbe9884fbb4e Mon Sep 17 00:00:00 2001 From: Charlie Date: Fri, 6 Feb 2026 15:38:36 -0500 Subject: [PATCH 1/2] Add full run and full sample API endpoints --- sample_registry/app.py | 66 ++++++++++++++++++++++++++++++++++++ sample_registry/registrar.py | 61 ++++++++++++++++++++++++++++----- tests/test_api.py | 36 ++++++++++++++++++++ 3 files changed, 155 insertions(+), 8 deletions(-) diff --git a/sample_registry/app.py b/sample_registry/app.py index ca666f9..b9323f2 100644 --- a/sample_registry/app.py +++ b/sample_registry/app.py @@ -660,6 +660,72 @@ def api_get_annotations(): ) +@app.get("/api/get_full_run") +def api_get_full_run(): + run_accession = request.args.get("run_accession") + if not run_accession: + return api_error("Missing required query parameter: run_accession") + try: + run_accession = int(run_accession) + except ValueError as exc: + return api_error(f"Invalid run_accession value: {exc}") + + with api_registry() as registry: + full_run = registry.get_full_run(run_accession) + + if not full_run: + return jsonify({"status": "ok", "run": None, "samples": []}) + + annotations = full_run["annotations_by_sample_accession"] + return jsonify( + { + "status": "ok", + "run": api_model_to_dict(full_run["run"]), + "samples": [ + { + "sample": api_model_to_dict(sample), + "annotations": [ + api_model_to_dict(annotation) + for annotation in annotations[sample.sample_accession] + ], + } + for sample in full_run["samples"] + ], + } + ) + + +@app.get("/api/get_full_sample") +def api_get_full_sample(): + sample_accession = request.args.get("sample_accession") + if not sample_accession: + return api_error("Missing required query parameter: sample_accession") + try: + sample_accession = int(sample_accession) + except ValueError as exc: + return api_error(f"Invalid sample_accession value: {exc}") + + with api_registry() as registry: + full_sample = registry.get_full_sample(sample_accession) + + return jsonify( + { + "status": "ok", + "sample": ( + api_model_to_dict(full_sample["sample"]) if full_sample else None + ), + "annotations": ( + [ + api_model_to_dict(annotation) + for annotation in full_sample["annotations"] + ] + if full_sample + else [] + ), + } + ) + + @app.route("/description") def show_description(): return render_template("description.html") diff --git a/sample_registry/registrar.py b/sample_registry/registrar.py index 4bd11b5..f510a28 100644 --- a/sample_registry/registrar.py +++ b/sample_registry/registrar.py @@ -94,14 +94,59 @@ def get_samples(self, run_accession: int) -> list[Sample]: ).all() ) - def get_annotations(self, sample_accession: int) -> list[Annotation]: - return list( - self.session.scalars( - select(Annotation).where( - Annotation.sample_accession == sample_accession - ) - ).all() - ) + def get_annotations(self, sample_accession: int) -> list[Annotation]: + return list( + self.session.scalars( + select(Annotation).where( + Annotation.sample_accession == sample_accession + ) + ).all() + ) + + def get_full_sample(self, sample_accession: int) -> dict | None: + """Return a sample and all of its annotations.""" + + sample = self.session.scalar( + select(Sample).where(Sample.sample_accession == sample_accession) + ) + if not sample: + return None + + annotations = self.get_annotations(sample_accession) + return {"sample": sample, "annotations": annotations} + + def get_full_run(self, run_accession: int) -> dict | None: + """Return a run with all samples and per-sample annotations.""" + + run = self.get_run(run_accession) + if not run: + return None + + samples = self.get_samples(run_accession) + sample_accessions = [sample.sample_accession for sample in samples] + annotations = [] + if sample_accessions: + annotations = list( + self.session.scalars( + select(Annotation).where( + Annotation.sample_accession.in_(sample_accessions) + ) + ).all() + ) + + annotations_by_sample_accession: dict[int, list[Annotation]] = { + sample_accession: [] for sample_accession in sample_accessions + } + for annotation in annotations: + annotations_by_sample_accession[annotation.sample_accession].append( + annotation + ) + + return { + "run": run, + "samples": samples, + "annotations_by_sample_accession": annotations_by_sample_accession, + } def register_run( self, diff --git a/tests/test_api.py b/tests/test_api.py index 9549f47..559ebdb 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -288,6 +288,42 @@ def test_api_get_annotations(api_client): ] +def test_api_get_full_run(api_client): + client, _ = api_client + response = client.get("/api/get_full_run", query_string={"run_accession": 1}) + assert response.status_code == 200 + payload = response.get_json() + assert payload["status"] == "ok" + assert payload["run"]["run_accession"] == 1 + samples = sorted( + payload["samples"], key=lambda item: item["sample"]["sample_accession"] + ) + assert [item["sample"]["sample_accession"] for item in samples] == [1, 2] + sample_1_annotations = sorted( + samples[0]["annotations"], key=lambda item: item["key"] + ) + assert sample_1_annotations == [ + {"sample_accession": 1, "key": "key0", "val": "val0"}, + {"sample_accession": 1, "key": "key4", "val": "val0"}, + ] + + +def test_api_get_full_sample(api_client): + client, _ = api_client + response = client.get( + "/api/get_full_sample", query_string={"sample_accession": 1} + ) + assert response.status_code == 200 + payload = response.get_json() + assert payload["status"] == "ok" + assert payload["sample"]["sample_accession"] == 1 + annotations = sorted(payload["annotations"], key=lambda item: item["key"]) + assert annotations == [ + {"sample_accession": 1, "key": "key0", "val": "val0"}, + {"sample_accession": 1, "key": "key4", "val": "val0"}, + ] + + def test_api_modify_annotation(api_client): client, Session = api_client response = client.post( From ce007402314714448234833f1349b2874f8b653e Mon Sep 17 00:00:00 2001 From: Charlie Date: Fri, 6 Feb 2026 15:57:21 -0500 Subject: [PATCH 2/2] Fix full run annotation query for large SQLite runs --- sample_registry/registrar.py | 576 ++++++++++++++++++----------------- tests/test_api.py | 4 +- tests/test_registrar.py | 416 ++++++++++++++----------- 3 files changed, 523 insertions(+), 473 deletions(-) diff --git a/sample_registry/registrar.py b/sample_registry/registrar.py index f510a28..4267188 100644 --- a/sample_registry/registrar.py +++ b/sample_registry/registrar.py @@ -1,30 +1,30 @@ -from typing import Optional -from sqlalchemy import and_, create_engine, delete, insert, select, update -from sqlalchemy.orm import Session, sessionmaker -from sample_registry.db import STANDARD_TAGS -from sample_registry.mapping import SampleTable -from sample_registry.models import Annotation, Sample, Run -from sample_registry.standards import MACHINE_TYPE_MAPPINGS - - -class SampleRegistry: - machines = MACHINE_TYPE_MAPPINGS.values() - kits = ["Nextera XT"] - - def __init__(self, session: Optional[Session] = None, uri: Optional[str] = None): - if session and uri: - raise ValueError("Cannot provide both session and uri") - elif session: - self.session = session - elif uri: - engine = create_engine(uri, echo=False) - SessionLocal = sessionmaker(bind=engine) - self.session = SessionLocal() - else: - from sample_registry import session as imported_session - - self.session = imported_session - +from typing import Optional +from sqlalchemy import and_, create_engine, delete, insert, select, update +from sqlalchemy.orm import Session, sessionmaker +from sample_registry.db import STANDARD_TAGS +from sample_registry.mapping import SampleTable +from sample_registry.models import Annotation, Sample, Run +from sample_registry.standards import MACHINE_TYPE_MAPPINGS + + +class SampleRegistry: + machines = MACHINE_TYPE_MAPPINGS.values() + kits = ["Nextera XT"] + + def __init__(self, session: Optional[Session] = None, uri: Optional[str] = None): + if session and uri: + raise ValueError("Cannot provide both session and uri") + elif session: + self.session = session + elif uri: + engine = create_engine(uri, echo=False) + SessionLocal = sessionmaker(bind=engine) + self.session = SessionLocal() + else: + from sample_registry import session as imported_session + + self.session = imported_session + def check_run_accession(self, acc: int) -> Run: run = self.session.scalar(select(Run).where(Run.run_accession == acc)) if not run: @@ -38,62 +38,62 @@ def check_sample_accession(self, acc: int) -> Sample: if not sample: raise ValueError("Sample does not exist %s" % acc) return sample - - def get_run(self, run_accession: int) -> Run | None: - """Return the ``Run`` record for ``run_accession``. - - Parameters - ---------- - run_accession: - Accession number identifying the sequencing run. - - Returns - ------- - Run | None - The ``Run`` instance if found, otherwise ``None``. - """ - - return self.session.scalar( - select(Run).where(Run.run_accession == run_accession) - ) - - def get_runs_by_data_uri(self, substring: str) -> list[int]: - """Return run accessions whose ``data_uri`` contains ``substring``. - - Parameters - ---------- - substring: - Text that must be contained within the ``data_uri`` field. - - Returns - ------- - list[int] - Run accessions ordered ascending for runs whose ``data_uri`` - contains ``substring``. - """ - - return list( - self.session.scalars( - select(Run.run_accession) - .where(Run.data_uri.contains(substring)) - .order_by(Run.run_accession) - ).all() - ) - - def get_samples(self, run_accession: int) -> list[Sample]: - """Return the list of ``Sample`` records for ``run_accession``. - - Parameters - ---------- - run_accession: - Accession number identifying the sequencing run. - """ - return list( - self.session.scalars( - select(Sample).where(Sample.run_accession == run_accession) - ).all() - ) - + + def get_run(self, run_accession: int) -> Run | None: + """Return the ``Run`` record for ``run_accession``. + + Parameters + ---------- + run_accession: + Accession number identifying the sequencing run. + + Returns + ------- + Run | None + The ``Run`` instance if found, otherwise ``None``. + """ + + return self.session.scalar( + select(Run).where(Run.run_accession == run_accession) + ) + + def get_runs_by_data_uri(self, substring: str) -> list[int]: + """Return run accessions whose ``data_uri`` contains ``substring``. + + Parameters + ---------- + substring: + Text that must be contained within the ``data_uri`` field. + + Returns + ------- + list[int] + Run accessions ordered ascending for runs whose ``data_uri`` + contains ``substring``. + """ + + return list( + self.session.scalars( + select(Run.run_accession) + .where(Run.data_uri.contains(substring)) + .order_by(Run.run_accession) + ).all() + ) + + def get_samples(self, run_accession: int) -> list[Sample]: + """Return the list of ``Sample`` records for ``run_accession``. + + Parameters + ---------- + run_accession: + Accession number identifying the sequencing run. + """ + return list( + self.session.scalars( + select(Sample).where(Sample.run_accession == run_accession) + ).all() + ) + def get_annotations(self, sample_accession: int) -> list[Annotation]: return list( self.session.scalars( @@ -124,15 +124,17 @@ def get_full_run(self, run_accession: int) -> dict | None: samples = self.get_samples(run_accession) sample_accessions = [sample.sample_accession for sample in samples] - annotations = [] - if sample_accessions: - annotations = list( - self.session.scalars( - select(Annotation).where( - Annotation.sample_accession.in_(sample_accessions) - ) - ).all() - ) + + annotations = list( + self.session.scalars( + select(Annotation) + .join( + Sample, + Annotation.sample_accession == Sample.sample_accession, + ) + .where(Sample.run_accession == run_accession) + ).all() + ) annotations_by_sample_accession: dict[int, list[Annotation]] = { sample_accession: [] for sample_accession in sample_accessions @@ -147,198 +149,198 @@ def get_full_run(self, run_accession: int) -> dict | None: "samples": samples, "annotations_by_sample_accession": annotations_by_sample_accession, } - - def register_run( - self, - run_date: str, - machine_type: str, - machine_kit: str, - lane: int, - data_uri: str, - comment: str, - ) -> Optional[int]: - # Using this because there are situations where the autoincrement is untrustworthy - max_run_accession = self.session.scalar( - select(Run.run_accession).order_by(Run.run_accession.desc()).limit(1) - ) - - return self.session.scalar( - insert(Run) - .returning(Run.run_accession) - .values( - { - "run_accession": max_run_accession + 1 if max_run_accession else 1, - "run_date": run_date, - "machine_type": machine_type, - "machine_kit": machine_kit, - "lane": lane, - "data_uri": data_uri, - "comment": comment, - } - ) - ) - - def modify_run(self, run_accession: int, **kwargs): - kwargs = {k: v for k, v in kwargs.items() if v is not None} - self.session.execute( - update(Run).where(Run.run_accession == run_accession).values(**kwargs) - ) - - def check_samples(self, run_accession: int, exists: bool = True) -> list[Sample]: - samples = self.session.scalars( - select(Sample).where(Sample.run_accession == run_accession) - ).all() - if bool(samples) != exists: - s = "exist" if exists else "don't exist" - raise ValueError(f"Samples {s} for run {run_accession}") - return list(samples) - - def register_samples( - self, run_accession: int, sample_table: SampleTable - ) -> list[int]: - sample_tups = [ - (sample_name, barcode_sequence) - for sample_name, barcode_sequence in sample_table.core_info - ] - if self.session.scalars( - select(Sample).where( - and_( - Sample.run_accession == run_accession, - Sample.sample_name.in_([s[0] for s in sample_tups]), - Sample.barcode_sequence.in_([s[1] for s in sample_tups]), - ) - ) - ).first(): - raise ValueError("Samples already registered for run %s" % run_accession) - - # Using this because there are situations where the autoincrement is untrustworthy - max_sample_accession = self.session.scalar( - select(Sample.sample_accession) - .order_by(Sample.sample_accession.desc()) - .limit(1) - ) - - return self.session.scalars( - insert(Sample) - .returning(Sample.sample_accession) - .values( - [ - { - "sample_accession": ( - max_sample_accession + i + 1 - if max_sample_accession - else i + 1 - ), - "run_accession": run_accession, - "sample_name": sample_name, - "barcode_sequence": barcode_sequence, - } - for i, (sample_name, barcode_sequence) in enumerate( - sample_table.core_info - ) - ] - ) - ) - - def modify_sample(self, sample_accession: int, **kwargs): - kwargs = {k: v for k, v in kwargs.items() if v is not None} - self.session.execute( - update(Sample) - .where(Sample.sample_accession == sample_accession) - .values(**kwargs) - ) - - def remove_samples(self, run_accession: int) -> list[int]: - samples = self.session.scalars( - select(Sample.sample_accession).where(Sample.run_accession == run_accession) - ).all() - self.session.execute( - delete(Annotation).where(Annotation.sample_accession.in_(samples)) - ) - self.session.execute( - delete(Sample).where(Sample.run_accession == run_accession) - ) - - return list(samples) - - def register_annotations( - self, run_accession: int, sample_table: SampleTable - ) -> list[tuple[int, str]]: - accessions = self._get_sample_accessions(run_accession, sample_table) - - # Remove existing annotations - self.session.execute( - delete(Annotation).where(Annotation.sample_accession.in_(accessions)) - ) - self.session.execute( - update(Sample) - .where(Sample.sample_accession.in_(accessions)) - .values({k: None for k in STANDARD_TAGS.values()}) - ) - - # Register new annotations - standard_annotation_args = [] - annotation_args = [] - for a, pairs in zip(accessions, sample_table.annotations): - for k, v in pairs: - if k in STANDARD_TAGS: - standard_annotation_args.append((a, STANDARD_TAGS[k], v)) - else: - annotation_args.append((a, k, v)) - - for a, k, v in standard_annotation_args: - self.session.execute( - update(Sample).where(Sample.sample_accession == a).values({k: v}) - ) - - annotation_keys = [] - if annotation_args: - annotation_keys = self.session.scalars( - insert(Annotation) - .returning(Annotation.sample_accession, Annotation.key) - .values( - [ - {"sample_accession": a, "key": k, "val": v} - for a, k, v in annotation_args - ] - ) - ) - - return list(annotation_keys) - - def _get_sample_accessions( - self, run_accession: int, sample_table: SampleTable - ) -> list[int]: - sample_tups = [ - (sample_name, barcode_sequence) - for sample_name, barcode_sequence in sample_table.core_info - ] - accessions = self.session.scalars( - select(Sample.sample_accession).where( - and_( - Sample.run_accession == run_accession, - Sample.sample_name.in_([s[0] for s in sample_tups]), - Sample.barcode_sequence.in_([s[1] for s in sample_tups]), - ) - ) - ).all() - - unaccessioned_recs = [] - for accession, rec in zip(accessions, sample_table.recs): - if accession is None: - unaccessioned_recs.append(rec) - if unaccessioned_recs: - raise IOError("Not accessioned: %s" % unaccessioned_recs) - return list(accessions) - - def modify_annotation(self, sample_accession: int, key: str, val: str): - self.session.execute( - update(Annotation) - .where( - and_( - Annotation.sample_accession == sample_accession, - Annotation.key == key, - ) - ) - .values({"val": val}) - ) + + def register_run( + self, + run_date: str, + machine_type: str, + machine_kit: str, + lane: int, + data_uri: str, + comment: str, + ) -> Optional[int]: + # Using this because there are situations where the autoincrement is untrustworthy + max_run_accession = self.session.scalar( + select(Run.run_accession).order_by(Run.run_accession.desc()).limit(1) + ) + + return self.session.scalar( + insert(Run) + .returning(Run.run_accession) + .values( + { + "run_accession": max_run_accession + 1 if max_run_accession else 1, + "run_date": run_date, + "machine_type": machine_type, + "machine_kit": machine_kit, + "lane": lane, + "data_uri": data_uri, + "comment": comment, + } + ) + ) + + def modify_run(self, run_accession: int, **kwargs): + kwargs = {k: v for k, v in kwargs.items() if v is not None} + self.session.execute( + update(Run).where(Run.run_accession == run_accession).values(**kwargs) + ) + + def check_samples(self, run_accession: int, exists: bool = True) -> list[Sample]: + samples = self.session.scalars( + select(Sample).where(Sample.run_accession == run_accession) + ).all() + if bool(samples) != exists: + s = "exist" if exists else "don't exist" + raise ValueError(f"Samples {s} for run {run_accession}") + return list(samples) + + def register_samples( + self, run_accession: int, sample_table: SampleTable + ) -> list[int]: + sample_tups = [ + (sample_name, barcode_sequence) + for sample_name, barcode_sequence in sample_table.core_info + ] + if self.session.scalars( + select(Sample).where( + and_( + Sample.run_accession == run_accession, + Sample.sample_name.in_([s[0] for s in sample_tups]), + Sample.barcode_sequence.in_([s[1] for s in sample_tups]), + ) + ) + ).first(): + raise ValueError("Samples already registered for run %s" % run_accession) + + # Using this because there are situations where the autoincrement is untrustworthy + max_sample_accession = self.session.scalar( + select(Sample.sample_accession) + .order_by(Sample.sample_accession.desc()) + .limit(1) + ) + + return self.session.scalars( + insert(Sample) + .returning(Sample.sample_accession) + .values( + [ + { + "sample_accession": ( + max_sample_accession + i + 1 + if max_sample_accession + else i + 1 + ), + "run_accession": run_accession, + "sample_name": sample_name, + "barcode_sequence": barcode_sequence, + } + for i, (sample_name, barcode_sequence) in enumerate( + sample_table.core_info + ) + ] + ) + ) + + def modify_sample(self, sample_accession: int, **kwargs): + kwargs = {k: v for k, v in kwargs.items() if v is not None} + self.session.execute( + update(Sample) + .where(Sample.sample_accession == sample_accession) + .values(**kwargs) + ) + + def remove_samples(self, run_accession: int) -> list[int]: + samples = self.session.scalars( + select(Sample.sample_accession).where(Sample.run_accession == run_accession) + ).all() + self.session.execute( + delete(Annotation).where(Annotation.sample_accession.in_(samples)) + ) + self.session.execute( + delete(Sample).where(Sample.run_accession == run_accession) + ) + + return list(samples) + + def register_annotations( + self, run_accession: int, sample_table: SampleTable + ) -> list[tuple[int, str]]: + accessions = self._get_sample_accessions(run_accession, sample_table) + + # Remove existing annotations + self.session.execute( + delete(Annotation).where(Annotation.sample_accession.in_(accessions)) + ) + self.session.execute( + update(Sample) + .where(Sample.sample_accession.in_(accessions)) + .values({k: None for k in STANDARD_TAGS.values()}) + ) + + # Register new annotations + standard_annotation_args = [] + annotation_args = [] + for a, pairs in zip(accessions, sample_table.annotations): + for k, v in pairs: + if k in STANDARD_TAGS: + standard_annotation_args.append((a, STANDARD_TAGS[k], v)) + else: + annotation_args.append((a, k, v)) + + for a, k, v in standard_annotation_args: + self.session.execute( + update(Sample).where(Sample.sample_accession == a).values({k: v}) + ) + + annotation_keys = [] + if annotation_args: + annotation_keys = self.session.scalars( + insert(Annotation) + .returning(Annotation.sample_accession, Annotation.key) + .values( + [ + {"sample_accession": a, "key": k, "val": v} + for a, k, v in annotation_args + ] + ) + ) + + return list(annotation_keys) + + def _get_sample_accessions( + self, run_accession: int, sample_table: SampleTable + ) -> list[int]: + sample_tups = [ + (sample_name, barcode_sequence) + for sample_name, barcode_sequence in sample_table.core_info + ] + accessions = self.session.scalars( + select(Sample.sample_accession).where( + and_( + Sample.run_accession == run_accession, + Sample.sample_name.in_([s[0] for s in sample_tups]), + Sample.barcode_sequence.in_([s[1] for s in sample_tups]), + ) + ) + ).all() + + unaccessioned_recs = [] + for accession, rec in zip(accessions, sample_table.recs): + if accession is None: + unaccessioned_recs.append(rec) + if unaccessioned_recs: + raise IOError("Not accessioned: %s" % unaccessioned_recs) + return list(accessions) + + def modify_annotation(self, sample_accession: int, key: str, val: str): + self.session.execute( + update(Annotation) + .where( + and_( + Annotation.sample_accession == sample_accession, + Annotation.key == key, + ) + ) + .values({"val": val}) + ) diff --git a/tests/test_api.py b/tests/test_api.py index 559ebdb..9616251 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -310,9 +310,7 @@ def test_api_get_full_run(api_client): def test_api_get_full_sample(api_client): client, _ = api_client - response = client.get( - "/api/get_full_sample", query_string={"sample_accession": 1} - ) + response = client.get("/api/get_full_sample", query_string={"sample_accession": 1}) assert response.status_code == 200 payload = response.get_json() assert payload["status"] == "ok" diff --git a/tests/test_registrar.py b/tests/test_registrar.py index 5d51b0c..a2cb9fe 100644 --- a/tests/test_registrar.py +++ b/tests/test_registrar.py @@ -1,183 +1,233 @@ -from typing import Generator -import pytest -from sqlalchemy import create_engine, func, select -from sqlalchemy.orm import Session, sessionmaker -from sample_registry.db import create_test_db -from sample_registry.mapping import SampleTable -from sample_registry.models import ( - Annotation, - Base, - Run, - Sample, -) -from sample_registry.registrar import SampleRegistry - -recs = [ - { - "SampleID": "S1", - "BarcodeSequence": "GCCT", - "HostSpecies": "Human", # Doesn't count towards annotations count, is stored in Sample - "SubjectID": "Hu23", - }, - { - "SampleID": "S2", - "BarcodeSequence": "GCAT", - "key1": "val1", # Counts towards annotations count - "key2": "val2", - }, -] - - -@pytest.fixture() -def db() -> Generator[Session, None, None]: - # This fixture should run before every test and create a new in-memory SQLite test database with identical data - SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:" - engine = create_engine(SQLALCHEMY_DATABASE_URI, echo=False) - Base.metadata.create_all(engine) - - Session = sessionmaker(bind=engine) - session = Session() - - create_test_db(session) - - yield session - - session.rollback() - session.close() - - -def test_check_run_accession(db): - registry = SampleRegistry(db) - assert registry.check_run_accession(1).run_accession == 1 - - -def test_get_run(db): - registry = SampleRegistry(db) - run = registry.get_run(1) - assert run.run_accession == 1 - assert run.run_date == "2024-07-02" - - -def test_get_run_doesnt_exist(db): - registry = SampleRegistry(db) - assert registry.get_run(9999) is None - - -def test_get_runs_by_data_uri(db): - registry = SampleRegistry(db) - assert registry.get_runs_by_data_uri("run1") == [1] - assert registry.get_runs_by_data_uri("raw_data") == [1, 2, 3] - assert registry.get_runs_by_data_uri("not-a-uri") == [] - - -def test_check_run_accession_doesnt_exist(db): - registry = SampleRegistry(db) - with pytest.raises(ValueError): - registry.check_run_accession(9999) - - -def test_register_run(db): - registry = SampleRegistry(db) - assert ( - registry.register_run( - "2021-01-01", - "Illumina-MiSeq", - "Nextera XT", - 1, - "/path/to/data/", - "A comment", - ) - == 4 - ) - - -def test_modify_run(db): - registry = SampleRegistry(db) - registry.modify_run(1, run_date="12/12/12", machine_type="Illumina-MiSeq") - assert db.scalar(select(Run).where(Run.run_accession == 1)).run_date == "12/12/12" - - -def test_check_samples(db): - registry = SampleRegistry(db) - assert len(registry.check_samples(1)) == 2 - - -def test_check_samples_exist(db): - registry = SampleRegistry(db) - with pytest.raises(ValueError): - registry.check_samples(1, exists=False) - - -def test_check_samples_doesnt_exist(db): - registry = SampleRegistry(db) - with pytest.raises(ValueError): - registry.check_samples(9999) - - -def test_register_samples(db): - registry = SampleRegistry(db) - sample_table = SampleTable(recs) - registry.register_samples(3, sample_table) - - assert ( - db.scalar( - select(func.count(Sample.sample_accession)).where(Sample.run_accession == 3) - ) - == 2 - ) - - -def test_modify_samples(db): - registry = SampleRegistry(db) - registry.modify_sample(1, sample_name="New name") - - assert ( - db.scalar(select(Sample).where(Sample.sample_accession == 1)).sample_name - == "New name" - ) - - -def test_register_samples_already_registered(db): - registry = SampleRegistry(db) - sample_table = SampleTable(recs) - registry.register_samples(3, sample_table) - with pytest.raises(ValueError): - registry.register_samples(3, sample_table) - - -def test_remove_samples(db): - registry = SampleRegistry(db) - sample_accessions = registry.remove_samples(1) - assert not db.scalar(select(Sample).where(Sample.run_accession == 1)) - assert not db.scalar( - select(Annotation).where(Annotation.sample_accession.in_(sample_accessions)) - ) - - -def test_register_annotations(db): - registry = SampleRegistry(db) - sample_table = SampleTable(recs) - registry.register_samples(3, sample_table) - registry.register_annotations(3, sample_table) - - assert ( - db.scalar( - select(func.count(Annotation.sample_accession)).where( - Annotation.sample_accession == 3 - ) - ) - == 2 - ) - - -def test_modify_annotation(db): - registry = SampleRegistry(db) - registry.modify_annotation(1, "key0", "new val") - assert ( - db.scalar( - select(Annotation).where( - Annotation.sample_accession == 1, Annotation.key == "key0" - ) - ).val - == "new val" - ) +from typing import Generator +import pytest +from sqlalchemy import create_engine, func, insert, select +from sqlalchemy.orm import Session, sessionmaker +from sample_registry.db import create_test_db +from sample_registry.mapping import SampleTable +from sample_registry.models import ( + Annotation, + Base, + Run, + Sample, +) +from sample_registry.registrar import SampleRegistry + +recs = [ + { + "SampleID": "S1", + "BarcodeSequence": "GCCT", + "HostSpecies": "Human", # Doesn't count towards annotations count, is stored in Sample + "SubjectID": "Hu23", + }, + { + "SampleID": "S2", + "BarcodeSequence": "GCAT", + "key1": "val1", # Counts towards annotations count + "key2": "val2", + }, +] + + +@pytest.fixture() +def db() -> Generator[Session, None, None]: + # This fixture should run before every test and create a new in-memory SQLite test database with identical data + SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:" + engine = create_engine(SQLALCHEMY_DATABASE_URI, echo=False) + Base.metadata.create_all(engine) + + Session = sessionmaker(bind=engine) + session = Session() + + create_test_db(session) + + yield session + + session.rollback() + session.close() + + +def test_check_run_accession(db): + registry = SampleRegistry(db) + assert registry.check_run_accession(1).run_accession == 1 + + +def test_get_run(db): + registry = SampleRegistry(db) + run = registry.get_run(1) + assert run.run_accession == 1 + assert run.run_date == "2024-07-02" + + +def test_get_run_doesnt_exist(db): + registry = SampleRegistry(db) + assert registry.get_run(9999) is None + + +def test_get_runs_by_data_uri(db): + registry = SampleRegistry(db) + assert registry.get_runs_by_data_uri("run1") == [1] + assert registry.get_runs_by_data_uri("raw_data") == [1, 2, 3] + assert registry.get_runs_by_data_uri("not-a-uri") == [] + + +def test_get_full_run_handles_large_sample_sets(db): + registry = SampleRegistry(db) + run_accession = registry.register_run( + "2025-01-01", + "Illumina-MiSeq", + "Nextera XT", + 1, + "raw_data/run4/Undetermined_S0_L001_R1_001.fastq.gz", + "large run", + ) + + sample_count = 1100 + start_accession = db.scalar(select(func.max(Sample.sample_accession))) + 1 + sample_rows = [ + { + "sample_accession": start_accession + i, + "sample_name": f"BulkSample{i}", + "run_accession": run_accession, + "barcode_sequence": f"BC{i:04d}", + "primer_sequence": None, + "sample_type": None, + "subject_id": None, + "host_species": None, + } + for i in range(sample_count) + ] + db.execute(insert(Sample).values(sample_rows)) + + annotation_rows = [ + { + "sample_accession": start_accession + i, + "key": "bulk_key", + "val": f"bulk_val_{i}", + } + for i in range(sample_count) + ] + db.execute(insert(Annotation).values(annotation_rows)) + db.commit() + + full_run = registry.get_full_run(run_accession) + + assert full_run["run"].run_accession == run_accession + assert len(full_run["samples"]) == sample_count + assert len(full_run["annotations_by_sample_accession"]) == sample_count + assert all( + len(sample_annotations) == 1 + for sample_annotations in full_run["annotations_by_sample_accession"].values() + ) + + +def test_check_run_accession_doesnt_exist(db): + registry = SampleRegistry(db) + with pytest.raises(ValueError): + registry.check_run_accession(9999) + + +def test_register_run(db): + registry = SampleRegistry(db) + assert ( + registry.register_run( + "2021-01-01", + "Illumina-MiSeq", + "Nextera XT", + 1, + "/path/to/data/", + "A comment", + ) + == 4 + ) + + +def test_modify_run(db): + registry = SampleRegistry(db) + registry.modify_run(1, run_date="12/12/12", machine_type="Illumina-MiSeq") + assert db.scalar(select(Run).where(Run.run_accession == 1)).run_date == "12/12/12" + + +def test_check_samples(db): + registry = SampleRegistry(db) + assert len(registry.check_samples(1)) == 2 + + +def test_check_samples_exist(db): + registry = SampleRegistry(db) + with pytest.raises(ValueError): + registry.check_samples(1, exists=False) + + +def test_check_samples_doesnt_exist(db): + registry = SampleRegistry(db) + with pytest.raises(ValueError): + registry.check_samples(9999) + + +def test_register_samples(db): + registry = SampleRegistry(db) + sample_table = SampleTable(recs) + registry.register_samples(3, sample_table) + + assert ( + db.scalar( + select(func.count(Sample.sample_accession)).where(Sample.run_accession == 3) + ) + == 2 + ) + + +def test_modify_samples(db): + registry = SampleRegistry(db) + registry.modify_sample(1, sample_name="New name") + + assert ( + db.scalar(select(Sample).where(Sample.sample_accession == 1)).sample_name + == "New name" + ) + + +def test_register_samples_already_registered(db): + registry = SampleRegistry(db) + sample_table = SampleTable(recs) + registry.register_samples(3, sample_table) + with pytest.raises(ValueError): + registry.register_samples(3, sample_table) + + +def test_remove_samples(db): + registry = SampleRegistry(db) + sample_accessions = registry.remove_samples(1) + assert not db.scalar(select(Sample).where(Sample.run_accession == 1)) + assert not db.scalar( + select(Annotation).where(Annotation.sample_accession.in_(sample_accessions)) + ) + + +def test_register_annotations(db): + registry = SampleRegistry(db) + sample_table = SampleTable(recs) + registry.register_samples(3, sample_table) + registry.register_annotations(3, sample_table) + + assert ( + db.scalar( + select(func.count(Annotation.sample_accession)).where( + Annotation.sample_accession == 3 + ) + ) + == 2 + ) + + +def test_modify_annotation(db): + registry = SampleRegistry(db) + registry.modify_annotation(1, "key0", "new val") + assert ( + db.scalar( + select(Annotation).where( + Annotation.sample_accession == 1, Annotation.key == "key0" + ) + ).val + == "new val" + )