diff --git a/.gitignore b/.gitignore index 699fa5be0..1d171a48f 100644 --- a/.gitignore +++ b/.gitignore @@ -66,3 +66,6 @@ slurm*.out lightning_logs/ # NOTE: uv.lock is NOT ignored - it should be tracked for reproducibility + +# Local-only planning docs (not for upstream) +applications/dynaclr/docs/DAGs/evaluation_matrix.md diff --git a/applications/airtable/configs/prepare_config.yml b/applications/airtable/configs/prepare_config.yml new file mode 100644 index 000000000..da9eb5f7b --- /dev/null +++ b/applications/airtable/configs/prepare_config.yml @@ -0,0 +1,47 @@ +# Dataset preparation pipeline: NFS -> VAST rechunked zarr v3 +# Usage: prepare run -c prepare_config.yml [--dry-run] + +nfs_root: /hpc/projects/intracellular_dashboard/organelle_dynamics +vast_root: /hpc/projects/organelle_phenotyping/datasets +workspace_dir: /hpc/mydata/eduardo.hirata/repos/viscy + +concatenate: + # null = auto-detect raw channels (Phase3D + raw *). Set explicitly to override. + channel_names: null + chunks_czyx: [1, 16, 256, 256] + shards_ratio: [1, 1, 8, 8, 8] + output_ome_zarr_version: "0.5" + conda_env: biahub + # Override biahub's internal SLURM settings (passed via -sb flag) + # Set to null to use biahub defaults + sbatch_overrides: + partition: cpu + +qc: + channel_names: [Phase3D] + NA_det: 1.35 + lambda_ill: 0.450 + pixel_size: 0.1494 + midband_fractions: [0.125, 0.25] + device: cuda + num_workers: 16 + +preprocess: + channel_names: -1 + num_workers: 32 + block_size: 32 + +# biahub concatenate submits its own SLURM jobs via submitit (no config needed) +# QC and preprocess run as separate SLURM jobs (no race condition) +slurm: + qc: + partition: gpu + gres: "gpu:1" + cpus_per_task: 16 + mem_per_cpu: 4G + time: "00:30:00" + preprocess: + partition: cpu + cpus_per_task: 32 + mem_per_cpu: 4G + time: "04:00:00" diff --git a/applications/airtable/scripts/write_experiment_metadata.py b/applications/airtable/scripts/write_experiment_metadata.py index 192bff024..6b0ce2853 100644 --- a/applications/airtable/scripts/write_experiment_metadata.py +++ b/applications/airtable/scripts/write_experiment_metadata.py @@ -68,6 +68,9 @@ def register(position_paths: list[Path], dry_run: bool = False, dataset: str | N if result.updated: db.batch_update(result.updated) logger.info("Updated %d existing records", len(result.updated)) + if result.template_ids_to_delete: + db.batch_delete(result.template_ids_to_delete) + logger.info("Deleted %d well template records", len(result.template_ids_to_delete)) print(format_register_summary(result, dry_run=dry_run)) diff --git a/applications/airtable/src/airtable_utils/database.py b/applications/airtable/src/airtable_utils/database.py index 1cb9ffd06..c1fd19a70 100644 --- a/applications/airtable/src/airtable_utils/database.py +++ b/applications/airtable/src/airtable_utils/database.py @@ -143,3 +143,18 @@ def batch_create(self, records: list[dict]) -> list[dict]: Created records as returned by the Airtable API. """ return self._table.batch_create([r["fields"] for r in records]) + + def batch_delete(self, record_ids: list[str]) -> list[dict]: + """Batch-delete records by ID. + + Parameters + ---------- + record_ids : list[str] + Airtable record IDs to delete. + + Returns + ------- + list[dict] + Deletion confirmations from the Airtable API. + """ + return self._table.batch_delete(record_ids) diff --git a/applications/airtable/src/airtable_utils/prepare.py b/applications/airtable/src/airtable_utils/prepare.py new file mode 100644 index 000000000..d35333a11 --- /dev/null +++ b/applications/airtable/src/airtable_utils/prepare.py @@ -0,0 +1,672 @@ +"""Config-driven dataset preparation: NFS -> VAST rechunked zarr v3.""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from textwrap import dedent + +import yaml +from iohub import open_ome_zarr +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Pydantic config models +# --------------------------------------------------------------------------- + + +class ConcatenateConfig(BaseModel): + """Parameters for biahub concatenate.""" + + channel_names: list[str] | None = None + chunks_czyx: list[int] = [1, 16, 256, 256] + shards_ratio: list[int] = [1, 1, 8, 8, 8] + output_ome_zarr_version: str = "0.5" + conda_env: str = "biahub" + sbatch_overrides: dict[str, str] | None = None + + +class QCParams(BaseModel): + """Focus-slice QC parameters.""" + + channel_names: list[str] = ["Phase3D"] + NA_det: float = 1.35 + lambda_ill: float = 0.450 + pixel_size: float = 0.1494 + midband_fractions: tuple[float, float] = (0.125, 0.25) + device: str = "cuda" + num_workers: int = 16 + + +class PreprocessParams(BaseModel): + """Normalization preprocessing parameters.""" + + channel_names: int | list[str] = -1 + num_workers: int = 48 + block_size: int = 32 + + +class SlurmStageConfig(BaseModel): + """SLURM resource settings for one job stage.""" + + partition: str + cpus_per_task: int = 24 + mem_per_cpu: str = "4G" + time: str = "06:00:00" + gres: str | None = None + constraint: str | None = None + + +class SlurmConfig(BaseModel): + """SLURM settings for QC and preprocess stages (separate jobs). + + The concatenation stage is not a SLURM job — ``biahub concatenate`` + submits its own SLURM jobs internally via submitit. + """ + + qc: SlurmStageConfig = Field( + default_factory=lambda: SlurmStageConfig( + partition="gpu", + gres="gpu:1", + cpus_per_task=16, + mem_per_cpu="4G", + time="00:30:00", + ) + ) + preprocess: SlurmStageConfig = Field( + default_factory=lambda: SlurmStageConfig( + partition="preempted", + cpus_per_task=16, + mem_per_cpu="4G", + time="04:00:00", + ) + ) + + +class PrepareConfig(BaseModel): + """Top-level prepare pipeline configuration.""" + + nfs_root: Path = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics") + vast_root: Path = Path("/hpc/projects/organelle_phenotyping/datasets") + workspace_dir: Path = Path("/hpc/mydata/eduardo.hirata/repos/viscy") + concatenate: ConcatenateConfig = Field(default_factory=ConcatenateConfig) + qc: QCParams = Field(default_factory=QCParams) + preprocess: PreprocessParams = Field(default_factory=PreprocessParams) + slurm: SlurmConfig = Field(default_factory=SlurmConfig) + + +# --------------------------------------------------------------------------- +# Path resolution +# --------------------------------------------------------------------------- + + +def resolve_nfs_paths(dataset_name: str, nfs_root: Path) -> dict[str, Path]: + """Return NFS zarr and tracking paths for a dataset. + + Parameters + ---------- + dataset_name : str + Dataset identifier, e.g. ``"2025_01_22_A549_G3BP1_ZIKV_DENV"``. + nfs_root : Path + Root of organelle_dynamics on NFS. + + Returns + ------- + dict[str, Path] + Keys: ``zarr``, ``tracking``. + + Raises + ------ + FileNotFoundError + If the assembled zarr does not exist on NFS. + """ + zarr_path = nfs_root / dataset_name / "2-assemble" / f"{dataset_name}.zarr" + tracking_path = nfs_root / dataset_name / "1-preprocess" / "label-free" / "3-track" / f"{dataset_name}_cropped.zarr" + if not zarr_path.exists(): + raise FileNotFoundError(f"NFS zarr not found: {zarr_path}") + return {"zarr": zarr_path, "tracking": tracking_path} + + +def resolve_vast_paths(dataset_name: str, vast_root: Path) -> dict[str, Path]: + """Return expected VAST output paths for a dataset. + + Parameters + ---------- + dataset_name : str + Dataset identifier. + vast_root : Path + Root of datasets directory on VAST. + + Returns + ------- + dict[str, Path] + Keys: ``output_dir``, ``zarr``, ``tracking``. + """ + output_dir = vast_root / dataset_name + return { + "output_dir": output_dir, + "zarr": output_dir / f"{dataset_name}.zarr", + "tracking": output_dir / "tracking.zarr", + } + + +# --------------------------------------------------------------------------- +# Zarr version validation +# --------------------------------------------------------------------------- + + +def check_zarr_version(zarr_path: Path) -> dict[str, int | str | None]: + """Check zarr format version and OME-Zarr version of an existing store. + + Parameters + ---------- + zarr_path : Path + Path to the zarr store root. + + Returns + ------- + dict[str, int | str | None] + Keys: ``zarr_format`` (2, 3, or None), ``ome_version`` (e.g. "0.5" or None). + """ + result: dict[str, int | str | None] = {"zarr_format": None, "ome_version": None} + + zarr_json = zarr_path / "zarr.json" + zgroup = zarr_path / ".zgroup" + + if zarr_json.exists(): + with open(zarr_json) as f: + meta = json.load(f) + result["zarr_format"] = meta.get("zarr_format", 3) + ome = meta.get("attributes", {}).get("ome", {}) + result["ome_version"] = ome.get("version") + elif zgroup.exists(): + with open(zgroup) as f: + meta = json.load(f) + result["zarr_format"] = meta.get("zarr_format", 2) + zattrs = zarr_path / ".zattrs" + if zattrs.exists(): + with open(zattrs) as f: + attrs = json.load(f) + result["ome_version"] = attrs.get("plate", {}).get("version") + + return result + + +def check_preprocessed(zarr_path: Path) -> bool: + """Check if normalization metadata has been written to the zarr store. + + Parameters + ---------- + zarr_path : Path + Path to the zarr store root. + + Returns + ------- + bool + True if normalization stats are present. + """ + zarr_json = zarr_path / "zarr.json" + zattrs = zarr_path / ".zattrs" + + if zarr_json.exists(): + with open(zarr_json) as f: + meta = json.load(f) + return "normalization" in meta.get("attributes", {}) + elif zattrs.exists(): + with open(zattrs) as f: + attrs = json.load(f) + return "normalization" in attrs + + return False + + +# --------------------------------------------------------------------------- +# Discovery (reads NFS zarr via iohub) +# --------------------------------------------------------------------------- + + +def discover_wells(nfs_zarr_path: Path) -> list[str]: + """Enumerate well paths from an NFS OME-Zarr plate. + + Returns well-level paths (e.g. ``"B/1"``) not full position paths. + The ``crop_concat.yml`` format expects ``{zarr}/{well}/*`` globs + so that biahub concatenate can discover positions within each well. + + Parameters + ---------- + nfs_zarr_path : Path + Path to the assembled zarr on NFS. + + Returns + ------- + list[str] + Sorted well paths like ``["A/1", "B/1", "C/2"]``. + """ + wells: list[str] = [] + with open_ome_zarr(str(nfs_zarr_path), mode="r") as plate: + for pos_path, _pos in plate.positions(): + # pos_path is like "A/1/000000" — extract well as "A/1" + well = "/".join(pos_path.split("/")[:2]) + if well not in wells: + wells.append(well) + return sorted(wells) + + +def discover_channels(nfs_zarr_path: Path) -> list[str]: + """Read channel names from an NFS OME-Zarr plate. + + Parameters + ---------- + nfs_zarr_path : Path + Path to the assembled zarr on NFS. + + Returns + ------- + list[str] + Channel names, e.g. ``["Phase3D", "raw GFP EX488 EM525-45", ...]``. + """ + with open_ome_zarr(str(nfs_zarr_path), mode="r") as plate: + return list(plate.channel_names) + + +RAW_CHANNEL_PREFIXES = ("Phase3D", "raw ") + + +def filter_raw_channels(channel_names: list[str]) -> list[str]: + """Filter to only raw imaging channels (Phase3D and raw fluorescence). + + Excludes virtual stains (``nuclei_prediction``, ``membrane_prediction``), + deconvolved channels (``GFP EX488 ...`` without ``raw`` prefix), and + other derived channels (``BF``). + + Parameters + ---------- + channel_names : list[str] + All channel names from the zarr. + + Returns + ------- + list[str] + Only channels starting with ``"Phase3D"`` or ``"raw "``. + """ + return [ch for ch in channel_names if ch.startswith(RAW_CHANNEL_PREFIXES)] + + +# --------------------------------------------------------------------------- +# Config generation +# --------------------------------------------------------------------------- + + +def generate_crop_concat_config( + nfs_zarr_path: Path, + wells: list[str], + channel_names: list[str], + concat_cfg: ConcatenateConfig, +) -> dict: + """Build a crop_concat.yml dict for biahub concatenate. + + Parameters + ---------- + nfs_zarr_path : Path + Path to the source zarr on NFS. + wells : list[str] + Well paths like ``["A/1", "B/2"]`` (row/col level). + Each becomes ``"{zarr}/{well}/*"`` so biahub globs positions within. + channel_names : list[str] + Channel names (repeated once per well entry). + concat_cfg : ConcatenateConfig + Concatenation parameters. + + Returns + ------- + dict + Config dict ready to write as YAML. + """ + concat_data_paths = [f"{nfs_zarr_path}/{well}/*" for well in wells] + return { + "concat_data_paths": concat_data_paths, + "time_indices": "all", + "channel_names": [channel_names] * len(wells), + "X_slice": "all", + "Y_slice": "all", + "Z_slice": "all", + "chunks_czyx": concat_cfg.chunks_czyx, + "shards_ratio": concat_cfg.shards_ratio, + "output_ome_zarr_version": concat_cfg.output_ome_zarr_version, + } + + +def generate_qc_config(data_path: Path, qc_params: QCParams) -> dict: + """Build a QC config dict compatible with ``qc run -c``. + + Parameters + ---------- + data_path : Path + Path to the VAST zarr (target of QC). + qc_params : QCParams + Focus-slice QC parameters. + + Returns + ------- + dict + Config dict ready to write as YAML. + """ + return { + "data_path": str(data_path), + "num_workers": qc_params.num_workers, + "focus_slice": { + "channel_names": qc_params.channel_names, + "NA_det": qc_params.NA_det, + "lambda_ill": qc_params.lambda_ill, + "pixel_size": qc_params.pixel_size, + "midband_fractions": list(qc_params.midband_fractions), + "device": qc_params.device, + }, + } + + +def write_yaml(config: dict, output_path: Path) -> None: + """Write a dict to a YAML file. + + Parameters + ---------- + config : dict + Config to serialize. + output_path : Path + Destination file path. + """ + + # Use a Dumper subclass that avoids YAML anchors/aliases for repeated + # lists. Patching yaml.Dumper directly leaks into every other yaml.dump + # in the same Python process. + class _NoAliasDumper(yaml.Dumper): + def ignore_aliases(self, data: object) -> bool: + return True + + with open(output_path, "w") as f: + yaml.dump(config, f, Dumper=_NoAliasDumper, default_flow_style=False, sort_keys=False) + + +# --------------------------------------------------------------------------- +# SLURM script generation +# --------------------------------------------------------------------------- + + +def _slurm_header(job_name: str, output_dir: Path, cfg: SlurmStageConfig) -> str: + """Build SBATCH header lines.""" + lines = [ + "#!/bin/bash", + f"#SBATCH --job-name={job_name}", + "#SBATCH --nodes=1", + "#SBATCH --ntasks-per-node=1", + f"#SBATCH --partition={cfg.partition}", + f"#SBATCH --cpus-per-task={cfg.cpus_per_task}", + f"#SBATCH --mem-per-cpu={cfg.mem_per_cpu}", + f"#SBATCH --time={cfg.time}", + f"#SBATCH --output={output_dir}/slurm_{job_name}_%j.out", + ] + if cfg.gres: + lines.append(f"#SBATCH --gres={cfg.gres}") + if cfg.constraint: + lines.append(f'#SBATCH --constraint="{cfg.constraint}"') + return "\n".join(lines) + + +def generate_sbatch_override_file(overrides: dict[str, str]) -> str: + """Generate content for a biahub sbatch override file. + + Parameters + ---------- + overrides : dict[str, str] + SLURM directive keys and values, e.g. + ``{"partition": "preempted", "mem-per-cpu": "16G"}``. + + Returns + ------- + str + File content with ``#SBATCH`` lines. + """ + lines = ["#!/bin/bash"] + for key, value in overrides.items(): + lines.append(f"#SBATCH --{key}={value}") + return "\n".join(lines) + "\n" + + +def generate_concatenate_script( + crop_concat_path: Path, + vast_zarr_path: Path, + nfs_tracking_path: Path, + vast_tracking_path: Path, + conda_env: str, + sbatch_override_path: Path | None = None, +) -> str: + """Generate a bash script for biahub concatenate + tracking copy. + + This is NOT a SLURM script. ``biahub concatenate`` submits its own + SLURM jobs internally via submitit. The ``-m`` flag makes it block + until those jobs complete. After concatenation, tracking is rsynced. + + Parameters + ---------- + crop_concat_path : Path + Path to the generated crop_concat.yml. + vast_zarr_path : Path + Target zarr output path. + nfs_tracking_path : Path + Source tracking zarr on NFS. + vast_tracking_path : Path + Target tracking zarr on VAST. + conda_env : str + Conda environment name for biahub. + sbatch_override_path : Path or None + Path to sbatch override file for biahub's internal SLURM jobs. + + Returns + ------- + str + Bash script content. + """ + # Build the biahub command as a single line to avoid conda run + # swallowing backslash continuations. + cmd_parts = [ + f"conda run -n {conda_env} biahub concatenate", + f'-c "{crop_concat_path}"', + f'-o "{vast_zarr_path}"', + "-m", + ] + if sbatch_override_path: + cmd_parts.append(f'-sb "{sbatch_override_path}"') + biahub_cmd = " ".join(cmd_parts) + + return dedent(f"""\ + #!/bin/bash + set -euo pipefail + + echo "=== Step 1: biahub concatenate (submits SLURM jobs via submitit) ===" + {biahub_cmd} + echo "Concatenation complete." + + echo "=== Step 2: Copy tracking zarr ===" + if [ -d "{nfs_tracking_path}" ]; then + rsync -a --copy-links "{nfs_tracking_path}/" "{vast_tracking_path}/" + echo "Tracking copy complete." + else + echo "WARNING: NFS tracking zarr not found at {nfs_tracking_path}, skipping." + fi + """) + + +def generate_qc_slurm( + dataset_name: str, + vast_output_dir: Path, + qc_config_path: Path, + workspace_dir: Path, + slurm_cfg: SlurmStageConfig, +) -> str: + """Generate SLURM script for focus-slice QC (needs GPU). + + Parameters + ---------- + dataset_name : str + Dataset identifier (used for job name). + vast_output_dir : Path + Output directory on VAST. + qc_config_path : Path + Path to the generated qc_config.yml. + workspace_dir : Path + Path to the viscy repo root. + slurm_cfg : SlurmStageConfig + SLURM resource parameters. + + Returns + ------- + str + Complete SLURM script content. + """ + header = _slurm_header(f"qc_{dataset_name}", vast_output_dir, slurm_cfg) + body = dedent(f"""\ + + export PYTHONNOUSERSITE=1 + + echo "=== QC: focus slice detection ===" + uv run --project "{workspace_dir}" --package qc \ + qc run -c "{qc_config_path}" + echo "QC complete." + """) + return header + "\n" + body + + +def generate_preprocess_slurm( + dataset_name: str, + vast_output_dir: Path, + vast_zarr_path: Path, + workspace_dir: Path, + preprocess_params: PreprocessParams, + slurm_cfg: SlurmStageConfig, +) -> str: + """Generate SLURM script for normalization preprocessing (CPU only). + + Parameters + ---------- + dataset_name : str + Dataset identifier (used for job name). + vast_output_dir : Path + Output directory on VAST. + vast_zarr_path : Path + Path to the rechunked zarr on VAST. + workspace_dir : Path + Path to the viscy repo root. + preprocess_params : PreprocessParams + Normalization preprocessing parameters. + slurm_cfg : SlurmStageConfig + SLURM resource parameters. + + Returns + ------- + str + Complete SLURM script content. + """ + header = _slurm_header(f"preprocess_{dataset_name}", vast_output_dir, slurm_cfg) + + ch_arg = preprocess_params.channel_names + if isinstance(ch_arg, int): + ch_flag = f"--channel_names={ch_arg}" + else: + ch_flag = " ".join(f"--channel_names={c}" for c in ch_arg) + + body = dedent(f"""\ + + export PYTHONNOUSERSITE=1 + + echo "=== Preprocess: normalization stats ===" + echo "Data: {vast_zarr_path}" + uv run --project "{workspace_dir}" --package dynaclr \ + viscy preprocess --data_path "{vast_zarr_path}" \ + {ch_flag} --num_workers {preprocess_params.num_workers} \ + --block_size {preprocess_params.block_size} + echo "Preprocess complete." + """) + return header + "\n" + body + + +# --------------------------------------------------------------------------- +# Status check +# --------------------------------------------------------------------------- + + +def check_dataset_status(dataset_name: str, nfs_root: Path, vast_root: Path) -> dict[str, str]: + """Check existence and version info for a dataset across NFS and VAST. + + Parameters + ---------- + dataset_name : str + Dataset identifier. + nfs_root : Path + NFS root directory. + vast_root : Path + VAST root directory. + + Returns + ------- + dict[str, str] + Status fields for the dataset. + """ + nfs_zarr = nfs_root / dataset_name / "2-assemble" / f"{dataset_name}.zarr" + vast = resolve_vast_paths(dataset_name, vast_root) + + nfs_exists = nfs_zarr.exists() + vast_zarr_exists = vast["zarr"].exists() + vast_tracking_exists = vast["tracking"].exists() + + zarr_fmt: str = "-" + ome_ver: str = "-" + preprocessed: str = "-" + + if vast_zarr_exists: + ver = check_zarr_version(vast["zarr"]) + zarr_fmt = str(ver["zarr_format"]) if ver["zarr_format"] else "?" + ome_ver = str(ver["ome_version"]) if ver["ome_version"] else "?" + preprocessed = "yes" if check_preprocessed(vast["zarr"]) else "no" + + return { + "dataset": dataset_name, + "nfs": "yes" if nfs_exists else "no", + "vast_zarr": "yes" if vast_zarr_exists else "no", + "zarr_version": zarr_fmt, + "ome_version": ome_ver, + "tracking": "yes" if vast_tracking_exists else "no", + "preprocessed": preprocessed, + } + + +def format_status_table(rows: list[dict[str, str]]) -> str: + """Format dataset status rows as a markdown table. + + Parameters + ---------- + rows : list[dict[str, str]] + Each dict from :func:`check_dataset_status`. + + Returns + ------- + str + Markdown table string. + """ + headers = [ + "dataset", + "nfs", + "vast_zarr", + "zarr_version", + "ome_version", + "tracking", + "preprocessed", + ] + col_widths = {h: max(len(h), *(len(r[h]) for r in rows)) for h in headers} + + header_line = "| " + " | ".join(h.ljust(col_widths[h]) for h in headers) + " |" + sep_line = "| " + " | ".join("-" * col_widths[h] for h in headers) + " |" + data_lines = ["| " + " | ".join(r[h].ljust(col_widths[h]) for h in headers) + " |" for r in rows] + return "\n".join([header_line, sep_line, *data_lines]) diff --git a/applications/airtable/src/airtable_utils/prepare_cli.py b/applications/airtable/src/airtable_utils/prepare_cli.py new file mode 100644 index 000000000..c4e9486bb --- /dev/null +++ b/applications/airtable/src/airtable_utils/prepare_cli.py @@ -0,0 +1,259 @@ +"""CLI for config-driven dataset preparation (NFS -> VAST).""" + +from __future__ import annotations + +import logging +import re +import subprocess + +import click + +from airtable_utils.prepare import ( + PrepareConfig, + check_dataset_status, + check_preprocessed, + check_zarr_version, + discover_channels, + discover_wells, + filter_raw_channels, + format_status_table, + generate_concatenate_script, + generate_crop_concat_config, + generate_preprocess_slurm, + generate_qc_config, + generate_qc_slurm, + generate_sbatch_override_file, + resolve_nfs_paths, + resolve_vast_paths, + write_yaml, +) + +logger = logging.getLogger(__name__) + +CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"]} + + +def _load_prepare_config(config_path: str) -> PrepareConfig: + """Load and validate a prepare config YAML.""" + from viscy_utils.cli_utils import load_config + + raw = load_config(config_path) + return PrepareConfig(**raw) + + +def _parse_slurm_job_id(sbatch_output: str) -> str: + """Extract job ID from sbatch stdout like 'Submitted batch job 12345'.""" + match = re.search(r"Submitted batch job (\d+)", sbatch_output) + if not match: + raise RuntimeError(f"Could not parse sbatch output: {sbatch_output}") + return match.group(1) + + +@click.group(context_settings=CONTEXT_SETTINGS) +def prepare(): + """Prepare datasets for training on VAST storage.""" + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +@prepare.command() +@click.argument("dataset_name") +@click.option( + "-c", + "--config", + "config_path", + required=True, + type=click.Path(exists=True), + help="Path to prepare config YAML.", +) +@click.option("--dry-run", is_flag=True, help="Generate configs without submitting SLURM jobs.") +@click.option("--force", is_flag=True, help="Overwrite existing VAST zarr even if it is zarr v2.") +def run(dataset_name: str, config_path: str, dry_run: bool, force: bool) -> None: + """Run the full preparation pipeline for DATASET_NAME. + + Steps: Airtable validation -> discover positions/channels -> generate + crop_concat.yml + qc_config.yml + SLURM scripts -> submit jobs. + """ + cfg = _load_prepare_config(config_path) + + # 1. Validate dataset is registered in Airtable + click.echo(f"Validating {dataset_name} in Airtable...") + from airtable_utils.database import AirtableDatasets + + db = AirtableDatasets() + records = db.get_dataset_records(dataset_name) + if not records: + raise click.ClickException( + f"Dataset '{dataset_name}' not found in Airtable. Register it first with the airtable-register workflow." + ) + click.echo(f" Found {len(records)} FOV records in Airtable.") + + # 2. Resolve NFS paths + nfs = resolve_nfs_paths(dataset_name, cfg.nfs_root) + click.echo(f" NFS zarr: {nfs['zarr']}") + + # 3. Resolve VAST paths + vast = resolve_vast_paths(dataset_name, cfg.vast_root) + click.echo(f" VAST output: {vast['output_dir']}") + + # 4. Check existing VAST zarr + if vast["zarr"].exists(): + ver = check_zarr_version(vast["zarr"]) + is_v3 = ver["zarr_format"] == 3 + is_ome05 = ver["ome_version"] == "0.5" + is_preprocessed = check_preprocessed(vast["zarr"]) + + if is_v3 and is_ome05 and is_preprocessed: + click.echo( + f" VAST zarr already exists: zarr v{ver['zarr_format']}, " + f"OME {ver['ome_version']}, preprocessed. Skipping." + ) + return + + if not force: + msg = ( + f"VAST zarr already exists at {vast['zarr']} " + f"(zarr v{ver['zarr_format']}, OME {ver['ome_version']}, " + f"preprocessed={is_preprocessed}). " + "Use --force to overwrite." + ) + raise click.ClickException(msg) + + click.echo(f" WARNING: Overwriting existing VAST zarr (zarr v{ver['zarr_format']}, OME {ver['ome_version']}).") + + # 5. Discover wells and resolve channels from NFS zarr + click.echo("Discovering wells and channels from NFS zarr...") + wells = discover_wells(nfs["zarr"]) + zarr_channels = discover_channels(nfs["zarr"]) + + if cfg.concatenate.channel_names is not None: + concat_channels = cfg.concatenate.channel_names + missing = [ch for ch in concat_channels if ch not in zarr_channels] + if missing: + raise click.ClickException(f"Channels {missing} from config not found in zarr. Available: {zarr_channels}") + else: + concat_channels = filter_raw_channels(zarr_channels) + if not concat_channels: + raise click.ClickException(f"No raw channels found in zarr. Available: {zarr_channels}") + + click.echo(f" Wells: {wells}") + click.echo(f" Zarr channels: {zarr_channels}") + click.echo(f" Extracting: {concat_channels}") + + # 6. Create output directory + vast["output_dir"].mkdir(parents=True, exist_ok=True) + + # 7. Generate crop_concat.yml + crop_concat_cfg = generate_crop_concat_config(nfs["zarr"], wells, concat_channels, cfg.concatenate) + crop_concat_path = vast["output_dir"] / "crop_concat.yml" + write_yaml(crop_concat_cfg, crop_concat_path) + click.echo(f" Wrote: {crop_concat_path}") + + # 8. Generate qc_config.yml + qc_cfg = generate_qc_config(vast["zarr"], cfg.qc) + qc_config_path = vast["output_dir"] / "qc_config.yml" + write_yaml(qc_cfg, qc_config_path) + click.echo(f" Wrote: {qc_config_path}") + + # 9. Generate scripts + sbatch_override_path = None + if cfg.concatenate.sbatch_overrides: + sbatch_content = generate_sbatch_override_file(cfg.concatenate.sbatch_overrides) + sbatch_override_path = vast["output_dir"] / "sbatch_overrides.sh" + sbatch_override_path.write_text(sbatch_content) + click.echo(f" Wrote: {sbatch_override_path}") + + concat_script = generate_concatenate_script( + crop_concat_path=crop_concat_path, + vast_zarr_path=vast["zarr"], + nfs_tracking_path=nfs["tracking"], + vast_tracking_path=vast["tracking"], + conda_env=cfg.concatenate.conda_env, + sbatch_override_path=sbatch_override_path, + ) + concat_script_path = vast["output_dir"] / "01_concatenate.sh" + concat_script_path.write_text(concat_script) + click.echo(f" Wrote: {concat_script_path}") + + qc_script = generate_qc_slurm( + dataset_name=dataset_name, + vast_output_dir=vast["output_dir"], + qc_config_path=qc_config_path, + workspace_dir=cfg.workspace_dir, + slurm_cfg=cfg.slurm.qc, + ) + qc_script_path = vast["output_dir"] / "02_qc.sh" + qc_script_path.write_text(qc_script) + click.echo(f" Wrote: {qc_script_path}") + + preprocess_script = generate_preprocess_slurm( + dataset_name=dataset_name, + vast_output_dir=vast["output_dir"], + vast_zarr_path=vast["zarr"], + workspace_dir=cfg.workspace_dir, + preprocess_params=cfg.preprocess, + slurm_cfg=cfg.slurm.preprocess, + ) + preprocess_script_path = vast["output_dir"] / "03_preprocess.sh" + preprocess_script_path.write_text(preprocess_script) + click.echo(f" Wrote: {preprocess_script_path}") + + if dry_run: + click.echo("\n--dry-run: configs and scripts generated, nothing executed.") + return + + # 10. Run concatenation (biahub submits its own SLURM jobs via submitit) + click.echo("\nRunning biahub concatenate + tracking copy...") + click.echo(" (biahub will submit SLURM jobs internally and -m will monitor them)") + subprocess.run(["bash", str(concat_script_path)], check=True) + click.echo("Concatenation and tracking copy complete.") + + # 11. Submit QC and preprocess as separate SLURM jobs (no dependency, no race condition) + click.echo("\nSubmitting QC and preprocess SLURM jobs...") + result_qc = subprocess.run( + ["sbatch", str(qc_script_path)], + capture_output=True, + text=True, + check=True, + ) + qc_job_id = _parse_slurm_job_id(result_qc.stdout) + click.echo(f" QC job: {qc_job_id} (GPU, ~5-20 min)") + + result_pp = subprocess.run( + ["sbatch", str(preprocess_script_path)], + capture_output=True, + text=True, + check=True, + ) + pp_job_id = _parse_slurm_job_id(result_pp.stdout) + click.echo(f" Preprocess job: {pp_job_id} (CPU, ~3 hrs)") + + click.echo(f"\nPipeline running for {dataset_name}.") + click.echo(f" Output: {vast['output_dir']}") + click.echo(f" Monitor: squeue -j {qc_job_id},{pp_job_id}") + + +@prepare.command() +@click.argument("dataset_names", nargs=-1, required=True) +@click.option( + "-c", + "--config", + "config_path", + required=True, + type=click.Path(exists=True), + help="Path to prepare config YAML.", +) +def status(dataset_names: tuple[str, ...], config_path: str) -> None: + """Check NFS/VAST existence and version status for one or more datasets.""" + cfg = _load_prepare_config(config_path) + + rows = [check_dataset_status(name, cfg.nfs_root, cfg.vast_root) for name in dataset_names] + click.echo(format_status_table(rows)) + + +def main() -> None: + """Entry point for the prepare CLI.""" + prepare() + + +if __name__ == "__main__": + main() diff --git a/applications/airtable/src/airtable_utils/registration.py b/applications/airtable/src/airtable_utils/registration.py index c189ff1fa..e35072659 100644 --- a/applications/airtable/src/airtable_utils/registration.py +++ b/applications/airtable/src/airtable_utils/registration.py @@ -35,6 +35,10 @@ "seeding_density", "treatment_concentration_nm", "fluorescence_modality", + "microscope", + "labelfree_modality", + "treatment", + "hours_post_treatment", ) @@ -49,6 +53,7 @@ class RegisterResult: channel_names: list[str] = field(default_factory=list) pixel_size_xy_um: float | None = None pixel_size_z_um: float | None = None + template_ids_to_delete: list[str] = field(default_factory=list) def parse_position_path(position_path: Path) -> tuple[Path, str]: @@ -264,6 +269,7 @@ def format_register_summary(result: RegisterResult, dry_run: bool = False) -> st f"| created | {len(result.created)} |", f"| updated | {len(result.updated)} |", f"| unmatched | {len(result.unmatched)} |", + f"| templates_to_delete | {len(result.template_ids_to_delete)} |", f"| pixel_size_xy_um | {xy} |", f"| pixel_size_z_um | {z} |", f"| status | {status} |", @@ -421,8 +427,8 @@ def register_fovs( result = RegisterResult(dataset=dataset_name) - # Filter to directories only — glob("*/*/*") also picks up .zattrs/.zgroup files - pos_names = [p for p in pos_names if not Path(zarr_root / p).name.startswith(".")] + # Filter to directories only — glob("*/*/*") also picks up zarr.json, .zattrs, .zgroup files + pos_names = [p for p in pos_names if (zarr_root / p).is_dir()] with open_ome_zarr(str(zarr_root), mode="r") as plate: result.channel_names = plate.channel_names @@ -453,7 +459,13 @@ def register_fovs( # Resolve cell_line linked records -> registry entries -> marker rec_for_marker = fov_records.get((well_id, fov)) or well_templates.get(well_id) - if rec_for_marker is not None and rec_for_marker.cell_line: + if rec_for_marker is not None: + if not rec_for_marker.cell_line: + raise ValueError( + f"Well '{well_id}' has no cell_line set in Airtable. " + "cell_line is required for channel marker derivation — " + "fill it in the platemap before registering." + ) marker_entries = [registry[rid] for rid in rec_for_marker.cell_line if rid in registry] marker_fields = derive_channel_marker(result.channel_names, marker_entries) zarr_fields.update(marker_fields) @@ -478,4 +490,11 @@ def register_fovs( } result.created.append({"fields": fields}) + # Collect well template record IDs to delete — only for wells where at least + # one FOV was created from the template in this batch. + used_wells: set[str] = {rec["fields"]["well_id"] for rec in result.created} + for well_id, template in well_templates.items(): + if well_id in used_wells and template.record_id: + result.template_ids_to_delete.append(template.record_id) + return result diff --git a/applications/airtable/src/airtable_utils/schemas.py b/applications/airtable/src/airtable_utils/schemas.py index c84dd2930..1d608178b 100644 --- a/applications/airtable/src/airtable_utils/schemas.py +++ b/applications/airtable/src/airtable_utils/schemas.py @@ -131,7 +131,7 @@ class DatasetRecord(FOVRecord): @model_validator(mode="after") def _derive_channel_names(self) -> DatasetRecord: - """Populate ``channel_names`` from ``channel_0..7_name`` fields.""" + """Populate ``channel_names`` and ``channel_markers`` from ``channel_0..7_name/marker`` fields.""" if not self.channel_names: names = [] for i in range(MAX_CHANNELS): @@ -139,6 +139,14 @@ def _derive_channel_names(self) -> DatasetRecord: if name is not None: names.append(name) self.channel_names = names + if not self.channel_markers: + markers: dict[str, str] = {} + for i in range(MAX_CHANNELS): + name = getattr(self, f"channel_{i}_name") + marker = getattr(self, f"channel_{i}_marker") + if name is not None and marker is not None: + markers[name] = marker + self.channel_markers = markers return self @classmethod @@ -191,6 +199,10 @@ def _multi_select_val(v): data_path=fields.get("data_path"), tracks_path=fields.get("tracks_path"), fluorescence_modality=_select_val(fields.get("fluorescence_modality")), + microscope=_select_val(fields.get("microscope")), + labelfree_modality=_select_val(fields.get("labelfree_modality")), + treatment=_select_val(fields.get("treatment")), + hours_post_treatment=fields.get("hours post treatment"), t_shape=fields.get("t_shape"), c_shape=fields.get("c_shape"), z_shape=fields.get("z_shape"), diff --git a/applications/airtable/tests/conftest.py b/applications/airtable/tests/conftest.py index 728f0016a..2a3b7fddd 100644 --- a/applications/airtable/tests/conftest.py +++ b/applications/airtable/tests/conftest.py @@ -37,6 +37,10 @@ "channel_3_marker": None, "data_path": "/hpc/datasets/alpha.zarr", "fluorescence_modality": {"name": "widefield"}, + "microscope": {"name": "mantis"}, + "labelfree_modality": {"name": "widefield"}, + "treatment": {"name": "DMSO"}, + "hours post treatment": 2.0, "t_shape": 50, "c_shape": 2, "z_shape": 30, @@ -70,6 +74,10 @@ "channel_3_marker": None, "data_path": "/hpc/datasets/beta.zarr", "fluorescence_modality": None, + "microscope": "dragonfly", + "labelfree_modality": "oblique", + "treatment": None, + "hours post treatment": None, "t_shape": 100, "c_shape": 2, "z_shape": 15, diff --git a/applications/airtable/tests/test_database.py b/applications/airtable/tests/test_database.py index 15cbb5634..42f483fba 100644 --- a/applications/airtable/tests/test_database.py +++ b/applications/airtable/tests/test_database.py @@ -22,8 +22,8 @@ def test_init_with_env_vars(self, mock_env, mock_api): AirtableDatasets() # Api was called with the fake key mock_api.assert_called_once_with("patFAKEKEY123") - # .table() was called with the fake base id and TABLE_NAME - mock_api.return_value.table.assert_called_once_with("appFAKEBASE456", "Datasets") + # .table() is called twice: once for Datasets, once for Marker Registry + mock_api.return_value.table.assert_any_call("appFAKEBASE456", "Datasets") def test_init_raises_when_api_key_missing(self, monkeypatch): """ValueError is raised when AIRTABLE_API_KEY is not set.""" @@ -183,15 +183,43 @@ def test_dataframe_columns(self, airtable_datasets, mock_table, sample_airtable_ "seeding_density", "treatment_concentration_nm", "channel_names", + "channel_markers", *(f"channel_{i}_{attr}" for i in range(8) for attr in ("name", "marker")), "data_path", "tracks_path", "fluorescence_modality", + "microscope", + "labelfree_modality", + "treatment", + "hours_post_treatment", "t_shape", "c_shape", "z_shape", "y_shape", "x_shape", + "pixel_size_xy_um", + "pixel_size_z_um", "record_id", } assert set(df.columns) == expected_cols + + +# --------------------------------------------------------------------------- +# batch_delete +# --------------------------------------------------------------------------- + + +class TestBatchDelete: + """Test AirtableDatasets.batch_delete().""" + + def test_delegates_to_table(self, airtable_datasets, mock_table): + mock_table.batch_delete.return_value = [{"id": "rec001", "deleted": True}] + result = airtable_datasets.batch_delete(["rec001"]) + mock_table.batch_delete.assert_called_once_with(["rec001"]) + assert result == [{"id": "rec001", "deleted": True}] + + def test_passes_multiple_ids(self, airtable_datasets, mock_table): + ids = ["rec001", "rec002", "rec003"] + mock_table.batch_delete.return_value = [] + airtable_datasets.batch_delete(ids) + mock_table.batch_delete.assert_called_once_with(ids) diff --git a/applications/airtable/tests/test_register_fovs.py b/applications/airtable/tests/test_register_fovs.py index 0e9964c2f..aaddcd8e0 100644 --- a/applications/airtable/tests/test_register_fovs.py +++ b/applications/airtable/tests/test_register_fovs.py @@ -29,6 +29,7 @@ def _make_well_template(well_id: str, record_id: str | None = None, **overrides) "fov": None, "cell_type": "A549", "cell_state": "Live", + "cell_line": ["recCELLLINE1"], "marker": "TOMM20", "organelle": "mitochondria", "perturbation": "ZIKV", @@ -36,6 +37,10 @@ def _make_well_template(well_id: str, record_id: str | None = None, **overrides) "moi": 5.0, "time_interval_min": 30.0, "fluorescence_modality": "Light-sheet", + "microscope": "mantis", + "labelfree_modality": "widefield", + "treatment": "DMSO", + "hours_post_treatment": 2.0, "channel_0_marker": "brightfield", "channel_1_marker": "mitochondria", "record_id": record_id, @@ -51,6 +56,7 @@ def _make_fov_record(well_id: str, fov: str, record_id: str, **overrides) -> Dat "well_id": well_id, "fov": fov, "cell_type": "A549", + "cell_line": ["recCELLLINE1"], "marker": "TOMM20", "organelle": "mitochondria", "record_id": record_id, @@ -133,7 +139,10 @@ def test_creates_new_fov_records_from_well_templates(self): Path("/data/test_dataset.zarr/A/1/000000"), Path("/data/test_dataset.zarr/A/1/000001"), ] - with patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate): + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): result = register_fovs(paths, db=db) assert result.dataset == "test_dataset" @@ -159,8 +168,13 @@ def test_creates_new_fov_records_from_well_templates(self): assert rec0["organelle"] == "mitochondria" assert rec0["perturbation"] == "ZIKV" assert rec0["moi"] == 5.0 + assert rec0["microscope"] == "mantis" + assert rec0["labelfree_modality"] == "widefield" + assert rec0["treatment"] == "DMSO" + assert rec0["hours_post_treatment"] == 2.0 assert rec0["channel_0_marker"] == "brightfield" assert rec0["channel_1_marker"] == "mitochondria" + assert result.template_ids_to_delete == ["recWELL1"] def test_updates_existing_fov_records(self): """Existing per-FOV records get updated with zarr-derived fields only.""" @@ -172,7 +186,10 @@ def test_updates_existing_fov_records(self): mock_plate = _make_mock_plate(positions) paths = [Path("/data/test_dataset.zarr/A/1/000000")] - with patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate): + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): result = register_fovs(paths, db=db) assert len(result.created) == 0 @@ -202,7 +219,10 @@ def test_unmatched_positions(self): Path("/data/test_dataset.zarr/A/1/000000"), Path("/data/test_dataset.zarr/B/2/000000"), ] - with patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate): + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): result = register_fovs(paths, db=db) assert len(result.created) == 1 @@ -226,7 +246,10 @@ def test_mixed_create_and_update(self): Path("/data/test_dataset.zarr/A/1/000000"), Path("/data/test_dataset.zarr/A/1/000001"), ] - with patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate): + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): result = register_fovs(paths, db=db) assert len(result.updated) == 1 @@ -259,6 +282,23 @@ def test_raises_on_mixed_zarr_stores(self): with pytest.raises(ValueError, match="same zarr store"): register_fovs(paths, db=db) + def test_raises_when_cell_line_missing(self): + """ValueError raised when a well template has no cell_line set.""" + template_no_cell_line = _make_well_template("A/1", cell_line=None) + db = MagicMock() + db.get_dataset_records.return_value = [template_no_cell_line] + + positions = {"A/1/000000": (10, 3, 1, 512, 512)} + mock_plate = _make_mock_plate(positions) + + paths = [Path("/data/test_dataset.zarr/A/1/000000")] + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): + with pytest.raises(ValueError, match="cell_line is required"): + register_fovs(paths, db=db) + def test_all_records_already_per_fov_no_templates(self): """When all records are per-FOV and no templates exist, only updates happen.""" existing = _make_fov_record("A/1", "000000", record_id="recFOV1") @@ -275,7 +315,10 @@ def test_all_records_already_per_fov_no_templates(self): Path("/data/test_dataset.zarr/A/1/000000"), Path("/data/test_dataset.zarr/A/1/000001"), ] - with patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate): + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): result = register_fovs(paths, db=db) assert len(result.updated) == 1 @@ -341,12 +384,112 @@ def test_copies_non_none_fields(self): assert fields["perturbation"] == "ZIKV" assert fields["moi"] == 5.0 assert fields["time_interval_min"] == 30.0 + assert fields["microscope"] == "mantis" + assert fields["labelfree_modality"] == "widefield" + assert fields["treatment"] == "DMSO" + assert fields["hours_post_treatment"] == 2.0 assert fields["channel_0_marker"] == "brightfield" assert fields["channel_1_marker"] == "mitochondria" def test_skips_none_fields(self): - template = _make_well_template("A/1", seeding_density=None, treatment_concentration_nm=None) + template = _make_well_template( + "A/1", + seeding_density=None, + treatment_concentration_nm=None, + microscope=None, + labelfree_modality=None, + ) fields = copy_well_template_fields(template) assert "seeding_density" not in fields assert "treatment_concentration_nm" not in fields + assert "microscope" not in fields + assert "labelfree_modality" not in fields + + +# --------------------------------------------------------------------------- +# template deletion tracking +# --------------------------------------------------------------------------- + + +class TestTemplateDeletion: + """Tests for template_ids_to_delete population in register_fovs.""" + + def test_template_deleted_when_fov_created(self): + """Template record ID appears in deletion list when FOVs are created from it.""" + template_a1 = _make_well_template("A/1", record_id="recWELL1") + db = MagicMock() + db.get_dataset_records.return_value = [template_a1] + + positions = {"A/1/000000": (10, 3, 1, 512, 512)} + mock_plate = _make_mock_plate(positions) + + paths = [Path("/data/test_dataset.zarr/A/1/000000")] + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): + result = register_fovs(paths, db=db) + + assert len(result.created) == 1 + assert result.template_ids_to_delete == ["recWELL1"] + + def test_template_not_deleted_when_all_positions_unmatched(self): + """Template with no created FOVs is not in deletion list.""" + template_a1 = _make_well_template("A/1", record_id="recWELL1") + db = MagicMock() + db.get_dataset_records.return_value = [template_a1] + + # B/2 has no template — will be unmatched + positions = {"B/2/000000": (10, 3, 1, 512, 512)} + mock_plate = _make_mock_plate(positions) + + paths = [Path("/data/test_dataset.zarr/B/2/000000")] + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): + result = register_fovs(paths, db=db) + + assert len(result.unmatched) == 1 + assert result.template_ids_to_delete == [] + + def test_only_used_templates_deleted(self): + """Only templates where at least one FOV was created appear in deletion list.""" + template_a1 = _make_well_template("A/1", record_id="recWELL_A1") + template_b2 = _make_well_template("B/2", record_id="recWELL_B2") + db = MagicMock() + db.get_dataset_records.return_value = [template_a1, template_b2] + + # A/1 gets a FOV; B/2 gets no positions in this batch + positions = {"A/1/000000": (10, 3, 1, 512, 512)} + mock_plate = _make_mock_plate(positions) + + paths = [Path("/data/test_dataset.zarr/A/1/000000")] + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): + result = register_fovs(paths, db=db) + + assert len(result.created) == 1 + assert result.template_ids_to_delete == ["recWELL_A1"] + + def test_template_without_record_id_not_added(self): + """Template with no record_id is skipped in deletion list.""" + template_a1 = _make_well_template("A/1", record_id=None) + db = MagicMock() + db.get_dataset_records.return_value = [template_a1] + + positions = {"A/1/000000": (10, 3, 1, 512, 512)} + mock_plate = _make_mock_plate(positions) + + paths = [Path("/data/test_dataset.zarr/A/1/000000")] + with ( + patch("airtable_utils.registration.open_ome_zarr", return_value=mock_plate), + patch("pathlib.Path.is_dir", return_value=True), + ): + result = register_fovs(paths, db=db) + + assert len(result.created) == 1 + assert result.template_ids_to_delete == [] diff --git a/applications/airtable/tests/test_schemas.py b/applications/airtable/tests/test_schemas.py index 7917af8ba..11e611355 100644 --- a/applications/airtable/tests/test_schemas.py +++ b/applications/airtable/tests/test_schemas.py @@ -164,6 +164,10 @@ def test_full_record_with_select_dicts(self, sample_airtable_records): assert rec.channel_1_marker == "Endoplasmic Reticulum" assert rec.data_path == "/hpc/datasets/alpha.zarr" assert rec.fluorescence_modality == "widefield" + assert rec.microscope == "mantis" + assert rec.labelfree_modality == "widefield" + assert rec.treatment == "DMSO" + assert rec.hours_post_treatment == 2.0 assert rec.t_shape == 50 assert rec.c_shape == 2 assert rec.z_shape == 30 @@ -181,6 +185,10 @@ def test_record_with_plain_string_fields(self, sample_airtable_records): assert rec.perturbation == "ZIKV" assert rec.moi == 0.5 assert rec.cell_line is None + assert rec.microscope == "dragonfly" + assert rec.labelfree_modality == "oblique" + assert rec.treatment is None + assert rec.hours_post_treatment is None def test_minimal_record(self): """Record with only required fields.""" diff --git a/applications/dynaclr/configs/cellanome/embed_all.sh b/applications/dynaclr/configs/cellanome/embed_all.sh new file mode 100755 index 000000000..3671d4bd5 --- /dev/null +++ b/applications/dynaclr/configs/cellanome/embed_all.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# SLURM array job: generate DINOv3 + DynaCLR embeddings for all 5 cellanome datasets. +# Array index: 0-9 (5 datasets × 2 models) +# 0-4 → DINOv3 +# 5-9 → DynaCLR +# +# Usage: +# sbatch embed_all.sh +# # or a single task interactively: +# SLURM_ARRAY_TASK_ID=0 bash embed_all.sh + +#SBATCH --job-name=cellanome_embed +#SBATCH --array=0-9 +#SBATCH --partition=gpu +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=64G +#SBATCH --time=4:00:00 +#SBATCH --output=/hpc/mydata/eduardo.hirata/logs/cellanome_embed_%A_%a.out +#SBATCH --error=/hpc/mydata/eduardo.hirata/logs/cellanome_embed_%A_%a.err + +export PYTHONNOUSERSITE=1 + +REPO=/home/eduardo.hirata/repos/viscy +CFG_ROOT="${REPO}/applications/dynaclr/configs/cellanome" + +DATASETS=( + "20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes" + "20260211112411_P-05_R000439_FC_2026_02_11_manual_loading_mixed_GFP+RFP" + "20260220144306_P-05_R000476_FC_2026_02_20_A549_GFP_RFP_Org_Cells" + "20260310112219_P-05_R000486_FC_2026_03_10_A549_pAL27+ISG15_off_on_DENV" + "20260324133209_P-05_R000497_FC_2026_03_24_A549_SEC61B_G3BP1_pAL40_DENV_rerun" +) + +TASK=${SLURM_ARRAY_TASK_ID} +N=${#DATASETS[@]} # 5 + +DATASET_IDX=$(( TASK % N )) +MODEL_IDX=$(( TASK / N )) # 0 = DINOv3, 1 = DynaCLR + +DATASET="${DATASETS[$DATASET_IDX]}" + +if [ "$MODEL_IDX" -eq 0 ]; then + SCRIPT="${REPO}/applications/dynaclr/scripts/cellanome/embed_dinov3.py" + CONFIG="${CFG_ROOT}/${DATASET}/embed_dinov3.yml" +else + SCRIPT="${REPO}/applications/dynaclr/scripts/cellanome/embed_dynaclr.py" + CONFIG="${CFG_ROOT}/${DATASET}/embed_dynaclr.yml" +fi + +echo "Task ${TASK}: dataset=${DATASET} model_idx=${MODEL_IDX}" +echo "Config: ${CONFIG}" + +cd "${REPO}" +uv run python "${SCRIPT}" "${CONFIG}" diff --git a/applications/dynaclr/configs/cellanome/embed_dinov3.yml b/applications/dynaclr/configs/cellanome/embed_dinov3.yml new file mode 100644 index 000000000..0314593a6 --- /dev/null +++ b/applications/dynaclr/configs/cellanome/embed_dinov3.yml @@ -0,0 +1,38 @@ +# DINOv3 embedding extraction for cellanome dataset. +# Reads primary_analysis.csv directly, outputs cell-level anndata. +# Usage: uv run python applications/dynaclr/scripts/cellanome/embed_dinov3.py configs/cellanome/embed_dinov3.yml + +# --- Data paths --- +zarr_store: /hpc/projects/multimodal/datasets/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes.zarr +analysis_base: /hpc/instruments/cm.r3200/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/image_analysis_output-12032025-143316 +transcriptome_anndata: /hpc/projects/multimodal/datasets/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/anndata/seurat-bc3a-l_all.zarr +output_path: /hpc/projects/multimodal/datasets/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/2-embeddings/dinov3-convnext-tiny-BF.zarr + +# --- Experiment --- +# Omit to auto-discover all scans/lanes under analysis_base. +# scan_ids: [5] +# lane_ids: [3, 4, 5, 6] + +# --- Model --- +model_name: facebook/dinov3-convnext-tiny-pretrain-lvd1689m + +# --- Channels --- +channels: + - White + +# --- Crop --- +patch_size: 96 +reference_pixel_size: 1.0 +source_pixel_size: 1.0 + +# --- Filtering --- +# Dict of column_name: {min, max, eq, isin} applied to primary_analysis.csv. +filters: + object_class: + isin: [cell, cell-adhered] + object_radius_px: + min: 39 + +# --- Inference --- +batch_size: 128 +device: cuda diff --git a/applications/dynaclr/configs/cellanome/embed_dynaclr.yml b/applications/dynaclr/configs/cellanome/embed_dynaclr.yml new file mode 100644 index 000000000..a6023b717 --- /dev/null +++ b/applications/dynaclr/configs/cellanome/embed_dynaclr.yml @@ -0,0 +1,46 @@ +# DynaCLR embedding extraction for cellanome dataset. +# Reads primary_analysis.csv directly, outputs cell-level anndata. +# Usage: uv run python applications/dynaclr/scripts/cellanome/embed_dynaclr.py configs/cellanome/embed_dynaclr.yml + +# --- Data paths --- +zarr_store: /hpc/projects/multimodal/datasets/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes.zarr +analysis_base: /hpc/instruments/cm.r3200/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/image_analysis_output-12032025-143316 +transcriptome_anndata: /hpc/projects/multimodal/datasets/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/anndata/seurat-bc3a-l_all.zarr +output_path: /hpc/projects/multimodal/datasets/20251203141914_P-05_R000414_FC_BH_120325_try4_Adherent_with_SRA_training_4lanes/2-embeddings/dynaclr-2d-boc-BF.zarr + +# --- Experiment --- +# scan_ids: [5] +# lane_ids: [3, 4, 5, 6] + +# --- Model --- +ckpt_path: /hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_Ph/organelle_sensor_phase_maxproj_ver3_150epochs/saved_checkpoints/epoch=104-step=53760.ckpt +encoder_config: + backbone: convnext_tiny + in_channels: 1 + in_stack_depth: 1 + stem_kernel_size: [1, 4, 4] + stem_stride: [1, 4, 4] + embedding_dim: 768 + projection_dim: 32 + +# --- Channel --- +channel_name: White + +# --- Crop --- +# Trained on 160x160 at 0.149 µm/px (~23.8 µm physical). +# Cellanome is 20x at 0.247 µm/px. +# raw_crop = 160 * 0.149 / 0.247 = 96 px, resized to 160. +patch_size: 160 +reference_pixel_size: 0.149 +source_pixel_size: 0.247 + +# --- Filtering --- +filters: + object_class: + isin: [cell, cell-adhered] + object_radius_px: + min: 39 + +# --- Inference --- +batch_size: 128 +device: cuda diff --git a/applications/dynaclr/configs/collections/DynaCLR-2D-BagOfChannels-v3.yml b/applications/dynaclr/configs/collections/DynaCLR-2D-BagOfChannels-v3.yml index 95047fdc5..415c9238f 100644 --- a/applications/dynaclr/configs/collections/DynaCLR-2D-BagOfChannels-v3.yml +++ b/applications/dynaclr/configs/collections/DynaCLR-2D-BagOfChannels-v3.yml @@ -1,5 +1,5 @@ name: DynaCLR-2D-BagOfChannels-v3 -description: "Multi-organelle bag-of-channels DynaCLR training collection. Includes SEC61B (ER), TOMM20 (mitochondria), and G3BP1 (stress granules) experiments with ZIKV/DENV infection. All 3 channels (Phase3D, GFP, mCherry) trained jointly." +description: "[LEGACY] Multi-organelle bag-of-channels DynaCLR training collection. Includes SEC61B (ER), TOMM20 (mitochondria), and G3BP1 (stress granules) experiments with ZIKV/DENV infection. All 3 channels (Phase3D, GFP, mCherry) trained jointly." provenance: airtable_base_id: app8vqaoWyOwa0sB5 diff --git a/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels-annotated.yml b/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels-annotated.yml new file mode 100644 index 000000000..960779080 --- /dev/null +++ b/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels-annotated.yml @@ -0,0 +1,119 @@ +name: DynaCLR-2D-MIP-BagOfChannels-Annotated +description: "Subset of DynaCLR-2D-MIP-BagOfChannels-MultiCell with available cell annotations. + Includes 2025_01_28 G3BP1 and 2025_07_24 multi-channel experiments. + Used for linear classifier evaluation. ALFI excluded." +datasets_root: /hpc/projects/organelle_phenotyping + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2025_01_28_A549_G3BP1_ZIKV_DENV\", {dataset}), SEARCH(\"2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}))" + record_ids: [] + created_at: "2026-04-08T00:00:00" + created_by: "eduardo.hirata" + +experiments: + # ── G3BP1 (stress granules) — 2025_01_28 ── + # Annotations: B/4 (uninfected), C/4 (infected) + - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_24 multi-channel — G3BP1, SEC61B, viral sensor, Phase3D ── + # Annotations: A/2, C/1, C/2 (TOMM20 wells B/1, B/2 not annotated — excluded) + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_SEC61_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - A/1 + infected: + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_viral_sensor + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_Phase3D + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 diff --git a/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels-v2.yml b/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels-v2.yml new file mode 100644 index 000000000..27db0e2e7 --- /dev/null +++ b/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels-v2.yml @@ -0,0 +1,615 @@ +name: DynaCLR-2D-MIP-BagOfChannels-MultiCell +description: "Multi-cell-type bag-of-channels 2D DynaCLR training collection with z-reduction. Combines A549 infectomics (3D z-stacks from VAST, MIP for fluorescence / center-slice for Phase3D), microglia dynamorph (BF, Phase3D, Retardance), ALFI mitosis (DIC, U2OS/RPE-1/HeLa), and dragonfly confocal. All data paths point to /hpc/projects/organelle_phenotyping/datasets/." +datasets_root: /hpc/projects/organelle_phenotyping + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2024_10_09_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_11_05_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_16_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_31_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2025_01_28_A549_G3BP1_ZIKV_DENV\", {dataset}), SEARCH(\"2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}), SEARCH(\"2025_04_15_A549_H2B_CAAX_ZIKV_DENV\", {dataset}), SEARCH(\"2025_04_17_A549_H2B_CAAX_DENV\", {dataset}), SEARCH(\"2024_08_14_ZIKV_pal17_48h\", {dataset}), SEARCH(\"20191107_1209_1_GW23_dynamorph\", {dataset}), SEARCH(\"ALFI\", {dataset}))" + record_ids: [] + created_at: "2026-03-30T00:00:00" + created_by: "eduardo.hirata" + +experiments: + # ══════════════════════════════════════════════════════════════════════ + # A549 infectomics — 3D z-stacks on VAST (single-channel bags) + # ══════════════════════════════════════════════════════════════════════ + + # ── G3BP1 (stress granules) ── + - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_viral_sensor_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_Phase3D_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── CAAX (membrane) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── H2B (chromatin) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── TOMM20 (mitochondria) ── + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_TOMM20 + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_TOMM20 + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_TOMM20_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/1 + infected: + - B/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── SEC61B (endoplasmic reticulum) ── + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_SEC61 + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_SEC61 + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_SEC61_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - A/1 + infected: + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Viral sensor (mCherry) ── + - name: 2025_07_24_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── A549 Phase3D (label-free) ── + - name: 2025_07_24_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Dragonfly confocal — viral sensor (pAL10) ── + - name: 2024_08_14_ZIKV_pal17_48h_pAL10 + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: MultiCam_GFP_BF + marker: pAL10 + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: pAL10 + organelle: viral_sensor + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 + + - name: 2024_08_14_ZIKV_pal17_48h_Phase3D + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: Phase3D + organelle: Phase3D + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 + + # ══════════════════════════════════════════════════════════════════════ + # Microglia dynamorph — 2D label-free (Phase3D only). + # Brightfield + Retardance dropped: same physical cells as the Phase3D + # entry, so they tripled this experiment's row count and biased + # marker/experiment sampling without adding biological signal. + # ══════════════════════════════════════════════════════════════════════ + + - name: 20191107_GW23_dynamorph_Phase3D + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Phase3D + organelle: Phase3D + pixel_size_xy_um: 0.325 + + # ══════════════════════════════════════════════════════════════════════ + # ALFI — 2D DIC mitosis datasets (multiple cell types) + # ══════════════════════════════════════════════════════════════════════ + + - name: ALFI_U2OS_DMSO_MLN8237 + data_path: ${datasets_root}/datasets/ALFI_U2OS_DMSO_MLN8237/ALFI_U2OS_DMSO_MLN8237.zarr + tracks_path: ${datasets_root}/datasets/ALFI_U2OS_DMSO_MLN8237/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + DMSO: + - MI02/0 + MLN8237: + - MI01/0 + - MI03/0 + - MI04/0 + - MI05/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.1766 + + - name: ALFI_RPE1_untreated + data_path: ${datasets_root}/datasets/ALFI_RPE1_untreated/ALFI_RPE1_untreated.zarr + tracks_path: ${datasets_root}/datasets/ALFI_RPE1_untreated/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + untreated: + - MI07/0 + - MI08/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.2631 + + - name: ALFI_HeLa_DMSO_MLN8237 + data_path: ${datasets_root}/datasets/ALFI_HeLa_DMSO_MLN8237/ALFI_HeLa_DMSO_MLN8237.zarr + tracks_path: ${datasets_root}/datasets/ALFI_HeLa_DMSO_MLN8237/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + DMSO: + - MI06/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.2631 diff --git a/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels-v3.yml b/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels-v3.yml new file mode 100644 index 000000000..4f7ee6b68 --- /dev/null +++ b/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels-v3.yml @@ -0,0 +1,615 @@ +name: DynaCLR-2D-MIP-BagOfChannels-v3 +description: "v3: drops dynamorph Brightfield + Retardance entries (same physical cells as Phase3D, were inflating dynamorph row count and biasing the sampler). Combines A549 infectomics (3D z-stacks from VAST, MIP for fluorescence / center-slice for Phase3D), microglia dynamorph (Phase3D only), ALFI mitosis (DIC, U2OS/RPE-1/HeLa), and dragonfly confocal. All data paths point to /hpc/projects/organelle_phenotyping/datasets/." +datasets_root: /hpc/projects/organelle_phenotyping + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2024_10_09_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_11_05_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_16_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_31_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2025_01_28_A549_G3BP1_ZIKV_DENV\", {dataset}), SEARCH(\"2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}), SEARCH(\"2025_04_15_A549_H2B_CAAX_ZIKV_DENV\", {dataset}), SEARCH(\"2025_04_17_A549_H2B_CAAX_DENV\", {dataset}), SEARCH(\"2024_08_14_ZIKV_pal17_48h\", {dataset}), SEARCH(\"20191107_1209_1_GW23_dynamorph\", {dataset}), SEARCH(\"ALFI\", {dataset}))" + record_ids: [] + created_at: "2026-03-30T00:00:00" + created_by: "eduardo.hirata" + +experiments: + # ══════════════════════════════════════════════════════════════════════ + # A549 infectomics — 3D z-stacks on VAST (single-channel bags) + # ══════════════════════════════════════════════════════════════════════ + + # ── G3BP1 (stress granules) ── + - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_viral_sensor_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_Phase3D_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── CAAX (membrane) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── H2B (chromatin) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── TOMM20 (mitochondria) ── + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_TOMM20 + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_TOMM20 + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_TOMM20_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/1 + infected: + - B/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── SEC61B (endoplasmic reticulum) ── + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_SEC61 + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_SEC61 + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_SEC61_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - A/1 + infected: + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Viral sensor (mCherry) ── + - name: 2025_07_24_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── A549 Phase3D (label-free) ── + - name: 2025_07_24_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Dragonfly confocal — viral sensor (pAL10) ── + - name: 2024_08_14_ZIKV_pal17_48h_pAL10 + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: MultiCam_GFP_BF + marker: pAL10 + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: pAL10 + organelle: viral_sensor + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 + + - name: 2024_08_14_ZIKV_pal17_48h_Phase3D + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: Phase3D + organelle: Phase3D + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 + + # ══════════════════════════════════════════════════════════════════════ + # Microglia dynamorph — 2D label-free (Phase3D only). + # Brightfield + Retardance dropped: same physical cells as the Phase3D + # entry, so they tripled this experiment's row count and biased + # marker/experiment sampling without adding biological signal. + # ══════════════════════════════════════════════════════════════════════ + + - name: 20191107_GW23_dynamorph_Phase3D + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Phase3D + organelle: Phase3D + pixel_size_xy_um: 0.325 + + # ══════════════════════════════════════════════════════════════════════ + # ALFI — 2D DIC mitosis datasets (multiple cell types) + # ══════════════════════════════════════════════════════════════════════ + + - name: ALFI_U2OS_DMSO_MLN8237 + data_path: ${datasets_root}/datasets/ALFI_U2OS_DMSO_MLN8237/ALFI_U2OS_DMSO_MLN8237.zarr + tracks_path: ${datasets_root}/datasets/ALFI_U2OS_DMSO_MLN8237/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + DMSO: + - MI02/0 + MLN8237: + - MI01/0 + - MI03/0 + - MI04/0 + - MI05/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.1766 + + - name: ALFI_RPE1_untreated + data_path: ${datasets_root}/datasets/ALFI_RPE1_untreated/ALFI_RPE1_untreated.zarr + tracks_path: ${datasets_root}/datasets/ALFI_RPE1_untreated/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + untreated: + - MI07/0 + - MI08/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.2631 + + - name: ALFI_HeLa_DMSO_MLN8237 + data_path: ${datasets_root}/datasets/ALFI_HeLa_DMSO_MLN8237/ALFI_HeLa_DMSO_MLN8237.zarr + tracks_path: ${datasets_root}/datasets/ALFI_HeLa_DMSO_MLN8237/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + DMSO: + - MI06/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.2631 diff --git a/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels.yml b/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels.yml new file mode 100644 index 000000000..fb52e3f1e --- /dev/null +++ b/applications/dynaclr/configs/collections/DynaCLR-2D-MIP-BagOfChannels.yml @@ -0,0 +1,658 @@ +name: DynaCLR-2D-MIP-BagOfChannels-MultiCell +description: "Multi-cell-type bag-of-channels 2D DynaCLR training collection with z-reduction. Combines A549 infectomics (3D z-stacks from VAST, MIP for fluorescence / center-slice for Phase3D), microglia dynamorph (BF, Phase3D, Retardance), ALFI mitosis (DIC, U2OS/RPE-1/HeLa), and dragonfly confocal. All data paths point to /hpc/projects/organelle_phenotyping/datasets/." +datasets_root: /hpc/projects/organelle_phenotyping + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2024_10_09_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_11_05_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_16_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_31_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2025_01_28_A549_G3BP1_ZIKV_DENV\", {dataset}), SEARCH(\"2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}), SEARCH(\"2025_04_15_A549_H2B_CAAX_ZIKV_DENV\", {dataset}), SEARCH(\"2025_04_17_A549_H2B_CAAX_DENV\", {dataset}), SEARCH(\"2024_08_14_ZIKV_pal17_48h\", {dataset}), SEARCH(\"20191107_1209_1_GW23_dynamorph\", {dataset}), SEARCH(\"ALFI\", {dataset}))" + record_ids: [] + created_at: "2026-03-30T00:00:00" + created_by: "eduardo.hirata" + +experiments: + # ══════════════════════════════════════════════════════════════════════ + # A549 infectomics — 3D z-stacks on VAST (single-channel bags) + # ══════════════════════════════════════════════════════════════════════ + + # ── G3BP1 (stress granules) ── + - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_viral_sensor_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_Phase3D_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── CAAX (membrane) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── H2B (chromatin) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── TOMM20 (mitochondria) ── + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_TOMM20 + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_TOMM20 + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_TOMM20_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/1 + infected: + - B/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── SEC61B (endoplasmic reticulum) ── + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_SEC61 + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_SEC61 + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_SEC61_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - A/1 + infected: + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Viral sensor (mCherry) ── + - name: 2025_07_24_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── A549 Phase3D (label-free) ── + - name: 2025_07_24_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Dragonfly confocal — viral sensor (pAL10) ── + - name: 2024_08_14_ZIKV_pal17_48h_pAL10 + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: MultiCam_GFP_BF + marker: pAL10 + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: pAL10 + organelle: viral_sensor + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 + + - name: 2024_08_14_ZIKV_pal17_48h_Phase3D + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: Phase3D + organelle: Phase3D + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 + + # ══════════════════════════════════════════════════════════════════════ + # Microglia dynamorph — 2D label-free (BF, Phase3D, Retardance) + # ══════════════════════════════════════════════════════════════════════ + + - name: 20191107_GW23_dynamorph_Brightfield + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Brightfield + marker: Brightfield + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Brightfield + organelle: Brightfield + pixel_size_xy_um: 0.325 + + - name: 20191107_GW23_dynamorph_Phase3D + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Phase3D + organelle: Phase3D + pixel_size_xy_um: 0.325 + + - name: 20191107_GW23_dynamorph_Retardance + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Retardance + marker: Retardance + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Retardance + organelle: Retardance + pixel_size_xy_um: 0.325 + + # ══════════════════════════════════════════════════════════════════════ + # ALFI — 2D DIC mitosis datasets (multiple cell types) + # ══════════════════════════════════════════════════════════════════════ + + - name: ALFI_U2OS_DMSO_MLN8237 + data_path: ${datasets_root}/datasets/ALFI_U2OS_DMSO_MLN8237/ALFI_U2OS_DMSO_MLN8237.zarr + tracks_path: ${datasets_root}/datasets/ALFI_U2OS_DMSO_MLN8237/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + DMSO: + - MI02/0 + MLN8237: + - MI01/0 + - MI03/0 + - MI04/0 + - MI05/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.1766 + + - name: ALFI_RPE1_untreated + data_path: ${datasets_root}/datasets/ALFI_RPE1_untreated/ALFI_RPE1_untreated.zarr + tracks_path: ${datasets_root}/datasets/ALFI_RPE1_untreated/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + untreated: + - MI07/0 + - MI08/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.2631 + + - name: ALFI_HeLa_DMSO_MLN8237 + data_path: ${datasets_root}/datasets/ALFI_HeLa_DMSO_MLN8237/ALFI_HeLa_DMSO_MLN8237.zarr + tracks_path: ${datasets_root}/datasets/ALFI_HeLa_DMSO_MLN8237/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + DMSO: + - MI06/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.2631 diff --git a/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v2.yml b/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v2.yml index d15d22bc6..c71b97a79 100644 --- a/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v2.yml +++ b/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v2.yml @@ -1,5 +1,6 @@ name: DynaCLR-3D-BagOfChannels-v2 description: "Multi-organelle bag-of-channels 3D DynaCLR training collection. Each experiment entry is a single-channel bag: H2B (chromatin), CAAX (membrane), TOMM20 (mitochondria), SEC61B (ER), G3BP1 (stress granules), viral sensor (mCherry/pAL10), and Phase3D (label-free). Includes dragonfly confocal (2024_08_14_ZIKV_pal17_48h) for cross-microscope training. All data paths point to VAST (zarr v3, rechunked)." +datasets_root: /hpc/projects/organelle_phenotyping provenance: airtable_base_id: app8vqaoWyOwa0sB5 @@ -11,8 +12,8 @@ provenance: experiments: # ── G3BP1 (stress granules) ── - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: G3BP1 @@ -30,9 +31,49 @@ experiments: pixel_size_xy_um: 0.1494 pixel_size_z_um: 0.174 + - name: 2025_01_28_A549_viral_sensor_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_Phase3D_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + - name: 2025_07_24_A549_G3BP1_ZIKV - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: G3BP1 @@ -51,8 +92,8 @@ experiments: # ── CAAX (membrane) ── - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_CAAX - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr channels: - name: raw mCherry EX561 EM600-37 marker: CAAX @@ -70,8 +111,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2025_04_17_A549_H2B_CAAX_DENV_CAAX - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr channels: - name: raw mCherry EX561 EM600-37 marker: CAAX @@ -90,8 +131,8 @@ experiments: # ── H2B (chromatin) ── - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_H2B - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr channels: - name: raw Cy5 EX639 EM698-70 marker: HIST2H2BE @@ -109,8 +150,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2025_04_17_A549_H2B_CAAX_DENV_H2B - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr channels: - name: raw Cy5 EX639 EM698-70 marker: HIST2H2BE @@ -129,8 +170,8 @@ experiments: # ── TOMM20 (mitochondria) ── - name: 2024_10_09_A549_TOMM20_ZIKV_DENV - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: TOMM20 @@ -147,9 +188,47 @@ experiments: pixel_size_xy_um: 0.1494 pixel_size_z_um: 0.174 + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_TOMM20 - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: TOMM20 @@ -167,8 +246,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2025_07_24_A549_TOMM20_ZIKV - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: TOMM20 @@ -187,8 +266,8 @@ experiments: # ── SEC61B (endoplasmic reticulum) ── - name: 2024_10_16_A549_SEC61_ZIKV_DENV - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: SEC61B @@ -205,9 +284,47 @@ experiments: pixel_size_xy_um: 0.1494 pixel_size_z_um: 0.174 + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_SEC61 - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: SEC61B @@ -225,8 +342,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2025_07_24_A549_SEC61_ZIKV - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr channels: - name: raw GFP EX488 EM525-45 marker: SEC61B @@ -244,9 +361,9 @@ experiments: pixel_size_z_um: 0.174 # ── Viral sensor (mCherry) ── - - name: 2025_07_24_A549_viral_sensor - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + - name: 2025_07_24_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr channels: - name: raw mCherry EX561 EM600-37 marker: viral_sensor @@ -267,8 +384,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_viral_sensor - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr channels: - name: raw mCherry EX561 EM600-37 marker: viral_sensor @@ -286,8 +403,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2024_10_31_A549_SEC61_ZIKV_DENV_viral_sensor - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr channels: - name: raw mCherry EX561 EM600-37 marker: viral_sensor @@ -305,9 +422,9 @@ experiments: pixel_size_z_um: 0.174 # ── Phase3D (label-free) ── - - name: 2025_07_24_A549_Phase3D - data_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + - name: 2025_07_24_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr channels: - name: Phase3D marker: Phase3D @@ -328,8 +445,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_Phase3D - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr channels: - name: Phase3D marker: Phase3D @@ -347,8 +464,8 @@ experiments: pixel_size_z_um: 0.174 - name: 2024_10_31_A549_SEC61_ZIKV_DENV_Phase3D - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr - tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr channels: - name: Phase3D marker: Phase3D @@ -367,8 +484,8 @@ experiments: # ── Dragonfly confocal — viral sensor (pAL10) ── - name: 2024_08_14_ZIKV_pal17_48h_pAL10 - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr - tracks_path: /hpc/projects/intracellular_dashboard/viral-sensor/2024_08_14_ZIKV_pal17_48h/5-tracking/2024_08_14_ZIKV_pal17_48h.zarr + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr channels: - name: MultiCam_GFP_BF marker: pAL10 @@ -389,8 +506,8 @@ experiments: # ── Dragonfly confocal — Phase3D (label-free) ── - name: 2024_08_14_ZIKV_pal17_48h_Phase3D - data_path: /hpc/projects/organelle_phenotyping/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr - tracks_path: /hpc/projects/intracellular_dashboard/viral-sensor/2024_08_14_ZIKV_pal17_48h/5-tracking/2024_08_14_ZIKV_pal17_48h.zarr + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr channels: - name: Phase3D marker: Phase3D diff --git a/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v3.yml b/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v3.yml new file mode 100644 index 000000000..fc53aab45 --- /dev/null +++ b/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v3.yml @@ -0,0 +1,527 @@ +name: DynaCLR-3D-BagOfChannels-v3 +description: "Multi-organelle bag-of-channels 3D DynaCLR training collection. Each experiment entry is a single-channel bag: H2B (chromatin), CAAX (membrane), TOMM20 (mitochondria), SEC61B (ER), G3BP1 (stress granules), viral sensor (mCherry/pAL10), and Phase3D (label-free). Includes dragonfly confocal (2024_08_14_ZIKV_pal17_48h) for cross-microscope training. All data paths point to VAST (zarr v3, rechunked)." +datasets_root: /hpc/projects/organelle_phenotyping + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2024_10_09_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_11_05_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_16_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_31_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2025_01_28_A549_G3BP1_ZIKV_DENV\", {dataset}), SEARCH(\"2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}), SEARCH(\"2025_04_15_A549_H2B_CAAX_ZIKV_DENV\", {dataset}), SEARCH(\"2025_04_17_A549_H2B_CAAX_DENV\", {dataset}), SEARCH(\"2024_08_14_ZIKV_pal17_48h\", {dataset}))" + record_ids: [] + created_at: "2026-04-10T00:00:00" + created_by: "eduardo.hirata" + +experiments: + # ── G3BP1 (stress granules) ── + - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_viral_sensor_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_Phase3D_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── CAAX (membrane) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── H2B (chromatin) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── TOMM20 (mitochondria) ── + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_TOMM20 + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_TOMM20_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/1 + infected: + - B/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── SEC61B (endoplasmic reticulum) ── + - name: 2024_10_16_A549_SEC61_ZIKV_DENV + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_SEC61 + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_SEC61_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - A/1 + infected: + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Viral sensor (mCherry) ── + - name: 2025_07_24_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Phase3D (label-free) ── + - name: 2025_07_24_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Dragonfly confocal — viral sensor (pAL10) ── + - name: 2024_08_14_ZIKV_pal17_48h_pAL10 + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: MultiCam_GFP_BF + marker: pAL10 + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: pAL10 + organelle: viral_sensor + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 + + # ── Dragonfly confocal — Phase3D (label-free) ── + - name: 2024_08_14_ZIKV_pal17_48h_Phase3D + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: Phase3D + organelle: label_free + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 diff --git a/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v4.yml b/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v4.yml new file mode 100644 index 000000000..23787e77d --- /dev/null +++ b/applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v4.yml @@ -0,0 +1,527 @@ +name: DynaCLR-3D-BagOfChannels-v2 +description: "Multi-organelle bag-of-channels 3D DynaCLR training collection. Each experiment entry is a single-channel bag: H2B (chromatin), CAAX (membrane), TOMM20 (mitochondria), SEC61B (ER), G3BP1 (stress granules), viral sensor (mCherry/pAL10), and Phase3D (label-free). Includes dragonfly confocal (2024_08_14_ZIKV_pal17_48h) for cross-microscope training. All data paths point to VAST (zarr v3, rechunked)." +datasets_root: /hpc/projects/organelle_phenotyping + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2024_10_09_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_11_05_A549_TOMM20_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_16_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2024_10_31_A549_SEC61_ZIKV_DENV\", {dataset}), SEARCH(\"2025_01_28_A549_G3BP1_ZIKV_DENV\", {dataset}), SEARCH(\"2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}), SEARCH(\"2025_04_15_A549_H2B_CAAX_ZIKV_DENV\", {dataset}), SEARCH(\"2025_04_17_A549_H2B_CAAX_DENV\", {dataset}), SEARCH(\"2024_08_14_ZIKV_pal17_48h\", {dataset}))" + record_ids: [] + created_at: "2026-03-27T00:00:00" + created_by: "eduardo.hirata" + +experiments: + # ── G3BP1 (stress granules) ── + - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_viral_sensor_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_Phase3D_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── CAAX (membrane) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_CAAX + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: CAAX + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: CAAX + organelle: membrane + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── H2B (chromatin) ── + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_17_A549_H2B_CAAX_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/2025_04_17_A549_H2B_CAAX_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_17_A549_H2B_CAAX_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 10.0 + start_hpi: 3.0 + marker: HIST2H2BE + organelle: chromatin + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── TOMM20 (mitochondria) ── + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_09_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/2024_10_09_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_09_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - A/4 + infected: + - B/4 + interval_minutes: 30.0 + start_hpi: 5.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_TOMM20 + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_TOMM20_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: TOMM20 + perturbation_wells: + uninfected: + - B/1 + infected: + - B/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: TOMM20 + organelle: mitochondria + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── SEC61B (endoplasmic reticulum) ── + - name: 2024_10_16_A549_SEC61_ZIKV_DENV + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_16_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_16_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_SEC61 + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_SEC61_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - A/1 + infected: + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Viral sensor (mCherry) ── + - name: 2025_07_24_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Phase3D (label-free) ── + - name: 2025_07_24_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_05_A549_TOMM20_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/2024_11_05_A549_TOMM20_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_05_A549_TOMM20_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.5 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_10_31_A549_SEC61_ZIKV_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/2024_10_31_A549_SEC61_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_10_31_A549_SEC61_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: label_free + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── Dragonfly confocal — viral sensor (pAL10) ── + - name: 2024_08_14_ZIKV_pal17_48h_pAL10 + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: MultiCam_GFP_BF + marker: pAL10 + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: pAL10 + organelle: viral_sensor + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 + + # ── Dragonfly confocal — Phase3D (label-free) ── + - name: 2024_08_14_ZIKV_pal17_48h_Phase3D + data_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: ${datasets_root}/datasets/2024_08_14_ZIKV_pal17_48h/tracking.zarr/2024_08_14_ZIKV_pal17_48h.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + - "0/5" + - "0/6" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: Phase3D + organelle: label_free + moi: 1.0 + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878 diff --git a/applications/dynaclr/configs/collections/DynaCLR-BoC-lc-evaluation-v1-test.yml b/applications/dynaclr/configs/collections/DynaCLR-BoC-lc-evaluation-v1-test.yml new file mode 100644 index 000000000..0b03d5401 --- /dev/null +++ b/applications/dynaclr/configs/collections/DynaCLR-BoC-lc-evaluation-v1-test.yml @@ -0,0 +1,174 @@ +name: DynaCLR-BoC-lc-evaluation-v1-test +description: "Minimal subset of DynaCLR-BoC-lc-evaluation-v1 for fast end-to-end + testing of MMD and linear classifier evaluation. Three markers (G3BP1, Phase3D, + viral_sensor) across two dates (2025_07_22 and 2025_07_24) + one G3BP1-only + experiment (2025_01_28). Enables cross-experiment MMD for all three markers." +datasets_root: /hpc/projects/organelle_phenotyping + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2025_01_28_A549_G3BP1_ZIKV_DENV\", {dataset}), SEARCH(\"2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}), SEARCH(\"2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}))" + record_ids: [] + created_at: "2026-04-09T00:00:00" + created_by: "eduardo.hirata" + +experiments: + # ── 2025_01_28: G3BP1, ZIKV + DENV — 1 uninfected + 1 infected well ── + - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_01_28: Phase3D, ZIKV + DENV — 1 uninfected + 1 infected well ── + - name: 2025_01_28_A549_Phase3D_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_24: G3BP1, ZIKV — 1 uninfected + 1 infected well ── + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_24: Phase3D, ZIKV — 1 uninfected + 1 infected well ── + - name: 2025_07_24_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_22: G3BP1, ZIKV — 1 uninfected (C/1) + 1 infected (C/2) ── + - name: 2025_07_22_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_22: Phase3D, ZIKV — 1 uninfected (C/1) + 1 infected (C/2) ── + - name: 2025_07_22_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_22: viral_sensor, ZIKV — 1 uninfected (C/1) + 1 infected (C/2) ── + - name: 2025_07_22_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_24: viral_sensor, ZIKV — 1 uninfected + 1 infected well ── + - name: 2025_07_24_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 diff --git a/applications/dynaclr/configs/collections/DynaCLR-BoC-lc-evaluation-v1.yml b/applications/dynaclr/configs/collections/DynaCLR-BoC-lc-evaluation-v1.yml new file mode 100644 index 000000000..f96572fd6 --- /dev/null +++ b/applications/dynaclr/configs/collections/DynaCLR-BoC-lc-evaluation-v1.yml @@ -0,0 +1,401 @@ +name: DynaCLR-BoC-lc-evaluation-v1 +description: "Annotated experiments for linear classifier evaluation of bag-of-channels DynaCLR models. + Includes all datasets with infection_state / cell_division_state annotations and processed zarr stores." +datasets_root: /hpc/projects/organelle_phenotyping + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2024_11_07\", {dataset}), SEARCH(\"2025_01_24\", {dataset}), SEARCH(\"2025_01_28_A549_G3BP1_ZIKV_DENV\", {dataset}), SEARCH(\"2025_07_22\", {dataset}), SEARCH(\"2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV\", {dataset}), SEARCH(\"2025_08_26\", {dataset}))" + record_ids: [] + created_at: "2026-04-09T00:00:00" + created_by: "eduardo.hirata" + +experiments: + # ── 2025_01_28: G3BP1 (stress granules), ZIKV + DENV ── + # Annotated wells: B/4 (uninfected), C/4 (infected) + - name: 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_viral_sensor_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_28_A549_Phase3D_ZIKV_DENV + data_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_28_A549_G3BP1_ZIKV_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/4 + infected: + - B/2 + - C/4 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_24: multi-channel (G3BP1, SEC61B, viral sensor, Phase3D), ZIKV ── + # Annotated wells: A/2 (infected), C/1 (uninfected), C/2 (infected) + # TOMM20 wells B/1, B/2 not annotated — excluded from this collection + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_SEC61_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - A/1 + infected: + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - A/1 + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_24_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - A/1 + - C/1 + - B/1 + infected: + - C/2 + - B/2 + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2024_11_07: SEC61B (ER), DENV ── + # Annotated wells: B/3 (uninfected), C/2 (infected+uninfected) + - name: 2024_11_07_A549_SEC61_DENV_SEC61B + data_path: ${datasets_root}/datasets/2024_11_07_A549_SEC61_DENV/2024_11_07_A549_SEC61_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_07_A549_SEC61_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - B/3 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_07_A549_SEC61_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2024_11_07_A549_SEC61_DENV/2024_11_07_A549_SEC61_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_07_A549_SEC61_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/3 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2024_11_07_A549_SEC61_DENV_Phase3D + data_path: ${datasets_root}/datasets/2024_11_07_A549_SEC61_DENV/2024_11_07_A549_SEC61_DENV.zarr + tracks_path: ${datasets_root}/datasets/2024_11_07_A549_SEC61_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/3 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_01_24: G3BP1 (stress granules), DENV ── + # Annotated wells: B/1 (uninfected), B/2 (infected), B/3 (uninfected), C/2 (infected) + - name: 2025_01_24_A549_G3BP1_DENV_G3BP1 + data_path: ${datasets_root}/datasets/2025_01_24_A549_G3BP1_DENV/2025_01_24_A549_G3BP1_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_24_A549_G3BP1_DENV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - B/1 + - B/3 + infected: + - B/2 + - C/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_24_A549_G3BP1_DENV_viral_sensor + data_path: ${datasets_root}/datasets/2025_01_24_A549_G3BP1_DENV/2025_01_24_A549_G3BP1_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_24_A549_G3BP1_DENV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - B/3 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_01_24_A549_G3BP1_DENV_Phase3D + data_path: ${datasets_root}/datasets/2025_01_24_A549_G3BP1_DENV/2025_01_24_A549_G3BP1_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_01_24_A549_G3BP1_DENV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - B/1 + - B/3 + infected: + - B/2 + - C/2 + interval_minutes: 30.0 + start_hpi: 4.0 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_07_22: G3BP1 (stress granules) + pAL17 (viral sensor), ZIKV ── + # Annotated wells: C/1 (uninfected), C/2 (infected) + - name: 2025_07_22_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_22_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_07_22_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + # ── 2025_08_26: SEC61B (ER), ZIKV ── + # Annotated wells: A/1 (uninfected), B/1 (infected+uninfected) + - name: 2025_08_26_A549_SEC61_ZIKV + data_path: ${datasets_root}/datasets/2025_08_26_A549_SEC61_TOMM20_ZIKV/2025_08_26_A549_SEC61_TOMM20_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_08_26_A549_SEC61_TOMM20_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: SEC61B + perturbation_wells: + uninfected: + - A/1 + infected: + - B/1 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_08_26_A549_viral_sensor_ZIKV + data_path: ${datasets_root}/datasets/2025_08_26_A549_SEC61_TOMM20_ZIKV/2025_08_26_A549_SEC61_TOMM20_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_08_26_A549_SEC61_TOMM20_ZIKV/tracking.zarr + channels: + - name: raw mCherry EX561 EM600-37 + marker: viral_sensor + perturbation_wells: + uninfected: + - A/1 + infected: + - B/1 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: viral_sensor + organelle: viral_sensor + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_08_26_A549_Phase3D_ZIKV + data_path: ${datasets_root}/datasets/2025_08_26_A549_SEC61_TOMM20_ZIKV/2025_08_26_A549_SEC61_TOMM20_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_08_26_A549_SEC61_TOMM20_ZIKV/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + uninfected: + - A/1 + infected: + - B/1 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: Phase3D + organelle: Phase3D + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 diff --git a/applications/dynaclr/configs/collections/alfi-eval.yml b/applications/dynaclr/configs/collections/alfi-eval.yml new file mode 100644 index 000000000..3b66830f8 --- /dev/null +++ b/applications/dynaclr/configs/collections/alfi-eval.yml @@ -0,0 +1,55 @@ +name: alfi-eval +description: "ALFI mitosis evaluation collection. All 3 cell lines (HeLa MI06, RPE1 MI07/MI08, U2OS MI01-MI05), DIC channel. Analysis done per cell line." +datasets_root: /hpc/projects/organelle_phenotyping + +experiments: + - name: ALFI_HeLa_DMSO_MLN8237 + data_path: ${datasets_root}/datasets/ALFI_HeLa_DMSO_MLN8237/ALFI_HeLa_DMSO_MLN8237.zarr + tracks_path: ${datasets_root}/datasets/ALFI_HeLa_DMSO_MLN8237/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + DMSO: + - MI06/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.2631 + + - name: ALFI_RPE1_untreated + data_path: ${datasets_root}/datasets/ALFI_RPE1_untreated/ALFI_RPE1_untreated.zarr + tracks_path: ${datasets_root}/datasets/ALFI_RPE1_untreated/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + untreated: + - MI07/0 + - MI08/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.2631 + + - name: ALFI_U2OS_DMSO_MLN8237 + data_path: ${datasets_root}/datasets/ALFI_U2OS_DMSO_MLN8237/ALFI_U2OS_DMSO_MLN8237.zarr + tracks_path: ${datasets_root}/datasets/ALFI_U2OS_DMSO_MLN8237/tracking.zarr + channels: + - name: DIC + marker: DIC + perturbation_wells: + DMSO: + - MI02/0 + MLN8237: + - MI01/0 + - MI03/0 + - MI04/0 + - MI05/0 + interval_minutes: 7.0 + start_hpi: 0.0 + marker: DIC + organelle: DIC + pixel_size_xy_um: 0.1766 diff --git a/applications/dynaclr/configs/collections/benchmark_2exp.yml b/applications/dynaclr/configs/collections/benchmark_2exp.yml new file mode 100644 index 000000000..eeada4a1c --- /dev/null +++ b/applications/dynaclr/configs/collections/benchmark_2exp.yml @@ -0,0 +1,36 @@ +name: benchmark_2exp +description: "Benchmark collection: G3BP1 (2025_07_24) + H2B (2025_04_15) for dataloader profiling" +datasets_root: /hpc/projects/organelle_phenotyping + +experiments: + - name: 2025_07_24_A549_G3BP1_ZIKV + data_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: ${datasets_root}/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr + channels: + - name: raw GFP EX488 EM525-45 + marker: G3BP1 + perturbation_wells: + uninfected: + - C/1 + infected: + - C/2 + interval_minutes: 30.0 + start_hpi: 3.5 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + + - name: 2025_04_15_A549_H2B_CAAX_ZIKV_DENV_H2B + data_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr + tracks_path: ${datasets_root}/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr + channels: + - name: raw Cy5 EX639 EM698-70 + marker: HIST2H2BE + perturbation_wells: + uninfected: + - B/1 + DENV: + - B/2 + interval_minutes: 30.0 + start_hpi: 4.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 diff --git a/applications/dynaclr/configs/collections/example_mantis_dragonfly.yml b/applications/dynaclr/configs/collections/example_mantis_dragonfly.yml new file mode 100644 index 000000000..97953cea3 --- /dev/null +++ b/applications/dynaclr/configs/collections/example_mantis_dragonfly.yml @@ -0,0 +1,54 @@ +name: example_mantis_dragonfly +description: "Example collection combining mantis (lightsheet) and dragonfly (confocal) datasets. SEC61B from 2025_07_24 ZIKV experiment and pAL10 viral sensor from 2024_08_14 ZIKV experiment." + +provenance: + airtable_base_id: app8vqaoWyOwa0sB5 + airtable_query: "OR(SEARCH(\"2025_07_24\", {dataset}), SEARCH(\"2024_08_14\", {dataset}))" + record_ids: [] + created_at: "2026-04-01T00:00:00" + created_by: "eduardo.hirata" + +experiments: + - name: 2025_07_24_A549_SEC61_ZIKV + data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/train-test/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/1-preprocess/label-free/3-track/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_cropped.zarr + channels: + - name: Phase3D + marker: Phase3D + - name: GFP EX488 EM525-45 + marker: SEC61B + - name: mCherry EX561 EM600-37 + marker: mCherry + perturbation_wells: + ZIKV: + - A/2 + interval_minutes: 30.0 + start_hpi: 3.5 + marker: SEC61B + organelle: endoplasmic_reticulum + microscope: mantis + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 + moi: 5.0 + + - name: 2024_08_14_ZIKV_pal17_48h + data_path: /hpc/projects/organelle_phenotyping/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_sharded.zarr + tracks_path: /hpc/projects/organelle_phenotyping/datasets/2024_08_14_ZIKV_pal17_48h/2024_08_14_ZIKV_pal17_48h_timeaware_tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + - name: MultiCam_GFP_BF + marker: pAL10 + perturbation_wells: + uninfected: + - "0/3" + ZIKV: + - "0/4" + interval_minutes: 30.0 + start_hpi: 3.0 + marker: pAL10 + organelle: viral_sensor + microscope: dragonfly + pixel_size_xy_um: 0.206 + pixel_size_z_um: 0.2878069639205931 + moi: 5.0 diff --git a/applications/dynaclr/configs/collections/microglia-eval.yml b/applications/dynaclr/configs/collections/microglia-eval.yml new file mode 100644 index 000000000..db2c13f00 --- /dev/null +++ b/applications/dynaclr/configs/collections/microglia-eval.yml @@ -0,0 +1,73 @@ +name: microglia-eval +description: "Microglia dynamorph evaluation collection. All 3 label-free channels (Brightfield, Phase3D, Retardance), all 5 perturbation conditions." +datasets_root: /hpc/projects/organelle_phenotyping + +experiments: + - name: 20191107_GW23_dynamorph_Brightfield + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Brightfield + marker: Brightfield + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Brightfield + organelle: Brightfield + pixel_size_xy_um: 0.325 + + - name: 20191107_GW23_dynamorph_Phase3D + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Phase3D + marker: Phase3D + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Phase3D + organelle: Phase3D + pixel_size_xy_um: 0.325 + + - name: 20191107_GW23_dynamorph_Retardance + data_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/20191107_1209_1_GW23_dynamorph.zarr + tracks_path: ${datasets_root}/datasets/20191107_1209_1_GW23_dynamorph/tracking.zarr + channels: + - name: Retardance + marker: Retardance + perturbation_wells: + untreated: + - C/5 + IL17: + - B/4 + IFN-beta: + - B/5 + Rubella: + - C/4 + Glioblastoma_supernatant: + - B/2 + interval_minutes: 9.0 + start_hpi: 0.0 + marker: Retardance + organelle: Retardance + pixel_size_xy_um: 0.325 diff --git a/applications/dynaclr/configs/dimensionality_reduction/multi-dataset-dim-reduction.yml b/applications/dynaclr/configs/dimensionality_reduction/multi-dataset-dim-reduction.yml new file mode 100644 index 000000000..878087e3f --- /dev/null +++ b/applications/dynaclr/configs/dimensionality_reduction/multi-dataset-dim-reduction.yml @@ -0,0 +1,31 @@ +datasets: + "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV": + hcs_plate: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/train-test/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_2.zarr + anndata: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3/timeaware_phase_160patch_104ckpt.zarr + "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV": + hcs_plate: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/train-test/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr + anndata: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3/timeaware_phase_160patch_104ckpt.zarr + +# Usage: +# dynaclr combined-dim-reduction -c multi-dataset-dim-reduction.yml +# +# Notes: +# - `datasets[*].anndata` are the AnnData zarrs that will be concatenated to fit the joint reductions. +# - Remove any method section (pca/umap/phate) to skip computing it. +reduce_combined: + overwrite_keys: false + + # PCA configuration (remove this section to skip PCA) + pca: + # Number of components. null = keep all components. + n_components: 32 + normalize_features: true + + # PHATE configuration (remove this section to skip PHATE) + phate: + n_components: 2 + knn: 5 + decay: 40 + scale_embeddings: true + random_state: 42 + n_jobs: -1 diff --git a/applications/dynaclr/configs/evaluation/DINOv3-frozen/infectomics-annotated.yaml b/applications/dynaclr/configs/evaluation/DINOv3-frozen/infectomics-annotated.yaml new file mode 100644 index 000000000..84a28ca70 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DINOv3-frozen/infectomics-annotated.yaml @@ -0,0 +1,28 @@ +# Evaluation config: DINOv3-frozen × infectomics-annotated (Wave 1, baseline). +# +# DINOv3-frozen uses the raw HuggingFace DINOv3 convnext-tiny weights with no +# contrastive fine-tuning and no trainable projection. Tests whether +# DynaCLR's contrastive training adds anything beyond pre-trained DINOv3 +# features alone. +# +# `ckpt_path: null` — DINOv3 weights are loaded from HuggingFace inside +# `DINOv3Model.__init__`. The eval orchestrator omits `ckpt_path` from the +# generated predict YAML so Lightning skips checkpoint restoration. +# +# Submit: +# sbatch applications/dynaclr/configs/evaluation/DINOv3-frozen/run_infectomics_annotated.sh + +base: + - ../recipes/predict.yml + - ../recipes/reduce.yml + - ../recipes/plot_infectomics.yml + - ../recipes/linear_classifiers_infectomics.yml + - ../recipes/infectomics-annotated.yml + +training_config: /hpc/mydata/eduardo.hirata/repos/viscy/applications/dynaclr/configs/evaluation/DINOv3-frozen/training_config_dinov3_frozen.yaml +ckpt_path: null +output_dir: /hpc/projects/organelle_phenotyping/models/DINOv3-frozen/evaluations/infectomics-annotated/ + +# Publish to its own sub-registry so the LC bundle is identifiable. +linear_classifiers: + publish_dir: /hpc/projects/organelle_phenotyping/models/linear_classifiers/DINOv3-frozen/infectomics/ diff --git a/applications/dynaclr/configs/evaluation/DINOv3-frozen/run_infectomics_annotated.sh b/applications/dynaclr/configs/evaluation/DINOv3-frozen/run_infectomics_annotated.sh new file mode 100644 index 000000000..60939fe13 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DINOv3-frozen/run_infectomics_annotated.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Wave-1 evaluation: DINOv3-frozen × infectomics-annotated. +# +# Baseline that pulls raw DINOv3 convnext-tiny weights from HuggingFace — +# no contrastive fine-tuning, no projection head. Tests whether DynaCLR's +# training adds value beyond pre-trained DINOv3 features. +# +# Submit: +# sbatch applications/dynaclr/configs/evaluation/DINOv3-frozen/run_infectomics_annotated.sh + +#SBATCH --job-name=eval_w1_dinov3_frozen_infectomics +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=32G +#SBATCH --partition=cpu +#SBATCH --time=1-00:00:00 +#SBATCH --output=/hpc/projects/organelle_phenotyping/models/DINOv3-frozen/evaluations/infectomics-annotated/nextflow_logs/%x-%j.out + +export PYTHONNOUSERSITE=1 + +module load nextflow/24.10.5 + +WORKSPACE="/hpc/mydata/eduardo.hirata/repos/viscy" +CONFIG="$WORKSPACE/applications/dynaclr/configs/evaluation/DINOv3-frozen/infectomics-annotated.yaml" +LOGDIR="/hpc/projects/organelle_phenotyping/models/DINOv3-frozen/evaluations/infectomics-annotated/nextflow_logs" + +mkdir -p "$LOGDIR" +cd "$LOGDIR" + +nextflow run "$WORKSPACE/applications/dynaclr/nextflow/main.nf" -entry evaluation --eval_config "$CONFIG" --workspace_dir "$WORKSPACE" -resume diff --git a/applications/dynaclr/configs/evaluation/DINOv3-frozen/training_config_dinov3_frozen.yaml b/applications/dynaclr/configs/evaluation/DINOv3-frozen/training_config_dinov3_frozen.yaml new file mode 100644 index 000000000..bd1310b5d --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DINOv3-frozen/training_config_dinov3_frozen.yaml @@ -0,0 +1,187 @@ +# Synthetic "training_config" for DINOv3-frozen evaluation baseline. +# +# DINOv3-frozen uses raw HuggingFace DINOv3 weights — no contrastive +# fine-tuning, no trainable projection head. Embeddings are the convnext-tiny +# backbone output (768-dim). +# +# This file is consumed by `dynaclr prepare-eval-configs` to generate a +# Lightning predict YAML. There is no Lightning checkpoint — the leaf YAML +# sets `ckpt_path: null`, so the orchestrator omits `ckpt_path` from the +# predict YAML, and DINOv3Model.__init__ pulls weights from HuggingFace. +# +# Mirrors `DINOv3-temporal-MLP-2D-BagOfChannels/v1/config_updated.yaml` +# byte-for-byte except `model.init_args.encoder.init_args.projection: null`. +# Same backbone, same data preprocessing → apples-to-apples comparison +# isolating the contribution of the trained MLP projection head. + +# lightning.pytorch==2.6.1 +seed_everything: 42 +trainer: + accelerator: gpu + strategy: auto + devices: 1 + num_nodes: 1 + precision: 32-true + fast_dev_run: false + max_epochs: 100 + log_every_n_steps: 10 + enable_checkpointing: false + accumulate_grad_batches: 1 + inference_mode: true +model: + class_path: dynaclr.engine.ContrastiveModule + init_args: + encoder: + class_path: viscy_models.foundation.dinov3.DINOv3Model + init_args: + model_name: facebook/dinov3-convnext-tiny-pretrain-lvd1689m + freeze: true + projection: null + loss_function: + class_path: pytorch_metric_learning.losses.NTXentLoss + init_args: + temperature: 0.5 + embedding_regularizer: null + embedding_reg_weight: 1 + reducer: null + distance: null + collect_stats: null + lr: 0.0001 + schedule: Constant + log_batches_per_epoch: 3 + log_samples_per_batch: 3 + log_embeddings_every_n_epochs: 10 + pca_color_keys: condition + log_negative_metrics_every_n_epochs: 2 + example_input_array_shape: + - 1 + - 1 + - 1 + - 160 + - 160 + ckpt_path: null + freeze_backbone: false + projection: null +data: + class_path: dynaclr.data.datamodule.MultiExperimentDataModule + init_args: + cell_index_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/models/DINOv3-temporal-MLP-2D-BagOfChannels/DynaCLR-2D-BagOfChannels-v3.parquet + z_window: 1 + yx_patch_size: + - 256 + - 256 + final_yx_patch_size: + - 160 + - 160 + val_experiments: [] + split_ratio: 0.8 + tau_range: + - 0.5 + - 2.0 + tau_decay_rate: 0.0 + batch_size: 256 + num_workers: 1 + stratify_by: + - condition + leaky: 0.0 + temporal_enrichment: false + temporal_window_hours: 2.0 + temporal_global_fraction: 0.3 + channels_per_sample: 1 + channel_dropout_channels: null + channel_dropout_prob: 0.0 + normalizations: + - class_path: viscy_transforms.ScaleIntensityRangePercentilesd + init_args: + keys: + - channel_0 + lower: 50.0 + upper: 99.0 + b_min: 0.0 + b_max: 1.0 + clip: false + relative: false + channel_wise: false + allow_missing_keys: false + augmentations: + - class_path: viscy_transforms.BatchedRandAffined + init_args: + keys: + - channel_0 + prob: 0.8 + rotate_range: + - 3.14 + - 0.0 + - 0.0 + shear_range: + - 0.05 + - 0.05 + - 0.0 + - 0.05 + - 0.0 + - 0.05 + translate_range: null + scale_range: + - - 0.8 + - 1.2 + - - 0.8 + - 1.2 + - - 0.8 + - 1.2 + mode: bilinear + allow_missing_keys: false + - class_path: viscy_transforms.BatchedRandAdjustContrastd + init_args: + keys: + - channel_0 + gamma: + - 0.8 + - 1.2 + prob: 0.5 + invert_image: false + retain_stats: false + allow_missing_keys: false + - class_path: viscy_transforms.BatchedRandScaleIntensityd + init_args: + keys: + - channel_0 + factors: 0.5 + prob: 0.5 + channel_wise: false + allow_missing_keys: false + - class_path: viscy_transforms.BatchedRandGaussianSmoothd + init_args: + keys: + - channel_0 + sigma_x: + - 0.25 + - 0.75 + sigma_y: + - 0.25 + - 0.75 + sigma_z: + - 0.0 + - 0.0 + truncated: 4.0 + prob: 0.5 + border_type: constant + allow_missing_keys: false + - class_path: viscy_transforms.BatchedRandGaussianNoised + init_args: + keys: + - channel_0 + prob: 0.5 + mean: 0.0 + std: 0.2 + allow_missing_keys: false + sample_std: true + cache_pool_bytes: 0 + seed: 0 + include_wells: null + exclude_fovs: null + focus_channel: null + reference_pixel_size_xy_um: null + reference_pixel_size_z_um: null +optimizer: null +lr_scheduler: null +ckpt_path: null diff --git a/applications/dynaclr/configs/evaluation/DINOv3-temporal-MLP-2D-BagOfChannels-v1/infectomics-annotated.yaml b/applications/dynaclr/configs/evaluation/DINOv3-temporal-MLP-2D-BagOfChannels-v1/infectomics-annotated.yaml new file mode 100644 index 000000000..f2f60ac29 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DINOv3-temporal-MLP-2D-BagOfChannels-v1/infectomics-annotated.yaml @@ -0,0 +1,27 @@ +# Evaluation config: DINOv3-temporal-MLP-2D-BagOfChannels-v1 × infectomics-annotated (Wave 1) +# Trains LC pipelines on 14 infectomics experiments using DINOv3-temporal-MLP +# 128-dim projections (.X). Publishes pipelines to the central LC registry. +# +# Usage: +# nextflow run applications/dynaclr/nextflow/main.nf -entry evaluation \ +# --eval_config applications/dynaclr/configs/evaluation/DINOv3-temporal-MLP-2D-BagOfChannels-v1/infectomics-annotated.yaml \ +# --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ +# -resume + +base: + - ../recipes/predict.yml + - ../recipes/reduce.yml + - ../recipes/plot_infectomics.yml + - ../recipes/linear_classifiers_infectomics.yml + - ../recipes/infectomics-annotated.yml + +training_config: /hpc/projects/organelle_phenotyping/models/DINOv3-temporal-MLP-2D-BagOfChannels/v1/config_updated.yaml +ckpt_path: /hpc/projects/organelle_phenotyping/models/DINOv3-temporal-MLP-2D-BagOfChannels/v1/DINOv3-temporal-MLP-2D-BagOfChannels-v1/20260319-235942/checkpoints/epoch=71-step=14040.ckpt +output_dir: /hpc/projects/organelle_phenotyping/models/DINOv3-temporal-MLP-2D-BagOfChannels/v1/evaluations/infectomics-annotated/ + +# Use the trained 128-dim MLP projection head as adata.X. Without this, the +# EmbeddingWriter defaults to "features" and writes the frozen DINOv3 backbone +# output — making this row a duplicate of DINOv3-frozen. The MLP head carries +# all the learned task signal here; the backbone is frozen during training. +predict: + embedding_key: projections diff --git a/applications/dynaclr/configs/evaluation/DINOv3-temporal-MLP-2D-BagOfChannels-v1/run_infectomics_annotated.sh b/applications/dynaclr/configs/evaluation/DINOv3-temporal-MLP-2D-BagOfChannels-v1/run_infectomics_annotated.sh new file mode 100644 index 000000000..aaae57e24 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DINOv3-temporal-MLP-2D-BagOfChannels-v1/run_infectomics_annotated.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Wave-1 evaluation: DINOv3-temporal-MLP-2D-BagOfChannels-v1 x infectomics-annotated. +# Sibling comparison run alongside the DynaCLR families (different architecture). +# +# Submit: +# sbatch applications/dynaclr/configs/evaluation/DINOv3-temporal-MLP-2D-BagOfChannels-v1/run_infectomics_annotated.sh + +#SBATCH --job-name=eval_w1_dinov3_tmlp_v1_infectomics +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=32G +#SBATCH --partition=cpu +#SBATCH --time=1-00:00:00 +#SBATCH --output=/hpc/projects/organelle_phenotyping/models/DINOv3-temporal-MLP-2D-BagOfChannels/v1/evaluations/infectomics-annotated/nextflow_logs/%x-%j.out + +export PYTHONNOUSERSITE=1 + +module load nextflow/24.10.5 + +WORKSPACE="/hpc/mydata/eduardo.hirata/repos/viscy" +CONFIG="$WORKSPACE/applications/dynaclr/configs/evaluation/DINOv3-temporal-MLP-2D-BagOfChannels-v1/infectomics-annotated.yaml" +LOGDIR="/hpc/projects/organelle_phenotyping/models/DINOv3-temporal-MLP-2D-BagOfChannels/v1/evaluations/infectomics-annotated/nextflow_logs" + +mkdir -p "$LOGDIR" +cd "$LOGDIR" + +nextflow run "$WORKSPACE/applications/dynaclr/nextflow/main.nf" -entry evaluation --eval_config "$CONFIG" --workspace_dir "$WORKSPACE" -resume diff --git a/applications/dynaclr/configs/evaluation/DynaCLR-2D-BagOfChannels-v3/infectomics-annotated.yaml b/applications/dynaclr/configs/evaluation/DynaCLR-2D-BagOfChannels-v3/infectomics-annotated.yaml new file mode 100644 index 000000000..5a79450cd --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DynaCLR-2D-BagOfChannels-v3/infectomics-annotated.yaml @@ -0,0 +1,20 @@ +# Evaluation config: DynaCLR-2D-BagOfChannels-v3 × infectomics-annotated (Wave 1) +# Trains LC pipelines on 14 infectomics experiments using DynaCLR-2D-BoC-v3 +# 768-dim features. Publishes pipelines to the central LC registry. +# +# Usage: +# nextflow run applications/dynaclr/nextflow/main.nf -entry evaluation \ +# --eval_config applications/dynaclr/configs/evaluation/DynaCLR-2D-BagOfChannels-v3/infectomics-annotated.yaml \ +# --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ +# -resume + +base: + - ../recipes/predict.yml + - ../recipes/reduce.yml + - ../recipes/plot_infectomics.yml + - ../recipes/linear_classifiers_infectomics.yml + - ../recipes/infectomics-annotated.yml + +training_config: /hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_Ph/organelle_sensor_phase_maxproj_ver3_150epochs/DynaCLR-2D-BagOfChannels-v3.yml +ckpt_path: /hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_Ph/organelle_sensor_phase_maxproj_ver3_150epochs/saved_checkpoints/epoch=104-step=53760.ckpt +output_dir: /hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_Ph/v3/evaluations/infectomics-annotated/ diff --git a/applications/dynaclr/configs/evaluation/DynaCLR-2D-BagOfChannels-v3/run_infectomics_annotated.sh b/applications/dynaclr/configs/evaluation/DynaCLR-2D-BagOfChannels-v3/run_infectomics_annotated.sh new file mode 100644 index 000000000..419c91b10 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DynaCLR-2D-BagOfChannels-v3/run_infectomics_annotated.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Wave-1 evaluation: DynaCLR-2D-BagOfChannels-v3 x infectomics-annotated. +# Sibling comparison run alongside the 2D-MIP-BagOfChannels family. +# +# Submit: +# sbatch applications/dynaclr/configs/evaluation/DynaCLR-2D-BagOfChannels-v3/run_infectomics_annotated.sh + +#SBATCH --job-name=eval_w1_2dboc_v3_infectomics +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=32G +#SBATCH --partition=cpu +#SBATCH --time=1-00:00:00 +#SBATCH --output=/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_Ph/v3/evaluations/infectomics-annotated/nextflow_logs/%x-%j.out + +export PYTHONNOUSERSITE=1 + +module load nextflow/24.10.5 + +WORKSPACE="/hpc/mydata/eduardo.hirata/repos/viscy" +CONFIG="$WORKSPACE/applications/dynaclr/configs/evaluation/DynaCLR-2D-BagOfChannels-v3/infectomics-annotated.yaml" +LOGDIR="/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_Ph/v3/evaluations/infectomics-annotated/nextflow_logs" + +mkdir -p "$LOGDIR" +cd "$LOGDIR" + +nextflow run "$WORKSPACE/applications/dynaclr/nextflow/main.nf" -entry evaluation --eval_config "$CONFIG" --workspace_dir "$WORKSPACE" -resume diff --git a/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-mixed-markers-fix-shuffler/infectomics-annotated.yaml b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-mixed-markers-fix-shuffler/infectomics-annotated.yaml new file mode 100644 index 000000000..8faf4ec9c --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-mixed-markers-fix-shuffler/infectomics-annotated.yaml @@ -0,0 +1,21 @@ +# Evaluation config: DynaCLR-2D-MIP-BagOfChannels-mixed-markers-fix-shuffler × infectomics-annotated (Wave 1) +# Sibling row to DynaCLR-2D-MIP-BagOfChannels for direct comparison. +# mixed-markers (vs single-marker) variant + shuffler fix. +# Run-id: dlzt3s65 (most recent under this training dir). +# +# Submit: +# sbatch applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-mixed-markers-fix-shuffler/run_infectomics_annotated.sh + +base: + - ../recipes/predict.yml + - ../recipes/reduce.yml + - ../recipes/plot_infectomics.yml + - ../recipes/linear_classifiers_infectomics.yml + - ../recipes/infectomics-annotated.yml + +training_config: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-mixed-markers-fix-shuffler/config.yaml +ckpt_path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-mixed-markers-fix-shuffler/DynaCLR-2D-MIP-BagOfChannels/dlzt3s65/checkpoints/last.ckpt +output_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-mixed-markers-fix-shuffler/evaluations/infectomics-annotated/ + +linear_classifiers: + publish_dir: /hpc/projects/organelle_phenotyping/models/linear_classifiers/DynaCLR-2D-MIP-BagOfChannels-mixed-markers-fix-shuffler/infectomics/ diff --git a/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-mixed-markers-fix-shuffler/run_infectomics_annotated.sh b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-mixed-markers-fix-shuffler/run_infectomics_annotated.sh new file mode 100644 index 000000000..f93cf24f4 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-mixed-markers-fix-shuffler/run_infectomics_annotated.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Wave-1: DynaCLR-2D-MIP-BagOfChannels-mixed-markers-fix-shuffler × infectomics-annotated. +# Sibling comparison run. +# +# Submit: +# sbatch applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-mixed-markers-fix-shuffler/run_infectomics_annotated.sh + +#SBATCH --job-name=eval_w1_2dmip_mmfs_infectomics +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=32G +#SBATCH --partition=cpu +#SBATCH --time=1-00:00:00 +#SBATCH --output=/hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-mixed-markers-fix-shuffler/evaluations/infectomics-annotated/nextflow_logs/%x-%j.out + +export PYTHONNOUSERSITE=1 + +module load nextflow/24.10.5 + +WORKSPACE="/hpc/mydata/eduardo.hirata/repos/viscy" +CONFIG="$WORKSPACE/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-mixed-markers-fix-shuffler/infectomics-annotated.yaml" +LOGDIR="/hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-mixed-markers-fix-shuffler/evaluations/infectomics-annotated/nextflow_logs" + +mkdir -p "$LOGDIR" +cd "$LOGDIR" + +nextflow run "$WORKSPACE/applications/dynaclr/nextflow/main.nf" -entry evaluation \ + --eval_config "$CONFIG" \ + --workspace_dir "$WORKSPACE" \ + -resume diff --git a/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-192/infectomics-annotated.yaml b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-192/infectomics-annotated.yaml new file mode 100644 index 000000000..c3f9f4327 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-192/infectomics-annotated.yaml @@ -0,0 +1,25 @@ +# Evaluation: DynaCLR-2D-MIP-BagOfChannels (single-marker, 384->192, zext16) x infectomics-annotated. +# Variant: single-marker batching at the larger 384->216->192 patch with zext16. +# Checkpoint: p6vlebcu/last.ckpt (2026-04-28). +# +# Submit: +# sbatch applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-192/run_infectomics_annotated.sh + +base: + - ../recipes/predict.yml + - ../recipes/reduce.yml + - ../recipes/plot_infectomics.yml + - ../recipes/linear_classifiers_infectomics.yml + - ../recipes/infectomics-annotated.yml + +training_config: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-384to192-zext16-single-marker-fix-shuffler/config.yaml +ckpt_path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-384to192-zext16-single-marker-fix-shuffler/DynaCLR-2D-MIP-BagOfChannels/p6vlebcu/checkpoints/last.ckpt +output_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-384to192-zext16-single-marker-fix-shuffler/evaluations/infectomics-annotated/ + +# Larger 384->192 patches need a smaller predict batch to fit on gpu_2d. +# The default 400 OOMs (exit 137) on this checkpoint; 128 is safe. +predict: + batch_size: 128 + +linear_classifiers: + publish_dir: /hpc/projects/organelle_phenotyping/models/linear_classifiers/DynaCLR-2D-MIP-BagOfChannels-single-marker-192/infectomics/ diff --git a/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-192/run_infectomics_annotated.sh b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-192/run_infectomics_annotated.sh new file mode 100644 index 000000000..2972e4750 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-192/run_infectomics_annotated.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Wave-1 evaluation: DynaCLR-2D-MIP-BagOfChannels (single-marker, 384->192, zext16) x infectomics-annotated. +# +# Submit: +# sbatch applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-192/run_infectomics_annotated.sh + +#SBATCH --job-name=eval_w1_2dmip_sm192_infectomics +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=32G +#SBATCH --partition=cpu +#SBATCH --time=1-00:00:00 +#SBATCH --output=%x-%j.out + +export PYTHONNOUSERSITE=1 + +module load nextflow/24.10.5 + +WORKSPACE="/hpc/mydata/eduardo.hirata/repos/viscy" +CONFIG="$WORKSPACE/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-192/infectomics-annotated.yaml" +LOGDIR="/hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-384to192-zext16-single-marker-fix-shuffler/evaluations/infectomics-annotated/nextflow_logs" + +mkdir -p "$LOGDIR" +cd "$LOGDIR" + +nextflow run "$WORKSPACE/applications/dynaclr/nextflow/main.nf" -entry evaluation \ + --eval_config "$CONFIG" \ + --workspace_dir "$WORKSPACE" \ + -resume diff --git a/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-fix-shuffler/infectomics-annotated.yaml b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-fix-shuffler/infectomics-annotated.yaml new file mode 100644 index 000000000..c4b49021c --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-fix-shuffler/infectomics-annotated.yaml @@ -0,0 +1,21 @@ +# Evaluation config: DynaCLR-2D-MIP-BagOfChannels-single-marker-fix-shuffler × infectomics-annotated (Wave 1) +# Sibling row to DynaCLR-2D-MIP-BagOfChannels for direct comparison. +# Same training-config family, different run (single-marker variant + shuffler fix). +# +# Submit: +# sbatch applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-fix-shuffler/run_infectomics_annotated.sh + +base: + - ../recipes/predict.yml + - ../recipes/reduce.yml + - ../recipes/plot_infectomics.yml + - ../recipes/linear_classifiers_infectomics.yml + - ../recipes/infectomics-annotated.yml + +training_config: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-single-marker-fix-shuffler/config.yaml +ckpt_path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-single-marker-fix-shuffler/DynaCLR-2D-MIP-BagOfChannels/jbrwhzr3/checkpoints/last.ckpt +output_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-single-marker-fix-shuffler/evaluations/infectomics-annotated/ + +# Publish to its own sub-registry: separate model, separate task domain. +linear_classifiers: + publish_dir: /hpc/projects/organelle_phenotyping/models/linear_classifiers/DynaCLR-2D-MIP-BagOfChannels-single-marker-fix-shuffler/infectomics/ diff --git a/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-fix-shuffler/run_infectomics_annotated.sh b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-fix-shuffler/run_infectomics_annotated.sh new file mode 100644 index 000000000..021814dae --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-fix-shuffler/run_infectomics_annotated.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# Wave-1 evaluation: DynaCLR-2D-MIP-BagOfChannels-single-marker-fix-shuffler x infectomics-annotated. +# Sibling comparison run to DynaCLR-2D-MIP-BagOfChannels — same training-config family, +# different ckpt (single-marker + shuffler fix variant). +# +# Submit: +# sbatch applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-fix-shuffler/run_infectomics_annotated.sh + +#SBATCH --job-name=eval_w1_2dmip_smfs_infectomics +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=32G +#SBATCH --partition=cpu +#SBATCH --time=1-00:00:00 +#SBATCH --output=/hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-single-marker-fix-shuffler/evaluations/infectomics-annotated/nextflow_logs/%x-%j.out + +export PYTHONNOUSERSITE=1 + +module load nextflow/24.10.5 + +WORKSPACE="/hpc/mydata/eduardo.hirata/repos/viscy" +CONFIG="$WORKSPACE/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels-single-marker-fix-shuffler/infectomics-annotated.yaml" +LOGDIR="/hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-single-marker-fix-shuffler/evaluations/infectomics-annotated/nextflow_logs" + +mkdir -p "$LOGDIR" +cd "$LOGDIR" + +nextflow run "$WORKSPACE/applications/dynaclr/nextflow/main.nf" -entry evaluation \ + --eval_config "$CONFIG" \ + --workspace_dir "$WORKSPACE" \ + -resume diff --git a/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/alfi.yaml b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/alfi.yaml new file mode 100644 index 000000000..65b89673b --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/alfi.yaml @@ -0,0 +1,34 @@ +# Evaluation config: DynaCLR-2D-MIP-BagOfChannels × alfi (Wave 2) +# Applies LC pipelines from the central registry (trained on infectomics-annotated). +# Data: HeLa (MI06), RPE1 (MI07/MI08), U2OS (MI01-MI05), DIC channel. +# Annotations are appended for plot coloring; no LC training happens here. +# +# Usage: +# nextflow run applications/dynaclr/nextflow/main.nf -entry evaluation \ +# --eval_config applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/alfi.yaml \ +# --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ +# -resume + +base: + - ../recipes/predict.yml + - ../recipes/reduce.yml + - ../recipes/alfi.yml + +training_config: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels.yml +ckpt_path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels/20260403-150013/checkpoints/last.ckpt +output_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/evaluations/alfi/ + +predict: + batch_size: 256 + num_workers: 4 + precision: 32-true + devices: 1 + +append_predictions: + # Once an alfi-annotated Wave-1 leaf trains DIC LCs, switch this to + # `.../alfi/latest`. Today's run uses the infectomics sub-registry — + # marker mismatch (DIC vs G3BP1/SEC61B/Phase3D/viral_sensor) means + # predicted_* columns will be NaN for every cell. The reader logs + # "0/N markers covered" and continues. This is a smoke test for the + # reader path, not a real prediction run. + pipelines_dir: /hpc/projects/organelle_phenotyping/models/linear_classifiers/DynaCLR-2D-MIP-BagOfChannels/infectomics/latest diff --git a/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/infectomics-annotated.yaml b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/infectomics-annotated.yaml new file mode 100644 index 000000000..77b11f559 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/infectomics-annotated.yaml @@ -0,0 +1,27 @@ +# Evaluation config: DynaCLR-2D-MIP-BagOfChannels × infectomics-annotated (Wave 1) +# Trains LC pipelines on 14 infectomics experiments (ZIKV + DENV; G3BP1, SEC61B, +# Phase3D, viral_sensor markers). Publishes pipelines to the central LC registry. +# +# Usage: +# nextflow run applications/dynaclr/nextflow/main.nf -entry evaluation \ +# --eval_config applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/infectomics-annotated.yaml \ +# --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ +# -resume + +base: + - ../recipes/predict.yml + - ../recipes/reduce.yml + - ../recipes/plot_infectomics.yml + - ../recipes/linear_classifiers_infectomics.yml + - ../recipes/infectomics-annotated.yml + +training_config: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels.yml +ckpt_path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels/20260403-150013/checkpoints/last.ckpt +output_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluations/infectomics-annotated/ + +# Publish trained LCs to the per-model + per-task-domain sub-registry. +# Atomically promotes to {publish_dir}/vN/ and updates `latest` symlink. +# Sub-registries (e.g. infectomics/, alfi/) hold pipelines that share an +# annotation domain; Wave-2 leaves choose which sub-registry to fetch. +linear_classifiers: + publish_dir: /hpc/projects/organelle_phenotyping/models/linear_classifiers/DynaCLR-2D-MIP-BagOfChannels/infectomics/ diff --git a/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/microglia.yaml b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/microglia.yaml new file mode 100644 index 000000000..dd4f053b0 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/microglia.yaml @@ -0,0 +1,34 @@ +# Evaluation config: DynaCLR-2D-MIP-BagOfChannels × microglia (Wave 2) +# Applies LC pipelines from the central registry (trained on infectomics-annotated). +# Data: 20191107_1209_1_GW23_dynamorph — Brightfield, Phase3D, Retardance. +# Perturbations: untreated, IL17, IFN-beta, Rubella, Glioblastoma_supernatant. +# Microglia has no annotations, so append_annotations is omitted; markers absent +# from the registry manifest (Brightfield, Retardance) are skipped silently. +# +# Usage: +# nextflow run applications/dynaclr/nextflow/main.nf -entry evaluation \ +# --eval_config applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/microglia.yaml \ +# --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ +# -resume + +base: + - ../recipes/predict.yml + - ../recipes/reduce.yml + - ../recipes/microglia.yml + +training_config: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels.yml +ckpt_path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels/20260403-150013/checkpoints/last.ckpt +output_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/evaluations/microglia/ + +predict: + batch_size: 256 + num_workers: 4 + precision: 32-true + devices: 1 + +append_predictions: + # Microglia has Brightfield/Retardance markers that aren't in the + # infectomics manifest, so the coverage report will log "missing" for + # those. Phase3D may overlap. Switching to a microglia-trained + # sub-registry (e.g. `.../microglia/latest`) is a future option. + pipelines_dir: /hpc/projects/organelle_phenotyping/models/linear_classifiers/DynaCLR-2D-MIP-BagOfChannels/infectomics/latest diff --git a/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/run_infectomics_annotated.sh b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/run_infectomics_annotated.sh new file mode 100644 index 000000000..7f11b521c --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/run_infectomics_annotated.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Wave-1 evaluation: DynaCLR-2D-MIP-BagOfChannels x infectomics-annotated. +# Trains linear classifiers on the 14 ZIKV+DENV infectomics experiments and +# publishes them to the central LC registry at +# /hpc/projects/organelle_phenotyping/models/linear_classifiers/DynaCLR-2D-MIP-BagOfChannels/vN/ +# with a `latest` symlink updated atomically at the end of the LC step. +# +# Submit: +# sbatch applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/run_infectomics_annotated.sh + +#SBATCH --job-name=eval_w1_2dmip_infectomics +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=32G +#SBATCH --partition=cpu +#SBATCH --time=1-00:00:00 +#SBATCH --output=%x-%j.out +# Wrapper hosts the Nextflow head + any local-executor processes (per-experiment +# PLOT runs locally; PLOT_COMBINED + REDUCE_COMBINED + LC + PREDICT go to slurm). +# 32G + 4 cpus is enough for 19 sequential per-experiment plots on 350k cells. + +export PYTHONNOUSERSITE=1 + +module load nextflow/24.10.5 + +WORKSPACE="/hpc/mydata/eduardo.hirata/repos/viscy" +CONFIG="$WORKSPACE/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/infectomics-annotated.yaml" +LOGDIR="/hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluations/infectomics-annotated/nextflow_logs" + +mkdir -p "$LOGDIR" +cd "$LOGDIR" + +nextflow run "$WORKSPACE/applications/dynaclr/nextflow/main.nf" -entry evaluation \ + --eval_config "$CONFIG" \ + --workspace_dir "$WORKSPACE" \ + -resume diff --git a/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/run_microglia.sh b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/run_microglia.sh new file mode 100644 index 000000000..a8cc4d63c --- /dev/null +++ b/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/run_microglia.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Wave-2 evaluation: DynaCLR-2D-MIP-BagOfChannels x microglia. +# Applies LC pipelines from the central registry's infectomics sub-registry: +# /hpc/projects/.../linear_classifiers/DynaCLR-2D-MIP-BagOfChannels/infectomics/latest +# No annotations to append. Marker mismatch (microglia has Brightfield, +# Phase3D, Retardance vs registry trained on G3BP1/SEC61B/Phase3D/viral_sensor) +# means only Phase3D cells get predictions. Coverage report logged. +# +# Submit: +# sbatch applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/run_microglia.sh + +#SBATCH --job-name=eval_w2_2dmip_microglia +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=32G +#SBATCH --partition=cpu +#SBATCH --time=1-00:00:00 +#SBATCH --output=%x-%j.out + +export PYTHONNOUSERSITE=1 + +module load nextflow/24.10.5 + +WORKSPACE="/hpc/mydata/eduardo.hirata/repos/viscy" +CONFIG="$WORKSPACE/applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/microglia.yaml" +LOGDIR="/hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/evaluations/microglia/nextflow_logs" + +mkdir -p "$LOGDIR" +cd "$LOGDIR" + +nextflow run "$WORKSPACE/applications/dynaclr/nextflow/main.nf" -entry evaluation \ + --eval_config "$CONFIG" \ + --workspace_dir "$WORKSPACE" \ + -resume diff --git a/applications/dynaclr/configs/evaluation/eval_registry.yaml b/applications/dynaclr/configs/evaluation/eval_registry.yaml new file mode 100644 index 000000000..212566a4f --- /dev/null +++ b/applications/dynaclr/configs/evaluation/eval_registry.yaml @@ -0,0 +1,58 @@ +# Eval registry — input to `compare_evals.py` for cross-model comparison. +# +# Each entry points to a model's eval `output_dir` (where smoothness/, +# linear_classifiers/, mmd/ artifacts land). One entry per (model x dataset) +# pair you want overlaid in the comparison plots / CSVs. +# +# Mental model: +# - Each leaf YAML at configs/evaluation/{model}/{dataset}.yaml declares an +# `output_dir`. After `nextflow run -entry evaluation` finishes, that +# directory is the `eval_dir` listed here. +# - To compare 3 models on infectomics-annotated, list 3 entries — all with +# the same dataset suffix in their `eval_dir`. +# - To compare 1 model across 3 datasets, list 3 entries — different dataset +# suffixes. +# +# Usage: +# python applications/dynaclr/scripts/evaluation/compare_evals.py \ +# -c applications/dynaclr/configs/evaluation/eval_registry.yaml +# +# Status: +# ✅ = eval has run, artifacts exist +# ⏳ = eval queued / running +# ⬜ = leaf YAML exists but no run yet — uncomment after first successful run + +output_dir: /hpc/projects/organelle_phenotyping/comparisons/infectomics-annotated/ +fdr_threshold: 0.05 + +models: + # --- DynaCLR-2D-MIP-BagOfChannels family (siblings: same recipe, different ckpt) --- + # Baseline: pre fix-shuffler mixed-markers, 192->160, zext11. + - name: DynaCLR-2D-MIP-BagOfChannels + eval_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluations/infectomics-annotated/ + + # Single-marker batching (batch_group_by=marker), 192->160, zext11. + - name: DynaCLR-2D-MIP-BagOfChannels-single-marker-fix-shuffler + eval_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-single-marker-fix-shuffler/evaluations/infectomics-annotated/ + + # Mixed-markers post fix-shuffler, 192->160, zext11. + - name: DynaCLR-2D-MIP-BagOfChannels-mixed-markers-fix-shuffler + eval_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-mixed-markers-fix-shuffler/evaluations/infectomics-annotated/ + + # Single-marker, larger patch (384->216->192) with zext16 — tests whether more + # subcellular detail at the cost of fewer samples per batch helps. + - name: DynaCLR-2D-MIP-BagOfChannels-single-marker-192 + eval_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-384to192-zext16-single-marker-fix-shuffler/evaluations/infectomics-annotated/ + + # --- Other architectures --- + - name: DynaCLR-2D-BagOfChannels-v3 + eval_dir: /hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_Ph/v3/evaluations/infectomics-annotated/ + + - name: DINOv3-temporal-MLP-2D-BagOfChannels-v1 + eval_dir: /hpc/projects/organelle_phenotyping/models/DINOv3-temporal-MLP-2D-BagOfChannels/v1/evaluations/infectomics-annotated/ + + # DINOv3-frozen baseline: raw HuggingFace DINOv3 convnext-tiny weights, no + # fine-tuning, no projection. Tests whether contrastive training adds value + # over pre-trained features alone. + - name: DINOv3-frozen + eval_dir: /hpc/projects/organelle_phenotyping/models/DINOv3-frozen/evaluations/infectomics-annotated/ diff --git a/applications/dynaclr/configs/evaluation/eval_registry_infectomics.yaml b/applications/dynaclr/configs/evaluation/eval_registry_infectomics.yaml new file mode 100644 index 000000000..e467f501f --- /dev/null +++ b/applications/dynaclr/configs/evaluation/eval_registry_infectomics.yaml @@ -0,0 +1,57 @@ +# Registry for compare_evals.py — infectomics-annotated (Wave-1) cross-model comparison. +# +# Lists every model that has produced +# {model_root}/evaluations/infectomics-annotated/linear_classifiers/metrics_summary.csv +# under /hpc/projects/organelle_phenotyping/models/. Each `name` is the legend +# label in the comparison plots; the model→color palette is built from the +# `name` strings, so keep them stable across registries when comparing the +# same model on different datasets. +# +# Status legend in comments: +# ✅ landed — eval ran, artifacts exist +# 🔄 running / queued +# ⬜ pending — leaf YAML exists but no run yet (uncomment after first run) +# +# Usage: +# uv run python applications/dynaclr/scripts/evaluation/compare_evals.py \ +# -c applications/dynaclr/configs/evaluation/eval_registry_infectomics.yaml + +models: + # --- Baselines ----------------------------------------------------------- + # ✅ DINOv3-frozen — raw HF convnext-tiny, no fine-tuning, no projection + - name: DINOv3-frozen + eval_dir: /hpc/projects/organelle_phenotyping/models/DINOv3-frozen/evaluations/infectomics-annotated + + # ✅ DINOv3-temporal-MLP-v1 — fairly-matched SSL baseline (HF DINOv3 + MLP head) + - name: DINOv3-temporal-MLP-v1 + eval_dir: /hpc/projects/organelle_phenotyping/models/DINOv3-temporal-MLP-2D-BagOfChannels/v1/evaluations/infectomics-annotated + + # ⬜ DynaCLR-classical — historical 150-epoch ckpt (older dataset). + # Leaf scaffolded 2026-05-05; submit once Wave-1 head runs. + # - name: DynaCLR-classical (historical) + # eval_dir: /hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/classical/dynaclr_gfp_rfp_ph_2D/organelle_sensor_phase_maxproj_ver1_150epochs/evaluations/infectomics-annotated + + # --- DynaCLR-2D-MIP-BagOfChannels family (BoC variant bake-off) ---------- + # ✅ Baseline: pre fix-shuffler mixed-markers, 192→160, zext11 + - name: DynaCLR-2D-BoC (192to160, zext11) + eval_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluations/infectomics-annotated + + # ✅ Single-marker batching (batch_group_by=marker), 192→160, zext11 + - name: DynaCLR-2D-BoC (192to160, single-marker-fix) + eval_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-single-marker-fix-shuffler/evaluations/infectomics-annotated + + # ✅ Mixed-markers post fix-shuffler, 192→160, zext11 + - name: DynaCLR-2D-BoC (192to160, mixed-markers-fix) + eval_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-mixed-markers-fix-shuffler/evaluations/infectomics-annotated + + # ✅ Single-marker, larger patch (384→216→192) with zext16 + - name: DynaCLR-2D-BoC (384to192, zext16, single-marker-fix) + eval_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-384to192-zext16-single-marker-fix-shuffler/evaluations/infectomics-annotated + + # --- DynaCLR-2D-BagOfChannels-v3 (SEC61/TOMM20/G3BP1/Sensor, time-interval) --- + # ✅ Legacy ContrastiveModule, wrapped via dynaclr.engine for inference + - name: DynaCLR-2D-BagOfChannels-v3 + eval_dir: /hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_Ph/v3/evaluations/infectomics-annotated + +output_dir: /hpc/projects/organelle_phenotyping/comparisons/infectomics-annotated +fdr_threshold: 0.05 diff --git a/applications/dynaclr/configs/evaluation/export_onnx_2d_mip_boc.yml b/applications/dynaclr/configs/evaluation/export_onnx_2d_mip_boc.yml new file mode 100644 index 000000000..52d70fc76 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/export_onnx_2d_mip_boc.yml @@ -0,0 +1,24 @@ +model: + class_path: dynaclr.engine.ContrastiveModule + init_args: + encoder: + class_path: viscy_models.contrastive.ContrastiveEncoder + init_args: + backbone: convnext_tiny + in_channels: 1 + in_stack_depth: 1 + stem_kernel_size: [1, 4, 4] + stem_stride: [1, 4, 4] + embedding_dim: 768 + projection_dim: 32 + drop_path_rate: 0.1 + loss_function: + class_path: viscy_models.contrastive.loss.NTXentLoss + init_args: + temperature: 0.2 + lr: 0.00002 + example_input_array_shape: [1, 1, 1, 160, 160] + +ckpt_path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels/20260403-150013/checkpoints/last.ckpt + +export_path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluations/onnx/last.onnx diff --git a/applications/dynaclr/configs/evaluation/predict_microglia_alfi.sh b/applications/dynaclr/configs/evaluation/predict_microglia_alfi.sh new file mode 100644 index 000000000..14736dac4 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/predict_microglia_alfi.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Predict embeddings for microglia and ALFI datasets +# Uses DynaCLR-2D-MIP-BagOfChannels checkpoint. +# +# Usage: +# sbatch applications/dynaclr/configs/evaluation/predict_microglia_alfi.sh + +#SBATCH --job-name=dynaclr_predict_microglia_alfi +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=8 +#SBATCH --mem-per-cpu=16G +#SBATCH --time=3:00:00 + +export PYTHONNOUSERSITE=1 +WORKSPACE_DIR="/hpc/mydata/eduardo.hirata/repos/viscy" + +# echo "=== Microglia predict ===" +# srun uv run --project /hpc/mydata/eduardo.hirata/repos/viscy viscy predict --config /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/evaluations/microglia/configs/predict.yml + +echo "=== ALFI predict ===" +srun uv run --project /hpc/mydata/eduardo.hirata/repos/viscy viscy predict --config /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/evaluations/alfi/configs/predict.yml diff --git a/applications/dynaclr/configs/evaluation/recipes/alfi.yml b/applications/dynaclr/configs/evaluation/recipes/alfi.yml new file mode 100644 index 000000000..48c7147d4 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/recipes/alfi.yml @@ -0,0 +1,51 @@ +# Dataset recipe: alfi (mitosis) +# ============================================================================= +# Wave-2 column of the evaluation matrix. Applies LC pipelines published by the +# same model's infectomics-annotated run (central registry). No LC training, +# no MMD. Leaves override training_config + ckpt_path + output_dir + +# append_predictions.pipelines_dir (per-model registry path). +# +# Cell index covers HeLa (MI06), RPE1 (MI07/MI08), U2OS (MI01-MI05) — DIC channel. +# Annotations: ALFI_combined_annotations.csv (cell_division_state, cell_death_state). +# +# Leaves merge this via: +# base: +# - recipes/predict.yml +# - recipes/reduce.yml +# - recipes/alfi.yml + +cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/alfi-eval.parquet + +steps: + - predict + - split + - reduce_dimensionality + - reduce_combined + - smoothness + - append_annotations + - append_predictions + - plot + +append_annotations: + annotations: + - experiment: "ALFI_HeLa_DMSO_MLN8237" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/ALFI/ALFI_combined_annotations.csv + - experiment: "ALFI_RPE1_untreated" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/ALFI/ALFI_combined_annotations.csv + - experiment: "ALFI_U2OS_DMSO_MLN8237" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/ALFI/ALFI_combined_annotations.csv + +plot: + embedding_keys: + - X_pca + combined_embedding_keys: + - X_pca_combined + - X_phate_combined + color_by: + - cell_division_state + - cell_death_state + - experiment + - marker + point_size: 1.0 + components: [0, 1] + format: pdf diff --git a/applications/dynaclr/configs/evaluation/recipes/infectomics-annotated.yml b/applications/dynaclr/configs/evaluation/recipes/infectomics-annotated.yml new file mode 100644 index 000000000..dc0288a2d --- /dev/null +++ b/applications/dynaclr/configs/evaluation/recipes/infectomics-annotated.yml @@ -0,0 +1,34 @@ +# Dataset recipe: infectomics-annotated +# ============================================================================= +# Wave-1 column of the evaluation matrix. Used by all model leaves to train +# linear classifiers on annotated infectomics data and publish them to the +# central LC registry. Leaf configs override training_config + ckpt_path + +# output_dir + linear_classifiers.publish_dir (the registry path is per-model). +# +# Composition: +# - cell_index_path — shared infectomics parquet (14 experiments) +# - steps — full pipeline incl. LC + append_annotations + append_predictions +# - linear_classifiers — annotations + tasks (inherited from linear_classifiers_infectomics.yml) +# - plot_combined — combined-only scatter (skips per-exp fan-out); +# inherits combined_embedding_keys from plot_infectomics.yml +# +# Leaves merge this via: +# base: +# - recipes/predict.yml +# - recipes/reduce.yml +# - recipes/plot_infectomics.yml +# - recipes/linear_classifiers_infectomics.yml +# - recipes/infectomics-annotated.yml + +cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-BoC-lc-evaluation-v1.parquet + +steps: + - predict + - split + - reduce_dimensionality + - reduce_combined + - smoothness + - linear_classifiers + - append_annotations + - append_predictions + - plot_combined diff --git a/applications/dynaclr/configs/evaluation/recipes/linear_classifiers_infectomics.yml b/applications/dynaclr/configs/evaluation/recipes/linear_classifiers_infectomics.yml new file mode 100644 index 000000000..a4e15e44c --- /dev/null +++ b/applications/dynaclr/configs/evaluation/recipes/linear_classifiers_infectomics.yml @@ -0,0 +1,59 @@ +# Linear classifier settings for the infectomics benchmark. +# Covers ZIKV + DENV datasets across G3BP1, SEC61B, Phase3D, viral_sensor markers. +# Every experiment here needs an annotation CSV — when an experiment is listed +# without a matching CSV (or the CSV's tracks don't overlap the zarr obs), the +# LC step writes nothing and downstream scripts (Stage 3d label-timing) quietly +# get `predicted_* = NaN`. Add every zarr that needs predictions; missing the +# sensor-channel zarrs is how we lost ZIKV pool coverage in v1. +linear_classifiers: + annotations: + - experiment: "2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv + - experiment: "2025_01_28_A549_viral_sensor_ZIKV_DENV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv + - experiment: "2025_01_28_A549_Phase3D_ZIKV_DENV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv + - experiment: "2025_07_24_A549_G3BP1_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_24_A549_SEC61_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_24_A549_viral_sensor_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_24_A549_Phase3D_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2024_11_07_A549_SEC61_DENV_SEC61B" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2024_11_07_A549_SEC61_DENV/2024_11_07_A549_SEC61B_DENV_combined_annotations.csv + - experiment: "2024_11_07_A549_SEC61_DENV_Phase3D" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2024_11_07_A549_SEC61_DENV/2024_11_07_A549_SEC61B_DENV_combined_annotations.csv + - experiment: "2024_11_07_A549_SEC61_DENV_viral_sensor" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2024_11_07_A549_SEC61_DENV/2024_11_07_A549_SEC61B_DENV_combined_annotations.csv + - experiment: "2025_01_24_A549_G3BP1_DENV_G3BP1" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_24_A549_G3BP1_DENV/2025_01_24_A549_G3BP1_DENV_combined_annotations.csv + - experiment: "2025_01_24_A549_G3BP1_DENV_Phase3D" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_24_A549_G3BP1_DENV/2025_01_24_A549_G3BP1_DENV_combined_annotations.csv + - experiment: "2025_01_24_A549_G3BP1_DENV_viral_sensor" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_24_A549_G3BP1_DENV/2025_01_24_A549_G3BP1_DENV_combined_annotations.csv + - experiment: "2025_07_22_A549_G3BP1_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_22_A549_viral_sensor_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_22_A549_Phase3D_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_08_26_A549_SEC61_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_08_26_A549_SEC61_TOMM20_ZIKV/2025_08_26_A549_SEC61_TOMM20_ZIKV_combined_annotations.csv + - experiment: "2025_08_26_A549_viral_sensor_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_08_26_A549_SEC61_TOMM20_ZIKV/2025_08_26_A549_SEC61_TOMM20_ZIKV_combined_annotations.csv + - experiment: "2025_08_26_A549_Phase3D_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_08_26_A549_SEC61_TOMM20_ZIKV/2025_08_26_A549_SEC61_TOMM20_ZIKV_combined_annotations.csv + tasks: + - task: infection_state + - task: cell_division_state + - task: organelle_state + marker_filters: + - G3BP1 + - SEC61B + - task: cell_death_state + use_scaling: true + use_pca: false + split_train_data: 0.8 + random_seed: 42 diff --git a/applications/dynaclr/configs/evaluation/recipes/microglia.yml b/applications/dynaclr/configs/evaluation/recipes/microglia.yml new file mode 100644 index 000000000..3de411e87 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/recipes/microglia.yml @@ -0,0 +1,42 @@ +# Dataset recipe: microglia (dynamorph) +# ============================================================================= +# Wave-2 column of the evaluation matrix. Applies LC pipelines published by the +# same model's infectomics-annotated run (central registry). Microglia has no +# annotation CSVs, so append_annotations is omitted; append_predictions still +# runs and silently skips cells whose markers (Brightfield, Retardance) are +# absent from the registry manifest. +# +# Data: 20191107_1209_1_GW23_dynamorph (Brightfield, Phase3D, Retardance). +# Perturbations: untreated, IL17, IFN-beta, Rubella, Glioblastoma_supernatant. +# +# Leaves merge this via: +# base: +# - recipes/predict.yml +# - recipes/reduce.yml +# - recipes/microglia.yml + +cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/microglia-eval.parquet + +steps: + - predict + - split + - reduce_dimensionality + - reduce_combined + - smoothness + - append_predictions + - plot + +plot: + embedding_keys: + - X_pca + combined_embedding_keys: + - X_pca_combined + - X_phate_combined + color_by: + - perturbation + - hours_post_perturbation + - experiment + - marker + point_size: 1.0 + components: [0, 1] + format: pdf diff --git a/applications/dynaclr/configs/evaluation/recipes/mmd_defaults.yml b/applications/dynaclr/configs/evaluation/recipes/mmd_defaults.yml new file mode 100644 index 000000000..8370035b4 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/recipes/mmd_defaults.yml @@ -0,0 +1,29 @@ +# Default MMD algorithm settings shared across all MMD eval configs. +# Use as a base: reference in per-experiment or pooled MMD configs to avoid +# repeating these parameters. Override any field in the leaf config. +# +# Usage: +# base: recipes/mmd_defaults.yml +# input_path: /path/to/embeddings.zarr +# output_dir: /path/to/output +# comparisons: +# - cond_a: uninfected +# cond_b: ZIKV +# label: "uninfected vs ZIKV" + +group_by: perturbation +save_plots: true + +mmd: + n_permutations: 1000 + max_cells: 2000 + min_cells: 20 + seed: 42 + balance_samples: false + share_bandwidth_from: null + +map_settings: + enabled: false + distance: cosine + null_size: 10000 + seed: 0 diff --git a/applications/dynaclr/configs/evaluation/recipes/plot_infectomics.yml b/applications/dynaclr/configs/evaluation/recipes/plot_infectomics.yml new file mode 100644 index 000000000..7b1587ffa --- /dev/null +++ b/applications/dynaclr/configs/evaluation/recipes/plot_infectomics.yml @@ -0,0 +1,16 @@ +# Default plot settings for infectomics DynaCLR evaluation. +plot: + embedding_keys: + - X_pca + combined_embedding_keys: + - X_pca_combined + - X_phate_combined + color_by: + - perturbation + - hours_post_perturbation + - experiment + - marker + point_size: 1.0 + components: [0, 1] + pairplot_components: 4 # render PC1..PC4 grid (4x4 = 16 panels per coloring); bump for paper figures + format: pdf diff --git a/applications/dynaclr/configs/evaluation/recipes/predict.yml b/applications/dynaclr/configs/evaluation/recipes/predict.yml new file mode 100644 index 000000000..1dcc4951e --- /dev/null +++ b/applications/dynaclr/configs/evaluation/recipes/predict.yml @@ -0,0 +1,6 @@ +# Default predict step settings for DynaCLR evaluation. +predict: + batch_size: 400 + num_workers: 4 + precision: 32-true + devices: 1 diff --git a/applications/dynaclr/configs/evaluation/recipes/reduce.yml b/applications/dynaclr/configs/evaluation/recipes/reduce.yml new file mode 100644 index 000000000..22eaaa3fc --- /dev/null +++ b/applications/dynaclr/configs/evaluation/recipes/reduce.yml @@ -0,0 +1,29 @@ +# Default dimensionality reduction settings for DynaCLR evaluation. +# PHATE runs only in reduce_combined; per-experiment reduce_dimensionality uses PCA only. +# Override n_jobs for reduce_combined.phate in the leaf config if needed. +reduce_dimensionality: + overwrite_keys: true + pca: + n_components: 32 + normalize_features: true + +reduce_combined: + overwrite_keys: true + pca: + n_components: 32 + normalize_features: true + + phate: + n_components: 2 + knn: 5 + decay: 40 + knn_dist: cosine + scale_embeddings: false + random_state: 42 + n_pca: null # skip PHATE's internal PCA — fit on X_pca_combined + subsample: 100 # per-store lineage cap — fast iteration; bump for paper figures (e.g. 5000) + # -1 = sklearn convention "use all CPUs", resolved SLURM-aware via + # viscy_utils.mp_utils.available_cpus (reads SLURM_CPUS_PER_TASK). + # KNN search dominates wall time; BLAS env vars in reduce_combined.nf + # are pinned to 1 to avoid n_jobs * BLAS oversubscription. + n_jobs: -1 diff --git a/applications/dynaclr/configs/evaluation/test_evaluation.yaml b/applications/dynaclr/configs/evaluation/test_evaluation.yaml new file mode 100644 index 000000000..02646e1e0 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/test_evaluation.yaml @@ -0,0 +1,77 @@ +# Minimal test config for MMD + linear classifier evaluation. +# Collection: DynaCLR-BoC-lc-evaluation-v1-test (7 experiments, 3 markers x 2 dates) +# 2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1: B/4 (uninfected), C/4 (infected) — G3BP1 +# 2025_07_22_A549_G3BP1_ZIKV: C/2 (infected) — G3BP1 +# 2025_07_22_A549_Phase3D_ZIKV: C/2 (infected) — Phase3D +# 2025_07_22_A549_viral_sensor_ZIKV: C/2 (infected) — viral_sensor +# 2025_07_24_A549_G3BP1_ZIKV: C/1 (uninfected), C/2 (infected) — G3BP1 +# 2025_07_24_A549_Phase3D_ZIKV: C/1 (uninfected), C/2 (infected) — Phase3D +# 2025_07_24_A549_viral_sensor_ZIKV: C/1 (uninfected), C/2 (infected) — viral_sensor +# +# nextflow run applications/dynaclr/nextflow/main.nf \ +# --eval_config applications/dynaclr/configs/evaluation/test_evaluation.yaml \ +# --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ +# -resume -profile local + +base: + - recipes/predict.yml + - recipes/reduce.yml + +training_config: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels.yml +ckpt_path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/DynaCLR-2D-MIP-BagOfChannels/20260403-150013/checkpoints/last.ckpt +cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-BoC-lc-evaluation-v1-test.parquet +output_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluation_test_lc_2 + +steps: + - predict + - split + - reduce_dimensionality + - reduce_combined + - linear_classifiers + - smoothness + +# Override n_jobs for smaller test run +reduce_combined: + phate: + n_jobs: 12 + +mmd: + - name: perturbation + group_by: perturbation + comparisons: + - cond_a: uninfected + cond_b: infected + label: "uninfected vs infected" + temporal_bin_size: 4.0 + combined_temporal_bin_size: null + combined_mode: true + +linear_classifiers: + annotations: + - experiment: "2025_01_28_A549_G3BP1_ZIKV_DENV_G3BP1" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv + - experiment: "2025_01_28_A549_Phase3D_ZIKV_DENV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_01_28_A549_G3BP1_ZIKV_DENV/2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv + - experiment: "2025_07_24_A549_G3BP1_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_24_A549_Phase3D_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_24_A549_viral_sensor_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_22_A549_G3BP1_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_22_A549_Phase3D_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + - experiment: "2025_07_22_A549_viral_sensor_ZIKV" + path: /hpc/projects/organelle_phenotyping/datasets/annotations/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv + tasks: + - task: infection_state + - task: cell_division_state + - task: organelle_state + marker_filters: + - G3BP1 + - task: cell_death_state + use_scaling: true + use_pca: false + split_train_data: 0.8 + random_seed: 42 diff --git a/applications/dynaclr/configs/evaluation/tracking/ctc_tracking_2d_mip_boc.yaml b/applications/dynaclr/configs/evaluation/tracking/ctc_tracking_2d_mip_boc.yaml new file mode 100644 index 000000000..7c68590c7 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/tracking/ctc_tracking_2d_mip_boc.yaml @@ -0,0 +1,20 @@ +models: + - path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluations/onnx/last.onnx + label: DynaCLR-2D-MIP + pixel_size_um: 0.149 # training pixel size (ALFI dragonfly) + - path: null + label: baseline-iou + +datasets: + - path: /hpc/reference/group.royer/CTC/training/DIC-C2DH-HeLa + sequences: ["01", "02"] + pixel_size_um: 0.190 # DIC-C2DH-HeLa from TIFF XResolution metadata + +ctc_metadata_path: /hpc/reference/group.royer/CTC/metadata.yaml +model_input_shape: [160, 160] +distance_threshold: 325.0 +n_neighbors: 10 +delta_t: 5 +batch_size: 128 +output_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluations/ctc_tracking/ +show_napari: false diff --git a/applications/dynaclr/configs/evaluation/tracking/ctc_tracking_2d_mip_boc_all.sh b/applications/dynaclr/configs/evaluation/tracking/ctc_tracking_2d_mip_boc_all.sh new file mode 100644 index 000000000..10bb6ba95 --- /dev/null +++ b/applications/dynaclr/configs/evaluation/tracking/ctc_tracking_2d_mip_boc_all.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# CTC tracking accuracy benchmark — DynaCLR-2D-MIP vs IoU baseline +# Runs on all 9 2D CTC training datasets. +# +# sbatch applications/dynaclr/configs/evaluation/ctc_tracking_2d_mip_boc_all.sh + +#SBATCH --job-name=ctc_tracking +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=64G +#SBATCH --partition=gpu +#SBATCH --gres=gpu:1 +#SBATCH --time=0-02:00:00 +#SBATCH --output=%x-%j.out + +export PYTHONNOUSERSITE=1 +export GRB_LICENSE_FILE=/home/eduardo.hirata/gurobi/gurobi.lic + +WORKSPACE="/hpc/mydata/eduardo.hirata/repos/viscy" +CONFIG="$WORKSPACE/applications/dynaclr/configs/evaluation/ctc_tracking_2d_mip_boc_all.yaml" + +uv run --project "$WORKSPACE" dynaclr evaluate-tracking-accuracy -c "$CONFIG" diff --git a/applications/dynaclr/configs/evaluation/tracking/ctc_tracking_2d_mip_boc_all.yaml b/applications/dynaclr/configs/evaluation/tracking/ctc_tracking_2d_mip_boc_all.yaml new file mode 100644 index 000000000..9b81b3dca --- /dev/null +++ b/applications/dynaclr/configs/evaluation/tracking/ctc_tracking_2d_mip_boc_all.yaml @@ -0,0 +1,37 @@ +models: + - path: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluations/onnx/last.onnx + label: DynaCLR-2D-MIP + pixel_size_um: 0.149 # training pixel size (Mantis-v1 ) + - path: null + label: baseline-iou + +# 2D datasets only — 3D datasets excluded (model is 2D-only) +# pixel_size_um is auto-detected from TIFF XResolution metadata +datasets: + - path: /hpc/reference/group.royer/CTC/training/BF-C2DL-HSC + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/BF-C2DL-MuSC + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/DIC-C2DH-HeLa + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/Fluo-C2DL-MSC + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/Fluo-N2DH-GOWT1 + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/Fluo-N2DH-SIM+ + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/Fluo-N2DL-HeLa + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/PhC-C2DH-U373 + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/PhC-C2DL-PSC + sequences: ["01", "02"] + +ctc_metadata_path: /hpc/reference/group.royer/CTC/metadata.yaml +model_input_shape: [160, 160] +distance_threshold: 325.0 +n_neighbors: 10 +delta_t: 5 +batch_size: 128 +output_dir: /hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluations/ctc_tracking_all/ +show_napari: false diff --git a/applications/dynaclr/configs/linear_classifiers/evaluate_dataset_example.yaml b/applications/dynaclr/configs/linear_classifiers/evaluate_dataset_example.yaml deleted file mode 100644 index c2514d04e..000000000 --- a/applications/dynaclr/configs/linear_classifiers/evaluate_dataset_example.yaml +++ /dev/null @@ -1,38 +0,0 @@ -# Example configuration for evaluate_dataset.py -# -# Usage: -# python evaluate_dataset.py -c configs/evaluate_dataset_example.yaml -# python evaluate_dataset.py -c configs/evaluate_dataset_example.yaml --report - -dataset_name: my_test_dataset -test_annotations_csv: /path/to/test_annotations.csv -output_dir: /path/to/output - -models: - 2D: - name: DynaCLR-2D-BagOfChannels-timeaware - version: v3 - wandb_project: linearclassifiers-DynaCLR-2D-BagOfChannels-timeaware-v3 - test_embeddings_dir: /path/to/2D/embeddings/ - train_datasets: - - embeddings_dir: /path/to/train_ds1/embeddings/ - annotations: /path/to/train_ds1/annotations.csv - - embeddings_dir: /path/to/train_ds2/embeddings/ - annotations: /path/to/train_ds2/annotations.csv - -# Optional: auto-detected from test CSV if omitted -task_channels: - infection_state: [phase, sensor] - cell_division_state: [phase] - -# Classifier hyperparams (all optional, shown with defaults) -use_scaling: true -n_pca_components: null -max_iter: 1000 -class_weight: balanced -solver: liblinear -split_train_data: 0.8 -random_seed: 42 - -# W&B logging (set to false for local-only runs) -wandb_logging: true diff --git a/applications/dynaclr/configs/prediction/predict.yml b/applications/dynaclr/configs/prediction/predict.yml index a76cf05c6..0f560fa8c 100644 --- a/applications/dynaclr/configs/prediction/predict.yml +++ b/applications/dynaclr/configs/prediction/predict.yml @@ -11,6 +11,9 @@ trainer: num_nodes: 1 precision: 32-true callbacks: + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: 10 - class_path: viscy_utils.callbacks.embedding_writer.EmbeddingWriter init_args: output_path: #TODO point to the path to save the embeddings diff --git a/applications/dynaclr/configs/pseudotime/candidates.yaml b/applications/dynaclr/configs/pseudotime/candidates.yaml new file mode 100644 index 000000000..4691ee5cd --- /dev/null +++ b/applications/dynaclr/configs/pseudotime/candidates.yaml @@ -0,0 +1,92 @@ +# Stage 0 config — candidate selection, lineage reconnection, cohort tagging. +# Load via: --datasets datasets.yaml --config candidates.yaml +# +# Defines candidate sets that select the productive cohort (and optionally +# a mock cohort from uninfected wells via well_pattern). Bystander and +# abortive cohorts are derived from the productive set's wells using the +# LC predictions (per discussion §3.2 and the locked execution plan). + +candidate_sets: + zikv_productive_07_24: + description: | + Productive ZIKV cohort from 07_24, SEC61 (A/2) + G3BP1 (C/2) wells. + The `productive` cohort uses the existing transitioning filter as + the anchor source. `bystander` and `abortive` are LC-derived from + the same wells. `mock` comes from uninfected control wells (A/1, C/1). + datasets: ["2025_07_24_SEC61", "2025_07_24_G3BP1"] + productive_filter: + anchor_label: infection_state + anchor_positive: infected + anchor_negative: uninfected + min_pre_minutes: 240 # discussion §3.6: target h_pre = 240–360 min + min_post_minutes: 360 + crop_window_minutes: 360 # generous; readout windows pick within this + cohort_rules: + bystander_uninfected_fraction: 0.8 + abortive_min_run: 3 + mock_well_patterns: + # Uninfected control wells per dataset entries. Read from + # datasets.yaml control_fov_pattern when available; this list + # overrides for explicitness. + - "A/1" + - "C/1" + lineage_rules: + reconnect: true + return_both_branches: true # keep both daughters; daughter handling regime-dependent downstream + transition_window_k_pre_minutes: 60 + transition_window_k_post_minutes: 120 + max_lineages: 200 # cap; longer lineages preferred when over + + zikv_productive_pooled: + description: | + Pooled productive cohort across 07_24 (SEC61 + G3BP1), 07_22 (G3BP1 + only), and 08_26 (SEC61 only). Used to increase n for + compare_phase_to_fluor — the 07_24-only set was n=6/7. + + Filter is loosened to a 300-min track-length floor (10 frames at + 30 min/frame, 30 frames at 10 min/frame) to recover 07_22 cells + that the strict 600-min filter excluded. 08_26 uses + ``productive_source: lc_zarr`` because its annotation CSV is + per-frame rather than track-linked. Mock cohort uses + ``min_non_productive_minutes`` decoupled from the productive + filter so short tracks can still serve as null distributions. + datasets: + - "2025_07_24_SEC61" + - "2025_07_24_G3BP1" + - "2025_07_22_G3BP1" + - "2025_08_26_SEC61" + productive_filter: + anchor_label: infection_state + anchor_positive: infected + anchor_negative: uninfected + min_pre_minutes: 120 # was 240; loosened to 300-min total floor + min_post_minutes: 180 # was 360; loosened to 300-min total floor + crop_window_minutes: 360 + cohort_rules: + bystander_uninfected_fraction: 0.8 + abortive_min_run: 3 + min_non_productive_minutes: 300 # decoupled from productive pre/post; lets + # short tracks contribute as bystander/mock + mock_well_patterns: + - "A/1" + - "C/1" + lineage_rules: + reconnect: true + return_both_branches: true + transition_window_k_pre_minutes: 60 + transition_window_k_post_minutes: 120 + max_lineages: 400 + + manual_debug_zikv: + description: "Manual debug: 4 hand-picked tracks from manual_candidates.py, ZIKV 07_24 A/2" + source: manual # tells select_candidates.py to skip auto path + manual_script: manual_candidates.py + cohort_rules: + mock_well_patterns: + - "A/1" + - "C/1" + lineage_rules: + reconnect: true + return_both_branches: true + transition_window_k_pre_minutes: 60 + transition_window_k_post_minutes: 120 diff --git a/applications/dynaclr/configs/pseudotime/compare_tracks.yaml b/applications/dynaclr/configs/pseudotime/compare_tracks.yaml new file mode 100644 index 000000000..c5a2de434 --- /dev/null +++ b/applications/dynaclr/configs/pseudotime/compare_tracks.yaml @@ -0,0 +1,29 @@ +# Stage 4 config — cross-track comparison of organelle readouts. +# Load via: --datasets datasets.yaml --config compare_tracks.yaml + +comparisons: + zikv_07_24_full: + description: "Full ZIKV 07_24 comparison: SEC61 + G3BP1 + phase under all three tracks." + candidate_set: zikv_productive_07_24 + organelles: [sec61, g3bp1, phase] + tracks: [A-anno, A-LC, B] + cohorts: [productive, mock, bystander] + # Target headline metric for the methodological claim (per DAG §9.1 + # and discussion §2.2). Path B succeeds if its IQR width at the + # primary metric is ≥ 25% tighter than the better of A-anno and A-LC. + methodological_threshold_fraction: 0.25 + # Real-time bins for population aggregation. + bin_minutes: 30 + bin_range_minutes: [-360, 540] + + zikv_pooled_full: + description: | + Pooled comparison across 07_24 (SEC61+G3BP1), 07_22 (G3BP1), and + 08_26 (SEC61). Increases sample size for compare_phase_to_fluor. + candidate_set: zikv_productive_pooled + organelles: [sec61, g3bp1, phase] + tracks: [A-anno, A-LC, B] + cohorts: [productive, mock, bystander] + methodological_threshold_fraction: 0.25 + bin_minutes: 30 + bin_range_minutes: [-360, 540] diff --git a/applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.sh b/applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.sh new file mode 100644 index 000000000..8802b0b01 --- /dev/null +++ b/applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# DINOv3-temporal-MLP-2D-BagOfChannels +# +# New run: +# sbatch applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.sh +# +# Resume: edit CKPT_PATH and WANDB_RUN_ID below, then sbatch from RUN_DIR: +# sbatch /hpc/projects/.../DINOv3-temporal-MLP-2D-BagOfChannels.sh + +#SBATCH --job-name=dinov3_mlp_2d +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gres=gpu:2 +#SBATCH --constraint="h100|h200" +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=15 +#SBATCH --mem-per-cpu=8G +#SBATCH --time=2-00:00:00 + +# ── Run identity ────────────────────────────────────────────────────── +export PROJECT="DINOv3-temporal-MLP-2D-BagOfChannels-v1" +export RUN_NAME="dinov3-mlp-2d-mip-ntxent-t0p5-lr1e4-bs512" +export CONFIGS="applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.yml" + +# ── Resume (uncomment to continue from checkpoint) ──────────────────── +export CKPT_PATH="/hpc/projects/organelle_phenotyping/models/DINOv3-temporal-MLP-2D-BagOfChannels-v1/dinov3-mlp-2d-mip-ntxent-t0p5-lr1e4-bs512/DINOv3-temporal-MLP-2D-BagOfChannels-v1/20260403-223550/checkpoints/last.ckpt" +export WANDB_RUN_ID="20260403-223550" + +source "$(dirname "$0")/../slurm/train.sh" diff --git a/applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.yml b/applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.yml new file mode 100644 index 000000000..6f3ccda42 --- /dev/null +++ b/applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.yml @@ -0,0 +1,120 @@ +# DINOv3-temporal-MLP-2D-BagOfChannels +# ========================================= +# Frozen DINOv3 backbone + trainable MLP projection head. +# 2D bag-of-channels with MIP z-reduction (same data pipeline as +# DynaCLR-2D-MIP-BagOfChannels). +# +# Launch: +# sbatch applications/dynaclr/configs/training/DINOv3/DINOv3-temporal-MLP-2D-BagOfChannels.sh +# +# Resume: +# CKPT_PATH=.../last.ckpt sbatch .../DINOv3-temporal-MLP-2D-BagOfChannels.sh + +base: + - ../recipes/trainer.yml + - ../recipes/model/dinov3_frozen_mlp.yml + +trainer: + strategy: ddp + devices: 2 + precision: bf16-mixed + max_epochs: 100 + logger: + init_args: + project: DINOv3-temporal-MLP-2D-BagOfChannels-v1 + name: null + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: loss/val + every_n_epochs: 1 + save_top_k: 5 + save_last: true + - class_path: viscy_utils.callbacks.OnlineEvalCallback + init_args: + every_n_epochs: 5 + label_key: perturbation + k: 20 + track_id_key: global_track_id + timepoint_key: t + +model: + init_args: + pca_color_keys: [perturbation, hours_post_perturbation, experiment, marker] + log_negative_metrics_every_n_epochs: 2 + example_input_array_shape: [1, 1, 1, 160, 160] + +data: + class_path: dynaclr.data.datamodule.MultiExperimentDataModule + init_args: + cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-2D-MIP-BagOfChannels-MultiCell.parquet + focus_channel: Phase3D + reference_pixel_size_xy_um: 0.1494 + z_window: 1 + z_extraction_window: 11 + z_focus_offset: 0.5 + yx_patch_size: [256, 256] + final_yx_patch_size: [160, 160] + channels_per_sample: 1 + positive_cell_source: lookup + positive_match_columns: [lineage_id] + positive_channel_source: same + tau_range: [0.5, 2.0] + tau_decay_rate: 2.0 + stratify_by: [perturbation, marker] + split_ratio: 0.8 + batch_size: 512 + num_workers: 2 + seed: 42 + normalizations: + - class_path: viscy_transforms.NormalizeSampled + init_args: + keys: [channel_0] + level: timepoint_statistics + subtrahend: mean + divisor: std + augmentations: + - class_path: viscy_transforms.BatchedRandAffined + init_args: + keys: [channel_0] + prob: 0.8 + scale_range: [[0.9, 1.1], [0.9, 1.1], [0.9, 1.1]] + rotate_range: [3.14, 0.0, 0.0] + shear_range: [0.05, 0.05, 0.0, 0.05, 0.0, 0.05] + - class_path: viscy_transforms.BatchedRandFlipd + init_args: + keys: [channel_0] + spatial_axes: [1, 2] + prob: 0.5 + - class_path: viscy_transforms.BatchedRandAdjustContrastd + init_args: + keys: [channel_0] + prob: 0.5 + gamma: [0.8, 1.3] + - class_path: viscy_transforms.BatchedRandScaleIntensityd + init_args: + keys: [channel_0] + prob: 0.5 + factors: 0.5 + - class_path: viscy_transforms.BatchedRandGaussianSmoothd + init_args: + keys: [channel_0] + prob: 0.5 + sigma_x: [0.25, 0.50] + sigma_y: [0.25, 0.50] + sigma_z: [0.0, 0.0] + - class_path: viscy_transforms.BatchedRandGaussianNoised + init_args: + keys: [channel_0] + prob: 0.5 + mean: 0.0 + std: 0.08 + # Z-reduction: MIP for fluorescence, center-slice for label-free. + # Must be LAST augmentation (before implicit final spatial crop). + - class_path: viscy_transforms.BatchedChannelWiseZReductiond + init_args: + keys: [channel_0] + allow_missing_keys: true diff --git a/applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.sh b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.sh similarity index 83% rename from applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.sh rename to applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.sh index feb6edadd..9fc4b3be3 100755 --- a/applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.sh +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.sh @@ -2,7 +2,7 @@ # DynaCLR-2D-BagOfChannels-v3 # # New run: -# sbatch applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.sh +# sbatch applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.sh # # Resume: edit CKPT_PATH and WANDB_RUN_ID below, then sbatch from RUN_DIR. @@ -18,10 +18,10 @@ # ── Run identity ────────────────────────────────────────────────────── export PROJECT="DynaCLR-2D-BagOfChannels-v3" export RUN_NAME="phase1-ntxent-temp0p2" -export CONFIGS="applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.yml" +export CONFIGS="applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.yml" # ── Resume (uncomment to continue from checkpoint) ──────────────────── # export CKPT_PATH="" # export WANDB_RUN_ID="" -source "$(dirname "$0")/slurm/train.sh" +source "$(dirname "$0")/../slurm/train.sh" diff --git a/applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.yml b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.yml similarity index 78% rename from applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.yml rename to applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.yml index ff4eba7b5..e50e2ba10 100644 --- a/applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.yml +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.yml @@ -5,22 +5,17 @@ # Temporal positive pairs (same lineage at t+tau), stratified by perturbation + marker. # # Launch: -# sbatch applications/dynaclr/configs/training/DynaCLR-2D-BagOfChannels-v3.sh +# sbatch applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-BagOfChannels-v3.sh -seed_everything: 42 +base: + - ../recipes/trainer.yml + - ../recipes/model/contrastive_encoder_convnext_tiny.yml trainer: - accelerator: gpu strategy: ddp devices: 2 - num_nodes: 1 precision: bf16-mixed max_epochs: 150 - log_every_n_steps: 10 - enable_checkpointing: true - enable_model_summary: false - inference_mode: true - use_distributed_sampler: false callbacks: - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: @@ -36,37 +31,27 @@ trainer: every_n_epochs: 5 label_key: perturbation k: 20 - - class_path: viscy_utils.callbacks.SaveConfigToWandb model: - class_path: dynaclr.engine.ContrastiveModule init_args: encoder: - class_path: viscy_models.contrastive.ContrastiveEncoder init_args: - backbone: convnext_tiny - in_channels: 1 in_stack_depth: 1 stem_kernel_size: [1, 4, 4] stem_stride: [1, 4, 4] - embedding_dim: 768 projection_dim: 32 drop_path_rate: 0.1 loss_function: - class_path: viscy_models.contrastive.loss.NTXentLoss init_args: temperature: 0.2 lr: 0.00002 - log_batches_per_epoch: 3 - log_samples_per_batch: 3 - log_embeddings_every_n_epochs: 10 pca_color_keys: "[perturbation,hours_post_perturbation]" example_input_array_shape: [1, 1, 1, 160, 160] data: class_path: dynaclr.data.datamodule.MultiExperimentDataModule init_args: - cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/applications/dynaclr/configs/cell_index/DynaCLR-2D-BagOfChannels-v3.parquet + cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-2D-BagOfChannels-v3.parquet z_window: 1 yx_patch_size: [192, 192] final_yx_patch_size: [160, 160] diff --git a/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-192.sh b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-192.sh new file mode 100644 index 000000000..861b67fe8 --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-192.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# DynaCLR-2D-MIP-BagOfChannels SINGLE-MARKER 192px variant. +# Same recipe as single-marker.sh but with 384->256->192 crops instead +# of 256->192->160. Larger final input preserves more subcellular detail +# at ~2x the I/O cost per batch. +# +# sbatch applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-192.sh + +#SBATCH --job-name=dynaclr_2d_sm192 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gres=gpu:2 +#SBATCH --constraint="h100|h200" +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=15 +# 17 GB/CPU × 15 CPUs = 255 GB/rank, 510 GB/node on 2 GPUs. Bumped from 14G +# after rank 3 host-RAM OOM on 384² patches (job 31449149). Combined with +# prefetch_factor=1 in datamodule. Dropped from 4 GPUs to 2 to ease queueing; +# batch_size kept at 256/rank — host RAM was the OOM driver (cgroup), not +# VRAM, and that scales with workers × prefetch, not batch_size. If this +# still OOMs, suspect a real leak (loky semaphores, tensorstore decoder +# scratch) — investigate rather than papering over with more RAM. +#SBATCH --mem-per-cpu=17G +#SBATCH --time=3-00:00:00 + +export PROJECT="DynaCLR-2D-MIP-BagOfChannels" +export RUN_NAME="2d-mip-ntxent-t0p2-lr2e5-bs256-384to192-zext16-single-marker-fix-shuffler" +export CONFIGS="applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.yml applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-192.yml" + +# Warm-start disabled: prior attempt 31442612 hit a 30-min NCCL all-reduce +# timeout in optimizer.step. Suspected interaction between the warm-start +# (160-input encoder weights loaded into a 192-input model) and the +# augmentation pipeline causing rank divergence. Train from random init +# to remove that confound; if the fresh-init run trains cleanly we can +# revisit warm-start in v2. +# export EXTRA_ARGS="--model.init_args.ckpt_path=/hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-single-marker-fix-shuffler/DynaCLR-2D-MIP-BagOfChannels/0rhpwh77/checkpoints/last.ckpt" + +source /hpc/mydata/eduardo.hirata/repos/viscy/applications/dynaclr/configs/training/slurm/train.sh diff --git a/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-192.yml b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-192.yml new file mode 100644 index 000000000..b8cc2c044 --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-192.yml @@ -0,0 +1,74 @@ +# Override: single-marker batches at 384->256->192 patch sizes (vs the +# default single-marker at 256->192->160). Larger final crop preserves more +# subcellular detail; affine corner safety holds at scale_range=[0.8, 1.3] +# under any rotation because 384 / sqrt(2) * 0.8 = 217 > 192. +# +# Layered on top of the base single-marker override: +# --config applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml +# --config applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.yml +# --config applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-192.yml + +trainer: + devices: 2 + +model: + init_args: + example_input_array_shape: [1, 1, 1, 192, 192] + +data: + init_args: + yx_patch_size: [384, 384] + final_yx_patch_size: [192, 192] + augmentations: + - class_path: viscy_transforms.BatchedRandAffined + init_args: + keys: [channel_0] + prob: 0.8 + scale_range: [[0.8, 1.3], [0.8, 1.3], [0.8, 1.3]] + rotate_range: [3.14, 0.0, 0.0] + shear_range: [0.05, 0.05, 0.0, 0.05, 0.0, 0.05] + - class_path: viscy_transforms.BatchedRandFlipd + init_args: + keys: [channel_0] + spatial_axes: [1, 2] + prob: 0.5 + - class_path: viscy_transforms.BatchedRandAdjustContrastd + init_args: + keys: [channel_0] + prob: 0.5 + gamma: [0.6, 1.6] + - class_path: viscy_transforms.BatchedRandScaleIntensityd + init_args: + keys: [channel_0] + prob: 0.5 + factors: 0.5 + - class_path: viscy_transforms.BatchedRandGaussianSmoothd + init_args: + keys: [channel_0] + prob: 0.5 + sigma_x: [0.25, 0.50] + sigma_y: [0.25, 0.50] + sigma_z: [0.0, 0.0] + - class_path: viscy_transforms.BatchedRandGaussianNoised + init_args: + keys: [channel_0] + prob: 0.5 + mean: 0.0 + std: 0.1 + # Random Z + YX crop sized to fit inside the affine-safe inscribed + # region: at scale_range=[0.8, 1.3] under any rotation, the safe + # inscribed square is 384 / sqrt(2) * 0.8 = 217 px. We crop to 216 + # to land fully inside; the implicit center-crop to 192 then keeps + # 12 px of margin per side. A 256×256 random crop (prior version) + # spilled ~20 px outside the safe zone on each side, so some + # batches contained zero-padded affine corners — likely cause of + # gradient magnitude divergence and DDP all-reduce timeout. + - class_path: viscy_transforms.BatchedRandSpatialCropd + init_args: + keys: [channel_0] + roi_size: [10, 216, 216] + # Z-reduction stays last (before implicit final spatial crop to 192). + - class_path: viscy_transforms.BatchedChannelWiseZReductiond + init_args: + keys: [channel_0] + allow_missing_keys: true diff --git a/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.sh b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.sh new file mode 100644 index 000000000..63fd28fc3 --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# DynaCLR-2D-MIP-BagOfChannels single-marker — A40 interactive single-GPU variant. +# For smoke-testing and small-scale iteration on the interactive partition +# without queueing on the gpu partition. +# +# sbatch applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.sh + +#SBATCH --job-name=dynaclr_2d_sm_a40 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:a40:1 +#SBATCH --partition=interactive +#SBATCH --cpus-per-task=16 +#SBATCH --mem-per-cpu=14G +#SBATCH --time=4-00:00:00 + +export PROJECT="DynaCLR-2D-MIP-BagOfChannels" +export RUN_NAME="2d-mip-ntxent-t0p2-lr2e5-bs128-A40-single-marker" +export CONFIGS="applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.yml applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.yml" + +source "$(dirname "$0")/../slurm/train.sh" diff --git a/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.yml b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.yml new file mode 100644 index 000000000..1a85a68a5 --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.yml @@ -0,0 +1,12 @@ +# Single-GPU A40 override for DynaCLR-2D-MIP-BagOfChannels single-marker. +# Chains on top of the 4-GPU base + single-marker override; strips DDP and +# halves batch size to fit the A40's 48 GB VRAM. + +trainer: + strategy: auto + devices: 1 + +data: + init_args: + batch_size: 128 + num_workers: 1 diff --git a/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.sh b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.sh new file mode 100755 index 000000000..215492166 --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# DynaCLR-2D-MIP-BagOfChannels SINGLE-MARKER variant. +# Every batch contains only one marker (OPS-style). +# +# sbatch applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.sh + +#SBATCH --job-name=dynaclr_2d_sm +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 +#SBATCH --constraint="h100|h200" +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=15 +#SBATCH --mem-per-cpu=8G +#SBATCH --time=3-00:00:00 + +export PROJECT="DynaCLR-2D-MIP-BagOfChannels" +export RUN_NAME="2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-single-marker-fix-shuffler" +export CONFIGS="applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.yml" + +# Warm-start at epoch 0 from THIS run's prior attempt (0rhpwh77/last.ckpt, +# Apr 24, epoch=0-step=800). Job 31410692 trained for 1 epoch + val before +# hanging on a OnlineEvalCallback DDP logging deadlock (rank-0-only log +# triggers an unmatched all-reduce on epoch end). Fix landed in +# online_eval.py — switching to sync_dist=True and computing on every +# rank. Loads encoder weights only via engine.py:76-86; optimizer state +# and epoch counter still reset. +export EXTRA_ARGS="--model.init_args.ckpt_path=/hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-single-marker-fix-shuffler/DynaCLR-2D-MIP-BagOfChannels/0rhpwh77/checkpoints/last.ckpt" + +source "$(dirname "$0")/../slurm/train.sh" diff --git a/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.yml b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.yml new file mode 100644 index 000000000..5fd34915c --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels-single-marker.yml @@ -0,0 +1,30 @@ +# Override: single-marker batches for DynaCLR-2D-MIP-BoC. +# Matches the OPS strategy — every batch is one marker, forcing the model +# to learn cellular features instead of channel-filter shortcuts. + +data: + init_args: + # v3 parquet drops dynamorph Brightfield + Retardance (same physical + # cells as Phase3D, were inflating that experiment's row count). + cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-2D-MIP-BagOfChannels-v3.parquet + batch_group_by: marker + # Within a marker's draw, balance across the experiments containing + # that marker. Without this, Phase3D batches were 74% dynamorph cells + # because dynamorph is by far the largest Phase3D experiment. + stratify_by: experiment + # Marker-uniform weights. Without these, batch_group_by + stratify_by + # weights marker draws by the (marker, experiment) cross-product cell + # counts — Retardance/Brightfield-dominant experiments would skew the + # marker distribution. Setting equal weights per marker forces uniform + # P(marker)=1/n_markers per batch. Diagnostic: + # applications/dynaclr/scripts/dataloader_inspection/check_marker_experiment_stratify.py + group_weights: + Phase3D: 1.0 + pAL10: 1.0 + viral_sensor: 1.0 + G3BP1: 1.0 + SEC61B: 1.0 + TOMM20: 1.0 + CAAX: 1.0 + HIST2H2BE: 1.0 + DIC: 1.0 diff --git a/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.sh b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.sh new file mode 100644 index 000000000..a74b670e9 --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# DynaCLR-2D-MIP-BagOfChannels +# Multi-cell-type 2D contrastive learning with channel-wise z-reduction. +# +# New run: +# sbatch applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.sh +# +# Resume: edit CKPT_PATH and WANDB_RUN_ID below, then sbatch. + +#SBATCH --job-name=dynaclr_2d_mip +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 +#SBATCH --constraint="h100|h200" +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=15 +#SBATCH --mem-per-cpu=8G +#SBATCH --time=3-00:00:00 + +# ── Run identity ────────────────────────────────────────────────────── +# Warm-started from prior mixed-markers run (s1f8kgtp/last.ckpt, Apr 22) +# at epoch 0. Picks up the FlexibleBatchSampler reshuffle fix +# (commit f4f40c38) and the profiling-pass defaults (nw=4, ts.Batch +# overlap, file_io_concurrency=128, z_extraction_window=16, cuDNN +# benchmark, TF32 matmul). Optimizer state and epoch counter reset so +# AdamW moments don't carry over biased gradients from the broken sampler. +export PROJECT="DynaCLR-2D-MIP-BagOfChannels" +export RUN_NAME="2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-mixed-markers-fix-shuffler" +export CONFIGS="applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml" + +# ── Warm-start at epoch 0 (state_dict only — not Lightning's full resume) ── +export EXTRA_ARGS="--model.init_args.ckpt_path=/hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11-mixed-markers/DynaCLR-2D-MIP-BagOfChannels/s1f8kgtp/checkpoints/last.ckpt" + +# ── Resume (Lightning full state, NOT what we want here) ────────────── +# export CKPT_PATH="/path/to/last.ckpt" +# export WANDB_RUN_ID="" + +source "$(dirname "$0")/../slurm/train.sh" diff --git a/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml new file mode 100644 index 000000000..bfae82546 --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml @@ -0,0 +1,156 @@ +# DynaCLR-2D-MIP-BagOfChannels +# ============================== +# 2D bag-of-channels contrastive learning with channel-wise z-reduction. +# Extracts a 20-slice z-stack around focus, randomly crops to 10 slices +# (Z-invariance), then applies MIP for fluorescence and center-slice for +# label-free (Phase3D, BF, DIC, Retardance). +# Multi-cell-type: A549 infectomics, microglia dynamorph, ALFI mitosis. +# +# Launch: +# sbatch applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.sh +# +# Resume: +# CKPT_PATH=.../last.ckpt sbatch .../DynaCLR-2D-MIP-BagOfChannels.sh + +base: + - ../recipes/trainer.yml + - ../recipes/model/contrastive_encoder_convnext_tiny.yml + +trainer: + strategy: ddp + devices: 4 + precision: bf16-mixed + max_epochs: 150 + # 2.7M train anchors × bs=256 = ~10.5k full-epoch batches → 85 min/epoch + # at current ~132 samples/s. Cap epoch length so wall time stays bounded + # and the val signal lands often. Matches the pattern used by OPS-1000genes. + limit_train_batches: 800 + limit_val_batches: 200 + logger: + init_args: + project: DynaCLR-2D-MIP-BagOfChannels + name: null + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: loss/val + every_n_epochs: 1 + save_top_k: 5 + save_last: true + - class_path: viscy_utils.callbacks.OnlineEvalCallback + init_args: + every_n_epochs: 5 + label_key: perturbation + k: 20 + track_id_key: global_track_id + timepoint_key: t + +model: + init_args: + encoder: + init_args: + in_stack_depth: 1 + stem_kernel_size: [1, 4, 4] + stem_stride: [1, 4, 4] + projection_dim: 32 + drop_path_rate: 0.1 + loss_function: + init_args: + temperature: 0.2 + lr: 0.00002 + pca_color_keys: "[perturbation,hours_post_perturbation,experiment,marker]" + log_negative_metrics_every_n_epochs: 2 + example_input_array_shape: [1, 1, 1, 160, 160] + +data: + class_path: dynaclr.data.datamodule.MultiExperimentDataModule + init_args: + cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-2D-MIP-BagOfChannels-v3.parquet + focus_channel: Phase3D + reference_pixel_size_xy_um: 0.1494 + z_window: 1 + z_extraction_window: 16 + z_focus_offset: 0.3 + yx_patch_size: [256, 256] + final_yx_patch_size: [160, 160] + channels_per_sample: 1 + positive_cell_source: lookup + positive_match_columns: [lineage_id] + positive_channel_source: same + tau_range: [0.5, 2.0] + tau_decay_rate: 2.0 + stratify_by: [perturbation, marker] + split_ratio: 0.8 + batch_size: 256 + num_workers: 4 + # Per-rank memory budget on h100/h200 nodes (8 GB/CPU × 15 CPUs = 120 GB). + # 26 plates × 500 MB per-plate cache pools = ~13 GB of dead weight (iohub + # 0.3.x creates one ts.Context per open_ome_zarr; random sampling across + # plates → near-zero hit rate). Disable. Cap ThreadBuffer queue to one + # batch. prefetch_factor=1 (was 2) after rank 3 host-RAM OOM on 384² + # patches (job 31449149) — halves in-flight batch count per worker. + prefetch_factor: 1 + buffer_size: 1 + cache_pool_bytes: 0 + # Match historical iohub default; the 128 we tried for +8.6% A/B amplifies + # tensorstore decoder/IO buffers and contributes to OOM under DDP. + file_io_concurrency: 32 + seed: 42 + normalizations: + - class_path: viscy_transforms.NormalizeSampled + init_args: + keys: [channel_0] + level: timepoint_statistics + subtrahend: mean + divisor: std + augmentations: + - class_path: viscy_transforms.BatchedRandAffined + init_args: + keys: [channel_0] + prob: 0.8 + scale_range: [[0.8, 1.3], [0.8, 1.3], [0.8, 1.3]] + rotate_range: [3.14, 0.0, 0.0] + shear_range: [0.05, 0.05, 0.0, 0.05, 0.0, 0.05] + - class_path: viscy_transforms.BatchedRandFlipd + init_args: + keys: [channel_0] + spatial_axes: [1, 2] + prob: 0.5 + - class_path: viscy_transforms.BatchedRandAdjustContrastd + init_args: + keys: [channel_0] + prob: 0.5 + gamma: [0.6, 1.6] + - class_path: viscy_transforms.BatchedRandScaleIntensityd + init_args: + keys: [channel_0] + prob: 0.5 + factors: 0.5 + - class_path: viscy_transforms.BatchedRandGaussianSmoothd + init_args: + keys: [channel_0] + prob: 0.5 + sigma_x: [0.25, 0.50] + sigma_y: [0.25, 0.50] + sigma_z: [0.0, 0.0] + - class_path: viscy_transforms.BatchedRandGaussianNoised + init_args: + keys: [channel_0] + prob: 0.5 + mean: 0.0 + std: 0.1 + # Random Z crop: select 10 of 20 extracted slices for Z-invariance. + # Must come before ZReduction so MIP sees a variable sub-stack. + - class_path: viscy_transforms.BatchedRandSpatialCropd + init_args: + keys: [channel_0] + roi_size: [10, 192, 192] + # Z-reduction: MIP for fluorescence, center-slice for label-free. + # Must be LAST augmentation (before implicit final spatial crop). + - class_path: viscy_transforms.BatchedChannelWiseZReductiond + init_args: + keys: [channel_0] + allow_missing_keys: true diff --git a/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.sh b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.sh new file mode 100755 index 000000000..610d2a9d4 --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# DynaCLR-3D-BagOfChannels-v2 SINGLE-MARKER variant (fresh, no resume). +# +# sbatch applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.sh + +#SBATCH --job-name=dynaclr_3d_sm +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gres=gpu:2 +#SBATCH --constraint="h100|h200" +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=15 +#SBATCH --mem-per-cpu=12G +#SBATCH --time=4-00:00:00 + +export PROJECT="DynaCLR-3D-BagOfChannels-v2" +export RUN_NAME="3d-z32-256to228to160-ntxent-t0p2-single-marker" +export CONFIGS="applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.yml applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.yml" + +source "$(dirname "$0")/../slurm/train.sh" diff --git a/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.yml b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.yml new file mode 100644 index 000000000..95ffb127e --- /dev/null +++ b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2-single-marker.yml @@ -0,0 +1,7 @@ +# Override: single-marker batches for DynaCLR-3D-BoC-v2. +# Matches the OPS strategy — every batch is one marker. + +data: + init_args: + batch_group_by: marker + stratify_by: null diff --git a/applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.sh b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh similarity index 54% rename from applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.sh rename to applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh index d8f73fd63..e1f1e5752 100755 --- a/applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.sh +++ b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh @@ -2,28 +2,29 @@ # DynaCLR-3D-BagOfChannels-v2 # # New run: -# sbatch applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.sh +# sbatch applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh # # Resume: edit CKPT_PATH and WANDB_RUN_ID below, then sbatch from RUN_DIR: -# sbatch /hpc/projects/.../3d-z16-.../DynaCLR-3D-BagOfChannels-v2.sh +# sbatch /hpc/projects/.../3d-z32-.../DynaCLR-3D-BagOfChannels-v2.sh #SBATCH --job-name=dynaclr_3d_v2 #SBATCH --nodes=1 -#SBATCH --ntasks-per-node=4 -#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=2 +#SBATCH --gres=gpu:2 #SBATCH --constraint="h100|h200" #SBATCH --partition=gpu #SBATCH --cpus-per-task=15 -#SBATCH --mem-per-cpu=8G -#SBATCH --time=0-22:00:00 +#SBATCH --mem-per-cpu=12G +#SBATCH --time=4-00:00:00 # ── Run identity ────────────────────────────────────────────────────── export PROJECT="DynaCLR-3D-BagOfChannels-v2" -export RUN_NAME="3d-z16-ntxent-t0p2-lr2e5-bs512-192to160-zext45" -export CONFIGS="applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.yml" +export RUN_NAME="3d-z32-256to228to160-ntxent-t0p2-mixed-markers" +export CONFIGS="applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.yml" # ── Resume (uncomment to continue from checkpoint) ──────────────────── -# export CKPT_PATH="/hpc/projects/organelle_phenotyping/models/DynaCLR-3D-BagOfChannels-v2/3d-z16-ntxent-t0p2-lr2e5-bs512-192to160-zext45/checkpoints/last.ckpt" -# export WANDB_RUN_ID="20260329-063341" +# Commented out for fresh A/B comparison run against single-marker variant. +# export CKPT_PATH="/hpc/projects/organelle_phenotyping/models/DynaCLR-3D-BagOfChannels-v2/3d-z32-256to228to160-ntxent-t0p2/DynaCLR-3D-BagOfChannels-v2/20260402-185442/checkpoints/last.ckpt" +# export WANDB_RUN_ID="20260402-185442" -source "$(dirname "$0")/slurm/train.sh" +source "$(dirname "$0")/../slurm/train.sh" diff --git a/applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.yml b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.yml similarity index 73% rename from applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.yml rename to applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.yml index 3d6392ce7..b9272212d 100644 --- a/applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.yml +++ b/applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.yml @@ -1,35 +1,32 @@ # DynaCLR-3D-BagOfChannels-v2 # ============================ # 3D bag-of-channels contrastive learning. -# One random fluorescence channel per sample, 16-slice Z window. +# One random fluorescence channel per sample, 32-slice Z window. # Temporal positive pairs (same lineage at t+tau), stratified by perturbation. # +# Augmentation pipeline: +# extract (45,256,256) → normalize → affine → RandCrop (40,228,228) +# → flip/contrast/noise → CenterCrop (32,160,160) [auto-appended] +# # Launch: -# sbatch applications/dynaclr/configs/training/DynaCLR-3D-BagOfChannels-v2.sh +# sbatch applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh # # Resume: # CKPT_PATH=.../last.ckpt sbatch .../DynaCLR-3D-BagOfChannels-v2.sh -seed_everything: 42 +base: + - ../recipes/trainer.yml + - ../recipes/model/contrastive_encoder_convnext_tiny.yml trainer: - accelerator: gpu strategy: ddp - devices: 4 - num_nodes: 1 + devices: 2 precision: bf16-mixed max_epochs: 150 - log_every_n_steps: 10 - enable_checkpointing: true - enable_model_summary: false - inference_mode: true - use_distributed_sampler: false logger: - class_path: lightning.pytorch.loggers.WandbLogger init_args: - entity: computational_imaging project: DynaCLR-3D-BagOfChannels-v2 - name: unnamed-run + name: 3d-z32-256to228to160-ntxent-t0p2 callbacks: - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: @@ -47,46 +44,35 @@ trainer: k: 20 track_id_key: global_track_id timepoint_key: t - - class_path: viscy_utils.callbacks.SaveConfigToWandb model: - class_path: dynaclr.engine.ContrastiveModule init_args: encoder: - class_path: viscy_models.contrastive.ContrastiveEncoder init_args: - backbone: convnext_tiny - in_channels: 1 - in_stack_depth: 16 + in_stack_depth: 32 stem_kernel_size: [4, 4, 4] stem_stride: [4, 4, 4] - embedding_dim: 768 projection_dim: 32 drop_path_rate: 0.1 loss_function: - class_path: viscy_models.contrastive.loss.NTXentLoss init_args: temperature: 0.2 lr: 0.00002 - log_batches_per_epoch: 3 - log_samples_per_batch: 3 - log_embeddings_every_n_epochs: 10 pca_color_keys: "[perturbation,hours_post_perturbation,experiment,marker]" log_negative_metrics_every_n_epochs: 2 - example_input_array_shape: [1, 1, 16, 160, 160] + example_input_array_shape: [1, 1, 32, 160, 160] data: class_path: dynaclr.data.datamodule.MultiExperimentDataModule init_args: - collection_path: applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v2.yml - cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-3D-BagOfChannels-v2.parquet + cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-3D-BagOfChannels-v4.parquet focus_channel: Phase3D reference_pixel_size_xy_um: 0.1494 reference_pixel_size_z_um: 0.174 - z_window: 16 + z_window: 32 z_extraction_window: 45 z_focus_offset: 0.3 - yx_patch_size: [192, 192] + yx_patch_size: [256, 256] final_yx_patch_size: [160, 160] channels_per_sample: 1 positive_cell_source: lookup @@ -96,8 +82,8 @@ data: tau_decay_rate: 2.0 stratify_by: [perturbation] split_ratio: 0.8 - batch_size: 512 - num_workers: 1 + batch_size: 256 + num_workers: 4 seed: 42 normalizations: - class_path: viscy_transforms.NormalizeSampled @@ -114,6 +100,13 @@ data: scale_range: [[0.9, 1.1], [0.9, 1.1], [0.9, 1.1]] rotate_range: [3.14, 0.0, 0.0] shear_range: [0.05, 0.05, 0.0, 0.05, 0.0, 0.05] + # Random crop: Z for focus invariance + YX for translation augmentation. + # The datamodule auto-appends a CenterCrop to [32, 160, 160] after this + # to remove rotation zero-fill artifacts at the edges. + - class_path: viscy_transforms.BatchedRandSpatialCropd + init_args: + keys: [channel_0] + roi_size: [40, 228, 228] - class_path: viscy_transforms.BatchedRandFlipd init_args: keys: [channel_0] diff --git a/applications/dynaclr/configs/training/OPS-1000genes-lite.sh b/applications/dynaclr/configs/training/OPS-1000genes-lite.sh deleted file mode 100755 index ebc569469..000000000 --- a/applications/dynaclr/configs/training/OPS-1000genes-lite.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash -# OPS 1000-gene DynaCLR with cosine gene classifier head (lite dataset) -# -# New run: -# sbatch applications/dynaclr/configs/training/OPS-1000genes-lite.sh -# -# Resume: edit CKPT_PATH and WANDB_RUN_ID below, then sbatch from RUN_DIR. - -#SBATCH --job-name=dynaclr_ops_1k -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=4 -#SBATCH --gres=gpu:4 -#SBATCH --partition=gpu -#SBATCH --constraint="h100|h200" -#SBATCH --cpus-per-task=15 -#SBATCH --mem-per-cpu=8G -#SBATCH --time=0-22:00:00 - -# ── Run identity ────────────────────────────────────────────────────── -export PROJECT="OPS" -export RUN_NAME="OPS-1000genes-lite-CosineClassifier" -export EXTRA_ARGS="--trainer.logger.init_args.project=OPS-1000genes-lite-CosineClassifier" -export CONFIGS="applications/dynaclr/configs/training/OPS-1000genes-lite.yml" - -# ── Resume (uncomment to continue from checkpoint) ──────────────────── -# export CKPT_PATH="" -# export WANDB_RUN_ID="" - -WORKSPACE_DIR="${WORKSPACE_DIR:-/hpc/mydata/eduardo.hirata/repos/viscy}" -source "${WORKSPACE_DIR}/applications/dynaclr/configs/training/slurm/train.sh" diff --git a/applications/dynaclr/configs/training/OPS-1000genes-lite.yml b/applications/dynaclr/configs/training/OPS-1000genes-lite.yml deleted file mode 100644 index 88b9ef4c9..000000000 --- a/applications/dynaclr/configs/training/OPS-1000genes-lite.yml +++ /dev/null @@ -1,143 +0,0 @@ -# OPS 1000-gene DynaCLR with cosine gene classifier head (lite dataset) -# ====================================================================== -# Lite dataset: 11M cells, 1001 perturbations, 22 reporters, 74 experiments. -# Percentile normalization (50-99), bag-of-channels, gene+reporter positive pairs. -# -# Launch: -# sbatch applications/dynaclr/configs/training/OPS-1000genes-lite.sh - -seed_everything: 42 - -trainer: - accelerator: gpu - strategy: ddp - devices: 4 - num_nodes: 1 - precision: bf16-mixed - max_epochs: 300 - limit_train_batches: 400 - limit_val_batches: 100 - log_every_n_steps: 5 - enable_checkpointing: true - enable_model_summary: false - inference_mode: true - use_distributed_sampler: false - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: loss/val - every_n_epochs: 1 - save_top_k: 5 - save_last: true - - class_path: viscy_utils.callbacks.OnlineEvalCallback - init_args: - every_n_epochs: 5 - label_key: perturbation - k: 20 - - class_path: viscy_utils.callbacks.SaveConfigToWandb - -model: - class_path: dynaclr.engine.ContrastiveModule - init_args: - encoder: - class_path: viscy_models.contrastive.ContrastiveEncoder - init_args: - backbone: convnext_tiny - in_channels: 1 - in_stack_depth: 1 - stem_kernel_size: [1, 4, 4] - stem_stride: [1, 4, 4] - embedding_dim: 768 - projection_dim: 256 - drop_path_rate: 0.0 - loss_function: - class_path: viscy_models.contrastive.loss.NTXentLoss - init_args: - temperature: 0.5 - auxiliary_heads: - gene: - class_path: viscy_models.components.heads.ClassificationHead - init_args: - head_name: gene - batch_key: gene_label - in_dims: 768 - hidden_dims: 256 - num_classes: 1001 - cosine_classifier: true - loss_weight: 0.5 - top_k: 5 - weight_schedule: cosine - weight_start: 0.0 - weight_warmup_epochs: 30 - lr: 0.0002 - log_batches_per_epoch: 8 - log_samples_per_batch: 1 - log_embeddings_every_n_epochs: 10 - example_input_array_shape: [1, 1, 1, 128, 128] - -data: - class_path: dynaclr.data.datamodule.MultiExperimentDataModule - init_args: - cell_index_path: /hpc/projects/organelle_phenotyping/datasets/ops/training_labels_1000genes_lite_v2_valid.parquet - z_window: 1 - yx_patch_size: [224, 224] - final_yx_patch_size: [128, 128] - channels_per_sample: 1 - positive_cell_source: lookup - positive_match_columns: [perturbation, marker] - stratify_by: marker - split_ratio: 0.8 - batch_size: 512 - num_workers: 4 - seed: 0 - shuffle_val: true - label_columns: - gene_label: perturbation - normalizations: - - class_path: viscy_transforms.BatchedScaleIntensityRangePercentilesd - init_args: - keys: [channel_0] - lower: 50 - upper: 99 - b_min: 0.0 - b_max: 1.0 - clip: true - augmentations: - - class_path: viscy_transforms.BatchedRandAffined - init_args: - keys: [channel_0] - prob: 0.8 - scale_range: [[1.0, 1.0], [0.9, 1.1], [0.9, 1.1]] - rotate_range: [3.14, 0.0, 0.0] - shear_range: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - - class_path: viscy_transforms.BatchedRandFlipd - init_args: - keys: [channel_0] - spatial_axes: [1, 2] - prob: 0.5 - - class_path: viscy_transforms.BatchedRandAdjustContrastd - init_args: - keys: [channel_0] - prob: 0.5 - gamma: [0.8, 1.2] - - class_path: viscy_transforms.BatchedRandScaleIntensityd - init_args: - keys: [channel_0] - prob: 0.5 - factors: 0.5 - - class_path: viscy_transforms.BatchedRandGaussianSmoothd - init_args: - keys: [channel_0] - prob: 0.5 - sigma_x: [0.2, 0.5] - sigma_y: [0.2, 0.5] - sigma_z: [0.0, 0.0] - - class_path: viscy_transforms.BatchedRandGaussianNoised - init_args: - keys: [channel_0] - prob: 0.5 - mean: 0.0 - std: 0.08 diff --git a/applications/dynaclr/configs/training/OPS-373genes.sh b/applications/dynaclr/configs/training/OPS-373genes.sh deleted file mode 100755 index 1a7134086..000000000 --- a/applications/dynaclr/configs/training/OPS-373genes.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash -# OPS 373-gene DynaCLR with gene classifier head -# -# New run: -# sbatch applications/dynaclr/configs/training/OPS-373genes.sh -# -# Resume: edit CKPT_PATH and WANDB_RUN_ID below, then sbatch from RUN_DIR. - -#SBATCH --job-name=dynaclr_ops -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=4 -#SBATCH --gres=gpu:4 -#SBATCH --partition=gpu -#SBATCH --cpus-per-task=15 -#SBATCH --mem-per-cpu=8G -#SBATCH --time=0-22:00:00 - -# ── Run identity ────────────────────────────────────────────────────── -export PROJECT="dynaclr" -export RUN_NAME="OPS-373genes-GeneClassifier" -export CONFIGS="applications/dynaclr/configs/training/OPS-373genes.yml" - -# ── Resume (uncomment to continue from checkpoint) ──────────────────── -# export CKPT_PATH="" -# export WANDB_RUN_ID="" - -source "$(dirname "$0")/slurm/train.sh" diff --git a/applications/dynaclr/configs/training/OPS-373genes.yml b/applications/dynaclr/configs/training/OPS-373genes.yml deleted file mode 100644 index 875f17714..000000000 --- a/applications/dynaclr/configs/training/OPS-373genes.yml +++ /dev/null @@ -1,124 +0,0 @@ -# OPS 373-gene DynaCLR with gene classifier head -# ================================================= -# Fine-tune from pre-trained OPS checkpoint with cosine classifier. -# Gene+reporter positive pairs, stratified by marker (reporter). -# -# Launch: -# sbatch applications/dynaclr/configs/training/OPS-373genes.sh - -seed_everything: 42 - -trainer: - accelerator: gpu - strategy: ddp - devices: 4 - num_nodes: 1 - precision: bf16-mixed - max_epochs: 300 - limit_train_batches: 400 - limit_val_batches: 100 - log_every_n_steps: 5 - enable_checkpointing: true - enable_model_summary: false - inference_mode: true - use_distributed_sampler: false - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: loss/val - every_n_epochs: 1 - save_top_k: 5 - save_last: true - - class_path: viscy_utils.callbacks.SaveConfigToWandb - -model: - class_path: dynaclr.engine.ContrastiveModule - init_args: - encoder: - class_path: viscy_models.contrastive.ContrastiveEncoder - init_args: - backbone: convnext_tiny - in_channels: 1 - in_stack_depth: 1 - stem_kernel_size: [1, 4, 4] - stem_stride: [1, 4, 4] - embedding_dim: 768 - projection_dim: 256 - drop_path_rate: 0.0 - loss_function: - class_path: viscy_models.contrastive.loss.NTXentLoss - init_args: - temperature: 0.5 - ckpt_path: /hpc/projects/intracellular_dashboard/ops/models/logs/dynaclr/ops_bagofchannels_gene_n_reporter_grouped_reporter_256proj_373genes_convnext_tiny_temp0p5_512bs_lr1e-4_pretrained_self/version_0/checkpoints/last.ckpt - lr: 0.0001 - log_batches_per_epoch: 8 - log_samples_per_batch: 1 - log_embeddings_every_n_epochs: 10 - example_input_array_shape: [1, 1, 1, 128, 128] - -data: - class_path: dynaclr.data.datamodule.MultiExperimentDataModule - init_args: - cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/applications/dynaclr/configs/cell_index/ops_373genes.parquet - z_window: 1 - yx_patch_size: [224, 224] - final_yx_patch_size: [128, 128] - channels_per_sample: 1 - positive_cell_source: lookup - positive_match_columns: [perturbation, marker] - stratify_by: marker - split_ratio: 0.8 - batch_size: 512 - num_workers: 4 - seed: 0 - shuffle_val: true - label_columns: - gene_label: perturbation - normalizations: - - class_path: viscy_transforms.BatchedScaleIntensityRangePercentilesd - init_args: - keys: [channel_0] - lower: 50 - upper: 99 - b_min: 0.0 - b_max: 1.0 - clip: true - augmentations: - - class_path: viscy_transforms.BatchedRandAffined - init_args: - keys: [channel_0] - prob: 0.8 - scale_range: [[1.0, 1.0], [0.9, 1.1], [0.9, 1.1]] - rotate_range: [3.14, 0.0, 0.0] - shear_range: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - - class_path: viscy_transforms.BatchedRandFlipd - init_args: - keys: [channel_0] - spatial_axes: [1, 2] - prob: 0.5 - - class_path: viscy_transforms.BatchedRandAdjustContrastd - init_args: - keys: [channel_0] - prob: 0.5 - gamma: [0.8, 1.2] - - class_path: viscy_transforms.BatchedRandScaleIntensityd - init_args: - keys: [channel_0] - prob: 0.5 - factors: 0.5 - - class_path: viscy_transforms.BatchedRandGaussianSmoothd - init_args: - keys: [channel_0] - prob: 0.5 - sigma_x: [0.2, 0.5] - sigma_y: [0.2, 0.5] - sigma_z: [0.0, 0.0] - - class_path: viscy_transforms.BatchedRandGaussianNoised - init_args: - keys: [channel_0] - prob: 0.5 - mean: 0.0 - std: 0.08 diff --git a/applications/dynaclr/configs/training/OPS/OPS-1000genes-allmarkers.sh b/applications/dynaclr/configs/training/OPS/OPS-1000genes-allmarkers.sh new file mode 100644 index 000000000..03f97a327 --- /dev/null +++ b/applications/dynaclr/configs/training/OPS/OPS-1000genes-allmarkers.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# OPS 1000-gene × ALL-markers DynaCLR — single-marker SupCon batches with +# sqrt-weighted marker sampling, warm-started from OPS-1000genes-lite epoch 192. +# +# New run: +# sbatch applications/dynaclr/configs/training/OPS/OPS-1000genes-allmarkers.sh +# +# Resume: edit CKPT_PATH and WANDB_RUN_ID below, then sbatch from RUN_DIR. + +#SBATCH --job-name=dynaclr_ops_allmk +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 +#SBATCH --partition=gpu +#SBATCH --constraint="h100|h200" +# gpu-h-5 has pathological NFS read performance on /hpc/projects/ — +# FOV-split Arrow-take takes ~26 min vs ~15s on gpu-h-2/gpu-f-4. +# Exclude it until the underlying storage issue is fixed or we move the +# dataset to faster storage. +#SBATCH --exclude=gpu-h-5 +#SBATCH --cpus-per-task=15 +# 16 GB/CPU × 60 CPUs = 960 GB/node. Needed because the 81M-row OPS +# cell_index × 4 DDP ranks × dataloader worker fork-copies blows past the +# original 480 GB budget. Pandas reference-counting defeats CoW so workers +# end up duplicating the full cached DataFrame. +#SBATCH --mem-per-cpu=14G +#SBATCH --time=3-00:00:00 + +# ── Run identity ────────────────────────────────────────────────────── +# Warm-started from prior OPS run (t89f7q4n/last.ckpt, Apr 20) at epoch 0. +# Picks up the FlexibleBatchSampler reshuffle fix (commit f4f40c38) plus +# the profiling-pass defaults (file_io_concurrency=128, ts.Batch overlap, +# cuDNN benchmark, TF32). num_workers stays at 1 for OPS due to per-rank +# memory pressure on the 81M-row cell_index. Optimizer state and epoch +# counter reset. +export PROJECT="OPS" +export RUN_NAME="OPS-1000genes-allmarkers-fix-shuffler" +WARMSTART_CKPT="/hpc/projects/organelle_phenotyping/models/OPS/OPS-1000genes-allmarkers/OPS-1000genes-allmarkers/t89f7q4n/checkpoints/last.ckpt" +export EXTRA_ARGS="--trainer.logger.init_args.project=OPS-1000genes-allmarkers --model.init_args.ckpt_path=${WARMSTART_CKPT}" +export CONFIGS="applications/dynaclr/configs/training/OPS/OPS-1000genes-allmarkers.yml" + +# ── Resume (Lightning full state, NOT what we want here) ────────────── +# export CKPT_PATH="" +# export WANDB_RUN_ID="" + +WORKSPACE_DIR="${WORKSPACE_DIR:-/hpc/mydata/eduardo.hirata/repos/viscy}" +source "${WORKSPACE_DIR}/applications/dynaclr/configs/training/slurm/train.sh" diff --git a/applications/dynaclr/configs/training/Phase-contrastive-timeaware.sh b/applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.sh similarity index 93% rename from applications/dynaclr/configs/training/Phase-contrastive-timeaware.sh rename to applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.sh index 6637b6634..96dbf1a99 100755 --- a/applications/dynaclr/configs/training/Phase-contrastive-timeaware.sh +++ b/applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.sh @@ -2,7 +2,7 @@ # Phase contrastive timeaware — DINOv3 frozen backbone + temporal MLP # # New run: -# sbatch applications/dynaclr/configs/training/Phase-contrastive-timeaware.sh +# sbatch applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.sh # # Resume: edit CKPT_PATH and WANDB_RUN_ID below, then sbatch from RUN_DIR. @@ -18,7 +18,7 @@ # ── Run identity ────────────────────────────────────────────────────── export PROJECT="Phase-contrastive-timeaware" export RUN_NAME="dinov3-mlp-temp0p5" -export CONFIGS="applications/dynaclr/configs/training/Phase-contrastive-timeaware.yml" +export CONFIGS="applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.yml" # ── Resume (uncomment to continue from checkpoint) ──────────────────── # export CKPT_PATH="" diff --git a/applications/dynaclr/configs/training/Phase-contrastive-timeaware.yml b/applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.yml similarity index 75% rename from applications/dynaclr/configs/training/Phase-contrastive-timeaware.yml rename to applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.yml index 5f50eed02..d0007b902 100644 --- a/applications/dynaclr/configs/training/Phase-contrastive-timeaware.yml +++ b/applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.yml @@ -5,22 +5,17 @@ # Reproduces legacy Phase contrastive timeaware ablations. # # Launch: -# sbatch applications/dynaclr/configs/training/Phase-contrastive-timeaware.sh +# sbatch applications/dynaclr/configs/training/Phase-contrastive/Phase-contrastive-timeaware.sh -seed_everything: 42 +base: + - ../recipes/trainer.yml + - ../recipes/model/dinov3_frozen_mlp.yml trainer: - accelerator: gpu strategy: auto devices: 1 - num_nodes: 1 precision: 32-true max_epochs: 150 - log_every_n_steps: 10 - enable_checkpointing: true - enable_model_summary: false - inference_mode: true - use_distributed_sampler: false callbacks: - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: @@ -36,39 +31,15 @@ trainer: every_n_epochs: 5 label_key: perturbation k: 20 - - class_path: viscy_utils.callbacks.SaveConfigToWandb model: - class_path: dynaclr.engine.ContrastiveModule init_args: - encoder: - class_path: viscy_models.foundation.DINOv3Model - init_args: - model_name: facebook/dinov3-convnext-tiny-pretrain-lvd1689m - freeze: true - projection: - class_path: viscy_models.components.heads.MLP - init_args: - in_dims: 768 - hidden_dims: 768 - out_dims: 128 - norm: ln - activation: relu - loss_function: - class_path: viscy_models.contrastive.loss.NTXentLoss - init_args: - temperature: 0.5 - lr: 0.0001 - log_batches_per_epoch: 3 - log_samples_per_batch: 3 - log_embeddings_every_n_epochs: 10 pca_color_keys: "[perturbation,hours_post_perturbation]" example_input_array_shape: [1, 1, 30, 192, 192] data: class_path: dynaclr.data.datamodule.MultiExperimentDataModule init_args: - collection_path: applications/dynaclr/configs/collections/Phase-contrastive-timeaware.yml cell_index_path: applications/dynaclr/configs/cell_index/Phase-contrastive-timeaware.parquet z_window: 30 z_extraction_window: 40 diff --git a/applications/dynaclr/configs/training/README.md b/applications/dynaclr/configs/training/README.md index f9f933da4..599d1fd45 100644 --- a/applications/dynaclr/configs/training/README.md +++ b/applications/dynaclr/configs/training/README.md @@ -1,96 +1,110 @@ # DynaCLR Training Configs -Composable training configuration using LightningCLI `--config` stacking. -Each layer is a YAML fragment; later configs deep-merge into earlier ones -(dicts merge, lists replace). +Training configuration stack for LightningCLI `--config`. Later configs +deep-merge into earlier ones (dicts merge, lists replace). Each leaf +YAML declares a `base:` list of recipes to compose on top of. -## Structure +## Directory layout ``` configs/training/ - _base.yml Trainer + model defaults (callbacks, optimizer, encoder) - arch/ Encoder geometry (stem, z_depth, patch size) - 2d_z1.yml stem=[1,4,4], z_window=1 - 3d_z16.yml stem=[4,4,4], z_window=16, random Z crop - 3d_z30.yml stem=[5,4,4], z_window=30, 192px patch - data/ Data pipeline: sampling + normalization + augmentations - boc_{dim}_{positive_pair}_{batch_composition}.yml - demo/ Self-contained configs for smoke tests (single --config) - slurm/ SLURM experiment scripts (sbatch entry points) - train.sh Shared launcher (sourced, not sbatch'd directly) - _legacy/ Old monolithic configs (reference only) + DynaCLR-2D/ # 2D (and MIP) time-lapse contrastive runs + DynaCLR-2D-BagOfChannels-v3.{yml,sh} + DynaCLR-2D-MIP-BagOfChannels.{yml,sh} + DynaCLR-2D-MIP-BagOfChannels-single-marker.{yml,sh} + DynaCLR-2D-MIP-BagOfChannels-single-marker-A40.{yml,sh} + DynaCLR-3D/ # 3D time-lapse contrastive runs + DynaCLR-3D-BagOfChannels-v2.{yml,sh} + DynaCLR-3D-BagOfChannels-v2-single-marker.{yml,sh} + DINOv3/ # DINOv3 frozen-encoder + MLP probes + DINOv3-temporal-MLP-2D-BagOfChannels.{yml,sh} + Phase-contrastive/ + Phase-contrastive-timeaware.{yml,sh} + + recipes/ # Reusable building blocks (referenced via base:) + trainer.yml Trainer + logger + common callbacks + model/ Encoder and head architectures + data/ Sampling / positive-pair strategies + augmentations/ Augmentation pipelines (ops_2d_mild, etc.) + + debug/ # Fast-dev-run / tiny configs for reproducing hangs / OOMs + demo/ # Self-contained single-file demos for smoke tests + slurm/ + train.sh Shared launcher sourced by every sbatch script + preprocess.yml Preprocessing config (not a training run) ``` -## Data config naming convention +Each top-level model family lives in its own folder. The `yml` and `sh` +for a given run share a name and a directory so `CONFIGS=` references +stay local. -``` -{channel_mode}_{dim}_{positive_pair_strategy}_{batch_composition}.yml -``` - -| Segment | Values | Meaning | -|---------|--------|---------| -| channel_mode | `boc` | bag-of-channels (1 random channel per sample) | -| dim | `2d`, `3d` | spatial dimensionality | -| positive_pair | `temporal` | same cell lineage at t+tau | -| | `gene-reporter` | same gene + same reporter (OPS) | -| | `self` | SimCLR-style (same crop, different augmentation) | -| batch_composition | `stratify-perturbation` | balance infected/uninfected | -| | `stratify-perturbation-marker` | balance perturbation and organelle marker | -| | `stratify-marker` | balance by reporter/marker only | - -## Composition +## Composition via `base:` -Stack three configs: `_base.yml` + `arch/*.yml` + `data/*.yml`, then -pass experiment-specific values as CLI overrides in the SLURM script. +Each leaf YAML starts with a `base:` list pointing at recipe fragments +(paths are relative to the YAML's directory; since all leaf YAMLs live +one level below `recipes/`, they use `../recipes/...`): -```bash -viscy fit \ - --config _base.yml \ - --config arch/3d_z16.yml \ - --config data/boc_3d_temporal_stratify-perturbation.yml \ - --trainer.devices 4 \ - --data.init_args.batch_size 512 \ - --data.init_args.collection_path path/to/collection.yml +```yaml +# DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml +base: + - ../recipes/trainer.yml + - ../recipes/model/contrastive_encoder_convnext_tiny.yml ``` +`viscy_utils.compose.load_composed_config` walks the `base:` chain, +deep-merges dicts, and replaces lists. + ## SLURM scripts -Each experiment is a thin `.sh` that sets `PROJECT`, `RUN_NAME`, `CONFIGS`, -experiment-specific `EXTRA_ARGS`, and sources `train.sh`: +Each experiment is a thin `.sh` that sets `PROJECT`, `RUN_NAME`, +`CONFIGS`, optional `EXTRA_ARGS`, and sources `slurm/train.sh`: ```bash -# Submit -sbatch slurm/DynaCLR-3D-BagOfChannels-v2.sh +sbatch applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh -# Override run name -RUN_NAME=phase2-hcl sbatch slurm/DynaCLR-3D-BagOfChannels-v2.sh +RUN_NAME=phase2-hcl sbatch applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh -# Parameter sweep for TEMP in 0.1 0.2 0.5; do RUN_NAME="sweep-temp${TEMP}" \ EXTRA_ARGS="--model.init_args.loss_function.init_args.temperature ${TEMP}" \ - sbatch slurm/DynaCLR-3D-BagOfChannels-v2.sh + sbatch applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh done ``` `train.sh` handles: -- `PYTHONNOUSERSITE=1` (prevents `~/.local/` shadowing conda) -- Creates `${MODEL_ROOT}/${PROJECT}/${RUN_NAME}/` output directory -- Copies config files into the run directory for reproducibility -- Sets WandB logger project/name/save_dir via CLI overrides -- Sets checkpoint dirpath via CLI override +- `export PYTHONNOUSERSITE=1` (prevents `~/.local/` shadowing conda) +- Creates `${MODEL_ROOT}/${PROJECT}/${RUN_NAME}/` output dir +- Rotates `config.yaml` from any previous run +- Copies the calling sbatch script into the run dir for reproducibility +- Sets WandB logger project / name / save_dir via CLI overrides +- Optional `CKPT_PATH` resume and `WANDB_RUN_ID` to continue a run -## Adding a new experiment +## Resuming a run -1. Check if an existing `data/*.yml` matches your sampling strategy. - If not, create a new one following the naming convention. -2. Create a new `slurm/.sh` with SBATCH directives and overrides. -3. Submit with `sbatch slurm/.sh`. +```bash +CKPT_PATH=/hpc/projects/.../checkpoints/last.ckpt \ +WANDB_RUN_ID= \ + sbatch --export=ALL,CKPT_PATH,WANDB_RUN_ID \ + applications/dynaclr/configs/training/DynaCLR-3D/DynaCLR-3D-BagOfChannels-v2.sh +``` -## Demo configs +`WANDB_RUN_ID` appends `--trainer.logger.init_args.id= +--trainer.logger.init_args.resume=must` so metrics land on the same +W&B timeline. -Self-contained single-file configs for quick testing: +## Adding a new experiment -```bash -viscy fit --config demo/demo_3d_fit.yml --trainer.fast_dev_run true -``` +1. Find the closest existing run in the matching model family + folder. Copy the `.yml` and `.sh` alongside it with a new name. +2. Edit `base:` in the YAML to pick the right recipes. +3. Override training-specific values in the YAML (or via `EXTRA_ARGS` + in the sbatch script for one-off sweeps). +4. `sbatch applications/dynaclr/configs/training//.sh`. + +## Debug / demo configs + +- `debug/` — fastdev, tiny, and DDP-reproducer configs used to isolate + SLURM hangs, memory spikes, and DDP sync issues. Launched with + `uv run viscy fit --config .yml --config debug/.yml`. +- `demo/` — self-contained single-file configs for quick local smoke + tests (no base chain). diff --git a/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.sh b/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.sh new file mode 100755 index 000000000..e085be446 --- /dev/null +++ b/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Fast-dev-run smoke test of BoC training on 4-GPU DDP. +# Goal: validate sampler generator + FOV split + NCCL init + first +# batch end-to-end with the 20k-row boc_tiny parquet. + +#SBATCH --job-name=boc_fastdev_ddp +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=8 +#SBATCH --mem-per-cpu=4G +#SBATCH --time=0-00:30:00 +#SBATCH --output=/hpc/mydata/eduardo.hirata/repos/viscy/tmp/boc_fastdev_ddp_%j.out + +export PYTHONNOUSERSITE=1 +export NCCL_DEBUG=WARN + +cd "$(dirname "$0")/../../../../.." + +srun uv run --project . viscy fit \ + --config applications/dynaclr/configs/training/DynaCLR-2D/DynaCLR-2D-MIP-BagOfChannels.yml \ + --config applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.yml diff --git a/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.yml b/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.yml new file mode 100644 index 000000000..59e42559b --- /dev/null +++ b/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.yml @@ -0,0 +1,29 @@ +# SLURM fast-dev-run override for DynaCLR-2D-MIP-BagOfChannels. +# Tests DDP end-to-end on a 20k-row slice: 4 ranks × sampler __iter__ + +# NCCL init + first batch + backward + val. +# +# Launch: +# sbatch applications/dynaclr/configs/training/DynaCLR-2D-MIP-BagOfChannels-fastdev-ddp.sh + +trainer: + strategy: ddp + devices: 4 + num_nodes: 1 + fast_dev_run: 5 + logger: null + callbacks: [] + max_epochs: 1 + +data: + init_args: + # 20k-row slice, enough to exercise sampler/dataset without load cost. + cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/tmp/boc_tiny.parquet + batch_size: 8 + num_workers: 1 + prefetch_factor: 2 + buffer_size: 2 + cache_pool_bytes: 0 + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev.yml b/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev.yml new file mode 100644 index 000000000..ad9aa883a --- /dev/null +++ b/applications/dynaclr/configs/training/debug/DynaCLR-2D-MIP-BagOfChannels-fastdev.yml @@ -0,0 +1,32 @@ +# Local fast-dev-run override for DynaCLR-2D-MIP-BagOfChannels. +# Goal: verify training starts and completes ≥1 train + val batch end-to-end +# on a single GPU. Uses the smallest DynaCLR parquet (3.4M rows). +# +# Run: +# uv run viscy fit \ +# --config applications/dynaclr/configs/training/DynaCLR-2D-MIP-BagOfChannels.yml \ +# --config applications/dynaclr/configs/training/DynaCLR-2D-MIP-BagOfChannels-fastdev.yml + +trainer: + strategy: auto + devices: 1 + num_nodes: 1 + fast_dev_run: 5 + logger: null + callbacks: [] + max_epochs: 1 + +data: + init_args: + # ~20k-row slice of the full BoC parquet — enough to exercise every + # sampling path without a 3-minute parquet load. + cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/tmp/boc_tiny.parquet + batch_size: 8 + num_workers: 0 + prefetch_factor: null + buffer_size: 2 + cache_pool_bytes: 0 + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev-ddp.sh b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev-ddp.sh new file mode 100755 index 000000000..812d12420 --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev-ddp.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Minimal OPS fast_dev_run on 4-GPU DDP to localize the post-LOCAL_RANK hang. +# Strips callbacks, logger, wandb. + +#SBATCH --job-name=ops_fastdev_ddp +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=4 +#SBATCH --gres=gpu:4 +#SBATCH --partition=gpu +#SBATCH --constraint="h100|h200" +#SBATCH --exclude=gpu-h-5 +#SBATCH --cpus-per-task=15 +#SBATCH --mem-per-cpu=14G +#SBATCH --time=0-01:00:00 +#SBATCH --output=/hpc/mydata/eduardo.hirata/repos/viscy/tmp/ops_fastdev_ddp_%j.out + +export PYTHONNOUSERSITE=1 +export NCCL_DEBUG=WARN + +cd "$(dirname "$0")/../../../../.." + +srun uv run --project . viscy fit \ + --config applications/dynaclr/configs/training/OPS/OPS-1000genes-allmarkers.yml \ + --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers-fastdev-ddp.yml diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev-ddp.yml b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev-ddp.yml new file mode 100644 index 000000000..e2689417c --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev-ddp.yml @@ -0,0 +1,31 @@ +# SLURM fast_dev_run=5 override for OPS to localize the post-LOCAL_RANK hang. +# Strips all callbacks, logger, wandb — just data + model + one train + val batch. +# +# Launch: +# sbatch applications/dynaclr/configs/training/OPS-1000genes-allmarkers-fastdev-ddp.sh + +trainer: + strategy: ddp + devices: 4 + num_nodes: 1 + # Narrowing: wandb logger (31264775) + OnlineEvalCallback (31264776) + # both confirmed harmless. Now testing val_check_interval and limit_* + # knobs. Dropping fast_dev_run so these actually take effect. + callbacks: [] + max_epochs: 1 + limit_train_batches: 10 + limit_val_batches: 5 + val_check_interval: 0.5 + num_sanity_val_steps: 0 + +data: + init_args: + batch_size: 16 + num_workers: 1 + prefetch_factor: 2 + buffer_size: 2 + cache_pool_bytes: 0 + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev.yml b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev.yml new file mode 100644 index 000000000..5d37feb38 --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-fastdev.yml @@ -0,0 +1,30 @@ +# Local fast-dev-run override for OPS-1000genes-allmarkers. +# Goal: reproduce the OOM path from job 31264591 on a single A40 locally. +# +# Run: +# uv run viscy fit \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers.yml \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers-fastdev.yml + +trainer: + strategy: auto + devices: 1 + num_nodes: 1 + fast_dev_run: true + logger: null + callbacks: [] + # fast_dev_run already caps batches/epochs; the explicit limits are defensive. + limit_train_batches: 1 + limit_val_batches: 1 + max_epochs: 1 + +data: + init_args: + batch_size: 8 + num_workers: 0 + prefetch_factor: null + # Skip warm-start checkpoint to keep this self-contained. + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp-local.yml b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp-local.yml new file mode 100644 index 000000000..002ff6e6d --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp-local.yml @@ -0,0 +1,33 @@ +# Local single-GPU "DDP" reproducer: strategy=ddp, devices=1. +# Exercises the DDP wrap path without needing 2 GPUs. Keeps wandb ENABLED +# (inherited from the parent OPS config) so we can test whether DDP-wrap +# + wandb is the hang, without SLURM. +# +# Run: +# uv run viscy fit \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers.yml \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers-tiny-ddp-local.yml + +trainer: + strategy: ddp + devices: 1 + num_nodes: 1 + max_epochs: 1 + limit_train_batches: 10 + limit_val_batches: 5 + val_check_interval: 0.5 + num_sanity_val_steps: 0 + callbacks: [] + +data: + init_args: + cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/tmp/ops_tiny.parquet + batch_size: 8 + num_workers: 0 + prefetch_factor: null + buffer_size: 2 + cache_pool_bytes: 0 + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp.sh b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp.sh new file mode 100755 index 000000000..1a998d30f --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# 4-GPU DDP on OPS tiny (346k rows) WITHOUT fast_dev_run. +# Isolates DDP+wandb+val_check_interval from dataset-size effects. + +#SBATCH --job-name=ops_tiny_ddp +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gres=gpu:2 +#SBATCH --partition=gpu +# Drop GPU-type constraint to clear the queue faster. nodes=1 guarantees +# the two ranks share a single GPU model, which is what matters for DDP. +#SBATCH --exclude=gpu-h-5 +#SBATCH --cpus-per-task=8 +#SBATCH --mem-per-cpu=8G +#SBATCH --time=0-00:30:00 +#SBATCH --output=/hpc/mydata/eduardo.hirata/repos/viscy/tmp/ops_tiny_ddp_%j.out + +export PYTHONNOUSERSITE=1 +export NCCL_DEBUG=WARN + +cd "$(dirname "$0")/../../../../.." + +srun uv run --project . viscy fit \ + --config applications/dynaclr/configs/training/OPS/OPS-1000genes-allmarkers.yml \ + --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers-tiny-ddp.yml diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp.yml b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp.yml new file mode 100644 index 000000000..d6328fe24 --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-ddp.yml @@ -0,0 +1,33 @@ +# 4-GPU DDP test on OPS tiny (346k rows) WITHOUT fast_dev_run. +# Purpose: narrow the hang — does it need full OPS scale (55M), or does +# DDP+wandb+val_check_interval on any OPS-flavored data reproduce it? +# +# Launch: +# sbatch applications/dynaclr/configs/training/OPS-1000genes-allmarkers-tiny-ddp.sh + +trainer: + strategy: ddp + devices: 2 + num_nodes: 1 + max_epochs: 1 + limit_train_batches: 10 + limit_val_batches: 5 + val_check_interval: 0.5 + num_sanity_val_steps: 0 + callbacks: [] + # 31265169 TIMEOUT with wandb logger on. Now disabling it to isolate + # whether wandb + DDP + no-fastdev is the bug. + logger: null + +data: + init_args: + cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/tmp/ops_tiny.parquet + batch_size: 8 + num_workers: 1 + prefetch_factor: 2 + buffer_size: 2 + cache_pool_bytes: 0 + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-full.yml b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-full.yml new file mode 100644 index 000000000..73411a4ef --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny-full.yml @@ -0,0 +1,34 @@ +# Local reproducer for the post-LOCAL_RANK hang on OPS tiny (346k rows). +# Same as OPS-1000genes-allmarkers-tiny.yml but DROPS fast_dev_run — because +# every passing run used fast_dev_run and every hanging run did not. +# Single GPU to rule out DDP as the variable. +# +# Run: +# uv run viscy fit \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers.yml \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers-tiny-full.yml + +trainer: + strategy: auto + devices: 1 + num_nodes: 1 + max_epochs: 1 + limit_train_batches: 10 + limit_val_batches: 5 + val_check_interval: 0.5 + num_sanity_val_steps: 0 + logger: null + callbacks: [] + +data: + init_args: + cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/tmp/ops_tiny.parquet + batch_size: 8 + num_workers: 0 + prefetch_factor: null + buffer_size: 2 + cache_pool_bytes: 0 + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny.yml b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny.yml new file mode 100644 index 000000000..4bf557a2b --- /dev/null +++ b/applications/dynaclr/configs/training/debug/OPS-1000genes-allmarkers-tiny.yml @@ -0,0 +1,32 @@ +# Local fast_dev_run override for OPS-1000genes-allmarkers on a tiny slice. +# Purpose: reproduce the full-config hang on a single GPU locally, where +# iteration is ~60 sec/test instead of ~10 min/SLURM-cycle. +# +# Run: +# uv run viscy fit \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers.yml \ +# --config applications/dynaclr/configs/training/OPS-1000genes-allmarkers-tiny.yml + +trainer: + strategy: auto + devices: 1 + num_nodes: 1 + fast_dev_run: 5 + logger: null + callbacks: [] + max_epochs: 1 + +data: + init_args: + # 346k-row slice: 2 experiments × 5 markers × 20 genes + # Preserves [gene_name, marker] SupCon pairing and batch_group_by=marker. + cell_index_path: /hpc/mydata/eduardo.hirata/repos/viscy/tmp/ops_tiny.parquet + batch_size: 8 + num_workers: 0 + prefetch_factor: null + buffer_size: 2 + cache_pool_bytes: 0 + +model: + init_args: + ckpt_path: null diff --git a/applications/dynaclr/configs/training/demo/demo_2d_fit.yml b/applications/dynaclr/configs/training/demo/demo_2d_fit.yml index b35f31e98..5d2febb41 100644 --- a/applications/dynaclr/configs/training/demo/demo_2d_fit.yml +++ b/applications/dynaclr/configs/training/demo/demo_2d_fit.yml @@ -58,11 +58,6 @@ data: class_path: dynaclr.data.datamodule.MultiExperimentDataModule init_args: # ── Data source ────────────────────────────────────────────────────────── - # For production: use the full v3 collection + parquet - # collection_path: applications/dynaclr/configs/collections/DynaCLR-2D-BagOfChannels-v3.yml - # cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-2D-BagOfChannels-v3.parquet - # For demo: single zarr, fast startup - collection_path: null cell_index_path: applications/dynaclr/configs/cell_index/example_flat.parquet # ── Patch extraction ───────────────────────────────────────────────────── diff --git a/applications/dynaclr/configs/training/demo/demo_3d_fit.yml b/applications/dynaclr/configs/training/demo/demo_3d_fit.yml index b5a0a0573..c809968f9 100644 --- a/applications/dynaclr/configs/training/demo/demo_3d_fit.yml +++ b/applications/dynaclr/configs/training/demo/demo_3d_fit.yml @@ -58,13 +58,6 @@ data: class_path: dynaclr.data.datamodule.MultiExperimentDataModule init_args: # ── Data source ────────────────────────────────────────────────────────── - # Provide one of collection_path or cell_index_path. - # cell_index_path is faster (skips zarr enumeration at startup). - # For production: use the full v2 collection + parquet - # collection_path: applications/dynaclr/configs/collections/DynaCLR-3D-BagOfChannels-v2.yml - # cell_index_path: /hpc/projects/organelle_phenotyping/models/collections/DynaCLR-3D-BagOfChannels-v2.parquet - # For demo: single zarr, fast startup - collection_path: null cell_index_path: applications/dynaclr/configs/cell_index/example_flat.parquet # ── Patch extraction ───────────────────────────────────────────────────── diff --git a/applications/dynaclr/configs/training/demo/demo_bag_of_channels_v3_fit.yml b/applications/dynaclr/configs/training/demo/demo_bag_of_channels_v3_fit.yml index 8f29ecad4..96eb260b0 100644 --- a/applications/dynaclr/configs/training/demo/demo_bag_of_channels_v3_fit.yml +++ b/applications/dynaclr/configs/training/demo/demo_bag_of_channels_v3_fit.yml @@ -53,8 +53,7 @@ model: data: class_path: dynaclr.data.datamodule.MultiExperimentDataModule init_args: - collection_path: applications/dynaclr/configs/collections/demo_bag_of_channels_v3.yml - cell_index_path: null + cell_index_path: applications/dynaclr/configs/cell_index/demo_bag_of_channels_v3.parquet z_window: 30 yx_patch_size: [288, 288] final_yx_patch_size: [192, 192] diff --git a/applications/dynaclr/configs/training/recipes/augmentations/ops_2d_mild.yml b/applications/dynaclr/configs/training/recipes/augmentations/ops_2d_mild.yml new file mode 100644 index 000000000..763d1cc68 --- /dev/null +++ b/applications/dynaclr/configs/training/recipes/augmentations/ops_2d_mild.yml @@ -0,0 +1,41 @@ +# Augmentation recipe: mild 2D augmentations for OPS data. +# Lighter affine (no Z scaling, no shear), narrower gamma, lower noise std. + +data: + init_args: + augmentations: + - class_path: viscy_transforms.BatchedRandAffined + init_args: + keys: [channel_0] + prob: 0.8 + scale_range: [[1.0, 1.0], [0.9, 1.1], [0.9, 1.1]] + rotate_range: [3.14, 0.0, 0.0] + shear_range: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + - class_path: viscy_transforms.BatchedRandFlipd + init_args: + keys: [channel_0] + spatial_axes: [1, 2] + prob: 0.5 + - class_path: viscy_transforms.BatchedRandAdjustContrastd + init_args: + keys: [channel_0] + prob: 0.5 + gamma: [0.8, 1.2] + - class_path: viscy_transforms.BatchedRandScaleIntensityd + init_args: + keys: [channel_0] + prob: 0.5 + factors: 0.5 + - class_path: viscy_transforms.BatchedRandGaussianSmoothd + init_args: + keys: [channel_0] + prob: 0.5 + sigma_x: [0.2, 0.5] + sigma_y: [0.2, 0.5] + sigma_z: [0.0, 0.0] + - class_path: viscy_transforms.BatchedRandGaussianNoised + init_args: + keys: [channel_0] + prob: 0.5 + mean: 0.0 + std: 0.08 diff --git a/applications/dynaclr/configs/training/recipes/data/ops_gene_reporter.yml b/applications/dynaclr/configs/training/recipes/data/ops_gene_reporter.yml new file mode 100644 index 000000000..f93941c8f --- /dev/null +++ b/applications/dynaclr/configs/training/recipes/data/ops_gene_reporter.yml @@ -0,0 +1,20 @@ +# Data recipe: OPS gene+reporter contrastive learning defaults. +# Leaf configs override: cell_index_path, normalizations (lower percentile differs). + +data: + class_path: dynaclr.data.datamodule.MultiExperimentDataModule + init_args: + z_window: 1 + yx_patch_size: [224, 224] + final_yx_patch_size: [128, 128] + channels_per_sample: 1 + positive_cell_source: lookup + positive_match_columns: [perturbation, marker] + stratify_by: marker + split_ratio: 0.8 + batch_size: 512 + num_workers: 4 + seed: 0 + shuffle_val: true + label_columns: + gene_label: perturbation diff --git a/applications/dynaclr/configs/training/recipes/model/contrastive_encoder_convnext_tiny.yml b/applications/dynaclr/configs/training/recipes/model/contrastive_encoder_convnext_tiny.yml new file mode 100644 index 000000000..3b70366e8 --- /dev/null +++ b/applications/dynaclr/configs/training/recipes/model/contrastive_encoder_convnext_tiny.yml @@ -0,0 +1,18 @@ +# Model recipe: ContrastiveModule with ConvNeXt-Tiny encoder. +# Leaf configs override: in_stack_depth, stem_kernel_size, stem_stride, +# projection_dim, drop_path_rate, temperature, lr, and logging args. + +model: + class_path: dynaclr.engine.ContrastiveModule + init_args: + encoder: + class_path: viscy_models.contrastive.ContrastiveEncoder + init_args: + backbone: convnext_tiny + in_channels: 1 + embedding_dim: 768 + loss_function: + class_path: viscy_models.contrastive.loss.NTXentLoss + log_batches_per_epoch: 3 + log_samples_per_batch: 3 + log_embeddings_every_n_epochs: 10 diff --git a/applications/dynaclr/configs/training/recipes/model/dinov3_frozen_mlp.yml b/applications/dynaclr/configs/training/recipes/model/dinov3_frozen_mlp.yml new file mode 100644 index 000000000..1e2e71699 --- /dev/null +++ b/applications/dynaclr/configs/training/recipes/model/dinov3_frozen_mlp.yml @@ -0,0 +1,27 @@ +# Model recipe: Frozen DINOv3-ConvNeXt-Tiny backbone + trainable MLP projection. +# Leaf configs override: pca_color_keys, example_input_array_shape. + +model: + class_path: dynaclr.engine.ContrastiveModule + init_args: + encoder: + class_path: viscy_models.foundation.DINOv3Model + init_args: + model_name: facebook/dinov3-convnext-tiny-pretrain-lvd1689m + freeze: true + projection: + class_path: viscy_models.components.heads.MLP + init_args: + in_dims: 768 + hidden_dims: 768 + out_dims: 128 + norm: ln + activation: relu + loss_function: + class_path: viscy_models.contrastive.loss.NTXentLoss + init_args: + temperature: 0.5 + lr: 0.0001 + log_batches_per_epoch: 3 + log_samples_per_batch: 3 + log_embeddings_every_n_epochs: 10 diff --git a/applications/dynaclr/configs/training/recipes/trainer.yml b/applications/dynaclr/configs/training/recipes/trainer.yml new file mode 100644 index 000000000..92a0bcdbf --- /dev/null +++ b/applications/dynaclr/configs/training/recipes/trainer.yml @@ -0,0 +1,35 @@ +# Trainer recipe: DynaCLR shared trainer defaults. +# Includes WandB logger (project/name/save_dir set by train.sh CLI overrides), +# LR monitor, and model checkpoint. Config is saved to trainer.log_dir by +# Lightning's default SaveConfigCallback; the wandb files tab picks it up +# automatically when save_dir matches. +# +# Leaf configs override: strategy, devices, precision, max_epochs, +# logger.init_args.project/name, and optionally re-list callbacks +# to add OnlineEvalCallback (callbacks is a list — it replaces entirely). + +seed_everything: 42 + +trainer: + accelerator: gpu + num_nodes: 1 + log_every_n_steps: 10 + enable_checkpointing: true + enable_model_summary: false + inference_mode: true + use_distributed_sampler: false + benchmark: true + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + entity: computational_imaging + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: loss/val + every_n_epochs: 1 + save_top_k: 5 + save_last: true diff --git a/applications/dynaclr/configs/training/slurm/train.sh b/applications/dynaclr/configs/training/slurm/train.sh index 807ce10dd..386435f07 100755 --- a/applications/dynaclr/configs/training/slurm/train.sh +++ b/applications/dynaclr/configs/training/slurm/train.sh @@ -28,6 +28,9 @@ RUN_DIR="${MODEL_ROOT}/${PROJECT}/${RUN_NAME}" export PYTHONNOUSERSITE=1 export NCCL_DEBUG=INFO export PYTHONFAULTHANDLER=1 +# bf16-mixed already lets matmuls use TF32; "high" instructs Lightning to +# enable TF32 for any remaining float32 matmuls (silences the runtime warning). +export TORCH_FLOAT32_MATMUL_PRECISION=high function cleanup() { rm -rf /tmp/$SLURM_JOB_ID/*.zarr diff --git a/applications/dynaclr/docs/DAGs/ai_ready_datasets.md b/applications/dynaclr/docs/DAGs/ai_ready_datasets.md new file mode 100644 index 000000000..8e000769f --- /dev/null +++ b/applications/dynaclr/docs/DAGs/ai_ready_datasets.md @@ -0,0 +1,163 @@ +# Data Preparation DAG + +## Entry point + +`prepare run -c prepare_config.yaml` (from `airtable_utils`) discovers wells and +channels from NFS, generates all configs and SLURM scripts, and submits the pipeline. + +```bash +prepare run 2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV \ + -c /path/to/prepare_config.yaml + +# Dry-run: generate configs/scripts without submitting +prepare run 2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV \ + -c /path/to/prepare_config.yaml \ + --dry-run +``` + +## Step-by-step detail + +``` +NFS assembled zarr (intracellular_dashboard/organelle_dynamics/{dataset}/2-assemble/) + │ + ▼ +prepare run # discovers wells + channels from NFS zarr + │ airtable_utils.prepare_cli # validates dataset is in Airtable + │ airtable_utils.prepare # generates all configs and scripts + ▼ +{vast_output_dir}/ + ├── crop_concat.yml # biahub concatenate config (wells × channels) + ├── qc_config.yml # focus-slice QC config + ├── sbatch_overrides.sh # optional SLURM overrides for biahub's internal jobs + ├── 01_concatenate.sh # bash (not SLURM): runs biahub + rsync tracking + ├── 02_qc.sh # SLURM: GPU focus-slice detection + └── 03_preprocess.sh # SLURM: CPU normalization stats + │ + ▼ +bash 01_concatenate.sh # NOT a SLURM job — runs interactively + │ Step 1: conda run biahub concatenate -c crop_concat.yml -o {dataset}.zarr -m + │ biahub submits its own SLURM jobs internally via submitit; -m blocks until done + │ Step 2: rsync tracking zarr (NFS → VAST) + ▼ +{dataset}.zarr (OME-Zarr v0.5 / zarr v3, rechunked) +tracking.zarr (cell tracking results) + │ + ├──► sbatch 02_qc.sh # GPU (~30 min) + │ qc run -c qc_config.yml # focus-slice detection on Phase3D channel + │ → writes focus_slice metadata into {dataset}.zarr + │ + └──► sbatch 03_preprocess.sh # CPU, preempted partition (~4 hrs) + viscy preprocess # computes per-channel normalization stats + --data_path {dataset}.zarr + → writes normalization metadata into {dataset}.zarr +``` + +## Pipeline DAG (process dependency) + +``` +NFS zarr (assembled) + │ + ▼ +prepare run ──── generates configs + scripts + │ + ▼ +01_concatenate.sh (interactive bash, blocks until biahub SLURM jobs finish) + │ + ▼ +{dataset}.zarr + tracking.zarr + │ + ├──► 02_qc.sh (SLURM, GPU) → focus_slice metadata in zarr + └──► 03_preprocess.sh (SLURM, CPU) → normalization metadata in zarr +``` + +02_qc and 03_preprocess run in parallel (no dependency between them). +Both write metadata back to the same zarr; their outputs are checked by +`check_preprocessed()` before downstream training or evaluation. + +## Key commands + + +| Step | Command | Input | Output | +| ----------------- | ------------------------------------------------- | ------------------ | --------------------------------------------------------------- | +| Generate + submit | `prepare run -c prepare_config.yaml` | NFS assembled zarr | scripts + configs, submits jobs | +| Status check | `prepare status -c prepare_config.yaml` | - | markdown table (NFS/VAST existence, zarr version, preprocessed) | +| Concatenate | `bash 01_concatenate.sh` | crop_concat.yml | {dataset}.zarr + tracking.zarr | +| QC | `sbatch 02_qc.sh` | qc_config.yml | focus_slice metadata in zarr | +| Preprocess | `sbatch 03_preprocess.sh` | {dataset}.zarr | normalization metadata in zarr | + + +## prepare_config.yaml format + +```yaml +nfs_root: /hpc/projects/intracellular_dashboard/organelle_dynamics +vast_root: /hpc/projects/organelle_phenotyping/datasets +workspace_dir: /hpc/mydata/eduardo.hirata/repos/viscy + +concatenate: + channel_names: null # null = auto-detect raw channels (Phase3D + "raw " prefix) + chunks_czyx: [1, 16, 256, 256] + shards_ratio: [1, 1, 8, 8, 8] + output_ome_zarr_version: "0.5" + conda_env: biahub + sbatch_overrides: # optional: overrides for biahub's internal SLURM jobs + partition: preempted + mem-per-cpu: 16G + +qc: + channel_names: [Phase3D] + NA_det: 1.35 + lambda_ill: 0.450 + pixel_size: 0.1494 + midband_fractions: [0.125, 0.25] + device: cuda + num_workers: 16 + +preprocess: + channel_names: -1 # -1 = all channels + num_workers: 32 + block_size: 32 + +slurm: + qc: + partition: gpu + gres: gpu:1 + cpus_per_task: 16 + mem_per_cpu: 4G + time: "00:30:00" + preprocess: + partition: preempted + cpus_per_task: 32 + mem_per_cpu: 4G + time: "04:00:00" +``` + +## Notes + +- `prepare run` validates the dataset exists in Airtable before generating anything. +Use `--force` to overwrite an existing VAST zarr (e.g. to upgrade from zarr v2 to v0.5). +- `01_concatenate.sh` is an interactive bash script, not a SLURM job. Run it from a login +node or an interactive session; it blocks until biahub's internal SLURM jobs finish (`-m` flag). +- `02_qc.sh` and `03_preprocess.sh` are independent — submit both immediately after +`01_concatenate.sh` completes; no need to wait for QC before running preprocess. +- Channel auto-detection (`channel_names: null`) keeps channels with prefix `Phase3D` or `raw` . +Virtual stains (`nuclei_prediction`, `membrane_prediction`) and deconvolved channels are excluded. +- `check_preprocessed()` checks for `normalization` key in zarr metadata; used by `prepare status` +and as a gate before evaluation. +- Raw channel names written to `crop_concat.yml` are repeated once per well entry — this is a +biahub concatenate requirement. + +## Path convention + +All AI-ready data lives under `/hpc/projects/organelle_phenotyping/`: + + +| Directory | Contents | +| -------------------------- | --------------------------------------------- | +| `datasets//` | Zarr v3 store + `tracking.zarr` | +| `datasets/annotations/` | Per-experiment annotation CSVs | +| `models/collections/` | Cell index parquets (one per collection YAML) | +| `models//` | Training runs (checkpoints, WandB configs) | + + +Collection YAMLs use `datasets_root: /hpc/projects/organelle_phenotyping` and +`${datasets_root}/datasets/...` placeholders — resolved at load time by `load_collection()`. diff --git a/applications/dynaclr/docs/DAGs/evaluation.md b/applications/dynaclr/docs/DAGs/evaluation.md new file mode 100644 index 000000000..5322c9202 --- /dev/null +++ b/applications/dynaclr/docs/DAGs/evaluation.md @@ -0,0 +1,682 @@ +# Evaluation DAG + +This document describes the **per-run** evaluation pipeline (one model on +one dataset). For the cross-model, cross-dataset matrix layout — including +the central linear-classifier registry that lets Wave-2 datasets fetch LC +pipelines trained on Wave-1 (infectomics-annotated) — see the companion +[`evaluation_matrix.md`](evaluation_matrix.md). + +## Running with Nextflow (recommended) + +```bash +module load nextflow/24.10.5 + +nextflow run applications/dynaclr/nextflow/main.nf -entry evaluation \ + --eval_config applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels/infectomics-annotated.yaml \ + --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ + -resume +``` + +`-resume` makes Nextflow skip steps whose outputs already exist. Re-run the same command after a failure — Nextflow picks up from where it left off. + +### Local test (no SLURM) + +```bash +nextflow run applications/dynaclr/nextflow/main.nf \ + --eval_config applications/dynaclr/configs/evaluation/DynaCLR-2D-MIP-BagOfChannels_test.yaml \ + --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ + -profile local \ + -resume +``` + +## Pipeline entry point + +`dynaclr prepare-eval-configs` (also aliased as `dynaclr evaluate`) generates all YAML configs +under `output_dir/configs/` and prints a JSON manifest to stdout. Nextflow reads the manifest +to wire steps together. + +``` +eval_config.yaml + │ + ▼ +dynaclr prepare-eval-configs -c eval_config.yaml # writes configs/ + manifest JSON + │ + ▼ +output_dir/configs/ + ├── eval.yaml # copy of input config (for re-runs) + ├── predict.yml # GPU step: viscy predict + ├── reduce.yaml # template: dynaclr reduce-dimensionality (per-experiment) + ├── reduce_combined.yaml # CPU step: dynaclr combined-dim-reduction (joint) + ├── smoothness.yaml # template: dynaclr evaluate-smoothness (per-experiment) + ├── plot.yaml # template: dynaclr plot-embeddings (per-experiment) — only when "plot" in steps + ├── plot_combined.yaml # CPU step: dynaclr plot-embeddings (all experiments) — only when "plot_combined" in steps + ├── {block_name}.yaml # template: dynaclr compute-mmd (per-experiment, per-block) + ├── {block_name}_cross_exp.yaml # CPU step: dynaclr compute-mmd --combined (per-block) + └── linear_classifiers.yaml # CPU step (optional) +``` + +## Step-by-step detail + +``` +checkpoint.ckpt + cell_index.parquet + │ + ▼ +viscy predict -c predict.yml # MultiExperimentDataModule predict mode + │ EmbeddingWriter callback # normalizations + z_reduction, no augmentations + ▼ # obs: fov_name, id, t, track_id, +embeddings/embeddings.zarr # experiment, marker, perturbation, + │ (AnnData: .X=features, # hours_post_perturbation, organelle, well, microscope + │ .obs=cell metadata) + │ + ▼ +dynaclr split-embeddings \ + --input embeddings/embeddings.zarr \ + --output-dir embeddings/ + │ Splits by obs["experiment"], deletes combined zarr + │ Also writes configs/viewer.yaml (datasets: {exp: {hcs_plate, anndata}}) + │ hcs_plate read from obs["store_path"] of each split zarr + ▼ +embeddings/{experiment_A}.zarr +embeddings/{experiment_B}.zarr + ... +configs/viewer.yaml # nd-embedding viewer config (also valid input + ... # for combined-dim-reduction via datasets: key) + │ + ├──► dynaclr reduce-dimensionality # PCA only (per experiment, parallel SLURM jobs) + │ -c reduce.yaml # __ZARR_PATH__ substituted by Nextflow + │ → {experiment}.zarr (obsm: X_pca) + │ NOTE: skip PHATE here to avoid computing it twice + │ + │ (after reduce-dimensionality finishes for ALL experiments) + │ + ├──► dynaclr combined-dim-reduction # joint PCA + PHATE across all experiments + │ -c reduce_combined.yaml # fits on concatenated embeddings + │ → {experiment}.zarr (obsm: X_pca_combined, X_phate_combined) + │ + │ (after combined-dim-reduction finishes) + │ + ├──► dynaclr plot-embeddings # per-experiment PCA scatter (X_pca) — when "plot" in steps + │ -c plot.yaml # parallel SLURM jobs, one per experiment + │ → plots/{experiment}/*.pdf + │ + ├──► dynaclr plot-embeddings # all-experiments combined (X_pca_combined, X_phate_combined) + │ -c plot_combined.yaml # only when "plot_combined" in steps; one job + │ → plots/combined/*.pdf + │ + ├──► dynaclr evaluate-smoothness # temporal smoothness + dynamic range + │ -c smoothness.yaml # parallel SLURM jobs, one per experiment + │ → smoothness/{model}_per_marker_smoothness.csv # one row per marker + │ → smoothness/{model}_smoothness_stats.csv # mean ± std across markers + │ → smoothness/*.pdf # per-marker + per-model plots + │ + ├──► dynaclr compute-mmd # one SLURM job per (experiment, block) + │ -c {block_name}.yaml # __ZARR_PATH__ substituted by Nextflow + │ → mmd/{block_name}/mmd_results.csv + │ → mmd/{block_name}/kinetics.pdf + │ → mmd/{block_name}/activity_heatmap.pdf + │ + ├──► dynaclr compute-mmd --combined # pairwise cross-experiment batch effect detection + │ -c {block_name}_cross_exp.yaml # only generated when combined_mode: true + │ # For each marker shared by a pair of experiments, runs MMD per + │ # (condition, time_bin) after per-pair mean centering. + │ # Conditions are auto-discovered from data intersection. + │ → mmd/{block_name}_cross_exp/combined_mmd_results.csv + │ → mmd/{block_name}_cross_exp/kinetics.pdf + │ → mmd/{block_name}_cross_exp/activity_heatmap.pdf + │ + ├──► dynaclr run-linear-classifiers # logistic regression probe + │ -c linear_classifiers.yaml # reads per-experiment zarrs directory + annotation CSVs + │ # joins annotations on (fov_name, t, track_id); trains one LogisticRegression + │ # per (task, marker); marker_filters omitted → auto-discovers all markers + │ # writes trained pipelines to linear_classifiers/pipelines/ (in-run staging) + │ # if publish_dir is set: atomically promotes the bundle to the central + │ # LC registry as {publish_dir}/vN/ and updates the `latest` symlink. + │ → linear_classifiers/metrics_summary.csv + │ → linear_classifiers/{task}_summary.pdf + │ → linear_classifiers/pipelines/{task}_{marker}.joblib + │ → linear_classifiers/pipelines/manifest.json + │ → {publish_dir}/vN/{task}_{marker}.joblib (when publish_dir set) + │ → {publish_dir}/vN/manifest.json (when publish_dir set) + │ → {publish_dir}/latest -> vN (atomic symlink swap) + │ + ├──► dynaclr append-annotations # persist ground truth labels to per-experiment zarrs + │ -c append_annotations.yaml # reads annotation CSVs + writes task columns to zarr obs + │ # only experiments with AnnotationSource entries are processed; others skipped + │ → {experiment}.zarr (obs: infection_state, organelle_state, ...) + │ + └──► dynaclr append-predictions # apply saved classifiers + -c append_predictions.yaml # predicts on ALL cells per marker, not just annotated ones + # pipelines_dir may be either: + # (a) in-run: {output_dir}/linear_classifiers/pipelines/ (default), or + # (b) external: a `latest` symlink into the central LC registry + # (e.g., /hpc/.../linear_classifiers/{model_name}/latest) + # The symlink is resolved once at startup so the run is consistent + # even if a new bundle is published mid-run. Logs feature_space (= + # registry/{model_name}) and version (= vN) for traceability. + → {experiment}.zarr (obs: predicted_infection_state, ...) + → {experiment}.zarr (obsm: predicted_infection_state_proba, ...) + → {experiment}.zarr (uns: predicted_infection_state_classes, + predicted_infection_state_lc_version, + predicted_infection_state_lc_feature_space, + predicted_infection_state_lc_path, ...) + +checkpoint.ckpt (independent of predict/split — runs in parallel) + │ + ▼ +viscy export -c export_onnx.yml # export backbone to ONNX + │ + ▼ +model.onnx + CTC datasets ({seq}_ERR_SEG/, {seq}/, {seq}_GT/TRA/) + │ + ▼ +dynaclr evaluate-tracking-accuracy \ # ILP tracking on CTC benchmarks + -c tracking_accuracy.yaml # loops over (model, dataset, sequence) + │ builds tracksdata graph from segmentation masks + │ runs ONNX inference on cell crops → dynaclr_similarity edge cost + │ solves ILP; compares to GT via evaluate_ctc_metrics + │ set show_napari: true for interactive inspection + ▼ +tracking_accuracy/results.csv # one row per (model, dataset, sequence) +tracking_accuracy/ # grouped mean summary printed to stdout +``` + +After all enrichment steps complete, per-experiment zarrs contain: + +- `.obs`: embeddings metadata + annotations (`infection_state`, etc.) + predictions (`predicted_infection_state`, etc.) +- `.obsm`: `X_pca`, `X_pca_combined`, `X_phate_combined`, `predicted_{task}_proba` +- `.uns`: `predicted_{task}_classes`, `predicted_{task}_lc_version`, `predicted_{task}_lc_feature_space`, `predicted_{task}_lc_path` + +This enables plots colored by experiment, perturbation, annotation, and prediction from a single zarr. The `_lc_*` uns fields record exactly which LC bundle produced each predicted column (registry path, version tag, feature_space). + +## Central LC registry + +Linear-classifier pipelines can be **published** to a central per-model +registry instead of (or in addition to) the per-run `output_dir`. This lets +later evaluations on different datasets reuse the same trained classifiers +without retraining. + +### Layout + +``` +/hpc/projects/organelle_phenotyping/models/linear_classifiers/ +├── DynaCLR-2D-MIP-BagOfChannels/ +│ ├── latest -> v3 # symlink (relative target) +│ ├── v1/ {manifest.json, *.joblib} +│ ├── v2/ +│ └── v3/ +├── DynaCLR-2D-BagOfChannels-v3/ { same } +├── DynaCLR-classical/ { same } +├── DINOv3-temporal-MLP-2D-BagOfChannels-v1/ { same } +└── DINOv3-frozen/ { same } +``` + +The directory name (e.g. `DynaCLR-2D-MIP-BagOfChannels`) is the +**feature_space** identifier — pipelines from one model's registry are +*not* applicable to a different model's embeddings (different dim, different +distribution). The model name follows the training-config-stem convention +(see `evaluation_matrix.md` §7). + +### Publishing (writer) + +A Wave-1 leaf (training run) sets `linear_classifiers.publish_dir`: + +```yaml +linear_classifiers: + publish_dir: /hpc/projects/organelle_phenotyping/models/linear_classifiers/DynaCLR-2D-MIP-BagOfChannels/ + # ... annotations, tasks, ... +``` + +`run-linear-classifiers` writes pipelines to a temp staging directory, +atomically renames to `vN/` (next available version), then atomically +swaps the `latest` symlink. Crash-safe: a partial bundle never appears as +`vN/`. + +### Fetching (reader) + +A Wave-2 leaf (evaluation on a different dataset) sets +`append_predictions.pipelines_dir`: + +```yaml +append_predictions: + pipelines_dir: /hpc/projects/organelle_phenotyping/models/linear_classifiers/DynaCLR-2D-MIP-BagOfChannels/latest +``` + +`append-predictions` resolves the symlink **once** at startup and uses the +resolved `vN/` for the rest of the run, so a publish during the run does +not affect output. The resolved path's parent name (`DynaCLR-2D-MIP-BagOfChannels`) +becomes `feature_space` in the manifest log. + +### Manifest format + +```json +{ + "trained_at": "2026-04-24T15:33:21+00:00", + "pipelines": [ + {"task": "infection_state", "marker_filter": "G3BP1", "path": "infection_state_G3BP1.joblib"}, + {"task": "infection_state", "marker_filter": "SEC61B", "path": "infection_state_SEC61B.joblib"} + ] +} +``` + +Lineage (model name + version) lives in the directory structure, not the +manifest. Reproducibility comes from pinning a specific `vN` (instead of +`latest`) in paper-rerun scripts. + +### Pinning vs. latest + +```yaml +# active development — picks up the latest published bundle +pipelines_dir: /hpc/.../linear_classifiers/DynaCLR-2D-MIP-BagOfChannels/latest + +# paper rerun — frozen at submission time +pipelines_dir: /hpc/.../linear_classifiers/DynaCLR-2D-MIP-BagOfChannels/v2 +``` + +## Nextflow DAG (process dependency graph) + +``` +checkpoint.ckpt ──────────────────────────────────────────────────────────────┐ + │ │ + ▼ ▼ +PREPARE_CONFIGS EXPORT_ONNX (optional) + │ │ + ▼ ▼ +PREDICT (GPU) model.onnx + CTC datasets + │ │ + ▼ ▼ +SPLIT (CPU light) TRACKING_ACCURACY (CPU) + │ → results.csv + ├─[scatter]─► REDUCE ─[gather]─► REDUCE_COMBINED ─┐ + │ │ + ├─► APPEND_ANNOTATIONS ───────────────────────────►├─[scatter]─► PLOT (only if "plot" in steps) + │ │ [gather] ─► PLOT_COMBINED (only if "plot_combined" in steps) + ├─► LINEAR_CLASSIFIERS ─► APPEND_PREDICTIONS ─────►┘ + │ + ├─[scatter]─► SMOOTHNESS ─[gather]─► SMOOTHNESS_GATHER + ├─[scatter per (exp,block)]─► MMD ─[gather]─► MMD_PLOT_HEATMAP + └─[gather per block]─► MMD_COMBINED +``` + +Key: **scatter** = one SLURM job per experiment (parallel). **gather** = waits for all scatter jobs. + +`TRACKING_ACCURACY` is independent of the embedding pipeline — it reads directly from an ONNX +model and CTC-format data. Run it manually or as a separate Nextflow job alongside the main DAG. + +`PLOT` (per-experiment fan-out) and `PLOT_COMBINED` (single combined figure) are +**independently togglable** via `steps:`. List `plot` for per-experiment scatter only, +`plot_combined` for the joint figure only, both for both, or neither for a metrics-only +run. `APPEND_ANNOTATIONS` and `APPEND_PREDICTIONS` emit a `'skip'` signal when not +present in `steps`, so plotting always proceeds once `REDUCE_COMBINED` finishes — +whichever plotting steps are listed. + +## CTC Tracking Accuracy Benchmark + +Standalone benchmark that evaluates whether DynaCLR embeddings improve cell tracking +accuracy on [Cell Tracking Challenge](https://celltrackingchallenge.net/) datasets. +**Not part of the Nextflow embedding pipeline** — run independently after exporting an ONNX model. + +### Approach + +``` +CTC segmentation masks + raw images + │ + ▼ +tracksdata graph (RegionPropsNodes + DistanceEdges) + │ + ├── baseline: IoU edge weights (no model) + │ + └── DynaCLR: ONNX inference on cell crops + → dynaclr_similarity × spatial_dist_weight as ILP edge cost + │ + ▼ +ILPSolver → tracked graph + │ + ▼ +evaluate_ctc_metrics vs. ground truth + │ + ▼ +results.csv (model × dataset × sequence × CTC metrics) +``` + +### Usage + +```bash +dynaclr evaluate-tracking-accuracy -c tracking_accuracy_config.yaml +``` + +### Config format + +```yaml +models: + - path: /hpc/projects/.../model_ckpt146.onnx + label: DynaCLR-classical + - path: /hpc/projects/.../model_ckpt185.onnx + label: DynaCLR-timeaware + - path: null # baseline: IoU + spatial distance only + label: baseline-iou + +datasets: + - path: /hpc/reference/group.royer/CTC/training/BF-C2DL-HSC + sequences: ["01", "02"] + - path: /hpc/reference/group.royer/CTC/training/Fluo-C2DL-Huh7 + sequences: ["01", "02"] + +crop_shape: [64, 64] # must match the model's training resolution +distance_threshold: 325.0 # spatial candidate edge threshold (pixels) +n_neighbors: 10 +delta_t: 5 # max frame gap for candidate edges +batch_size: 128 +output_dir: /path/to/tracking_accuracy_results +``` + +### Output + +**`results.csv`** — one row per (model, dataset, sequence): + +| Column | Description | +|--------|-------------| +| `model` | Model label | +| `dataset` | CTC dataset name | +| `sequence` | Sequence number (01, 02) | +| `LNK` | CTC Linking metric | +| `TRA` | Tracking metric | +| `DET` | Detection metric | +| `CHOTA` | Cell-specific HOTA | +| `HOTA` | Higher Order Tracking Accuracy | +| `MOTA` | Multiple Object Tracking Accuracy | +| `IDF1` | ID F1 score | +| `BIO(0)` | Biological metric | +| `OP_CLB(0)` | Combined linking+bio score | + +Prints a grouped summary (mean over sequences) at the end. + +### Prerequisites + +1. Export the model to ONNX: + ```bash + viscy export -c export_onnx.yml + ``` +2. CTC datasets must have `{seq}_ERR_SEG/`, `{seq}/`, and `{seq}_GT/TRA/` subdirectories. +3. Install eval dependencies: `uv sync --all-packages --extra eval` + +## Pseudotime alignment benchmark + +Standalone benchmark that quantifies how well DTW on DynaCLR embeddings recovers per-cell biological +event onsets (e.g. infection onset from the NS3 sensor channel). **Not part of the Nextflow embedding +pipeline** — runs after `linear_classifiers` + `append_predictions` so it can use either human +`infection_state` or model-predicted `predicted_infection_state` as ground truth. + +Full pipeline (template build + DTW alignment) lives in `applications/dynaclr/scripts/pseudotime/` — +see [`pseudotime.md`](pseudotime.md). The scoring step described below consumes the Stage 2a +alignment parquet that pipeline produces. + +### Approach + +``` +embeddings/{experiment}.zarr (with ground-truth obs) + │ + ▼ +pseudotime/2-align_cells/alignments/ + {template}_{flavor}_on_{query_set}.parquet + │ (per-frame: pseudotime, estimated_t_rel_minutes, alignment_region) + ▼ +score_alignment.py --method {dtw | no_align} + │ + │ Per cell: onset_error_minutes = estimated_t_rel_minutes at the first + │ aligned positive frame preceded by a negative frame + │ Population: AUROC + F1@0 over (cell, frame) pairs in aligned region + ▼ +scoring/{template}_{flavor}_on_{query_set}_{method}_per_cell.parquet +scoring/{template}_{flavor}_on_{query_set}_{method}_summary.md +scoring/results.csv (one row per run, accumulates across runs) + │ + ▼ +compare_methods.py + │ + ▼ +scoring/compare_methods.md (paper table) +scoring/compare_methods.png (4-panel bar chart: med|Δt|, IQR, AUROC, F1@0) +``` + +### Method comparison philosophy + +The benchmark compares **alignment methods on the same DynaCLR embedding**, not different +embeddings. Two methods are bundled: + +- **`dtw`** — uses `estimated_t_rel_minutes` from the Stage 2a parquet (DBA template + subsequence + DTW on the embedding trajectory). +- **`no_align`** — substitutes `estimated_t_rel_minutes` with each cell's frame index relative to + its track midpoint, no learning. The lower bound DTW must beat. + +This is the right comparison for the paper claim *"DTW on DynaCLR embeddings recovers infection +onset"* — the embedding is held fixed, so the metric attributes the gain to the alignment step, +not to representation quality. + +### Usage + +```bash +cd applications/dynaclr/scripts/pseudotime/2-align_cells +# Score one alignment parquet (DTW) +uv run python score_alignment.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/align_cells.yaml \ + --template infection_nondividing_sensor --flavor raw \ + --query-set sensor_all_07_24 \ + --truth-column infection_state --truth-positive infected \ + --method dtw + +# No-align baseline on the same parquet +uv run python score_alignment.py ... --method no_align + +# Render comparison artifacts +uv run python compare_methods.py +``` + +### Output schema (`scoring/results.csv`) + +| Column | Description | +|---|---| +| `template`, `flavor`, `query_set` | identifies the alignment parquet | +| `method` | `dtw` or `no_align` | +| `truth_column`, `truth_positive` | which obs column was the ground truth | +| `n_cells_scored` | cells with a usable negative→positive transition in the aligned region | +| `median_abs_onset_error_minutes` | robust center of \|Δt_onset\|; **primary metric** | +| `iqr_abs_onset_error_minutes` | spread of \|Δt_onset\| | +| `median_signed_onset_error_minutes` | systematic bias (≈0 if unbiased) | +| `auroc` | over aligned (cell, frame) pairs ranked by warped time; chance = 0.5 | +| `f1_at_zero` | F1 at the threshold `estimated_t_rel_minutes ≥ 0` | +| `n_pairs` | aligned (cell, frame) pairs used by AUROC/F1 | +| `timestamp_utc` | run timestamp | + +### Prerequisites + +1. A built template under `1-build_template/templates/` (see `pseudotime.md`). +2. A Stage 2a alignment parquet under `2-align_cells/alignments/`. +3. The query embedding zarrs must carry the requested `--truth-column` (human `infection_state` + on 07_22/07_24; `predicted_infection_state` on 08_26/01_28 — populated by + `APPEND_PREDICTIONS`). + +## Cross-model comparison + +After running evals for multiple models, compare results with: + +```bash +python applications/dynaclr/scripts/evaluation/compare_evals.py -c eval_registry.yml +``` + +Registry format: + +```yaml +models: + - name: DynaCLR-v3 + eval_dir: /path/to/eval_v3 # = output_dir of that model's eval run + - name: DINOv3-MLP + eval_dir: /path/to/eval_dino +output_dir: /path/to/comparison_output +fdr_threshold: 0.05 +``` + +`eval_dir` is the **`output_dir`** declared in each leaf eval YAML (where +`smoothness/`, `linear_classifiers/`, `mmd/` land). One row per model × dataset +pair you want to compare side-by-side. Auto-discovers results and produces +overlaid plots and summary CSVs for smoothness, linear classifiers, and MMD. + +The shipping registry at `applications/dynaclr/configs/evaluation/eval_registry.yaml` +is checked into git and updated as new (model × dataset) eval runs complete. + +## Key commands + +| Step | Command | Input | Output | +|------|---------|-------|--------| +| Config gen | `dynaclr prepare-eval-configs -c eval.yaml` | eval config | configs/ + manifest JSON | +| Predict | `viscy predict -c predict.yml` | checkpoint + parquet | embeddings/embeddings.zarr | +| Split | `dynaclr split-embeddings --input ... --output-dir ...` | combined zarr | per-experiment zarrs + `configs/viewer.yaml` | +| Dim reduction | `dynaclr reduce-dimensionality -c reduce.yaml` | {experiment}.zarr | zarr with X_pca | +| Combined reduction | `dynaclr combined-dim-reduction -c reduce_combined.yaml` | all {experiment}.zarr | zarrs with X_pca_combined/X_phate_combined | +| Plots (per-exp) | `dynaclr plot-embeddings -c plot.yaml` | {experiment}.zarr | plots/{experiment}/*.pdf | +| Plots (combined) | `dynaclr plot-embeddings -c plot_combined.yaml` | all {experiment}.zarr | plots/combined/*.pdf | +| Smoothness | `dynaclr evaluate-smoothness -c smoothness.yaml` | {experiment}.zarr | per_marker_smoothness.csv, smoothness_stats.csv | +| MMD (per-exp) | `dynaclr compute-mmd -c {block}.yaml` | {experiment}.zarr | mmd/{block}/mmd_results.csv | +| MMD (combined) | `dynaclr compute-mmd --combined -c {block}_cross_exp.yaml` | all {experiment}.zarr | mmd/{block}_cross_exp/combined_mmd_results.csv | +| MMD (pooled) | `dynaclr compute-mmd --pooled -c pooled.yaml` | all {experiment}.zarr | mmd_results.csv | +| Linear probe | `dynaclr run-linear-classifiers -c clf.yaml` | per-experiment zarrs + annotations | metrics_summary.csv, {task}_summary.pdf, pipelines/ | +| Append annotations | `dynaclr append-annotations -c append_annotations.yaml` | per-experiment zarrs + annotation CSVs | zarrs with obs annotation columns | +| Append predictions | `dynaclr append-predictions -c append_predictions.yaml` | per-experiment zarrs + pipelines/ | zarrs with predicted_{task} in obs/obsm/uns | +| Compare models | `python compare_evals.py -c eval_registry.yml` | multiple eval dirs | comparison CSVs + plots | +| CTC tracking | `dynaclr evaluate-tracking-accuracy -c tracking_accuracy.yaml` | ONNX model + CTC datasets | tracking_accuracy/results.csv | + +## Placeholder pattern + +Template YAMLs (`reduce.yaml`, `smoothness.yaml`, `{block}.yaml`, `plot.yaml`) contain `__ZARR_PATH__` +as a placeholder for `input_path`. `plot.yaml` also contains `__PLOT_DIR__`. Nextflow process +scripts substitute these inline with Python one-liners before calling the CLI command: + +```python +import yaml +with open('reduce.yaml') as f: + cfg = yaml.safe_load(f) +cfg['input_path'] = '/path/to/experiment.zarr' +with open('reduce_patched.yaml', 'w') as f: + yaml.dump(cfg, f, default_flow_style=False, sort_keys=False) +``` + +For `reduce_combined.yaml`, `plot_combined.yaml`, and `{block}_cross_exp.yaml`, Nextflow collects +all zarr paths and writes the `input_paths` list directly. + +## Notes + +- `MultiExperimentDataModule` supports `stage="predict"` since the eval orchestrator was added. + It uses the full cell index (no train/val split), applies only normalizations + z-reduction (no augmentations). +- `BatchedChannelWiseZReductiond` is architecturally required for 2D models even at inference time + (converts 3D z-stack → 2D MIP/center-slice). The orchestrator moves it from `augmentations` + to `normalizations` in the generated predict config. +- Dimensionality reductions (PCA, PHATE) are **not** computed inline during predict. + They run as separate CPU steps after splitting, keeping predict fast. +- The `combined-dim-reduction` step fits reductions on all experiments jointly and writes + `X_pca_combined` / `X_phate_combined` back to each per-experiment zarr. +- PHATE is not computed per-experiment by default (`reduce_dimensionality.phate: null`). Run it only jointly via `reduce_combined`. +- `configs/viewer.yaml` is generated after split and can be passed directly to `dynaclr combined-dim-reduction`. +- MMD reads `.X` (raw backbone embeddings) by default. It can also run on `X_pca` or `X_pca_combined` via `embedding_key`. +- Embeddings obs carries `organelle`, `well`, and `microscope` in addition to `experiment`, `marker`, `perturbation`, `hours_post_perturbation`. + +## MMD config format + +Use `configs/evaluation/recipes/mmd_defaults.yml` as a base to avoid repeating MMD algorithm parameters: + +```yaml +# Per-experiment (template — __ZARR_PATH__ substituted at runtime) +base: recipes/mmd_defaults.yml +input_path: __ZARR_PATH__ +output_dir: /path/to/evaluation/mmd/perturbation/ +group_by: perturbation +comparisons: + - cond_a: uninfected + cond_b: ZIKV + label: "uninfected vs ZIKV" +embedding_key: null # null = raw .X; or "X_pca", "X_pca_combined" +temporal_bin_size: 4.0 # uniform bin width in hours (null = aggregate) +# temporal_bins: [0, 6, 12, 24] # alternative: explicit bin edges (mutually exclusive) +mmd: + balance_samples: true # subsample larger group to match smaller + share_bandwidth_from: "uninfected vs uninfected" # reuse bandwidth from baseline comparison +map_settings: + enabled: true # compute mAP via copairs alongside MMD + +# Cross-experiment ({block}_cross_exp.yaml — input_paths substituted at runtime) +# No comparisons — conditions auto-discovered from data intersection. +base: recipes/mmd_defaults.yml +input_paths: [__ZARR_PATH__] +output_dir: /path/to/evaluation/mmd/perturbation_cross_exp/ +group_by: perturbation +temporal_bin_size: 4.0 + +# Pooled (standalone CLI only — not generated by orchestrator) +base: recipes/mmd_defaults.yml +input_paths: + - /path/to/exp_A.zarr + - /path/to/exp_B.zarr +output_dir: /path/to/evaluation/mmd/pooled/ +comparisons: + - cond_a: uninfected + cond_b: ZIKV + label: "uninfected vs ZIKV" +condition_aliases: + uninfected: [uninfected, uninfected1, uninfected2] # map variants to canonical name +``` + +## MMD output columns + +### Per-experiment and pooled (`mmd_results.csv`) + +| Column | Description | +|--------|-------------| +| `experiment` | Experiment name (absent in pooled output) | +| `marker` | Organelle marker (e.g., "TOMM20", "G3BP1") | +| `cond_a` | Reference/control condition | +| `cond_b` | Treatment condition | +| `label` | Human-readable comparison label | +| `hours_bin_start` | Start of temporal bin (NaN if no binning) | +| `hours_bin_end` | End of temporal bin (NaN if no binning) | +| `n_a` | Cells from `cond_a` used after subsampling | +| `n_b` | Cells from `cond_b` used after subsampling | +| `mmd2` | Unbiased MMD² estimate | +| `p_value` | Permutation test p-value (Phipson & Smyth smoothed) | +| `q_value` | BH-corrected FDR (pooled mode only) | +| `bandwidth` | Gaussian RBF bandwidth | +| `effect_size` | mmd2 / bandwidth (scale-free) | +| `activity_zscore` | (mmd2 − null_mean) / null_std — normalized against permutation null | +| `map_value` | Mean Average Precision (NaN if map_settings.enabled=false) | +| `map_p_value` | mAP permutation p-value (NaN if map_settings.enabled=false) | +| `embedding_key` | Embedding used ("X" or obsm key) | + +### Cross-experiment (`combined_mmd_results.csv`) + +| Column | Description | +|--------|-------------| +| `marker` | Organelle marker | +| `exp_a` | First experiment in the pair | +| `exp_b` | Second experiment in the pair | +| `condition` | Condition value matched across experiments | +| `hours_bin_start` | Start of temporal bin (NaN if no binning) | +| `hours_bin_end` | End of temporal bin (NaN if no binning) | +| `n_a` | Cells from `exp_a` used | +| `n_b` | Cells from `exp_b` used | +| `mmd2` | Unbiased MMD² estimate | +| `p_value` | Permutation test p-value | +| `bandwidth` | Gaussian RBF bandwidth | +| `effect_size` | mmd2 / bandwidth | +| `activity_zscore` | (mmd2 − null_mean) / null_std | +| `embedding_key` | Embedding used | + +## Linear classifiers output columns + +| Column | Description | +|--------|-------------| +| `task` | Classification task (e.g., `infection_state`) | +| `marker_filter` | Marker used to filter cells (one row per marker per task) | +| `n_samples` | Total annotated cells used | +| `val_accuracy` | Validation accuracy | +| `val_weighted_f1` | Validation weighted F1 | +| `val_auroc` | Validation AUROC (OvR macro for multiclass) | +| `train_*` | Training set counterparts of the above | +| `val_{class}_f1` | Per-class F1 on validation set | diff --git a/applications/dynaclr/docs/DAGs/pseudotime.md b/applications/dynaclr/docs/DAGs/pseudotime.md new file mode 100644 index 000000000..298437487 --- /dev/null +++ b/applications/dynaclr/docs/DAGs/pseudotime.md @@ -0,0 +1,731 @@ +# Pseudotime DAG + +This document describes how the pseudotime pipeline runs. For why it +runs this way — methodology decisions, claims, falsification protocol — +see the source-of-truth discussion document at +`/home/eduardo.hirata/repos/DynaCLR/.planning/dynaclr_dtw_discussion.md`. + +## 1. Goals + +We measure when SEC61 (ER), G3BP1 (stress granules), and quantitative +phase morphology change relative to per-cell NS3 sensor translocation +in single A549 cells, then compare three alignment tracks to see which +sharpens the population timing readouts. + +The pipeline produces three parallel sets of organelle-remodeling +plots, one per track, indexed on the same lineage-reconnected cohort +so the side-by-side comparison is direct. + +| Track | Anchor | Alignment | Outputs in | +|---|---|---|---| +| **A-anno** | Human `infection_state` first-positive frame | Per-cell shift, real-time | `3-organelle-remodeling/A-anno/` | +| **A-LC** | Linear classifier `predicted_infection_state` first-positive frame | Per-cell shift, real-time | `3-organelle-remodeling/A-LC/` | +| **B** | NS3 embedding band on transition window | Hybrid DTW: warp transition only, propagated to organelle and phase | `3-organelle-remodeling/B/` | + +Path B is the methodological contribution. Paths A-anno and A-LC are +baselines. + +### What `t_rel = 0` means + +`t_rel = 0` is when the NS3 protease sensor crosses our chosen anchor: +either the human `infection_state` first-positive frame (Path A-anno), +the LC threshold-crossing (Path A-LC), or the NS3 embedding's +half-rise band on the transition window (Path B). All three are +landmarks downstream of viral entry, polyprotein translation, and ER +remodeling — `t_rel = 0` is a fiducial clock, not the start of +infection. See discussion §1.1 and §2.1 for biological framing. + +## 2. Pipeline stages + +``` +0-select_candidates → 1-build_template → 2-align_cells → 3-organelle-remodeling → 4-compare_tracks + (lineage-reconnect, (Path B only — (Path A-anno, (per-track per-organelle (side-by-side + cohort tagging, NS3 transition A-LC, B readouts: SEC61 comparison, + manual + auto) template) alignments) cosine-distance, warp-vs-no-warp, + G3BP1 oscillation, bimodality test, + phase distance) robustness) +``` + +Stages 0 and 1 share data across tracks. Stages 2, 3, and 4 fork +per-track, then re-converge in Stage 4 for comparison. + +## 3. Directory layout + +``` +applications/dynaclr/ +├── configs/pseudotime/ +│ ├── datasets.yaml # datasets + embedding patterns +│ ├── candidates.yaml # candidate sets, cohort tags, lineage rules +│ ├── build_template.yaml # Stage 1 (Path B template build) +│ ├── align_cells.yaml # Stage 2 (per-track query sets) +│ └── compare_tracks.yaml # Stage 4 (cross-track headlines) +├── docs/DAGs/pseudotime.md # this file +├── src/dynaclr/pseudotime/ # library +│ ├── dtw_alignment.py # template fit + warp solver +│ ├── io.py # template-zarr layout + provenance +│ ├── alignment.py # lineage reconnection, daughter handling +│ ├── signals.py # extract annotation / LC / embedding signals +│ ├── metrics.py # onset, t_50, peak, oscillation stats +│ └── plotting.py # response curves, heatmaps, comparisons +└── scripts/pseudotime/ + ├── 0-select_candidates/ + │ ├── select_candidates.py # auto path (from annotations) + │ ├── manual_candidates.py # manual path (hand-picked tracks) + │ ├── reconnect_lineages.py # mother+daughter stitching, divides flag + │ ├── tag_cohorts.py # productive / bystander / abortive / mock + │ ├── inspect_candidates.py # per-track image montage QC + │ └── candidates/ # output: lineage-aware annotation CSVs + ├── 1-build_template/ # Path B only + │ ├── build_template.py # NS3 transition template via DBA + │ ├── evaluate_template.py # self-align sanity check + │ ├── templates/ # output: template_*.zarr + │ └── plots/ # output: build-set diagnostics + ├── 2-align_cells/ + │ ├── align_anno.py # Path A-anno: real-time shift on infection_state + │ ├── align_lc.py # Path A-LC: real-time shift on LC predictions + │ ├── align_embedding.py # Path B: hybrid DTW + warp propagation + │ ├── A-anno/alignments/ # per-track output parquets + │ ├── A-LC/alignments/ + │ └── B/alignments/ + ├── 3-organelle-remodeling/ + │ ├── readout_sec61.py # cosine-distance-from-baseline + │ ├── readout_g3bp1.py # oscillation excursion stats + │ ├── readout_phase.py # phase embedding distance + │ ├── A-anno/ # per-track per-organelle outputs + │ ├── A-LC/ + │ └── B/ + └── 4-compare_tracks/ + ├── compare_onsets.py # side-by-side SEC61, G3BP1, phase headlines + ├── compare_phase_to_fluor.py # claim (a') Spearman ρ + ├── warp_vs_no_warp.py # mandatory comparator (Path B only) + ├── bimodality_check.py # dip-test on every back-projected distribution + └── headline_figure.py # the figure-1 of the paper +``` + +Status of stage scripts as of this revision: + +- **Implemented and current:** `0-select_candidates/select_candidates.py`, + `manual_candidates.py`, `inspect_candidates.py`, + `1-build_template/build_template.py`, `evaluate_template.py`, + `2-align_cells/align_cells.py` (current name covers what becomes + `align_embedding.py` after split), per-stage plotting scripts. +- **Implemented but in worktree, not in DAG structure:** + `annotation_remodeling.py` and `prediction_remodeling.py` cover + Path A-anno and Path A-LC respectively; live in + `.claude/worktrees/cytoland-virtual-staining-examples/applications/dynaclr/scripts/pseudotime/`. + Need refresh against new directory structure and configs. +- **TODO: not yet implemented:** `reconnect_lineages.py`, + `tag_cohorts.py`, the per-track split in Stage 2, per-organelle + readouts as separate scripts, all of Stage 4. + +The current scripts still produce useful output but operate on a +single track (Path A-anno-equivalent) without lineage reconnection, +cohort tagging, or cross-track comparison. The phases below describe +the migration. + +## 4. Cohorts + +Stage 0 emits one annotation CSV per cohort. All four cohorts share the +schema. Each cell carries a `cohort` column. + +| Cohort | Definition | Used by | +|---|---|---| +| `productive` | Lineage with NS3 rise within imaging window; manual `[t_before, t_after]` and `t_key_event` from `manual_candidates.py` | Primary cohort; all three tracks | +| `bystander` | Lineage in infected wells with no NS3 rise across imaging duration | Negative control for claims (a), (a'), (b) | +| `abortive` | Lineage with NS3 channel embedding present, no rise | Claim (b) bifurcation comparison (caveat: censored-data category, see discussion §3.2) | +| `mock` | Lineage from uninfected wells | Per-frame null distribution for organelle distance comparisons | + +Mock cells do not get a synthetic `t_zero`. Every mock cell × frame +contributes to an FOV-stratified per-frame null distribution. + +## 5. Stage 0 — Select candidates and reconnect lineages + +Stage 0 emits per-cohort annotation CSVs. Two entry points feed it: +`select_candidates.py` (auto, from existing annotations) and +`manual_candidates.py` (hand-picked tracks). + +The single output artifact is `{cohort}_annotations.csv`, one row per +`(dataset_id, fov_name, lineage_id, t)` over the per-cell crop window. + +### 5.1 Auto path + +```bash +cd applications/dynaclr/scripts/pseudotime/0-select_candidates +uv run python select_candidates.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/candidates.yaml \ + --candidate-set infection_transitioning_nondiv +``` + +Filters tracks per `config["candidate_sets"][NAME]["filter"]` (anchor +label, anchor_positive, min_pre/post_minutes, crop_window_minutes), +expands each track into per-frame rows, copies real annotation labels +onto each row, then runs lineage reconnection (§5.3) and cohort tagging +(§5.4) before writing. + +### 5.2 Manual path + +```bash +cd applications/dynaclr/scripts/pseudotime/0-select_candidates +python manual_candidates.py +``` + +Each track spec is `{t_before, t_after, labels: {...}}` in a Python +dict. Frames in `[t_before, t_after]` get the positive label if they +fall inside an interval. The CSV schema is the only contract with +downstream stages. + +### 5.3 Lineage reconnection + +**TODO: implement.** Stitches mother + daughter chains into single +lineages using `parent_track_id`. Daughter handling regime-dependent: + +- Division before `t_zero`: keep both daughters as paired observations; + siblings are biologically equivalent at infection. +- Division after `t_zero`: keep the daughter with more pre-`t_zero` + footage. Daughters at this stage carry different viral loads. +- Division during the transition window: tag as separate cohort outside + DTW alignment; mitotic ER fragmentation distorts templates. + +Each lineage record carries `divides ∈ {none, pre, during, post}` and +`lineage_id` (replaces `track_id` as the canonical unit downstream). + +### 5.4 Cohort tagging + +**TODO: implement.** For each lineage, derive `cohort` from the +NS3 channel signal: + +- `productive`: lineage has manual `t_key_event` and survives window + cropping in Stage 1. +- `bystander`: in infected well, no NS3 rise across imaging window. +- `abortive`: NS3 channel embedding present, no rise. +- `mock`: from uninfected wells. + +### 5.5 Inspect + +```bash +cd applications/dynaclr/scripts/pseudotime/0-select_candidates +uv run python inspect_candidates.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/candidates.yaml \ + --candidate-set infection_transitioning_nondiv +``` + +Renders a per-cell-anchored image montage with `t_key_event` marked. +Also writes a sidecar `_qc.csv` with per-track stats (n_frames, +pre_frames, post_frames, fov, divides) for non-visual QC. + +## 6. Stage 1 — Build NS3 transition template (Path B only) + +Stage 1 produces the canonical NS3 transition template against which +Path B aligns query cells. Paths A-anno and A-LC do not use a +template. + +### 6.1 Template build + +```bash +cd applications/dynaclr/scripts/pseudotime/1-build_template +uv run python build_template.py \ + --config ../../../configs/pseudotime/build_template.yaml \ + --template infection_nondividing_zikv +``` + +The builder: + +1. Reads `productive` cohort annotations. +2. Crops each lineage to `[t_zero - h_pre, t_zero + h_post]` real-time. + Default: `h_pre = 240 min`, `h_post = 360 min` (`540 min` for G3BP1 + readout downstream). See discussion §3.6. +3. Pulls NS3 channel embeddings within the transition sub-window + `[t_zero - k_pre, t_zero + k_post]`. Default: `k_pre = 60 min`, + `k_post = 120 min`. Use the 10 min/frame cohort if available + (target frame count ≥ 12 for DBA stability). +4. Computes per-cell pre-baseline = mean NS3 embedding across pre-window + frames. Cosine distance against this per-cell baseline. +5. Runs DBA on the transition sub-window only. +6. Saves the template zarr. + +### 6.2 Template zarr contents + +| Path | Type | Description | +|---|---|---| +| `template` | (T, D) array | DBA template in the transition window | +| `time_calibration` | (T,) array | mean `t_relative_minutes` per template position | +| `template_labels/{col}` | (T,) array | per-position label fractions | +| `tau_event_band` | (2,) array | `[τ_lo, τ_hi]`: half-rise band of `||dT/dτ||`. The event identifier per discussion §3.4. | +| `lineage_ids` | list (attrs) | `[dataset_id, fov_name, lineage_id]` per build-set lineage | +| `aggregator` | str (attrs) | `"dba"` | +| `template_duration_minutes` | float (attrs) | `time_calibration[-1] - time_calibration[0]` | +| `build_frame_intervals_minutes` | dict (attrs) | per-dataset frame interval at build time | +| `viscy_git_sha`, `dtaidistance_version`, `scikit_learn_version`, `numpy_version` | str (attrs) | provenance | + +**TODO:** add `tau_event_band` to the zarr writer. Currently the +template stores derivative-argmax as a point. + +### 6.3 Self-consistency check + +```bash +uv run python evaluate_template.py \ + --config ../../../configs/pseudotime/build_template.yaml \ + --template infection_nondividing_zikv +``` + +Re-aligns the build-set lineages onto the template they built. Sanity +check, not generalization evidence. + +## 7. Stage 2 — Align query cells per track + +Stage 2 forks per track. Each track produces an alignment parquet with +the same schema (described in §7.4) so Stage 3 readouts and Stage 4 +comparisons read uniformly. + +### 7.1 Path A-anno: annotation-anchored real-time shift + +**TODO: implement** as `align_anno.py`. Replaces `annotation_remodeling.py`'s +alignment step (currently in +`.claude/worktrees/.../annotation_remodeling.py`). + +Per-cell `t_zero` = first frame where `infection_state == "infected"` +in the manual or auto annotation CSV. Real-time per-cell shift: +`t_rel = (t - t_zero) * frame_interval_minutes`. No DTW, no template, +no warping. + +```bash +# TODO: command shape +uv run python align_anno.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/align_cells.yaml \ + --query-set zikv_07_24 +``` + +### 7.2 Path A-LC: LC-anchored real-time shift + +**TODO: implement** as `align_lc.py`. Replaces `prediction_remodeling.py`. + +Per-cell `t_zero` = first frame of the longest run of `predicted_infection_state == "infected"` in the NS3 channel embedding zarr. `min_run` parameter (default 3) prevents single-frame flickers from defining `t_zero`. + +```bash +# TODO: command shape +uv run python align_lc.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/align_cells.yaml \ + --query-set zikv_07_24 --min-run 3 +``` + +### 7.3 Path B: hybrid DTW + warp propagation + +Existing `align_cells.py` covers most of this. Renaming and feature gaps +listed below. + +```bash +cd applications/dynaclr/scripts/pseudotime/2-align_cells +uv run python align_embedding.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/align_cells.yaml \ + --template infection_nondividing_zikv \ + --query-set zikv_07_24 \ + --min-match-minutes 360 --max-skew 0.7 +``` + +The aligner: + +1. Loads `templates/template_{name}.zarr` and the cohort-tagged query + annotations. +2. Pulls NS3 channel embeddings, applies build-time L2 normalization + (no refit at alignment time). +3. For each query lineage, runs subsequence DTW on the transition + sub-window. The template (length T) must match fully; the query + floats. Returns a warp path, best-match window `[q_start, q_end]`, + cost, and `path_skew`. +4. **TODO:** propagate the warp path to organelle and phase channel + embeddings within the transition sub-window. Pre and post stay in + real-time. +5. **TODO:** for each cell, back-project the τ_event band to a + per-cell real-time interval. Report median + IQR per cohort. +6. Applies guards (§7.6) and writes alignment parquet. +7. Writes a sidecar `{template}_on_{qset}.drop_log.json` with + per-filter drop counts. + +### 7.4 Alignment parquet schema (all three tracks) + +One row per `(dataset_id, fov_name, lineage_id, t)`. Per-lineage columns +are repeated on every frame so downstream scripts can filter rows +without a separate join. + +| Column | Type | Per-frame? | Tracks | Notes | +|---|---|---|---|---| +| `dataset_id`, `fov_name`, `lineage_id`, `t` | ids | yes | all | identifiers (`lineage_id` replaces `track_id`) | +| `cohort` | str | yes | all | `productive` / `bystander` / `abortive` / `mock` | +| `divides` | str | yes | all | `none` / `pre` / `during` / `post` | +| `t_zero` | int | per-lineage | all | per-cell anchor frame | +| `t_rel_minutes` | float | yes | all | real-time minutes from `t_zero` | +| `track_path` | str | per-lineage | all | `A-anno` / `A-LC` / `B` | +| `pseudotime` | float ∈ [0, 1] | yes | B only | warp-path template position | +| `alignment_region` | str | yes | B only | `pre` / `aligned` / `post` | +| `t_rel_minutes_warped` | float | yes | B only | back-projected real-time at template position; equals `t_rel_minutes` outside `aligned` | +| `dtw_cost` | float | per-lineage | B only | raw DTW cost | +| `length_normalized_cost` | float | per-lineage | B only | `dtw_cost / len(warp_path)` | +| `path_skew` | float ∈ [0, 1] | per-lineage | B only | mean deviation from ideal diagonal | +| `match_q_start`, `match_q_end` | int | per-lineage | B only | absolute query frames bounding the matched window | +| `template_id` | str | per-lineage | B only | UUID linking to template zarr | + +`t_rel_minutes` is shared across tracks. For Paths A, it's the only +time coordinate. For Path B, it covers pre/post windows; the transition +window also has `t_rel_minutes_warped` from the back-projection. + +### 7.5 Diagnostic plots per track + +```bash +# Path B only — same scripts as before, renamed +uv run python rank_by_cost.py --query-set zikv_07_24 +uv run python plot_top_n_montage.py --query-set zikv_07_24 --top-n 30 --worst-n 10 +uv run python plot_pcs_aligned.py --query-set zikv_07_24 --top-n 50 +``` + +### 7.6 Guards and frame-rate invariance + +DTW with generous psi can collapse the template onto a single query +frame. Five guards prevent and surface this: + +| Guard | CLI flag | Default | Rejects | +|---|---|---|---| +| Non-finite cost | always on | — | tracks too short for the solver | +| Path skew | `--max-skew` | 0.7 | degenerate non-diagonal warps (primary gate per discussion §3.8 #2) | +| Length-normalized cost gate | `--cost-gate` | sweep | stereotypy filter; sweep `{0, 10, 20, 30, 50}%` | +| Minute-based match window | `--min-match-minutes` | 360 | template compressed onto tiny real-time window | +| Pre/post headroom | per query-set YAML | 0 | lineages without real footage on either side | + +Path skew is the primary gate; cost is secondary (per discussion §3.8 +#2: skew rejects DTW failures without rejecting biological variance). +**TODO:** wire path-skew-as-primary into the existing two-pass filter +(currently cost-only). + +`--max-psi-minutes` defaults to half the template duration, read from +template attrs. Per-track psi is `round(max_psi_minutes / dataset_frame_interval_minutes)`. + +## 8. Stage 3 — Organelle remodeling readouts + +Stage 3 forks per track per organelle. Each readout reads its track's +alignment parquet and produces population curves and per-cell timing +metrics. The plotting scripts in this stage replace the current +`plot_organelle_remodeling.py` with per-organelle scripts. + +### 8.1 SEC61 readout + +**TODO: implement** as `readout_sec61.py`. Cosine distance of SEC61 +embedding from per-cell baseline (= mean SEC61 embedding across +pre-window frames). Per-cell trajectory binned by `t_rel_minutes`. +Population curve = binned median + IQR. + +```bash +uv run python readout_sec61.py \ + --track {A-anno|A-LC|B} \ + --query-set zikv_07_24 +``` + +Headline metric: real-time `t_rel` at which productive median exceeds +FOV-paired mock 95th percentile. Reported per-cohort. For Path B, +back-projected real-time IQR is reported alongside. + +### 8.2 G3BP1 readout + +**TODO: implement** as `readout_g3bp1.py`. Oscillation-aware metrics +on the post-window (real-time, never warped per discussion §3.6 — stress +granule kinetics are sub-frame, warping by NS3 is meaningless): + +| Metric | Definition | +|---|---| +| `excursion_count` | Number of distance threshold crossings in the post-window | +| `dwell_time_minutes` | Total time above threshold | +| `largest_excursion_amplitude` | Max distance above baseline | +| `largest_excursion_duration` | Duration of the longest contiguous excursion | + +Threshold = mock pulsation 95th percentile, FOV-stratified. + +### 8.3 Phase readout + +**TODO: implement** as `readout_phase.py`. Cosine distance of phase +embedding from per-cell baseline. Same structure as SEC61. + +For claim (a'), phase per-cell `t_50` (or onset-time metric) is +extracted and matched to the same cell's SEC61 fluorescence onset +time. The pair `(t_phase, t_sec61)` per cell feeds Stage 4's Spearman +ρ comparison. + +### 8.4 Per-cell timing metrics + +**TODO: implement** per-track `compute_timing_metrics.py`. Replaces +the existing single-track version under `3-organelle-remodeling/`. + +Per-cell scalars (computed on the aligned region): + +| Metric | Definition | Why | +|---|---|---| +| `t_onset_abs` | First `t_rel` where `distance − pre_median` crosses `+0.10` | SNR-robust | +| `t_50` | First `t_rel` crossing `pre_median + 0.5 × Δpeak`, last 2 frames excluded | Half-rise timing | +| `t_peak` | `argmax` of distance over interior aligned region | Time of maximum embedding divergence | +| `delta_peak` | `max(aligned distance) − median(pre distance)` | Amplitude in cosine units | +| `rise_rate_per_hour` | OLS slope of distance vs `t_rel` over aligned region × 60 | Per-cell aggregate speed | + +Outputs: `timing/{stem}_per_cell.parquet` + `timing/{stem}_summary.md`. + +## 9. Stage 4 — Cross-track comparison and robustness + +**TODO: implement.** Stage 4 is the headline-figure stage of the paper. + +### 9.1 Side-by-side onset comparison + +```bash +cd applications/dynaclr/scripts/pseudotime/4-compare_tracks +uv run python compare_onsets.py \ + --query-set zikv_07_24 \ + --organelles sec61,g3bp1,phase \ + --tracks A-anno,A-LC,B +``` + +For each `(organelle, track)`, plot the population curve + IQR on a +real-time axis. Three columns (organelles), three rows (tracks). The +methodological-claim verdict is whether Path B's IQR is ≥25% tighter +than the better of A-anno and A-LC at the headline metric (per +discussion §2.2). + +Outputs: +- `compare_onsets_{qset}.png`: the headline figure +- `compare_onsets_{qset}_summary.csv`: per `(organelle, track, cohort)` + headline number with CI, dip-test, dropped-cohort comparison + +### 9.2 Phase-to-fluorescence correlation (claim a') + +```bash +uv run python compare_phase_to_fluor.py \ + --query-set zikv_07_24 \ + --organelle {sec61|g3bp1} \ + --track B +``` + +For each productive lineage, extract phase and matched fluorescent +marker `t_50`. Compute Spearman ρ across the cohort. Permutation null +(1000 shuffles) for the p-value. SEC61 carries falsification weight +(per discussion §7 claim a'); G3BP1 expected-null result reports as +positive evidence for fluorescence-and-phase complementarity. + +### 9.3 Warp-vs-no-warp comparator (Path B only) + +```bash +uv run python warp_vs_no_warp.py \ + --query-set zikv_07_24 \ + --organelles sec61,g3bp1,phase +``` + +For Path B, regenerate every organelle/phase readout twice: once with +the NS3 warp propagated (current behavior), once without (organelle +embedding kept on its own real-time axis). Side-by-side population +curves and headline numbers. Per discussion §3.8 #10 and §4.5: if the +two agree within 25%, warp propagation is neutral; keep it. If they +diverge, investigate which masks real timing. + +### 9.4 Bimodality test + +```bash +uv run python bimodality_check.py \ + --input compare_onsets_{qset}_summary.csv +``` + +Hartigans dip-test (or 1- vs. 2-component GMM BIC) on every +back-projected real-time distribution from Path B. Per discussion +§3.8 #11. Multimodal distributions get reported as mode-stratified +medians or histogram-as-headline rather than a single point. + +### 9.5 Robustness panel + +| Check | Implements | Reference | +|---|---|---| +| Cost-gate sweep `{0, 10, 20, 30, 50}%` | Path B only | Discussion §3.8 #1 | +| Path-skew gate as primary | Path B only | Discussion §3.8 #2 | +| Window ablation ±50% on `h_pre`, `h_post`, `k_pre`, `k_post` | Path B only | Discussion §3.8 #3 | +| Within-condition shuffle null | All tracks | Discussion §3.8 #4 | +| DBA-init K=10 stability | Path B only | Discussion §3.8 #5 | +| Per-cell baseline-noise bootstrap | Path B only | Discussion §3.8 #6 | +| Funnel transparency table | All tracks | Discussion §3.8 #7 | +| Mock-FOV stability | All tracks | Discussion §3.8 #8 | +| Inter-annotator agreement on `t_zero` | All tracks | Discussion §3.8 #9 | +| Warp-vs-no-warp comparator | Path B only | §9.3 above | +| Bimodality test | Path B only | §9.4 above | +| Cost-gate kept-vs-dropped symmetric reporting | Path B only | Discussion §3.8 #12 | + +All robustness outputs land in `4-compare_tracks/robustness/`. + +## 10. Configs + +Three YAMLs split across `configs/pseudotime/`. Each is loaded with +`datasets.yaml` via the `--datasets` + `--config` CLI pair. + +| File | Contains | Used by | +|---|---|---| +| `datasets.yaml` | `data_zarr`, `embeddings` glob patterns, `datasets` list (pred_dir, annotations_path, fov_pattern, frame_interval_minutes) | every stage | +| `candidates.yaml` | `candidate_sets.{name}`, lineage-reconnect rules, cohort-tagging rules | Stage 0 | +| `build_template.yaml` | `templates.{name}` (Path B template build) | Stage 1 | +| `align_cells.yaml` | `query_sets.{name}` per track | Stage 2 | +| `compare_tracks.yaml` | Stage 4 headlines (which `(organelle, track, cohort)` triples to plot, sweep ranges) | Stage 4 | + +### 10.1 Example `candidate_sets` entry + +```yaml +candidate_sets: + zikv_productive_07_24: + datasets: ["2025_07_24_SEC61", "2025_07_24_G3BP1"] + cohort_filter: + productive: + anchor_label: infection_state + anchor_positive: infected + anchor_negative: uninfected + min_pre_minutes: 240 + min_post_minutes: 360 + bystander: + anchor_label: infection_state + all_negative: uninfected + min_imaging_minutes: 600 + abortive: + # NS3 channel present but no rise; defined operationally per + # the abortive-detection step + require_ns3_channel: true + max_ns3_rise_amplitude: 0.05 # cosine units + mock: + well_pattern: "B/*" # uninfected wells + lineage_rules: + reconnect: true + daughter_handling: + pre_t_zero: paired + post_t_zero: longer_pre_window + during_transition: separate_cohort + crop_window: + h_pre_minutes: 240 + h_post_minutes: 360 + h_post_minutes_g3bp1: 540 + max_lineages: 200 +``` + +### 10.2 Example template entry (Path B) + +```yaml +templates: + infection_nondividing_zikv: + candidate_set: zikv_productive_07_24 + cohort: productive + channel: sensor + transition_window: + k_pre_minutes: 60 + k_post_minutes: 120 + preprocessing: + zscore: none + l2_normalize: true + aggregator: dba + dba: + max_iter: 30 + tol: 1.0e-5 + init: medoid + cost_gate: + mode: sweep + values: [0.0, 0.1, 0.2, 0.3, 0.5] + primary_gate: path_skew + max_skew: 0.7 + tau_event: + mode: half_rise_band + threshold_fraction: 0.5 +``` + +### 10.3 Example query-set entry + +```yaml +query_sets: + zikv_07_24: + candidate_set: zikv_productive_07_24 + cohorts: [productive, bystander, abortive, mock] + channel: sensor + datasets: + - dataset_id: "2025_07_24_SEC61" + - dataset_id: "2025_07_24_G3BP1" + track_paths: [A-anno, A-LC, B] + min_pre_minutes: 240 + min_post_minutes: 360 +``` + +## 11. Annotations CSV schema + +One file per cohort at `0-select_candidates/candidates/{cohort}_{candidate_set}_annotations.csv`. +One row per `(dataset_id, fov_name, lineage_id, t)`. + +| Column | Type | Notes | +|---|---|---| +| `dataset_id` | str | matches a key in `config["datasets"]` | +| `fov_name` | str | e.g. `A/2/000000` | +| `lineage_id` | int | replaces `track_id`; reconnects mother + daughter | +| `track_id` | int | original track id, retained for traceability | +| `parent_track_id` | int | from tracking; `-1` if root | +| `t` | int | absolute frame index | +| `cohort` | str | `productive` / `bystander` / `abortive` / `mock` | +| `divides` | str | `none` / `pre` / `during` / `post` | +| `infection_state` | str | `infected` / `uninfected` / blank | +| `organelle_state` | str | `remodeled` / `noremodeled` / blank | +| `cell_division_state` | str | `mitosis` / `interphase` / blank | + +Stage 1 derives the per-lineage crop window and `t_key_event` from the +annotations; these are not CSV columns. + +## 12. How to run end-to-end + +Once all phases land: + +```bash +# Stage 0 +cd applications/dynaclr/scripts/pseudotime/0-select_candidates +uv run python select_candidates.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/candidates.yaml \ + --candidate-set zikv_productive_07_24 + +# Stage 1 (Path B only) +cd ../1-build_template +uv run python build_template.py \ + --config ../../../configs/pseudotime/build_template.yaml \ + --template infection_nondividing_zikv + +# Stage 2 (three tracks) +cd ../2-align_cells +uv run python align_anno.py --query-set zikv_07_24 +uv run python align_lc.py --query-set zikv_07_24 --min-run 3 +uv run python align_embedding.py --query-set zikv_07_24 --template infection_nondividing_zikv + +# Stage 3 (per track per organelle) +cd ../3-organelle-remodeling +for track in A-anno A-LC B; do + uv run python readout_sec61.py --track $track --query-set zikv_07_24 + uv run python readout_g3bp1.py --track $track --query-set zikv_07_24 + uv run python readout_phase.py --track $track --query-set zikv_07_24 +done + +# Stage 4 (comparison + robustness) +cd ../4-compare_tracks +uv run python compare_onsets.py --query-set zikv_07_24 +uv run python compare_phase_to_fluor.py --query-set zikv_07_24 --organelle sec61 +uv run python warp_vs_no_warp.py --query-set zikv_07_24 +uv run python bimodality_check.py --input compare_onsets_zikv_07_24_summary.csv +``` + +## 13. Limitations and pointers + +This document describes pipeline operations. For the *why* — claims, +falsification protocol, decisions, alternatives considered — see +`/home/eduardo.hirata/repos/DynaCLR/.planning/dynaclr_dtw_discussion.md`. + +Key limitations carried by the pipeline: + +- No smFISH for per-cell viral RNA → "productive vs abortive" is + partly a censored-data category, not pure biology. +- No entry inhibitor → phase-tracks-fluorescence (claim a') is + correlation, not causation, even at high ρ. +- No live cell-cycle marker → the `divides` flag confounds cell-cycle, + division timing, and survivor bias. +- DENV deferred → claim (e) is not in scope for v1. + +See discussion §8 for the full limitations and future-work priority. diff --git a/applications/dynaclr/docs/DAGs/training.md b/applications/dynaclr/docs/DAGs/training.md new file mode 100644 index 000000000..b5d1fb8de --- /dev/null +++ b/applications/dynaclr/docs/DAGs/training.md @@ -0,0 +1,161 @@ +# Training DAG + +## Prerequisites + +Datasets must be AI-ready before building a collection. See [ai_ready_datasets.md](ai_ready_datasets.md) +for the full data preparation pipeline (`prepare run` → concatenate → QC → preprocess). + +A dataset is ready when `prepare status` shows `preprocessed: yes` — meaning both +`normalization` and `focus_slice` metadata exist in the zarr zattrs. + +## Step-by-step detail + +``` +zarr stores (preprocessed: normalization + focus_slice in zattrs) +tracking.zarr (per-dataset, synced from NFS) + │ + ├──► collection.yml # defines experiments, channels, perturbation_wells + │ # versioned in git under configs/collections/ + ▼ +dynaclr build-cell-index \ + configs/collections/.yml \ + /hpc/projects/organelle_phenotyping/models/collections/.parquet \ + --num-workers 8 + │ reads tracking CSVs + zarr shape metadata + │ one row per (cell, timepoint, channel) + │ sets z=0 placeholder (overwritten in next step) + ▼ +.parquet (raw: shape columns, z=0, no norm stats) + │ + ▼ +dynaclr preprocess-cell-index \ + /hpc/.../collections/.parquet \ + --focus-channel Phase3D + │ opens each unique FOV once from zarr zattrs: + │ norm_mean/std/median/iqr/max/min — per (cell, timepoint, channel) + │ z_focus_mean — per FOV (mean across timepoints) + │ z — per timepoint focus slice index + │ drops empty frames (max == 0) + ▼ +.parquet (ready: self-contained, no zarr reads at training time) + │ + ▼ +viscy fit --config configs/training/.yml + │ OR: sbatch configs/training/.sh (SLURM, recommended) + │ MultiExperimentDataModule reads parquet only at init + │ tensorstore opens zarr lazily on first batch + │ ExperimentRegistry reads plate.zattrs["focus_slice"] once at startup + │ for z_ranges (z_extraction_window centered on dataset z_focus_mean) + ▼ +checkpoints/ + wandb logs +``` + +## Pipeline DAG (process dependency) + +``` +collection.yml + │ + ▼ +build-cell-index (CPU, ~1 min) + │ + ▼ +preprocess-cell-index (CPU, ~5 min, I/O bound) + │ + ▼ +viscy fit (GPU, hours–days) +``` + +## Key commands + + +| Step | Command | Input | Output | +| --------------------- | ------------------------------------------------------------------------- | -------------------------------------- | --------------------------------------------------------- | +| Build cell index | `dynaclr build-cell-index --num-workers 8` | collection YAML + zarr + tracking CSVs | parquet with TCZYX shape columns | +| Preprocess cell index | `dynaclr preprocess-cell-index --focus-channel Phase3D` | parquet + zarr zattrs | parquet with norm stats, per-timepoint z, empties removed | +| Train (interactive) | `uv run viscy fit --config configs/training/.yml` | training config + parquet | checkpoints + logs | +| Train (SLURM) | `sbatch configs/training/.sh` | training config + parquet | checkpoints + logs | +| Resume (SLURM) | `CKPT_PATH=.../last.ckpt sbatch configs/training/.sh` | checkpoint path env var | resumed checkpoints | + + +## What lives where + + +| Data | Location | When written | +| --------------------------------------- | --------------------------------------------------------- | -------------------------------------------- | +| Pixel data (TCZYX arrays) | zarr store on VAST | `prepare run` → concatenate | +| Cell tracking (y, x, t, track_id) | tracking.zarr on VAST | `prepare run` → concatenate | +| Normalization stats (per FOV/timepoint) | zarr zattrs → parquet `norm_*` columns | `viscy preprocess` → `preprocess-cell-index` | +| Focus slice (per timepoint) | zarr zattrs → parquet `z` column | `viscy preprocess` → `preprocess-cell-index` | +| Focus slice mean (per FOV) | zarr zattrs → parquet `z_focus_mean` | `viscy preprocess` → `preprocess-cell-index` | +| TCZYX shape per FOV | parquet columns | `build-cell-index` | +| Collection definition | `configs/collections/.yml` in git | manually authored | +| Parquet | `/hpc/projects/organelle_phenotyping/models/collections/` | `build-cell-index` | + + +## collection.yml format + +```yaml +name: +description: "..." + +experiments: + - name: # {date}_{cell}_{marker}_{perturbation} + data_path: /hpc/projects/.../dataset.zarr + tracks_path: /hpc/projects/.../tracking.zarr + channels: + - name: "raw GFP EX488 EM525-45" # zarr channel name (exact match) + marker: G3BP1 # protein label used in parquet + perturbation_wells: + uninfected: [C/1] + infected: [C/2] + interval_minutes: 30.0 + start_hpi: 3.5 + marker: G3BP1 + organelle: stress_granules + moi: 5.0 + pixel_size_xy_um: 0.1494 + pixel_size_z_um: 0.174 +``` + +Experiment name convention: `{date}_{cell_line}_{marker}_{perturbation}` — +perturbation suffix is always included (e.g., `_ZIKV`, `_DENV`, `_ZIKV_DENV`). + +## Training config structure + +Training configs use Lightning CLI `base:` inheritance: + +```yaml +base: + - recipes/trainer.yml # seed, accelerator, logger, callbacks + - recipes/model/contrastive_encoder_convnext_tiny.yml # or dinov3_frozen_mlp.yml + +trainer: + strategy: ddp + devices: 2 + precision: bf16-mixed + max_epochs: 150 + +data: + cell_index_path: /hpc/.../collections/.parquet + ... +``` + +SLURM `.sh` scripts export `PYTHONNOUSERSITE=1` and launch via `srun` for DDP. + +## Reproducibility + +Version `collection.yml` in git. The parquet is derived deterministically from: + +1. The collection YAML (experiment definitions, channels, wells) +2. Tracking zarrs (cell positions) +3. Zarr zattrs (normalization + focus stats from `viscy preprocess` + `qc run`) + +To reproduce: `build-cell-index` → `preprocess-cell-index` from the same collection YAML. + +## Notes + +- `preprocess-cell-index` overwrites the parquet in-place by default. Pass `--output` to write elsewhere. +- `--focus-channel Phase3D` selects which channel's `per_timepoint` focus indices are written to the `z` column. Use the channel that has the sharpest axial contrast (label-free Phase3D for most experiments). +- At training time, `ExperimentRegistry.__post_init__` reads `plate.zattrs["focus_slice"][channel]["dataset_statistics"]["z_focus_mean"]` to compute per-experiment z_ranges for patch extraction. This is the only zarr metadata read at training startup; the parquet is self-contained for all per-cell data. +- The `z` column in the parquet is carried through to embeddings obs during predict — downstream consumers (e.g., visualization) can use it to recover the in-focus plane for each cell at each timepoint. +- For performance tuning (num_workers, pin_memory, batch_size, augmentation placement), see [profiling.md](profiling.md) — authored after the first validated profiling sweep. diff --git a/applications/dynaclr/docs/linear_classifiers/README.md b/applications/dynaclr/docs/linear_classifiers/README.md index a4f893c9c..a0486d299 100644 --- a/applications/dynaclr/docs/linear_classifiers/README.md +++ b/applications/dynaclr/docs/linear_classifiers/README.md @@ -9,14 +9,13 @@ This directory contains: | File | Description | |------|-------------| | `src/utils.py` | Shared functions for discovering predictions, annotations, channel resolution, and path utilities | -| `src/report.py` | PDF report generation for cross-validation and evaluation (optional) | +| `src/report.py` | PDF report generation for cross-validation (optional, `--report` flag) | | `scripts/generate_prediction_scripts.py` | Generates SLURM `.sh`/`.yml` scripts for datasets missing embeddings | | `scripts/generate_batch_predictions.py` | Batch prediction config & SLURM script generator with auto z-range | | `scripts/generate_train_config.py` | Generates training YAML configs for all valid task x channel combinations | | `scripts/train_linear_classifier.py` | CLI for training a classifier from a config | | `scripts/apply_linear_classifier.py` | CLI for applying a trained classifier to new embeddings | | `scripts/cross_validation.py` | Leave-one-dataset-out CV with impact scoring (helps/hurts/uncertain) | -| `scripts/evaluate_dataset.py` | Compare embedding models (e.g. 2D vs 3D) on a held-out test set | ## Prerequisites @@ -80,8 +79,8 @@ dynaclr apply-linear-classifier -c configs/example_linear_classifier_inference.y Determine which training datasets help or hurt classifier performance using rotating leave-one-dataset-out CV. Run from the `linear_classifiers/` directory: ```bash -python scripts/cross_validation.py -c configs/cross_validate_example.yaml -python scripts/cross_validation.py -c configs/cross_validate_example.yaml --report # with PDF +dynaclr cross-validate -c configs/cross_validate_example.yaml +dynaclr cross-validate -c configs/cross_validate_example.yaml --report # with PDF ``` Outputs: @@ -96,24 +95,6 @@ Each dataset is labeled as: - **uncertain** — delta within noise - **unsafe** — fold skipped due to insufficient class samples -### 6. Evaluate models on a held-out test set - -Compare embedding models by training classifiers and evaluating on a held-out dataset: - -```bash -python scripts/evaluate_dataset.py -c configs/evaluate_dataset_example.yaml -python scripts/evaluate_dataset.py -c configs/evaluate_dataset_example.yaml --report # with PDF -``` - -Outputs per model: -- `{model}/{task}_{channel}_pipeline.joblib` — trained classifier -- `{model}/{task}_{channel}_predictions.zarr` — test predictions -- `{model}/metrics_summary.csv` — per-model metrics - -Combined outputs: -- `train_metrics_comparison.csv` — validation metrics across models -- `test_metrics_comparison.csv` — test metrics across models - ## Training Configuration Create a YAML config file (see `configs/example_linear_classifier_train.yaml`): diff --git a/applications/dynaclr/nextflow/README.md b/applications/dynaclr/nextflow/README.md new file mode 100644 index 000000000..f2d4f9184 --- /dev/null +++ b/applications/dynaclr/nextflow/README.md @@ -0,0 +1,153 @@ +# DynaCLR Nextflow Pipelines + +Multi-workflow Nextflow layout. `main.nf` is a thin router that dispatches to a +named sub-workflow via `-entry`. Each entry workflow owns its own DAG and lives +under `workflows/`; processes live under `modules//`. + +## Layout + +``` +applications/dynaclr/nextflow/ +├── main.nf # thin router — -entry +├── nextflow.config # shared params + SLURM resource labels +├── workflows/ +│ ├── evaluation.nf # workflow EVALUATION { take: ... } +│ └── training_preprocessing.nf # workflow TRAINING_PREPROCESSING { take: ... } +└── modules/ + ├── evaluation/ # processes used only by evaluation + │ ├── prepare_configs.nf + │ ├── predict.nf + │ ├── split.nf + │ └── ... + ├── preprocessing/ # processes used by training_preprocessing + │ ├── build_cell_index.nf + │ └── preprocess_cell_index.nf + └── shared/ # processes reused across workflows +``` + +## Running + +```bash +module load nextflow/24.10.5 + +# Evaluation +nextflow run applications/dynaclr/nextflow/main.nf -entry evaluation \ + --eval_config applications/dynaclr/configs/evaluation/.yaml \ + --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ + -resume + +# Training preprocessing (collection YAML → training-ready parquet) +nextflow run applications/dynaclr/nextflow/main.nf -entry training_preprocessing \ + --collection_yaml applications/dynaclr/configs/collections/.yml \ + --parquet_out /hpc/projects/organelle_phenotyping/models/collections/.parquet \ + --focus_channel Phase3D \ + --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ + -resume + +# Local test (no SLURM) — append `-profile local` +``` + +Running `main.nf` without `-entry` fails loudly with the list of valid entries. + +## Predict-only runs + +Use the `evaluation` entry with `steps: [predict, split]` in your eval config +to run inference on a new dataset without any downstream evals. Rerun with +more steps later using `-resume` — predict/split are skipped because +`embeddings.zarr` and `{exp}.zarr` already exist on disk. See +[docs/DAGs/evaluation.md](../docs/DAGs/evaluation.md#predict-only-runs-inference-without-downstream-evals) +for the full pattern. + +## Adding a new workflow + +Follow this four-step recipe. The `training_preprocessing` workflow is the +reference example — copy its structure. + +### 1. Create process modules + +Each process is a single `.nf` file under `modules//`. Prefer +`val` inputs (not `path`) to avoid Nextflow staging zarr/parquet files — +everything is read/written in place on VAST. + +```groovy +// modules//my_step.nf +process MY_STEP { + label 'cpu' // picks SLURM resources from nextflow.config + + input: + val input_path + val workspace_dir + + output: + val input_path, emit: result + + script: + """ + uv run --project=${workspace_dir} --package=dynaclr \ + dynaclr my-command ${input_path} + """ +} +``` + +Reuse `modules/shared/` for processes used by more than one workflow. + +### 2. Create a named sub-workflow + +```groovy +// workflows/my_workflow.nf +include { MY_STEP } from '../modules/my_workflow/my_step' + +workflow MY_WORKFLOW { + take: + input_path + workspace_dir + + main: + MY_STEP(input_path, workspace_dir) +} +``` + +Use `take:` for inputs so the workflow is composable. `main:` holds the DAG. + +### 3. Register an entry wrapper in `main.nf` + +```groovy +include { MY_WORKFLOW } from './workflows/my_workflow' + +workflow my_workflow { // lowercase name → -entry my_workflow + if (!params.input_path) { + error "ERROR: --input_path is required for -entry my_workflow" + } + MY_WORKFLOW(params.input_path, params.workspace_dir) +} +``` + +The wrapper has two jobs: validate required params and bridge CLI flags into +the `take:` arguments of the sub-workflow. Use lowercase entry names so they +don't clash with the imported UPPERCASE workflow symbol. + +### 4. Add params to `nextflow.config` + +```groovy +params { + // ... existing params + input_path = null // Required for -entry my_workflow +} +``` + +Resource labels (`cpu`, `cpu_heavy`, `gpu_2d`, `gpu_3d`, `cpu_light`) are +shared across all workflows — don't redefine them per workflow. + +## Conventions + +- **Entry workflow names are lowercase** (`evaluation`, `training_preprocessing`). + Sub-workflow symbols are UPPERCASE (`EVALUATION`, `TRAINING_PREPROCESSING`). +- **Process names are UPPERCASE** (`PREDICT`, `BUILD_CELL_INDEX`). +- **Pass paths as `val`, not `path`** — avoids Nextflow staging large zarrs. +- **Always use `-resume`** — every step re-checks existence on disk. +- **Use `PYTHONNOUSERSITE=1`** (already set in `env { }` block) — prevents + `~/.local/` from shadowing the conda/uv env. +- **Manifest-driven optional steps**: if a workflow generates a JSON manifest + (like `prepare-eval-configs`), gate steps with `.filter { it.containsKey(...) }` + and `.ifEmpty('skip')` so the DAG remains connected when a step is disabled. + See `workflows/evaluation.nf` for the pattern. diff --git a/applications/dynaclr/nextflow/main.nf b/applications/dynaclr/nextflow/main.nf new file mode 100644 index 000000000..dfacbec75 --- /dev/null +++ b/applications/dynaclr/nextflow/main.nf @@ -0,0 +1,70 @@ +#!/usr/bin/env nextflow +// DynaCLR Nextflow Router +// +// Thin entry-point that dispatches to a named sub-workflow via `-entry`. +// Each sub-workflow lives in workflows/.nf and owns its own DAG. +// +// Usage: +// module load nextflow/24.10.5 +// +// # Evaluation +// nextflow run applications/dynaclr/nextflow/main.nf -entry evaluation \ +// --eval_config /path/to/eval_config.yaml \ +// --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ +// -resume +// +// # Training preprocessing (collection → parquet) +// nextflow run applications/dynaclr/nextflow/main.nf -entry training_preprocessing \ +// --collection_yaml /path/to/collection.yml \ +// --parquet_out /hpc/.../collections/.parquet \ +// --focus_channel Phase3D \ +// --workspace_dir /hpc/mydata/eduardo.hirata/repos/viscy \ +// -resume +// +// Zarr/parquet files are read/written in place on VAST (no staging). + +nextflow.enable.dsl = 2 + +include { EVALUATION } from './workflows/evaluation' +include { TRAINING_PREPROCESSING } from './workflows/training_preprocessing' + + +// Default (unnamed) workflow — fail loudly if invoked without -entry. +workflow { + error """ + No entry workflow selected. + + Use one of: + -entry evaluation (requires --eval_config) + -entry training_preprocessing (requires --collection_yaml, --parquet_out) + """.stripIndent() +} + + +// Entry workflow names are lowercase to avoid clashing with the imported +// named workflows above (Nextflow treats sub-workflows with `take:` as +// non-directly-invocable, so we wrap each one here). + +workflow evaluation { + if (!params.eval_config) { + error "ERROR: --eval_config is required for -entry evaluation" + } + EVALUATION(file(params.eval_config), params.workspace_dir) +} + + +workflow training_preprocessing { + if (!params.collection_yaml) { + error "ERROR: --collection_yaml is required for -entry training_preprocessing" + } + if (!params.parquet_out) { + error "ERROR: --parquet_out is required for -entry training_preprocessing" + } + TRAINING_PREPROCESSING( + params.collection_yaml, + params.parquet_out, + params.focus_channel, + params.num_workers, + params.workspace_dir + ) +} diff --git a/applications/dynaclr/nextflow/modules/evaluation/append_annotations.nf b/applications/dynaclr/nextflow/modules/evaluation/append_annotations.nf new file mode 100644 index 000000000..1132a4e2f --- /dev/null +++ b/applications/dynaclr/nextflow/modules/evaluation/append_annotations.nf @@ -0,0 +1,22 @@ +// Append annotation columns to per-experiment zarr obs. +// Reads per-experiment annotation CSVs and writes task columns (e.g. infection_state) +// directly into each zarr so plots can color by ground truth labels. +// Runs after SPLIT; depends on split_done signal. + +process APPEND_ANNOTATIONS { + executor 'local' + + input: + val split_done // dependency signal from SPLIT (all zarrs exist) + val aa_yaml + val workspace_dir + + output: + val 'done', emit: done + + script: + """ + uv run --project=${workspace_dir} --package=dynaclr \ + dynaclr append-annotations -c ${aa_yaml} + """ +} diff --git a/applications/dynaclr/nextflow/modules/evaluation/append_predictions.nf b/applications/dynaclr/nextflow/modules/evaluation/append_predictions.nf new file mode 100644 index 000000000..6d2cd5360 --- /dev/null +++ b/applications/dynaclr/nextflow/modules/evaluation/append_predictions.nf @@ -0,0 +1,22 @@ +// Apply saved linear classifiers to per-experiment zarrs and write predictions. +// Loads pipelines saved by LINEAR_CLASSIFIERS, predicts on all cells per marker, +// and writes predicted_{task} columns to obs alongside probabilities in obsm. +// Depends on LINEAR_CLASSIFIERS completing (pipelines must exist). + +process APPEND_PREDICTIONS { + executor 'local' + + input: + val lc_done // dependency signal from LINEAR_CLASSIFIERS + val ap_yaml + val workspace_dir + + output: + val 'done', emit: done + + script: + """ + uv run --project=${workspace_dir} --package=dynaclr \ + dynaclr append-predictions -c ${ap_yaml} + """ +} diff --git a/applications/dynaclr/nextflow/modules/evaluation/linear_classifiers.nf b/applications/dynaclr/nextflow/modules/evaluation/linear_classifiers.nf new file mode 100644 index 000000000..ed3a73dcf --- /dev/null +++ b/applications/dynaclr/nextflow/modules/evaluation/linear_classifiers.nf @@ -0,0 +1,20 @@ +// Linear classifiers on per-experiment embeddings. +// Reads directly from the embeddings directory (all zarrs). + +process LINEAR_CLASSIFIERS { + executor 'local' + + input: + val split_done // dependency signal from SPLIT (embeddings dir is populated) + val lc_yaml + val workspace_dir + + output: + val 'done', emit: done + + script: + """ + uv run --project=${workspace_dir} --package=dynaclr \ + dynaclr run-linear-classifiers -c ${lc_yaml} + """ +} diff --git a/applications/dynaclr/nextflow/modules/evaluation/mmd.nf b/applications/dynaclr/nextflow/modules/evaluation/mmd.nf new file mode 100644 index 000000000..f85fc31b5 --- /dev/null +++ b/applications/dynaclr/nextflow/modules/evaluation/mmd.nf @@ -0,0 +1,28 @@ +// Per-experiment MMD for one (zarr, mmd_block) pair. +// Patches __ZARR_PATH__ in the per-block template YAML. + +process MMD { + label 'cpu' + + input: + tuple val(zarr_path), val(block_name), val(mmd_yaml) + val workspace_dir + + output: + val zarr_path, emit: zarr_path + + script: + def exp_name = new File(zarr_path).name.replaceAll(/\.zarr$/, '') + """ + python3 -c " +import yaml +with open('${mmd_yaml}') as f: + cfg = yaml.safe_load(f) +cfg['input_path'] = '${zarr_path}' +with open('mmd_${block_name}_${exp_name}_patched.yaml', 'w') as f: + yaml.dump(cfg, f, default_flow_style=False, sort_keys=False) +" + uv run --project=${workspace_dir} --package=dynaclr \ + dynaclr compute-mmd -c mmd_${block_name}_${exp_name}_patched.yaml + """ +} diff --git a/applications/dynaclr/nextflow/modules/evaluation/mmd_combined.nf b/applications/dynaclr/nextflow/modules/evaluation/mmd_combined.nf new file mode 100644 index 000000000..8e442981a --- /dev/null +++ b/applications/dynaclr/nextflow/modules/evaluation/mmd_combined.nf @@ -0,0 +1,28 @@ +// Cross-experiment MMD for one block (with per-experiment batch centering). +// Collects all zarr paths and patches the template YAML's input_paths list. + +process MMD_COMBINED { + label 'cpu' + + input: + tuple val(zarr_paths), val(block_name), val(mmd_combined_yaml) + val workspace_dir + + output: + val block_name, emit: block_name + + script: + def paths_repr = zarr_paths.split('\n').collect { "'${it}'" }.join(', ') + """ + python3 -c " +import yaml +with open('${mmd_combined_yaml}') as f: + cfg = yaml.safe_load(f) +cfg['input_paths'] = [${paths_repr}] +with open('mmd_${block_name}_combined_patched.yaml', 'w') as f: + yaml.dump(cfg, f, default_flow_style=False, sort_keys=False) +" + uv run --project=${workspace_dir} --package=dynaclr \ + dynaclr compute-mmd --combined -c mmd_${block_name}_combined_patched.yaml + """ +} diff --git a/applications/dynaclr/nextflow/modules/evaluation/mmd_plot_heatmap.nf b/applications/dynaclr/nextflow/modules/evaluation/mmd_plot_heatmap.nf new file mode 100644 index 000000000..505436e18 --- /dev/null +++ b/applications/dynaclr/nextflow/modules/evaluation/mmd_plot_heatmap.nf @@ -0,0 +1,15 @@ +// Gather step: plot one combined MMD heatmap (all markers) from all per-experiment CSVs. +// Runs once per block after all MMD scatter jobs complete. + +process MMD_PLOT_HEATMAP { + executor 'local' + + input: + val mmd_dir + + script: + """ + uv run --project=${params.workspace_dir} --package=dynaclr \ + dynaclr plot-mmd-heatmap ${mmd_dir} + """ +} diff --git a/applications/dynaclr/nextflow/modules/evaluation/plot.nf b/applications/dynaclr/nextflow/modules/evaluation/plot.nf new file mode 100644 index 000000000..9dde55925 --- /dev/null +++ b/applications/dynaclr/nextflow/modules/evaluation/plot.nf @@ -0,0 +1,33 @@ +// Per-experiment embedding scatter plots (X_pca). +// Patches __ZARR_PATH__ and __PLOT_DIR__ placeholders in the template YAML. + +process PLOT { + label 'cpu' + + input: + val zarr_path + val plot_yaml + val plots_dir + val workspace_dir + + output: + val zarr_path, emit: zarr_path + + script: + def exp_name = new File(zarr_path).name.replaceAll(/\.zarr$/, '') + def plot_subdir = "${plots_dir}/${exp_name}" + """ + python3 -c " +import yaml +with open('${plot_yaml}') as f: + cfg = yaml.safe_load(f) +cfg['input_path'] = '${zarr_path}' +cfg['output_dir'] = '${plot_subdir}' +with open('plot_patched.yaml', 'w') as f: + yaml.dump(cfg, f, default_flow_style=False, sort_keys=False) +" + mkdir -p ${plot_subdir} + uv run --project=${workspace_dir} --package=dynaclr \ + dynaclr plot-embeddings -c plot_patched.yaml + """ +} diff --git a/applications/dynaclr/nextflow/modules/evaluation/plot_combined.nf b/applications/dynaclr/nextflow/modules/evaluation/plot_combined.nf new file mode 100644 index 000000000..2af6b7dc7 --- /dev/null +++ b/applications/dynaclr/nextflow/modules/evaluation/plot_combined.nf @@ -0,0 +1,29 @@ +// Combined embedding plots across all experiments (X_pca_combined, X_phate_combined). +// Collects all zarr paths and patches the template YAML's input_paths list. + +process PLOT_COMBINED { + label 'cpu' + + input: + val zarr_paths // list of all per-experiment zarr paths + val plot_combined_yaml + val workspace_dir + + output: + val 'done', emit: done + + script: + def paths_repr = zarr_paths.collect { "'${it}'" }.join(', ') + """ + python3 -c " +import yaml +with open('${plot_combined_yaml}') as f: + cfg = yaml.safe_load(f) +cfg['input_paths'] = [${paths_repr}] +with open('plot_combined_patched.yaml', 'w') as f: + yaml.dump(cfg, f, default_flow_style=False, sort_keys=False) +" + uv run --project=${workspace_dir} --package=dynaclr \ + dynaclr plot-embeddings -c plot_combined_patched.yaml + """ +} diff --git a/applications/dynaclr/nextflow/modules/evaluation/predict.nf b/applications/dynaclr/nextflow/modules/evaluation/predict.nf new file mode 100644 index 000000000..170450460 --- /dev/null +++ b/applications/dynaclr/nextflow/modules/evaluation/predict.nf @@ -0,0 +1,19 @@ +// Run viscy predict to extract embeddings from a checkpoint. +// Writes embeddings/embeddings.zarr in output_dir. + +process PREDICT { + label "${params.gpu_label}" + + input: + val predict_yaml + val workspace_dir + + output: + val 'done', emit: done + + script: + """ + uv run --project=${workspace_dir} --package=viscy-utils \ + viscy predict -c ${predict_yaml} + """ +} diff --git a/applications/dynaclr/nextflow/modules/evaluation/prepare_configs.nf b/applications/dynaclr/nextflow/modules/evaluation/prepare_configs.nf new file mode 100644 index 000000000..98ee8a203 --- /dev/null +++ b/applications/dynaclr/nextflow/modules/evaluation/prepare_configs.nf @@ -0,0 +1,19 @@ +// Run dynaclr prepare-eval-configs to generate per-step YAML configs. +// Prints a JSON manifest to stdout; we capture it as the process output. + +process PREPARE_CONFIGS { + executor 'local' + + input: + path eval_config + val workspace_dir + + output: + path 'manifest.json', emit: manifest + + script: + """ + uv run --project=${workspace_dir} --package=dynaclr \ + dynaclr prepare-eval-configs -c ${eval_config} > manifest.json + """ +} diff --git a/applications/dynaclr/nextflow/modules/evaluation/reduce.nf b/applications/dynaclr/nextflow/modules/evaluation/reduce.nf new file mode 100644 index 000000000..c2db95a5e --- /dev/null +++ b/applications/dynaclr/nextflow/modules/evaluation/reduce.nf @@ -0,0 +1,29 @@ +// Per-experiment dimensionality reduction (PCA). +// Patches the __ZARR_PATH__ placeholder in the template YAML, then runs. +// Emits zarr_path for downstream processes. + +process REDUCE { + executor 'local' + + input: + val zarr_path + val reduce_yaml + val workspace_dir + + output: + val zarr_path, emit: zarr_path + + script: + """ + python3 -c " +import yaml +with open('${reduce_yaml}') as f: + cfg = yaml.safe_load(f) +cfg['input_path'] = '${zarr_path}' +with open('reduce_patched.yaml', 'w') as f: + yaml.dump(cfg, f, default_flow_style=False, sort_keys=False) +" + uv run --project=${workspace_dir} --package=dynaclr \ + dynaclr reduce-dimensionality -c reduce_patched.yaml + """ +} diff --git a/applications/dynaclr/nextflow/modules/evaluation/reduce_combined.nf b/applications/dynaclr/nextflow/modules/evaluation/reduce_combined.nf new file mode 100644 index 000000000..b54cca06a --- /dev/null +++ b/applications/dynaclr/nextflow/modules/evaluation/reduce_combined.nf @@ -0,0 +1,47 @@ +// Joint PCA + PHATE across all experiments. +// Collects all reduced per-experiment zarr paths, patches the template YAML, +// then runs combined-dim-reduction which writes X_pca_combined / X_phate_combined +// back into each per-experiment zarr. + +process REDUCE_COMBINED { + label 'cpu_heavy' + + input: + val zarr_paths // list of all per-experiment zarr paths (after REDUCE.collect()) + val reduce_combined_yaml + val workspace_dir + + output: + val zarr_paths, emit: zarr_paths + + script: + def paths_repr = zarr_paths.collect { "'${it}'" }.join(', ') + """ + # Pin BLAS to 1 thread per process: PHATE's n_jobs spawns one joblib + # worker per allocated CPU (SLURM_CPUS_PER_TASK), each running KNN + # search single-threaded. If BLAS were unbounded, every worker would + # also try to spawn ~cores threads on its own internal matmuls, + # producing ~cores^2 threads and thrashing the node. Standard sklearn + # parallelism pattern: one axis at a time. KNN search dominates wall + # time on PHATE-with-PCA-input; the BLAS-heavy phases (joint PCA, + # diffusion matrix powers) are bounded fast even at 1 thread. + # Also avoids the scipy.lu deadlock that hit when BLAS tried to use + # all 48 cores on cpu_heavy nodes. + export OPENBLAS_NUM_THREADS=1 + export MKL_NUM_THREADS=1 + export OMP_NUM_THREADS=1 + export NUMEXPR_NUM_THREADS=1 + export VECLIB_MAXIMUM_THREADS=1 + + python3 -c " +import yaml +with open('${reduce_combined_yaml}') as f: + cfg = yaml.safe_load(f) +cfg['input_paths'] = [${paths_repr}] +with open('reduce_combined_patched.yaml', 'w') as f: + yaml.dump(cfg, f, default_flow_style=False, sort_keys=False) +" + uv run --project=${workspace_dir} --package=dynaclr \ + dynaclr combined-dim-reduction -c reduce_combined_patched.yaml + """ +} diff --git a/applications/dynaclr/nextflow/modules/evaluation/smoothness.nf b/applications/dynaclr/nextflow/modules/evaluation/smoothness.nf new file mode 100644 index 000000000..6e7457b6c --- /dev/null +++ b/applications/dynaclr/nextflow/modules/evaluation/smoothness.nf @@ -0,0 +1,28 @@ +// Per-experiment temporal smoothness evaluation. +// Patches the __ZARR_PATH__ placeholder in the template YAML. + +process SMOOTHNESS { + label 'cpu' + + input: + val zarr_path + val smoothness_yaml + val workspace_dir + + output: + val zarr_path, emit: zarr_path + + script: + """ + python3 -c " +import yaml +with open('${smoothness_yaml}') as f: + cfg = yaml.safe_load(f) +cfg['models'][0]['path'] = '${zarr_path}' +with open('smoothness_patched.yaml', 'w') as f: + yaml.dump(cfg, f, default_flow_style=False, sort_keys=False) +" + uv run --project=${workspace_dir} --package=dynaclr \ + dynaclr evaluate-smoothness -c smoothness_patched.yaml + """ +} diff --git a/applications/dynaclr/nextflow/modules/evaluation/smoothness_gather.nf b/applications/dynaclr/nextflow/modules/evaluation/smoothness_gather.nf new file mode 100644 index 000000000..1e2c74455 --- /dev/null +++ b/applications/dynaclr/nextflow/modules/evaluation/smoothness_gather.nf @@ -0,0 +1,40 @@ +// Gather per-experiment smoothness CSVs into one combined file + per-marker summary. +// Runs after all SMOOTHNESS scatter jobs finish. + +process SMOOTHNESS_GATHER { + executor 'local' + + input: + val smoothness_dir + val done // signal that all SMOOTHNESS jobs finished + + output: + val smoothness_dir, emit: smoothness_dir + + script: + """ + uv run --project=${params.workspace_dir} --package=dynaclr python3 -c " +import glob, os +import pandas as pd + +smoothness_dir = '${smoothness_dir}' +csvs = glob.glob(os.path.join(smoothness_dir, '*_per_marker_smoothness.csv')) +if not csvs: + raise RuntimeError(f'No per_marker_smoothness CSVs found in {smoothness_dir}') + +combined = pd.concat([pd.read_csv(f) for f in sorted(csvs)], ignore_index=True) +out = os.path.join(smoothness_dir, 'all_experiments_per_marker_smoothness.csv') +combined.to_csv(out, index=False) +print(f'Wrote {len(combined)} rows from {len(csvs)} experiments to {out}') + +# Per-marker summary: mean +/- std across experiments +metric_cols = [c for c in combined.columns if c not in ('experiment', 'marker')] +agg = combined.groupby('marker')[metric_cols].agg(['mean', 'std']) +agg.columns = ['_'.join(c) for c in agg.columns] +agg = agg.reset_index() +summary_out = os.path.join(smoothness_dir, 'per_marker_summary.csv') +agg.to_csv(summary_out, index=False) +print(f'Per-marker summary ({len(agg)} markers) written to {summary_out}') +" + """ +} diff --git a/applications/dynaclr/nextflow/modules/evaluation/split.nf b/applications/dynaclr/nextflow/modules/evaluation/split.nf new file mode 100644 index 000000000..c611d3fe9 --- /dev/null +++ b/applications/dynaclr/nextflow/modules/evaluation/split.nf @@ -0,0 +1,45 @@ +// Split combined embeddings.zarr into per-experiment zarrs. +// Also generates configs/viewer.yaml using the cell index parquet. +// Emits per-experiment zarr paths as a list. + +process SPLIT { + executor 'local' + + input: + val predict_done // dependency signal from PREDICT + val embeddings_dir + val cell_index_path + val output_dir + val workspace_dir + + output: + path 'zarr_paths.txt', emit: zarr_paths_file + + script: + def combined_zarr = "${embeddings_dir}/embeddings.zarr" + def viewer_yaml = "${output_dir}/configs/viewer.yaml" + """ + uv run --project=${workspace_dir} --package=dynaclr \ + dynaclr split-embeddings \ + --input ${combined_zarr} \ + --output-dir ${embeddings_dir} + + # Generate viewer YAML from cell index parquet + uv run --project=${workspace_dir} --package=dynaclr python3 -c " +import pandas as pd, yaml, pathlib +embeddings_dir = pathlib.Path('${embeddings_dir}') +df = pd.read_parquet('${cell_index_path}', columns=['experiment', 'store_path']) +exp_to_plate = df.drop_duplicates('experiment').set_index('experiment')['store_path'].to_dict() +datasets = {} +for zarr_path in sorted(embeddings_dir.glob('*.zarr')): + exp_name = zarr_path.stem + datasets[exp_name] = {'hcs_plate': exp_to_plate[exp_name], 'anndata': str(zarr_path)} +with open('${viewer_yaml}', 'w') as f: + yaml.dump({'datasets': datasets}, f, default_flow_style=False, sort_keys=False) +print('Viewer YAML written to ${viewer_yaml}') +" + + # Write per-experiment zarr paths to a file for Nextflow to read + ls -d ${embeddings_dir}/*.zarr > zarr_paths.txt + """ +} diff --git a/applications/dynaclr/nextflow/modules/preprocessing/build_cell_index.nf b/applications/dynaclr/nextflow/modules/preprocessing/build_cell_index.nf new file mode 100644 index 000000000..05aceeb55 --- /dev/null +++ b/applications/dynaclr/nextflow/modules/preprocessing/build_cell_index.nf @@ -0,0 +1,24 @@ +// Build the flat cell index parquet from a collection YAML. +// Reads tracking CSVs + zarr shape metadata; writes one row per (cell, timepoint, channel). + +process BUILD_CELL_INDEX { + label 'cpu' + + input: + val collection_yaml + val parquet_out + val num_workers + val workspace_dir + + output: + val parquet_out, emit: parquet + + script: + """ + uv run --project=${workspace_dir} --package=dynaclr \ + dynaclr build-cell-index \ + ${collection_yaml} \ + ${parquet_out} \ + --num-workers ${num_workers} + """ +} diff --git a/applications/dynaclr/nextflow/modules/preprocessing/preprocess_cell_index.nf b/applications/dynaclr/nextflow/modules/preprocessing/preprocess_cell_index.nf new file mode 100644 index 000000000..c619fe118 --- /dev/null +++ b/applications/dynaclr/nextflow/modules/preprocessing/preprocess_cell_index.nf @@ -0,0 +1,22 @@ +// Enrich the cell index parquet with norm stats + per-timepoint focus slice z. +// Opens each unique FOV once from zarr zattrs; overwrites parquet in place. + +process PREPROCESS_CELL_INDEX { + label 'cpu' + + input: + val parquet_in + val focus_channel + val workspace_dir + + output: + val parquet_in, emit: parquet + + script: + """ + uv run --project=${workspace_dir} --package=dynaclr \ + dynaclr preprocess-cell-index \ + ${parquet_in} \ + --focus-channel ${focus_channel} + """ +} diff --git a/applications/dynaclr/nextflow/nextflow.config b/applications/dynaclr/nextflow/nextflow.config new file mode 100644 index 000000000..3920fa2ed --- /dev/null +++ b/applications/dynaclr/nextflow/nextflow.config @@ -0,0 +1,125 @@ +// DynaCLR Nextflow configuration — shared across all entry workflows. +// +// Targets HPC with SLURM. Resource labels: +// gpu_2d — GPU predict for 2D / MIP models (smaller RAM footprint) +// gpu_3d — GPU predict for 3D models (z-stacks need ~3x the RAM) +// cpu — CPU partition, memory-intensive (dim reduction, MMD) +// cpu_light — CPU partition, light tasks (split) +// cpu_heavy — CPU partition, high core/mem (build/preprocess parquet) +// +// Select which gpu label PREDICT uses via `--gpu_label gpu_3d` (default: gpu_2d). + +nextflow.enable.dsl = 2 + +params { + // Shared + workspace_dir = "/hpc/mydata/eduardo.hirata/repos/viscy" + gpu_label = "gpu_2d" // PREDICT resource profile: "gpu_2d" or "gpu_3d" + + // evaluation entry + eval_config = null // Required for -entry evaluation + + // training_preprocessing entry + collection_yaml = null // Required for -entry training_preprocessing + parquet_out = null // Required for -entry training_preprocessing + focus_channel = "Phase3D" + num_workers = 8 +} + +process { + executor = 'slurm' + + // Default: 'finish' lets in-flight tasks complete on first failure. + // Per-label withLabel directives below opt-in to time-only retry for + // transient SLURM failures (exit 140 = SIGUSR2 from time-warning, + // 137 = SIGKILL after time limit). Memory stays flat — our jobs are + // bounded by per-experiment cell quotas, so doubling RAM doesn't help. + // Generic Python errors (exit 1) are NOT retried. + errorStrategy = 'finish' + + withLabel: 'gpu_2d' { + queue = 'gpu' + cpus = 4 + memory = '64 GB' + time = { 4.h * task.attempt } + clusterOptions = '--gres=gpu:1' + errorStrategy = { task.exitStatus in [140, 137] ? 'retry' : 'finish' } + maxRetries = 2 + } + + withLabel: 'gpu_3d' { + queue = 'gpu' + cpus = 4 + memory = '192 GB' + time = { 8.h * task.attempt } + clusterOptions = '--gres=gpu:1' + errorStrategy = { task.exitStatus in [140, 137] ? 'retry' : 'finish' } + maxRetries = 2 + } + + withLabel: 'cpu' { + queue = 'cpu' + cpus = 16 + memory = '128 GB' + time = { 2.h * task.attempt } + errorStrategy = { task.exitStatus in [140, 137] ? 'retry' : 'finish' } + maxRetries = 2 + } + + withLabel: 'cpu_light' { + queue = 'cpu' + cpus = 4 + memory = '32 GB' + time = { 30.m * task.attempt } + errorStrategy = { task.exitStatus in [140, 137] ? 'retry' : 'finish' } + maxRetries = 2 + } + + withLabel: 'cpu_heavy' { + queue = 'cpu' + cpus = 48 + memory = '256 GB' + time = { 4.h * task.attempt } + errorStrategy = { task.exitStatus in [140, 137] ? 'retry' : 'finish' } + maxRetries = 2 + } +} + +env { + PYTHONNOUSERSITE = '1' +} + +// Timing reports — written to the launch directory (nextflow_logs/ for our SLURM scripts). +// trace.txt — per-task tab-separated: realtime, %cpu, peak_rss, status +// report.html — interactive summary (resource usage histograms, slowest tasks) +// timeline.html — Gantt chart of task execution +trace { + enabled = true + file = "${launchDir}/trace.txt" + overwrite = true + fields = 'task_id,name,status,exit,realtime,%cpu,peak_rss,peak_vmem,rchar,wchar' +} +report { + enabled = true + file = "${launchDir}/report.html" + overwrite = true +} +timeline { + enabled = true + file = "${launchDir}/timeline.html" + overwrite = true +} + +// Local profile for testing (no SLURM) +profiles { + local { + process.executor = 'local' + process { + withLabel: 'cpu' { cpus = 4 } + withLabel: 'cpu_light' { cpus = 2 } + withLabel: 'cpu_heavy' { cpus = 12 } + withLabel: 'gpu_2d' { cpus = 2 } + withLabel: 'gpu_3d' { cpus = 2 } + } + } +} diff --git a/applications/dynaclr/nextflow/workflows/evaluation.nf b/applications/dynaclr/nextflow/workflows/evaluation.nf new file mode 100644 index 000000000..12baeb7a0 --- /dev/null +++ b/applications/dynaclr/nextflow/workflows/evaluation.nf @@ -0,0 +1,271 @@ +// DynaCLR Evaluation Workflow +// +// Named sub-workflow invoked via `-entry EVALUATION` from main.nf. +// Takes an eval_config path + workspace_dir, runs the full embedding DAG: +// prepare-configs → predict → split → (reduce / smoothness / mmd / classifiers / plots). + +include { PREPARE_CONFIGS } from '../modules/evaluation/prepare_configs' +include { PREDICT } from '../modules/evaluation/predict' +include { SPLIT } from '../modules/evaluation/split' +include { REDUCE } from '../modules/evaluation/reduce' +include { REDUCE_COMBINED } from '../modules/evaluation/reduce_combined' +include { PLOT } from '../modules/evaluation/plot' +include { PLOT_COMBINED } from '../modules/evaluation/plot_combined' +include { SMOOTHNESS } from '../modules/evaluation/smoothness' +include { MMD } from '../modules/evaluation/mmd' +include { MMD_COMBINED } from '../modules/evaluation/mmd_combined' +include { MMD_PLOT_HEATMAP } from '../modules/evaluation/mmd_plot_heatmap' +include { SMOOTHNESS_GATHER } from '../modules/evaluation/smoothness_gather' +include { LINEAR_CLASSIFIERS } from '../modules/evaluation/linear_classifiers' +include { APPEND_ANNOTATIONS } from '../modules/evaluation/append_annotations' +include { APPEND_PREDICTIONS } from '../modules/evaluation/append_predictions' + + +workflow EVALUATION { + take: + eval_config + workspace_dir + + main: + // ----------------------------------------------------------------------- + // Step 1: Generate per-step YAML configs → JSON manifest + // ----------------------------------------------------------------------- + PREPARE_CONFIGS(eval_config, workspace_dir) + + manifest_ch = PREPARE_CONFIGS.out.manifest + .map { f -> new groovy.json.JsonSlurper().parse(f) } + + // ----------------------------------------------------------------------- + // Step 2: Predict (GPU) — only if "predict" key is in manifest + // ----------------------------------------------------------------------- + predict_yaml_ch = manifest_ch + .flatMap { manifest -> + manifest.containsKey('predict') ? [manifest.predict] : [] + } + + PREDICT(predict_yaml_ch, workspace_dir) + + predict_signal_ch = PREDICT.out.done + .ifEmpty('skip') + .first() + + // ----------------------------------------------------------------------- + // Step 3: Split — runs after predict (or immediately if predict skipped) + // ----------------------------------------------------------------------- + SPLIT( + predict_signal_ch, + manifest_ch.map { it.embeddings_dir }, + manifest_ch.map { it.cell_index_path }, + manifest_ch.map { it.output_dir }, + workspace_dir + ) + + per_exp_zarrs_ch = SPLIT.out.zarr_paths_file + .splitText() + .map { it.trim() } + .filter { it.endsWith('.zarr') } + + split_done_ch = per_exp_zarrs_ch.collect().map { 'done' } + + // ----------------------------------------------------------------------- + // Step 4a: Per-experiment dim reduction (scatter) — after split + // ----------------------------------------------------------------------- + reduce_yaml_ch = manifest_ch + .filter { it.containsKey('reduce') } + .map { it.reduce } + + reduce_inputs_ch = per_exp_zarrs_ch.combine(reduce_yaml_ch) + + REDUCE( + reduce_inputs_ch.map { zarr, yaml -> zarr }, + reduce_inputs_ch.map { zarr, yaml -> yaml }, + workspace_dir + ) + + // ----------------------------------------------------------------------- + // Step 4b: Combined dim reduction (gather) — after all REDUCE finish + // ----------------------------------------------------------------------- + reduce_combined_yaml_ch = manifest_ch + .filter { it.containsKey('reduce_combined') } + .map { it.reduce_combined } + + REDUCE_COMBINED( + REDUCE.out.zarr_path.collect(), + reduce_combined_yaml_ch, + workspace_dir + ) + + // Barrier: per-experiment zarr writes (X_pca_combined / X_phate_combined) + // must finish before APPEND_ANNOTATIONS / APPEND_PREDICTIONS start writing + // to the same zarrs. ifEmpty('skip') keeps the chain alive when + // reduce_combined isn't in steps. + reduce_combined_done_ch = REDUCE_COMBINED.out.zarr_paths + .ifEmpty('skip') + .first() + + // ----------------------------------------------------------------------- + // Step 5: Smoothness (scatter, depends only on split) + // ----------------------------------------------------------------------- + smoothness_yaml_ch = manifest_ch + .filter { it.containsKey('smoothness') } + .map { it.smoothness } + + smoothness_inputs_ch = per_exp_zarrs_ch.combine(smoothness_yaml_ch) + + SMOOTHNESS( + smoothness_inputs_ch.map { zarr, yaml -> zarr }, + smoothness_inputs_ch.map { zarr, yaml -> yaml }, + workspace_dir + ) + + smoothness_dir_ch = manifest_ch.map { "${it.output_dir}/smoothness" } + smoothness_done_ch = SMOOTHNESS.out.zarr_path.collect().map { 'done' } + SMOOTHNESS_GATHER(smoothness_dir_ch, smoothness_done_ch) + + // ----------------------------------------------------------------------- + // Step 6: MMD per-experiment (scatter, depends only on split) + // ----------------------------------------------------------------------- + mmd_block_inputs_ch = manifest_ch + .filter { it.containsKey('mmd_blocks') && it.mmd_blocks.size() > 0 } + .flatMap { manifest -> + manifest.mmd_blocks.collect { block_name -> + [block_name, manifest["mmd_${block_name}"]] + } + } + + mmd_per_exp_ch = per_exp_zarrs_ch + .combine(mmd_block_inputs_ch) + .map { zarr, block_name, mmd_yaml -> tuple(zarr, block_name, mmd_yaml) } + + MMD(mmd_per_exp_ch, workspace_dir) + + mmd_heatmap_dirs_ch = manifest_ch + .filter { it.containsKey('mmd_blocks') && it.mmd_blocks.size() > 0 } + .flatMap { manifest -> + manifest.mmd_blocks.collect { block_name -> manifest["mmd_${block_name}_dir"] } + } + + MMD.out.zarr_path.collect() + .combine(mmd_heatmap_dirs_ch) + .map { items -> items[-1] } + | MMD_PLOT_HEATMAP + + // ----------------------------------------------------------------------- + // Step 6b: MMD combined (gather per block, depends only on split) + // ----------------------------------------------------------------------- + mmd_combined_inputs_ch = manifest_ch + .filter { it.containsKey('mmd_combined_blocks') && it.mmd_combined_blocks.size() > 0 } + .flatMap { manifest -> + manifest.mmd_combined_blocks.collect { block_name -> + [block_name, manifest["mmd_${block_name}_cross_exp"]] + } + } + + mmd_combined_zarrs_str_ch = per_exp_zarrs_ch.collect().map { zarrs -> zarrs.join('\n') } + + mmd_combined_ch = mmd_combined_zarrs_str_ch + .combine(mmd_combined_inputs_ch) + .map { zarrs_str, block_name, mmd_yaml -> tuple(zarrs_str, block_name, mmd_yaml) } + + MMD_COMBINED(mmd_combined_ch, workspace_dir) + + // ----------------------------------------------------------------------- + // Step 7: Append annotations — must run AFTER reduce_combined (both + // mutate per-experiment zarrs), and after split (zarrs must exist). + // ----------------------------------------------------------------------- + // Concurrency invariant: per-experiment zarrs have one writer at a time. + // Pre-write order: SPLIT -> REDUCE -> REDUCE_COMBINED -> APPEND_ANNOTATIONS + // -> LINEAR_CLASSIFIERS (reads only) -> APPEND_PREDICTIONS -> PLOT (reads). + aa_yaml_ch = manifest_ch + .filter { it.containsKey('append_annotations') } + .map { it.append_annotations } + + aa_ready_ch = split_done_ch.mix(reduce_combined_done_ch).collect().map { 'ready' } + + APPEND_ANNOTATIONS(aa_ready_ch, aa_yaml_ch, workspace_dir) + + aa_done_ch = APPEND_ANNOTATIONS.out.done + .ifEmpty('skip') + .first() + + // ----------------------------------------------------------------------- + // Step 8: Linear classifiers — after append_annotations + // ----------------------------------------------------------------------- + lc_yaml_ch = manifest_ch + .filter { it.containsKey('linear_classifiers') } + .map { it.linear_classifiers } + + LINEAR_CLASSIFIERS(aa_done_ch, lc_yaml_ch, workspace_dir) + + // ----------------------------------------------------------------------- + // Step 9: Append predictions — after linear classifiers AND split + // ----------------------------------------------------------------------- + // append_predictions reads per-experiment zarrs (produced by SPLIT) and + // writes predicted_* columns to obs. It must wait on BOTH: + // - LINEAR_CLASSIFIERS (when present in steps): pipelines must exist + // - SPLIT: the zarrs to predict on must exist + // For Wave-2 evaluations that fetch pipelines from an external registry + // (no LINEAR_CLASSIFIERS in steps), lc_done is 'skip' immediately and + // only the split dependency keeps APPEND_PREDICTIONS gated. + ap_yaml_ch = manifest_ch + .filter { it.containsKey('append_predictions') } + .map { it.append_predictions } + + lc_done_ch = LINEAR_CLASSIFIERS.out.done + .ifEmpty('skip') + .first() + + // Combine the two upstream signals into one barrier value. + ap_ready_ch = lc_done_ch.mix(split_done_ch).collect().map { 'ready' } + + APPEND_PREDICTIONS(ap_ready_ch, ap_yaml_ch, workspace_dir) + + ap_done_ch = APPEND_PREDICTIONS.out.done + .ifEmpty('skip') + .first() + + // ----------------------------------------------------------------------- + // Step 10a: Per-experiment plots — after reduce_combined + enrichment + // ----------------------------------------------------------------------- + plot_yaml_ch = manifest_ch + .filter { it.containsKey('plot') } + .map { it.plot } + + plots_dir_ch = manifest_ch.map { "${it.output_dir}/plots" } + + enrichment_done_ch = aa_done_ch.mix(ap_done_ch).collect().map { 'ready' } + + reduce_zarrs_str_ch = REDUCE_COMBINED.out.zarr_paths + .map { zarrs -> zarrs.join('\n') } + + post_reduce_zarrs_ch = reduce_zarrs_str_ch + .combine(enrichment_done_ch) + .map { zarrs_str, _ready -> zarrs_str.split('\n').toList() } + .flatten() + + plot_inputs_ch = post_reduce_zarrs_ch.combine(plot_yaml_ch).combine(plots_dir_ch) + + PLOT( + plot_inputs_ch.map { zarr, yaml, dir -> zarr }, + plot_inputs_ch.map { zarr, yaml, dir -> yaml }, + plot_inputs_ch.map { zarr, yaml, dir -> dir }, + workspace_dir + ) + + // ----------------------------------------------------------------------- + // Step 10b: Combined plots (gather) — after reduce_combined + enrichment + // ----------------------------------------------------------------------- + plot_combined_yaml_ch = manifest_ch + .filter { it.containsKey('plot_combined') } + .map { it.plot_combined } + + plot_combined_input_ch = reduce_zarrs_str_ch + .combine(enrichment_done_ch) + .map { zarrs_str, _ready -> zarrs_str.split('\n').toList() } + + PLOT_COMBINED( + plot_combined_input_ch, + plot_combined_yaml_ch, + workspace_dir + ) +} diff --git a/applications/dynaclr/nextflow/workflows/training_preprocessing.nf b/applications/dynaclr/nextflow/workflows/training_preprocessing.nf new file mode 100644 index 000000000..517acac3c --- /dev/null +++ b/applications/dynaclr/nextflow/workflows/training_preprocessing.nf @@ -0,0 +1,28 @@ +// DynaCLR Training Preprocessing Workflow +// +// Named sub-workflow invoked via `-entry TRAINING_PREPROCESSING` from main.nf. +// Takes a collection YAML and produces a training-ready parquet: +// build-cell-index → preprocess-cell-index (norm stats + focus slice z). +// +// Required params: +// --collection_yaml path to configs/collections/.yml +// --parquet_out output parquet path +// --focus_channel channel used for per-timepoint z (default: Phase3D) +// --num_workers build-cell-index parallelism (default: 8) + +include { BUILD_CELL_INDEX } from '../modules/preprocessing/build_cell_index' +include { PREPROCESS_CELL_INDEX } from '../modules/preprocessing/preprocess_cell_index' + + +workflow TRAINING_PREPROCESSING { + take: + collection_yaml + parquet_out + focus_channel + num_workers + workspace_dir + + main: + BUILD_CELL_INDEX(collection_yaml, parquet_out, num_workers, workspace_dir) + PREPROCESS_CELL_INDEX(BUILD_CELL_INDEX.out.parquet, focus_channel, workspace_dir) +} diff --git a/applications/dynaclr/pyproject.toml b/applications/dynaclr/pyproject.toml index dc31623a5..e3eab2874 100644 --- a/applications/dynaclr/pyproject.toml +++ b/applications/dynaclr/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ optional-dependencies.eval = [ "anndata", + "dtaidistance", "natsort", "phate", "scikit-learn", @@ -51,6 +52,20 @@ optional-dependencies.eval = [ "umap-learn", "wandb", ] +optional-dependencies.tracking = [ + "gurobipy>=12.0.1,<13", + "onnxruntime-gpu", + "py-ctcmetrics", + "tabulate", + "tracksdata", +] +# Visualization extras: PCA-RGB timelapse MP4 export needs the FFmpeg +# plugin for imageio. Without it the timelapse CLI silently falls back +# to writing an animated GIF. +optional-dependencies.viz = [ + "imageio-ffmpeg", + "matplotlib", +] urls.Homepage = "https://github.com/mehta-lab/VisCy" urls.Issues = "https://github.com/mehta-lab/VisCy/issues" urls.Repository = "https://github.com/mehta-lab/VisCy" diff --git a/applications/dynaclr/scripts/cellanome/embed_dinov3.py b/applications/dynaclr/scripts/cellanome/embed_dinov3.py new file mode 100644 index 000000000..c0f084c69 --- /dev/null +++ b/applications/dynaclr/scripts/cellanome/embed_dinov3.py @@ -0,0 +1,407 @@ +"""Extract DINOv3 embeddings for cellanome cells → cell-level AnnData. + +Reads primary_analysis.csv from the Cellanome pipeline, crops cell patches +from the OME-Zarr store, runs them through a frozen DINOv3 model, and writes +a new cell-level AnnData zarr where each row is one segmented cell. + +Usage +----- +uv run python embed_dinov3.py config.yaml +""" + +import argparse +import logging +import math +from pathlib import Path + +import anndata as ad +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +import yaml +import zarr +from tqdm import tqdm + +from viscy_models.foundation import DINOv3Model + +CHANNEL_SHORT_NAMES = { + "White": "BF", + "Blue-FITC (520)": "FITC", + "Red-CY5 (700)": "CY5", + "Green-CY3 (605)": "CY3", +} + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +def load_primary_analysis( + analysis_base: str, + scan_ids: list[int] | None = None, + lane_ids: list[int] | None = None, +) -> pd.DataFrame: + """Load and concatenate primary_analysis.csv for all scans/lanes. + + Parameters + ---------- + analysis_base : str + Path to the image_analysis_output directory. + scan_ids : list[int] or None + Scan IDs to include. If None, auto-discover. + lane_ids : list[int] or None + Lane IDs to include. If None, auto-discover. + + Returns + ------- + pd.DataFrame + Concatenated primary analysis with all columns. + """ + base = Path(analysis_base) + if scan_ids is None: + scan_ids = sorted(int(p.name.split("_")[1]) for p in base.glob("scan_*") if p.is_dir()) + if lane_ids is None: + all_lanes = set() + for scan_id in scan_ids: + scan_dir = base / f"scan_{scan_id}" + all_lanes.update(int(p.name.split("_")[1]) for p in scan_dir.glob("lane_*") if p.is_dir()) + lane_ids = sorted(all_lanes) + + frames = [] + for scan_id in scan_ids: + for lane_id in lane_ids: + csv_path = ( + base + / f"scan_{scan_id}" + / f"lane_{lane_id}" + / "processed" + / "CAGE_REGISTRATION" + / "primary_analysis.csv" + ) + if not csv_path.exists(): + logger.warning(f"Missing: {csv_path}") + continue + df = pd.read_csv(csv_path) + frames.append(df) + logger.info(f"scan_{scan_id}/lane_{lane_id}: {len(df)} objects") + + combined = pd.concat(frames, ignore_index=True) + logger.info(f"Total: {len(combined)} objects across {len(frames)} scan/lane combinations") + return combined + + +def derive_zarr_paths(df: pd.DataFrame) -> pd.DataFrame: + """Derive zarr position and path from cage_crop_file_name. + + Parameters + ---------- + df : pd.DataFrame + Must have columns: cage_crop_file_name, lane_id, scan_id. + + Returns + ------- + pd.DataFrame + With added zarr_position and zarr_path columns. + """ + + def _parse_position(cage_crop: str) -> str: + parts = str(cage_crop).split("_") + return f"{parts[4]}{parts[5]}" + + df["zarr_position"] = df["cage_crop_file_name"].apply(_parse_position) + df["zarr_path"] = df["lane_id"].astype(str) + "/" + df["scan_id"].astype(str) + "/" + df["zarr_position"] + return df + + +def build_barcode_lookup(anndata_path: str) -> dict[tuple[int, str], list[str]]: + """Build (global_cage_id_matched, lane) → [barcode_index, ...] lookup. + + Parameters + ---------- + anndata_path : str + Path to the transcriptome AnnData zarr. + + Returns + ------- + dict[tuple[int, str], list[str]] + Mapping from (cage_id, lane_string) to list of barcode obs_names. + """ + adata = ad.read_zarr(anndata_path) + obs = adata.obs.copy() + obs["_lane"] = obs.index.str.extract(r"(lane_\d)")[0].to_numpy() + obs["_cage_id"] = obs["global.cage.id.matched"].astype(int) + + lookup: dict[tuple[int, str], list[str]] = {} + for idx, row in obs.iterrows(): + key = (row["_cage_id"], row["_lane"]) + lookup.setdefault(key, []).append(idx) + + logger.info(f"Barcode lookup: {len(lookup)} unique (cage, lane) pairs from {adata.n_obs} barcodes") + return lookup + + +def join_barcodes(df: pd.DataFrame, lookup: dict[tuple[int, str], list[str]]) -> pd.DataFrame: + """Join barcode indices to cells via (global_cage_id_matched, lane). + + Parameters + ---------- + df : pd.DataFrame + Must have columns: global_cage_id_matched, lane_id. + lookup : dict + From build_barcode_lookup. + + Returns + ------- + pd.DataFrame + With added barcode_index and in_anndata columns. + """ + barcode_indices = [] + for _, row in df.iterrows(): + key = (int(row["global_cage_id_matched"]), f"lane_{int(row['lane_id'])}") + barcodes = lookup.get(key, []) + barcode_indices.append(";".join(barcodes) if barcodes else "") + df["barcode_index"] = barcode_indices + df["in_anndata"] = df["barcode_index"] != "" + return df + + +def apply_filters(df: pd.DataFrame, filters: dict) -> pd.DataFrame: + """Apply column-level filters to the DataFrame. + + Parameters + ---------- + df : pd.DataFrame + Input DataFrame. + filters : dict + Mapping of column_name → {min, max, eq, isin}. + + Returns + ------- + pd.DataFrame + Filtered DataFrame. + """ + for col, conditions in filters.items(): + if col not in df.columns: + raise ValueError(f"Filter column '{col}' not found. Available: {list(df.columns)[:10]}...") + if "min" in conditions: + df = df[df[col] >= conditions["min"]] + if "max" in conditions: + df = df[df[col] <= conditions["max"]] + if "eq" in conditions: + df = df[df[col] == conditions["eq"]] + if "isin" in conditions: + df = df[df[col].isin(conditions["isin"])] + return df + + +def resolve_channel_indices(store: zarr.Group, zarr_path: str, channel_names: list[str]) -> list[int]: + """Resolve integer indices for named channels in an OME-Zarr FOV. + + Parameters + ---------- + store : zarr.Group + Opened zarr store. + zarr_path : str + Relative path to the FOV group. + channel_names : list[str] + Channel labels to look up. + + Returns + ------- + list[int] + Zero-based channel indices. + """ + fov_group = store[zarr_path] + channels = fov_group.attrs["omero"]["channels"] + labels = [ch.get("label", ch.get("name", "")) for ch in channels] + indices = [] + for name in channel_names: + if name not in labels: + raise ValueError(f"Channel '{name}' not found. Available: {labels}") + indices.append(labels.index(name)) + return indices + + +def crop_cell( + fov_array: np.ndarray, + cy: int, + cx: int, + half: int, + channels: list[int] | None = None, +) -> np.ndarray | None: + """Crop a square patch centered on (cy, cx) from a 2D FOV array. + + Parameters + ---------- + fov_array : np.ndarray + FOV image array of shape ``(C, H, W)``. + cy : int + Y centroid in FOV pixels. + cx : int + X centroid in FOV pixels. + half : int + Half the crop size in pixels. + channels : list[int] or None + Channel indices to select. If None, use all channels. + + Returns + ------- + np.ndarray or None + Cropped patch, or None if out of bounds. + """ + _, h, w = fov_array.shape + y0, y1 = cy - half, cy + half + x0, x1 = cx - half, cx + half + if y0 < 0 or x0 < 0 or y1 > h or x1 > w: + return None + patch = fov_array[:, y0:y1, x0:x1] + if channels is not None: + patch = patch[channels] + return patch + + +def main(): + """Extract DINOv3 embeddings for cellanome cells.""" + parser = argparse.ArgumentParser(description="Extract DINOv3 embeddings for cellanome cells.") + parser.add_argument("config", help="Path to YAML config file") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + zarr_store = cfg["zarr_store"] + analysis_base = cfg["analysis_base"] + transcriptome_anndata = cfg.get("transcriptome_anndata", None) + output_path = cfg["output_path"] + model_name = cfg.get("model_name", "facebook/dinov2-base") + channels = cfg.get("channels", None) + output_key = cfg.get("output_key", None) + patch_size = cfg.get("patch_size", 96) + reference_pixel_size = cfg.get("reference_pixel_size", 1.0) + source_pixel_size = cfg.get("source_pixel_size", 1.0) + batch_size = cfg.get("batch_size", 128) + device_str = cfg.get("device", "cuda") + scan_ids = cfg.get("scan_ids", None) + lane_ids = cfg.get("lane_ids", None) + filters = cfg.get("filters", {}) + + # --- Load and prepare data --- + df = load_primary_analysis(analysis_base, scan_ids, lane_ids) + n_raw = len(df) + df = apply_filters(df, filters) + logger.info(f"After filtering: {len(df)} cells (removed {n_raw - len(df)})") + + df = derive_zarr_paths(df) + if transcriptome_anndata is not None: + lookup = build_barcode_lookup(transcriptome_anndata) + df = join_barcodes(df, lookup) + n_matched = df["in_anndata"].sum() + logger.info(f"Barcode match: {n_matched}/{len(df)} cells ({100 * n_matched / len(df):.1f}%)") + else: + logger.info("No transcriptome_anndata provided; skipping barcode join") + + # --- Pixel size rescaling --- + # raw_crop covers the same physical area as patch_size at reference resolution. + # Larger source pixels → fewer pixels needed. + raw_half = round(patch_size * reference_pixel_size / source_pixel_size) // 2 + raw_crop_size = 2 * raw_half + logger.info(f"Raw crop: {raw_crop_size}x{raw_crop_size} -> model input: {patch_size}x{patch_size}") + + # --- Resolve channels --- + store = zarr.open(zarr_store, mode="r") + first_zarr_path = df["zarr_path"].iloc[0] + if channels is not None: + channel_indices = resolve_channel_indices(store, first_zarr_path, channels) + channel_labels = channels + else: + fov_group = store[first_zarr_path] + omero_channels = fov_group.attrs["omero"]["channels"] + channel_labels = [ch.get("label", ch.get("name", "")) for ch in omero_channels] + channel_indices = list(range(len(channel_labels))) + logger.info(f"Channels: {channel_labels} (indices {channel_indices})") + + short_names = [CHANNEL_SHORT_NAMES.get(ch, ch) for ch in channel_labels] + output_key = output_key or "dinov3_" + "_".join(short_names) + + # --- Load model --- + device = torch.device(device_str if torch.cuda.is_available() else "cpu") + model = DINOv3Model(model_name=model_name, freeze=True) + model = model.to(device) + model.eval() + logger.info(f"Loaded DINOv3 {model_name} on {device}") + + # --- Inference --- + df = df.sort_values("zarr_path").reset_index(drop=True) + current_fov_path: str | None = None + current_fov: np.ndarray | None = None + all_embeddings = [] + valid_indices = [] + skipped_border = 0 + + n_batches = math.ceil(len(df) / batch_size) + pbar = tqdm(range(0, len(df), batch_size), total=n_batches, desc="Embedding", unit="batch") + for batch_start in pbar: + batch_df = df.iloc[batch_start : batch_start + batch_size] + patches = [] + batch_valid = [] + + for idx, row in batch_df.iterrows(): + zarr_path = row["zarr_path"] + cy, cx = int(row["object_y_fov"]), int(row["object_x_fov"]) + + if zarr_path != current_fov_path: + current_fov = store[zarr_path]["0"][0, :, 0] + current_fov_path = zarr_path + + patch = crop_cell(current_fov, cy, cx, raw_half, channels=channel_indices) + if patch is None: + skipped_border += 1 + continue + + patches.append(patch) + batch_valid.append(idx) + + if not patches: + continue + + batch_tensor = torch.from_numpy(np.stack(patches)).float() + if raw_crop_size != patch_size: + batch_tensor = F.interpolate( + batch_tensor, size=(patch_size, patch_size), mode="bilinear", align_corners=False + ) + + # Per-sample z-score: zero mean, unit std + mean = batch_tensor.flatten(1).mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) + std = batch_tensor.flatten(1).std(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1).clamp(min=1e-8) + batch_tensor = (batch_tensor - mean) / std + + batch_tensor = batch_tensor.unsqueeze(2).to(device) + + with torch.inference_mode(): + features, _ = model(batch_tensor) + + all_embeddings.append(features.cpu().numpy()) + valid_indices.extend(batch_valid) + pbar.set_postfix(cells=len(valid_indices), skipped=skipped_border) + + if skipped_border > 0: + logger.warning(f"Skipped {skipped_border} cells too close to FOV border") + + # --- Write cell-level anndata --- + embeddings = np.concatenate(all_embeddings, axis=0) + valid_df = df.iloc[valid_indices].reset_index(drop=True) + logger.info(f"Embeddings: {embeddings.shape}") + + pd.options.future.infer_string = False + obs = valid_df.copy() + for col in obs.select_dtypes(include=["string", "string[pyarrow]"]).columns: + obs[col] = obs[col].astype(object) + obs.index = obs["object_uuid"].astype(str) + + cell_adata = ad.AnnData(X=embeddings.astype(np.float32), obs=obs) + cell_adata.write_zarr(output_path) + logger.info(f"Wrote {output_path}: {cell_adata.n_obs} cells x {cell_adata.n_vars} dims") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/cellanome/embed_dynaclr.py b/applications/dynaclr/scripts/cellanome/embed_dynaclr.py new file mode 100644 index 000000000..32e343856 --- /dev/null +++ b/applications/dynaclr/scripts/cellanome/embed_dynaclr.py @@ -0,0 +1,405 @@ +"""Extract DynaCLR embeddings for cellanome cells → cell-level AnnData. + +Reads primary_analysis.csv from the Cellanome pipeline, crops cell patches +(single channel) from the OME-Zarr store, runs them through a DynaCLR +contrastive encoder checkpoint, and writes a new cell-level AnnData zarr. + +Usage +----- +uv run python applications/dynaclr/scripts/cellanome/embed_dynaclr.py config.yaml +""" + +import argparse +import logging +import math +from pathlib import Path + +import anndata as ad +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +import yaml +import zarr +from tqdm import tqdm + +from dynaclr.engine import ContrastiveEncoder + +CHANNEL_SHORT_NAMES = { + "White": "BF", + "Blue-FITC (520)": "FITC", + "Red-CY5 (700)": "CY5", + "Green-CY3 (605)": "CY3", +} + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +def load_primary_analysis( + analysis_base: str, + scan_ids: list[int] | None = None, + lane_ids: list[int] | None = None, +) -> pd.DataFrame: + """Load and concatenate primary_analysis.csv for all scans/lanes. + + Parameters + ---------- + analysis_base : str + Path to the image_analysis_output directory. + scan_ids : list[int] or None + Scan IDs to include. If None, auto-discover. + lane_ids : list[int] or None + Lane IDs to include. If None, auto-discover. + + Returns + ------- + pd.DataFrame + Concatenated primary analysis with all columns. + """ + base = Path(analysis_base) + if scan_ids is None: + scan_ids = sorted(int(p.name.split("_")[1]) for p in base.glob("scan_*") if p.is_dir()) + if lane_ids is None: + all_lanes = set() + for scan_id in scan_ids: + scan_dir = base / f"scan_{scan_id}" + all_lanes.update(int(p.name.split("_")[1]) for p in scan_dir.glob("lane_*") if p.is_dir()) + lane_ids = sorted(all_lanes) + + frames = [] + for scan_id in scan_ids: + for lane_id in lane_ids: + csv_path = ( + base + / f"scan_{scan_id}" + / f"lane_{lane_id}" + / "processed" + / "CAGE_REGISTRATION" + / "primary_analysis.csv" + ) + if not csv_path.exists(): + logger.warning(f"Missing: {csv_path}") + continue + df = pd.read_csv(csv_path) + frames.append(df) + logger.info(f"scan_{scan_id}/lane_{lane_id}: {len(df)} objects") + + combined = pd.concat(frames, ignore_index=True) + logger.info(f"Total: {len(combined)} objects across {len(frames)} scan/lane combinations") + return combined + + +def derive_zarr_paths(df: pd.DataFrame) -> pd.DataFrame: + """Derive zarr position and path from cage_crop_file_name. + + Parameters + ---------- + df : pd.DataFrame + Must have columns: cage_crop_file_name, lane_id, scan_id. + + Returns + ------- + pd.DataFrame + With added zarr_position and zarr_path columns. + """ + + def _parse_position(cage_crop: str) -> str: + parts = str(cage_crop).split("_") + return f"{parts[4]}{parts[5]}" + + df["zarr_position"] = df["cage_crop_file_name"].apply(_parse_position) + df["zarr_path"] = df["lane_id"].astype(str) + "/" + df["scan_id"].astype(str) + "/" + df["zarr_position"] + return df + + +def build_barcode_lookup(anndata_path: str) -> dict[tuple[int, str], list[str]]: + """Build (global_cage_id_matched, lane) → [barcode_index, ...] lookup. + + Parameters + ---------- + anndata_path : str + Path to the transcriptome AnnData zarr. + + Returns + ------- + dict[tuple[int, str], list[str]] + Mapping from (cage_id, lane_string) to list of barcode obs_names. + """ + adata = ad.read_zarr(anndata_path) + obs = adata.obs.copy() + obs["_lane"] = obs.index.str.extract(r"(lane_\d)")[0].to_numpy() + obs["_cage_id"] = obs["global.cage.id.matched"].astype(int) + + lookup: dict[tuple[int, str], list[str]] = {} + for idx, row in obs.iterrows(): + key = (row["_cage_id"], row["_lane"]) + lookup.setdefault(key, []).append(idx) + + logger.info(f"Barcode lookup: {len(lookup)} unique (cage, lane) pairs from {adata.n_obs} barcodes") + return lookup + + +def join_barcodes(df: pd.DataFrame, lookup: dict[tuple[int, str], list[str]]) -> pd.DataFrame: + """Join barcode indices to cells via (global_cage_id_matched, lane). + + Parameters + ---------- + df : pd.DataFrame + Must have columns: global_cage_id_matched, lane_id. + lookup : dict + From build_barcode_lookup. + + Returns + ------- + pd.DataFrame + With added barcode_index and in_anndata columns. + """ + barcode_indices = [] + for _, row in df.iterrows(): + key = (int(row["global_cage_id_matched"]), f"lane_{int(row['lane_id'])}") + barcodes = lookup.get(key, []) + barcode_indices.append(";".join(barcodes) if barcodes else "") + df["barcode_index"] = barcode_indices + df["in_anndata"] = df["barcode_index"] != "" + return df + + +def apply_filters(df: pd.DataFrame, filters: dict) -> pd.DataFrame: + """Apply column-level filters to the DataFrame. + + Parameters + ---------- + df : pd.DataFrame + Input DataFrame. + filters : dict + Mapping of column_name → {min, max, eq, isin}. + + Returns + ------- + pd.DataFrame + Filtered DataFrame. + """ + for col, conditions in filters.items(): + if col not in df.columns: + raise ValueError(f"Filter column '{col}' not found. Available: {list(df.columns)[:10]}...") + if "min" in conditions: + df = df[df[col] >= conditions["min"]] + if "max" in conditions: + df = df[df[col] <= conditions["max"]] + if "eq" in conditions: + df = df[df[col] == conditions["eq"]] + if "isin" in conditions: + df = df[df[col].isin(conditions["isin"])] + return df + + +def resolve_channel_index(store: zarr.Group, zarr_path: str, channel_name: str) -> int: + """Resolve the integer index of a named channel in an OME-Zarr FOV. + + Parameters + ---------- + store : zarr.Group + Opened zarr store. + zarr_path : str + Relative path to the FOV group. + channel_name : str + Channel label to look up. + + Returns + ------- + int + Zero-based channel index. + """ + fov_group = store[zarr_path] + channels = fov_group.attrs["omero"]["channels"] + labels = [ch.get("label", ch.get("name", "")) for ch in channels] + if channel_name not in labels: + raise ValueError(f"Channel '{channel_name}' not found. Available: {labels}") + return labels.index(channel_name) + + +def crop_cell( + fov_array: np.ndarray, + cy: int, + cx: int, + half: int, + channel_idx: int | None = None, +) -> np.ndarray | None: + """Crop a square patch centered on (cy, cx) from a 2D FOV array. + + Parameters + ---------- + fov_array : np.ndarray + FOV image array of shape ``(C, H, W)``. + cy : int + Y centroid in FOV pixels. + cx : int + X centroid in FOV pixels. + half : int + Half the crop size in pixels. + channel_idx : int or None + Channel index to select. If None, use all channels. + + Returns + ------- + np.ndarray or None + Cropped patch, or None if out of bounds. + """ + _, h, w = fov_array.shape + y0, y1 = cy - half, cy + half + x0, x1 = cx - half, cx + half + if y0 < 0 or x0 < 0 or y1 > h or x1 > w: + return None + if channel_idx is not None: + return fov_array[channel_idx : channel_idx + 1, y0:y1, x0:x1] + return fov_array[:, y0:y1, x0:x1] + + +def main(): + """Extract DynaCLR embeddings for cellanome cells.""" + parser = argparse.ArgumentParser(description="Extract DynaCLR embeddings for cellanome cells.") + parser.add_argument("config", help="Path to YAML config file") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + zarr_store = cfg["zarr_store"] + analysis_base = cfg["analysis_base"] + transcriptome_anndata = cfg.get("transcriptome_anndata", None) + output_path = cfg["output_path"] + ckpt_path = cfg["ckpt_path"] + encoder_config = cfg["encoder_config"] + channel_name = cfg.get("channel_name", "White") + output_key = cfg.get("output_key", None) + patch_size = cfg.get("patch_size", 96) + reference_pixel_size = cfg.get("reference_pixel_size", 1.0) + source_pixel_size = cfg.get("source_pixel_size", 1.0) + batch_size = cfg.get("batch_size", 128) + device_str = cfg.get("device", "cuda") + scan_ids = cfg.get("scan_ids", None) + lane_ids = cfg.get("lane_ids", None) + filters = cfg.get("filters", {}) + + # --- Load and prepare data --- + df = load_primary_analysis(analysis_base, scan_ids, lane_ids) + n_raw = len(df) + df = apply_filters(df, filters) + logger.info(f"After filtering: {len(df)} cells (removed {n_raw - len(df)})") + + df = derive_zarr_paths(df) + if transcriptome_anndata is not None: + lookup = build_barcode_lookup(transcriptome_anndata) + df = join_barcodes(df, lookup) + n_matched = df["in_anndata"].sum() + logger.info(f"Barcode match: {n_matched}/{len(df)} cells ({100 * n_matched / len(df):.1f}%)") + else: + logger.info("No transcriptome_anndata provided; skipping barcode join") + + # --- Pixel size rescaling --- + # raw_crop covers the same physical area as patch_size at reference resolution. + # Larger source pixels → fewer pixels needed. + raw_half = round(patch_size * reference_pixel_size / source_pixel_size) // 2 + raw_crop_size = 2 * raw_half + logger.info(f"Raw crop: {raw_crop_size}x{raw_crop_size} -> model input: {patch_size}x{patch_size}") + + # --- Resolve channel --- + store = zarr.open(zarr_store, mode="r") + first_zarr_path = df["zarr_path"].iloc[0] + channel_idx = resolve_channel_index(store, first_zarr_path, channel_name) + logger.info(f"Channel '{channel_name}' -> index {channel_idx}") + + short_name = CHANNEL_SHORT_NAMES.get(channel_name, channel_name) + output_key = output_key or f"dynaclr_{short_name}" + + # --- Load model --- + device = torch.device(device_str if torch.cuda.is_available() else "cpu") + encoder_config["stem_kernel_size"] = tuple(encoder_config["stem_kernel_size"]) + encoder_config["stem_stride"] = tuple(encoder_config["stem_stride"]) + encoder = ContrastiveEncoder(**encoder_config) + ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) + sd = {k.replace("model.", "", 1): v for k, v in ckpt["state_dict"].items() if k.startswith("model.")} + encoder.load_state_dict(sd) + encoder = encoder.to(device) + encoder.eval() + logger.info(f"Loaded DynaCLR encoder from {ckpt_path} on {device}") + + # --- Inference --- + df = df.sort_values("zarr_path").reset_index(drop=True) + current_fov_path: str | None = None + current_fov: np.ndarray | None = None + all_embeddings = [] + valid_indices = [] + skipped_border = 0 + + n_batches = math.ceil(len(df) / batch_size) + pbar = tqdm(range(0, len(df), batch_size), total=n_batches, desc="Embedding", unit="batch") + for batch_start in pbar: + batch_df = df.iloc[batch_start : batch_start + batch_size] + patches = [] + batch_valid = [] + + for idx, row in batch_df.iterrows(): + zarr_path = row["zarr_path"] + cy, cx = int(row["object_y_fov"]), int(row["object_x_fov"]) + + if zarr_path != current_fov_path: + current_fov = store[zarr_path]["0"][0, :, 0] + current_fov_path = zarr_path + + patch = crop_cell(current_fov, cy, cx, raw_half, channel_idx=channel_idx) + if patch is None: + skipped_border += 1 + continue + + patches.append(patch) + batch_valid.append(idx) + + if not patches: + continue + + batch_tensor = torch.from_numpy(np.stack(patches)).float() + if raw_crop_size != patch_size: + batch_tensor = F.interpolate( + batch_tensor, + size=(patch_size, patch_size), + mode="bilinear", + align_corners=False, + ) + + # Per-sample z-score: zero mean, unit std + mean = batch_tensor.flatten(1).mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) + std = batch_tensor.flatten(1).std(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1).clamp(min=1e-8) + batch_tensor = (batch_tensor - mean) / std + + batch_tensor = batch_tensor.unsqueeze(2).to(device) + + with torch.inference_mode(): + embedding, _ = encoder(batch_tensor) + + all_embeddings.append(embedding.cpu().numpy()) + valid_indices.extend(batch_valid) + pbar.set_postfix(cells=len(valid_indices), skipped=skipped_border) + + if skipped_border > 0: + logger.warning(f"Skipped {skipped_border} cells too close to FOV border") + + # --- Write cell-level anndata --- + embeddings = np.concatenate(all_embeddings, axis=0) + valid_df = df.iloc[valid_indices].reset_index(drop=True) + logger.info(f"Embeddings: {embeddings.shape}") + + pd.options.future.infer_string = False + obs = valid_df.copy() + for col in obs.select_dtypes(include=["string", "string[pyarrow]"]).columns: + obs[col] = obs[col].astype(object) + obs.index = obs["object_uuid"].astype(str) + + cell_adata = ad.AnnData(X=embeddings.astype(np.float32), obs=obs) + cell_adata.write_zarr(output_path) + logger.info(f"Wrote {output_path}: {cell_adata.n_obs} cells x {cell_adata.n_vars} dims") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/dataloader_inspection/benchmark_setup_time.py b/applications/dynaclr/scripts/dataloader_inspection/benchmark_setup_time.py new file mode 100644 index 000000000..7668b4aa9 --- /dev/null +++ b/applications/dynaclr/scripts/dataloader_inspection/benchmark_setup_time.py @@ -0,0 +1,117 @@ +"""Benchmark MultiExperimentDataModule setup time. + +Measures the time for _compute_valid_anchors and _build_match_lookup +on the DynaCLR-2D-MIP-BagOfChannels parquet (3.3M rows) to quantify +the speedup from the vectorized implementations. + +Usage +----- + uv run python applications/dynaclr/scripts/dataloader_inspection/benchmark_setup_time.py +""" + +from __future__ import annotations + +import time + +CELL_INDEX_PARQUET = "applications/dynaclr/configs/cell_index/DynaCLR-2D-MIP-BagOfChannels.parquet" +TAU_RANGE = (0.5, 2.0) +YX_PATCH_SIZE = (256, 256) + + +def _fmt(seconds: float) -> str: + if seconds < 1: + return f"{seconds * 1000:.1f} ms" + if seconds < 60: + return f"{seconds:.2f} s" + return f"{seconds / 60:.1f} min" + + +def main() -> None: + """Run the MultiExperimentDataModule setup benchmark and print a timing summary.""" + from dynaclr.data.experiment import ExperimentRegistry + from dynaclr.data.index import MultiExperimentIndex + from viscy_data.cell_index import read_cell_index + + print("=" * 60) + print("MultiExperimentDataModule setup benchmark") + print(f"Parquet: {CELL_INDEX_PARQUET}") + print("=" * 60) + + # ---------------------------------------------------------------- + # Parquet read (shared cost) + # ---------------------------------------------------------------- + t0 = time.perf_counter() + df = read_cell_index(CELL_INDEX_PARQUET) + parquet_time = time.perf_counter() - t0 + print(f"\nParquet read: {_fmt(parquet_time)} ({len(df):,} rows)") + + # ---------------------------------------------------------------- + # Registry build (shared cost) + # ---------------------------------------------------------------- + t0 = time.perf_counter() + registry, _ = ExperimentRegistry.from_cell_index( + CELL_INDEX_PARQUET, + z_window=1, + z_extraction_window=20, + z_focus_offset=0.3, + focus_channel="Phase3D", + reference_pixel_size_xy_um=0.1494, + ) + registry_time = time.perf_counter() - t0 + print(f"Registry build: {_fmt(registry_time)} ({len(registry.experiments)} experiments)") + + # ---------------------------------------------------------------- + # MultiExperimentIndex (includes _compute_valid_anchors) + # ---------------------------------------------------------------- + print("\n--- MultiExperimentIndex (cell_index_df path) ---") + t0 = time.perf_counter() + index = MultiExperimentIndex( + registry=registry, + yx_patch_size=YX_PATCH_SIZE, + tau_range_hours=TAU_RANGE, + cell_index_df=df, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + index_time = time.perf_counter() - t0 + print(f" Total: {_fmt(index_time)}") + print(f" Tracks: {len(index.tracks):,} Valid anchors: {len(index.valid_anchors):,}") + + # ---------------------------------------------------------------- + # _build_match_lookup (MultiExperimentTripletDataset init) + # ---------------------------------------------------------------- + print("\n--- _build_match_lookup (dataset init) ---") + from dynaclr.data.dataset import MultiExperimentTripletDataset + + t0 = time.perf_counter() + MultiExperimentTripletDataset( + index=index, + fit=True, + tau_range_hours=TAU_RANGE, + cache_pool_bytes=0, + channels_per_sample=1, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + dataset_time = time.perf_counter() - t0 + print(f" _build_match_lookup: {_fmt(dataset_time)}") + + # ---------------------------------------------------------------- + # Summary + # ---------------------------------------------------------------- + total = parquet_time + registry_time + index_time + dataset_time + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + print("| Step | Time |") + print("|-------------------------|----------------|") + print(f"| Parquet read | {_fmt(parquet_time):>14} |") + print(f"| Registry build | {_fmt(registry_time):>14} |") + print(f"| Index (_valid_anchors) | {_fmt(index_time):>14} |") + print(f"| Dataset (_match_lookup) | {_fmt(dataset_time):>14} |") + print("|-------------------------|----------------|") + print(f"| **Total** | {_fmt(total):>14} |") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/dataloader_inspection/check_batch_composition.py b/applications/dynaclr/scripts/dataloader_inspection/check_batch_composition.py index af6a38a8b..766c0551d 100644 --- a/applications/dynaclr/scripts/dataloader_inspection/check_batch_composition.py +++ b/applications/dynaclr/scripts/dataloader_inspection/check_batch_composition.py @@ -13,7 +13,7 @@ Usage:: - python applications/dynaclr/scripts/dataloader_inspection/check_batch_composition.py + uv run python applications/dynaclr/scripts/dataloader_inspection/check_batch_composition.py """ # ruff: noqa: E402, D103 @@ -45,7 +45,7 @@ COLLECTION_PATH = "/home/eduardo.hirata/repos/viscy/applications/dynaclr/configs/collections/example_cell_index.yaml" Z_WINDOW = 1 -YX_PATCH_SIZE = (256, 256) +YX_PATCH_SIZE = (192, 192) FINAL_YX_PATCH_SIZE = (160, 160) BATCH_SIZE = 8 NUM_WORKERS = 4 @@ -164,7 +164,7 @@ def run_scenario( bi, name, checks=checks, - save_path=OUTPUT_DIR / f"{name.lower().replace(' ', '_')}_batch{bi}.png" if OUTPUT_DIR else None, + save_path=(OUTPUT_DIR / f"{name.lower().replace(' ', '_')}_batch{bi}.png" if OUTPUT_DIR else None), ) return batches @@ -183,7 +183,6 @@ def run_scenario( print("Building DataModule...") dm = MultiExperimentDataModule( - collection_path=COLLECTION_PATH, cell_index_path=CELL_INDEX_PATH, z_window=Z_WINDOW, yx_patch_size=YX_PATCH_SIZE, @@ -367,7 +366,6 @@ def run_scenario( # %% dm_simclr = MultiExperimentDataModule( - collection_path=COLLECTION_PATH, cell_index_path=CELL_INDEX_PATH, z_window=Z_WINDOW, yx_patch_size=YX_PATCH_SIZE, @@ -434,7 +432,6 @@ def run_scenario( def run_normalization_scenario(name: str, level: str) -> None: dm_n = MultiExperimentDataModule( - collection_path=COLLECTION_PATH, cell_index_path=CELL_INDEX_PATH, z_window=Z_WINDOW, yx_patch_size=YX_PATCH_SIZE, diff --git a/applications/dynaclr/scripts/dataloader_inspection/check_marker_experiment_stratify.py b/applications/dynaclr/scripts/dataloader_inspection/check_marker_experiment_stratify.py new file mode 100644 index 000000000..1659909c5 --- /dev/null +++ b/applications/dynaclr/scripts/dataloader_inspection/check_marker_experiment_stratify.py @@ -0,0 +1,132 @@ +"""Verify FlexibleBatchSampler composition for batch_group_by=marker + stratify_by=experiment. + +Loads the production cell index, configures a sampler that mirrors the +proposed DynaCLR-2D-MIP single-marker override (marker batches stratified +by experiment), draws a handful of batches, and prints a marker x experiment +cross-tab per batch. + +Run before committing a sampler config change to confirm batches compose +the way the config promises. + +Usage +----- + uv run python applications/dynaclr/scripts/dataloader_inspection/check_marker_experiment_stratify.py +""" + +from __future__ import annotations + +import sys +from collections import Counter +from pathlib import Path + +import pandas as pd + +CELL_INDEX_PARQUET = ( + "/hpc/projects/organelle_phenotyping/models/collections/DynaCLR-2D-MIP-BagOfChannels-v3.parquet" +) +BATCH_SIZE = 256 +N_BATCHES_TO_SHOW = 16 +SEED = 42 + + +def _config(label: str, batch_group_by, stratify_by, group_weights=None) -> dict: + return { + "label": label, + "batch_group_by": batch_group_by, + "stratify_by": stratify_by, + "group_weights": group_weights, + } + + +# Uniform weights matching the v3 single-marker override (9 markers after BF +# and Retardance are dropped from the v3 collection). +UNIFORM_WEIGHTS = { + "Phase3D": 1.0, + "pAL10": 1.0, + "viral_sensor": 1.0, + "G3BP1": 1.0, + "SEC61B": 1.0, + "TOMM20": 1.0, + "CAAX": 1.0, + "HIST2H2BE": 1.0, + "DIC": 1.0, +} + +CONFIGS = [ + _config("current (stratify_by=null)", batch_group_by="marker", stratify_by=None), + _config("proposed (stratify_by=experiment)", batch_group_by="marker", stratify_by="experiment"), + _config( + "proposed + uniform group_weights", + batch_group_by="marker", + stratify_by="experiment", + group_weights=UNIFORM_WEIGHTS, + ), +] + + +def main() -> None: + from viscy_data.sampler import FlexibleBatchSampler + + print(f"Loading parquet: {CELL_INDEX_PARQUET}") + df = pd.read_parquet(CELL_INDEX_PARQUET) + print(f" rows={len(df):,} unique markers={df['marker'].nunique()} unique experiments={df['experiment'].nunique()}") + print() + + # FlexibleBatchSampler expects valid_anchors with the relevant columns; + # for sampler-composition QC we don't need _real_ anchor validity, just + # representative rows. Use the full parquet directly. + valid_anchors = df + + for cfg in CONFIGS: + print("=" * 80) + print( + f"## {cfg['label']}: batch_group_by={cfg['batch_group_by']!r}, " + f"stratify_by={cfg['stratify_by']!r}, " + f"group_weights={'set' if cfg.get('group_weights') else 'None'}" + ) + print("=" * 80) + + sampler = FlexibleBatchSampler( + valid_anchors=valid_anchors, + batch_size=BATCH_SIZE, + batch_group_by=cfg["batch_group_by"], + stratify_by=cfg["stratify_by"], + group_weights=cfg.get("group_weights"), + leaky=0.0, + seed=SEED, + ) + + # Collect first N batches. + marker_counts: Counter = Counter() + for i, batch_indices in enumerate(sampler): + if i >= N_BATCHES_TO_SHOW: + break + batch_rows = valid_anchors.iloc[batch_indices] + markers = batch_rows["marker"].unique() + experiments = batch_rows["experiment"].value_counts() + primary_marker = batch_rows["marker"].mode().iloc[0] + marker_counts[primary_marker] += 1 + + print( + f"batch {i:>2}: marker={primary_marker!s:<14} " + f"unique_markers={len(markers)} " + f"unique_experiments={len(experiments)}" + ) + # If marker integrity holds, len(markers) should be 1. + if len(markers) > 1: + print(f" WARN: batch contains MULTIPLE markers: {sorted(markers)}") + # Show top 3 experiments in the batch + for exp_name, count in experiments.head(3).items(): + print(f" {exp_name:<60s} {count:>4d}") + if len(experiments) > 3: + print(f" ... +{len(experiments) - 3} more experiments") + + print() + print(f"Marker selection across {N_BATCHES_TO_SHOW} batches:") + for m, n in marker_counts.most_common(): + print(f" {m:<20s} {n}") + print() + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/dataloader_inspection/data_patch_resizing.py b/applications/dynaclr/scripts/dataloader_inspection/data_patch_resizing.py index 0a2816438..e21015fab 100644 --- a/applications/dynaclr/scripts/dataloader_inspection/data_patch_resizing.py +++ b/applications/dynaclr/scripts/dataloader_inspection/data_patch_resizing.py @@ -1,29 +1,32 @@ """End-to-end proof that DynaCLR pixel-size normalization works. -Creates a temporary parquet with modified pixel sizes, feeds it through the -real ``MultiExperimentDataModule`` dataloader, and plots the output patches. +Builds the datamodule once to get sample metadata (cell coordinates), +then reads native zarr crops at different pixel-size-derived scales +and rescales them to show how the pipeline normalizes physical extent. -The Mantis experiment (0.1494 um/px) is the reference. The Dragonfly experiment -natively has 0.206 um/px — we test with both the real value and an artificial -override to show the dataloader responds correctly. +Row 0: Raw FOV with bounding boxes for each pixel-size variant. +Row 1: Native zarr crop → _rescale_patch → center crop = model input (160×160). Usage:: uv run python applications/dynaclr/scripts/dataloader_inspection/data_patch_resizing.py """ +# %% # ruff: noqa: D103 from __future__ import annotations -import tempfile from pathlib import Path +import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np -import pandas as pd +import torch +from iohub.ngff.nodes import open_ome_zarr from dynaclr.data.datamodule import MultiExperimentDataModule +from dynaclr.data.dataset import _rescale_patch from viscy_transforms._crop import BatchedCenterSpatialCrop # --------------------------------------------------------------------------- @@ -32,7 +35,7 @@ _ROOT = Path(__file__).resolve().parents[4] -CELL_INDEX_PATH = _ROOT / "applications/dynaclr/configs/cell_index/dragonfly_mantis_demo.parquet" +CELL_INDEX_PATH = _ROOT / "applications/dynaclr/configs/cell_index/example_mantis_dragonfly.parquet" OUTPUT_DIR = _ROOT / "applications/dynaclr/scripts/dataloader_inspection/output" OUTPUT_PATH = OUTPUT_DIR / "data_patch_resizing.png" @@ -40,116 +43,194 @@ YX_PATCH_SIZE = (200, 200) FINAL_YX_PATCH_SIZE = (160, 160) REFERENCE_PIXEL_SIZE_XY_UM = 0.1494 -REFERENCE_PIXEL_SIZE_Z_UM = 0.2878 CHANNEL_NAME = "Phase3D" DRAGONFLY_EXP = "2024_08_14_ZIKV_pal17_48h" -MANTIS_EXP = "2025_07_24_A549_SEC61B_ZIKV" +MANTIS_EXP = "2025_07_24_A549_SEC61_ZIKV" -# Pixel sizes to test for Dragonfly (real + artificial overrides) +# Pixel sizes to visualize for Dragonfly DRAGONFLY_PIXEL_SIZES = { "real (0.206)": 0.206, - "override (0.1494)": 0.1494, # same as reference — should be no-op - "override (0.7)": 0.7, # even coarser — should crop fewer pixels + "same as ref (0.1494)": 0.1494, + "coarser (0.7)": 0.7, } +BBOX_COLORS = ["#e74c3c", "#2ecc71", "#3498db"] +INCLUDE_WELLS = ["A/2", "0/4"] # --------------------------------------------------------------------------- -# Helpers +# Step 1: Build datamodule once to get sample metadata # --------------------------------------------------------------------------- +print("Building datamodule...") +dm = MultiExperimentDataModule( + cell_index_path=str(CELL_INDEX_PATH), + z_window=Z_WINDOW, + yx_patch_size=YX_PATCH_SIZE, + final_yx_patch_size=FINAL_YX_PATCH_SIZE, + batch_size=8, + num_workers=0, + channels_per_sample=[CHANNEL_NAME], + reference_pixel_size_xy_um=REFERENCE_PIXEL_SIZE_XY_UM, + reference_pixel_size_z_um=None, + positive_cell_source="self", + tau_range=(0.0, 100.0), + stratify_by=None, + include_wells=INCLUDE_WELLS, +) +dm.setup("fit") + +registry = dm.train_dataset.index.registry + +print("Drawing samples for metadata...") +loader = dm.train_dataloader() +per_exp: dict[str, dict] = {} +needed = {e.name for e in registry.experiments} + +MAX_BATCHES = 200 +for batch_idx, batch in enumerate(loader): + anchor = batch["anchor"] + meta = batch["anchor_meta"] + for i in range(len(meta)): + exp_name = meta[i]["experiment"] + if exp_name not in per_exp: + per_exp[exp_name] = {"meta": meta[i], "patch": anchor[i]} + if per_exp.keys() >= needed: + break + if batch_idx >= MAX_BATCHES: + print(f" WARNING: only found experiments {set(per_exp.keys())} after {MAX_BATCHES} batches") + break + +for exp_name, d in per_exp.items(): + m = d["meta"] + print(f" {exp_name}: fov={m['fov_name']}, t={m['t']}, y={m['y_clamp']}, x={m['x_clamp']}") -def make_tmp_parquet(pixel_size_xy: float, pixel_size_z: float = REFERENCE_PIXEL_SIZE_Z_UM) -> str: - """Write a temp parquet with Dragonfly pixel sizes overridden.""" - df = pd.read_parquet(CELL_INDEX_PATH) - mask = df["experiment"] == DRAGONFLY_EXP - df.loc[mask, "pixel_size_xy_um"] = pixel_size_xy - df.loc[mask, "pixel_size_z_um"] = pixel_size_z - tmp = tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) - df.to_parquet(tmp.name) - return tmp.name - - -def draw_one_sample(parquet_path: str) -> dict: - """Build a datamodule, draw one batch, return first anchor patch + metadata.""" - dm = MultiExperimentDataModule( - collection_path=None, - cell_index_path=parquet_path, - z_window=Z_WINDOW, - yx_patch_size=YX_PATCH_SIZE, - final_yx_patch_size=FINAL_YX_PATCH_SIZE, - batch_size=8, - num_workers=0, - channels_per_sample=[CHANNEL_NAME], - reference_pixel_size_xy_um=REFERENCE_PIXEL_SIZE_XY_UM, - reference_pixel_size_z_um=REFERENCE_PIXEL_SIZE_Z_UM, - positive_cell_source="self", - tau_range=(0.0, 100.0), - stratify_by=None, - ) - dm.setup("fit") - - registry = dm.train_dataset.index.registry - scale_factors = {e.name: registry.scale_factors[e.name] for e in registry.experiments} - - # Draw batches until we get one from each experiment - loader = dm.train_dataloader() - per_exp: dict[str, dict] = {} - needed = {e.name for e in registry.experiments} - for batch in loader: - anchor = batch["anchor"] - meta = batch["anchor_meta"] - for i in range(anchor.shape[0]): - exp_name = meta[i]["experiment"] - if exp_name not in per_exp: - per_exp[exp_name] = { - "patch": anchor[i], - "meta": meta[i], - "scale": scale_factors[exp_name], - } - if per_exp.keys() >= needed: - break +# --------------------------------------------------------------------------- +# Step 2: Read raw FOV slices and native crops from zarr +# --------------------------------------------------------------------------- - return per_exp +def read_fov_and_crop( + meta: dict, + pixel_size_xy: float, + z_focus: int, + channel_name: str = CHANNEL_NAME, +) -> tuple[np.ndarray, np.ndarray, int, int]: + """Read the focus Z-slice FOV and a native crop at the given pixel size. + + Returns + ------- + fov : np.ndarray + Full FOV 2D image at the focus Z-slice. + crop : np.ndarray + Native crop at the scale implied by pixel_size_xy. + y_half, x_half : int + Half-widths of the native crop in pixels. + """ + store_path = meta["store_path"] + fov_name = meta["fov_name"] + t = int(meta["t"]) + y_center = int(meta["y_clamp"]) + x_center = int(meta["x_clamp"]) + + scale_yx = REFERENCE_PIXEL_SIZE_XY_UM / pixel_size_xy + y_half = round((YX_PATCH_SIZE[0] // 2) * scale_yx) + x_half = round((YX_PATCH_SIZE[1] // 2) * scale_yx) + + fov_path = f"{store_path}/{fov_name}" + with open_ome_zarr(fov_path, mode="r") as pos: + ch_idx = list(pos.channel_names).index(channel_name) + _, _, _, img_h, img_w = pos.data.shape + + fov = pos.data.oindex[t, ch_idx, z_focus, :, :] + + y0 = max(0, y_center - y_half) + y1 = min(img_h, y_center + y_half) + x0 = max(0, x_center - x_half) + x1 = min(img_w, x_center + x_half) + crop = pos.data.oindex[t, ch_idx, z_focus, y0:y1, x0:x1] + + return fov, crop, y_half, x_half -# --------------------------------------------------------------------------- -# Run the dataloader for each Dragonfly pixel size configuration -# --------------------------------------------------------------------------- center_crop = BatchedCenterSpatialCrop(roi_size=(Z_WINDOW, FINAL_YX_PATCH_SIZE[0], FINAL_YX_PATCH_SIZE[1])) -all_results = {} -for label, px_size in DRAGONFLY_PIXEL_SIZES.items(): - print(f"\n--- Dragonfly pixel_size_xy_um = {px_size} ({label}) ---") - tmp_path = make_tmp_parquet(px_size) - per_exp = draw_one_sample(tmp_path) - - for exp_name, data in per_exp.items(): - scale = data["scale"] - patch = data["patch"] # (C, Z, Y, X) at yx_patch_size - final = center_crop(patch[None])[0] - key = f"{exp_name}\n{label}" if exp_name == DRAGONFLY_EXP else exp_name - if exp_name == MANTIS_EXP and label != "real (0.206)": - continue # Mantis is unchanged, only show once - print(f" {exp_name}: scale_yx={scale[1]:.3f}, patch={tuple(patch.shape)}") - all_results[key] = { - "patch_2d": patch[0, 0].numpy(), +z_focuses = {} +for e in registry.experiments: + zr = registry.z_ranges[e.name] + z_focuses[e.name] = (zr[0] + zr[1]) // 2 + print(f" {e.name}: z_range={zr}, z_focus={z_focuses[e.name]}") + +print("Reading zarr crops...") + +results: list[dict] = [] + +# Mantis (reference — scale ≈ 1.0) +m_meta = per_exp[MANTIS_EXP]["meta"] +m_fov, m_crop, m_yh, m_xh = read_fov_and_crop(m_meta, REFERENCE_PIXEL_SIZE_XY_UM, z_focuses[MANTIS_EXP]) +m_tensor = torch.from_numpy(m_crop).float().unsqueeze(0).unsqueeze(0) # (1, 1, H, W) +m_rescaled = _rescale_patch(m_tensor, (1.0, 1.0, 1.0), (Z_WINDOW, YX_PATCH_SIZE[0], YX_PATCH_SIZE[1])) +m_final = center_crop(m_rescaled[None])[0] +m_dl_patch = per_exp[MANTIS_EXP]["patch"] +m_dl_final = center_crop(m_dl_patch[None])[0] +results.append( + { + "label": f"{MANTIS_EXP}\nreference ({REFERENCE_PIXEL_SIZE_XY_UM} µm/px)", + "exp": MANTIS_EXP, + "fov": m_fov, + "native_crop": m_crop, + "final_2d": m_final[0, 0].numpy(), + "dl_final_2d": m_dl_final[0, 0].numpy(), + "scale_yx": 1.0, + "pixel_size": REFERENCE_PIXEL_SIZE_XY_UM, + "y_half": m_yh, + "x_half": m_xh, + "meta": m_meta, + } +) + +# Dragonfly — one entry per pixel-size variant +d_meta = per_exp[DRAGONFLY_EXP]["meta"] +d_dl_patch = per_exp[DRAGONFLY_EXP]["patch"] +d_dl_final = center_crop(d_dl_patch[None])[0] +d_fov = None + +for i, (label, px_size) in enumerate(DRAGONFLY_PIXEL_SIZES.items()): + fov, crop, y_half, x_half = read_fov_and_crop(d_meta, px_size, z_focuses[DRAGONFLY_EXP]) + if d_fov is None: + d_fov = fov + + scale_yx = REFERENCE_PIXEL_SIZE_XY_UM / px_size + scale = (1.0, scale_yx, scale_yx) + target = (Z_WINDOW, YX_PATCH_SIZE[0], YX_PATCH_SIZE[1]) + + crop_tensor = torch.from_numpy(crop).float().unsqueeze(0).unsqueeze(0) + rescaled = _rescale_patch(crop_tensor, scale, target) + final = center_crop(rescaled[None])[0] + + print(f" {label}: scale_yx={scale_yx:.3f}, native_crop={crop.shape}, rescaled={tuple(rescaled.shape)}") + + results.append( + { + "label": f"{DRAGONFLY_EXP}\n{label}", + "exp": DRAGONFLY_EXP, + "fov": d_fov, + "native_crop": crop, "final_2d": final[0, 0].numpy(), - "scale": scale, - "pixel_size_label": label if exp_name == DRAGONFLY_EXP else "reference", + "dl_final_2d": d_dl_final[0, 0].numpy(), + "scale_yx": scale_yx, + "pixel_size": px_size, + "y_half": y_half, + "x_half": x_half, + "meta": d_meta, } + ) # --------------------------------------------------------------------------- -# Plot +# Step 3: Plot # --------------------------------------------------------------------------- -n = len(all_results) -fig, axes = plt.subplots(2, n, figsize=(5 * n, 10)) -if n == 1: - axes = axes[:, None] - def add_scalebar(ax, pixel_size_um, patch_px, bar_um=5.0): bar_px = bar_um / pixel_size_um @@ -165,7 +246,7 @@ def add_scalebar(ax, pixel_size_um, patch_px, bar_um=5.0): ax.text( x0 + bar_px / 2, y - 8, - f"{bar_um:.0f} um", + f"{bar_um:.0f} µm", color="white", fontsize=9, ha="center", @@ -173,30 +254,73 @@ def add_scalebar(ax, pixel_size_um, patch_px, bar_um=5.0): ) -for col, (key, r) in enumerate(all_results.items()): - scale = r["scale"] +def add_bbox(ax, y_center, x_center, y_half, x_half, color, label, img_shape): + y0 = max(0, y_center - y_half) + x0 = max(0, x_center - x_half) + h = min(y_center + y_half, img_shape[0]) - y0 + w = min(x_center + x_half, img_shape[1]) - x0 + rect = mpatches.Rectangle( + (x0, y0), + w, + h, + linewidth=2, + edgecolor=color, + facecolor="none", + linestyle="-", + label=label, + ) + ax.add_patch(rect) + + +n = len(results) +fig, axes = plt.subplots(3, n, figsize=(5 * n, 14)) +if n == 1: + axes = axes[:, None] + +for col, r in enumerate(results): + meta = r["meta"] + exp_name = r["exp"] + y_center = int(meta["y_clamp"]) + x_center = int(meta["x_clamp"]) - # Row 0: Dataloader output (yx_patch_size, after _rescale_patch) + # Row 0: Raw FOV with bounding box ax = axes[0, col] - patch = r["patch_2d"] - vmin, vmax = np.percentile(patch, (1, 99)) - ax.imshow(patch, cmap="gray", vmin=vmin, vmax=vmax) - add_scalebar(ax, REFERENCE_PIXEL_SIZE_XY_UM, YX_PATCH_SIZE[0]) - ax.set_title( - f"{key}\nscale_yx=({scale[1]:.3f}, {scale[2]:.3f})\nDataloader: {YX_PATCH_SIZE[0]}x{YX_PATCH_SIZE[1]} px", - fontsize=9, - fontweight="bold", - ) + fov = r["fov"] + vmin_raw, vmax_raw = np.percentile(fov, (1, 99)) + ax.imshow(fov, cmap="gray", vmin=vmin_raw, vmax=vmax_raw) + + if exp_name == DRAGONFLY_EXP: + for i, (lbl, px_size) in enumerate(DRAGONFLY_PIXEL_SIZES.items()): + s = REFERENCE_PIXEL_SIZE_XY_UM / px_size + yh = round((YX_PATCH_SIZE[0] // 2) * s) + xh = round((YX_PATCH_SIZE[1] // 2) * s) + add_bbox(ax, y_center, x_center, yh, xh, BBOX_COLORS[i], lbl, fov.shape) + ax.legend(loc="upper left", fontsize=7, framealpha=0.7) + else: + add_bbox( + ax, + y_center, + x_center, + r["y_half"], + r["x_half"], + BBOX_COLORS[0], + "reference", + fov.shape, + ) + + ax.set_title(f"{r['label']}\nRaw FOV (mid-Z)", fontsize=9, fontweight="bold") ax.axis("off") - # Row 1: After center crop = MODEL INPUT + # Row 1: Model input (native crop → rescale → center crop) ax = axes[1, col] final = r["final_2d"] + vmin, vmax = np.percentile(final, (1, 99)) ax.imshow(final, cmap="gray", vmin=vmin, vmax=vmax) add_scalebar(ax, REFERENCE_PIXEL_SIZE_XY_UM, FINAL_YX_PATCH_SIZE[0]) phys = FINAL_YX_PATCH_SIZE[0] * REFERENCE_PIXEL_SIZE_XY_UM ax.set_title( - f"Model input: {FINAL_YX_PATCH_SIZE[0]}x{FINAL_YX_PATCH_SIZE[1]} px | {phys:.1f} um", + f"Model input: {FINAL_YX_PATCH_SIZE[0]}×{FINAL_YX_PATCH_SIZE[1]} px | {phys:.1f} µm\n" + f"native crop: {r['native_crop'].shape} → scale_yx={r['scale_yx']:.3f}", fontsize=9, ) ax.axis("off") @@ -205,9 +329,27 @@ def add_scalebar(ax, pixel_size_um, patch_px, bar_um=5.0): spine.set_edgecolor("#2ecc71") spine.set_linewidth(3) + # Row 2: Actual dataloader output (for comparison with "real" variant) + ax = axes[2, col] + dl_final = r["dl_final_2d"] + vmin_dl, vmax_dl = np.percentile(dl_final, (1, 99)) + ax.imshow(dl_final, cmap="gray", vmin=vmin_dl, vmax=vmax_dl) + add_scalebar(ax, REFERENCE_PIXEL_SIZE_XY_UM, FINAL_YX_PATCH_SIZE[0]) + ax.set_title( + f"Dataloader output: {FINAL_YX_PATCH_SIZE[0]}×{FINAL_YX_PATCH_SIZE[1]} px\n" + f"(same for all variants — real pixel size)", + fontsize=9, + ) + ax.axis("off") + for spine in ax.spines.values(): + spine.set_visible(True) + spine.set_edgecolor("#e67e22") + spine.set_linewidth(3) + row_labels = [ - "Dataloader output\n(after _rescale_patch)", - "Model input\n(after center crop)", + "Raw FOV + crop region", + "Expected\n(native crop → rescale → crop)", + "Dataloader output\n(real pixel size)", ] for row_idx, label in enumerate(row_labels): axes[row_idx, 0].annotate( @@ -222,8 +364,9 @@ def add_scalebar(ax, pixel_size_um, patch_px, bar_um=5.0): ) fig.suptitle( - f"Pixel-size normalization proof: reference={REFERENCE_PIXEL_SIZE_XY_UM} um/px\n" - f"Same Dragonfly data with different declared pixel sizes -> different scale factors", + f"Pixel-size normalization: reference={REFERENCE_PIXEL_SIZE_XY_UM} µm/px\n" + f"Different pixel sizes → different native crops" + f" → same {FINAL_YX_PATCH_SIZE[0]}×{FINAL_YX_PATCH_SIZE[1]} model input", fontsize=12, fontweight="bold", y=0.99, @@ -233,3 +376,5 @@ def add_scalebar(ax, pixel_size_um, patch_px, bar_um=5.0): fig.savefig(OUTPUT_PATH, dpi=150, bbox_inches="tight") print(f"\nSaved: {OUTPUT_PATH}") plt.close(fig) + +# %% diff --git a/applications/dynaclr/scripts/dataloader_inspection/dataloader_demo.py b/applications/dynaclr/scripts/dataloader_inspection/dataloader_demo.py new file mode 100644 index 000000000..3744014d6 --- /dev/null +++ b/applications/dynaclr/scripts/dataloader_inspection/dataloader_demo.py @@ -0,0 +1,443 @@ +"""Dataloader demo: visualize raw, normalized, and augmented batches. + +Jupyter-style notebook (use ``# %%`` cells in VS Code or JupyterLab). + +Shows what the DynaCLR model actually receives as input. For each batch: + +- **Row 0 (anchor raw)**: raw patches from zarr (no transforms). +- **Row 1 (anchor aug)**: after normalization + augmentation + crop + (exactly what the model sees during training). +- **Row 2 (positive raw)**: positive pair raw patches. +- **Row 3 (positive aug)**: positive after transforms. + +Each column annotation shows experiment, marker, perturbation, timepoint, +and lineage/temporal checks. Batch composition is summarized in the title. + +Usage:: + + uv run python applications/dynaclr/scripts/dataloader_inspection/dataloader_demo.py +""" + +# ruff: noqa: E402, D103 + +# %% [markdown] +# # DynaCLR Dataloader Demo +# +# Visualize anchor/positive pairs with normalization and augmentation. +# All parameters are inline — edit and re-run cells. +# +# ## Augmentation pipeline +# +# The augmentation order matters. The pipeline is: +# +# 1. **Normalize** on full extraction patch ``(45, 256, 256)`` +# 2. **Affine** (rotate/scale/shear) on ``(45, 256, 256)`` +# 3. **RandSpatialCrop** to ``(40, 228, 228)`` — random Z for focus +# invariance + random YX for translation augmentation +# 4. **Flip, contrast, scale, smooth, noise** on ``(40, 228, 228)`` +# 5. **CenterCrop** to ``(32, 160, 160)`` — auto-appended by datamodule, +# removes rotation zero-fill artifacts at the edges + +# %% +from __future__ import annotations + +import copy +from collections import Counter +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + +from dynaclr.data.datamodule import MultiExperimentDataModule +from viscy_transforms import ( + BatchedRandAdjustContrastd, + BatchedRandAffined, + BatchedRandFlipd, + BatchedRandGaussianNoised, + BatchedRandGaussianSmoothd, + BatchedRandScaleIntensityd, + BatchedRandSpatialCropd, + NormalizeSampled, +) + +# %% [markdown] +# ## Configuration +# +# Everything is inline — edit and re-run. + +# %% +# --- Data source --- +CELL_INDEX_PATH = "/hpc/projects/organelle_phenotyping/models/collections/DynaCLR-3D-BagOfChannels-v4.parquet" + +# --- Patch extraction --- +Z_WINDOW = 32 +Z_EXTRACTION_WINDOW = 45 +Z_FOCUS_OFFSET = 0.3 +YX_PATCH_SIZE = (256, 256) +FINAL_YX_PATCH_SIZE = (160, 160) + +# --- Channel mode --- +# 1 = bag-of-channels (one random channel per sample, key="channel_0") +# None = all channels; ["Phase3D", "GFP"] = fixed list +CHANNELS_PER_SAMPLE = 1 +CHANNEL_NAMES = ["channel_0"] + +# --- Positive pair sampling --- +POSITIVE_CELL_SOURCE = "lookup" +POSITIVE_MATCH_COLUMNS = ["lineage_id"] +TAU_RANGE = (0.5, 2.0) +TAU_DECAY_RATE = 2.0 + +# --- Batch sampling --- +BATCH_SIZE = 10 +BATCH_GROUP_BY = None +STRATIFY_BY = ["perturbation"] +SEED = 42 + +# --- Pixel size normalization --- +REFERENCE_PIXEL_SIZE_XY_UM = 0.1494 +REFERENCE_PIXEL_SIZE_Z_UM = 0.174 +FOCUS_CHANNEL = "Phase3D" + +# --- Normalization --- +NORMALIZATIONS = [ + NormalizeSampled( + keys=CHANNEL_NAMES, + level="timepoint_statistics", + subtrahend="mean", + divisor="std", + ), +] + +# --- Augmentations --- +# The RandSpatialCrop goes after the affine to trim rotation artifacts +# and provide random Z + XY translation. The datamodule auto-appends +# a CenterCrop to [Z_WINDOW, 160, 160] at the end. +AUGMENTATIONS = [ + BatchedRandAffined( + keys=CHANNEL_NAMES, + prob=1, + scale_range=[[0.9, 1.1], [0.9, 1.1], [0.9, 1.1]], + rotate_range=[3.14, 0.0, 0.0], + shear_range=[0.05, 0.05, 0.0, 0.05, 0.0, 0.05], + ), + BatchedRandSpatialCropd( + keys=CHANNEL_NAMES, + roi_size=[40, 228, 228], + ), + BatchedRandFlipd(keys=CHANNEL_NAMES, spatial_axes=[1, 2], prob=0.5), + BatchedRandAdjustContrastd(keys=CHANNEL_NAMES, prob=0.5, gamma=(0.6, 1.6)), + BatchedRandScaleIntensityd(keys=CHANNEL_NAMES, prob=0.5, factors=0.5), + BatchedRandGaussianSmoothd( + keys=CHANNEL_NAMES, + prob=1, + sigma_x=[0.25, 0.50], + sigma_y=[0.25, 0.50], + sigma_z=[0.0, 0.2], + ), + BatchedRandGaussianNoised(keys=CHANNEL_NAMES, prob=1, mean=0.0, std=0.1), +] + +# --- Display --- +N_BATCHES = 4 +N_SHOW = 10 +NUM_WORKERS = 1 +SHOW_AUGMENTED = True +OUTPUT_DIR = Path("applications/dynaclr/scripts/dataloader_inspection/results/dataloader_demo") + + +# %% [markdown] +# ## Helpers + + +# %% +def _img_2d(tensor_5d: np.ndarray, sample_idx: int) -> np.ndarray: + """Extract a 2D slice from (B, C, Z, Y, X) for display.""" + img = tensor_5d[sample_idx] + if img.ndim == 4: + img = img[0, img.shape[1] // 2] + elif img.ndim == 3: + img = img[0] + return img + + +def plot_batch( + raw_batch: dict, + aug_batch: dict | None, + batch_idx: int, + n_show: int, + show_augmented: bool = True, + save_path: Path | None = None, +) -> None: + """Plot one batch: raw and augmented anchor/positive pairs.""" + anchor_raw = raw_batch["anchor"].numpy() + positive_raw = raw_batch.get("positive") + has_positive = positive_raw is not None + if has_positive: + positive_raw = positive_raw.numpy() + + anchor_meta = raw_batch["anchor_meta"] + positive_meta = raw_batch.get("positive_meta", [{}] * len(anchor_meta)) + n = min(n_show, len(anchor_meta)) + + row_labels = ["anchor (raw)"] + if show_augmented and aug_batch is not None: + row_labels.append("anchor (aug)") + if has_positive: + row_labels.append("positive (raw)") + if show_augmented and aug_batch is not None: + row_labels.append("positive (aug)") + n_rows = len(row_labels) + + fig, axes = plt.subplots(n_rows, n, figsize=(n * 2.0, n_rows * 2.4), squeeze=False) + + markers = Counter(m.get("marker", "?") for m in anchor_meta[:n]) + perts = Counter(m.get("perturbation", "?") for m in anchor_meta[:n]) + m_str = " ".join(f"{k}={v}" for k, v in markers.most_common(5)) + p_str = " ".join(f"{k}={v}" for k, v in perts.most_common(5)) + fig.suptitle( + f"Batch {batch_idx} | markers: {m_str} | pert: {p_str}", + fontsize=9, + fontweight="bold", + ) + + anchor_aug = aug_batch["anchor"].numpy() if (show_augmented and aug_batch) else None + positive_aug = None + if has_positive and show_augmented and aug_batch: + pa = aug_batch.get("positive") + positive_aug = pa.numpy() if pa is not None else None + + for i in range(n): + am = anchor_meta[i] + pm = positive_meta[i] if i < len(positive_meta) else {} + + row = 0 + img = _img_2d(anchor_raw, i) + vmin, vmax = np.percentile(img, [1, 99]) + axes[row, i].imshow(img, cmap="gray", vmin=vmin, vmax=vmax) + axes[row, i].set_xticks([]) + axes[row, i].set_yticks([]) + lines = [ + f"{am.get('experiment', '?')[:25]}", + f"fov={am.get('fov_name', '?')}", + f"track={am.get('global_track_id', '?')[-15:]}", + f"marker={am.get('marker', '?')}", + f"pert={am.get('perturbation', '?')}", + f"t={am.get('t', '?')}", + ] + if has_positive: + lin_ok = am.get("lineage_id") == pm.get("lineage_id") + dt_ok = am.get("t") != pm.get("t") + lines.append(f"lineage={'✓' if lin_ok else '✗'} Δt={'✓' if dt_ok else '✗'}") + axes[row, i].set_title("\n".join(lines), fontsize=5, linespacing=1.1) + + if anchor_aug is not None: + row += 1 + img_a = _img_2d(anchor_aug, i) + vmin_a, vmax_a = np.percentile(img_a, [1, 99]) + axes[row, i].imshow(img_a, cmap="gray", vmin=vmin_a, vmax=vmax_a) + axes[row, i].set_xticks([]) + axes[row, i].set_yticks([]) + axes[row, i].set_title(f"μ={img_a.mean():.2f} σ={img_a.std():.2f}", fontsize=5) + + if has_positive: + row += 1 + img_p = _img_2d(positive_raw, i) + vmin_p, vmax_p = np.percentile(img_p, [1, 99]) + axes[row, i].imshow(img_p, cmap="gray", vmin=vmin_p, vmax=vmax_p) + axes[row, i].set_xticks([]) + axes[row, i].set_yticks([]) + pos_lines = [ + f"fov={pm.get('fov_name', '?')}", + f"track={pm.get('global_track_id', '?')[-15:]}", + f"pert={pm.get('perturbation', '?')} t={pm.get('t', '?')}", + ] + axes[row, i].set_title("\n".join(pos_lines), fontsize=5, linespacing=1.1) + + if positive_aug is not None: + row += 1 + img_pa = _img_2d(positive_aug, i) + vmin_pa, vmax_pa = np.percentile(img_pa, [1, 99]) + axes[row, i].imshow(img_pa, cmap="gray", vmin=vmin_pa, vmax=vmax_pa) + axes[row, i].set_xticks([]) + axes[row, i].set_yticks([]) + axes[row, i].set_title(f"μ={img_pa.mean():.2f} σ={img_pa.std():.2f}", fontsize=5) + + for r, label in enumerate(row_labels): + axes[r, 0].set_ylabel(label, fontsize=7, fontweight="bold") + + plt.tight_layout() + if save_path: + fig.savefig(save_path, dpi=150, bbox_inches="tight") + print(f" Saved: {save_path}") + else: + plt.show() + # plt.close(fig) + + +# %% [markdown] +# ## Build DataModule +# +# Passes normalizations + augmentations directly to the DataModule. +# ``on_after_batch_transfer`` applies: normalizations → augmentations +# (including RandSpatialCrop) → auto-appended CenterCrop to final size. + +# %% +dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PATH, + z_window=Z_WINDOW, + z_extraction_window=Z_EXTRACTION_WINDOW, + z_focus_offset=Z_FOCUS_OFFSET, + yx_patch_size=YX_PATCH_SIZE, + final_yx_patch_size=FINAL_YX_PATCH_SIZE, + channels_per_sample=CHANNELS_PER_SAMPLE, + positive_cell_source=POSITIVE_CELL_SOURCE, + positive_match_columns=POSITIVE_MATCH_COLUMNS, + tau_range=TAU_RANGE, + tau_decay_rate=TAU_DECAY_RATE, + batch_size=BATCH_SIZE, + batch_group_by=BATCH_GROUP_BY, + stratify_by=STRATIFY_BY, + num_workers=NUM_WORKERS, + seed=SEED, + focus_channel=FOCUS_CHANNEL, + reference_pixel_size_xy_um=REFERENCE_PIXEL_SIZE_XY_UM, + reference_pixel_size_z_um=REFERENCE_PIXEL_SIZE_Z_UM, + channel_dropout_prob=0.0, + normalizations=NORMALIZATIONS, + augmentations=AUGMENTATIONS, +) +dm.setup("fit") + + +# Fake a minimal trainer so on_after_batch_transfer can check .predicting +class _FakeTrainer: + predicting = False + training = True + + +dm.trainer = _FakeTrainer() +print("DataModule ready.\n") + +va = dm.train_dataset.index.valid_anchors +print(f"Anchors: {len(va):,} | Experiments: {va['experiment'].nunique()}") +for exp, g in va.groupby("experiment"): + markers = g["marker"].value_counts().to_dict() if "marker" in g.columns else {} + perts = g["perturbation"].value_counts().to_dict() + print(f" {exp}: {len(g):,} anchors, markers={markers}, perturbations={perts}") + +# %% [markdown] +# ## Draw batches +# +# The dataloader returns raw patches ``(B, C, 45, 256, 256)`` (no transforms). +# ``dm.on_after_batch_transfer`` applies the full pipeline: +# +# 1. Normalize ``(45, 256, 256)`` +# 2. Affine ``(45, 256, 256)`` +# 3. RandSpatialCrop ``(40, 228, 228)`` +# 4. Flip / contrast / noise ``(40, 228, 228)`` +# 5. CenterCrop ``(32, 160, 160)`` (auto-appended) +# +# We deepcopy each batch so we can show raw vs augmented side by side. + +# %% +if OUTPUT_DIR: + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +dl = dm.train_dataloader() +dl_iter = iter(dl) + +for batch_idx in range(N_BATCHES): + print(f"\n--- Batch {batch_idx} ---") + batch = next(dl_iter) + + meta = batch["anchor_meta"] + n = len(meta) + markers = Counter(m.get("marker", "?") for m in meta) + perts = Counter(m.get("perturbation", "?") for m in meta) + print(f" {n} samples, markers={dict(markers)}, perturbations={dict(perts)}") + + raw_batch = copy.deepcopy(batch) + aug_batch = dm.on_after_batch_transfer(batch, dataloader_idx=0) if SHOW_AUGMENTED else None + + save_path = OUTPUT_DIR / f"train_batch_{batch_idx}.png" if OUTPUT_DIR else None + plot_batch( + raw_batch=raw_batch, + aug_batch=aug_batch, + batch_idx=batch_idx, + n_show=N_SHOW, + show_augmented=SHOW_AUGMENTED, + save_path=save_path, + ) + +# %% +print("\nDone.") + +# %% [markdown] +# ## Validation dataloader +# +# The val dataloader uses the same dataset class but a different subset +# (train/val FOV split). Worth inspecting because DDP validation-epoch-end +# syncs `loss/val` across ranks — a bad val batch on any rank can stall +# the whole sync, or produce NaN features that poison metrics aggregation. +# +# We also scan the raw val batch for NaN/Inf before and after normalization, +# to catch any rows the preprocess step failed to filter. + +# %% +val_dl = dm.val_dataloader() +val_iter = iter(val_dl) + +nan_batches_raw = 0 +nan_batches_norm = 0 +for batch_idx in range(N_BATCHES): + print(f"\n--- Val batch {batch_idx} ---") + batch = next(val_iter) + + meta = batch["anchor_meta"] + n = len(meta) + markers = Counter(m.get("marker", "?") for m in meta) + perts = Counter(m.get("perturbation", "?") for m in meta) + print(f" {n} samples, markers={dict(markers)}, perturbations={dict(perts)}") + + raw_anchor = batch["anchor"] + raw_pos = batch.get("positive") + raw_bad = raw_anchor.isnan().any() or raw_anchor.isinf().any() + if raw_pos is not None: + raw_bad = raw_bad or raw_pos.isnan().any() or raw_pos.isinf().any() + if raw_bad: + nan_batches_raw += 1 + print(" ⚠ raw val batch contains NaN/Inf") + + raw_batch = copy.deepcopy(batch) + aug_batch = dm.on_after_batch_transfer(batch, dataloader_idx=1) if SHOW_AUGMENTED else None + + if aug_batch is not None: + aa = aug_batch["anchor"] + ap = aug_batch.get("positive") + norm_bad = aa.isnan().any() or aa.isinf().any() + if ap is not None: + norm_bad = norm_bad or ap.isnan().any() or ap.isinf().any() + if norm_bad and not raw_bad: + nan_batches_norm += 1 + print(" ⚠ post-normalize val batch contains NaN/Inf") + + save_path = OUTPUT_DIR / f"val_batch_{batch_idx}.png" if OUTPUT_DIR else None + plot_batch( + raw_batch=raw_batch, + aug_batch=aug_batch, + batch_idx=batch_idx, + n_show=N_SHOW, + show_augmented=SHOW_AUGMENTED, + save_path=save_path, + ) + +print(f"\nVal scan over {N_BATCHES} batches: raw NaN/Inf={nan_batches_raw}, post-norm NaN/Inf={nan_batches_norm}") + +# %% [markdown] +# ## Re-run additional batches +# +# Edit ``batch_idx`` and re-run this cell to inspect more batches +# without restarting the dataloader iterator. + +# %% diff --git a/applications/dynaclr/scripts/dataloader_inspection/explore_gut_parquet.py b/applications/dynaclr/scripts/dataloader_inspection/explore_gut_parquet.py new file mode 100644 index 000000000..c12fd32d3 --- /dev/null +++ b/applications/dynaclr/scripts/dataloader_inspection/explore_gut_parquet.py @@ -0,0 +1,238 @@ +"""Minimal exploration of Zuben's gut cell classifier parquet with DynaCLR dataloader. + +Parquet: /hpc/projects/jacobo_group/zuben/proj/gutCellClassifier/data/dynaclr_cell_index.parquet + +Key findings: +- Flat schema: one row per (cell, t, channel). Compatible with MultiExperimentDataModule. +- NOT timelapse: all t=0, no temporal positives. Use positive_cell_source="self" (SimCLR). +- 25 experiments (AAY6/7/8 × day 0/1/2 × gut1-6), 4 channels, 6 perturbation stages. +- Missing: hours_post_perturbation (not needed for self-positive mode). + +Usage:: + + cd /home/eduardo.hirata/repos/viscy + uv run python applications/dynaclr/scripts/dataloader_inspection/explore_gut_parquet.py +""" + +# ruff: noqa: E402, D103 + +# %% [markdown] +# # Gut Cell Parquet Explorer + +# %% +from __future__ import annotations + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import zarr + +# %% [markdown] +# ## 1. Parquet Summary + +# %% +PARQUET_PATH = "/hpc/projects/jacobo_group/zuben/proj/gutCellClassifier/data/dynaclr_cell_index.parquet" +OUTPUT_DIR = Path("applications/dynaclr/scripts/dataloader_inspection/output/gut_parquet") +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +df = pd.read_parquet(PARQUET_PATH) +print(f"Shape: {df.shape}") +print(f"Columns: {df.columns.tolist()}\n") + +print(f"Experiments ({df['experiment'].nunique()}): {sorted(df['experiment'].unique())}\n") +print(f"Channels: {df['channel_name'].unique().tolist()}") +print(f"Perturbations: {sorted(df['perturbation'].unique())}") +print(f"t values: {sorted(df['t'].unique())} <- all 0, not timelapse") +print(f"z range: {df['z'].min()} - {df['z'].max()}") + +# %% +# Per-experiment cell counts and stage breakdown +print("\n## Per-experiment cell counts (unique cells × 4 channels = rows)") +for exp, g in df.groupby("experiment"): + n_cells = g["cell_id"].nunique() + stages = g["perturbation"].value_counts().to_dict() + print(f" {exp}: {n_cells} cells | stages={stages}") + +# %% [markdown] +# ## 2. Sample random patches from zarr +# +# Direct zarr read bypasses the iohub channel_names issue. +# Array shape: (T, C, Z, Y, X) = (1, 4, ~98, H, W) +# Channel order: nuclear, septate, brush_border, SuH + +CHANNEL_NAMES = ["nuclear", "septate", "brush_border", "SuH"] +PATCH_SIZE = 128 # pixels around cell center +N_SAMPLES_PER_CHANNEL = 4 +N_STAGES = 3 # show first N stages + + +def read_patch(row: pd.Series, channel_idx: int, patch: int = PATCH_SIZE) -> np.ndarray | None: + """Read a 2D patch around the cell center from zarr.""" + store = zarr.open(row["store_path"], mode="r") + pos_path = f"{row['well']}/{row['fov']}" + arr = store[pos_path]["0"] # (T, C, Z, Y, X) + z = int(row["z"]) + y = int(row["y"]) + x = int(row["x"]) + H, W = arr.shape[3], arr.shape[4] + half = patch // 2 + y0, y1 = max(0, y - half), min(H, y + half) + x0, x1 = max(0, x - half), min(W, x + half) + t = int(row["t"]) + return arr[t, channel_idx, z, y0:y1, x0:x1] + + +# %% [markdown] +# ## 3. Grid: channels × perturbation stages + +# %% +stages = sorted(df["perturbation"].unique())[:N_STAGES] +n_cols = N_SAMPLES_PER_CHANNEL +n_rows = len(CHANNEL_NAMES) * len(stages) + +fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2), squeeze=False) +fig.suptitle("Gut cell patches: rows=channel×stage, cols=random samples", fontsize=10) + +row_idx = 0 +for stage in stages: + stage_df = df[df["perturbation"] == stage] + for ch_i, ch_name in enumerate(CHANNEL_NAMES): + ch_df = stage_df[stage_df["channel_name"] == ch_name] + sampled = ch_df.sample(min(N_SAMPLES_PER_CHANNEL, len(ch_df)), random_state=42) + ax_row = axes[row_idx] + for col_i, (_, row) in enumerate(sampled.iterrows()): + patch = read_patch(row, ch_i) + ax = ax_row[col_i] + vmin, vmax = np.percentile(patch, [1, 99]) + ax.imshow(patch, cmap="gray", vmin=vmin, vmax=vmax) + ax.set_xticks([]) + ax.set_yticks([]) + if col_i == 0: + ax.set_ylabel(f"{ch_name}\n{stage}", fontsize=7) + row_idx += 1 + +plt.tight_layout() +save_path = OUTPUT_DIR / "patches_channel_by_stage.png" +fig.savefig(save_path, dpi=120, bbox_inches="tight") +print(f"Saved: {save_path}") + +# %% [markdown] +# ## 4. Stage distribution per experiment + +# %% +fig, ax = plt.subplots(figsize=(14, 4)) +pivot = ( + df.drop_duplicates(["cell_id", "perturbation"]).groupby(["experiment", "perturbation"]).size().unstack(fill_value=0) # noqa: PD010 +) +pivot.plot.bar(ax=ax, stacked=True, colormap="tab10") +ax.set_title("Cell counts by experiment and stage") +ax.set_xlabel("") +ax.tick_params(axis="x", rotation=45) +ax.legend(title="stage", bbox_to_anchor=(1, 1)) +plt.tight_layout() +save_path = OUTPUT_DIR / "stage_distribution.png" +fig.savefig(save_path, dpi=120, bbox_inches="tight") +print(f"Saved: {save_path}") + +# %% [markdown] +# ## 5. Channel distribution + +# %% +fig, axes = plt.subplots(1, 2, figsize=(10, 4)) +df.drop_duplicates(["cell_id", "channel_name"])["channel_name"].value_counts().plot.bar(ax=axes[0], color="steelblue") +axes[0].set_title("Cells per channel") +axes[0].tick_params(axis="x", rotation=30) + +df.drop_duplicates(["cell_id", "perturbation"])["perturbation"].value_counts().plot.bar(ax=axes[1], color="coral") +axes[1].set_title("Cells per stage") +axes[1].tick_params(axis="x", rotation=30) + +plt.tight_layout() +save_path = OUTPUT_DIR / "distributions.png" +fig.savefig(save_path, dpi=120, bbox_inches="tight") +print(f"Saved: {save_path}") + +# %% [markdown] +# ## 6. DynaCLR DataModule (self-positive / SimCLR) +# +# Not timelapse (t=0 only) so use positive_cell_source="self" — +# augmentation creates two views of the same cell. + +# %% +from dynaclr.data.datamodule import MultiExperimentDataModule + +Z_WINDOW = 1 +YX_PATCH_SIZE = (256, 256) +FINAL_YX_PATCH_SIZE = (224, 224) +BATCH_SIZE = 8 +NUM_WORKERS = 4 +N_BATCHES = 2 + +print("Building DataModule (self-positive, marker-grouped)...") +dm = MultiExperimentDataModule( + cell_index_path=PARQUET_PATH, + z_window=Z_WINDOW, + yx_patch_size=YX_PATCH_SIZE, + final_yx_patch_size=FINAL_YX_PATCH_SIZE, + batch_size=BATCH_SIZE, + num_workers=NUM_WORKERS, + channel_dropout_prob=0.0, + positive_cell_source="self", + channels_per_sample=1, + batch_group_by=["marker"], + stratify_by="perturbation", +) +dm.setup("fit") +print("Done.\n") + +va = dm.train_dataset.index.valid_anchors +print(f"Valid anchors: {len(va):,}") +print(f"Channels: {va['marker'].value_counts().to_dict()}") +print(f"Perturbations: {va['perturbation'].value_counts().to_dict()}") + + +# %% +def plot_batch(batch: dict, batch_idx: int, title: str, save_path: Path | None = None) -> None: + """Grid of anchor images annotated with channel + perturbation.""" + anchor = batch["anchor"].numpy() + meta = batch["anchor_meta"] + n = len(meta) + + fig, axes = plt.subplots(1, n, figsize=(n * 2.2, 2.8), squeeze=False) + channels_in_batch = {m.get("marker", "?") for m in meta} + perts_in_batch = {m.get("perturbation", "?") for m in meta} + fig.suptitle( + f"{title} — Batch {batch_idx}\nchannel={channels_in_batch} | stages={perts_in_batch}", + fontsize=9, + ) + for i, (ax, m) in enumerate(zip(axes[0], meta)): + img = anchor[i] + if img.ndim == 4: + img = img[0, img.shape[1] // 2] + elif img.ndim == 3: + img = img[0] + vmin, vmax = np.percentile(img, [1, 99]) + ax.imshow(img, cmap="gray", vmin=vmin, vmax=vmax) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title(f"{m.get('marker', '?')}\n{m.get('perturbation', '?')}", fontsize=6) + plt.tight_layout() + if save_path: + fig.savefig(save_path, dpi=120, bbox_inches="tight") + print(f" Saved: {save_path}") + + +dl = dm.train_dataloader() +for i, batch in enumerate(dl): + if i >= N_BATCHES: + break + meta = batch["anchor_meta"] + print(f"Batch {i}: {len(meta)} samples marker={{{meta[0].get('marker')}}} anchor shape={batch['anchor'].shape}") + plot_batch( + batch, i, "Gut: marker-grouped, perturbation-stratified", save_path=OUTPUT_DIR / f"dataloader_batch_{i}.png" + ) + +# %% +plt.show() diff --git a/applications/dynaclr/scripts/dataloader_inspection/test_2d_mip_augmentation.py b/applications/dynaclr/scripts/dataloader_inspection/test_2d_mip_augmentation.py new file mode 100644 index 000000000..5f9687ea5 --- /dev/null +++ b/applications/dynaclr/scripts/dataloader_inspection/test_2d_mip_augmentation.py @@ -0,0 +1,319 @@ +"""2D MIP augmentation demo — inspect and verify the pipeline. + +Jupyter-style notebook (use ``# %%`` cells in VS Code or JupyterLab). + +Shows what the 2D MIP model receives as input and verifies: + +- **Row 0 (anchor raw)**: center z-slice of the 20-slice raw extraction patch. +- **Row 1 (anchor aug)**: after normalize → affine → RandSpatialCrop(10) → MIP/center-slice → CenterCrop(160,160). + +Column annotations show marker, perturbation, and the z-reduction strategy +applied (MIP for fluorescence, center-slice for label-free). + +Pipeline: + extract (20, 192, 192) → normalize → affine → RandSpatialCrop(10, 192, 192) + → flip/contrast/noise → ZReduction (MIP or center-slice) → CenterCrop(1, 160, 160) + +Usage:: + + uv run python applications/dynaclr/scripts/dataloader_inspection/test_2d_mip_augmentation.py +""" + +# ruff: noqa: E402, D103 + +# %% [markdown] +# # 2D MIP Augmentation Demo +# +# Verify the z-reduction strategy per marker and visualize raw vs augmented. +# +# ## Pipeline +# +# 1. **Extract** 20 z-slices around focus +# 2. **Normalize** (subtract mean, divide std) +# 3. **Affine** (rotate/scale/shear) +# 4. **RandSpatialCrop** to (10, 192, 192) — random Z for focus invariance +# 5. **Flip, contrast, scale, smooth, noise** +# 6. **ZReduction**: MIP for fluorescence, center-slice for label-free +# 7. **CenterCrop** to (1, 160, 160) — auto-appended by datamodule + +# %% +from __future__ import annotations + +import copy +from collections import Counter +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch + +from dynaclr.data.datamodule import MultiExperimentDataModule +from viscy_data._utils import _transform_channel_wise +from viscy_data.channel_utils import parse_channel_name +from viscy_transforms import ( + BatchedChannelWiseZReductiond, + BatchedRandAdjustContrastd, + BatchedRandAffined, + BatchedRandFlipd, + BatchedRandGaussianNoised, + BatchedRandGaussianSmoothd, + BatchedRandScaleIntensityd, + BatchedRandSpatialCropd, + NormalizeSampled, +) + +# %% [markdown] +# ## Configuration + +# %% +CELL_INDEX_PATH = "/home/eduardo.hirata/repos/viscy/applications/dynaclr/configs/cell_index/test_2d_mip_mixed.parquet" + +Z_WINDOW = 1 +Z_EXTRACTION_WINDOW = 20 +Z_FOCUS_OFFSET = 0.5 +YX_PATCH_SIZE = (192, 192) +FINAL_YX_PATCH_SIZE = (160, 160) +CHANNEL_NAMES = ["channel_0"] + +BATCH_SIZE = 16 +N_BATCHES = 4 +N_SHOW = 10 +NUM_WORKERS = 4 +OUTPUT_DIR = Path("/home/eduardo.hirata/repos/viscy/applications/dynaclr/scripts/dataloader_inspection/results") + +# %% [markdown] +# ## Build DataModule + +# %% +normalizations = [ + NormalizeSampled( + keys=CHANNEL_NAMES, + level="timepoint_statistics", + subtrahend="mean", + divisor="std", + ) +] +augmentations = [ + BatchedRandAffined( + keys=CHANNEL_NAMES, + prob=0.8, + scale_range=[[0.8, 1.3], [0.8, 1.3], [0.8, 1.3]], + rotate_range=[3.14, 0.0, 0.0], + shear_range=[0.05, 0.05, 0.0, 0.05, 0.0, 0.05], + ), + BatchedRandFlipd(keys=CHANNEL_NAMES, spatial_axes=[1, 2], prob=0.5), + BatchedRandAdjustContrastd(keys=CHANNEL_NAMES, prob=0.5, gamma=(0.6, 1.6)), + BatchedRandScaleIntensityd(keys=CHANNEL_NAMES, prob=0.5, factors=0.5), + BatchedRandGaussianSmoothd( + keys=CHANNEL_NAMES, + prob=0.5, + sigma_x=[0.25, 0.50], + sigma_y=[0.25, 0.50], + sigma_z=[0.0, 0.0], + ), + BatchedRandGaussianNoised(keys=CHANNEL_NAMES, prob=0.5, mean=0.0, std=0.1), + # Random Z crop: select 10 of 20 extracted slices for Z-invariance. + BatchedRandSpatialCropd(keys=CHANNEL_NAMES, roi_size=[10, 192, 192]), + # Z-reduction: MIP for fluorescence, center-slice for label-free. + BatchedChannelWiseZReductiond(keys=CHANNEL_NAMES, allow_missing_keys=True), +] + +dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PATH, + z_window=Z_WINDOW, + z_extraction_window=Z_EXTRACTION_WINDOW, + z_focus_offset=Z_FOCUS_OFFSET, + yx_patch_size=YX_PATCH_SIZE, + final_yx_patch_size=FINAL_YX_PATCH_SIZE, + channels_per_sample=1, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + stratify_by=["perturbation", "marker"], + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=NUM_WORKERS, + seed=42, + focus_channel="Phase3D", + reference_pixel_size_xy_um=0.1494, + channel_dropout_prob=0.0, + normalizations=normalizations, + augmentations=augmentations, +) +dm.setup("fit") + +va = dm.train_dataset.index.valid_anchors +print(f"Anchors: {len(va):,} | Experiments: {va['experiment'].nunique()}") +for exp, g in va.groupby("experiment"): + markers = g["marker"].value_counts().to_dict() if "marker" in g.columns else {} + print(f" {exp}: {len(g):,} anchors markers={markers}") + + +# %% [markdown] +# ## Helpers + + +# %% +def _apply_augmentations(batch: dict) -> torch.Tensor: + """Apply the full augmentation pipeline to a raw batch, return (B,C,1,H,W).""" + norm_meta = batch.get("anchor_norm_meta") + is_labelfree = torch.tensor( + [parse_channel_name(m.get("marker", ""))["channel_type"] == "labelfree" for m in batch["anchor_meta"]], + dtype=torch.bool, + ) + return _transform_channel_wise( + transform=dm._augmentation_transform, + channel_names=dm._channel_names, + patch=batch["anchor"], + norm_meta=norm_meta, + extra={"_is_labelfree": is_labelfree}, + ) + + +def _img2d_raw(tensor: np.ndarray, sample_idx: int) -> np.ndarray: + """Center z-slice from raw (B, C, Z, Y, X) for display.""" + vol = tensor[sample_idx, 0] # (Z, Y, X) + return vol[vol.shape[0] // 2] + + +def _img2d_aug(tensor: np.ndarray, sample_idx: int) -> np.ndarray: + """2D image from augmented (B, C, 1, Y, X).""" + return tensor[sample_idx, 0, 0] + + +def _strategy(marker: str) -> str: + ct = parse_channel_name(marker)["channel_type"] + return "center-slice" if ct == "labelfree" else "MIP" + + +def plot_batch( + raw_batch: dict, + aug_patch: torch.Tensor, + batch_idx: int, + n_show: int = N_SHOW, + save_path: Path | None = None, +) -> None: + anchor_raw = raw_batch["anchor"].numpy() + anchor_aug = aug_patch.numpy() + meta = raw_batch.get("anchor_meta", []) + n = min(n_show, len(meta)) + + markers = Counter(m.get("marker", "?") for m in meta[:n]) + perts = Counter(m.get("perturbation", "?") for m in meta[:n]) + m_str = " ".join(f"{k}={v}" for k, v in markers.most_common(5)) + p_str = " ".join(f"{k}={v}" for k, v in perts.most_common(5)) + + fig, axes = plt.subplots(2, n, figsize=(n * 2.0, 2 * 2.4), squeeze=False) + fig.suptitle( + f"Batch {batch_idx} | markers: {m_str} | pert: {p_str}\n" + f"raw z-depth={anchor_raw.shape[2]} aug z-depth={anchor_aug.shape[2]}", + fontsize=8, + fontweight="bold", + ) + + for i in range(n): + am = meta[i] if i < len(meta) else {} + marker = am.get("marker", "?") + strategy = _strategy(marker) + + # Row 0: raw center z-slice + img_raw = _img2d_raw(anchor_raw, i) + vmin, vmax = np.percentile(img_raw, [1, 99]) + axes[0, i].imshow(img_raw, cmap="gray", vmin=vmin, vmax=vmax) + axes[0, i].set_xticks([]) + axes[0, i].set_yticks([]) + axes[0, i].set_title( + "\n".join( + [ + f"{am.get('experiment', '?')[:20]}", + f"marker={marker}", + f"pert={am.get('perturbation', '?')}", + f"t={am.get('t', '?')}", + f"z_reduction={strategy}", + ] + ), + fontsize=5, + linespacing=1.1, + ) + + # Row 1: augmented (post ZReduction) + img_aug = _img2d_aug(anchor_aug, i) + vmin_a, vmax_a = np.percentile(img_aug, [1, 99]) + axes[1, i].imshow(img_aug, cmap="gray", vmin=vmin_a, vmax=vmax_a) + axes[1, i].set_xticks([]) + axes[1, i].set_yticks([]) + axes[1, i].set_title(f"μ={img_aug.mean():.2f} σ={img_aug.std():.2f}", fontsize=5) + + axes[0, 0].set_ylabel("raw (center z)", fontsize=7, fontweight="bold") + axes[1, 0].set_ylabel("aug (MIP/center)", fontsize=7, fontweight="bold") + + plt.tight_layout() + if save_path: + fig.savefig(save_path, dpi=150, bbox_inches="tight") + print(f" Saved: {save_path}") + else: + plt.show() + + +def check_batch(batch_idx: int, raw_batch: dict, aug_patch: torch.Tensor) -> None: + """Assert shape and z-reduction correctness, print summary.""" + meta = raw_batch.get("anchor_meta", []) + + assert aug_patch.shape[2] == 1, f"Batch {batch_idx}: z should be 1, got {aug_patch.shape}" + assert aug_patch.shape[3] == FINAL_YX_PATCH_SIZE[0], f"Y should be {FINAL_YX_PATCH_SIZE[0]}" + assert aug_patch.shape[4] == FINAL_YX_PATCH_SIZE[1], f"X should be {FINAL_YX_PATCH_SIZE[1]}" + print(f" [PASS] shape: {tuple(aug_patch.shape)}") + + n_lf, n_fl = 0, 0 + for i, m in enumerate(meta): + marker = m.get("marker", "") + ct = parse_channel_name(marker)["channel_type"] + assert not torch.all(aug_patch[i] == 0), f"Sample {i} ({marker}) is all zeros" + if ct == "labelfree": + n_lf += 1 + else: + n_fl += 1 + + raw_z = raw_batch["anchor"].shape[2] + print(f" [PASS] label-free (center-slice)={n_lf} fluorescence (MIP)={n_fl} raw_z={raw_z}") + print(f" [INFO] markers: {dict(Counter(m.get('marker', '?') for m in meta))}") + + +# %% [markdown] +# ## Draw batches + +# %% +if OUTPUT_DIR: + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +dl = dm.train_dataloader() +dl_iter = iter(dl) + +for batch_idx in range(N_BATCHES): + print(f"\n--- Batch {batch_idx} ---") + batch = next(dl_iter) + raw_batch = copy.deepcopy(batch) + aug_patch = _apply_augmentations(batch) + check_batch(batch_idx, raw_batch, aug_patch) + save_path = OUTPUT_DIR / f"batch_{batch_idx}.png" if OUTPUT_DIR else None + plot_batch(raw_batch, aug_patch, batch_idx, save_path=save_path) + +# %% +print("\nDone.") + +# %% [markdown] +# ## Re-run additional batches +# +# Edit ``batch_idx`` and re-run this cell to inspect more batches +# without restarting the dataloader iterator. + +# %% +batch_idx = N_BATCHES +batch = next(dl_iter) +raw_batch = copy.deepcopy(batch) +aug_patch = _apply_augmentations(batch) +check_batch(batch_idx, raw_batch, aug_patch) +plot_batch(raw_batch, aug_patch, batch_idx) + +# %% diff --git a/applications/dynaclr/scripts/evaluation/compare_evals.py b/applications/dynaclr/scripts/evaluation/compare_evals.py new file mode 100644 index 000000000..573f98f11 --- /dev/null +++ b/applications/dynaclr/scripts/evaluation/compare_evals.py @@ -0,0 +1,519 @@ +"""Compare evaluation results across multiple model runs. + +Reads outputs produced by ``dynaclr evaluate`` from multiple model eval directories, +compares smoothness, linear classifier AUROC, and MMD activity z-scores side by side, +and writes summary CSVs and plots to a shared output directory. + +Usage +----- +python compare_evals.py -c eval_registry.yml + +Registry YAML format +-------------------- +models: + - name: DynaCLR-v3 + eval_dir: /path/to/eval_v3 + - name: DINOv3-MLP + eval_dir: /path/to/eval_dino +output_dir: /path/to/comparison_output +fdr_threshold: 0.05 # optional, default 0.05 +""" + +from __future__ import annotations + +import re +from pathlib import Path + +import click +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import yaml +from matplotlib.lines import Line2D + +# Per-class metric columns are emitted by the LC trainer as +# ``val__`` (precision, recall, f1). ``val_weighted_*`` and +# ``val__auroc`` are excluded — only point-classification metrics here. +_PER_CLASS_METRIC_RE = re.compile(r"^val_(?P.+)_(?Pprecision|recall|f1)$") + +# --------------------------------------------------------------------------- +# Registry loading +# --------------------------------------------------------------------------- + + +def _load_registry(path: Path) -> tuple[list[dict], Path, float]: + with open(path) as f: + raw = yaml.safe_load(f) + output_dir = Path(raw["output_dir"]) + fdr_threshold = float(raw.get("fdr_threshold", 0.05)) + return raw["models"], output_dir, fdr_threshold + + +def _build_model_palette(model_names: list[str]) -> dict[str, tuple[float, float, float, float]]: + """Map model name → RGBA color, stable across all plots in one run. + + Uses ``tab10`` for ≤10 models (10 visually distinct hues) and ``tab20`` + for 11–20 models. Colors are picked from discrete colormap indices so + they don't blur into each other when many models are compared. + """ + n = len(model_names) + cmap = plt.cm.tab10 if n <= 10 else plt.cm.tab20 + return {name: cmap(i % cmap.N) for i, name in enumerate(sorted(model_names))} + + +# --------------------------------------------------------------------------- +# Smoothness +# --------------------------------------------------------------------------- + + +def _load_smoothness(models: list[dict]) -> pd.DataFrame | None: + """Load per-marker smoothness CSVs from all evals and concat. + + Each eval writes one ``*_per_marker_smoothness.csv`` per (experiment, marker) + with columns including ``smoothness_score`` and ``dynamic_range``. We + concat all of them and tag with the model name; the plotting step + aggregates (mean across experiments+markers) for the headline bar chart. + """ + frames = [] + for entry in models: + smoothness_dir = Path(entry["eval_dir"]) / "smoothness" + csvs = list(smoothness_dir.glob("*_per_marker_smoothness.csv")) + if not csvs: + click.echo(f"[smoothness] No smoothness CSV found for {entry['name']}", err=True) + continue + per_csv = [pd.read_csv(c) for c in csvs] + df = pd.concat(per_csv, ignore_index=True) + df["model"] = entry["name"] + frames.append(df) + if not frames: + return None + return pd.concat(frames, ignore_index=True) + + +def _plot_smoothness(df: pd.DataFrame, output_dir: Path, model_color: dict) -> None: + """Plot per-model smoothness as mean ± std across (experiment, marker) rows.""" + metrics = ["smoothness_score", "dynamic_range"] + present = [m for m in metrics if m in df.columns] + if not present: + return + + # Aggregate across (experiment, marker) per model: mean ± std. + agg = df.groupby("model")[present].agg(["mean", "std"]) + + fig, axes = plt.subplots(1, len(present), figsize=(5 * len(present), 4), squeeze=False) + for ax, metric in zip(axes[0], present): + means = agg[(metric, "mean")] + stds = agg[(metric, "std")].fillna(0.0) + bar_colors = [model_color.get(m, "gray") for m in means.index] + ax.bar( + means.index, + means.values, + yerr=stds.values, + capsize=4, + color=bar_colors, + ) + ax.set_title(metric.replace("_", " ").title()) + ax.set_ylabel(metric) + plt.setp(ax.get_xticklabels(), rotation=30, ha="right") + + fig.tight_layout() + out = output_dir / "smoothness_comparison.pdf" + fig.savefig(out, bbox_inches="tight") + plt.close(fig) + click.echo(f"[smoothness] Saved: {out}", err=True) + + +# --------------------------------------------------------------------------- +# Linear classifiers +# --------------------------------------------------------------------------- + + +def _load_linear_classifiers(models: list[dict]) -> pd.DataFrame | None: + frames = [] + for entry in models: + csv = Path(entry["eval_dir"]) / "linear_classifiers" / "metrics_summary.csv" + if not csv.exists(): + click.echo(f"[linear_classifiers] Not found for {entry['name']}: {csv}", err=True) + continue + df = pd.read_csv(csv) + df["model"] = entry["name"] + frames.append(df) + if not frames: + return None + return pd.concat(frames, ignore_index=True) + + +def _plot_linear_classifiers(df: pd.DataFrame, output_dir: Path, model_color: dict) -> None: + # The LC writer emits per-split metrics: train_auroc and val_auroc. + # Plot val_auroc (held-out generalization) — that is the headline number + # for cross-model comparison. + auroc_col = "val_auroc" + if auroc_col not in df.columns: + return + + # Marker breakdown lives in `marker_filter` (per-marker LCs), not `marker`. + marker_col = "marker_filter" if "marker_filter" in df.columns else "marker" + + tasks = sorted(df["task"].unique()) if "task" in df.columns else ["all"] + ncols = min(4, len(tasks)) + nrows = int(np.ceil(len(tasks) / ncols)) + fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 4 * nrows), squeeze=False) + axes_flat = axes.flatten() + models = sorted(df["model"].unique()) + + for ax_idx, task in enumerate(tasks): + ax = axes_flat[ax_idx] + sub = df[df["task"] == task] if "task" in df.columns else df + pivot = sub.pivot_table( + index=marker_col if marker_col in sub.columns else sub.index, + columns="model", + values=auroc_col, + ) + pivot = pivot.reindex(columns=models) + + x = np.arange(len(pivot)) + width = 0.8 / len(models) + for i, model in enumerate(models): + if model not in pivot.columns: + continue + ax.bar(x + i * width, pivot[model].values, width, label=model, color=model_color[model]) + + ax.set_xticks(x + width * (len(models) - 1) / 2) + ax.set_xticklabels(pivot.index, rotation=45, ha="right", fontsize=8) + ax.set_ylabel("Validation AUROC") + ax.set_title(task, fontsize=9) + ax.axhline(0.5, color="gray", linewidth=0.8, linestyle="--") + ax.set_ylim(0, 1.05) + + for ax in axes_flat[len(tasks) :]: + ax.set_visible(False) + + handles = [plt.Rectangle((0, 0), 1, 1, color=model_color[m], label=m) for m in models] + fig.legend(handles=handles, loc="lower center", ncol=len(models), fontsize=8, bbox_to_anchor=(0.5, 0)) + fig.tight_layout(rect=[0, 0.05, 1, 1]) + out = output_dir / "linear_classifiers_comparison.pdf" + fig.savefig(out, bbox_inches="tight") + plt.close(fig) + click.echo(f"[linear_classifiers] Saved: {out}", err=True) + + +def _discover_per_class_metrics(df: pd.DataFrame) -> dict[str, list[str]]: + """Map class name → list of available metrics (precision, recall, f1). + + Columns are auto-detected from ``val__`` names so the plot + works across tasks with different label sets (infected/uninfected, + interphase/mitosis, alive/dead, noremodel/remodel, etc.). + """ + found: dict[str, set[str]] = {} + for col in df.columns: + m = _PER_CLASS_METRIC_RE.match(col) + if m is None: + continue + cls, metric = m.group("cls"), m.group("metric") + if cls in {"weighted", "macro"}: + continue + # Only include classes that have at least one non-null value in this slice — + # avoids polluting a task panel with empty bars for classes from other tasks. + if df[col].notna().any(): + found.setdefault(cls, set()).add(metric) + return {cls: sorted(metrics) for cls, metrics in found.items()} + + +def _plot_linear_classifiers_per_class(df: pd.DataFrame, output_dir: Path, model_color: dict) -> None: + """Per-class precision/recall/F1 grouped bars per (task, marker_filter). + + AUROC is prevalence-invariant and rewards good ranking, but the + infectomics LC tasks are heavily imbalanced (cell_division_state ~99/1, + organelle_state ~91/9). Per-class precision and recall expose whether + the classifier is actually usable at the chosen decision threshold for + the minority class — they are the metrics that move when imbalance bites. + """ + marker_col = "marker_filter" if "marker_filter" in df.columns else "marker" + if "task" not in df.columns: + return + + models = sorted(df["model"].unique()) + + # One subplot per (task, marker_filter) so SEC61B vs G3BP1 stay separate. + # Normalize missing marker_filter to None so iteration semantics are + # consistent (pandas yields float NaN which is not None and not equality-comparable). + if marker_col in df.columns: + seen: set[tuple[str, str | None]] = set() + panels: list[tuple[str, str | None]] = [] + for _, row in df[["task", marker_col]].drop_duplicates().iterrows(): + mf = row[marker_col] + mf_norm = None if pd.isna(mf) else mf + key = (row["task"], mf_norm) + if key not in seen: + seen.add(key) + panels.append(key) + panels.sort(key=lambda p: (p[0], "" if p[1] is None else p[1])) + else: + panels = [(t, None) for t in sorted(df["task"].unique())] + + if not panels: + return + + ncols = min(3, len(panels)) + nrows = int(np.ceil(len(panels) / ncols)) + fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 4.5 * nrows), squeeze=False) + axes_flat = axes.flatten() + + any_panel_drawn = False + for ax_idx, (task, marker_filter) in enumerate(panels): + ax = axes_flat[ax_idx] + sub = df[df["task"] == task] + if marker_col in sub.columns: + if marker_filter is None: + sub = sub[sub[marker_col].isna()] + else: + sub = sub[sub[marker_col] == marker_filter] + if sub.empty: + ax.set_visible(False) + continue + + per_class = _discover_per_class_metrics(sub) + if not per_class: + ax.set_visible(False) + continue + + # x-axis groups: (class, metric) pairs; bars within each group: models. + groups: list[tuple[str, str]] = [] + for cls in sorted(per_class): + for metric in ["precision", "recall", "f1"]: + if metric in per_class[cls]: + groups.append((cls, metric)) + + x = np.arange(len(groups)) + width = 0.8 / max(len(models), 1) + for i, model in enumerate(models): + row = sub[sub["model"] == model] + if row.empty: + continue + row = row.iloc[0] + values = [row.get(f"val_{cls}_{metric}", np.nan) for cls, metric in groups] + ax.bar(x + i * width, values, width, label=model, color=model_color[model]) + + ax.set_xticks(x + width * (len(models) - 1) / 2) + ax.set_xticklabels([f"{cls}\n{metric}" for cls, metric in groups], fontsize=8) + ax.set_ylim(0, 1.05) + ax.set_ylabel("Score") + title = f"{task}" if marker_filter in (None, "") else f"{task} ({marker_col}={marker_filter})" + + # Annotate panel with minority-class N when val support is available. + # Per-class support is identical across models (same dataset, same seed), + # so we read it from the first non-null row. + support_cols = {cls: f"val_{cls}_support" for cls in per_class} + supports: dict[str, int] = {} + for cls, col in support_cols.items(): + if col in sub.columns: + vals = sub[col].dropna() + if not vals.empty: + supports[cls] = int(vals.iloc[0]) + if supports: + total_n = sum(supports.values()) + minority_cls = min(supports, key=supports.get) + minority_n = supports[minority_cls] + minority_pct = 100 * minority_n / total_n if total_n else 0.0 + title += f"\nval N={total_n} | minority {minority_cls}={minority_n} ({minority_pct:.1f}%)" + ax.set_title(title, fontsize=9) + ax.axhline(0.5, color="gray", linewidth=0.6, linestyle=":") + any_panel_drawn = True + + for ax in axes_flat[len(panels) :]: + ax.set_visible(False) + + if not any_panel_drawn: + plt.close(fig) + click.echo("[linear_classifiers] No per-class metrics found in metrics_summary.csv", err=True) + return + + handles = [plt.Rectangle((0, 0), 1, 1, color=model_color[m], label=m) for m in models] + fig.legend(handles=handles, loc="lower center", ncol=len(models), fontsize=8, bbox_to_anchor=(0.5, 0)) + fig.tight_layout(rect=[0, 0.05, 1, 1]) + out = output_dir / "linear_classifiers_per_class.pdf" + fig.savefig(out, bbox_inches="tight") + plt.close(fig) + click.echo(f"[linear_classifiers] Saved: {out}", err=True) + + +# --------------------------------------------------------------------------- +# MMD +# --------------------------------------------------------------------------- + + +def _load_mmd(models: list[dict]) -> pd.DataFrame | None: + frames = [] + for entry in models: + mmd_root = Path(entry["eval_dir"]) / "mmd" + if not mmd_root.exists(): + click.echo(f"[mmd] No mmd directory for {entry['name']}", err=True) + continue + for csv in sorted(mmd_root.rglob("mmd_results.csv")): + block_name = csv.parent.name + df = pd.read_csv(csv) + df["model"] = entry["name"] + df["block"] = block_name + frames.append(df) + if not frames: + return None + return pd.concat(frames, ignore_index=True) + + +def _plot_mmd_kinetics(df: pd.DataFrame, output_dir: Path, fdr_threshold: float, model_color: dict) -> None: + temporal = df.dropna(subset=["hours_bin_start", "hours_bin_end"]).copy() + if temporal.empty: + click.echo("[mmd] No temporal rows — skipping kinetics plot", err=True) + return + + temporal["hours_mid"] = (temporal["hours_bin_start"] + temporal["hours_bin_end"]) / 2 + markers = sorted(temporal["marker"].unique()) + models = sorted(temporal["model"].unique()) + labels = sorted(temporal["label"].unique()) + blocks = sorted(temporal["block"].unique()) + + for block in blocks: + sub_block = temporal[temporal["block"] == block] + ncols = min(4, len(markers)) + nrows = int(np.ceil(len(markers) / ncols)) + fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 4 * nrows), squeeze=False) + axes_flat = axes.flatten() + + linestyles = ["-", "--", ":", "-."] + label_ls = dict(zip(labels, linestyles[: len(labels)])) + + for ax_idx, marker in enumerate(markers): + ax = axes_flat[ax_idx] + sub = sub_block[sub_block["marker"] == marker] + for model in models: + for label in labels: + grp = sub[(sub["model"] == model) & (sub["label"] == label)].sort_values("hours_mid") + if grp.empty: + continue + ax.plot( + grp["hours_mid"], + grp["activity_zscore"], + color=model_color[model], + linestyle=label_ls[label], + linewidth=1.5, + ) + if "q_value" in grp.columns: + sig = grp[grp["q_value"] < fdr_threshold] + ax.scatter(sig["hours_mid"], sig["activity_zscore"], color=model_color[model], s=30, zorder=5) + ax.axhline(0, color="gray", linewidth=0.8, linestyle="--") + ax.set_title(marker, fontsize=9) + ax.set_xlabel("Hours post perturbation") + ax.set_ylabel("Activity z-score") + + for ax in axes_flat[len(markers) :]: + ax.set_visible(False) + + legend_handles = [Line2D([0], [0], color=model_color[m], linewidth=2, label=m) for m in models] + legend_handles += [ + Line2D([0], [0], color="black", linestyle=label_ls[lb], linewidth=1.5, label=lb) for lb in labels + ] + fig.legend( + handles=legend_handles, + loc="lower center", + ncol=len(models) + len(labels), + fontsize=8, + bbox_to_anchor=(0.5, 0), + ) + fig.tight_layout(rect=[0, 0.05, 1, 1]) + + out = output_dir / f"mmd_kinetics_{block}.pdf" + fig.savefig(out, bbox_inches="tight") + plt.close(fig) + click.echo(f"[mmd] Saved: {out}", err=True) + + +def _plot_mmd_summary_heatmap(summary: pd.DataFrame, output_dir: Path) -> None: + blocks = sorted(summary["block"].unique()) + labels = sorted(summary["label"].unique()) + models = sorted(summary["model"].unique()) + + for block in blocks: + sub_block = summary[summary["block"] == block] + ncols = len(labels) + markers = sorted(sub_block["marker"].unique()) + fig, axes = plt.subplots(1, ncols, figsize=(5 * ncols, max(3, len(markers) * 0.5 + 1)), squeeze=False) + for col_idx, label in enumerate(labels): + ax = axes[0, col_idx] + pivot = sub_block[sub_block["label"] == label].pivot_table( + index="marker", columns="model", values="mean_activity_zscore", aggfunc="mean" + ) + pivot = pivot.reindex(columns=models) + vmax = np.nanpercentile(np.abs(pivot.values), 95) if pivot.values.size > 0 else 1.0 + im = ax.imshow(pivot.values, aspect="auto", cmap="RdBu_r", vmin=-vmax, vmax=vmax) + ax.set_xticks(range(len(models))) + ax.set_xticklabels(models, rotation=45, ha="right", fontsize=8) + ax.set_yticks(range(len(pivot.index))) + ax.set_yticklabels(pivot.index, fontsize=8) + ax.set_title(label, fontsize=9) + plt.colorbar(im, ax=ax, label="Mean activity z-score") + + fig.tight_layout() + out = output_dir / f"mmd_summary_heatmap_{block}.pdf" + fig.savefig(out, bbox_inches="tight") + plt.close(fig) + click.echo(f"[mmd] Saved: {out}", err=True) + + +def _build_mmd_summary(df: pd.DataFrame) -> pd.DataFrame: + return ( + df.groupby(["block", "model", "marker", "label"])["activity_zscore"] + .agg(mean_activity_zscore="mean", n_bins="count") + .reset_index() + .sort_values(["block", "label", "marker", "mean_activity_zscore"], ascending=[True, True, True, False]) + ) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +@click.command() +@click.option( + "-c", "--config", required=True, type=click.Path(exists=True, path_type=Path), help="Path to eval_registry.yml" +) +def main(config: Path) -> None: + """Compare evaluation results across model runs.""" + models, output_dir, fdr_threshold = _load_registry(config) + output_dir.mkdir(parents=True, exist_ok=True) + + # Build the model→color palette once from the registry's model list so + # every plot in this run uses the same color for the same model. + model_color = _build_model_palette([m["name"] for m in models]) + + # Smoothness + smoothness_df = _load_smoothness(models) + if smoothness_df is not None: + smoothness_df.to_csv(output_dir / "smoothness_comparison.csv", index=False) + _plot_smoothness(smoothness_df, output_dir, model_color) + click.echo("\n## Smoothness\n") + click.echo(smoothness_df[["model", "smoothness_score", "dynamic_range"]].to_markdown(index=False)) + + # Linear classifiers + lc_df = _load_linear_classifiers(models) + if lc_df is not None: + lc_df.to_csv(output_dir / "linear_classifiers_comparison.csv", index=False) + _plot_linear_classifiers(lc_df, output_dir, model_color) + _plot_linear_classifiers_per_class(lc_df, output_dir, model_color) + summary_cols = [c for c in ["model", "task", "marker", "auroc", "f1"] if c in lc_df.columns] + click.echo("\n## Linear Classifiers\n") + click.echo(lc_df[summary_cols].to_markdown(index=False)) + + # MMD + mmd_df = _load_mmd(models) + if mmd_df is not None: + mmd_summary = _build_mmd_summary(mmd_df) + mmd_summary.to_csv(output_dir / "mmd_comparison.csv", index=False) + _plot_mmd_kinetics(mmd_df, output_dir, fdr_threshold, model_color) + _plot_mmd_summary_heatmap(mmd_summary, output_dir) + click.echo("\n## MMD activity z-score\n") + click.echo(mmd_summary.to_markdown(index=False)) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/evaluation/microglia_alfi_analysis.py b/applications/dynaclr/scripts/evaluation/microglia_alfi_analysis.py new file mode 100644 index 000000000..2bd1e4672 --- /dev/null +++ b/applications/dynaclr/scripts/evaluation/microglia_alfi_analysis.py @@ -0,0 +1,361 @@ +"""Embedding analysis for microglia and ALFI datasets. + +Microglia (unsupervised): + PCA/UMAP colored by perturbation condition and per-track embedding + displacement — proxy for morphological dynamics (Khurana et al. 2022, + https://doi.org/10.1091/mbc.E21-11-0561). + +ALFI HeLa (supervised): + PCA/UMAP colored by cell cycle phase annotations (interphase vs mitosis) + from the ALFI dataset (Dang et al. 2023, + https://doi.org/10.1038/s41597-023-02540-1). + +Usage +----- +python scripts/evaluation/microglia_alfi_analysis.py \\ + --microglia-embeddings /path/to/microglia/embeddings.zarr \\ + --alfi-embeddings /path/to/alfi/embeddings.zarr \\ + --output-dir /path/to/output/ +""" + +import argparse +from pathlib import Path + +import anndata as ad +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from umap import UMAP + +ALFI_ANNOTATIONS = Path("/hpc/projects/organelle_phenotyping/datasets/annotations/ALFI/ALFI_combined_annotations.csv") + +DIVISION_PALETTE = { + "interphase": "cornflowerblue", + "mitosis": "darkorange", +} + + +def compute_track_displacement_metrics(adata: ad.AnnData) -> pd.DataFrame: + """Compute per-track embedding displacement metrics. + + Parameters + ---------- + adata : AnnData + Embeddings with obs columns fov_name, track_id, t. + adata.X contains raw embeddings (N x D). + + Returns + ------- + pd.DataFrame + One row per track with columns: + fov_name, track_id, mean_step_size, total_path_length, + net_displacement, track_length, and any available metadata columns. + """ + embeddings = np.asarray(adata.X) + obs = adata.obs.copy() + obs["_idx"] = np.arange(len(obs)) + + meta_cols = [c for c in ["perturbation", "marker", "experiment"] if c in obs.columns] + records = [] + + for (fov, tid), grp in obs.groupby(["fov_name", "track_id"], sort=False): + grp = grp.sort_values("t") + idxs = grp["_idx"].values + if len(idxs) < 2: + continue + embs = embeddings[idxs] + steps = np.linalg.norm(np.diff(embs, axis=0), axis=1) + record = { + "fov_name": fov, + "track_id": tid, + "mean_step_size": steps.mean(), + "total_path_length": steps.sum(), + "net_displacement": float(np.linalg.norm(embs[-1] - embs[0])), + "track_length": len(idxs), + } + for col in meta_cols: + record[col] = grp[col].iloc[0] + records.append(record) + + return pd.DataFrame(records) + + +def _get_or_compute_pca(adata: ad.AnnData, features_scaled: np.ndarray) -> np.ndarray: + if "X_pca" in adata.obsm: + return adata.obsm["X_pca"] + pca = PCA(n_components=32) + return pca.fit_transform(features_scaled) + + +def _get_or_compute_umap(adata: ad.AnnData, features_scaled: np.ndarray) -> np.ndarray: + if "X_umap" in adata.obsm: + return adata.obsm["X_umap"] + print(" Computing UMAP...") + return UMAP(n_components=2, n_neighbors=15, random_state=42).fit_transform(features_scaled) + + +def analyze_microglia(adata: ad.AnnData, output_dir: Path) -> None: + """Run microglia displacement analysis and save plots.""" + print(f"Microglia: {adata.shape[0]:,} observations") + + features = np.asarray(adata.X) + features_scaled = StandardScaler().fit_transform(features) + pca_emb = _get_or_compute_pca(adata, features_scaled) + umap_emb = _get_or_compute_umap(adata, features_scaled) + + track_metrics = compute_track_displacement_metrics(adata) + print(f" {len(track_metrics):,} tracks") + + obs = adata.obs.copy().merge( + track_metrics[["fov_name", "track_id", "mean_step_size", "net_displacement"]], + on=["fov_name", "track_id"], + how="left", + ) + + perturbations = sorted(obs["perturbation"].unique()) if "perturbation" in obs.columns else [] + markers = sorted(obs["marker"].unique()) if "marker" in obs.columns else [] + palette_p = dict(zip(perturbations, sns.color_palette("tab10", len(perturbations)))) + palette_m = dict(zip(markers, sns.color_palette("Set2", len(markers)))) + + plot_df = pd.DataFrame( + { + "PC1": pca_emb[:, 0], + "PC2": pca_emb[:, 1], + "UMAP1": umap_emb[:, 0], + "UMAP2": umap_emb[:, 1], + "perturbation": obs["perturbation"].values if "perturbation" in obs.columns else "unknown", + "marker": obs["marker"].values if "marker" in obs.columns else "unknown", + "mean_step_size": obs["mean_step_size"].values, + "net_displacement": obs["net_displacement"].values, + } + ) + + vmin = np.nanpercentile(plot_df["mean_step_size"], 5) + vmax = np.nanpercentile(plot_df["mean_step_size"], 95) + + for reduction, x_col, y_col in [("pca", "PC1", "PC2"), ("umap", "UMAP1", "UMAP2")]: + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + + sns.scatterplot( + data=plot_df, + x=x_col, + y=y_col, + hue="perturbation", + palette=palette_p, + ax=axes[0], + alpha=0.5, + s=8, + linewidth=0, + ) + axes[0].set_title(f"{reduction.upper()} — perturbation") + + sns.scatterplot( + data=plot_df, + x=x_col, + y=y_col, + hue="marker", + palette=palette_m, + ax=axes[1], + alpha=0.5, + s=8, + linewidth=0, + ) + axes[1].set_title(f"{reduction.upper()} — channel/marker") + + sc = axes[2].scatter( + plot_df[x_col], + plot_df[y_col], + c=plot_df["mean_step_size"], + cmap="plasma", + alpha=0.5, + s=8, + vmin=vmin, + vmax=vmax, + ) + plt.colorbar(sc, ax=axes[2], label="Mean embedding step size") + axes[2].set_title(f"{reduction.upper()} — embedding displacement") + axes[2].set_xlabel(x_col) + axes[2].set_ylabel(y_col) + + plt.tight_layout() + out = output_dir / f"microglia_{reduction}.pdf" + plt.savefig(out, bbox_inches="tight") + plt.close() + print(f" Saved {out}") + + # Displacement by perturbation + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + order = sorted(track_metrics["perturbation"].unique()) if "perturbation" in track_metrics.columns else None + + sns.boxplot(data=track_metrics, x="perturbation", y="mean_step_size", ax=axes[0], order=order) + axes[0].set_title("Mean embedding step size by perturbation") + axes[0].set_ylabel("Mean step size in embedding space") + axes[0].tick_params(axis="x", rotation=30) + + sns.boxplot(data=track_metrics, x="perturbation", y="net_displacement", ax=axes[1], order=order) + axes[1].set_title("Net displacement (start→end) by perturbation") + axes[1].set_ylabel("Net displacement in embedding space") + axes[1].tick_params(axis="x", rotation=30) + + plt.tight_layout() + out = output_dir / "microglia_displacement_by_perturbation.pdf" + plt.savefig(out, bbox_inches="tight") + plt.close() + print(f" Saved {out}") + + summary = track_metrics.groupby("perturbation")[["mean_step_size", "net_displacement", "track_length"]].agg( + ["median", "mean", "std", "count"] + ) + print("\n## Microglia track displacement summary\n") + print(summary.to_markdown()) + + +def analyze_alfi(adata: ad.AnnData, output_dir: Path) -> None: + """Run ALFI HeLa cell cycle analysis and save plots.""" + print(f"\nALFI total: {adata.shape[0]:,} observations") + + # Filter to HeLa (MI06) + if "fov_name" in adata.obs.columns: + hela_mask = adata.obs["fov_name"] == "MI06" + elif "experiment" in adata.obs.columns: + hela_mask = adata.obs["experiment"].str.contains("HeLa") + else: + raise RuntimeError("Cannot identify HeLa cells — no fov_name or experiment column in obs") + + adata_hela = adata[hela_mask].copy() + print(f" HeLa (MI06): {adata_hela.shape[0]:,} observations") + + # Join annotations + annotations = pd.read_csv(ALFI_ANNOTATIONS) + ann_indexed = annotations.set_index(["fov_name", "track_id", "t"]) + + obs_hela = adata_hela.obs.copy() + mi = pd.MultiIndex.from_arrays( + [ + obs_hela["fov_name"], + obs_hela["track_id"].astype(int), + obs_hela["t"].astype(int), + ], + names=["fov_name", "track_id", "t"], + ) + obs_hela["cell_division_state"] = ann_indexed.reindex(mi)["cell_division_state"].values + obs_hela["cell_cycle_fine_state"] = ann_indexed.reindex(mi)["cell_cycle_fine_state"].values + + n_annotated = obs_hela["cell_division_state"].notna().sum() + print(f" Annotated: {n_annotated:,} / {len(obs_hela):,}") + print(obs_hela["cell_division_state"].value_counts().to_string()) + + features_hela = np.asarray(adata_hela.X) + features_scaled = StandardScaler().fit_transform(features_hela) + pca_emb = _get_or_compute_pca(adata_hela, features_scaled) + umap_emb = _get_or_compute_umap(adata_hela, features_scaled) + + unannotated = obs_hela["cell_division_state"].isna() + + for reduction, emb in [("pca", pca_emb), ("umap", umap_emb)]: + x_col, y_col = ("PC1", "PC2") if reduction == "pca" else ("UMAP1", "UMAP2") + + # Division state plot + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + for ax, fine in zip(axes, [False, True]): + col = "cell_cycle_fine_state" if fine else "cell_division_state" + states = obs_hela[col].dropna().unique() + if fine: + palette = dict(zip(sorted(states), sns.color_palette("tab10", len(states)))) + else: + palette = DIVISION_PALETTE + + for state, color in palette.items(): + mask = obs_hela[col] == state + ax.scatter( + emb[mask, 0], + emb[mask, 1], + c=color, + label=state, + alpha=0.6, + s=10, + linewidth=0, + ) + ax.scatter( + emb[unannotated, 0], + emb[unannotated, 1], + c="lightgray", + label="unannotated", + alpha=0.3, + s=6, + linewidth=0, + ) + title = "fine cell cycle state" if fine else "cell division state" + ax.set_title(f"HeLa {reduction.upper()} — {title}") + ax.set_xlabel(x_col) + ax.set_ylabel(y_col) + ax.legend(markerscale=2, bbox_to_anchor=(1, 1), loc="upper left", fontsize=8) + + plt.tight_layout() + out = output_dir / f"alfi_hela_{reduction}_cell_cycle.pdf" + plt.savefig(out, bbox_inches="tight") + plt.close() + print(f" Saved {out}") + + # Displacement by cell cycle state + track_metrics = compute_track_displacement_metrics(adata_hela) + + track_annotations = ( + annotations[annotations["fov_name"] == "MI06"] + .groupby(["fov_name", "track_id"])["cell_division_state"] + .agg(lambda x: x.dropna().mode().iloc[0] if x.dropna().shape[0] > 0 else pd.NA) + .reset_index() + .rename(columns={"cell_division_state": "dominant_state"}) + ) + track_metrics = track_metrics.merge(track_annotations, on=["fov_name", "track_id"], how="left") + + annotated = track_metrics.dropna(subset=["dominant_state"]) + if len(annotated) > 0: + fig, ax = plt.subplots(figsize=(6, 5)) + sns.boxplot( + data=annotated, + x="dominant_state", + y="mean_step_size", + palette=DIVISION_PALETTE, + ax=ax, + order=[s for s in DIVISION_PALETTE if s in annotated["dominant_state"].unique()], + ) + ax.set_title("HeLa: embedding step size by cell cycle state") + ax.set_xlabel("Dominant cell division state (per track)") + ax.set_ylabel("Mean step size in embedding space") + plt.tight_layout() + out = output_dir / "alfi_hela_displacement_by_state.pdf" + plt.savefig(out, bbox_inches="tight") + plt.close() + print(f" Saved {out}") + + summary = annotated.groupby("dominant_state")["mean_step_size"].describe() + print("\n## ALFI HeLa displacement by state\n") + print(summary.to_markdown()) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument( + "--microglia-embeddings", type=Path, required=True, help="AnnData zarr from microglia inference" + ) + parser.add_argument("--alfi-embeddings", type=Path, required=True, help="AnnData zarr from ALFI inference") + parser.add_argument("--output-dir", type=Path, required=True, help="Directory to save PDF figures") + args = parser.parse_args() + + args.output_dir.mkdir(parents=True, exist_ok=True) + + print("=== Microglia analysis ===") + adata_micro = ad.read_zarr(args.microglia_embeddings) + analyze_microglia(adata_micro, args.output_dir) + + print("\n=== ALFI analysis ===") + adata_alfi = ad.read_zarr(args.alfi_embeddings) + analyze_alfi(adata_alfi, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/README.md b/applications/dynaclr/scripts/profiling/README.md new file mode 100644 index 000000000..b1b44a730 --- /dev/null +++ b/applications/dynaclr/scripts/profiling/README.md @@ -0,0 +1,40 @@ +# DynaCLR I/O profiling scripts + +Scripts that validate data-loading performance on VAST/NFS for the DynaCLR +contrastive training pipeline. + +## Current scripts + +### `benchmark_recheck_cached_data.py` + +Measures the effect of `TensorStoreConfig.recheck_cached_data` on NFS read +latency for the DynaCLR contrastive read pattern. Exercises the iohub +tensorstore implementation directly (no training stack involved) so it can +be run **before** the dynaclr datamodule is ported to iohub 0.3.x. + +**Prerequisite.** Requires an iohub build with the upstream +`recheck_cached_data` knob on `TensorStoreConfig`. Until that lands, either +install iohub from the feature branch locally, or skip this script. + +Run: + +``` +uv run python applications/dynaclr/scripts/profiling/benchmark_recheck_cached_data.py +``` + +Output is a markdown table comparing median/p95 batch latency, patches/s, +and MiB/s across three configurations (`none`, `"open"`, `false`). Run +twice back-to-back and compare: if the `none` vs `"open"` gap shrinks on +the second run, the Linux NFS client page cache is masking the +per-chunk revalidation cost on this node. + +## Planned follow-ups (after iohub 0.3.x merge into dynadtw) + +- **Dataset-level A/B** — same configurations, but driven through + `MultiExperimentDataModule` + `MultiExperimentTripletDataset` so we + exercise `_get_position`/`_get_tensorstore`/`_slice_patches` and the + `ts.stack(...).read().result()` batched read path exactly as training + does. +- **SLURM DDP A/B** — 200-step fastdev runs with Lightning's + `SimpleProfiler`, comparing `data_time`/`batch_time` and GETATTR/s + from `nfsiostat` across ranks. diff --git a/applications/dynaclr/scripts/profiling/benchmark_boc2d_real.py b/applications/dynaclr/scripts/profiling/benchmark_boc2d_real.py new file mode 100644 index 000000000..206ddcf2e --- /dev/null +++ b/applications/dynaclr/scripts/profiling/benchmark_boc2d_real.py @@ -0,0 +1,264 @@ +"""Production-config DataLoader benchmark + batch-composition sanity check. + +Exercises the real +``DynaCLR-2D-MIP-BagOfChannels.yml`` training settings against the +committed v2 parquet to measure end-to-end DataLoader throughput and +verify that batch-grouping/stratification actually do what the config +says. + +Two parts +--------- + +**1. Composition check** — forces ``batch_group_by="marker"`` and checks +the first 20 batches: + +- every batch contains exactly one marker (single-marker batches), +- different batches surface different markers (proves the grouping is + shuffled across the epoch, not stuck on one value). + +**2. Throughput A/B** — runs the production config +(``batch_size=256``, ``channels_per_sample=1``, ``stratify_by=[perturbation, marker]``, +``num_workers=2``) under two ``recheck_cached_data`` settings: + +- ``None`` — TensorStore driver default. +- ``"open"`` — validate at open only (our merge's default). + +Reports median/p95 per-iter latency, iter/s, samples/s for each leg. +Because this runs on the real VAST-resident parquet with 7k+ FOVs, the +FOV-open amortisation is representative of real training. + +Usage +----- + uv run python applications/dynaclr/scripts/profiling/benchmark_boc2d_real.py +""" + +from __future__ import annotations + +import statistics +import time +from dataclasses import dataclass + +import numpy as np + +from dynaclr.data.datamodule import MultiExperimentDataModule +from viscy_transforms import ( + BatchedChannelWiseZReductiond, + BatchedRandSpatialCropd, + NormalizeSampled, +) + +CELL_INDEX_PARQUET = "/hpc/projects/organelle_phenotyping/models/collections/DynaCLR-2D-MIP-BagOfChannels-v2.parquet" + +BATCH_SIZE = 256 +NUM_WORKERS = 2 +WARMUP_BATCHES = 10 +N_BATCHES = 60 +SEED = 42 + +Z_WINDOW = 1 +Z_EXTRACTION_WINDOW = 20 +Z_FOCUS_OFFSET = 0.3 +YX_PATCH_SIZE = (256, 256) +FINAL_YX_PATCH_SIZE = (160, 160) + +COMPOSITION_BATCHES = 20 + +RECHECK_LEGS: list[tuple[str, str | bool | None]] = [ + ("None (driver default)", None), + ("open (our default)", "open"), +] + + +@dataclass +class LegResult: + """Timing outcome for one recheck_cached_data leg on the real parquet.""" + + label: str + iter_latencies_s: list[float] + total_s: float + + @property + def median_ms(self) -> float: + """Return median per-iter latency in milliseconds.""" + return statistics.median(self.iter_latencies_s) * 1000.0 + + @property + def p95_ms(self) -> float: + """Return p95 per-iter latency in milliseconds.""" + return float(np.percentile(self.iter_latencies_s, 95)) * 1000.0 + + @property + def iter_per_s(self) -> float: + """Return sustained iterations per second.""" + return len(self.iter_latencies_s) / self.total_s + + @property + def samples_per_s(self) -> float: + """Return sustained samples per second.""" + return self.iter_per_s * BATCH_SIZE + + +def _build_production_dm( + recheck_cached_data: str | bool | None, + batch_group_by: str | list[str] | None = None, + stratify_by: list[str] | None = None, + num_workers: int = NUM_WORKERS, +) -> MultiExperimentDataModule: + """Build a DataModule matching the production 2D-MIP-BoC training recipe.""" + normalizations = [ + NormalizeSampled( + keys=["channel_0"], + level="timepoint_statistics", + subtrahend="mean", + divisor="std", + ), + ] + augmentations = [ + BatchedRandSpatialCropd(keys=["channel_0"], roi_size=(10, 192, 192)), + BatchedChannelWiseZReductiond(keys=["channel_0"], allow_missing_keys=True), + ] + dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PARQUET, + focus_channel="Phase3D", + reference_pixel_size_xy_um=0.1494, + z_window=Z_WINDOW, + z_extraction_window=Z_EXTRACTION_WINDOW, + z_focus_offset=Z_FOCUS_OFFSET, + yx_patch_size=YX_PATCH_SIZE, + final_yx_patch_size=FINAL_YX_PATCH_SIZE, + channels_per_sample=1, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + positive_channel_source="same", + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + batch_group_by=batch_group_by, + stratify_by=stratify_by if stratify_by is not None else ["perturbation", "marker"], + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=num_workers, + seed=SEED, + normalizations=normalizations, + augmentations=augmentations, + ) + dm.tensorstore_config = dm.tensorstore_config.model_copy(update={"recheck_cached_data": recheck_cached_data}) + return dm + + +def _composition_check() -> None: + """Verify batch_group_by='marker' yields single-marker, shuffled batches.""" + print("=" * 72) + print("Composition check: batch_group_by='marker'") + print("=" * 72) + + dm = _build_production_dm( + recheck_cached_data="open", + batch_group_by="marker", + stratify_by=None, + num_workers=0, + ) + dm.setup("fit") + loader = dm.train_dataloader() + it = iter(loader) + + markers_by_batch: list[set[str]] = [] + for i in range(COMPOSITION_BATCHES): + batch = next(it) + metas = batch["anchor_meta"] + batch_markers = {m["marker"] for m in metas} + markers_by_batch.append(batch_markers) + print(f" batch {i:>2}: {len(batch_markers)} unique markers → {sorted(batch_markers)[:4]}") + + non_singleton = [i for i, ms in enumerate(markers_by_batch) if len(ms) != 1] + if non_singleton: + print(f"\n FAIL: {len(non_singleton)} of {COMPOSITION_BATCHES} batches had >1 marker") + print(f" offending batches: {non_singleton}") + raise AssertionError("batch_group_by='marker' did not produce single-marker batches") + + unique_markers_seen = set().union(*markers_by_batch) + print(f"\n PASS: all {COMPOSITION_BATCHES} batches are single-marker") + print(f" markers touched across the {COMPOSITION_BATCHES} batches: {len(unique_markers_seen)}") + print(f" → {sorted(unique_markers_seen)}") + + if len(unique_markers_seen) < 2: + print("\n WARNING: only 1 marker touched across all batches — epoch may be stuck on one group") + else: + print(" → grouping is shuffled across markers (good)") + + del it + del loader + + +def _run_throughput_leg(label: str, recheck_cached_data: str | bool | None) -> LegResult: + """Run one throughput leg on the production config.""" + print(f"\n-- Throughput leg: recheck_cached_data = {label} --") + dm = _build_production_dm( + recheck_cached_data=recheck_cached_data, + batch_group_by=None, + stratify_by=["perturbation", "marker"], + num_workers=NUM_WORKERS, + ) + dm.setup("fit") + loader = dm.train_dataloader() + it = iter(loader) + + for _ in range(WARMUP_BATCHES): + _ = next(it) + + latencies_s: list[float] = [] + t_total = time.perf_counter() + t_prev = time.perf_counter() + for _ in range(N_BATCHES): + _ = next(it) + t_now = time.perf_counter() + latencies_s.append(t_now - t_prev) + t_prev = t_now + total_s = time.perf_counter() - t_total + + del it + del loader + + result = LegResult(label=label, iter_latencies_s=latencies_s, total_s=total_s) + print( + f" median {result.median_ms:.1f} ms | p95 {result.p95_ms:.1f} ms | " + f"{result.iter_per_s:.2f} iter/s | {result.samples_per_s:.1f} samples/s" + ) + return result + + +def _print_markdown(results: list[LegResult]) -> None: + """Emit a markdown-formatted throughput table.""" + print() + print("## Throughput (real 2D-MIP-BoC v2 parquet)") + print() + print(f"- Parquet: `{CELL_INDEX_PARQUET.split('/')[-1]}`") + print(f"- Batch size: {BATCH_SIZE}, num_workers: {NUM_WORKERS}") + print(f"- Warmup: {WARMUP_BATCHES} batches; timed: {N_BATCHES} batches") + print(f"- Z_extraction={Z_EXTRACTION_WINDOW}, YX={YX_PATCH_SIZE}, final_YX={FINAL_YX_PATCH_SIZE}") + print("- channels_per_sample=1, stratify_by=[perturbation, marker]") + print() + print("| recheck_cached_data | median ms | p95 ms | iter/s | samples/s |") + print("|---|---:|---:|---:|---:|") + for r in results: + print(f"| {r.label} | {r.median_ms:.1f} | {r.p95_ms:.1f} | {r.iter_per_s:.2f} | {r.samples_per_s:.1f} |") + print() + + +def main() -> None: + """Run composition check, then the throughput A/B, and print a summary.""" + _composition_check() + + print() + print("=" * 72) + print("Throughput A/B: production config, real parquet") + print("=" * 72) + + results: list[LegResult] = [] + for label, value in RECHECK_LEGS: + results.append(_run_throughput_leg(label, value)) + + _print_markdown(results) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/benchmark_dataloader_recheck.py b/applications/dynaclr/scripts/profiling/benchmark_dataloader_recheck.py new file mode 100644 index 000000000..bd8626fad --- /dev/null +++ b/applications/dynaclr/scripts/profiling/benchmark_dataloader_recheck.py @@ -0,0 +1,198 @@ +"""Full-pipeline A/B benchmark for TensorStoreConfig.recheck_cached_data. + +Drives :class:`dynaclr.data.datamodule.MultiExperimentDataModule` +end-to-end — ``__getitems__`` + ``collate_fn=lambda x:x`` + +PyTorch DataLoader with ``num_workers`` forked workers — to measure the +effect of ``recheck_cached_data`` on sustained training-loader +throughput, the only number that actually matters for GPU utilization. + +Three legs are compared against the same parquet, in the same process, +with the same FOVs and the same seed so sampling is deterministic: + +- ``"open"`` — validate at open only, trust cache thereafter (our + expected production setting). +- ``None`` — driver default, revalidate cached chunk metadata every + read (one stat/GETATTR per chunk per read on NFS). +- ``False`` — never revalidate (included for completeness). + +Per leg the script: + +1. Constructs a fresh ``MultiExperimentDataModule``, forcibly overriding + ``self.tensorstore_config.recheck_cached_data`` after ``__init__`` so + every Plate opens with the configured setting. +2. Runs ``setup("fit")`` once. +3. Warms the DataLoader with ``WARMUP_BATCHES`` batches (discarded). +4. Times ``N_BATCHES`` steady-state batches by wall-clocking the + iterator yield interval — this is what the training loop sees. +5. Reports median/p95 iteration time and steady-state iter/s. + +Because we use forked DataLoader workers, each config opens its own +Plates inside the worker after fork — matching real DDP training. + +Usage +----- + uv run python applications/dynaclr/scripts/profiling/benchmark_dataloader_recheck.py + +Requires: + +- iohub with ``recheck_cached_data`` on ``TensorStoreConfig`` + (czbiohub-sf/iohub#406 or later). +- A parquet whose ``store_path`` entries are readable on this node. +""" + +from __future__ import annotations + +import statistics +import time +from dataclasses import dataclass + +import numpy as np + +from dynaclr.data.datamodule import MultiExperimentDataModule + +CELL_INDEX_PARQUET = "/home/eduardo.hirata/repos/viscy/applications/dynaclr/configs/cell_index/benchmark_2exp.parquet" + +BATCH_SIZE = 32 +NUM_WORKERS = 4 +WARMUP_BATCHES = 10 +N_BATCHES = 100 +SEED = 42 + +Z_WINDOW = 8 +YX_PATCH_SIZE = (192, 192) +FINAL_YX_PATCH_SIZE = (160, 160) + +LEGS: list[tuple[str, str | bool | None]] = [ + ("open (recommended)", "open"), + ("None (driver default)", None), + ("False (never revalidate)", False), +] + + +@dataclass +class LegResult: + """Timing outcome for one recheck_cached_data leg.""" + + label: str + iter_latencies_s: list[float] + total_s: float + + @property + def median_ms(self) -> float: + """Return the median inter-batch iteration time in milliseconds.""" + return statistics.median(self.iter_latencies_s) * 1000.0 + + @property + def p95_ms(self) -> float: + """Return the p95 inter-batch iteration time in milliseconds.""" + return float(np.percentile(self.iter_latencies_s, 95)) * 1000.0 + + @property + def iter_per_s(self) -> float: + """Return steady-state iterations per second.""" + return len(self.iter_latencies_s) / self.total_s + + @property + def samples_per_s(self) -> float: + """Return steady-state samples per second.""" + return self.iter_per_s * BATCH_SIZE + + +def _build_datamodule(recheck_cached_data: str | bool | None) -> MultiExperimentDataModule: + """Construct a DataModule and force the recheck_cached_data leg onto its config.""" + dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PARQUET, + z_window=Z_WINDOW, + yx_patch_size=YX_PATCH_SIZE, + final_yx_patch_size=FINAL_YX_PATCH_SIZE, + channels_per_sample=None, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + stratify_by=None, + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=NUM_WORKERS, + seed=SEED, + normalizations=[], + augmentations=[], + ) + # The datamodule sets recheck_cached_data="open" by default; override + # it here so every leg can dial the knob independently without editing + # the production code path. + dm.tensorstore_config = dm.tensorstore_config.model_copy(update={"recheck_cached_data": recheck_cached_data}) + return dm + + +def _run_leg(label: str, recheck_cached_data: str | bool | None) -> LegResult: + """Run one A/B leg and return a populated LegResult.""" + print(f"\n-- Leg: recheck_cached_data = {label} --") + dm = _build_datamodule(recheck_cached_data) + dm.setup("fit") + loader = dm.train_dataloader() + + it = iter(loader) + + # Warmup — discard. Forks workers, populates each worker's plate/ts + # caches, amortises Python import cost in the forked child. + for _ in range(WARMUP_BATCHES): + _ = next(it) + + # Steady-state timing. We measure the inter-batch yield interval, + # which is exactly what the training loop observes. + latencies_s: list[float] = [] + t_total = time.perf_counter() + t_prev = time.perf_counter() + for _ in range(N_BATCHES): + _ = next(it) + t_now = time.perf_counter() + latencies_s.append(t_now - t_prev) + t_prev = t_now + total_s = time.perf_counter() - t_total + + # Release workers before the next leg so forked processes do not + # pile up and compete for file descriptors. + del it + del loader + + result = LegResult(label=label, iter_latencies_s=latencies_s, total_s=total_s) + print( + f" median {result.median_ms:.1f} ms | p95 {result.p95_ms:.1f} ms | " + f"{result.iter_per_s:.2f} iter/s | {result.samples_per_s:.1f} samples/s" + ) + return result + + +def _print_markdown(results: list[LegResult]) -> None: + """Emit a markdown-formatted summary for the PR / Confluence.""" + print() + print("## Results (dataloader-level A/B)") + print() + print(f"- Parquet: `{CELL_INDEX_PARQUET.split('/')[-1]}`") + print(f"- Batch size: {BATCH_SIZE}, num_workers: {NUM_WORKERS}") + print(f"- Warmup: {WARMUP_BATCHES} batches; timed: {N_BATCHES} batches") + print(f"- Z={Z_WINDOW}, YX={YX_PATCH_SIZE}, final_YX={FINAL_YX_PATCH_SIZE}") + print() + print("| recheck_cached_data | median ms | p95 ms | iter/s | samples/s |") + print("|---|---:|---:|---:|---:|") + for r in results: + print(f"| {r.label} | {r.median_ms:.1f} | {r.p95_ms:.1f} | {r.iter_per_s:.2f} | {r.samples_per_s:.1f} |") + print() + + +def main() -> None: + """Run all three legs and print a combined markdown summary.""" + print("=" * 72) + print("Dataloader-level recheck_cached_data benchmark — MultiExperimentDataModule") + print("=" * 72) + + results: list[LegResult] = [] + for label, value in LEGS: + results.append(_run_leg(label, value)) + + _print_markdown(results) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/benchmark_dataloader_workers_sweep.py b/applications/dynaclr/scripts/profiling/benchmark_dataloader_workers_sweep.py new file mode 100644 index 000000000..6ff16e14e --- /dev/null +++ b/applications/dynaclr/scripts/profiling/benchmark_dataloader_workers_sweep.py @@ -0,0 +1,188 @@ +"""Sweep num_workers × recheck_cached_data for the DynaCLR dataloader. + +Purpose +------- + +The first pass A/B (``benchmark_dataloader_recheck.py``) showed a counter- +intuitive result on ``MultiExperimentDataModule.train_dataloader()`` with +``num_workers=4``: ``recheck_cached_data="open"`` was slower than the +driver default. The raw ``ts.stack`` benchmark showed the opposite. Most +likely the p95 tails were dominated by first-touch FOV opens while the +ThreadDataLoader prefetch buffer masked them differently per leg. + +This sweep pins down the cause by running every ``recheck_cached_data`` +value across several ``num_workers`` settings with generous warmup, so we +can tell: + +- Does the ordering flip between ``num_workers=0`` (no fork, no thread + buffer) and ``num_workers>0`` (forked workers)? +- Is the ``"open"`` penalty paid only on cold FOV opens? If yes, longer + warmup should close the gap. +- Does the ``p95`` converge once steady-state is reached? + +Usage +----- + uv run python applications/dynaclr/scripts/profiling/benchmark_dataloader_workers_sweep.py +""" + +from __future__ import annotations + +import statistics +import time +from dataclasses import dataclass + +import numpy as np + +from dynaclr.data.datamodule import MultiExperimentDataModule + +CELL_INDEX_PARQUET = "/hpc/projects/organelle_phenotyping/models/collections/DynaCLR-2D-MIP-BagOfChannels-v2.parquet" + +BATCH_SIZE = 256 +WARMUP_BATCHES = 10 +N_BATCHES = 40 +SEED = 42 + +Z_WINDOW = 1 +Z_EXTRACTION_WINDOW = 20 +YX_PATCH_SIZE = (256, 256) +FINAL_YX_PATCH_SIZE = (160, 160) + +WORKER_COUNTS: list[int] = [0, 2, 4, 8] +RECHECK_VALUES: list[tuple[str, str | bool | None]] = [ + ("None", None), + ("open", "open"), + ("False", False), +] + + +@dataclass +class SweepResult: + """One cell of the ``num_workers`` × ``recheck_cached_data`` grid.""" + + num_workers: int + recheck_label: str + iter_latencies_s: list[float] + total_s: float + + @property + def median_ms(self) -> float: + """Return median per-iter latency in milliseconds.""" + return statistics.median(self.iter_latencies_s) * 1000.0 + + @property + def p95_ms(self) -> float: + """Return p95 per-iter latency in milliseconds.""" + return float(np.percentile(self.iter_latencies_s, 95)) * 1000.0 + + @property + def iter_per_s(self) -> float: + """Return sustained iterations per second across timed batches.""" + return len(self.iter_latencies_s) / self.total_s + + @property + def samples_per_s(self) -> float: + """Return sustained samples per second (iter/s × batch).""" + return self.iter_per_s * BATCH_SIZE + + +def _build(num_workers: int, recheck_cached_data: str | bool | None) -> MultiExperimentDataModule: + """Build one datamodule with forced num_workers and recheck_cached_data.""" + dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PARQUET, + z_window=Z_WINDOW, + z_extraction_window=Z_EXTRACTION_WINDOW, + z_focus_offset=0.3, + yx_patch_size=YX_PATCH_SIZE, + final_yx_patch_size=FINAL_YX_PATCH_SIZE, + channels_per_sample=1, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + stratify_by=["perturbation", "marker"], + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=num_workers, + seed=SEED, + focus_channel="Phase3D", + reference_pixel_size_xy_um=0.1494, + normalizations=[], + augmentations=[], + ) + dm.tensorstore_config = dm.tensorstore_config.model_copy(update={"recheck_cached_data": recheck_cached_data}) + return dm + + +def _run_cell(num_workers: int, label: str, recheck_cached_data: str | bool | None) -> SweepResult: + """Run one cell of the sweep.""" + print(f"\n-- num_workers={num_workers}, recheck_cached_data={label} --") + dm = _build(num_workers, recheck_cached_data) + dm.setup("fit") + loader = dm.train_dataloader() + it = iter(loader) + + for _ in range(WARMUP_BATCHES): + _ = next(it) + + latencies_s: list[float] = [] + t_total = time.perf_counter() + t_prev = time.perf_counter() + for _ in range(N_BATCHES): + _ = next(it) + t_now = time.perf_counter() + latencies_s.append(t_now - t_prev) + t_prev = t_now + total_s = time.perf_counter() - t_total + + del it + del loader + + result = SweepResult( + num_workers=num_workers, + recheck_label=label, + iter_latencies_s=latencies_s, + total_s=total_s, + ) + print( + f" median {result.median_ms:.1f} ms | p95 {result.p95_ms:.1f} ms | " + f"{result.iter_per_s:.2f} iter/s | {result.samples_per_s:.1f} samples/s" + ) + return result + + +def _print_markdown(results: list[SweepResult]) -> None: + """Emit a markdown-formatted sweep table for the PR / Confluence.""" + print() + print("## Sweep results") + print() + print(f"- Parquet: `{CELL_INDEX_PARQUET.split('/')[-1]}`") + print(f"- Batch size: {BATCH_SIZE}, warmup: {WARMUP_BATCHES}, timed: {N_BATCHES}") + print(f"- Z={Z_WINDOW}, YX={YX_PATCH_SIZE}, final_YX={FINAL_YX_PATCH_SIZE}") + print() + print("| num_workers | recheck | median ms | p95 ms | iter/s | samples/s |") + print("|---:|---|---:|---:|---:|---:|") + for r in results: + print( + f"| {r.num_workers} | {r.recheck_label} | " + f"{r.median_ms:.1f} | {r.p95_ms:.1f} | " + f"{r.iter_per_s:.2f} | {r.samples_per_s:.1f} |" + ) + print() + + +def main() -> None: + """Run the full sweep and print a combined markdown summary.""" + print("=" * 72) + print("num_workers × recheck_cached_data sweep — MultiExperimentDataModule") + print("=" * 72) + + results: list[SweepResult] = [] + for nw in WORKER_COUNTS: + for label, value in RECHECK_VALUES: + results.append(_run_cell(nw, label, value)) + + _print_markdown(results) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/benchmark_recheck_cached_data.py b/applications/dynaclr/scripts/profiling/benchmark_recheck_cached_data.py new file mode 100644 index 000000000..98f503a11 --- /dev/null +++ b/applications/dynaclr/scripts/profiling/benchmark_recheck_cached_data.py @@ -0,0 +1,215 @@ +"""Measure the impact of ``TensorStoreConfig.recheck_cached_data`` on NFS reads. + +Single-process raw ``ts.stack(...).read().result()`` loop against a +2-experiment parquet for three TensorStoreConfig settings: + +- ``none`` — driver default, revalidate on every read (one stat/GETATTR + per chunk per read). +- ``open`` — validate only at open time, trust the cache thereafter. +- ``false`` — never revalidate. + +The loop issues ``N_BATCHES`` batches of stacked 3D crops sampled from +random FOVs, reports median/p95 read latency and sustained patches/s. +For the DataLoader-driven end-to-end view see +``benchmark_dataloader_workers_sweep.py``. + +Usage +----- + uv run python applications/dynaclr/scripts/profiling/benchmark_recheck_cached_data.py +""" + +from __future__ import annotations + +import statistics +import time +from dataclasses import dataclass +from typing import Any + +import numpy as np +import pandas as pd +import tensorstore as ts +from iohub import open_ome_zarr +from iohub.core.config import TensorStoreConfig + +CELL_INDEX_PARQUET = "/home/eduardo.hirata/repos/viscy/applications/dynaclr/configs/cell_index/benchmark_2exp.parquet" + +BATCH_SIZE = 32 +N_BATCHES = 50 +PATCH_Z = 8 +PATCH_YX = (192, 192) +SEED = 0 + +DATA_COPY_CONCURRENCY = 16 +FILE_IO_CONCURRENCY = 64 +CACHE_POOL_BYTES: int | None = None + +CONFIGS: list[tuple[str, dict[str, Any]]] = [ + ("none (driver default)", {}), + ("open", {"recheck_cached_data": "open"}), + ("false", {"recheck_cached_data": False}), +] + + +@dataclass +class Result: + """Timing results for one ``recheck_cached_data`` configuration.""" + + label: str + batch_latencies_ms: list[float] + total_bytes: int + total_s: float + + @property + def median_ms(self) -> float: + """Return the median per-batch read latency in milliseconds.""" + return statistics.median(self.batch_latencies_ms) + + @property + def p95_ms(self) -> float: + """Return the p95 per-batch read latency in milliseconds.""" + return float(np.percentile(self.batch_latencies_ms, 95)) + + @property + def patches_per_s(self) -> float: + """Return the sustained patch-read throughput.""" + return BATCH_SIZE * len(self.batch_latencies_ms) / self.total_s + + @property + def mib_per_s(self) -> float: + """Return the sustained read throughput in MiB/s.""" + return (self.total_bytes / (1024 * 1024)) / self.total_s + + +def _load_fov_index() -> pd.DataFrame: + """Return unique (store_path, well, fov, shape) rows from the benchmark parquet.""" + df = pd.read_parquet(CELL_INDEX_PARQUET) + unique = df[["store_path", "well", "fov", "C_shape", "Z_shape", "Y_shape", "X_shape"]].drop_duplicates( + subset=["store_path", "well", "fov"] + ) + return unique.reset_index(drop=True) + + +def _open_stores(fov_df: pd.DataFrame, ts_config: TensorStoreConfig) -> dict[str, Any]: + """Open each unique zarr store once with the given TensorStoreConfig.""" + store_paths = fov_df["store_path"].drop_duplicates().tolist() + plates: dict[str, Any] = {} + for sp in store_paths: + plates[sp] = open_ome_zarr( + sp, + mode="r", + implementation="tensorstore", + implementation_config=ts_config, + ) + return plates + + +def _sample_patches( + fov_df: pd.DataFrame, + plates: dict[str, Any], + batch_size: int, + rng: np.random.Generator, +) -> tuple[list[ts.TensorStore], int]: + """Pick ``batch_size`` random (fov, z, y, x) crops and return lazy slices + byte count. + + Returns a list of tensorstore lazy slices (one per crop) plus the + total number of bytes the resulting stacked read will pull. + """ + lazies: list[ts.TensorStore] = [] + total_bytes = 0 + rows = fov_df.sample(n=batch_size, replace=True, random_state=rng.integers(0, 2**31 - 1)) + for _, row in rows.iterrows(): + plate = plates[row["store_path"]] + position_path = f"{row['well']}/{row['fov']}" + arr = plate[position_path]["0"].native + z_start = int(rng.integers(0, max(1, row["Z_shape"] - PATCH_Z + 1))) + y_start = int(rng.integers(0, max(1, row["Y_shape"] - PATCH_YX[0] + 1))) + x_start = int(rng.integers(0, max(1, row["X_shape"] - PATCH_YX[1] + 1))) + lazy = arr[ + 0, # t=0 — keep indexing simple; timepoint is not what we're benchmarking + :, + z_start : z_start + PATCH_Z, + y_start : y_start + PATCH_YX[0], + x_start : x_start + PATCH_YX[1], + ] + lazies.append(lazy) + total_bytes += PATCH_Z * PATCH_YX[0] * PATCH_YX[1] * row["C_shape"] * 4 # assume float32 + return lazies, total_bytes + + +def _run_one_config(label: str, extra_cfg: dict[str, Any], fov_df: pd.DataFrame) -> Result: + """Run the read-loop benchmark for one recheck_cached_data setting.""" + ts_config = TensorStoreConfig( + data_copy_concurrency=DATA_COPY_CONCURRENCY, + file_io_concurrency=FILE_IO_CONCURRENCY, + cache_pool_bytes=CACHE_POOL_BYTES, + **extra_cfg, + ) + plates = _open_stores(fov_df, ts_config) + + def _translate_all(lazies: list[ts.TensorStore]) -> list[ts.TensorStore]: + """Translate each lazy slice to origin so ts.stack can combine them.""" + return [p.translate_to[0] for p in lazies] # noqa: PD013 + + rng_warm = np.random.default_rng(SEED) + warm_lazies, _ = _sample_patches(fov_df, plates, BATCH_SIZE, rng_warm) + _ = ts.stack(_translate_all(warm_lazies)).read().result() + + rng = np.random.default_rng(SEED + 1) + latencies_ms: list[float] = [] + total_bytes = 0 + t_total = time.perf_counter() + for _ in range(N_BATCHES): + lazies, batch_bytes = _sample_patches(fov_df, plates, BATCH_SIZE, rng) + t0 = time.perf_counter() + _ = ts.stack(_translate_all(lazies)).read().result() + latencies_ms.append((time.perf_counter() - t0) * 1000.0) + total_bytes += batch_bytes + total_s = time.perf_counter() - t_total + + for plate in plates.values(): + plate.close() + + return Result(label=label, batch_latencies_ms=latencies_ms, total_bytes=total_bytes, total_s=total_s) + + +def _print_markdown_table(results: list[Result]) -> None: + """Print a markdown-formatted results table suitable for Confluence/PR pasting.""" + print() + print("## Results") + print() + print(f"- Parquet: `{CELL_INDEX_PARQUET.split('/')[-1]}`") + print(f"- Batch size: {BATCH_SIZE}, N batches: {N_BATCHES}") + print(f"- Patch shape: (C, Z={PATCH_Z}, Y={PATCH_YX[0]}, X={PATCH_YX[1]})") + print(f"- data_copy_concurrency={DATA_COPY_CONCURRENCY}, file_io_concurrency={FILE_IO_CONCURRENCY}") + print() + print("| recheck_cached_data | median ms | p95 ms | patches/s | MiB/s | total s |") + print("|---|---:|---:|---:|---:|---:|") + for r in results: + print( + f"| {r.label} | {r.median_ms:.1f} | {r.p95_ms:.1f} | " + f"{r.patches_per_s:.1f} | {r.mib_per_s:.1f} | {r.total_s:.2f} |" + ) + print() + + +def main() -> None: + """Run the three configurations back-to-back and print a markdown summary.""" + print("=" * 72) + print("recheck_cached_data benchmark — DynaCLR contrastive read pattern on VAST") + print("=" * 72) + + fov_df = _load_fov_index() + print(f"Loaded {len(fov_df)} unique FOVs across {fov_df['store_path'].nunique()} stores") + + results: list[Result] = [] + for label, extra_cfg in CONFIGS: + print(f"\n-- Running: recheck_cached_data = {label} --") + r = _run_one_config(label, extra_cfg, fov_df) + print(f" median {r.median_ms:.1f} ms | p95 {r.p95_ms:.1f} ms | {r.patches_per_s:.1f} patches/s") + results.append(r) + + _print_markdown_table(results) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/profile_dataloaders.py b/applications/dynaclr/scripts/profiling/profile_dataloaders.py new file mode 100644 index 000000000..956fb201d --- /dev/null +++ b/applications/dynaclr/scripts/profiling/profile_dataloaders.py @@ -0,0 +1,371 @@ +"""Profile BatchedConcatDataModule + TripletDatasets vs MultiExperimentDataModule. + +Benchmarks setup time, raw __getitems__ latency, and full dataloader +throughput for: +- Old: BatchedConcatDataModule wrapping 2 TripletDataModules (one per experiment) +- New: Single MultiExperimentDataModule with flat parquet index + +Uses two real datasets: +- 2025_07_24 G3BP1 (stress granules) +- 2025_04_15 H2B (chromatin) + +Usage +----- + uv run python applications/dynaclr/scripts/dataloader_inspection/profile_dataloaders.py +""" + +from __future__ import annotations + +import time + +import numpy as np +import pandas as pd +import torch + +# --------------------------------------------------------------------------- +# Dataset paths +# --------------------------------------------------------------------------- + +COLLECTION_YAML = "applications/dynaclr/configs/collections/benchmark_2exp.yml" +CELL_INDEX_PARQUET = "applications/dynaclr/configs/cell_index/benchmark_2exp.parquet" + +DATASETS = { + "G3BP1": { + "data_path": ( + "/hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" + "/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr" + ), + "tracks_path": ( + "/hpc/projects/organelle_phenotyping/datasets/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/tracking.zarr" + ), + "source_channel": ["raw GFP EX488 EM525-45"], + "include_wells": ["C/1", "C/2"], + }, + "H2B": { + "data_path": ( + "/hpc/projects/organelle_phenotyping/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV" + "/2025_04_15_A549_H2B_CAAX_ZIKV_DENV.zarr" + ), + "tracks_path": ( + "/hpc/projects/organelle_phenotyping/datasets/2025_04_15_A549_H2B_CAAX_ZIKV_DENV/tracking.zarr" + ), + "source_channel": ["raw Cy5 EX639 EM698-70"], + "include_wells": ["B/1", "B/2"], + }, +} + +# Shared benchmark parameters +BATCH_SIZES = [8, 32, 64, 128] +N_BATCHES = 20 +WARMUP_BATCHES = 3 +CACHE_POOL_BYTES = 500_000_000 # 500 MB +Z_RANGE = (30, 46) # 16 z-slices, 3D benchmark + + +def _fmt(seconds: float) -> str: + if seconds < 1: + return f"{seconds * 1000:.1f} ms" + return f"{seconds:.2f} s" + + +# ====================================================================== +# Old: BatchedConcatDataModule wrapping 2 TripletDataModules +# ====================================================================== + + +def setup_old(): + """Set up legacy BatchedConcatDataModule with 2 TripletDataModules.""" + from viscy_data.combined import BatchedConcatDataModule + from viscy_data.triplet import TripletDataModule + + dms = [] + for name, cfg in DATASETS.items(): + dm = TripletDataModule( + data_path=cfg["data_path"], + tracks_path=cfg["tracks_path"], + source_channel=cfg["source_channel"], + z_range=Z_RANGE, + initial_yx_patch_size=(192, 192), + final_yx_patch_size=(160, 160), + split_ratio=0.8, + batch_size=BATCH_SIZES[-1], + num_workers=1, + time_interval=3, + return_negative=False, + cache_pool_bytes=CACHE_POOL_BYTES, + fit_include_wells=cfg["include_wells"], + ) + dms.append(dm) + print(f" Created TripletDataModule for {name}") + + concat_dm = BatchedConcatDataModule(data_modules=dms) + concat_dm.setup("fit") + return concat_dm + + +# ====================================================================== +# New: MultiExperimentDataModule +# ====================================================================== + + +def setup_new(): + """Set up MultiExperimentDataModule with pre-built parquet.""" + from dynaclr.data.datamodule import MultiExperimentDataModule + + dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PARQUET, + z_window=Z_RANGE[1] - Z_RANGE[0], # 16 + yx_patch_size=(192, 192), + final_yx_patch_size=(160, 160), + channels_per_sample=None, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + stratify_by=["perturbation"], + split_ratio=0.8, + batch_size=BATCH_SIZES[-1], + num_workers=1, + seed=42, + cache_pool_bytes=CACHE_POOL_BYTES, + normalizations=[], + augmentations=[], + ) + dm.setup("fit") + return dm + + +# ====================================================================== +# Benchmark helpers +# ====================================================================== + + +def benchmark_getitems( + dataset: torch.utils.data.Dataset, + batch_size: int, + n_batches: int = N_BATCHES, + warmup: int = WARMUP_BATCHES, +) -> dict: + """Time raw __getitems__ calls. + + Parameters + ---------- + dataset : Dataset + Must implement __getitems__(indices). + batch_size : int + Number of indices per call. + n_batches : int + Total batches to time (excluding warmup). + warmup : int + Batches to discard for cache warmup. + + Returns + ------- + dict + Timing statistics. + """ + n_samples = len(dataset) + rng = np.random.default_rng(42) + total = warmup + n_batches + + times = [] + for i in range(total): + indices = rng.integers(0, n_samples, size=batch_size).tolist() + t0 = time.perf_counter() + _ = dataset.__getitems__(indices) + t1 = time.perf_counter() + if i >= warmup: + times.append(t1 - t0) + + times_arr = np.array(times) + return { + "batch_size": batch_size, + "mean_ms": times_arr.mean() * 1000, + "std_ms": times_arr.std() * 1000, + "median_ms": np.median(times_arr) * 1000, + "p95_ms": np.percentile(times_arr, 95) * 1000, + "throughput_samples_per_sec": batch_size / times_arr.mean(), + } + + +def benchmark_dataloader( + dataloader, + n_batches: int = N_BATCHES, + warmup: int = WARMUP_BATCHES, +) -> dict: + """Time full dataloader iteration. + + Parameters + ---------- + dataloader : DataLoader + Configured dataloader. + n_batches : int + Batches to time after warmup. + warmup : int + Batches to discard. + + Returns + ------- + dict + Timing statistics. + """ + timestamps = [] + total_samples = 0 + + for i, batch in enumerate(dataloader): + if i >= warmup + n_batches: + break + now = time.perf_counter() + if i >= warmup: + timestamps.append(now) + # Count samples in batch + if isinstance(batch, list): + # BatchedConcatDataModule returns list of micro-batches + for mb in batch: + if isinstance(mb, dict) and "anchor" in mb: + total_samples += mb["anchor"].shape[0] + elif isinstance(batch, dict) and "anchor" in batch: + total_samples += batch["anchor"].shape[0] + + if len(timestamps) < 2: + return {"note": "not enough batches"} + + inter_batch = np.diff(timestamps) + return { + "n_batches": len(inter_batch), + "total_samples": total_samples, + "mean_inter_batch_ms": inter_batch.mean() * 1000, + "std_inter_batch_ms": inter_batch.std() * 1000, + "median_inter_batch_ms": np.median(inter_batch) * 1000, + "throughput_samples_per_sec": total_samples / inter_batch.sum() if inter_batch.sum() > 0 else 0, + } + + +# ====================================================================== +# Main +# ====================================================================== + + +def main(): + """Profile and compare dataloader implementations.""" + results = [] + + print("=" * 70) + print("DATALOADER PROFILING") + print("BatchedConcatDataModule + TripletDatasets vs MultiExperimentDataModule") + print("=" * 70) + print("\nDatasets: G3BP1 (2025_07_24) + H2B (2025_04_15)") + print(f"Z range: {Z_RANGE} ({Z_RANGE[1] - Z_RANGE[0]} slices)") + print("Patch: 192x192 -> 160x160") + print(f"Cache: {CACHE_POOL_BYTES / 1e6:.0f} MB") + + # ------------------------------------------------------------------ + # Setup timing + # ------------------------------------------------------------------ + print("\n## Setup: Old (BatchedConcatDataModule + 2x TripletDataModule)") + t0 = time.perf_counter() + old_dm = setup_old() + old_setup_time = time.perf_counter() - t0 + n_old_train = len(old_dm.train_dataset) + n_old_val = len(old_dm.val_dataset) + print(f" Setup time: {_fmt(old_setup_time)}") + print(f" Train samples: {n_old_train} | Val samples: {n_old_val}") + + print("\n## Setup: New (MultiExperimentDataModule)") + t0 = time.perf_counter() + new_dm = setup_new() + new_setup_time = time.perf_counter() - t0 + n_new_train = len(new_dm.train_dataset) + n_new_val = len(new_dm.val_dataset) if new_dm.val_dataset else 0 + print(f" Setup time: {_fmt(new_setup_time)}") + print(f" Train samples: {n_new_train} | Val samples: {n_new_val}") + + # ------------------------------------------------------------------ + # Benchmark 1: Raw __getitems__ + # ------------------------------------------------------------------ + print("\n" + "=" * 70) + print("BENCHMARK 1: Raw __getitems__ (no dataloader, no transforms)") + print("=" * 70) + + for bs in BATCH_SIZES: + print(f"\n### batch_size={bs}") + + stats_old = benchmark_getitems(old_dm.train_dataset, bs) + stats_old["dataset"] = "Old (BatchedConcatDataset)" + results.append(stats_old) + print( + f" Old: {stats_old['mean_ms']:.1f} ± {stats_old['std_ms']:.1f} ms/batch " + f"| p95={stats_old['p95_ms']:.1f} ms " + f"| {stats_old['throughput_samples_per_sec']:.0f} samples/s" + ) + + stats_new = benchmark_getitems(new_dm.train_dataset, bs) + stats_new["dataset"] = "New (MultiExperimentTripletDataset)" + results.append(stats_new) + print( + f" New: {stats_new['mean_ms']:.1f} ± {stats_new['std_ms']:.1f} ms/batch " + f"| p95={stats_new['p95_ms']:.1f} ms " + f"| {stats_new['throughput_samples_per_sec']:.0f} samples/s" + ) + + speedup = stats_old["mean_ms"] / stats_new["mean_ms"] if stats_new["mean_ms"] > 0 else float("inf") + direction = "faster" if speedup > 1 else "slower" + print(f" New is {abs(speedup):.2f}x {direction}") + + # ------------------------------------------------------------------ + # Benchmark 2: Full dataloader + # ------------------------------------------------------------------ + print("\n" + "=" * 70) + print("BENCHMARK 2: Full ThreadDataLoader iteration") + print("=" * 70) + + for bs in [32, 64]: + print(f"\n### batch_size={bs}") + + # Old + old_dm.batch_size = bs + for sub_dm in old_dm.data_modules: + sub_dm.batch_size = bs + old_dl = old_dm.train_dataloader() + dl_old = benchmark_dataloader(old_dl) + print( + f" Old: {dl_old.get('mean_inter_batch_ms', 0):.1f} ± " + f"{dl_old.get('std_inter_batch_ms', 0):.1f} ms/batch " + f"| {dl_old.get('throughput_samples_per_sec', 0):.0f} samples/s" + ) + + # New + new_dm.batch_size = bs + new_dl = new_dm.train_dataloader() + dl_new = benchmark_dataloader(new_dl) + print( + f" New: {dl_new.get('mean_inter_batch_ms', 0):.1f} ± " + f"{dl_new.get('std_inter_batch_ms', 0):.1f} ms/batch " + f"| {dl_new.get('throughput_samples_per_sec', 0):.0f} samples/s" + ) + + # ------------------------------------------------------------------ + # Summary + # ------------------------------------------------------------------ + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + + print("\n### __getitems__ throughput (samples/sec)") + summary = pd.DataFrame(results) + pivot = summary.pivot_table( + index="batch_size", + columns="dataset", + values="throughput_samples_per_sec", + ) + print(pivot.to_string(float_format=lambda x: f"{x:.0f}")) + + print("\n### Setup times") + print("| Pipeline | Setup Time |") + print("|----------|-----------|") + print(f"| Old (BatchedConcatDataModule) | {_fmt(old_setup_time)} |") + print(f"| New (MultiExperimentDataModule) | {_fmt(new_setup_time)} |") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/profile_num_workers.py b/applications/dynaclr/scripts/profiling/profile_num_workers.py new file mode 100644 index 000000000..e57279021 --- /dev/null +++ b/applications/dynaclr/scripts/profiling/profile_num_workers.py @@ -0,0 +1,175 @@ +"""Sweep num_workers to find optimal dataloader parallelism. + +Holds all other parameters constant and measures end-to-end ThreadDataLoader +throughput (samples/sec and inter-batch latency) for num_workers in [1, 2, 4, 8]. + +Unlike profile_stages.py (which isolates individual pipeline stages) or +profile_dataloaders.py (which compares two dataloader implementations), this +script answers: does adding more CPU workers reduce GPU starvation? + +Usage +----- + uv run python applications/dynaclr/scripts/dataloader_inspection/profile_num_workers.py +""" + +from __future__ import annotations + +import time + +import numpy as np + +from dynaclr.data.datamodule import MultiExperimentDataModule + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +CELL_INDEX_PARQUET = "applications/dynaclr/configs/cell_index/benchmark_2exp.parquet" + +BATCH_SIZE = 128 +N_BATCHES = 30 +WARMUP = 5 +CACHE_POOL_BYTES = 500_000_000 # 500 MB + +Z_WINDOW = 16 +Z_EXTRACTION_WINDOW = 45 +YX_PATCH = (192, 192) +FINAL_YX_PATCH = (160, 160) + +NUM_WORKERS_SWEEP = [1, 2, 4, 8] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def setup_dm(num_workers: int) -> MultiExperimentDataModule: + """Build a MultiExperimentDataModule with the given num_workers.""" + dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PARQUET, + z_window=Z_WINDOW, + z_extraction_window=Z_EXTRACTION_WINDOW, + z_focus_offset=0.3, + yx_patch_size=YX_PATCH, + final_yx_patch_size=FINAL_YX_PATCH, + channels_per_sample=1, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + stratify_by=["perturbation"], + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=num_workers, + seed=42, + cache_pool_bytes=CACHE_POOL_BYTES, + normalizations=[], + augmentations=[], + ) + dm.setup("fit") + return dm + + +def benchmark_dataloader(dataloader, n_batches: int = N_BATCHES, warmup: int = WARMUP) -> dict: + """Measure inter-batch latency and throughput over the dataloader. + + Parameters + ---------- + dataloader : ThreadDataLoader + Configured training dataloader. + n_batches : int + Number of batches to time after warmup. + warmup : int + Batches to discard for cache/thread warmup. + + Returns + ------- + dict + Inter-batch timing stats, throughput in samples/sec, and VAST bandwidth in MB/s. + """ + timestamps = [] + total_samples = 0 + read_mb_per_batch = None + + for i, batch in enumerate(dataloader): + if i >= warmup + n_batches: + break + now = time.perf_counter() + if i >= warmup: + timestamps.append(now) + if isinstance(batch, dict) and "anchor" in batch: + total_samples += batch["anchor"].shape[0] + if read_mb_per_batch is None: + # anchor + positive (fit mode). Lower bound — ignores chunk alignment overhead. + n_tensors = 2 if "positive" in batch else 1 + read_mb_per_batch = batch["anchor"].nelement() * batch["anchor"].element_size() * n_tensors / 1e6 + + if len(timestamps) < 2: + return {"note": "not enough batches"} + + inter_batch = np.diff(timestamps) + mean_s = inter_batch.mean() + bandwidth_mb_s = read_mb_per_batch / mean_s if read_mb_per_batch else 0.0 + return { + "mean_ms": mean_s * 1000, + "std_ms": inter_batch.std() * 1000, + "median_ms": float(np.median(inter_batch) * 1000), + "p95_ms": float(np.percentile(inter_batch, 95) * 1000), + "throughput_samples_per_sec": total_samples / inter_batch.sum(), + "read_mb_per_batch": read_mb_per_batch or 0.0, + "bandwidth_mb_s": bandwidth_mb_s, + } + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + """Sweep num_workers and report throughput.""" + print("=" * 60) + print("num_workers SWEEP — ThreadDataLoader throughput") + print("=" * 60) + print(f"batch_size={BATCH_SIZE}, z={Z_EXTRACTION_WINDOW}→{Z_WINDOW}") + print(f"patch={YX_PATCH}→{FINAL_YX_PATCH}, channels_per_sample=1") + print(f"warmup={WARMUP} batches, measured over {N_BATCHES} batches") + print() + + # Setup is shared across runs — only the dataloader changes. + # Re-setup for each num_workers since ThreadDataLoader is created in train_dataloader(). + results = [] + for nw in NUM_WORKERS_SWEEP: + print(f"## num_workers={nw}") + dm = setup_dm(nw) + dl = dm.train_dataloader() + stats = benchmark_dataloader(dl) + stats["num_workers"] = nw + results.append(stats) + print( + f" {stats['mean_ms']:.1f} ± {stats['std_ms']:.1f} ms/batch" + f" | p95={stats['p95_ms']:.1f} ms" + f" | {stats['throughput_samples_per_sec']:.0f} samples/sec" + f" | {stats['bandwidth_mb_s']:.0f} MB/s" + ) + print() + + print("=" * 60) + print("SUMMARY") + print("=" * 60) + print() + read_mb = results[0]["read_mb_per_batch"] if results else 0.0 + print(f"Read volume per batch (lower bound): {read_mb:.0f} MB") + print() + print("| num_workers | mean ms/batch | p95 ms | samples/sec | MB/s (VAST) |") + print("|-------------|---------------|--------|-------------|-------------|") + for r in results: + print( + f"| {r['num_workers']:11d} | {r['mean_ms']:13.1f} | {r['p95_ms']:6.1f}" + f" | {r['throughput_samples_per_sec']:11.0f} | {r['bandwidth_mb_s']:11.0f} |" + ) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/profile_predict_batch_size.py b/applications/dynaclr/scripts/profiling/profile_predict_batch_size.py new file mode 100644 index 000000000..a5b820164 --- /dev/null +++ b/applications/dynaclr/scripts/profiling/profile_predict_batch_size.py @@ -0,0 +1,219 @@ +"""Sweep batch_size for prediction to find GPU utilization sweet spot. + +Times the full predict pipeline (dataloader I/O + GPU forward) at increasing +batch sizes to find where GPU utilization saturates on the local A40. + +Uses the microglia-eval parquet and the 2D MIP checkpoint. + +Usage +----- + uv run python applications/dynaclr/scripts/dataloader_inspection/profile_predict_batch_size.py +""" + +from __future__ import annotations + +import time + +import numpy as np +import torch + +from dynaclr.data.datamodule import MultiExperimentDataModule +from viscy_data._utils import _transform_channel_wise +from viscy_models.contrastive import ContrastiveEncoder +from viscy_transforms import BatchedChannelWiseZReductiond, NormalizeSampled + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +CELL_INDEX_PARQUET = "/hpc/projects/organelle_phenotyping/models/collections/microglia-eval.parquet" +CKPT_PATH = ( + "/hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels" + "/2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11" + "/DynaCLR-2D-MIP-BagOfChannels/20260403-150013/checkpoints/last.ckpt" +) + +BATCH_SIZES = [256, 512, 1024, 2048, 4096] +N_BATCHES = 20 +WARMUP = 3 +NUM_WORKERS = 4 +DEVICE = "cuda" + + +# --------------------------------------------------------------------------- +# Setup +# --------------------------------------------------------------------------- + + +def setup_dm(batch_size: int) -> MultiExperimentDataModule: + """Build a predict-mode MultiExperimentDataModule for the given batch size.""" + dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PARQUET, + focus_channel="Phase3D", + reference_pixel_size_xy_um=0.1494, + z_window=1, + z_extraction_window=11, + z_focus_offset=0.5, + yx_patch_size=(192, 192), + final_yx_patch_size=(160, 160), + channels_per_sample=1, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + split_ratio=1.0, + batch_size=batch_size, + num_workers=NUM_WORKERS, + pin_memory=True, + seed=42, + normalizations=[ + NormalizeSampled( + keys=["channel_0"], + level="timepoint_statistics", + subtrahend="mean", + divisor="std", + ), + BatchedChannelWiseZReductiond(keys=["channel_0"], allow_missing_keys=True), + ], + augmentations=[], + ) + dm.setup("predict") + return dm + + +def load_model() -> torch.nn.Module: + """Load ConvNeXt-Tiny encoder from the benchmark checkpoint.""" + encoder = ContrastiveEncoder( + backbone="convnext_tiny", + in_channels=1, + in_stack_depth=1, + stem_kernel_size=[1, 4, 4], + stem_stride=[1, 4, 4], + embedding_dim=768, + projection_dim=32, + drop_path_rate=0.0, + ) + ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=True) + # checkpoint keys are prefixed with "model." since ContrastiveModule stores encoder as self.model + state = {k.removeprefix("model."): v for k, v in ckpt["state_dict"].items() if k.startswith("model.")} + encoder.load_state_dict(state) + encoder.eval() + encoder.to(DEVICE) + return encoder + + +# --------------------------------------------------------------------------- +# Benchmark +# --------------------------------------------------------------------------- + + +def benchmark(batch_size: int, model: torch.nn.Module) -> dict: + """Time the predict pipeline (I/O + forward) over N_BATCHES after warmup.""" + dm = setup_dm(batch_size) + dl = dm.predict_dataloader() + + forward_times = [] + samples_processed = 0 + t_start = None + + with torch.inference_mode(): + for i, batch in enumerate(dl): + if i >= WARMUP + N_BATCHES: + break + + # Mirror the predict path: apply _predict_transform then forward + norm_meta = batch.get("anchor_norm_meta") + if isinstance(norm_meta, list) and all(m is None for m in norm_meta): + norm_meta = None + anchor = _transform_channel_wise( + transform=dm._predict_transform, + channel_names=dm._channel_names, + patch=batch["anchor"].to(DEVICE), + norm_meta=norm_meta, + ) + + if i == WARMUP: + torch.cuda.synchronize() + t_start = time.perf_counter() + + torch.cuda.synchronize() + t0 = time.perf_counter() + _ = model(anchor) + torch.cuda.synchronize() + t1 = time.perf_counter() + + if i >= WARMUP: + forward_times.append(t1 - t0) + samples_processed += anchor.shape[0] + + torch.cuda.synchronize() + t_end = time.perf_counter() + + wall_s = t_end - t_start if t_start else 1.0 + fwd = np.array(forward_times) * 1000 + + return { + "batch_size": batch_size, + "forward_mean_ms": fwd.mean(), + "forward_std_ms": fwd.std(), + "e2e_samples_per_sec": samples_processed / wall_s, + "gpu_mem_mib": torch.cuda.max_memory_allocated() // (1024**2), + } + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + """Sweep batch sizes and print a throughput summary table.""" + if not torch.cuda.is_available(): + print("No GPU available.") + return + + gpu_name = torch.cuda.get_device_name(0) + total_mib = torch.cuda.get_device_properties(0).total_memory // (1024**2) + print("=" * 65) + print(f"Predict batch_size sweep — {gpu_name} ({total_mib} MiB)") + print("=" * 65) + print(f"num_workers={NUM_WORKERS}, warmup={WARMUP}, measured={N_BATCHES} batches") + print("model: ConvNeXt-Tiny 2D MIP, input 1×1×160×160") + print() + + print("Loading model...") + model = load_model() + torch.cuda.reset_peak_memory_stats() + + results = [] + for bs in BATCH_SIZES: + print(f"batch_size={bs} ...", end=" ", flush=True) + try: + torch.cuda.reset_peak_memory_stats() + r = benchmark(bs, model) + results.append(r) + print( + f"{r['forward_mean_ms']:.1f} ms fwd | " + f"{r['e2e_samples_per_sec']:.0f} samples/sec | " + f"{r['gpu_mem_mib']} MiB" + ) + except torch.cuda.OutOfMemoryError: + print("OOM") + break + + print() + print("=" * 65) + print("SUMMARY") + print("=" * 65) + print() + print("| batch_size | fwd ms | samples/sec | GPU MiB |") + print("|------------|--------|-------------|---------|") + for r in results: + print( + f"| {r['batch_size']:10d} | {r['forward_mean_ms']:6.1f} | " + f"{r['e2e_samples_per_sec']:11.0f} | {r['gpu_mem_mib']:7d} |" + ) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/profiling/profile_stages.py b/applications/dynaclr/scripts/profiling/profile_stages.py new file mode 100644 index 000000000..e00adc294 --- /dev/null +++ b/applications/dynaclr/scripts/profiling/profile_stages.py @@ -0,0 +1,326 @@ +"""Profile per-stage breakdown: I/O vs normalization vs augmentation vs crop. + +Isolates each stage of the training batch pipeline to find the bottleneck: +1. I/O: __getitems__ (tensorstore read + positive sampling) +2. CPU→GPU: .to(device) transfer +3. Normalization: NormalizeSampled (fov/timepoint stats) +4. Augmentation: affine + flip + contrast + scale + smooth + noise +5. Final crop: BatchedRandSpatialCropd (z_extraction → z_window) + +Uses the new MultiExperimentDataModule with the benchmark_2exp collection. +Requires GPU. + +Usage +----- + uv run python applications/dynaclr/scripts/dataloader_inspection/profile_stages.py +""" + +from __future__ import annotations + +import time + +import numpy as np +import torch +from monai.transforms import Compose + +from dynaclr.data.datamodule import MultiExperimentDataModule +from viscy_transforms import ( + BatchedRandAdjustContrastd, + BatchedRandAffined, + BatchedRandFlipd, + BatchedRandGaussianNoised, + BatchedRandGaussianSmoothd, + BatchedRandScaleIntensityd, + BatchedRandSpatialCropd, + NormalizeSampled, +) + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +COLLECTION_YAML = "applications/dynaclr/configs/collections/benchmark_2exp.yml" +CELL_INDEX_PARQUET = "applications/dynaclr/configs/cell_index/benchmark_2exp.parquet" + +BATCH_SIZE = 256 +N_BATCHES = 15 +WARMUP = 3 +CACHE_POOL_BYTES = 500_000_000 + +Z_WINDOW = 1 +Z_EXTRACTION_WINDOW = 11 +YX_PATCH = (192, 192) +FINAL_YX_PATCH = (160, 160) + +CHANNEL_KEY = "channel_0" +DEVICE = "cuda" + + +def _fmt(seconds: float) -> str: + if seconds < 1: + return f"{seconds * 1000:.1f} ms" + return f"{seconds:.2f} s" + + +def setup(): + """Set up MultiExperimentDataModule with production-like config.""" + dm = MultiExperimentDataModule( + cell_index_path=CELL_INDEX_PARQUET, + z_window=Z_WINDOW, + z_extraction_window=Z_EXTRACTION_WINDOW, + z_focus_offset=0.3, + yx_patch_size=YX_PATCH, + final_yx_patch_size=FINAL_YX_PATCH, + channels_per_sample=None, + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + tau_range=(0.5, 2.0), + tau_decay_rate=2.0, + stratify_by=["perturbation"], + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=1, + seed=42, + cache_pool_bytes=CACHE_POOL_BYTES, + normalizations=[], + augmentations=[], + ) + dm.setup("fit") + return dm + + +def build_transforms(): + """Build the individual transform stages matching DynaCLR-3D-BagOfChannels-v2.""" + normalization = NormalizeSampled( + keys=[CHANNEL_KEY], + level="fov_statistics", + subtrahend="mean", + divisor="std", + ) + + augmentations = [ + BatchedRandAffined( + keys=[CHANNEL_KEY], + prob=0.8, + scale_range=[[0.9, 1.1], [0.9, 1.1], [0.9, 1.1]], + rotate_range=[3.14, 0.0, 0.0], + shear_range=[0.05, 0.05, 0.0, 0.05, 0.0, 0.05], + ), + BatchedRandFlipd( + keys=[CHANNEL_KEY], + spatial_axes=[1, 2], + prob=0.5, + ), + BatchedRandAdjustContrastd( + keys=[CHANNEL_KEY], + prob=0.5, + gamma=(0.6, 1.6), + ), + BatchedRandScaleIntensityd( + keys=[CHANNEL_KEY], + prob=0.5, + factors=0.5, + ), + BatchedRandGaussianSmoothd( + keys=[CHANNEL_KEY], + prob=0.5, + sigma_x=[0.25, 0.50], + sigma_y=[0.25, 0.50], + sigma_z=[0.0, 0.2], + ), + BatchedRandGaussianNoised( + keys=[CHANNEL_KEY], + prob=0.5, + mean=0.0, + std=0.1, + ), + ] + + final_crop = BatchedRandSpatialCropd( + keys=[CHANNEL_KEY], + roi_size=(Z_WINDOW, FINAL_YX_PATCH[0], FINAL_YX_PATCH[1]), + ) + + return normalization, augmentations, final_crop + + +def time_stage(fn, n_batches=N_BATCHES, warmup=WARMUP): + """Time a callable over multiple iterations, return stats. + + Parameters + ---------- + fn : callable + Function to time. Called with no arguments. + n_batches : int + Iterations to time after warmup. + warmup : int + Iterations to discard. + + Returns + ------- + dict + mean_ms, std_ms, median_ms. + """ + times = [] + for i in range(warmup + n_batches): + if DEVICE == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + result = fn() + if DEVICE == "cuda": + torch.cuda.synchronize() + t1 = time.perf_counter() + if i >= warmup: + times.append(t1 - t0) + arr = np.array(times) + return { + "mean_ms": arr.mean() * 1000, + "std_ms": arr.std() * 1000, + "median_ms": np.median(arr) * 1000, + }, result + + +def main(): + """Profile individual dataloader pipeline stages.""" + print("=" * 70) + print("STAGE BREAKDOWN: I/O → Transfer → Normalize → Augment → Crop") + print("=" * 70) + print(f"batch_size={BATCH_SIZE}, z_extraction={Z_EXTRACTION_WINDOW}→z_window={Z_WINDOW}") + print(f"patch={YX_PATCH}→{FINAL_YX_PATCH}, device={DEVICE}") + print() + + # Setup + dm = setup() + dataset = dm.train_dataset + normalization, augmentations, final_crop = build_transforms() + rng = np.random.default_rng(42) + n_samples = len(dataset) + + def random_indices(): + return rng.integers(0, n_samples, size=BATCH_SIZE).tolist() + + # Pre-generate index lists so index generation doesn't pollute timing + all_indices = [random_indices() for _ in range(WARMUP + N_BATCHES + 5)] + idx_iter = iter(all_indices) + + # ── Stage 1: I/O (__getitems__) ── + print("## Stage 1: I/O (__getitems__)") + batches = [] + + def io_step(): + indices = next(idx_iter) + batch = dataset.__getitems__(indices) + batches.append(batch) + return batch + + io_stats, _ = time_stage(io_step) + + # Use the last batch for subsequent stages + sample_batch = batches[-1] + anchor = sample_batch["anchor"] + positive = sample_batch.get("positive") + + # Read volume: what was actually fetched from VAST (z_extraction_window, not z_window). + # anchor + positive (fit mode reads both). Lower bound — chunk alignment may add overhead. + n_tensors = 2 if positive is not None else 1 + read_bytes = anchor.nelement() * anchor.element_size() * n_tensors + read_mb = read_bytes / 1e6 + bandwidth_mb_s = read_mb / (io_stats["mean_ms"] / 1000) + io_stats["read_mb"] = read_mb + io_stats["bandwidth_mb_s"] = bandwidth_mb_s + + print(f" {io_stats['mean_ms']:.1f} ± {io_stats['std_ms']:.1f} ms") + pos_label = "+ positive" if positive is not None else "" + print(f" read volume: {read_mb:.0f} MB (anchor{pos_label}) | bandwidth: {bandwidth_mb_s:.0f} MB/s") + print(f" anchor shape: {anchor.shape}, dtype: {anchor.dtype}") + + # ── Stage 2: CPU→GPU transfer ── + print("\n## Stage 2: CPU → GPU transfer") + + def transfer_step(): + return anchor.to(DEVICE, non_blocking=True) + + transfer_stats, gpu_anchor = time_stage(transfer_step) + print(f" {transfer_stats['mean_ms']:.1f} ± {transfer_stats['std_ms']:.1f} ms") + print(f" tensor size: {anchor.nelement() * anchor.element_size() / 1e6:.1f} MB") + + # ── Stage 3: Normalization ── + print("\n## Stage 3: Normalization (subtract mean, divide std — manual)") + # NormalizeSampled via _transform_channel_wise requires channel-name + # alignment that depends on the full DataModule context. Time the raw + # arithmetic instead: this is what NormalizeSampled does per channel. + + def norm_step(): + x = gpu_anchor.clone() + mean = x.mean(dim=(-3, -2, -1), keepdim=True) + std = x.std(dim=(-3, -2, -1), keepdim=True) + return (x - mean) / (std + 1e-8) + + norm_stats, normed = time_stage(norm_step) + print(f" {norm_stats['mean_ms']:.1f} ± {norm_stats['std_ms']:.1f} ms") + + # ── Stage 4: Augmentations (individually) ── + print("\n## Stage 4: Augmentations (individual)") + aug_names = [ + "RandAffined", + "RandFlipd", + "RandAdjustContrastd", + "RandScaleIntensityd", + "RandGaussianSmoothd", + "RandGaussianNoised", + ] + aug_total = 0.0 + current_input = normed + + for aug_name, aug_transform in zip(aug_names, augmentations): + t = Compose([aug_transform]) + inp = current_input + + def aug_step(transform=t, data=inp): + d = {CHANNEL_KEY: data.clone()} + return transform(d)[CHANNEL_KEY] + + stats, current_input = time_stage(aug_step) + aug_total += stats["mean_ms"] + print(f" {aug_name:30s} {stats['mean_ms']:8.1f} ± {stats['std_ms']:.1f} ms") + + print(f" {'TOTAL':30s} {aug_total:8.1f} ms") + + # ── Stage 5: Final crop ── + print("\n## Stage 5: Final crop (BatchedRandSpatialCropd)") + crop_input = current_input + + def crop_step(): + d = {CHANNEL_KEY: crop_input.clone()} + return final_crop(d)[CHANNEL_KEY] + + crop_stats, _ = time_stage(crop_step) + print(f" {crop_stats['mean_ms']:.1f} ± {crop_stats['std_ms']:.1f} ms") + + # ── Summary ── + print("\n" + "=" * 70) + print("SUMMARY (mean ms per batch)") + print("=" * 70) + + stages = { + "I/O (__getitems__)": io_stats["mean_ms"], + "CPU→GPU transfer": transfer_stats["mean_ms"], + "Normalization": norm_stats["mean_ms"], + "Augmentations (total)": aug_total, + "Final crop": crop_stats["mean_ms"], + } + total = sum(stages.values()) + + print("\n| Stage | Time (ms) | % of total | Bandwidth |") + print("|-------|-----------|------------|-----------|") + for name, ms in stages.items(): + if name == "I/O (__getitems__)": + bw = f"{io_stats['bandwidth_mb_s']:.0f} MB/s ({io_stats['read_mb']:.0f} MB read)" + else: + bw = "—" + print(f"| {name} | {ms:.1f} | {ms / total * 100:.1f}% | {bw} |") + print(f"| **Total** | **{total:.1f}** | **100%** | |") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/0-select_candidates/lineage_utils.py b/applications/dynaclr/scripts/pseudotime/0-select_candidates/lineage_utils.py new file mode 100644 index 000000000..fd4dcb0eb --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/0-select_candidates/lineage_utils.py @@ -0,0 +1,245 @@ +"""Lineage reconnection and cohort tagging for Stage 0. + +Reuses the library function ``identify_lineages`` from +``dynaclr.pseudotime.alignment`` and adds the per-lineage classifiers +needed by the new DAG: ``divides`` (none/pre/during/post relative to +``t_zero``) and ``cohort`` (productive/bystander/abortive/mock). + +The cohort logic is operational and LC-derived per discussion §3.2: +- productive — lineage has manual ``t_key_event`` +- bystander — lineage in infected well, LC says uninfected for ≥80% of frames +- abortive — lineage in infected well, LC shows brief positive run < ``min_run`` then no sustained rise +- mock — lineage from uninfected wells (well_pattern in candidates.yaml) +""" + +from __future__ import annotations + +import logging +from typing import Literal + +import pandas as pd + +from dynaclr.pseudotime.alignment import identify_lineages + +_logger = logging.getLogger(__name__) + + +ORPHAN_LINEAGE_ID = "" + + +def _format_lineage_id(dataset_id: str, fov_name: str, root_tid: int, leaf_tid: int) -> str: + """Build a content-stable lineage id from its branch endpoints. + + The id is unique across the whole pipeline as long as + ``(dataset_id, fov_name, root_track_id, leaf_track_id)`` is unique. + Stable across re-runs regardless of cohort composition or row order. + """ + return f"{dataset_id}|{fov_name}|root{int(root_tid)}|leaf{int(leaf_tid)}" + + +def assign_lineage_ids( + df: pd.DataFrame, + return_both_branches: bool = True, +) -> pd.DataFrame: + """Add a content-stable ``lineage_id`` column. + + Calls ``identify_lineages()`` per ``dataset_id`` group; for each branch + ``[root_track_id, ..., leaf_track_id]`` builds the deterministic id + ``"{dataset_id}|{fov_name}|root{root_tid}|leaf{leaf_tid}"`` and assigns + every track in the branch to it. When ``return_both_branches=True`` and + a mother track appears in multiple branches, the first branch wins (so + the mother shares an id with one branch only — same convention as + before). + + Tracks that do not match any lineage get ``lineage_id = ""`` + (the orphan sentinel). + + Parameters + ---------- + df : pd.DataFrame + Tracking dataframe with columns: dataset_id, fov_name, track_id, + parent_track_id. + return_both_branches : bool + If True (recommended), keep both daughters as separate lineages + sharing the mother. If False, only the first daughter survives. + + Returns + ------- + pd.DataFrame + Input with a string ``lineage_id`` column added. + """ + required = {"dataset_id", "fov_name", "track_id", "parent_track_id"} + missing = required - set(df.columns) + if missing: + raise KeyError(f"assign_lineage_ids missing required columns: {sorted(missing)}") + + if df.empty: + out = df.copy() + out["lineage_id"] = pd.Series(dtype=str) + return out + + track_to_lineage: dict[tuple[str, str, int], str] = {} + + for dataset_id, ds_group in df.groupby("dataset_id", sort=False): + lineages = identify_lineages(ds_group, return_both_branches=return_both_branches) + for fov_name, track_ids in lineages: + if not track_ids: + continue + root_tid = int(track_ids[0]) + leaf_tid = int(track_ids[-1]) + lineage_id = _format_lineage_id(str(dataset_id), str(fov_name), root_tid, leaf_tid) + for tid in track_ids: + key = (str(dataset_id), str(fov_name), int(tid)) + # When return_both_branches=True a mother can appear in multiple + # branches; the first branch wins. + if key not in track_to_lineage: + track_to_lineage[key] = lineage_id + + df = df.copy() + df["lineage_id"] = df.apply( + lambda row: track_to_lineage.get( + (str(row["dataset_id"]), str(row["fov_name"]), int(row["track_id"])), + ORPHAN_LINEAGE_ID, + ), + axis=1, + ) + + n_orphan_tracks = df[df["lineage_id"] == ORPHAN_LINEAGE_ID].groupby(["dataset_id", "fov_name", "track_id"]).ngroups + n_total_tracks = df.groupby(["dataset_id", "fov_name", "track_id"]).ngroups + if n_orphan_tracks: + _logger.info(f"{n_orphan_tracks}/{n_total_tracks} tracks did not match any lineage") + + return df + + +DivideRegime = Literal["none", "pre", "during", "post"] + + +def classify_divides( + df: pd.DataFrame, + t_zero: int, + k_pre_frames: int, + k_post_frames: int, +) -> DivideRegime: + """Classify a lineage's division timing relative to ``t_zero``. + + A division is detected when the lineage contains more than one + track and the daughter's earliest frame falls inside the lineage. + + Parameters + ---------- + df : pd.DataFrame + Rows for a single lineage (one ``lineage_id``). + t_zero : int + Anchor frame (manual ``t_key_event`` or LC first-positive). + k_pre_frames, k_post_frames : int + Transition sub-window half-widths in frames. + + Returns + ------- + str + One of ``"none"``, ``"pre"``, ``"during"``, ``"post"``. + """ + track_ids = df["track_id"].unique() + if len(track_ids) < 2: + return "none" + + # Division frame = earliest frame of the youngest daughter (the track + # whose first frame is largest among non-root tracks). + track_starts = df.groupby("track_id")["t"].min().sort_values() + if len(track_starts) < 2: + return "none" + + division_frame = int(track_starts.iloc[1]) + transition_lo = t_zero - k_pre_frames + transition_hi = t_zero + k_post_frames + + if division_frame < transition_lo: + return "pre" + if division_frame <= transition_hi: + return "during" + return "post" + + +CohortLabel = Literal["productive", "bystander", "abortive", "unannotated_productive", "mock"] + + +def tag_cohort_for_lineage( + df: pd.DataFrame, + well_is_uninfected: bool, + has_manual_t_zero: bool, + lc_predictions: pd.Series | None = None, + min_run: int = 3, + bystander_uninfected_fraction: float = 0.8, +) -> CohortLabel: + """Assign a cohort label to a lineage based on operational rules. + + Per discussion §3.2 and the locked plan: + + - Productive: lineage has a manual ``t_key_event``. + - Bystander: in an infected well, LC says uninfected for at least + ``bystander_uninfected_fraction`` of frames AND has no sustained + positive run. + - Abortive: in an infected well, LC shows at least one positive frame + but no sustained run of ``min_run`` consecutive positives. + - Unannotated productive: in an infected well, LC shows a sustained + positive run but the lineage has no manual anchor. These are the + LC's "extra" calls that the manual annotation missed; downstream + stages may treat them as a separate cohort or exclude them. + - Mock: in an uninfected well. + + Parameters + ---------- + df : pd.DataFrame + Lineage rows. + well_is_uninfected : bool + Whether the lineage's well is in the mock cohort. + has_manual_t_zero : bool + Whether the lineage has a manual anchor. + lc_predictions : pd.Series or None + LC ``predicted_infection_state`` per frame for this lineage + (sorted by ``t``). ``None`` falls back to ``bystander`` for any + non-mock, non-productive lineage. + min_run : int + Minimum consecutive positives to count as a sustained rise. + bystander_uninfected_fraction : float + Fraction of frames a bystander must spend negative (default 0.8). + """ + if well_is_uninfected: + return "mock" + if has_manual_t_zero: + return "productive" + + if lc_predictions is None: + return "bystander" + + pos = (lc_predictions == "infected").astype(int).to_numpy() + if pos.sum() == 0: + return "bystander" + + if _has_run(pos, min_run): + # LC found a sustained run but no manual anchor: an unannotated + # productive candidate. Tag distinctly so downstream stages can + # decide to include, exclude, or compare against the manual set. + return "unannotated_productive" + + fraction_negative = 1.0 - (pos.sum() / len(pos)) + if fraction_negative >= bystander_uninfected_fraction: + return "bystander" + + return "abortive" + + +def _has_run(arr, min_run: int) -> bool: + """Return True if ``arr`` contains at least ``min_run`` consecutive 1s.""" + if min_run <= 0: + return arr.sum() > 0 + run = 0 + for v in arr: + if v: + run += 1 + if run >= min_run: + return True + else: + run = 0 + return False diff --git a/applications/dynaclr/scripts/pseudotime/0-select_candidates/manual_candidates.py b/applications/dynaclr/scripts/pseudotime/0-select_candidates/manual_candidates.py new file mode 100644 index 000000000..d837e0e2f --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/0-select_candidates/manual_candidates.py @@ -0,0 +1,274 @@ +"""Write a manual productive-cohort CSV for debugging the pseudotime pipeline. + +Hand-picked tracks organised as a nested dict keyed by +``(dataset_id, fov_name) → [track specs]``. Each track spec carries the +crop window ``[t_before, t_after]`` and a ``labels`` dict mapping each +label column to a list of ``[t_on, t_off]`` intervals (inclusive). + +``t_key_event`` (the DTW anchor frame) is derived from +:data:`ANCHOR_LABEL` — the first positive frame of that label's +interval list. For ``infection_state`` this is the first ``infected`` +frame. + +Output schema matches :mod:`select_candidates`. The manual path is the +``productive`` cohort only; bystander/abortive/mock cohorts come from +:mod:`select_candidates` (auto path). + +Run:: + + uv run python manual_candidates.py +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import anndata as ad +import pandas as pd + +SCRIPT_DIR = Path(__file__).resolve().parent + +sys.path.insert(0, str(SCRIPT_DIR)) +from lineage_utils import assign_lineage_ids, classify_divides # noqa: E402 + +CANDIDATE_SET = "manual_debug_zikv" +ANCHOR_LABEL = "infection_state" +ANCHOR_POSITIVE = "infected" + +# Output: same schema as select_candidates.py (productive cohort only). +ANNOTATIONS_OUTPUT = SCRIPT_DIR / "candidates" / f"{CANDIDATE_SET}_productive.csv" + +# Label columns and their (positive, negative) values. +LABEL_VALUES: dict[str, tuple[str, str]] = { + "infection_state": ("infected", "uninfected"), + "organelle_state": ("remodeled", "noremodeled"), + "cell_division_state": ("mitosis", "interphase"), +} + +# Per-dataset metadata. ``embedding_zarr`` validates each spec against +# the embedding ``.obs`` and supplies ``parent_track_id`` when not given +# in the spec. ``transition_window_*_minutes`` set the divides classifier +# half-widths (must match candidates.yaml lineage_rules for the +# corresponding candidate set). +DATASETS: dict[str, dict] = { + "2025_07_24_SEC61": { + "frame_interval_minutes": 30.0, + "embedding_zarr": ( + "/hpc/projects/organelle_phenotyping/models/DynaCLR-2D-MIP-BagOfChannels/" + "2d-mip-ntxent-t0p2-lr2e5-bs256-192to160-zext11/evaluation_lc_v1/embeddings/" + "2025_07_24_A549_viral_sensor_ZIKV.zarr" + ), + "transition_window_k_pre_minutes": 60.0, + "transition_window_k_post_minutes": 120.0, + }, +} + +# (dataset_id, fov_name) -> list of track specs. +# Each spec: track_id, t_before, t_after, labels. +# ``parent_track_id`` is optional; when missing, it's pulled from the +# embedding .obs during validation. +# ``labels`` maps label column -> list of [t_on, t_off] intervals (inclusive). +TRACKS: dict[tuple[str, str], list[dict]] = { + ("2025_07_24_SEC61", "A/2/000000"): [ + { + "track_id": 86, + "t_before": 14, + "t_after": 27, + "labels": {"infection_state": [[20, 27]]}, + }, + { + "track_id": 65, + "t_before": 14, + "t_after": 27, + "labels": {"infection_state": [[19, 27]]}, + }, + { + "track_id": 40, + "t_before": 14, + "t_after": 27, + "labels": {"infection_state": [[19, 27]]}, + }, + { + "track_id": 60, + "t_before": 14, + "t_after": 27, + "labels": {"infection_state": [[19, 27]]}, + }, + ], +} + + +def _t_key_event(spec: dict) -> int: + """First positive frame of ``ANCHOR_LABEL`` — the per-cell anchor.""" + labels = spec.get("labels", {}) + if ANCHOR_LABEL not in labels or not labels[ANCHOR_LABEL]: + raise ValueError(f"Track {spec.get('track_id')!r} is missing an anchor interval for '{ANCHOR_LABEL}'.") + return int(labels[ANCHOR_LABEL][0][0]) + + +def _validate_against_anndata(dataset_id: str, fov_name: str, tracks: list[dict]) -> dict[int, int]: + """Validate each track spec against the embedding ``.obs``. + + Returns a mapping ``track_id → parent_track_id`` from the embedding's + obs so manual specs need not repeat the parent ids. Raises + :class:`ValueError` listing every problem in one edit pass. + """ + ds_meta = DATASETS[dataset_id] + emb_path = ds_meta["embedding_zarr"] + adata = ad.read_zarr(emb_path) + adata.obs_names_make_unique() + obs = adata.obs + fov_obs = obs[obs["fov_name"].astype(str) == fov_name] + + problems: list[str] = [] + parent_lookup: dict[int, int] = {} + + for spec in tracks: + tid = int(spec["track_id"]) + track_obs = fov_obs[fov_obs["track_id"].astype(int) == tid] + if len(track_obs) == 0: + problems.append(f" {dataset_id} {fov_name} track_id={tid}: not found in {emb_path}") + continue + track_tps = set(track_obs["t"].astype(int).tolist()) + + if "parent_track_id" in track_obs.columns: + parent_lookup[tid] = int(track_obs.iloc[0]["parent_track_id"]) + else: + parent_lookup[tid] = int(spec.get("parent_track_id", -1)) + + for col in ("t_before", "t_after"): + if int(spec[col]) not in track_tps: + problems.append( + f" {dataset_id} {fov_name} track_id={tid}: {col}={spec[col]} not in track " + f"(min={min(track_tps)}, max={max(track_tps)})" + ) + + labels = spec.get("labels", {}) + if ANCHOR_LABEL not in labels or not labels[ANCHOR_LABEL]: + problems.append(f" {dataset_id} {fov_name} track_id={tid}: missing anchor label '{ANCHOR_LABEL}'") + for label_col, intervals in labels.items(): + if label_col not in LABEL_VALUES: + problems.append( + f" {dataset_id} {fov_name} track_id={tid}: unknown label column '{label_col}' " + f"(known: {list(LABEL_VALUES)})" + ) + for interval in intervals: + if len(interval) != 2 or interval[0] > interval[1]: + problems.append( + f" {dataset_id} {fov_name} track_id={tid}: {label_col} interval {interval} " + f"is not a valid [t_on, t_off] pair" + ) + continue + if interval[0] not in track_tps or interval[1] not in track_tps: + problems.append( + f" {dataset_id} {fov_name} track_id={tid}: {label_col} interval {interval} " + f"has frame(s) outside the track (min={min(track_tps)}, max={max(track_tps)})" + ) + + if problems: + raise ValueError("Manual candidate validation failed:\n" + "\n".join(problems)) + + return parent_lookup + + +def build_annotation_rows() -> pd.DataFrame: + """Expand every track spec into per-timepoint rows in the new schema.""" + rows = [] + parent_by_dataset: dict[str, dict[int, int]] = {} + for (dataset_id, fov_name), tracks in TRACKS.items(): + parent_by_dataset[dataset_id] = _validate_against_anndata(dataset_id, fov_name, tracks) + + for (dataset_id, fov_name), tracks in TRACKS.items(): + parent_lookup = parent_by_dataset[dataset_id] + for spec in tracks: + tid = int(spec["track_id"]) + t_before = int(spec["t_before"]) + t_after = int(spec["t_after"]) + intervals = spec.get("labels", {}) + parent_tid = int(spec.get("parent_track_id", parent_lookup.get(tid, -1))) + + for t in range(t_before, t_after + 1): + row = { + "dataset_id": dataset_id, + "fov_name": fov_name, + "track_id": tid, + "parent_track_id": parent_tid, + "t": t, + } + for label_col, (pos, neg) in LABEL_VALUES.items(): + spans = intervals.get(label_col) + if spans is None: + row[label_col] = "" + elif any(lo <= t <= hi for lo, hi in spans): + row[label_col] = pos + else: + row[label_col] = neg + rows.append(row) + return pd.DataFrame(rows) + + +def main() -> None: + """Validate specs, reconnect lineages, classify divides, write CSV.""" + df = build_annotation_rows() + df = assign_lineage_ids(df, return_both_branches=True) + + # Per-cell t_zero from the manual anchor. + t_zero_lookup: dict[str, int] = {} + for (dataset_id, fov_name), tracks in TRACKS.items(): + for spec in tracks: + tid = int(spec["track_id"]) + mask = (df["dataset_id"] == dataset_id) & (df["fov_name"] == fov_name) & (df["track_id"] == tid) + if not mask.any(): + continue + lineage_id = str(df.loc[mask, "lineage_id"].iloc[0]) + if not lineage_id: + continue + # Set t_zero per lineage to the earliest manual anchor across + # tracks in that lineage. + t_anchor = _t_key_event(spec) + t_zero_lookup[lineage_id] = min(t_zero_lookup.get(lineage_id, t_anchor), t_anchor) + + # Divides classification per lineage using its dataset's frame interval. + divides: dict[str, str] = {} + for lineage_id, lineage_df in df.groupby("lineage_id"): + if not lineage_id: + divides[lineage_id] = "none" + continue + t_zero = t_zero_lookup.get(str(lineage_id)) + if t_zero is None: + divides[lineage_id] = "none" + continue + ds_id = str(lineage_df["dataset_id"].iloc[0]) + ds_meta = DATASETS[ds_id] + frame_interval = float(ds_meta["frame_interval_minutes"]) + k_pre = int(round(float(ds_meta["transition_window_k_pre_minutes"]) / frame_interval)) + k_post = int(round(float(ds_meta["transition_window_k_post_minutes"]) / frame_interval)) + divides[lineage_id] = classify_divides(lineage_df, t_zero, k_pre, k_post) + + df["cohort"] = "productive" + df["divides"] = df["lineage_id"].map(divides).fillna("none") + + output_columns = [ + "dataset_id", + "fov_name", + "lineage_id", + "track_id", + "parent_track_id", + "t", + "cohort", + "divides", + "infection_state", + "organelle_state", + "cell_division_state", + ] + out = df[output_columns] + + ANNOTATIONS_OUTPUT.parent.mkdir(parents=True, exist_ok=True) + out.to_csv(ANNOTATIONS_OUTPUT, index=False) + print(f"Wrote {len(out)} rows to {ANNOTATIONS_OUTPUT}") + print(out.head(10).to_string(index=False)) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/0-select_candidates/reconnect_lineages.py b/applications/dynaclr/scripts/pseudotime/0-select_candidates/reconnect_lineages.py new file mode 100644 index 000000000..d1c41ccb9 --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/0-select_candidates/reconnect_lineages.py @@ -0,0 +1,83 @@ +"""Inspect lineage reconnection on a candidate-set's productive cohort. + +Reads the productive cohort CSV (output of :mod:`select_candidates` or +:mod:`manual_candidates`), reconnects lineages via ``parent_track_id``, +and prints a summary of: + +- lineages with two or more tracks (mother + daughter chains) +- distribution of ``divides ∈ {none, pre, during, post}`` +- per-lineage track lists for visual inspection + +This is read-only — it does not rewrite the cohort CSV. It exists to +verify Phase 1 lineage reconnection is doing what we expect before +running the rest of the pipeline. + +Usage:: + + cd applications/dynaclr/scripts/pseudotime/0-select_candidates + uv run python reconnect_lineages.py \ + --candidate-set zikv_productive_07_24 +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import pandas as pd + +SCRIPT_DIR = Path(__file__).resolve().parent +CANDIDATES_DIR = SCRIPT_DIR / "candidates" + + +def main() -> None: + """Print a per-cohort lineage report (multi-track lineages + divides distribution).""" + parser = argparse.ArgumentParser(description="Inspect lineage reconnection on a candidate set.") + parser.add_argument("--candidate-set", required=True, help="Candidate-set name") + parser.add_argument( + "--cohort", + default="productive", + choices=["productive", "bystander", "abortive", "unannotated_productive", "mock"], + help="Cohort CSV to inspect (default: productive)", + ) + args = parser.parse_args() + + csv_path = CANDIDATES_DIR / f"{args.candidate_set}_{args.cohort}.csv" + if not csv_path.exists(): + raise FileNotFoundError(f"{csv_path} not found. Run select_candidates.py or manual_candidates.py first.") + + df = pd.read_csv(csv_path) + print(f"# Lineage report — {csv_path.name}\n") + print(f"- {len(df):,} rows") + print(f"- {df.groupby(['dataset_id', 'fov_name', 'track_id']).ngroups:,} unique tracks") + print(f"- {df['lineage_id'].nunique():,} unique lineages") + print() + + # Multi-track lineages — these are the mother+daughter chains. + track_counts = df.groupby("lineage_id")["track_id"].nunique() + multi_track = track_counts[track_counts > 1].sort_values(ascending=False) + print(f"## Multi-track lineages ({len(multi_track):,}/{len(track_counts):,})\n") + if len(multi_track): + for lineage_id, n_tracks in multi_track.head(20).items(): + sub = df[df["lineage_id"] == lineage_id] + track_ids = sorted(sub["track_id"].unique()) + divides = sub["divides"].iloc[0] + ds = sub["dataset_id"].iloc[0] + fov = sub["fov_name"].iloc[0] + print(f"- lineage {lineage_id} | {ds} {fov} | tracks {track_ids} | divides={divides}") + if len(multi_track) > 20: + print(f"- ... ({len(multi_track) - 20} more)") + else: + print("(none — every lineage is a single track)") + print() + + # Divides distribution per cohort. + divides_by_lineage = df.drop_duplicates(["lineage_id"])[["lineage_id", "divides"]] + print("## Divides distribution\n") + counts = divides_by_lineage["divides"].value_counts() + for k in ("none", "pre", "during", "post"): + print(f"- {k}: {counts.get(k, 0)}") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/0-select_candidates/select_candidates.py b/applications/dynaclr/scripts/pseudotime/0-select_candidates/select_candidates.py new file mode 100644 index 000000000..11cc21ba6 --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/0-select_candidates/select_candidates.py @@ -0,0 +1,871 @@ +r"""Auto-select candidates and tag cohorts for Stage 0 of the new DAG. + +Reads per-dataset annotation CSVs and produces four cohort-tagged +annotation CSVs per candidate set: + +- ``{candidate_set}_productive.csv`` — transitioning lineages with a + manual ``t_key_event``-equivalent anchor (first ``infected`` frame). +- ``{candidate_set}_bystander.csv`` — lineages in infected wells whose + LC prediction stays mostly uninfected. +- ``{candidate_set}_abortive.csv`` — lineages in infected wells with a + brief positive run shorter than ``min_run`` and no sustained rise. +- ``{candidate_set}_mock.csv`` — lineages from uninfected control wells. + +Per discussion §3.2 and the execution plan: the cohort definitions are +operational and LC-derived. ``parent_track_id`` is preserved through the +pipeline so ``identify_lineages()`` can reconnect mother + daughter into +shared ``lineage_id``. + +Usage:: + + cd applications/dynaclr/scripts/pseudotime/0-select_candidates + uv run python select_candidates.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/candidates.yaml \ + --candidate-set zikv_productive_07_24 +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +import anndata as ad +import pandas as pd + +SCRIPT_DIR = Path(__file__).resolve().parent +CANDIDATES_DIR = SCRIPT_DIR / "candidates" + +sys.path.insert(0, str(SCRIPT_DIR.parent)) +sys.path.insert(0, str(SCRIPT_DIR)) +from lineage_utils import ( # noqa: E402 + assign_lineage_ids, + classify_divides, + tag_cohort_for_lineage, +) +from utils import load_stage_config # noqa: E402 + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") +_logger = logging.getLogger(__name__) + +LABEL_VALUES: dict[str, tuple[str, str]] = { + "infection_state": ("infected", "uninfected"), + "organelle_state": ("remodeled", "noremodeled"), + "cell_division_state": ("mitosis", "interphase"), +} + +# Output schema columns (per DAG §11). +OUTPUT_COLUMNS = [ + "dataset_id", + "fov_name", + "lineage_id", + "track_id", + "parent_track_id", + "t", + "cohort", + "divides", + "infection_state", + "organelle_state", + "cell_division_state", +] + + +def _select_productive_tracks( + ann_df: pd.DataFrame, + dataset_id: str, + fov_pattern: str, + filter_cfg: dict, + frame_interval_minutes: float, +) -> pd.DataFrame: + """Pick transitioning tracks from one dataset's annotations. + + Preserves ``parent_track_id`` (the previous version dropped it). + + Parameters + ---------- + ann_df : pd.DataFrame + Raw annotations CSV. + dataset_id, fov_pattern : str + Dataset id + FOV substring filter. + filter_cfg : dict + Per-candidate-set ``productive_filter`` dict. + frame_interval_minutes : float + Dataset-specific frame interval for minute → frame conversions. + + Returns + ------- + pd.DataFrame + Per-frame rows with productive cohort tag, lineage_id assigned later. + """ + anchor_label = filter_cfg["anchor_label"] + anchor_positive = filter_cfg["anchor_positive"] + anchor_negative = filter_cfg.get("anchor_negative", "uninfected") + min_pre = float(filter_cfg.get("min_pre_minutes", 0)) + min_post = float(filter_cfg.get("min_post_minutes", 0)) + crop_window_minutes = filter_cfg["crop_window_minutes"] + + pre_frames = int(round(min_pre / frame_interval_minutes)) + post_frames = int(round(min_post / frame_interval_minutes)) + crop_half = int(round(float(crop_window_minutes) / frame_interval_minutes)) + + sub = ann_df[ann_df["fov_name"].astype(str).str.contains(fov_pattern, regex=False)].copy() + if sub.empty: + _logger.warning(f"[{dataset_id}] no rows match fov_pattern {fov_pattern!r}") + return pd.DataFrame(columns=OUTPUT_COLUMNS) + + rows: list[dict] = [] + for (fov, tid), g in sub.groupby(["fov_name", "track_id"]): + g = g.sort_values("t") + states = set(g[anchor_label].dropna()) + if anchor_positive not in states or anchor_negative not in states: + continue + + t_onset = int(g[g[anchor_label] == anchor_positive]["t"].min()) + t_min, t_max = int(g["t"].min()), int(g["t"].max()) + if (t_onset - t_min) < pre_frames or (t_max - t_onset) < post_frames: + continue + + t_before = max(t_min, t_onset - crop_half) + t_after = min(t_max, t_onset + crop_half) + + parent_id = int(g.iloc[0].get("parent_track_id", -1)) + + in_window = g[(g["t"] >= t_before) & (g["t"] <= t_after)] + for _, r in in_window.iterrows(): + row = { + "dataset_id": dataset_id, + "fov_name": str(fov), + "track_id": int(tid), + "parent_track_id": parent_id, + "t": int(r["t"]), + } + for label_col in LABEL_VALUES: + if label_col in r and pd.notna(r[label_col]) and r[label_col] != "": + row[label_col] = r[label_col] + else: + row[label_col] = "" + rows.append(row) + + return pd.DataFrame(rows) + + +def _select_productive_from_zarr( + dataset_cfg: dict, + dataset_id: str, + fov_pattern: str, + filter_cfg: dict, + frame_interval_minutes: float, + embedding_pattern: str, + pred_column: str = "predicted_infection_state", + min_run: int = 3, +) -> pd.DataFrame: + """Pick productive tracks from the embedding zarr's LC predictions. + + Used for datasets without a track-linked manual annotation CSV (e.g. + 08_26 SEC61). Anchor is the first frame of a sustained run of + ``min_run`` consecutive ``predicted_infection_state == anchor_positive`` + predictions. The pre/post window requirements match + ``_select_productive_tracks``. + + Per-frame rows for each surviving track are emitted with the manual + annotation columns left blank (only the LC-derived anchor is real). + Downstream A-LC alignment will recompute the LC anchor; A-anno will + have no anchor for these cells (NaN ``t_zero``). + """ + anchor_positive = filter_cfg["anchor_positive"] + anchor_negative = filter_cfg.get("anchor_negative", "uninfected") + min_pre = float(filter_cfg.get("min_pre_minutes", 0)) + min_post = float(filter_cfg.get("min_post_minutes", 0)) + crop_window_minutes = filter_cfg["crop_window_minutes"] + + pre_frames = int(round(min_pre / frame_interval_minutes)) + post_frames = int(round(min_post / frame_interval_minutes)) + crop_half = int(round(float(crop_window_minutes) / frame_interval_minutes)) + + pred_dir = Path(dataset_cfg["pred_dir"]) + date_prefix = "_".join(dataset_id.split("_")[:3]) + matches = [m for m in pred_dir.glob(embedding_pattern) if m.name.startswith(date_prefix)] + if not matches: + _logger.warning( + f"[{dataset_id}] no zarr matched {embedding_pattern} with prefix {date_prefix}; " + "productive_source=lc_zarr produces empty cohort" + ) + return pd.DataFrame(columns=OUTPUT_COLUMNS) + + adata = ad.read_zarr(matches[0]) + adata.obs_names_make_unique() + if pred_column not in adata.obs.columns: + _logger.warning(f"[{dataset_id}] {pred_column} not in {matches[0].name}; productive empty") + return pd.DataFrame(columns=OUTPUT_COLUMNS) + + obs = adata.obs.copy() + obs = obs[obs["fov_name"].astype(str).str.contains(fov_pattern, regex=False)] + if obs.empty: + return pd.DataFrame(columns=OUTPUT_COLUMNS) + + rows: list[dict] = [] + n_anchored = 0 + for (fov, tid), g in obs.groupby(["fov_name", "track_id"]): + g = g.sort_values("t") + states = set(g[pred_column].dropna().unique()) - {""} + if anchor_positive not in states or anchor_negative not in states: + continue + + # Find the first sustained run of `min_run` consecutive positives. + positive_mask = (g[pred_column] == anchor_positive).to_numpy() + run_start = None + run = 0 + for i, v in enumerate(positive_mask): + if v: + run += 1 + if run >= min_run: + run_start = i - min_run + 1 + break + else: + run = 0 + if run_start is None: + continue + + t_onset = int(g["t"].iloc[run_start]) + t_min, t_max = int(g["t"].min()), int(g["t"].max()) + if (t_onset - t_min) < pre_frames or (t_max - t_onset) < post_frames: + continue + + t_before = max(t_min, t_onset - crop_half) + t_after = min(t_max, t_onset + crop_half) + + parent_id = int(g["parent_track_id"].iloc[0]) if "parent_track_id" in g.columns else -1 + + in_window = g[(g["t"] >= t_before) & (g["t"] <= t_after)] + n_anchored += 1 + for _, r in in_window.iterrows(): + row = { + "dataset_id": dataset_id, + "fov_name": str(fov), + "track_id": int(tid), + "parent_track_id": parent_id, + "t": int(r["t"]), + } + for label_col in LABEL_VALUES: + row[label_col] = "" + rows.append(row) + + if n_anchored: + _logger.info( + f"[{dataset_id}] productive_source=lc_zarr: {n_anchored} tracks anchored from " + f"{matches[0].name} (fov={fov_pattern}, min_run={min_run})" + ) + return pd.DataFrame(rows) + + +def _organelle_zarr_pattern(dataset_id: str, embeddings: dict[str, str]) -> str: + """Return the organelle-channel zarr pattern for ``dataset_id``. + + Maps the dataset suffix (e.g. ``SEC61``) to ``embeddings.organelle_sec61``. + Falls back to ``embeddings.sensor`` if no per-organelle entry exists. + """ + suffix = dataset_id.split("_")[-1].lower() + key = f"organelle_{suffix}" + return embeddings.get(key, embeddings.get("sensor", "*.zarr")) + + +def _select_mock_from_zarr( + dataset_cfg: dict, + dataset_id: str, + fov_pattern: str, + embedding_pattern: str, + min_track_minutes: float, + frame_interval_minutes: float, +) -> pd.DataFrame: + """Pull mock cells directly from the embedding zarr's .obs. + + Fallback path for control wells that were never manually annotated + (e.g. SEC61's A/1 in 07_24). The well is uninfected by experimental + design; we synthesize ``infection_state="uninfected"`` and leave + other annotation columns blank, then run the resulting frame through + the same lineage and cohort machinery as annotation-derived cohorts. + Reads ``parent_track_id`` from the zarr so lineage reconnection works. + """ + pred_dir = Path(dataset_cfg["pred_dir"]) + date_prefix = "_".join(dataset_id.split("_")[:3]) + matches = [m for m in pred_dir.glob(embedding_pattern) if m.name.startswith(date_prefix)] + if not matches: + _logger.warning(f"[{dataset_id}] no zarr matched {embedding_pattern} with prefix {date_prefix}") + return pd.DataFrame(columns=OUTPUT_COLUMNS) + + adata = ad.read_zarr(matches[0]) + adata.obs_names_make_unique() + obs = adata.obs.copy() + obs = obs[obs["fov_name"].astype(str).str.contains(fov_pattern, regex=False)] + if obs.empty: + return pd.DataFrame(columns=OUTPUT_COLUMNS) + + min_frames = int(round(min_track_minutes / frame_interval_minutes)) + rows: list[dict] = [] + for (fov, tid), g in obs.groupby(["fov_name", "track_id"]): + if len(g) < min_frames: + continue + g = g.sort_values("t") + parent_id = int(g["parent_track_id"].iloc[0]) if "parent_track_id" in g.columns else -1 + for _, r in g.iterrows(): + rows.append( + { + "dataset_id": dataset_id, + "fov_name": str(fov), + "track_id": int(tid), + "parent_track_id": parent_id, + "t": int(r["t"]), + "infection_state": "uninfected", + "organelle_state": "", + "cell_division_state": "", + } + ) + df = pd.DataFrame(rows) + if not df.empty: + n_tracks = df.groupby(["fov_name", "track_id"]).ngroups + _logger.info( + f"[{dataset_id}] pulled {n_tracks} mock tracks from zarr {matches[0].name} " + f"(fov={fov_pattern}; annotation CSV had none)" + ) + return df + + +def _select_well_tracks( + ann_df: pd.DataFrame, + dataset_id: str, + fov_pattern: str, + min_track_minutes: float, + frame_interval_minutes: float, +) -> pd.DataFrame: + """Pull every track from a well, no anchor filter. + + Used for bystander/abortive (infected well) and mock (uninfected well) + cohorts where we keep every long-enough track. + + Parameters + ---------- + min_track_minutes : float + Drop tracks shorter than this many real minutes. + """ + min_frames = int(round(min_track_minutes / frame_interval_minutes)) + + sub = ann_df[ann_df["fov_name"].astype(str).str.contains(fov_pattern, regex=False)].copy() + if sub.empty: + return pd.DataFrame(columns=OUTPUT_COLUMNS) + + rows: list[dict] = [] + for (fov, tid), g in sub.groupby(["fov_name", "track_id"]): + if len(g) < min_frames: + continue + g = g.sort_values("t") + parent_id = int(g.iloc[0].get("parent_track_id", -1)) + for _, r in g.iterrows(): + row = { + "dataset_id": dataset_id, + "fov_name": str(fov), + "track_id": int(tid), + "parent_track_id": parent_id, + "t": int(r["t"]), + } + for label_col in LABEL_VALUES: + if label_col in r and pd.notna(r[label_col]) and r[label_col] != "": + row[label_col] = r[label_col] + else: + row[label_col] = "" + rows.append(row) + + return pd.DataFrame(rows) + + +def _load_lc_predictions( + dataset_cfg: dict, + dataset_id: str, + fov_pattern: str, + embedding_pattern: str, + pred_column: str = "predicted_infection_state", +) -> pd.DataFrame: + """Pull LC predictions from the embedding zarr's obs. + + The pred_dir typically contains many zarrs across dates and channels; + we filter by the dataset's date prefix (first three underscore-separated + tokens of ``dataset_id``, e.g. ``2025_07_24``) before glob-matching the + channel pattern. Without the date filter, multiple datasets' zarrs + would all match ``*_viral_sensor_*.zarr`` and we'd pick an arbitrary + (often wrong) one. + + Returns empty DataFrame if no embedding zarr matches. + """ + pred_dir = Path(dataset_cfg["pred_dir"]) + date_prefix = "_".join(dataset_id.split("_")[:3]) # e.g. "2025_07_24" + matches = [m for m in pred_dir.glob(embedding_pattern) if m.name.startswith(date_prefix)] + if not matches: + _logger.warning( + f"No embedding zarr matched {embedding_pattern} with date prefix " + f"{date_prefix} in {pred_dir}; LC cohorts fall back to bystander default" + ) + return pd.DataFrame() + if len(matches) > 1: + _logger.warning( + f"Multiple zarrs matched {embedding_pattern} with date prefix " + f"{date_prefix} for {dataset_id}: {[m.name for m in matches]}; using first" + ) + + adata = ad.read_zarr(matches[0]) + adata.obs_names_make_unique() + if pred_column not in adata.obs.columns: + _logger.warning(f"{pred_column} not in {matches[0]} .obs; LC fallback") + return pd.DataFrame() + + obs = adata.obs[["fov_name", "track_id", "t", pred_column]].copy() + obs = obs[obs["fov_name"].astype(str).str.contains(fov_pattern, regex=False)] + obs["fov_name"] = obs["fov_name"].astype(str) + obs["track_id"] = obs["track_id"].astype(int) + obs["t"] = obs["t"].astype(int) + _logger.info(f"[{dataset_id}] loaded LC predictions from {matches[0].name} ({len(obs)} rows in {fov_pattern})") + return obs + + +def _t_zero_per_lineage( + productive_df: pd.DataFrame, + anchor_label: str, + anchor_positive: str, +) -> dict[str, int]: + """First frame where ``anchor_label == anchor_positive`` per lineage.""" + out: dict[str, int] = {} + for lineage_id, g in productive_df.groupby("lineage_id"): + if not lineage_id: + continue + positive = g[g[anchor_label] == anchor_positive] + if positive.empty: + continue + out[str(lineage_id)] = int(positive["t"].min()) + return out + + +def _well_is_uninfected(fov_name: str, mock_well_patterns: list[str]) -> bool: + return any(p in str(fov_name) for p in mock_well_patterns) + + +def _emit_cohort( + df: pd.DataFrame, + cohort: str, + t_zero_lookup: dict[str, int], + window_frames_by_dataset: dict[str, tuple[int, int]], +) -> pd.DataFrame: + """Add cohort + divides columns and order the output schema. + + ``window_frames_by_dataset`` maps ``dataset_id`` to ``(k_pre_frames, + k_post_frames)``. Each lineage lives in exactly one dataset, so we + look up its dataset's frame conversions to classify ``divides``. + """ + df = df.copy() + df["cohort"] = cohort + + divides_per_lineage: dict[str, str] = {} + for lineage_id, g in df.groupby("lineage_id"): + if not lineage_id: + divides_per_lineage[lineage_id] = "none" + continue + t_zero = t_zero_lookup.get(str(lineage_id)) + if t_zero is None: + divides_per_lineage[lineage_id] = "none" + continue + dataset_id = str(g["dataset_id"].iloc[0]) + k_pre_frames, k_post_frames = window_frames_by_dataset[dataset_id] + divides_per_lineage[lineage_id] = classify_divides(g, t_zero, k_pre_frames, k_post_frames) + + df["divides"] = df["lineage_id"].map(divides_per_lineage).fillna("none") + + for col in OUTPUT_COLUMNS: + if col not in df.columns: + df[col] = "" + return df[OUTPUT_COLUMNS] + + +def _build_dataset_cohorts( + dataset_id: str, + dataset_cfg: dict, + cand_cfg: dict, + embedding_pattern: str, + embeddings: dict[str, str] | None = None, +) -> dict[str, pd.DataFrame]: + """Produce productive / bystander / abortive / mock dataframes for one dataset. + + Parameters + ---------- + dataset_id : str + Dataset key from ``datasets.yaml``. + dataset_cfg : dict + Single dataset entry from ``datasets.yaml``. + cand_cfg : dict + Candidate-set entry from ``candidates.yaml``. + embedding_pattern : str + Glob pattern for the LC-prediction embedding zarr (e.g. sensor channel). + embeddings : dict[str, str], optional + Full ``embeddings`` mapping from ``datasets.yaml``. Used to pick the + per-dataset organelle-channel pattern for the mock-from-zarr fallback, + since control wells (e.g. SEC61's A/1) may be absent from the sensor + zarr but present in the organelle zarr. + """ + fov_pattern = dataset_cfg.get("fov_pattern", "") + frame_interval = float(dataset_cfg["frame_interval_minutes"]) + ann_path = Path(dataset_cfg["annotations_path"]) + _logger.info(f"[{dataset_id}] reading {ann_path}") + ann_df = pd.read_csv(ann_path) + + productive_filter = cand_cfg["productive_filter"] + cohort_rules = cand_cfg.get("cohort_rules", {}) + lineage_rules = cand_cfg.get("lineage_rules", {}) + + # Mock wells are per-dataset: each dataset's `control_fov_pattern` + # in datasets.yaml names the well that has the same imaging channel + # as that dataset (e.g. A/1 for SEC61, C/1 for G3BP1). The cohort- + # level `mock_well_patterns` is a fallback that applies the same + # well list to every dataset — keeps backwards-compat with + # candidate-set configs that don't yet split per-dataset. + ds_control_pattern = dataset_cfg.get("control_fov_pattern") + if ds_control_pattern: + mock_patterns: list[str] = [ds_control_pattern] + else: + mock_patterns = list(cohort_rules.get("mock_well_patterns", [])) + bystander_fraction = float(cohort_rules.get("bystander_uninfected_fraction", 0.8)) + abortive_min_run = int(cohort_rules.get("abortive_min_run", 3)) + # Non-productive cohorts (bystander, abortive, mock) only need enough + # frames to compute LC-run statistics for cohort tagging — they don't + # need the productive filter's pre/post window. Decoupling lets short + # tracks (e.g. 07_22 at 10 min/frame) contribute to comparison cohorts + # without weakening the productive definition. + min_non_productive_minutes = float(cohort_rules.get("min_non_productive_minutes", 300.0)) + + k_pre = int(round(float(lineage_rules.get("transition_window_k_pre_minutes", 60)) / frame_interval)) + k_post = int(round(float(lineage_rules.get("transition_window_k_post_minutes", 120)) / frame_interval)) + + # 1) Productive cohort from the infected well. ``productive_source`` + # in datasets.yaml selects between annotation CSV (default) and + # LC-from-zarr (for datasets without track-linked manual annotations, + # e.g. 08_26_SEC61). + productive_source = dataset_cfg.get("productive_source", "annotation_csv") + if productive_source == "lc_zarr": + productive_df = _select_productive_from_zarr( + dataset_cfg=dataset_cfg, + dataset_id=dataset_id, + fov_pattern=fov_pattern, + filter_cfg=productive_filter, + frame_interval_minutes=frame_interval, + embedding_pattern=embedding_pattern, + min_run=abortive_min_run, + ) + elif productive_source == "annotation_csv": + productive_df = _select_productive_tracks( + ann_df, + dataset_id=dataset_id, + fov_pattern=fov_pattern, + filter_cfg=productive_filter, + frame_interval_minutes=frame_interval, + ) + else: + raise ValueError( + f"[{dataset_id}] unknown productive_source={productive_source!r}; must be 'annotation_csv' or 'lc_zarr'" + ) + + # 2) Mock cohort from uninfected control wells. + # First try the annotation CSV; if it has no rows for the control + # well (common for never-annotated control wells), fall back to the + # embedding zarr's .obs and synthesize the cohort with + # infection_state="uninfected" (true by well design). + mock_parts: list[pd.DataFrame] = [] + min_track_minutes = min_non_productive_minutes + for pat in mock_patterns: + ctrl_df = _select_well_tracks( + ann_df, + dataset_id=dataset_id, + fov_pattern=pat, + min_track_minutes=min_track_minutes, + frame_interval_minutes=frame_interval, + ) + if ctrl_df.empty: + mock_pattern = _organelle_zarr_pattern(dataset_id, embeddings) if embeddings else embedding_pattern + ctrl_df = _select_mock_from_zarr( + dataset_cfg=dataset_cfg, + dataset_id=dataset_id, + fov_pattern=pat, + embedding_pattern=mock_pattern, + min_track_minutes=min_track_minutes, + frame_interval_minutes=frame_interval, + ) + if not ctrl_df.empty: + mock_parts.append(ctrl_df) + mock_df = pd.concat(mock_parts, ignore_index=True) if mock_parts else pd.DataFrame(columns=OUTPUT_COLUMNS) + + # 3) Bystander + abortive: every track in the infected well that's + # not in the productive cohort. LC predictions discriminate. + well_df = _select_well_tracks( + ann_df, + dataset_id=dataset_id, + fov_pattern=fov_pattern, + min_track_minutes=min_non_productive_minutes, + frame_interval_minutes=frame_interval, + ) + + if productive_df.empty: + productive_track_keys: set[tuple[str, int]] = set() + else: + productive_track_keys = set(zip(productive_df["fov_name"].astype(str), productive_df["track_id"].astype(int))) + well_non_productive_df = well_df[ + ~well_df.apply( + lambda r: (str(r["fov_name"]), int(r["track_id"])) in productive_track_keys, + axis=1, + ) + ].copy() + + lc_obs = _load_lc_predictions( + dataset_cfg, + dataset_id=dataset_id, + fov_pattern=fov_pattern, + embedding_pattern=embedding_pattern, + ) + + return { + "productive": productive_df, + "well_non_productive": well_non_productive_df, # to be split into bystander/abortive + "mock": mock_df, + "_meta": { + "fov_pattern": fov_pattern, + "frame_interval": frame_interval, + "mock_patterns": mock_patterns, + "bystander_fraction": bystander_fraction, + "abortive_min_run": abortive_min_run, + "k_pre": k_pre, + "k_post": k_post, + "lc_obs": lc_obs, + }, + } + + +def _split_well_non_productive( + well_df: pd.DataFrame, + lc_obs: pd.DataFrame, + bystander_fraction: float, + abortive_min_run: int, +) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """Split non-productive infected-well tracks into bystander, abortive, unannotated_productive. + + Returns three dataframes in that order. ``unannotated_productive`` + captures lineages with sustained LC positive runs that lack a manual + anchor — they're not in the productive cohort but the LC says they + look infected. Reported separately so downstream stages can decide. + """ + empty = well_df.iloc[0:0].copy() + if well_df.empty: + return empty, empty, empty + + if lc_obs is None or len(lc_obs) == 0: + # No LC: everything falls back to bystander (conservative default). + return well_df.copy(), empty, empty + + bystander_keys: set[tuple[str, int]] = set() + abortive_keys: set[tuple[str, int]] = set() + unannotated_keys: set[tuple[str, int]] = set() + for (fov, tid), g in lc_obs.groupby(["fov_name", "track_id"]): + g = g.sort_values("t") + cohort = tag_cohort_for_lineage( + df=g, + well_is_uninfected=False, + has_manual_t_zero=False, + lc_predictions=g["predicted_infection_state"], + min_run=abortive_min_run, + bystander_uninfected_fraction=bystander_fraction, + ) + key = (str(fov), int(tid)) + if cohort == "bystander": + bystander_keys.add(key) + elif cohort == "abortive": + abortive_keys.add(key) + elif cohort == "unannotated_productive": + unannotated_keys.add(key) + + def _filter(keys: set[tuple[str, int]]) -> pd.DataFrame: + return well_df[well_df.apply(lambda r: (str(r["fov_name"]), int(r["track_id"])) in keys, axis=1)].copy() + + return _filter(bystander_keys), _filter(abortive_keys), _filter(unannotated_keys) + + +def _funnel_report( + candidate_set: str, + cohort_dfs: dict[str, pd.DataFrame], +) -> str: + """Render a markdown funnel report (lineage and frame counts per cohort).""" + lines = [f"# Funnel — {candidate_set}", ""] + lines.append("| cohort | n_lineages | n_tracks | n_frames | divides=none | pre | during | post |") + lines.append("|---|---:|---:|---:|---:|---:|---:|---:|") + for cohort, df in cohort_dfs.items(): + if df.empty: + lines.append(f"| {cohort} | 0 | 0 | 0 | 0 | 0 | 0 | 0 |") + continue + n_lineages = df["lineage_id"].nunique() + n_tracks = df.groupby(["dataset_id", "fov_name", "track_id"]).ngroups + n_frames = len(df) + divides_counts = df.drop_duplicates(["dataset_id", "fov_name", "lineage_id"])["divides"].value_counts() + lines.append( + f"| {cohort} | {n_lineages} | {n_tracks} | {n_frames} | " + f"{divides_counts.get('none', 0)} | {divides_counts.get('pre', 0)} | " + f"{divides_counts.get('during', 0)} | {divides_counts.get('post', 0)} |" + ) + return "\n".join(lines) + + +def main() -> None: + """Write per-cohort annotation CSVs for one candidate set.""" + parser = argparse.ArgumentParser(description="Auto-select candidates and tag cohorts (Stage 0).") + parser.add_argument("--datasets", required=True, help="Path to datasets.yaml") + parser.add_argument("--config", required=True, help="Path to candidates.yaml") + parser.add_argument("--candidate-set", required=True, help="Name under config['candidate_sets']") + args = parser.parse_args() + + config = load_stage_config(args.datasets, args.config) + cand_sets = config.get("candidate_sets", {}) + if args.candidate_set not in cand_sets: + raise KeyError( + f"Candidate set {args.candidate_set!r} not in config['candidate_sets']. Known: {sorted(cand_sets)}" + ) + cand_cfg = cand_sets[args.candidate_set] + + if cand_cfg.get("source") == "manual": + raise SystemExit( + f"Candidate set {args.candidate_set!r} declares source=manual; use manual_candidates.py instead." + ) + + dataset_ids = cand_cfg["datasets"] + dataset_cfgs = {d["dataset_id"]: d for d in config["datasets"]} + + # Embedding pattern for LC predictions (sensor channel by convention). + embeddings_cfg = config.get("embeddings", {}) + embedding_pattern = embeddings_cfg.get("sensor", "*_viral_sensor_*.zarr") + + productive_parts: list[pd.DataFrame] = [] + bystander_parts: list[pd.DataFrame] = [] + abortive_parts: list[pd.DataFrame] = [] + unannotated_parts: list[pd.DataFrame] = [] + mock_parts: list[pd.DataFrame] = [] + + for ds_id in dataset_ids: + if ds_id not in dataset_cfgs: + raise KeyError(f"dataset_id {ds_id!r} not in datasets.yaml") + ds_cfg = dataset_cfgs[ds_id] + results = _build_dataset_cohorts( + dataset_id=ds_id, + dataset_cfg=ds_cfg, + cand_cfg=cand_cfg, + embedding_pattern=embedding_pattern, + embeddings=embeddings_cfg, + ) + meta = results["_meta"] + bystander_df, abortive_df, unannotated_df = _split_well_non_productive( + results["well_non_productive"], + meta["lc_obs"], + meta["bystander_fraction"], + meta["abortive_min_run"], + ) + if not results["productive"].empty: + productive_parts.append(results["productive"]) + if not bystander_df.empty: + bystander_parts.append(bystander_df) + if not abortive_df.empty: + abortive_parts.append(abortive_df) + if not unannotated_df.empty: + unannotated_parts.append(unannotated_df) + if not results["mock"].empty: + mock_parts.append(results["mock"]) + + if not productive_parts: + raise RuntimeError(f"Candidate set {args.candidate_set!r} produced no productive lineages.") + + # Tag cohort on each cohort frame before merging, so we can split + # back after lineage reconnection without losing cohort identity. + def _tag(parts: list[pd.DataFrame], cohort: str) -> pd.DataFrame: + if not parts: + return pd.DataFrame(columns=[*OUTPUT_COLUMNS, "cohort"]) + merged = pd.concat(parts, ignore_index=True) + merged["cohort"] = cohort + return merged + + productive_df = _tag(productive_parts, "productive") + bystander_df = _tag(bystander_parts, "bystander") + abortive_df = _tag(abortive_parts, "abortive") + unannotated_df = _tag(unannotated_parts, "unannotated_productive") + mock_df = _tag(mock_parts, "mock") + + # Lineage reconnection on the combined frame produces globally unique + # lineage_ids across all cohorts. We then split back by cohort. + combined = pd.concat([productive_df, bystander_df, abortive_df, unannotated_df, mock_df], ignore_index=True) + combined = assign_lineage_ids(combined, return_both_branches=True) + + productive_df = combined[combined["cohort"] == "productive"].copy() + bystander_df = combined[combined["cohort"] == "bystander"].copy() + abortive_df = combined[combined["cohort"] == "abortive"].copy() + unannotated_df = combined[combined["cohort"] == "unannotated_productive"].copy() + mock_df = combined[combined["cohort"] == "mock"].copy() + + # Cap productive lineages to max_lineages by length. + max_lineages = cand_cfg.get("max_lineages") + if max_lineages is not None: + lineage_lengths = ( + productive_df.groupby(["dataset_id", "fov_name", "lineage_id"]).size().sort_values(ascending=False) + ) + keep = set(lineage_lengths.head(max_lineages).index) + mask = productive_df.apply(lambda r: (r["dataset_id"], r["fov_name"], str(r["lineage_id"])) in keep, axis=1) + n_before = productive_df.groupby(["dataset_id", "fov_name", "lineage_id"]).ngroups + productive_df = productive_df[mask].reset_index(drop=True) + n_after = productive_df.groupby(["dataset_id", "fov_name", "lineage_id"]).ngroups + _logger.info(f"Capped productive lineages from {n_before} to {n_after} (max {max_lineages})") + + # Compute t_zero per productive lineage and build the divides classifier. + filter_cfg = cand_cfg["productive_filter"] + t_zero_lookup = _t_zero_per_lineage( + productive_df, + anchor_label=filter_cfg["anchor_label"], + anchor_positive=filter_cfg["anchor_positive"], + ) + + # Per-dataset window-frame conversion. Each lineage lives in one + # dataset, so divides classification picks the frame interval from + # its dataset. + lineage_rules = cand_cfg.get("lineage_rules", {}) + k_pre_minutes = float(lineage_rules.get("transition_window_k_pre_minutes", 60)) + k_post_minutes = float(lineage_rules.get("transition_window_k_post_minutes", 120)) + window_frames_by_dataset: dict[str, tuple[int, int]] = {} + for ds_id in dataset_ids: + frame_interval = float(dataset_cfgs[ds_id]["frame_interval_minutes"]) + window_frames_by_dataset[ds_id] = ( + int(round(k_pre_minutes / frame_interval)), + int(round(k_post_minutes / frame_interval)), + ) + + productive_out = _emit_cohort(productive_df, "productive", t_zero_lookup, window_frames_by_dataset) + bystander_out = _emit_cohort(bystander_df, "bystander", t_zero_lookup, window_frames_by_dataset) + abortive_out = _emit_cohort(abortive_df, "abortive", t_zero_lookup, window_frames_by_dataset) + unannotated_out = _emit_cohort(unannotated_df, "unannotated_productive", t_zero_lookup, window_frames_by_dataset) + mock_out = _emit_cohort(mock_df, "mock", t_zero_lookup, window_frames_by_dataset) + + CANDIDATES_DIR.mkdir(parents=True, exist_ok=True) + cohort_dfs = { + "productive": productive_out, + "bystander": bystander_out, + "abortive": abortive_out, + "unannotated_productive": unannotated_out, + "mock": mock_out, + } + for cohort, df in cohort_dfs.items(): + out_path = CANDIDATES_DIR / f"{args.candidate_set}_{cohort}.csv" + df.to_csv(out_path, index=False) + _logger.info(f"Wrote {out_path} ({len(df)} rows)") + + funnel_path = CANDIDATES_DIR / f"{args.candidate_set}_funnel.md" + funnel_path.write_text(_funnel_report(args.candidate_set, cohort_dfs)) + _logger.info(f"Wrote {funnel_path}") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/2-align_cells/align_anno.py b/applications/dynaclr/scripts/pseudotime/2-align_cells/align_anno.py new file mode 100644 index 000000000..cc9e3dadc --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/2-align_cells/align_anno.py @@ -0,0 +1,198 @@ +"""Path A-anno alignment: real-time shift on human ``infection_state``. + +Reads the productive-cohort CSV from Stage 0, anchors each lineage at +``t_zero = first frame where infection_state == "infected"``, and writes +a per-frame alignment parquet with the unified schema. No DTW, no +template, no warping. + +For bystander, abortive, and mock cohorts: no per-cell anchor — these +cohorts pass through with ``t_zero = NaN`` and ``t_rel_minutes = NaN``, +to be used as null distributions in Stage 3 readouts. + +Path A reuses :func:`dynaclr.pseudotime.alignment.assign_t_perturb` for +the per-lineage anchor logic; the rest of this script is parquet +plumbing. + +Usage:: + + cd applications/dynaclr/scripts/pseudotime/2-align_cells + uv run python align_anno.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/candidates.yaml \ + --candidate-set zikv_productive_07_24 +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +import numpy as np +import pandas as pd + +SCRIPT_DIR = Path(__file__).resolve().parent +CANDIDATES_DIR = SCRIPT_DIR.parent / "0-select_candidates" / "candidates" +OUTPUT_DIR = SCRIPT_DIR / "A-anno" / "alignments" + +sys.path.insert(0, str(SCRIPT_DIR.parent)) +from utils import load_stage_config # noqa: E402 + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") +_logger = logging.getLogger(__name__) + +# Unified Stage 2 parquet schema (per DAG §7.4). Path A populates the +# real-time columns; Path B-only columns are NaN/empty here. +PARQUET_COLUMNS = [ + # ids + "dataset_id", + "fov_name", + "lineage_id", + "track_id", + "t", + # cohort + lineage tags from Stage 0 + "cohort", + "divides", + # anchor + real-time + "t_zero", + "t_rel_minutes", + "track_path", + # Path B-only (left empty for Path A) + "pseudotime", + "alignment_region", + "t_rel_minutes_warped", + "dtw_cost", + "length_normalized_cost", + "path_skew", + "match_q_start", + "match_q_end", + "template_id", +] + + +def _frame_interval_lookup(config: dict) -> dict[str, float]: + """Map dataset_id → frame_interval_minutes from the merged config.""" + return {d["dataset_id"]: float(d["frame_interval_minutes"]) for d in config["datasets"]} + + +def _t_zero_from_annotations( + cohort_df: pd.DataFrame, + anchor_label: str, + anchor_positive: str, +) -> dict[str, int]: + """First frame per lineage where ``anchor_label == anchor_positive``.""" + out: dict[str, int] = {} + if cohort_df.empty: + return out + positive_rows = cohort_df[cohort_df[anchor_label] == anchor_positive] + for lineage_id, g in positive_rows.groupby("lineage_id"): + if not lineage_id: + continue + out[str(lineage_id)] = int(g["t"].min()) + return out + + +def _align_cohort( + cohort_df: pd.DataFrame, + cohort: str, + t_zero_lookup: dict[int, int], + frame_intervals: dict[str, float], +) -> pd.DataFrame: + """Add ``t_zero``, ``t_rel_minutes``, ``track_path`` to a cohort frame. + + Lineages without an anchor in ``t_zero_lookup`` (bystander, abortive, + mock) get NaN for ``t_zero`` and ``t_rel_minutes``. + """ + if cohort_df.empty: + out = pd.DataFrame(columns=PARQUET_COLUMNS) + return out + + out = cohort_df.copy() + + out["t_zero"] = out["lineage_id"].map(t_zero_lookup) + frame_interval_per_row = out["dataset_id"].map(frame_intervals) + if frame_interval_per_row.isna().any(): + unknown = sorted(set(out["dataset_id"][frame_interval_per_row.isna()])) + raise KeyError(f"Frame interval missing for dataset(s) {unknown}; check datasets.yaml entries.") + + has_anchor = out["t_zero"].notna() + out["t_rel_minutes"] = np.where( + has_anchor, + (out["t"] - out["t_zero"].fillna(0)) * frame_interval_per_row, + np.nan, + ) + out["track_path"] = "A-anno" + + # Path B-only columns left empty. + out["pseudotime"] = np.nan + out["alignment_region"] = "" + out["t_rel_minutes_warped"] = np.nan + out["dtw_cost"] = np.nan + out["length_normalized_cost"] = np.nan + out["path_skew"] = np.nan + out["match_q_start"] = pd.NA + out["match_q_end"] = pd.NA + out["template_id"] = "" + + for col in PARQUET_COLUMNS: + if col not in out.columns: + out[col] = "" + return out[PARQUET_COLUMNS] + + +def main() -> None: + """Write Path A-anno alignment parquet for one candidate set.""" + parser = argparse.ArgumentParser(description="Path A-anno alignment (annotation-anchored real-time shift).") + parser.add_argument("--datasets", required=True, help="Path to datasets.yaml") + parser.add_argument("--config", required=True, help="Path to candidates.yaml") + parser.add_argument("--candidate-set", required=True, help="Candidate-set name") + parser.add_argument( + "--anchor-label", + default="infection_state", + help="Annotation column to anchor on (default: infection_state)", + ) + parser.add_argument( + "--anchor-positive", + default="infected", + help="Positive value of anchor column (default: infected)", + ) + args = parser.parse_args() + + config = load_stage_config(args.datasets, args.config) + frame_intervals = _frame_interval_lookup(config) + + cohorts = ["productive", "bystander", "abortive", "unannotated_productive", "mock"] + cohort_dfs: dict[str, pd.DataFrame] = {} + for cohort in cohorts: + path = CANDIDATES_DIR / f"{args.candidate_set}_{cohort}.csv" + if path.exists(): + cohort_dfs[cohort] = pd.read_csv(path) + _logger.info(f"Read {path} ({len(cohort_dfs[cohort])} rows)") + else: + _logger.warning(f"{path} not found; cohort '{cohort}' will be empty") + cohort_dfs[cohort] = pd.DataFrame() + + productive_df = cohort_dfs["productive"] + if productive_df.empty: + raise RuntimeError(f"Productive cohort empty for {args.candidate_set!r}. Run select_candidates.py first.") + + t_zero_lookup = _t_zero_from_annotations( + productive_df, + anchor_label=args.anchor_label, + anchor_positive=args.anchor_positive, + ) + _logger.info(f"Computed t_zero for {len(t_zero_lookup)} productive lineages") + + aligned_parts = [_align_cohort(cohort_dfs[c], c, t_zero_lookup, frame_intervals) for c in cohorts] + aligned = pd.concat(aligned_parts, ignore_index=True) + + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + out_path = OUTPUT_DIR / f"{args.candidate_set}.parquet" + aligned.to_parquet(out_path, index=False) + n_with_anchor = aligned["t_zero"].notna().sum() + _logger.info(f"Wrote {out_path} ({len(aligned)} rows, {n_with_anchor} anchored)") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/2-align_cells/align_embedding.py b/applications/dynaclr/scripts/pseudotime/2-align_cells/align_embedding.py new file mode 100644 index 000000000..fcb912cc0 --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/2-align_cells/align_embedding.py @@ -0,0 +1,593 @@ +"""Path B alignment: subsequence DTW on the NS3 channel embedding. + +Stage 2 of the pseudotime pipeline (Path B track). Takes a template +built by ``1-build_template/build_template.py`` and scans it across +every track in a *query set*. Subsequence DTW finds, per query track, +the time window where the template best matches — i.e. when that cell +traversed the event encoded by the template. Frames inside the matched +window are mapped to template-relative minutes via the template's +``time_calibration``; frames outside are labelled ``pre`` / ``post``. + +Preprocessing re-uses the build-time z-score + PCA + L2 stored in the +template zarr. Never refit at alignment time. + +Output parquet matches the unified Stage 2 schema (per DAG §7.4) so +Path A and Path B parquets can be compared directly in Stage 4. +``length_normalized_cost`` and ``path_skew`` are surfaced as columns +(not just used as filters) so downstream sweeps can re-gate without +re-running DTW. Path-skew is the primary gate (rejects degenerate +non-diagonal warps); cost is the secondary gate (stereotypy filter). + +Usage:: + + cd applications/dynaclr/scripts/pseudotime/2-align_cells + uv run python align_embedding.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/align_cells.yaml \ + --template infection_nondividing_sensor \ + --flavor raw \ + --query-set sensor_all_transitioning +""" + +from __future__ import annotations + +import argparse +import json +import logging +import sys +from pathlib import Path + +import anndata as ad +import numpy as np +import pandas as pd + +from dynaclr.pseudotime import ( + AlignmentResult, + alignment_results_to_dataframe, + date_prefix_from_dataset_id, + dtw_align_tracks, + find_embedding_zarr, + load_template_flavor, + read_template_attrs, +) +from dynaclr.pseudotime.alignment import filter_tracks + +SCRIPT_DIR = Path(__file__).resolve().parent +TEMPLATES_DIR = SCRIPT_DIR.parent / "1-build_template" / "templates" +CANDIDATES_DIR = SCRIPT_DIR.parent / "0-select_candidates" / "candidates" +OUTPUT_ALIGNMENTS_DIR = SCRIPT_DIR / "B" / "alignments" + +sys.path.insert(0, str(SCRIPT_DIR.parent)) +from utils import load_stage_config # noqa: E402 + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") +_logger = logging.getLogger(__name__) + + +def _load_query_embeddings( + query_cfg: dict, + dataset_cfgs: dict[str, dict], + embedding_pattern: str, + min_track_timepoints: int, + template_len_frames: int, +) -> tuple[dict[str, ad.AnnData], dict[str, pd.DataFrame]]: + """Load query embedding zarrs and build per-dataset filtered track dfs. + + Parameters + ---------- + query_cfg : dict + Query-set config entry from ``config['query_sets'][name]``. Must + include ``datasets`` (list) and may include ``track_filter`` + (dict of ``obs``-column → required value) and + ``min_track_minutes`` (float). + dataset_cfgs : dict[str, dict] + Map ``{dataset_id: dataset_cfg}`` from ``config['datasets']``. + embedding_pattern : str + Glob pattern for the channel's zarr file + (e.g. ``"*_viral_sensor_*.zarr"``). + min_track_timepoints : int + CLI-level floor on track length, passed to :func:`filter_tracks`. + + Returns + ------- + tuple[dict[str, ad.AnnData], dict[str, pd.DataFrame]] + ``(adata_dict, df_dict)`` keyed by ``dataset_id``. The df has + at least ``fov_name``, ``track_id``, ``t`` — columns used by + ``dtw_align_tracks`` for valid-track selection. + """ + adata_dict: dict[str, ad.AnnData] = {} + df_dict: dict[str, pd.DataFrame] = {} + + track_filter = query_cfg.get("track_filter", {}) or {} + min_track_minutes = query_cfg.get("min_track_minutes") + min_pre_minutes = float(query_cfg.get("min_pre_minutes", 0)) + min_post_minutes = float(query_cfg.get("min_post_minutes", 0)) + + for ds_entry in query_cfg["datasets"]: + dataset_id = ds_entry["dataset_id"] + ds_cfg = dataset_cfgs[dataset_id] + fov_pattern = ds_cfg.get("fov_pattern") + frame_interval = ds_cfg["frame_interval_minutes"] + + prefix = date_prefix_from_dataset_id(dataset_id) + zarr_path = find_embedding_zarr(ds_cfg["pred_dir"], prefix + embedding_pattern) + + adata = ad.read_zarr(zarr_path) + adata.obs_names_make_unique() + + # FOV restriction from the dataset config (e.g. "C/2") — keeps us + # out of control wells unless the user explicitly wants them. + if fov_pattern is not None: + fov_mask = adata.obs["fov_name"].astype(str).str.contains(fov_pattern, regex=False) + adata = adata[fov_mask.to_numpy()].copy() + + # Track-level filters from query_cfg['track_filter']: each entry + # requires the obs column to equal the provided value on every + # frame of the track. Applied in adata space so the dropped rows + # don't make it into the DTW solver. + for col, required_value in track_filter.items(): + if col not in adata.obs.columns: + raise KeyError(f"track_filter column {col!r} not in adata.obs for {dataset_id}") + mask = adata.obs[col].astype(str) == str(required_value) + adata = adata[mask.to_numpy()].copy() + + # Build a tracking df for the stage-1 filter helper. + df = adata.obs[["fov_name", "track_id", "t"]].copy() + df["fov_name"] = df["fov_name"].astype(str) + df["track_id"] = df["track_id"].astype(int) + df["t"] = df["t"].astype(int) + + # Pass 1 (necessary condition): track must be long enough to hold at least + # the template itself plus any required pre/post headroom. The sufficient + # condition — that the match actually lands inside the track with that + # headroom on either side — is enforced post-DTW in main(). + if min_track_minutes is not None: + min_frames = int(np.ceil(float(min_track_minutes) / frame_interval)) + else: + min_frames = min_track_timepoints + headroom_frames = int(np.ceil((min_pre_minutes + min_post_minutes) / frame_interval)) + min_frames = max(min_frames, min_track_timepoints, template_len_frames + headroom_frames) + df = filter_tracks(df, min_timepoints=min_frames) + + # Subset adata to surviving (fov, track, t) rows. + keep_keys = set(zip(df["fov_name"], df["track_id"], df["t"])) + keep_mask = [ + (str(f), int(tid), int(t)) in keep_keys + for f, tid, t in zip(adata.obs["fov_name"], adata.obs["track_id"], adata.obs["t"]) + ] + adata = adata[np.asarray(keep_mask)].copy() + + if len(adata) == 0: + _logger.warning(f"[{dataset_id}] no tracks survived filters; skipping") + continue + + adata_dict[dataset_id] = adata + df_dict[dataset_id] = df + _logger.info( + f"[{dataset_id}] {adata.n_obs} rows, " + f"{df.groupby(['fov_name', 'track_id']).ngroups} tracks " + f"(min {min_frames} frames = {min_frames * frame_interval} min)" + ) + + return adata_dict, df_dict + + +def _enrich_with_cohort_metadata( + flat: pd.DataFrame, + candidate_set: str, + frame_interval_minutes: dict[str, float], +) -> pd.DataFrame: + """Join Path B alignment frame with Stage 0 cohort metadata. + + Reads the productive-cohort CSV produced by :mod:`select_candidates` + and merges in ``lineage_id``, ``cohort``, ``divides``, and + ``t_zero`` per ``(dataset_id, fov_name, track_id)`` so the Path B + parquet matches the unified Stage 2 schema (per DAG §7.4). Computes + real-time ``t_rel_minutes`` per row. + + Path B alignment runs only on the productive cohort. Cells not + matched in the cohort CSV get ``cohort="productive"`` (default for + Path B input), ``lineage_id=""`` (orphan sentinel), ``t_zero=NaN``. + """ + flat = flat.copy() + cand_csv = CANDIDATES_DIR / f"{candidate_set}_productive.csv" + if not cand_csv.exists(): + _logger.warning( + f"Productive cohort CSV {cand_csv} not found; lineage_id, cohort, divides, t_zero will be missing." + ) + flat["lineage_id"] = "" + flat["cohort"] = "productive" + flat["divides"] = "none" + flat["t_zero"] = pd.NA + flat["t_rel_minutes"] = np.nan + return flat + + productive = pd.read_csv(cand_csv) + # Reduce to per-track metadata: lineage_id, divides are per-lineage. + per_track = ( + productive.groupby(["dataset_id", "fov_name", "track_id"]) + .agg( + lineage_id=("lineage_id", "first"), + divides=("divides", "first"), + ) + .reset_index() + ) + flat["fov_name"] = flat["fov_name"].astype(str) + flat["track_id"] = flat["track_id"].astype(int) + per_track["fov_name"] = per_track["fov_name"].astype(str) + per_track["track_id"] = per_track["track_id"].astype(int) + flat = flat.merge(per_track, on=["dataset_id", "fov_name", "track_id"], how="left") + flat["lineage_id"] = flat["lineage_id"].fillna("").astype(str) + flat["divides"] = flat["divides"].fillna("none") + flat["cohort"] = "productive" + + # t_zero: per-lineage first frame where infection_state == "infected". + productive_pos = productive[productive["infection_state"] == "infected"] + t_zero_lookup = productive_pos.groupby("lineage_id")["t"].min().to_dict() if not productive_pos.empty else {} + flat["t_zero"] = flat["lineage_id"].map(t_zero_lookup) + + # t_rel_minutes = (t - t_zero) * frame_interval. NaN when no anchor. + fi = flat["dataset_id"].map(frame_interval_minutes) + has_anchor = flat["t_zero"].notna() + flat["t_rel_minutes"] = np.where( + has_anchor, + (flat["t"] - flat["t_zero"].fillna(0)) * fi, + np.nan, + ) + return flat + + +def _per_track_match_metadata( + results: list[AlignmentResult], + frame_interval_minutes: dict[str, float], +) -> pd.DataFrame: + """Derive per-track subsequence-match bounds from alignment results. + + ``match_q_start`` / ``match_q_end`` are the absolute query frames + bounding the first / last template-matched frame (``alignment_region + == "aligned"``). ``match_duration_minutes`` is the real-time span of + the match using the dataset's frame interval. + + Parameters + ---------- + results : list[AlignmentResult] + Output of :func:`dtw_align_tracks`. + frame_interval_minutes : dict[str, float] + Map ``{dataset_id: frame_interval_minutes}``. + + Returns + ------- + pd.DataFrame + One row per cell with ``dataset_id``, ``fov_name``, + ``track_id``, ``match_q_start``, ``match_q_end``, + ``match_duration_minutes``. + """ + rows = [] + for r in results: + aligned_mask = r.alignment_region == "aligned" + if not aligned_mask.any(): + # Shouldn't happen when subsequence=True with any match, but + # keep the row so downstream joins don't drop the cell. + q_start, q_end, duration = np.nan, np.nan, np.nan + else: + aligned_times = r.timepoints[aligned_mask] + q_start = int(aligned_times.min()) + q_end = int(aligned_times.max()) + fi = frame_interval_minutes[r.dataset_id] + duration = float(q_end - q_start) * fi + rows.append( + { + "dataset_id": r.dataset_id, + "fov_name": r.fov_name, + "track_id": r.track_id, + "match_q_start": q_start, + "match_q_end": q_end, + "match_duration_minutes": duration, + } + ) + return pd.DataFrame(rows) + + +def main() -> None: + """CLI entry point for Stage 2 subsequence alignment.""" + parser = argparse.ArgumentParser(description="Subsequence-DTW-align query tracks to a template (Stage 2)") + parser.add_argument("--datasets", required=True, help="Path to datasets.yaml (shared infra config)") + parser.add_argument("--config", required=True, help="Path to align_cells.yaml") + parser.add_argument("--template", required=True, help="Template name under config['templates']") + parser.add_argument( + "--flavor", + choices=["raw", "pca"], + default="raw", + help="Which template flavor to align against (default: raw)", + ) + parser.add_argument( + "--candidate-set", + default=None, + help=( + "Candidate-set name from candidates.yaml; used to join cohort + lineage " + "metadata into the unified Stage 2 parquet. Defaults to --query-set if omitted." + ), + ) + parser.add_argument( + "--query-set", + required=True, + help="Query-set name under config['query_sets']", + ) + parser.add_argument( + "--min-track-timepoints", + type=int, + default=3, + help="Minimum timepoints per track (default: 3). Overridden by query_set.min_track_minutes when larger.", + ) + parser.add_argument( + "--min-match-ratio", + type=float, + default=0.5, + help=( + "Minimum fraction of template length that the matched window must cover " + "(default: 0.5). Rejects degenerate subsequence DTW matches where the solver " + "collapses the template onto a few query frames. Set to 0 to disable." + ), + ) + parser.add_argument( + "--max-skew", + type=float, + default=0.8, + help=( + "Maximum allowed path skewness in [0, 1], where 0 = perfectly diagonal. " + "Rejects L-shaped or heavily non-diagonal warps that slip past the psi cap. " + "Default mirrors the old find_best_match_dtw_bernd_clifford default. Set to 1 to disable." + ), + ) + parser.add_argument( + "--min-match-minutes", + type=float, + default=None, + help=( + "Minimum real-time duration of the matched window. Supersedes --min-match-ratio " + "when set. Per-track frame threshold = ceil(min_match_minutes / frame_interval_minutes). " + "Use this to apply a single wall-clock bar across query datasets with different " + "frame intervals." + ), + ) + args = parser.parse_args() + + config = load_stage_config(args.datasets, args.config) + + query_cfg = config.get("query_sets", {}).get(args.query_set) + if query_cfg is None: + raise KeyError(f"Query set {args.query_set!r} not in config['query_sets']") + + template_path = TEMPLATES_DIR / f"template_{args.template}.zarr" + if not template_path.exists(): + raise FileNotFoundError(f"Template zarr not found: {template_path}") + + # Template channel is recorded in the zarr's config_snapshot attrs by build_template.py. + # Reading from the zarr (rather than requiring the build-template config) keeps Stage 2 + # self-contained: a template zarr is all you need to know what it aligns to. + snapshot = read_template_attrs(template_path).get("config_snapshot", {}) + template_entry = snapshot.get("templates", {}).get(args.template, {}) + template_channel = template_entry.get("channel") + if template_channel is None: + raise ValueError( + f"Template zarr {template_path} has no recorded channel in config_snapshot; " + f"was it built by the current build_template.py?" + ) + + query_channel = query_cfg.get("channel", template_channel) + if query_channel != template_channel: + raise ValueError( + f"Query set channel {query_channel!r} does not match template channel " + f"{template_channel!r}. Alignment must happen in the template's embedding space." + ) + + _logger.info(f"Loading template {template_path} (flavor={args.flavor})") + template_result, _attrs = load_template_flavor(template_path, args.flavor) + _logger.info(f" template shape {template_result.template.shape}, {template_result.n_input_tracks} input tracks") + + embedding_pattern = config["embeddings"][query_channel] + dataset_cfgs = {d["dataset_id"]: d for d in config["datasets"]} + + _logger.info(f"Loading query set {args.query_set!r} ({len(query_cfg['datasets'])} datasets)") + adata_dict, df_dict = _load_query_embeddings( + query_cfg, + dataset_cfgs, + embedding_pattern, + args.min_track_timepoints, + template_len_frames=template_result.template.shape[0], + ) + if not adata_dict: + raise RuntimeError(f"Query set {args.query_set!r} produced no usable tracks") + + all_results: list[AlignmentResult] = [] + frame_interval_by_ds = {d["dataset_id"]: float(d["frame_interval_minutes"]) for d in config["datasets"]} + + # psi is a TEMPLATE-axis budget (frames of the cost matrix), not a query-time + # budget. The frame-unit default (t_template // 2) inside dtw_align_tracks is + # the right value regardless of query frame rate. We don't scale psi by the + # query's frame interval — see dtw_align_tracks' inline note. + for dataset_id, adata in adata_dict.items(): + _logger.info(f"Aligning {dataset_id} (subsequence DTW)") + results = dtw_align_tracks( + adata=adata, + df=df_dict[dataset_id], + template_result=template_result, + dataset_id=dataset_id, + min_track_timepoints=args.min_track_timepoints, + subsequence=True, + ) + all_results.extend(results) + + if not all_results: + raise RuntimeError("No alignment results produced") + + drop_log: dict[str, int] = {"n_input_tracks": len(all_results)} + + # Drop tracks whose DTW solver could not find a valid path — these show up + # as `length_normalized_cost == inf` (dtaidistance returns an unreachable + # endpoint when psi overflows the cost band on very short tracks). They + # carry no ranking signal and only pollute downstream plots. + n_before = len(all_results) + all_results = [r for r in all_results if np.isfinite(r.length_normalized_cost)] + n_dropped = n_before - len(all_results) + drop_log["n_dropped_non_finite_cost"] = n_dropped + if n_dropped: + _logger.warning( + f"Dropped {n_dropped}/{n_before} tracks with non-finite DTW cost " + f"(likely too short relative to template length {template_result.template.shape[0]})" + ) + if not all_results: + raise RuntimeError("No tracks produced a finite DTW cost; check min_track_minutes vs template length") + + # Skew filter — primary gate per discussion §3.8 #2. Rejects degenerate + # non-diagonal warps (L-shape, cliff, etc.) without rejecting biological + # rate variance. Run BEFORE the cost / min-match filters so cost gating + # operates on a population of valid warps only. + if args.max_skew < 1.0: + n_before = len(all_results) + all_results = [r for r in all_results if r.path_skew <= args.max_skew] + n_skewed = n_before - len(all_results) + drop_log["n_dropped_max_skew"] = n_skewed + if n_skewed: + _logger.warning( + f"Dropped {n_skewed}/{n_before} tracks with path_skew > {args.max_skew:.2f} " + "(non-diagonal warps; relax --max-skew to keep them)" + ) + if not all_results: + raise RuntimeError( + f"No tracks survived skew filter (max_skew={args.max_skew}); relax or disable with --max-skew 1" + ) + + # Drop tracks whose matched window is shorter than --min-match-ratio of the + # template length. Subsequence DTW with psi relaxation can collapse the + # template onto a 1-frame query window (every template position warped to + # the same query frame) — cost is near-zero but the match is meaningless. + # A window shorter than ~half the template can't plausibly represent the + # event the template encodes. + t_template = template_result.template.shape[0] + + # Minute-based filter takes precedence — frame-rate invariant across query datasets. + # Threshold is computed per-track using that dataset's frame_interval_minutes so + # a 10 min/frame track needs more frames than a 30 min/frame track for the same + # real-time match duration. + if args.min_match_minutes is not None and args.min_match_minutes > 0: + n_before = len(all_results) + kept = [] + for r in all_results: + fi = frame_interval_by_ds[r.dataset_id] + min_aligned_track = max(2, int(np.ceil(args.min_match_minutes / fi))) + if int((r.alignment_region == "aligned").sum()) >= min_aligned_track: + kept.append(r) + all_results = kept + n_short = n_before - len(all_results) + drop_log["n_dropped_min_match_minutes"] = n_short + if n_short: + _logger.warning( + f"Dropped {n_short}/{n_before} tracks with matched window shorter than " + f"{args.min_match_minutes:.0f} min (threshold is per-track: ceil(minutes / frame_interval))" + ) + if not all_results: + raise RuntimeError( + f"No tracks matched at least {args.min_match_minutes:.0f} min; " + "lower --min-match-minutes or rebuild the template" + ) + elif args.min_match_ratio > 0: + # Legacy frame-based path; fine when all query datasets share a frame interval + # with the template build set, but not cross-dataset safe. + min_aligned = max(2, int(np.ceil(args.min_match_ratio * t_template))) + n_before = len(all_results) + all_results = [r for r in all_results if int((r.alignment_region == "aligned").sum()) >= min_aligned] + n_short = n_before - len(all_results) + drop_log["n_dropped_min_match_ratio"] = n_short + if n_short: + _logger.warning( + f"Dropped {n_short}/{n_before} tracks with matched window < {min_aligned} frames " + f"({args.min_match_ratio:.0%} of template length {t_template}) — likely degenerate collapses" + ) + if not all_results: + raise RuntimeError( + f"No tracks matched at least {min_aligned} frames; lower --min-match-ratio or rebuild the template" + ) + + # Pre/post headroom filter: the match must leave at least min_pre_minutes + # before match_q_start and min_post_minutes after match_q_end, both relative + # to the cell's own track start/end. This is the sufficient condition that + # pairs with pass 1's necessary condition in _load_query_embeddings. + min_pre_minutes = float(query_cfg.get("min_pre_minutes", 0)) + min_post_minutes = float(query_cfg.get("min_post_minutes", 0)) + if min_pre_minutes > 0 or min_post_minutes > 0: + filtered = [] + for r in all_results: + fi = frame_interval_by_ds[r.dataset_id] + pre_needed = int(np.ceil(min_pre_minutes / fi)) + post_needed = int(np.ceil(min_post_minutes / fi)) + aligned_mask = r.alignment_region == "aligned" + if not aligned_mask.any(): + continue + aligned_times = r.timepoints[aligned_mask] + q_start, q_end = int(aligned_times.min()), int(aligned_times.max()) + track_min, track_max = int(r.timepoints.min()), int(r.timepoints.max()) + if (q_start - track_min) >= pre_needed and (track_max - q_end) >= post_needed: + filtered.append(r) + n_before = len(all_results) + all_results = filtered + n_cut = n_before - len(all_results) + drop_log["n_dropped_pre_post_headroom"] = n_cut + if n_cut: + _logger.warning( + f"Dropped {n_cut}/{n_before} tracks without required pre/post headroom " + f"(need ≥ {min_pre_minutes:.0f} min before + {min_post_minutes:.0f} min after the match)" + ) + if not all_results: + raise RuntimeError( + "No tracks have the required pre/post headroom; " + "loosen min_pre_minutes/min_post_minutes in the query-set YAML" + ) + + flat = alignment_results_to_dataframe( + all_results, + template_id=template_result.template_id, + time_calibration=template_result.time_calibration, + ) + + match_meta = _per_track_match_metadata(all_results, frame_interval_by_ds) + flat = flat.merge(match_meta, on=["dataset_id", "fov_name", "track_id"], how="left") + + # estimated_t_rel_minutes (now renamed t_rel_minutes_warped per DAG §7.4) + # is only meaningful inside the aligned window. Pre/post frames get NaN + # here; their real-time t_rel_minutes is computed below from t_zero. + if "estimated_t_rel_minutes" in flat.columns: + outside = flat["alignment_region"].isin(["pre", "post"]) + flat.loc[outside, "estimated_t_rel_minutes"] = np.nan + flat = flat.rename(columns={"estimated_t_rel_minutes": "t_rel_minutes_warped"}) + else: + flat["t_rel_minutes_warped"] = np.nan + + # Unified Stage 2 schema (per DAG §7.4): join Path A's per-lineage + # cohort + lineage_id + t_zero by reading the candidate-set CSVs + # from Stage 0. The candidate-set name is taken from --candidate-set + # if provided, else inferred from the query_set name. + candidate_set = getattr(args, "candidate_set", None) or args.query_set + flat = _enrich_with_cohort_metadata(flat, candidate_set, frame_interval_by_ds) + flat["track_path"] = "B" + + OUTPUT_ALIGNMENTS_DIR.mkdir(parents=True, exist_ok=True) + out_path = OUTPUT_ALIGNMENTS_DIR / f"{args.template}_{args.flavor}_on_{args.query_set}.parquet" + flat.to_parquet(out_path, index=False) + n_tracks = flat.groupby(["dataset_id", "fov_name", "track_id"]).ngroups + _logger.info( + f"Wrote {out_path} ({len(flat)} rows, {n_tracks} tracks, " + f"{(flat['alignment_region'] == 'aligned').sum()} aligned frames)" + ) + + # Sidecar JSON capturing per-filter drop counts so a reviewer can see how many + # tracks made it through each guard without grepping stderr. + drop_log["n_kept"] = n_tracks + drop_log_path = out_path.with_suffix(".drop_log.json") + with open(drop_log_path, "w") as f: + json.dump(drop_log, f, indent=2) + _logger.info(f"Wrote drop log {drop_log_path}: {drop_log}") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/2-align_cells/align_lc.py b/applications/dynaclr/scripts/pseudotime/2-align_cells/align_lc.py new file mode 100644 index 000000000..6a14e096b --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/2-align_cells/align_lc.py @@ -0,0 +1,280 @@ +"""Path A-LC alignment: real-time shift on linear-classifier predictions. + +Reads the productive-cohort CSV from Stage 0, pulls +``predicted_infection_state`` from the NS3 channel embedding zarr, +anchors each lineage at the first frame of a sustained positive run +(at least ``--min-run`` consecutive positives), and writes a per-frame +alignment parquet with the unified schema. No DTW, no template, no +warping. + +The ``--min-run`` parameter (default 3 frames) prevents single-frame +LC flickers from defining ``t_zero``. + +Bystander, abortive, and mock cohorts pass through with +``t_zero = NaN`` for use as null distributions in Stage 3. + +Usage:: + + cd applications/dynaclr/scripts/pseudotime/2-align_cells + uv run python align_lc.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/candidates.yaml \ + --candidate-set zikv_productive_07_24 \ + --min-run 3 +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +import anndata as ad +import numpy as np +import pandas as pd + +SCRIPT_DIR = Path(__file__).resolve().parent +CANDIDATES_DIR = SCRIPT_DIR.parent / "0-select_candidates" / "candidates" +OUTPUT_DIR = SCRIPT_DIR / "A-LC" / "alignments" + +sys.path.insert(0, str(SCRIPT_DIR.parent)) +from utils import load_stage_config # noqa: E402 + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") +_logger = logging.getLogger(__name__) + +# Same parquet schema as align_anno.py — track_path differs. +PARQUET_COLUMNS = [ + "dataset_id", + "fov_name", + "lineage_id", + "track_id", + "t", + "cohort", + "divides", + "t_zero", + "t_rel_minutes", + "track_path", + "pseudotime", + "alignment_region", + "t_rel_minutes_warped", + "dtw_cost", + "length_normalized_cost", + "path_skew", + "match_q_start", + "match_q_end", + "template_id", +] + + +def _frame_interval_lookup(config: dict) -> dict[str, float]: + """Map dataset_id → frame_interval_minutes.""" + return {d["dataset_id"]: float(d["frame_interval_minutes"]) for d in config["datasets"]} + + +def _embedding_dir_lookup(config: dict) -> dict[str, Path]: + """Map dataset_id → directory containing the NS3 channel embedding zarrs.""" + return {d["dataset_id"]: Path(d["pred_dir"]) for d in config["datasets"]} + + +def _load_lc_predictions( + pred_dir: Path, + dataset_id: str, + pattern: str, + pred_column: str, +) -> pd.DataFrame: + """Read LC predictions per (fov_name, track_id, t) from a date-matched zarr. + + Filters glob matches to the dataset's date prefix (first three + underscore-separated tokens) so multi-date pred_dirs pick the + correct zarr. + """ + date_prefix = "_".join(dataset_id.split("_")[:3]) + matches = [m for m in pred_dir.glob(pattern) if m.name.startswith(date_prefix)] + if not matches: + _logger.warning(f"[{dataset_id}] no embedding zarr matched {pattern} with prefix {date_prefix} in {pred_dir}") + return pd.DataFrame(columns=["fov_name", "track_id", "t", pred_column]) + if len(matches) > 1: + _logger.warning( + f"[{dataset_id}] multiple zarrs matched {pattern} with prefix {date_prefix}: " + f"{[m.name for m in matches]}; using first" + ) + adata = ad.read_zarr(matches[0]) + adata.obs_names_make_unique() + if pred_column not in adata.obs.columns: + _logger.warning(f"{pred_column} not in {matches[0]} .obs") + return pd.DataFrame(columns=["fov_name", "track_id", "t", pred_column]) + obs = adata.obs[["fov_name", "track_id", "t", pred_column]].copy() + obs["fov_name"] = obs["fov_name"].astype(str) + obs["track_id"] = obs["track_id"].astype(int) + obs["t"] = obs["t"].astype(int) + return obs + + +def _first_run_start(positive_mask: np.ndarray, min_run: int) -> int | None: + """Return the index of the first frame entering a run of ≥ ``min_run`` 1s.""" + run = 0 + run_start = -1 + for i, v in enumerate(positive_mask): + if v: + if run == 0: + run_start = i + run += 1 + if run >= min_run: + return run_start + else: + run = 0 + run_start = -1 + return None + + +def _t_zero_from_lc( + productive_df: pd.DataFrame, + lc_obs_by_dataset: dict[str, pd.DataFrame], + pred_column: str, + positive_value: str, + min_run: int, +) -> dict[str, int]: + """For each productive lineage, find the LC anchor frame. + + Joins productive cohort rows with LC predictions on + ``(dataset_id, fov_name, track_id, t)``, then per lineage finds the + first frame entering a run of at least ``min_run`` consecutive + positive predictions. + """ + out: dict[str, int] = {} + if productive_df.empty: + return out + + for lineage_id, g in productive_df.groupby("lineage_id"): + if not lineage_id: + continue + ds_id = str(g["dataset_id"].iloc[0]) + if ds_id not in lc_obs_by_dataset or lc_obs_by_dataset[ds_id].empty: + continue + # Pull the LC predictions for this lineage's tracks. + lc_df = lc_obs_by_dataset[ds_id] + track_ids = set(g["track_id"].astype(int).unique()) + sub = lc_df[lc_df["track_id"].isin(track_ids)] + if sub.empty: + continue + # Sort by t and find the first sustained positive run. + sub = sub.sort_values("t") + positive_mask = (sub[pred_column] == positive_value).to_numpy() + run_start_idx = _first_run_start(positive_mask, min_run) + if run_start_idx is None: + continue + out[str(lineage_id)] = int(sub["t"].iloc[run_start_idx]) + + return out + + +def _align_cohort( + cohort_df: pd.DataFrame, + cohort: str, + t_zero_lookup: dict[str, int], + frame_intervals: dict[str, float], +) -> pd.DataFrame: + """Add ``t_zero``, ``t_rel_minutes``, ``track_path`` columns to a cohort frame.""" + if cohort_df.empty: + return pd.DataFrame(columns=PARQUET_COLUMNS) + + out = cohort_df.copy() + out["t_zero"] = out["lineage_id"].map(t_zero_lookup) + frame_interval_per_row = out["dataset_id"].map(frame_intervals) + if frame_interval_per_row.isna().any(): + unknown = sorted(set(out["dataset_id"][frame_interval_per_row.isna()])) + raise KeyError(f"Frame interval missing for dataset(s) {unknown}") + + has_anchor = out["t_zero"].notna() + out["t_rel_minutes"] = np.where( + has_anchor, + (out["t"] - out["t_zero"].fillna(0)) * frame_interval_per_row, + np.nan, + ) + out["track_path"] = "A-LC" + + out["pseudotime"] = np.nan + out["alignment_region"] = "" + out["t_rel_minutes_warped"] = np.nan + out["dtw_cost"] = np.nan + out["length_normalized_cost"] = np.nan + out["path_skew"] = np.nan + out["match_q_start"] = pd.NA + out["match_q_end"] = pd.NA + out["template_id"] = "" + + for col in PARQUET_COLUMNS: + if col not in out.columns: + out[col] = "" + return out[PARQUET_COLUMNS] + + +def main() -> None: + """Write Path A-LC alignment parquet for one candidate set.""" + parser = argparse.ArgumentParser(description="Path A-LC alignment (LC-anchored real-time shift).") + parser.add_argument("--datasets", required=True, help="Path to datasets.yaml") + parser.add_argument("--config", required=True, help="Path to candidates.yaml") + parser.add_argument("--candidate-set", required=True, help="Candidate-set name") + parser.add_argument("--min-run", type=int, default=3, help="Min consecutive positives for t_zero (default: 3)") + parser.add_argument( + "--pred-column", + default="predicted_infection_state", + help="LC prediction column in the embedding zarr's .obs", + ) + parser.add_argument("--positive-value", default="infected", help="LC positive class label (default: infected)") + args = parser.parse_args() + + config = load_stage_config(args.datasets, args.config) + frame_intervals = _frame_interval_lookup(config) + pred_dirs = _embedding_dir_lookup(config) + sensor_pattern = config.get("embeddings", {}).get("sensor", "*_viral_sensor_*.zarr") + + cohorts = ["productive", "bystander", "abortive", "unannotated_productive", "mock"] + cohort_dfs: dict[str, pd.DataFrame] = {} + for cohort in cohorts: + path = CANDIDATES_DIR / f"{args.candidate_set}_{cohort}.csv" + if path.exists(): + cohort_dfs[cohort] = pd.read_csv(path) + _logger.info(f"Read {path} ({len(cohort_dfs[cohort])} rows)") + else: + _logger.warning(f"{path} not found; cohort '{cohort}' empty") + cohort_dfs[cohort] = pd.DataFrame() + + productive_df = cohort_dfs["productive"] + if productive_df.empty: + raise RuntimeError(f"Productive cohort empty for {args.candidate_set!r}.") + + # Load LC predictions per dataset present in the productive cohort. + datasets_in_use = sorted(productive_df["dataset_id"].unique()) + lc_obs_by_dataset: dict[str, pd.DataFrame] = {} + for ds_id in datasets_in_use: + if ds_id not in pred_dirs: + _logger.warning(f"dataset_id {ds_id!r} missing from datasets.yaml; LC anchor unavailable") + continue + lc_obs_by_dataset[ds_id] = _load_lc_predictions( + pred_dirs[ds_id], dataset_id=ds_id, pattern=sensor_pattern, pred_column=args.pred_column + ) + + t_zero_lookup = _t_zero_from_lc( + productive_df, + lc_obs_by_dataset, + pred_column=args.pred_column, + positive_value=args.positive_value, + min_run=args.min_run, + ) + _logger.info(f"Computed LC t_zero for {len(t_zero_lookup)} productive lineages (min_run={args.min_run})") + + aligned_parts = [_align_cohort(cohort_dfs[c], c, t_zero_lookup, frame_intervals) for c in cohorts] + aligned = pd.concat(aligned_parts, ignore_index=True) + + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + out_path = OUTPUT_DIR / f"{args.candidate_set}.parquet" + aligned.to_parquet(out_path, index=False) + n_with_anchor = aligned["t_zero"].notna().sum() + _logger.info(f"Wrote {out_path} ({len(aligned)} rows, {n_with_anchor} anchored)") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/readout_common.py b/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/readout_common.py new file mode 100644 index 000000000..4930292c5 --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/readout_common.py @@ -0,0 +1,332 @@ +"""Shared helpers for per-organelle Stage 3 readouts. + +Each per-organelle script (readout_sec61, readout_g3bp1, readout_phase) +is a thin wrapper that: + +1. Loads its alignment parquet (Path A-anno, A-LC, or B output). +2. Loads the matching organelle channel embedding zarr per dataset. +3. Computes per-cell cosine distance from a per-cell pre-baseline. +4. (G3BP1) Computes oscillation-aware metrics on the post-window. +5. Aggregates across cells per cohort with FOV-stratified mock as null. + +The FOV-stratified mock null follows discussion §3.7 and the round 2 +ML-engineer critique: pooled mock distributions inflate the 95th +percentile with FOV-to-FOV variance, not per-cell variance. We match +each productive cell to mocks from the same FOV and use that local +distribution. +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import anndata as ad +import numpy as np +import pandas as pd + +from dynaclr.pseudotime import date_prefix_from_dataset_id, find_embedding_zarr + +_logger = logging.getLogger(__name__) + +ALIGNMENT_DIRS = { + "A-anno": "A-anno", + "A-LC": "A-LC", + "B": "B", +} + + +def load_alignment_parquet( + align_root: Path, + track: str, + candidate_set: str, +) -> pd.DataFrame: + """Load a Stage 2 alignment parquet for one track. + + Path B parquets carry a ``{template}_{flavor}_on_{candidate_set}.parquet`` + name; Path A parquets are ``{candidate_set}.parquet``. The caller can + pass the full filename via ``--alignment-parquet`` to bypass this + lookup. + """ + if track not in ALIGNMENT_DIRS: + raise KeyError(f"Unknown track {track!r}; expected one of {list(ALIGNMENT_DIRS)}") + track_dir = align_root / ALIGNMENT_DIRS[track] / "alignments" + if track in ("A-anno", "A-LC"): + path = track_dir / f"{candidate_set}.parquet" + else: + # Path B: glob for any template/flavor matching this candidate set. + matches = list(track_dir.glob(f"*_on_{candidate_set}.parquet")) + if not matches: + raise FileNotFoundError(f"No Path B parquet under {track_dir} for {candidate_set!r}") + if len(matches) > 1: + _logger.warning(f"Multiple Path B parquets for {candidate_set!r}; using {matches[0].name}") + path = matches[0] + if not path.exists(): + raise FileNotFoundError(f"Alignment parquet not found: {path}") + return pd.read_parquet(path) + + +def load_organelle_embeddings( + dataset_cfgs: dict[str, dict], + datasets_in_use: list[str], + embedding_pattern: str, +) -> dict[str, ad.AnnData]: + """Load the organelle channel zarr per dataset, date-matched. + + Re-uses :func:`find_embedding_zarr` from the library. Returns one + AnnData per dataset; keys missing here will skip downstream cells. + """ + out: dict[str, ad.AnnData] = {} + for ds_id in datasets_in_use: + if ds_id not in dataset_cfgs: + _logger.warning(f"dataset_id {ds_id!r} missing from datasets.yaml; skipping") + continue + ds_cfg = dataset_cfgs[ds_id] + prefix = date_prefix_from_dataset_id(ds_id) + try: + zarr_path = find_embedding_zarr(ds_cfg["pred_dir"], prefix + embedding_pattern) + except FileNotFoundError as exc: + _logger.warning(f"[{ds_id}] no embedding zarr matched {prefix + embedding_pattern}: {exc}") + continue + adata = ad.read_zarr(zarr_path) + adata.obs_names_make_unique() + out[ds_id] = adata + _logger.info(f"[{ds_id}] loaded {Path(zarr_path).name} ({adata.n_obs} cells)") + return out + + +def per_cell_baseline_distance( + align_df: pd.DataFrame, + adata_by_dataset: dict[str, ad.AnnData], + baseline_window_minutes: tuple[float, float] = (-240, -60), + metric: str = "cosine", + min_baseline_frames: int = 2, +) -> pd.DataFrame: + """Compute per-cell cosine distance from a per-cell pre-baseline. + + For each ``(dataset_id, fov_name, track_id)`` in the alignment + parquet, fetches the matching frames from the organelle AnnData, + computes a per-track baseline as the mean of pre-window embeddings, + then writes ``signal = cosine_distance(embedding, baseline)`` per + frame. + + Cells missing from the embedding zarr or with fewer than + ``min_baseline_frames`` baseline frames produce NaN signal. + + Returns a copy of ``align_df`` with a new ``signal`` column. + """ + from scipy.spatial.distance import cdist + + out = align_df.copy() + out["signal"] = np.nan + + for ds_id, ds_group in out.groupby("dataset_id"): + if ds_id not in adata_by_dataset: + continue + adata = adata_by_dataset[ds_id] + obs = adata.obs[["fov_name", "track_id", "t"]].copy() + obs["fov_name"] = obs["fov_name"].astype(str) + obs["track_id"] = obs["track_id"].astype(int) + obs["t"] = obs["t"].astype(int) + obs["_idx"] = np.arange(len(obs)) + + # Per-track loop. Vectorising would require careful indexing for + # the per-track baseline computation; the explicit loop is + # readable and fast enough for our cohort sizes. + for (fov, tid), track_rows in ds_group.groupby(["fov_name", "track_id"]): + track_obs = obs[(obs["fov_name"] == str(fov)) & (obs["track_id"] == int(tid))] + if track_obs.empty: + continue + t_to_idx = dict(zip(track_obs["t"].astype(int), track_obs["_idx"].astype(int))) + + # Baseline frames: pre-window in t_rel_minutes for cells with + # an anchor (productive cohort + sibling daughters); for + # cohorts without an anchor (mock, bystander, abortive, + # unannotated_productive) fall back to the whole-track mean + # so they contribute a per-frame signal usable as a null + # distribution per discussion §3.7. + anchored = track_rows["t_rel_minutes"].notna().any() + if anchored: + bl_mask = (track_rows["t_rel_minutes"] >= baseline_window_minutes[0]) & ( + track_rows["t_rel_minutes"] <= baseline_window_minutes[1] + ) + bl_t = track_rows.loc[bl_mask, "t"].astype(int).tolist() + bl_indices = [t_to_idx[t] for t in bl_t if t in t_to_idx] + if len(bl_indices) < min_baseline_frames: + continue + else: + bl_indices = list(t_to_idx.values()) + if len(bl_indices) < min_baseline_frames: + continue + X = adata.X + baseline = np.asarray(X[bl_indices]).mean(axis=0, keepdims=True) + + # Compute distance for every frame in track_rows. + row_t = track_rows["t"].astype(int).tolist() + row_indices = [t_to_idx.get(t, -1) for t in row_t] + valid = np.array([idx >= 0 for idx in row_indices]) + if not valid.any(): + continue + present = np.array([idx for idx, ok in zip(row_indices, valid) if ok]) + embeddings = np.asarray(X[present]) + distances = cdist(embeddings, baseline, metric=metric).flatten() + + row_idx_array = track_rows.index.to_numpy()[valid] + out.loc[row_idx_array, "signal"] = distances + + return out + + +def fov_stratified_threshold( + productive_signal: pd.DataFrame, + mock_signal: pd.DataFrame, + percentile: float = 95.0, +) -> pd.DataFrame: + """Per-FOV threshold from mock cells with well-row fallback. + + For each FOV present in the productive set, compute the + ``percentile``th of mock-cell ``signal`` values from the same FOV. + Returns a frame with ``(fov_name, threshold, n_mock, n_mock_source, + fallback)``. + + Fallback ladder: + + 1. Same FOV (e.g. ``A/2/000000``) — exact stratification. + 2. Same well-row (e.g. ``A/*`` for productive ``A/2/000000``) — + channel-matched stratification when mock lives in a sibling well + on the same row (e.g. SEC61 productive=A/2, SEC61 mock=A/1). + 3. Global mock percentile. + + Each level requires ≥ 30 frames; otherwise descends to the next. + Per discussion §3.8 #8 the per-FOV check is mandatory; this records + which level supplied the threshold. + """ + if mock_signal.empty: + raise RuntimeError("Mock cohort empty; cannot compute FOV-stratified threshold") + mock_valid = mock_signal[mock_signal["signal"].notna()] + if mock_valid.empty: + raise RuntimeError("Mock cohort has no finite signal values") + + mock_valid = mock_valid.copy() + mock_valid["fov_name"] = mock_valid["fov_name"].astype(str) + mock_valid["_row"] = mock_valid["fov_name"].str.split("/").str[0] + + global_threshold = float(np.percentile(mock_valid["signal"].to_numpy(), percentile)) + + rows = [] + productive_fovs = sorted(productive_signal["fov_name"].astype(str).unique()) + for fov in productive_fovs: + row_label = fov.split("/")[0] + same_fov = mock_valid[mock_valid["fov_name"] == fov] + same_row = mock_valid[mock_valid["_row"] == row_label] + if len(same_fov) >= 30: + threshold = float(np.percentile(same_fov["signal"].to_numpy(), percentile)) + rows.append( + { + "fov_name": fov, + "threshold": threshold, + "n_mock": len(same_fov), + "n_mock_source": "same_fov", + "fallback": False, + } + ) + elif len(same_row) >= 30: + threshold = float(np.percentile(same_row["signal"].to_numpy(), percentile)) + rows.append( + { + "fov_name": fov, + "threshold": threshold, + "n_mock": len(same_row), + "n_mock_source": "same_row", + "fallback": True, + } + ) + else: + rows.append( + { + "fov_name": fov, + "threshold": global_threshold, + "n_mock": len(mock_valid), + "n_mock_source": "global", + "fallback": True, + } + ) + return pd.DataFrame(rows) + + +def oscillation_metrics_per_cell( + cohort_signal: pd.DataFrame, + threshold_by_fov: pd.DataFrame, + post_window_minutes: tuple[float, float] = (180, 540), +) -> pd.DataFrame: + """Per-cell oscillation statistics for G3BP1 (post-window only). + + Real-time post-window only — never warped (per discussion §3.6). + Stress granule kinetics are minute-scale; warping by NS3 is meaningless. + + Per-cell scalars: + - excursion_count: number of zero-crossings of (signal > threshold). + - dwell_time_minutes: total time above threshold. + - largest_excursion_amplitude: peak signal − threshold across the window. + - largest_excursion_duration_minutes: longest contiguous span above threshold. + + Cells with no post-window frames or no FOV threshold get NaN. + """ + threshold_lookup = dict(zip(threshold_by_fov["fov_name"].astype(str), threshold_by_fov["threshold"])) + rows = [] + for (ds_id, fov, tid), g in cohort_signal.groupby(["dataset_id", "fov_name", "track_id"]): + g = g.sort_values("t") + post = g[(g["t_rel_minutes"] >= post_window_minutes[0]) & (g["t_rel_minutes"] <= post_window_minutes[1])] + if post.empty or post["signal"].notna().sum() == 0: + continue + threshold = threshold_lookup.get(str(fov)) + if threshold is None: + continue + signal = post["signal"].to_numpy() + t_minutes = post["t_rel_minutes"].to_numpy() + above = signal > threshold + # Excursion count: count rising edges (False → True) in the above mask. + rising = np.diff(above.astype(int)) == 1 + excursion_count = int(rising.sum()) + (1 if above[0] else 0) + # Dwell time: estimate via successive frame intervals where above is True. + if len(t_minutes) >= 2: + frame_intervals = np.diff(t_minutes) + dwell_intervals = frame_intervals * above[:-1].astype(float) + dwell_time = float(dwell_intervals.sum()) + else: + dwell_time = 0.0 if not above.any() else float(t_minutes[-1] - t_minutes[0]) + # Largest excursion amplitude. + amp_above = signal[above] + largest_amp = float(amp_above.max() - threshold) if amp_above.size else float("nan") + # Largest excursion duration: longest contiguous run of above==True. + if above.any(): + run_lens = [] + cur_start = None + for i, v in enumerate(above): + if v and cur_start is None: + cur_start = i + elif not v and cur_start is not None: + run_lens.append((cur_start, i - 1)) + cur_start = None + if cur_start is not None: + run_lens.append((cur_start, len(above) - 1)) + run_durations = [float(t_minutes[end] - t_minutes[start]) for start, end in run_lens] + largest_dur = float(max(run_durations)) if run_durations else 0.0 + else: + largest_dur = 0.0 + + rows.append( + { + "dataset_id": ds_id, + "fov_name": str(fov), + "track_id": int(tid), + "lineage_id": str(g["lineage_id"].iloc[0]), + "cohort": str(g["cohort"].iloc[0]), + "threshold": float(threshold), + "excursion_count": excursion_count, + "dwell_time_minutes": dwell_time, + "largest_excursion_amplitude": largest_amp, + "largest_excursion_duration_minutes": largest_dur, + } + ) + return pd.DataFrame(rows) diff --git a/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/readout_g3bp1.py b/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/readout_g3bp1.py new file mode 100644 index 000000000..07d832a1a --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/readout_g3bp1.py @@ -0,0 +1,126 @@ +"""Stage 3 G3BP1 readout: oscillation-aware metrics on the post-window. + +Stress granules are transient phase-separated condensates that +assemble and disassemble on minute timescales (per discussion §3.6). +Distance-from-baseline is computed everywhere, but the headline metrics +are oscillation-aware: excursion count, dwell time above mock 95th +percentile, largest excursion amplitude and duration. Computed on the +real-time post-window only (not warped) — warping by NS3 dynamics +would average phase-shifted oscillations to a flat line. + +Usage:: + + cd applications/dynaclr/scripts/pseudotime/3-organelle-remodeling + uv run python readout_g3bp1.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/candidates.yaml \ + --candidate-set zikv_productive_07_24 \ + --track A-anno +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +SCRIPT_DIR = Path(__file__).resolve().parent +ALIGN_ROOT = SCRIPT_DIR.parent / "2-align_cells" + +sys.path.insert(0, str(SCRIPT_DIR.parent)) +sys.path.insert(0, str(SCRIPT_DIR)) +from readout_common import ( # noqa: E402 + fov_stratified_threshold, + load_alignment_parquet, + load_organelle_embeddings, + oscillation_metrics_per_cell, + per_cell_baseline_distance, +) +from utils import load_stage_config # noqa: E402 + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") +_logger = logging.getLogger(__name__) + +ORGANELLE = "g3bp1" +EMBEDDING_KEY = "organelle_g3bp1" + + +def main() -> None: + """Compute G3BP1 cosine-distance signal + per-cell oscillation metrics.""" + parser = argparse.ArgumentParser(description="Stage 3 G3BP1 readout (oscillation-aware metrics).") + parser.add_argument("--datasets", required=True) + parser.add_argument("--config", required=True) + parser.add_argument("--candidate-set", required=True) + parser.add_argument("--track", required=True, choices=["A-anno", "A-LC", "B"]) + parser.add_argument("--baseline-pre-min", type=float, default=-240.0) + parser.add_argument("--baseline-pre-max", type=float, default=-60.0) + parser.add_argument("--post-min", type=float, default=180.0, help="Post-window start in minutes (default: 180)") + parser.add_argument( + "--post-max", + type=float, + default=540.0, + help=( + "Post-window end in minutes (default: 540 — extended for G3BP1 plateau " + "per discussion §3.6 / biologist round 2)" + ), + ) + args = parser.parse_args() + + config = load_stage_config(args.datasets, args.config) + dataset_cfgs = {d["dataset_id"]: d for d in config["datasets"]} + embedding_pattern = config.get("embeddings", {}).get(EMBEDDING_KEY) + if embedding_pattern is None: + raise KeyError(f"datasets.yaml embeddings.{EMBEDDING_KEY} not set") + + align_df = load_alignment_parquet(ALIGN_ROOT, args.track, args.candidate_set) + _logger.info(f"Loaded {len(align_df)} alignment rows for track {args.track}") + + datasets_in_use = sorted(align_df["dataset_id"].unique()) + adata_by_dataset = load_organelle_embeddings(dataset_cfgs, datasets_in_use, embedding_pattern) + if not adata_by_dataset: + raise RuntimeError(f"No {ORGANELLE} embedding zarrs loaded; check pattern + pred_dir") + + signal_df = per_cell_baseline_distance( + align_df, + adata_by_dataset, + baseline_window_minutes=(args.baseline_pre_min, args.baseline_pre_max), + ) + + productive_signal = signal_df[signal_df["cohort"] == "productive"] + mock_signal = signal_df[signal_df["cohort"] == "mock"] + if mock_signal["signal"].notna().sum() == 0: + raise RuntimeError("Mock cohort has no G3BP1 signal; cannot compute oscillation thresholds") + threshold_df = fov_stratified_threshold(productive_signal, mock_signal, percentile=95.0) + + osc_metrics = oscillation_metrics_per_cell( + productive_signal, + threshold_df, + post_window_minutes=(args.post_min, args.post_max), + ) + if not osc_metrics.empty: + summary_cols = [ + "excursion_count", + "dwell_time_minutes", + "largest_excursion_amplitude", + "largest_excursion_duration_minutes", + ] + _logger.info( + f"Per-cell oscillation summary (n={len(osc_metrics)}):\n{osc_metrics[summary_cols].describe().to_string()}" + ) + + out_dir = SCRIPT_DIR / args.track / ORGANELLE + out_dir.mkdir(parents=True, exist_ok=True) + signal_path = out_dir / f"{args.candidate_set}_signal.parquet" + threshold_path = out_dir / f"{args.candidate_set}_threshold.csv" + osc_path = out_dir / f"{args.candidate_set}_oscillation_metrics.parquet" + signal_df.to_parquet(signal_path, index=False) + threshold_df.to_csv(threshold_path, index=False) + osc_metrics.to_parquet(osc_path, index=False) + _logger.info(f"Wrote {signal_path} ({len(signal_df)} rows)") + _logger.info(f"Wrote {threshold_path} ({len(threshold_df)} FOVs)") + _logger.info(f"Wrote {osc_path} ({len(osc_metrics)} cells)") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/readout_phase.py b/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/readout_phase.py new file mode 100644 index 000000000..44655f367 --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/readout_phase.py @@ -0,0 +1,93 @@ +"""Stage 3 phase readout: cosine distance from per-cell pre-baseline. + +Same shape as readout_sec61 but uses the quantitative-phase channel +embedding. For claim (a') (DAG §9.2 / discussion §2.1), per-cell phase +onset times are matched against the corresponding fluorescent-marker +onset for Spearman ρ in Stage 4. + +Usage:: + + cd applications/dynaclr/scripts/pseudotime/3-organelle-remodeling + uv run python readout_phase.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/candidates.yaml \ + --candidate-set zikv_productive_07_24 \ + --track A-anno +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +SCRIPT_DIR = Path(__file__).resolve().parent +ALIGN_ROOT = SCRIPT_DIR.parent / "2-align_cells" + +sys.path.insert(0, str(SCRIPT_DIR.parent)) +sys.path.insert(0, str(SCRIPT_DIR)) +from readout_common import ( # noqa: E402 + fov_stratified_threshold, + load_alignment_parquet, + load_organelle_embeddings, + per_cell_baseline_distance, +) +from utils import load_stage_config # noqa: E402 + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") +_logger = logging.getLogger(__name__) + +ORGANELLE = "phase" +EMBEDDING_KEY = "phase" + + +def main() -> None: + """Compute per-cell phase cosine-distance signal and per-FOV mock thresholds.""" + parser = argparse.ArgumentParser(description="Stage 3 phase readout (cosine distance from baseline).") + parser.add_argument("--datasets", required=True) + parser.add_argument("--config", required=True) + parser.add_argument("--candidate-set", required=True) + parser.add_argument("--track", required=True, choices=["A-anno", "A-LC", "B"]) + parser.add_argument("--baseline-pre-min", type=float, default=-240.0) + parser.add_argument("--baseline-pre-max", type=float, default=-60.0) + args = parser.parse_args() + + config = load_stage_config(args.datasets, args.config) + dataset_cfgs = {d["dataset_id"]: d for d in config["datasets"]} + embedding_pattern = config.get("embeddings", {}).get(EMBEDDING_KEY) + if embedding_pattern is None: + raise KeyError(f"datasets.yaml embeddings.{EMBEDDING_KEY} not set") + + align_df = load_alignment_parquet(ALIGN_ROOT, args.track, args.candidate_set) + _logger.info(f"Loaded {len(align_df)} alignment rows for track {args.track}") + + datasets_in_use = sorted(align_df["dataset_id"].unique()) + adata_by_dataset = load_organelle_embeddings(dataset_cfgs, datasets_in_use, embedding_pattern) + if not adata_by_dataset: + raise RuntimeError(f"No {ORGANELLE} embedding zarrs loaded for any dataset; check pattern + pred_dir") + + signal_df = per_cell_baseline_distance( + align_df, + adata_by_dataset, + baseline_window_minutes=(args.baseline_pre_min, args.baseline_pre_max), + ) + n_with_signal = int(signal_df["signal"].notna().sum()) + _logger.info(f"Computed cosine distance for {n_with_signal}/{len(signal_df)} rows") + + productive_signal = signal_df[signal_df["cohort"] == "productive"] + mock_signal = signal_df[signal_df["cohort"] == "mock"] + threshold_df = fov_stratified_threshold(productive_signal, mock_signal, percentile=95.0) + + out_dir = SCRIPT_DIR / args.track / ORGANELLE + out_dir.mkdir(parents=True, exist_ok=True) + signal_path = out_dir / f"{args.candidate_set}_signal.parquet" + threshold_path = out_dir / f"{args.candidate_set}_threshold.csv" + signal_df.to_parquet(signal_path, index=False) + threshold_df.to_csv(threshold_path, index=False) + _logger.info(f"Wrote {signal_path} ({len(signal_df)} rows)") + _logger.info(f"Wrote {threshold_path} ({len(threshold_df)} FOVs)") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/readout_sec61.py b/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/readout_sec61.py new file mode 100644 index 000000000..7b855877f --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/3-organelle-remodeling/readout_sec61.py @@ -0,0 +1,100 @@ +"""Stage 3 SEC61 readout: cosine distance from per-cell pre-baseline. + +Reads a Stage 2 alignment parquet (Path A-anno, A-LC, or B), looks up +the matching SEC61 channel embeddings, computes per-cell cosine +distance from the pre-window baseline, aggregates across cells, and +emits a per-cohort population curve and per-cell timing metrics. + +Per discussion §3.6: SEC61 dynamics are monotone and structural; +distance-from-baseline is the right scalar readout. FOV-stratified +mock null per discussion §3.7. + +Usage:: + + cd applications/dynaclr/scripts/pseudotime/3-organelle-remodeling + uv run python readout_sec61.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/candidates.yaml \ + --candidate-set zikv_productive_07_24 \ + --track A-anno +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +SCRIPT_DIR = Path(__file__).resolve().parent +ALIGN_ROOT = SCRIPT_DIR.parent / "2-align_cells" + +sys.path.insert(0, str(SCRIPT_DIR.parent)) +sys.path.insert(0, str(SCRIPT_DIR)) +from readout_common import ( # noqa: E402 + fov_stratified_threshold, + load_alignment_parquet, + load_organelle_embeddings, + per_cell_baseline_distance, +) +from utils import load_stage_config # noqa: E402 + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") +_logger = logging.getLogger(__name__) + +ORGANELLE = "sec61" +EMBEDDING_KEY = "organelle_sec61" + + +def main() -> None: + """Compute per-cell SEC61 cosine-distance signal and per-FOV mock thresholds.""" + parser = argparse.ArgumentParser(description="Stage 3 SEC61 readout (cosine distance from baseline).") + parser.add_argument("--datasets", required=True) + parser.add_argument("--config", required=True) + parser.add_argument("--candidate-set", required=True) + parser.add_argument("--track", required=True, choices=["A-anno", "A-LC", "B"]) + parser.add_argument("--baseline-pre-min", type=float, default=-240.0) + parser.add_argument("--baseline-pre-max", type=float, default=-60.0) + args = parser.parse_args() + + config = load_stage_config(args.datasets, args.config) + dataset_cfgs = {d["dataset_id"]: d for d in config["datasets"]} + embedding_pattern = config.get("embeddings", {}).get(EMBEDDING_KEY) + if embedding_pattern is None: + raise KeyError(f"datasets.yaml embeddings.{EMBEDDING_KEY} not set") + + align_df = load_alignment_parquet(ALIGN_ROOT, args.track, args.candidate_set) + _logger.info(f"Loaded {len(align_df)} alignment rows for track {args.track}") + + datasets_in_use = sorted(align_df["dataset_id"].unique()) + adata_by_dataset = load_organelle_embeddings(dataset_cfgs, datasets_in_use, embedding_pattern) + if not adata_by_dataset: + raise RuntimeError(f"No {ORGANELLE} embedding zarrs loaded for any dataset; check pattern + pred_dir") + + signal_df = per_cell_baseline_distance( + align_df, + adata_by_dataset, + baseline_window_minutes=(args.baseline_pre_min, args.baseline_pre_max), + ) + n_with_signal = int(signal_df["signal"].notna().sum()) + _logger.info(f"Computed cosine distance for {n_with_signal}/{len(signal_df)} rows") + + productive_signal = signal_df[signal_df["cohort"] == "productive"] + mock_signal = signal_df[signal_df["cohort"] == "mock"] + if mock_signal["signal"].notna().sum() == 0: + _logger.warning("Mock cohort has no SEC61 signal; thresholds will fall back to global") + threshold_df = fov_stratified_threshold(productive_signal, mock_signal, percentile=95.0) + _logger.info(f"FOV-stratified mock 95th-percentile thresholds: {threshold_df['threshold'].describe().to_dict()}") + + out_dir = SCRIPT_DIR / args.track / ORGANELLE + out_dir.mkdir(parents=True, exist_ok=True) + signal_path = out_dir / f"{args.candidate_set}_signal.parquet" + threshold_path = out_dir / f"{args.candidate_set}_threshold.csv" + signal_df.to_parquet(signal_path, index=False) + threshold_df.to_csv(threshold_path, index=False) + _logger.info(f"Wrote {signal_path} ({len(signal_df)} rows)") + _logger.info(f"Wrote {threshold_path} ({len(threshold_df)} FOVs)") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/4-compare_tracks/bimodality_check.py b/applications/dynaclr/scripts/pseudotime/4-compare_tracks/bimodality_check.py new file mode 100644 index 000000000..fb45ef03b --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/4-compare_tracks/bimodality_check.py @@ -0,0 +1,110 @@ +"""Stage 4 bimodality check: Hartigans dip-test on back-projected real-time. + +Per discussion §3.8 #11 and round 2 ML-engineer critique: every +back-projected real-time distribution from Path B needs a multimodality +test before reporting a single median. We use the BIC ratio between +1- and 2-component Gaussian mixtures (no extra package dependency). +A 2-component model with BIC at least 10 lower than 1-component flags +the distribution as multimodal. + +Usage:: + + cd applications/dynaclr/scripts/pseudotime/4-compare_tracks + uv run python bimodality_check.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/compare_tracks.yaml \ + --comparison zikv_07_24_full +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +import numpy as np +import pandas as pd +from sklearn.mixture import GaussianMixture + +SCRIPT_DIR = Path(__file__).resolve().parent +READOUT_ROOT = SCRIPT_DIR.parent / "3-organelle-remodeling" +OUTPUT_DIR = SCRIPT_DIR / "comparisons" + +sys.path.insert(0, str(SCRIPT_DIR.parent)) +from utils import load_stage_config # noqa: E402 + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") +_logger = logging.getLogger(__name__) + + +def _bic_ratio(values: np.ndarray) -> tuple[float, float, bool]: + """Return (BIC_1comp, BIC_2comp, multimodal_flag). + + A distribution is flagged multimodal iff BIC drops by ≥ 10 going + from 1 to 2 components and there are at least 30 samples. + """ + if len(values) < 30: + return float("nan"), float("nan"), False + X = values.reshape(-1, 1) + g1 = GaussianMixture(n_components=1, random_state=0).fit(X) + g2 = GaussianMixture(n_components=2, random_state=0).fit(X) + bic1 = float(g1.bic(X)) + bic2 = float(g2.bic(X)) + return bic1, bic2, bool(bic1 - bic2 >= 10.0) + + +def main() -> None: + """Run dip-test (BIC GMM) on back-projected real-time per (track, organelle, cohort).""" + parser = argparse.ArgumentParser(description="Stage 4 bimodality check.") + parser.add_argument("--datasets", required=True) + parser.add_argument("--config", required=True) + parser.add_argument("--comparison", required=True) + args = parser.parse_args() + + config = load_stage_config(args.datasets, args.config) + cmp_cfg = config["comparisons"][args.comparison] + candidate_set = cmp_cfg["candidate_set"] + organelles = cmp_cfg.get("organelles", ["sec61", "g3bp1", "phase"]) + tracks = cmp_cfg.get("tracks", ["A-anno", "A-LC", "B"]) + cohorts = cmp_cfg.get("cohorts", ["productive"]) + + rows = [] + for track in tracks: + for organelle in organelles: + path = READOUT_ROOT / track / organelle / f"{candidate_set}_signal.parquet" + if not path.exists(): + continue + df = pd.read_parquet(path) + for cohort in cohorts: + # Path B: back-projected real-time at the aligned region. + # Path A: t_rel_minutes is itself real-time. Both reduce + # to "where in real-time is the signal at threshold." + col = ( + "t_rel_minutes_warped" if track == "B" and "t_rel_minutes_warped" in df.columns else "t_rel_minutes" + ) + sub = df[(df["cohort"] == cohort) & df[col].notna() & df["signal"].notna()] + if sub.empty: + continue + values = sub[col].to_numpy() + bic1, bic2, multimodal = _bic_ratio(values) + rows.append( + { + "track": track, + "organelle": organelle, + "cohort": cohort, + "n": int(len(values)), + "bic_1comp": bic1, + "bic_2comp": bic2, + "multimodal": multimodal, + } + ) + + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + out_path = OUTPUT_DIR / f"{args.comparison}_bimodality.csv" + pd.DataFrame(rows).to_csv(out_path, index=False) + _logger.info(f"Wrote {out_path}") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/4-compare_tracks/compare_onsets.py b/applications/dynaclr/scripts/pseudotime/4-compare_tracks/compare_onsets.py new file mode 100644 index 000000000..dd9232783 --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/4-compare_tracks/compare_onsets.py @@ -0,0 +1,216 @@ +"""Stage 4 cross-track comparison: side-by-side onset population curves. + +Reads per-organelle signal parquets from Stage 3 (one per track), bins +by ``t_rel_minutes``, computes binned median + IQR per cohort, and +emits a 9-panel figure (3 organelles × 3 tracks). Computes the +methodological-claim verdict: does Path B's IQR at the headline metric +beat the better of A-anno and A-LC by ≥ 25% per DAG §9.1? + +Per discussion §2.2: the methodological claim succeeds if Path B's +population-curve IQR at the headline metric is at least 25% tighter +than the better of the two Path A baselines. Same metric reported in +real-time minutes for all three tracks so the comparison is fair. + +Usage:: + + cd applications/dynaclr/scripts/pseudotime/4-compare_tracks + uv run python compare_onsets.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/compare_tracks.yaml \ + --comparison zikv_07_24_full +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +SCRIPT_DIR = Path(__file__).resolve().parent +READOUT_ROOT = SCRIPT_DIR.parent / "3-organelle-remodeling" +OUTPUT_DIR = SCRIPT_DIR / "comparisons" + +sys.path.insert(0, str(SCRIPT_DIR.parent)) +from utils import load_stage_config # noqa: E402 + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") +_logger = logging.getLogger(__name__) + + +def _load_signal(track: str, organelle: str, candidate_set: str) -> pd.DataFrame: + """Load a Stage 3 signal parquet.""" + path = READOUT_ROOT / track / organelle / f"{candidate_set}_signal.parquet" + if not path.exists(): + raise FileNotFoundError(f"Signal parquet not found: {path}") + return pd.read_parquet(path) + + +def _binned_summary( + df: pd.DataFrame, + bin_edges: np.ndarray, + cohort: str, +) -> pd.DataFrame: + """Per-bin median + IQR of ``signal`` for a cohort.""" + sub = df[(df["cohort"] == cohort) & df["signal"].notna() & df["t_rel_minutes"].notna()] + if sub.empty: + return pd.DataFrame(columns=["t_rel_bin_center", "median", "q25", "q75", "n"]) + bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) + bins = pd.cut(sub["t_rel_minutes"], bins=bin_edges, include_lowest=True, labels=False) + sub = sub.assign(_bin=bins) + rows = [] + for b, g in sub.groupby("_bin"): + if g["signal"].size < 3: + continue + rows.append( + { + "t_rel_bin_center": float(bin_centers[int(b)]), + "median": float(g["signal"].median()), + "q25": float(g["signal"].quantile(0.25)), + "q75": float(g["signal"].quantile(0.75)), + "n": int(g["signal"].size), + } + ) + return pd.DataFrame(rows) + + +def _iqr_at_zero(summary: pd.DataFrame, target_t_rel: float = 0.0) -> float: + """IQR width at the bin closest to ``target_t_rel``.""" + if summary.empty: + return float("nan") + idx = (summary["t_rel_bin_center"] - target_t_rel).abs().idxmin() + return float(summary.loc[idx, "q75"] - summary.loc[idx, "q25"]) + + +def _plot_grid( + summaries: dict[tuple[str, str, str], pd.DataFrame], + organelles: list[str], + tracks: list[str], + cohorts: list[str], + out_path: Path, + title: str, +) -> None: + """3×3 grid: rows = tracks, cols = organelles, cohorts overlaid per panel.""" + fig, axes = plt.subplots(len(tracks), len(organelles), figsize=(4 * len(organelles), 3 * len(tracks)), sharex=True) + if len(tracks) == 1: + axes = np.atleast_2d(axes) + cohort_colors = {"productive": "C3", "mock": "C7", "bystander": "C0"} + for i, track in enumerate(tracks): + for j, organelle in enumerate(organelles): + ax = axes[i, j] + for cohort in cohorts: + key = (track, organelle, cohort) + summary = summaries.get(key) + if summary is None or summary.empty: + continue + color = cohort_colors.get(cohort, "C2") + ax.plot(summary["t_rel_bin_center"], summary["median"], color=color, label=cohort) + ax.fill_between(summary["t_rel_bin_center"], summary["q25"], summary["q75"], color=color, alpha=0.2) + ax.axvline(0, color="k", linestyle="--", linewidth=0.5) + if i == 0: + ax.set_title(organelle) + if j == 0: + ax.set_ylabel(f"{track}\nsignal") + if i == len(tracks) - 1: + ax.set_xlabel("t_rel_minutes") + handles, labels = axes[0, 0].get_legend_handles_labels() + if handles: + fig.legend(handles, labels, loc="upper right") + fig.suptitle(title) + fig.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +def main() -> None: + """Run cross-track comparison for one comparison config entry.""" + parser = argparse.ArgumentParser(description="Stage 4 cross-track comparison.") + parser.add_argument("--datasets", required=True) + parser.add_argument("--config", required=True) + parser.add_argument("--comparison", required=True, help="Name under config['comparisons']") + args = parser.parse_args() + + config = load_stage_config(args.datasets, args.config) + comparisons = config.get("comparisons", {}) + if args.comparison not in comparisons: + raise KeyError(f"Comparison {args.comparison!r} not in {sorted(comparisons)}") + cmp_cfg = comparisons[args.comparison] + candidate_set = cmp_cfg["candidate_set"] + organelles = cmp_cfg.get("organelles", ["sec61", "g3bp1", "phase"]) + tracks = cmp_cfg.get("tracks", ["A-anno", "A-LC", "B"]) + cohorts = cmp_cfg.get("cohorts", ["productive", "mock", "bystander"]) + bin_minutes = float(cmp_cfg.get("bin_minutes", 30.0)) + bin_range = cmp_cfg.get("bin_range_minutes", [-360, 540]) + methodological_fraction = float(cmp_cfg.get("methodological_threshold_fraction", 0.25)) + bin_edges = np.arange(bin_range[0], bin_range[1] + bin_minutes, bin_minutes) + + # Load all signal parquets and compute per-(track, organelle, cohort) summaries. + summaries: dict[tuple[str, str, str], pd.DataFrame] = {} + for track in tracks: + for organelle in organelles: + try: + df = _load_signal(track, organelle, candidate_set) + except FileNotFoundError as exc: + _logger.warning(f"Skipping {track}/{organelle}: {exc}") + continue + for cohort in cohorts: + summaries[(track, organelle, cohort)] = _binned_summary(df, bin_edges, cohort) + + # Methodological-claim verdict per organelle. + verdict_rows = [] + for organelle in organelles: + b_iqr = _iqr_at_zero(summaries.get(("B", organelle, "productive"), pd.DataFrame())) + a_anno_iqr = _iqr_at_zero(summaries.get(("A-anno", organelle, "productive"), pd.DataFrame())) + a_lc_iqr = _iqr_at_zero(summaries.get(("A-LC", organelle, "productive"), pd.DataFrame())) + better_a = ( + min(v for v in (a_anno_iqr, a_lc_iqr) if not np.isnan(v)) + if any(not np.isnan(v) for v in (a_anno_iqr, a_lc_iqr)) + else float("nan") + ) + ratio = ( + float(b_iqr / better_a) if (better_a and not np.isnan(better_a) and not np.isnan(b_iqr)) else float("nan") + ) + success = (not np.isnan(ratio)) and (ratio <= 1.0 - methodological_fraction) + verdict_rows.append( + { + "organelle": organelle, + "B_iqr_at_zero": b_iqr, + "A_anno_iqr_at_zero": a_anno_iqr, + "A_LC_iqr_at_zero": a_lc_iqr, + "better_A_iqr": better_a, + "B_over_A_ratio": ratio, + "methodological_success": bool(success), + } + ) + verdict_df = pd.DataFrame(verdict_rows) + _logger.info(f"Methodological-claim verdicts:\n{verdict_df.to_string(index=False)}") + + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + fig_path = OUTPUT_DIR / f"{args.comparison}.png" + summary_path = OUTPUT_DIR / f"{args.comparison}_summary.csv" + verdict_path = OUTPUT_DIR / f"{args.comparison}_verdict.csv" + + long_summary = [] + for (track, organelle, cohort), df in summaries.items(): + if df.empty: + continue + df = df.copy() + df["track"] = track + df["organelle"] = organelle + df["cohort"] = cohort + long_summary.append(df) + if long_summary: + pd.concat(long_summary, ignore_index=True).to_csv(summary_path, index=False) + verdict_df.to_csv(verdict_path, index=False) + _plot_grid(summaries, organelles, tracks, cohorts, fig_path, title=args.comparison) + _logger.info(f"Wrote {fig_path}") + _logger.info(f"Wrote {summary_path}") + _logger.info(f"Wrote {verdict_path}") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/4-compare_tracks/compare_phase_to_fluor.py b/applications/dynaclr/scripts/pseudotime/4-compare_tracks/compare_phase_to_fluor.py new file mode 100644 index 000000000..6b228fda0 --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/4-compare_tracks/compare_phase_to_fluor.py @@ -0,0 +1,144 @@ +"""Stage 4 claim (a'): Spearman ρ of phase onset vs fluorescent-marker onset. + +Per discussion §2.1 / DAG §9.2: claim (a') succeeds iff per-cell phase +onset times correlate (ρ ≥ 0.20) with the matched fluorescent-marker +onset times across the productive cohort. Falsifier is per-organelle: +SEC61 carries weight; G3BP1 expected null is positive evidence for +fluorescence-and-phase complementarity. + +Per-cell onset time = first frame where signal exceeds the per-FOV +mock 95th percentile. ρ computed via scipy.stats.spearmanr; p-value +via 1000-shuffle permutation null. + +Usage:: + + cd applications/dynaclr/scripts/pseudotime/4-compare_tracks + uv run python compare_phase_to_fluor.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/compare_tracks.yaml \ + --comparison zikv_07_24_full \ + --organelle sec61 \ + --track A-anno +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +import numpy as np +import pandas as pd +from scipy.stats import spearmanr + +SCRIPT_DIR = Path(__file__).resolve().parent +READOUT_ROOT = SCRIPT_DIR.parent / "3-organelle-remodeling" +OUTPUT_DIR = SCRIPT_DIR / "comparisons" + +sys.path.insert(0, str(SCRIPT_DIR.parent)) +from utils import load_stage_config # noqa: E402 + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") +_logger = logging.getLogger(__name__) + + +def _per_cell_onset(signal_df: pd.DataFrame, threshold_df: pd.DataFrame) -> pd.DataFrame: + """First t_rel_minutes where signal exceeds the per-FOV threshold. + + Cells without an anchor (``t_rel_minutes`` all NaN — e.g. 08_26's + LC-anchored productive cells under the A-anno track) are skipped. + """ + threshold_lookup = dict(zip(threshold_df["fov_name"].astype(str), threshold_df["threshold"])) + rows = [] + productive = signal_df[ + (signal_df["cohort"] == "productive") & signal_df["signal"].notna() & signal_df["t_rel_minutes"].notna() + ] + for (ds, fov, tid), g in productive.groupby(["dataset_id", "fov_name", "track_id"]): + threshold = threshold_lookup.get(str(fov)) + if threshold is None: + continue + g = g.sort_values("t_rel_minutes") + above = g["signal"].to_numpy() > threshold + if not above.any(): + continue + first_idx = int(np.argmax(above)) + rows.append( + { + "dataset_id": ds, + "fov_name": str(fov), + "track_id": int(tid), + "lineage_id": str(g["lineage_id"].iloc[0]), + "onset_t_rel_minutes": float(g["t_rel_minutes"].iloc[first_idx]), + } + ) + return pd.DataFrame(rows) + + +def main() -> None: + """Compute Spearman ρ between phase and matched-fluor onset times.""" + parser = argparse.ArgumentParser(description="Stage 4 claim (a') Spearman ρ.") + parser.add_argument("--datasets", required=True) + parser.add_argument("--config", required=True) + parser.add_argument("--comparison", required=True) + parser.add_argument("--organelle", required=True, choices=["sec61", "g3bp1"]) + parser.add_argument("--track", required=True, choices=["A-anno", "A-LC", "B"]) + parser.add_argument("--n-permutations", type=int, default=1000) + args = parser.parse_args() + + config = load_stage_config(args.datasets, args.config) + cmp_cfg = config["comparisons"][args.comparison] + candidate_set = cmp_cfg["candidate_set"] + + fluor_signal = pd.read_parquet(READOUT_ROOT / args.track / args.organelle / f"{candidate_set}_signal.parquet") + fluor_threshold = pd.read_csv(READOUT_ROOT / args.track / args.organelle / f"{candidate_set}_threshold.csv") + phase_signal = pd.read_parquet(READOUT_ROOT / args.track / "phase" / f"{candidate_set}_signal.parquet") + phase_threshold = pd.read_csv(READOUT_ROOT / args.track / "phase" / f"{candidate_set}_threshold.csv") + + fluor_onset = _per_cell_onset(fluor_signal, fluor_threshold) + phase_onset = _per_cell_onset(phase_signal, phase_threshold) + + paired = fluor_onset.merge( + phase_onset, + on=["dataset_id", "fov_name", "track_id", "lineage_id"], + suffixes=("_fluor", "_phase"), + ) + n_paired = len(paired) + if n_paired < 5: + _logger.warning(f"Only {n_paired} paired cells; ρ not informative") + rho, _p = spearmanr(paired["onset_t_rel_minutes_fluor"], paired["onset_t_rel_minutes_phase"]) + + # Permutation null. + rng = np.random.default_rng(seed=0) + null_rhos = [] + fluor_arr = paired["onset_t_rel_minutes_fluor"].to_numpy() + phase_arr = paired["onset_t_rel_minutes_phase"].to_numpy() + for _ in range(args.n_permutations): + shuffled = rng.permutation(phase_arr) + r, _ = spearmanr(fluor_arr, shuffled) + null_rhos.append(r) + null_rhos = np.asarray(null_rhos) + perm_p = float(np.mean(np.abs(null_rhos) >= abs(rho))) + + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + out_path = OUTPUT_DIR / f"{args.comparison}_phase_vs_{args.organelle}_{args.track}.csv" + pd.DataFrame( + [ + { + "comparison": args.comparison, + "organelle": args.organelle, + "track": args.track, + "n_paired": int(n_paired), + "spearman_rho": float(rho) if not np.isnan(rho) else np.nan, + "permutation_p_value": perm_p, + "n_permutations": int(args.n_permutations), + "claim_succeeds": bool((not np.isnan(rho)) and abs(rho) >= 0.20 and perm_p < 0.05), + } + ] + ).to_csv(out_path, index=False) + _logger.info(f"ρ = {rho:.3f}, perm p = {perm_p:.4f} (n={n_paired})") + _logger.info(f"Wrote {out_path}") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/4-compare_tracks/plot_paired_traces.py b/applications/dynaclr/scripts/pseudotime/4-compare_tracks/plot_paired_traces.py new file mode 100644 index 000000000..4f3d34c19 --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/4-compare_tracks/plot_paired_traces.py @@ -0,0 +1,154 @@ +r"""Diagnostic: plot per-cell phase + fluor signal traces for paired cells. + +Visual companion to ``compare_phase_to_fluor.py``. For every productive +cell that contributes to the Spearman ρ (i.e. crosses both the phase +and the fluorescence threshold), draws a 2-panel row showing: + +- Top: phase cosine distance vs ``t_rel_minutes`` with the per-FOV + threshold and the detected phase onset. +- Bottom: fluor cosine distance vs ``t_rel_minutes`` with the per-FOV + threshold and the detected fluor onset. + +Used to inspect *why* g3bp1 ρ is high (sharp threshold crossings) and +sec61 ρ is unstable (slow / gradual). + +Usage:: + + cd applications/dynaclr/scripts/pseudotime/4-compare_tracks + uv run python plot_paired_traces.py \\ + --comparison zikv_07_24_full \\ + --organelle sec61 \\ + --track A-anno +""" + +from __future__ import annotations + +import argparse +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +SCRIPT_DIR = Path(__file__).resolve().parent +READOUT_ROOT = SCRIPT_DIR.parent / "3-organelle-remodeling" +OUTPUT_DIR = SCRIPT_DIR / "comparisons" + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") +_logger = logging.getLogger(__name__) + +# Mapping from --comparison to candidate set (matches compare_tracks.yaml). +COMPARISON_TO_CANDIDATE_SET = { + "zikv_07_24_full": "zikv_productive_07_24", + "zikv_pooled_full": "zikv_productive_pooled", +} + + +def _per_cell_onset(signal_df: pd.DataFrame, threshold_df: pd.DataFrame) -> pd.DataFrame: + """Productive-cell onset = first ``t_rel_minutes`` where signal > FOV threshold.""" + threshold_lookup = dict(zip(threshold_df["fov_name"].astype(str), threshold_df["threshold"])) + rows = [] + productive = signal_df[ + (signal_df["cohort"] == "productive") & signal_df["signal"].notna() & signal_df["t_rel_minutes"].notna() + ] + for (ds, fov, tid), g in productive.groupby(["dataset_id", "fov_name", "track_id"]): + threshold = threshold_lookup.get(str(fov)) + if threshold is None: + continue + g = g.sort_values("t_rel_minutes") + above = g["signal"].to_numpy() > threshold + if not above.any(): + continue + first_idx = int(np.argmax(above)) + rows.append( + { + "dataset_id": str(ds), + "fov_name": str(fov), + "track_id": int(tid), + "lineage_id": str(g["lineage_id"].iloc[0]), + "onset_t_rel_minutes": float(g["t_rel_minutes"].iloc[first_idx]), + "threshold": float(threshold), + } + ) + return pd.DataFrame(rows) + + +def main() -> None: + """Render per-cell phase + fluor diagnostic traces.""" + parser = argparse.ArgumentParser(description="Plot paired phase/fluor traces.") + parser.add_argument("--comparison", required=True) + parser.add_argument("--organelle", required=True, choices=["sec61", "g3bp1"]) + parser.add_argument("--track", required=True, choices=["A-anno", "A-LC", "B"]) + args = parser.parse_args() + + candidate_set = COMPARISON_TO_CANDIDATE_SET[args.comparison] + + fluor_signal = pd.read_parquet(READOUT_ROOT / args.track / args.organelle / f"{candidate_set}_signal.parquet") + fluor_threshold = pd.read_csv(READOUT_ROOT / args.track / args.organelle / f"{candidate_set}_threshold.csv") + phase_signal = pd.read_parquet(READOUT_ROOT / args.track / "phase" / f"{candidate_set}_signal.parquet") + phase_threshold = pd.read_csv(READOUT_ROOT / args.track / "phase" / f"{candidate_set}_threshold.csv") + + fluor_onset = _per_cell_onset(fluor_signal, fluor_threshold) + phase_onset = _per_cell_onset(phase_signal, phase_threshold) + paired = fluor_onset.merge( + phase_onset, + on=["dataset_id", "fov_name", "track_id", "lineage_id"], + suffixes=("_fluor", "_phase"), + ) + n = len(paired) + if n == 0: + _logger.warning("No paired cells; nothing to plot") + return + _logger.info(f"Plotting {n} paired cells") + + paired = paired.sort_values("onset_t_rel_minutes_fluor").reset_index(drop=True) + + fig, axes = plt.subplots(n, 2, figsize=(11, max(2.5 * n, 3.0)), sharex=True, squeeze=False) + for i, row in paired.iterrows(): + key = (row["dataset_id"], row["fov_name"], int(row["track_id"])) + phase_track = phase_signal[ + (phase_signal["dataset_id"] == key[0]) + & (phase_signal["fov_name"] == key[1]) + & (phase_signal["track_id"] == key[2]) + & phase_signal["t_rel_minutes"].notna() + ].sort_values("t_rel_minutes") + fluor_track = fluor_signal[ + (fluor_signal["dataset_id"] == key[0]) + & (fluor_signal["fov_name"] == key[1]) + & (fluor_signal["track_id"] == key[2]) + & fluor_signal["t_rel_minutes"].notna() + ].sort_values("t_rel_minutes") + + ax_phase, ax_fluor = axes[i, 0], axes[i, 1] + ax_phase.plot(phase_track["t_rel_minutes"], phase_track["signal"], color="0.2", lw=1.0) + ax_phase.axhline(row["threshold_phase"], color="grey", ls=":", lw=0.8) + ax_phase.axvline(row["onset_t_rel_minutes_phase"], color="C0", ls="--", lw=1.0) + ax_phase.axvline(0, color="red", ls="-", lw=0.6, alpha=0.6) + ax_phase.set_ylabel(f"{key[1]}\nt={key[2]}", fontsize=8) + if i == 0: + ax_phase.set_title("phase cosine distance", fontsize=9) + + ax_fluor.plot(fluor_track["t_rel_minutes"], fluor_track["signal"], color="0.2", lw=1.0) + ax_fluor.axhline(row["threshold_fluor"], color="grey", ls=":", lw=0.8) + ax_fluor.axvline(row["onset_t_rel_minutes_fluor"], color="C1", ls="--", lw=1.0) + ax_fluor.axvline(0, color="red", ls="-", lw=0.6, alpha=0.6) + if i == 0: + ax_fluor.set_title(f"{args.organelle} cosine distance", fontsize=9) + + axes[-1, 0].set_xlabel("t_rel_minutes (anchor=0)") + axes[-1, 1].set_xlabel("t_rel_minutes (anchor=0)") + fig.suptitle( + f"{args.comparison} | {args.organelle} | track={args.track} | n={n} paired", + fontsize=10, + ) + fig.tight_layout() + + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + out_path = OUTPUT_DIR / f"{args.comparison}_traces_phase_vs_{args.organelle}_{args.track}.png" + fig.savefig(out_path, dpi=130, bbox_inches="tight") + _logger.info(f"Wrote {out_path}") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/4-compare_tracks/warp_vs_no_warp.py b/applications/dynaclr/scripts/pseudotime/4-compare_tracks/warp_vs_no_warp.py new file mode 100644 index 000000000..5c87dde49 --- /dev/null +++ b/applications/dynaclr/scripts/pseudotime/4-compare_tracks/warp_vs_no_warp.py @@ -0,0 +1,131 @@ +"""Stage 4 warp-vs-no-warp comparator (Path B only). + +Per discussion §3.8 #10 / §4.5: this is the mandatory comparator that +forces the data to answer whether Path B's transition-window warp +sharpens organelle timing relative to using real-time alone. For each +(organelle, cohort) we compute the per-cell onset distribution under +two readout axes derived from Path B's parquet: + +- ``t_rel_minutes`` (real-time relative to per-cell t_zero) — equivalent + to Path A on the same cohort. +- ``t_rel_minutes_warped`` (back-projected real-time from DTW warp). + +If the median onset times agree to within 25%, the warp is neutral and +Path B's warp propagation is kept for sharpness without distortion. If +they diverge by more than 25%, the warp is masking real organelle +timing — flag and investigate. + +Usage:: + + cd applications/dynaclr/scripts/pseudotime/4-compare_tracks + uv run python warp_vs_no_warp.py \ + --datasets ../../../configs/pseudotime/datasets.yaml \ + --config ../../../configs/pseudotime/compare_tracks.yaml \ + --comparison zikv_07_24_full +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +import numpy as np +import pandas as pd + +SCRIPT_DIR = Path(__file__).resolve().parent +READOUT_ROOT = SCRIPT_DIR.parent / "3-organelle-remodeling" +OUTPUT_DIR = SCRIPT_DIR / "comparisons" + +sys.path.insert(0, str(SCRIPT_DIR.parent)) +from utils import load_stage_config # noqa: E402 + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") +_logger = logging.getLogger(__name__) + + +def _per_cell_onset(signal_df: pd.DataFrame, threshold_df: pd.DataFrame, time_col: str) -> pd.DataFrame: + """First time (real-time or warped) where signal exceeds per-FOV threshold.""" + threshold_lookup = dict(zip(threshold_df["fov_name"].astype(str), threshold_df["threshold"])) + rows = [] + productive = signal_df[ + (signal_df["cohort"] == "productive") & signal_df["signal"].notna() & signal_df[time_col].notna() + ] + for (ds, fov, tid), g in productive.groupby(["dataset_id", "fov_name", "track_id"]): + threshold = threshold_lookup.get(str(fov)) + if threshold is None: + continue + g = g.sort_values(time_col) + above = g["signal"].to_numpy() > threshold + if not above.any(): + continue + first_idx = int(np.argmax(above)) + rows.append( + { + "dataset_id": ds, + "fov_name": str(fov), + "track_id": int(tid), + "onset_minutes": float(g[time_col].iloc[first_idx]), + } + ) + return pd.DataFrame(rows) + + +def main() -> None: + """Compare warped vs unwarped onset times under Path B.""" + parser = argparse.ArgumentParser(description="Stage 4 warp-vs-no-warp comparator (Path B only).") + parser.add_argument("--datasets", required=True) + parser.add_argument("--config", required=True) + parser.add_argument("--comparison", required=True) + args = parser.parse_args() + + config = load_stage_config(args.datasets, args.config) + cmp_cfg = config["comparisons"][args.comparison] + candidate_set = cmp_cfg["candidate_set"] + organelles = cmp_cfg.get("organelles", ["sec61", "g3bp1", "phase"]) + threshold_fraction = float(cmp_cfg.get("methodological_threshold_fraction", 0.25)) + + rows = [] + for organelle in organelles: + signal_path = READOUT_ROOT / "B" / organelle / f"{candidate_set}_signal.parquet" + threshold_path = READOUT_ROOT / "B" / organelle / f"{candidate_set}_threshold.csv" + if not signal_path.exists(): + _logger.warning(f"Missing Path B signal for {organelle}; skipping") + continue + signal_df = pd.read_parquet(signal_path) + threshold_df = pd.read_csv(threshold_path) + + unwarped = _per_cell_onset(signal_df, threshold_df, "t_rel_minutes") + warped_col = "t_rel_minutes_warped" if "t_rel_minutes_warped" in signal_df.columns else "t_rel_minutes" + warped = _per_cell_onset(signal_df, threshold_df, warped_col) + + unwarped_median = float(unwarped["onset_minutes"].median()) if not unwarped.empty else float("nan") + warped_median = float(warped["onset_minutes"].median()) if not warped.empty else float("nan") + diff = ( + float(abs(warped_median - unwarped_median)) + if not (np.isnan(warped_median) or np.isnan(unwarped_median)) + else float("nan") + ) + ref = float(abs(unwarped_median)) if not np.isnan(unwarped_median) else float("nan") + diverges = bool(not np.isnan(diff) and ref > 0 and diff / ref > threshold_fraction) + rows.append( + { + "organelle": organelle, + "n_unwarped": int(len(unwarped)), + "n_warped": int(len(warped)), + "median_unwarped_minutes": unwarped_median, + "median_warped_minutes": warped_median, + "absolute_diff_minutes": diff, + "diverges": diverges, + } + ) + + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + out_path = OUTPUT_DIR / f"{args.comparison}_warp_vs_no_warp.csv" + pd.DataFrame(rows).to_csv(out_path, index=False) + _logger.info(f"Wrote {out_path}") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/scripts/pseudotime/README.md b/applications/dynaclr/scripts/pseudotime/README.md deleted file mode 100644 index 4b86214aa..000000000 --- a/applications/dynaclr/scripts/pseudotime/README.md +++ /dev/null @@ -1,146 +0,0 @@ -# Pseudotime Remodeling Analysis - -Measure organelle remodeling timing relative to viral infection onset using lineage-aware alignment and multiple signal extraction methods. - -## Overview - -This directory is organized into `src/` (importable library modules) and `analysis/` (HPC scripts): - -``` -pseudotime/ -├── README.md -├── src/ -│ ├── __init__.py -│ ├── alignment.py -│ ├── signals.py -│ ├── metrics.py -│ └── plotting.py -└── analysis/ - ├── annotation_remodeling.py - ├── prediction_remodeling.py - └── embedding_distance.py -``` - -The pipeline follows: - -``` -alignment → signal extraction → aggregation → metrics → plotting -``` - -### Library Modules (`src/`) - -| Module | Description | -|--------|-------------| -| `src/alignment.py` | Lineage detection, FOV/track filtering, T_perturb assignment | -| `src/signals.py` | Signal extraction: annotation binary, classifier prediction, embedding distance | -| `src/metrics.py` | Population aggregation, onset/T50/peak detection, per-track timing, statistical tests | -| `src/plotting.py` | Response curves, per-track heatmaps, timing distributions, onset comparison | - -### Analysis Scripts (`analysis/`) - -Each script runs the full pipeline with a different signal source. They are Jupyter-compatible (`# %%` cell markers) and designed for HPC execution. - -| Script | Signal Source | Requires | -|--------|--------------|----------| -| `analysis/annotation_remodeling.py` | Human annotations (`organelle_state` column) | Tracking CSV + annotation CSV | -| `analysis/prediction_remodeling.py` | Classifier predictions (`predicted_organelle_state` in AnnData) | Tracking CSV + predicted AnnData zarr | -| `analysis/embedding_distance.py` | Cosine distance from baseline embeddings | Tracking CSV + embedding AnnData zarr | - -## Prerequisites - -Install DynaCLR with the eval extras and statsmodels: - -```bash -cd applications/dynaclr -uv pip install -e ".[eval]" statsmodels -``` - -## Running Tests - -Unit tests cover all four library modules using synthetic data (no HPC paths required): - -```bash -cd applications/dynaclr -uv run pytest tests/test_pseudotime.py -v -``` - -### Test Structure - -| Test Class | Tests | Module Covered | -|------------|-------|----------------| -| `TestAlignment` | 7 | `src/alignment.py` — lineage detection, FOV filtering, T_perturb assignment | -| `TestSignals` | 5 | `src/signals.py` — annotation/prediction/embedding-distance signal extraction | -| `TestMetrics` | 8 | `src/metrics.py` — population aggregation, onset/T50/peak, track timing, stats | -| `TestPlotting` | 4 | `src/plotting.py` — file output (pdf+png) and Figure return for all plot types | - -### Synthetic Data - -Tests use a self-contained tracking DataFrame with: -- **C/2/000**: 3 tracks with parent-child lineage, infected at t=5 -- **C/2/001**: 1 orphan track, infected at t=7 -- **B/1/000**: 2 control tracks (no infection) - -Plus a matching AnnData with 16-dim random embeddings and classifier predictions. - -## Pipeline Details - -### 1. Alignment - -Tracks are filtered by FOV pattern and minimum length, then aligned to infection onset (T_perturb). Lineage-aware logic ensures all tracks in a parent-child lineage share the same T_perturb. - -```python -from src.alignment import align_tracks - -aligned_df = align_tracks( - tracking_df, - frame_interval_minutes=30.0, - fov_pattern="C/2", - min_track_timepoints=3, -) -# Adds columns: t_perturb, t_relative_minutes -``` - -### 2. Signal Extraction - -Three modes producing a common `signal` column: - -```python -from src.signals import ( - extract_annotation_signal, - extract_prediction_signal, - extract_embedding_distance, -) - -# Binary from annotations -df = extract_annotation_signal(aligned_df, state_col="organelle_state") - -# Binary or continuous from classifier predictions -df = extract_prediction_signal(adata, aligned_df, task="organelle_state") - -# Cosine distance from baseline embeddings -df = extract_embedding_distance(adata, aligned_df, baseline_method="per_track") -``` - -### 3. Aggregation and Metrics - -```python -from src.metrics import aggregate_population, find_onset_time - -time_bins = np.arange(-600, 901, 30) -pop_df = aggregate_population(df, time_bins, signal_type="fraction") -onset, threshold, bl_mean, bl_std = find_onset_time(pop_df) -``` - -### 4. Plotting - -All plot functions save pdf+png and return the matplotlib Figure: - -```python -from src.plotting import plot_response_curves - -fig = plot_response_curves( - organelle_curves={"SEC61": pop_df}, - organelle_configs={"SEC61": {"label": "SEC61", "color": "#1f77b4"}}, - output_dir=Path("figures/"), -) -``` diff --git a/applications/dynaclr/scripts/pseudotime/annotation_remodeling.py b/applications/dynaclr/scripts/pseudotime/annotation_remodeling.py deleted file mode 100644 index 96b446045..000000000 --- a/applications/dynaclr/scripts/pseudotime/annotation_remodeling.py +++ /dev/null @@ -1,338 +0,0 @@ -# %% -""" -Annotation-based organelle remodeling analysis. - -Measures remodeling timing using human annotations (organelle_state column) -directly from annotation CSVs — no model predictions required. - -Pipeline: alignment → annotation signal → aggregation → metrics → plotting - -Usage: Run as a Jupyter-compatible script (# %% cell markers). -""" - -from pathlib import Path - -import numpy as np -import pandas as pd - -from dynaclr.evaluation.pseudotime.alignment import align_tracks -from dynaclr.evaluation.pseudotime.metrics import ( - aggregate_population, - compute_track_timing, - find_half_max_time, - find_onset_time, - find_peak_metrics, - run_statistical_tests, -) -from dynaclr.evaluation.pseudotime.plotting import ( - plot_cell_heatmap, - plot_onset_comparison, - plot_response_curves, - plot_timing_distributions, -) -from dynaclr.evaluation.pseudotime.signals import ( - extract_annotation_signal, -) - -# %% -# =========================================================================== -# Dataset configuration -# =========================================================================== - -ANNOTATIONS_ROOT = Path("/hpc/projects/organelle_phenotyping/datasets/annotations") - -ORGANELLE_CONFIG = { - "G3BP1_ZIKV": { - "experiments": [ - { - "csv_path": ANNOTATIONS_ROOT - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "C/2", - "frame_interval_minutes": 10, - "label": "2025_07_22 ZIKV", - }, - { - "csv_path": ANNOTATIONS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "C/2", - "frame_interval_minutes": 30, - "label": "2025_07_24 ZIKV", - }, - ], - "controls": [ - { - "csv_path": ANNOTATIONS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "C/1", - "frame_interval_minutes": 30, - "label": "2025_07_24 control (C/1)", - }, - ], - "label": "G3BP1 ZIKV (Stress Granule)", - "color": "#1f77b4", - }, - "SEC61B_ZIKV": { - "experiments": [ - { - "csv_path": ANNOTATIONS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "A/2", - "frame_interval_minutes": 30, - "label": "2025_07_24 ZIKV (SEC61B)", - }, - ], - "controls": [], - "label": "SEC61B ZIKV (ER)", - "color": "#ff7f0e", - }, - "G3BP1_DENV": { - "experiments": [ - { - "csv_path": ANNOTATIONS_ROOT - / "2025_01_24_A549_G3BP1_DENV" - / "2025_01_24_A549_G3BP1_DENV_combined_annotations.csv", - "fov_pattern": "C/2", - "frame_interval_minutes": 10, - "label": "2025_01_24 DENV", - }, - { - "csv_path": ANNOTATIONS_ROOT - / "2025_01_28_A549_G3BP1_ZIKV_DENV" - / "2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv", - "fov_pattern": "C/4", - "frame_interval_minutes": 30, - "label": "2025_01_28 DENV", - }, - ], - "controls": [ - { - "csv_path": ANNOTATIONS_ROOT - / "2025_01_28_A549_G3BP1_ZIKV_DENV" - / "2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv", - "fov_pattern": "B/4", - "frame_interval_minutes": 30, - "label": "2025_01_28 control (B/4)", - }, - ], - "label": "G3BP1 DENV (Stress Granule)", - "color": "#2ca02c", - }, -} - -# Analysis parameters -T_PERTURB_SOURCE = "annotation" -TIME_BINS_MINUTES = np.arange(-600, 901, 30) -MIN_CELLS_PER_BIN = 5 -MIN_TRACK_TIMEPOINTS = 3 -ONSET_THRESHOLD_SIGMA = 2 - -RESULTS_DIR = Path(__file__).parent / "results" / "annotation_remodeling" - -# %% -# =========================================================================== -# Step 1 + 2: Load data, alignment, and signal extraction -# =========================================================================== - -marker_results = {} - -for marker, config in ORGANELLE_CONFIG.items(): - print(f"\n{'=' * 60}") - print(f"Processing {marker}") - print(f"{'=' * 60}") - - all_experiment_dfs = [] - - for exp in config["experiments"]: - print(f"\n Experiment: {exp['label']}") - df = pd.read_csv(exp["csv_path"]) - print(f" Loaded {len(df):,} annotations, t range: {df['t'].min()}-{df['t'].max()}") - - # Ensure parent_track_id exists - if "parent_track_id" not in df.columns: - df["parent_track_id"] = -1 - - # Step 1: Alignment - aligned = align_tracks( - df, - frame_interval_minutes=exp["frame_interval_minutes"], - source=T_PERTURB_SOURCE, - fov_pattern=exp["fov_pattern"], - min_track_timepoints=MIN_TRACK_TIMEPOINTS, - ) - - # Step 2: Signal extraction (annotation-based) - aligned = extract_annotation_signal(aligned, state_col="organelle_state", positive_value="remodel") - aligned["experiment"] = exp["label"] - aligned["marker"] = marker - all_experiment_dfs.append(aligned) - - if not all_experiment_dfs: - print(f" No data for {marker}, skipping") - continue - - combined = pd.concat(all_experiment_dfs, ignore_index=True) - - # Step 3: Aggregate - fraction_df = aggregate_population(combined, TIME_BINS_MINUTES, signal_type="fraction") - - n_tracks = combined.groupby(["fov_name", "track_id", "experiment"]).ngroups - marker_results[marker] = { - "combined_df": combined, - "fraction_df": fraction_df, - "config": config, - "n_tracks": n_tracks, - "n_experiments": len(config["experiments"]), - "n_frames": len(combined), - } - - print( - f"\n **{marker} summary**: {n_tracks} tracks, " - f"{len(config['experiments'])} experiments, {len(combined):,} total frames" - ) - -# %% -# =========================================================================== -# Process controls -# =========================================================================== - -control_results = {} -for marker, config in ORGANELLE_CONFIG.items(): - if not config.get("controls"): - continue - ctrl_dfs = [] - for ctrl in config["controls"]: - df = pd.read_csv(ctrl["csv_path"]) - df = df[df["fov_name"].str.startswith(ctrl["fov_pattern"])].copy() - ctrl_dfs.append(df) - if ctrl_dfs: - control_combined = pd.concat(ctrl_dfs, ignore_index=True) - n_total = len(control_combined.dropna(subset=["organelle_state"])) - n_remodel = (control_combined["organelle_state"] == "remodel").sum() - fraction = n_remodel / n_total if n_total > 0 else 0 - control_results[marker] = { - "n_total": n_total, - "n_remodel": n_remodel, - "fraction": fraction, - } - print(f" {marker} control: {n_remodel}/{n_total} = {fraction:.4f}") - -# %% -# =========================================================================== -# Step 4: Timing metrics -# =========================================================================== - -timing_rows = [] -for marker, res in marker_results.items(): - frac_df = res["fraction_df"] - - t_onset, threshold, bl_mean, bl_std = find_onset_time( - frac_df, - sigma_threshold=ONSET_THRESHOLD_SIGMA, - min_cells_per_bin=MIN_CELLS_PER_BIN, - ) - t_50 = find_half_max_time(frac_df) - peak = find_peak_metrics(frac_df) - - timing_rows.append( - { - "marker": marker, - "T_onset_minutes": t_onset, - "T_50_minutes": t_50, - "T_peak_minutes": peak["T_peak_minutes"], - "peak_amplitude": peak["peak_amplitude"], - "T_return_minutes": peak["T_return_minutes"], - "pulse_duration_minutes": peak["pulse_duration_minutes"], - "auc": peak["auc"], - "baseline_mean": bl_mean, - "baseline_std": bl_std, - "n_tracks": res["n_tracks"], - "n_experiments": res["n_experiments"], - } - ) - -timing_df = pd.DataFrame(timing_rows) -print("\n## Remodeling Timing Metrics\n") -print(timing_df.to_string(index=False)) - -# Per-track timing -all_track_timing = [] -for marker, res in marker_results.items(): - track_timing = compute_track_timing(res["combined_df"], signal_type="fraction") - track_timing["marker"] = marker - all_track_timing.append(track_timing) - -track_timing_df = pd.concat(all_track_timing, ignore_index=True) - -# %% -# =========================================================================== -# Step 5: Plotting -# =========================================================================== - -marker_curves = {m: res["fraction_df"] for m, res in marker_results.items()} -marker_configs = {m: res["config"] for m, res in marker_results.items()} - -plot_response_curves( - marker_curves, - marker_configs, - RESULTS_DIR, - signal_type="fraction", - min_cells_per_bin=MIN_CELLS_PER_BIN, - title="Annotation-based organelle remodeling after sensor translocation", - filename_prefix="annotation_remodeling_comparison", -) - -for marker, res in marker_results.items(): - plot_cell_heatmap( - res["combined_df"], - TIME_BINS_MINUTES, - signal_type="fraction", - organelle_label=res["config"]["label"], - output_dir=RESULTS_DIR, - filename_prefix=f"{marker}_annotation_heatmap", - ) - -plot_timing_distributions( - track_timing_df, - marker_configs, - RESULTS_DIR, - filename_prefix="per_track_onset_duration", -) - -plot_onset_comparison( - timing_df, - RESULTS_DIR, - filename_prefix="onset_comparison", -) - -# %% -# =========================================================================== -# Step 6: Statistical tests -# =========================================================================== - -if len(marker_results) > 1: - stats_df = run_statistical_tests(marker_results, track_timing_df, control_results or None) - print("\n## Statistical Tests\n") - print(stats_df.to_string(index=False)) - stats_df.to_csv(RESULTS_DIR / "statistical_tests.csv", index=False) - -# %% -# =========================================================================== -# Step 7: Save CSVs -# =========================================================================== - -RESULTS_DIR.mkdir(parents=True, exist_ok=True) - -timing_df.to_csv(RESULTS_DIR / "timing_metrics.csv", index=False) -track_timing_df.to_csv(RESULTS_DIR / "per_track_timing.csv", index=False) - -for marker, res in marker_results.items(): - frac_path = RESULTS_DIR / f"{marker}_fraction_curve.csv" - res["fraction_df"].to_csv(frac_path, index=False) - -print(f"\nResults saved to {RESULTS_DIR}") - -# %% diff --git a/applications/dynaclr/scripts/pseudotime/embedding_distance.py b/applications/dynaclr/scripts/pseudotime/embedding_distance.py deleted file mode 100644 index e9311e3c0..000000000 --- a/applications/dynaclr/scripts/pseudotime/embedding_distance.py +++ /dev/null @@ -1,301 +0,0 @@ -# %% -""" -Embedding distance-based organelle remodeling analysis. - -Measures remodeling timing using cosine distance from pre-infection -baseline embeddings. Supports per-track and control-well baselines, -with optional PCA projection. - -Pipeline: alignment → embedding distance → aggregation → metrics → plotting - -Usage: Run as a Jupyter-compatible script (# %% cell markers). -""" - -import glob -from pathlib import Path - -import anndata as ad -import numpy as np -import pandas as pd - -from dynaclr.evaluation.pseudotime.alignment import align_tracks -from dynaclr.evaluation.pseudotime.metrics import ( - aggregate_population, - compute_track_timing, - find_half_max_time, - find_onset_time, - find_peak_metrics, - run_statistical_tests, -) -from dynaclr.evaluation.pseudotime.plotting import ( - plot_cell_heatmap, - plot_onset_comparison, - plot_response_curves, - plot_timing_distributions, -) -from dynaclr.evaluation.pseudotime.signals import ( - extract_embedding_distance, -) - -# %% -# =========================================================================== -# Dataset configuration -# =========================================================================== - -ANNOTATIONS_ROOT = Path("/hpc/projects/organelle_phenotyping/datasets/annotations") -EMBEDDINGS_ROOT = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics") - -ORGANELLE_CONFIG = { - "G3BP1": { - "experiments": [ - { - "embeddings_path": EMBEDDINGS_ROOT - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "embeddings_pattern": "*organelle*.zarr", - "annotations_path": ANNOTATIONS_ROOT - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "C/2", - "control_fov_pattern": "C/1", - "frame_interval_minutes": 30, - "label": "2025_07_22 ZIKV", - }, - ], - "label": "G3BP1 (Stress Granule)", - "color": "#1f77b4", - }, - "SEC61B": { - "experiments": [ - { - "embeddings_path": EMBEDDINGS_ROOT - / "2024_11_07_A549_SEC61_DENV" - / "4-phenotyping/2-predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "embeddings_pattern": "*organelle*.zarr", - "annotations_path": ANNOTATIONS_ROOT - / "2024_11_07_A549_SEC61B_DENV" - / "2024_11_07_A549_SEC61B_DENV_combined_annotations.csv", - "fov_pattern": "C/2", - "control_fov_pattern": "B/3", - "frame_interval_minutes": 10, - "label": "2024_11_07 DENV", - }, - ], - "label": "SEC61B (ER)", - "color": "#2ca02c", - }, -} - -# Analysis parameters -T_PERTURB_SOURCE = "annotation" -BASELINE_METHOD = "per_track" # "per_track" or "control_well" -BASELINE_WINDOW_MINUTES = (-240, -180) -DISTANCE_METRIC = "cosine" -PCA_N_COMPONENTS = 20 # Set to None to use full embedding space -MIN_BASELINE_FRAMES = 2 -TIME_BINS_MINUTES = np.arange(-600, 901, 30) -MIN_CELLS_PER_BIN = 10 -MIN_TRACK_TIMEPOINTS = 3 -ONSET_THRESHOLD_SIGMA = 2 - -RESULTS_DIR = Path(__file__).parent / "results" / "embedding_distance" - -# %% -# =========================================================================== -# Step 1 + 2: Load data, alignment, and signal extraction -# =========================================================================== - -marker_results = {} - -for marker, config in ORGANELLE_CONFIG.items(): - print(f"\n{'=' * 60}") - print(f"Processing {marker}") - print(f"{'=' * 60}") - - all_experiment_dfs = [] - - for exp in config["experiments"]: - print(f"\n Experiment: {exp['label']}") - - # Load embeddings - emb_files = glob.glob(str(exp["embeddings_path"] / exp["embeddings_pattern"])) - if not emb_files: - print(f" No embeddings found matching: {exp['embeddings_pattern']}") - continue - - adata = ad.read_zarr(emb_files[0]) - print(f" Loaded {adata.shape[0]:,} embeddings") - - # Load annotations for infection state alignment - ann_df = pd.read_csv(exp["annotations_path"]) - if "parent_track_id" not in ann_df.columns: - ann_df["parent_track_id"] = -1 - - # Step 1: Alignment - aligned = align_tracks( - ann_df, - frame_interval_minutes=exp["frame_interval_minutes"], - source=T_PERTURB_SOURCE, - fov_pattern=exp["fov_pattern"], - min_track_timepoints=MIN_TRACK_TIMEPOINTS, - ) - - # Step 2: Signal extraction (embedding distance) - aligned = extract_embedding_distance( - adata, - aligned, - baseline_method=BASELINE_METHOD, - baseline_window_minutes=BASELINE_WINDOW_MINUTES, - control_fov_pattern=exp.get("control_fov_pattern"), - distance_metric=DISTANCE_METRIC, - pca_n_components=PCA_N_COMPONENTS, - min_baseline_frames=MIN_BASELINE_FRAMES, - ) - aligned["experiment"] = exp["label"] - aligned["marker"] = marker - all_experiment_dfs.append(aligned) - - if not all_experiment_dfs: - print(f" No data for {marker}, skipping") - continue - - combined = pd.concat(all_experiment_dfs, ignore_index=True) - - # Step 3: Aggregate - population_df = aggregate_population(combined, TIME_BINS_MINUTES, signal_type="continuous") - - n_tracks = combined.groupby(["fov_name", "track_id", "experiment"]).ngroups - marker_results[marker] = { - "combined_df": combined, - "population_df": population_df, - "config": config, - "n_tracks": n_tracks, - "n_experiments": len(config["experiments"]), - "n_frames": len(combined), - } - - print( - f"\n **{marker} summary**: {n_tracks} tracks, " - f"{len(config['experiments'])} experiments, {len(combined):,} total frames" - ) - -# %% -# =========================================================================== -# Step 4: Timing metrics -# =========================================================================== - -timing_rows = [] -for marker, res in marker_results.items(): - pop_df = res["population_df"] - - t_onset, threshold, bl_mean, bl_std = find_onset_time( - pop_df, - sigma_threshold=ONSET_THRESHOLD_SIGMA, - min_cells_per_bin=MIN_CELLS_PER_BIN, - ) - t_50 = find_half_max_time(pop_df) - peak = find_peak_metrics(pop_df) - - timing_rows.append( - { - "marker": marker, - "T_onset_minutes": t_onset, - "T_50_minutes": t_50, - "T_peak_minutes": peak["T_peak_minutes"], - "peak_amplitude": peak["peak_amplitude"], - "T_return_minutes": peak["T_return_minutes"], - "pulse_duration_minutes": peak["pulse_duration_minutes"], - "auc": peak["auc"], - "baseline_mean": bl_mean, - "baseline_std": bl_std, - "baseline_method": BASELINE_METHOD, - "distance_metric": DISTANCE_METRIC, - "pca_components": PCA_N_COMPONENTS, - "n_tracks": res["n_tracks"], - "n_experiments": res["n_experiments"], - } - ) - -timing_df = pd.DataFrame(timing_rows) -print("\n## Embedding Distance Timing Metrics\n") -print(timing_df.to_string(index=False)) - -# Per-track timing -all_track_timing = [] -for marker, res in marker_results.items(): - track_timing = compute_track_timing(res["combined_df"], signal_type="continuous") - track_timing["marker"] = marker - all_track_timing.append(track_timing) - -track_timing_df = pd.concat(all_track_timing, ignore_index=True) - -# %% -# =========================================================================== -# Step 5: Plotting -# =========================================================================== - -marker_curves = {m: res["population_df"] for m, res in marker_results.items()} -marker_configs = {m: res["config"] for m, res in marker_results.items()} - -plot_response_curves( - marker_curves, - marker_configs, - RESULTS_DIR, - signal_type="continuous", - min_cells_per_bin=MIN_CELLS_PER_BIN, - title=f"Embedding distance remodeling ({BASELINE_METHOD}, {DISTANCE_METRIC})", - filename_prefix="embedding_distance_comparison", -) - -for marker, res in marker_results.items(): - plot_cell_heatmap( - res["combined_df"], - TIME_BINS_MINUTES, - signal_type="continuous", - organelle_label=res["config"]["label"], - output_dir=RESULTS_DIR, - filename_prefix=f"{marker}_distance_heatmap", - ) - -if len(track_timing_df) > 0: - plot_timing_distributions( - track_timing_df, - marker_configs, - RESULTS_DIR, - filename_prefix="per_track_onset_duration", - ) - - plot_onset_comparison( - timing_df, - RESULTS_DIR, - filename_prefix="onset_comparison", - ) - -# %% -# =========================================================================== -# Step 6: Statistical tests -# =========================================================================== - -if len(marker_results) > 1 and len(track_timing_df) > 0: - stats_df = run_statistical_tests(marker_results, track_timing_df) - print("\n## Statistical Tests\n") - print(stats_df.to_string(index=False)) - stats_df.to_csv(RESULTS_DIR / "statistical_tests.csv", index=False) - -# %% -# =========================================================================== -# Step 7: Save CSVs -# =========================================================================== - -RESULTS_DIR.mkdir(parents=True, exist_ok=True) - -timing_df.to_csv(RESULTS_DIR / "timing_metrics.csv", index=False) -track_timing_df.to_csv(RESULTS_DIR / "per_track_timing.csv", index=False) - -for marker, res in marker_results.items(): - curve_path = RESULTS_DIR / f"{marker}_distance_curve.csv" - res["population_df"].to_csv(curve_path, index=False) - -print(f"\nResults saved to {RESULTS_DIR}") - -# %% diff --git a/applications/dynaclr/scripts/pseudotime/infection_death_remodeling.py b/applications/dynaclr/scripts/pseudotime/infection_death_remodeling.py deleted file mode 100644 index 890b6c83d..000000000 --- a/applications/dynaclr/scripts/pseudotime/infection_death_remodeling.py +++ /dev/null @@ -1,386 +0,0 @@ -# %% -""" -Multi-channel correlation: infection, death, and organelle remodeling. - -Uses classifier predictions from different channels to ask: -- Do cells that get infected earlier also die faster? -- Is faster death correlated with faster organelle remodeling? - -Pipeline: -1. Load sensor zarr → T_perturb (infection onset), T_death (cell death onset) -2. Load organelle zarr → T_remodel (organelle remodeling onset) -3. Merge per-track event timings -4. Correlate and visualize - -Usage: Run as a Jupyter-compatible script (# %% cell markers). -""" - -from pathlib import Path - -import anndata as ad -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -from scipy import stats - -# %% -# =========================================================================== -# Configuration -# =========================================================================== - -DATASET_ROOT = Path( - "/hpc/projects/intracellular_dashboard/organelle_dynamics" - "/2025_01_24_A549_G3BP1_DENV/4-phenotyping/predictions" - "/DynaCLR-2D-BagOfChannels-timeaware/v3" -) - -SENSOR_ZARR = DATASET_ROOT / "timeaware_sensor_160patch_104ckpt.zarr" -ORGANELLE_ZARR = DATASET_ROOT / "timeaware_organelle_160patch_104ckpt.zarr" - -FOV_PATTERN = "C/2" # infected wells -FRAME_INTERVAL_MINUTES = 10 -MIN_TRACK_TIMEPOINTS = 3 - -RESULTS_DIR = Path(__file__).parent / "results" / "infection_death_remodeling" - -# %% -# =========================================================================== -# Step 1: Load data and filter to infected wells -# =========================================================================== - -sensor = ad.read_zarr(SENSOR_ZARR) -organelle = ad.read_zarr(ORGANELLE_ZARR) - -print(f"Sensor: {sensor.shape[0]:,} cells") -print(f"Organelle: {organelle.shape[0]:,} cells") - -# Filter to infected FOVs -sensor_obs = sensor.obs[sensor.obs["fov_name"].astype(str).str.startswith(FOV_PATTERN)].copy() -organelle_obs = organelle.obs[organelle.obs["fov_name"].astype(str).str.startswith(FOV_PATTERN)].copy() - -print(f"\nAfter FOV filter ({FOV_PATTERN}):") -print(f" Sensor: {len(sensor_obs):,} cells") -print(f" Organelle: {len(organelle_obs):,} cells") - -# %% -# =========================================================================== -# Step 2: Build per-cell merged dataframe -# =========================================================================== - -merge_keys = ["fov_name", "track_id", "t"] - -sensor_cols = merge_keys + [ - "predicted_infection_state", - "predicted_cell_death_state", -] -organelle_cols = merge_keys + [ - "predicted_organelle_state_g3bp1", -] - -merged = sensor_obs[sensor_cols].merge( - organelle_obs[organelle_cols], - on=merge_keys, - how="inner", -) - -merged["t_minutes"] = merged["t"] * FRAME_INTERVAL_MINUTES - -print(f"\nMerged: {len(merged):,} cells across {merged.groupby(['fov_name', 'track_id']).ngroups} tracks") -print(f" Infection: {merged['predicted_infection_state'].value_counts().to_dict()}") -print(f" Death: {merged['predicted_cell_death_state'].value_counts().to_dict()}") -print(f" Remodel: {merged['predicted_organelle_state_g3bp1'].value_counts().to_dict()}") - -# %% -# =========================================================================== -# Step 3: Compute per-track event timings -# =========================================================================== - - -def find_first_event(group: pd.DataFrame, col: str, value: str) -> float | None: - """Return t_minutes of the first frame matching value, or None.""" - hits = group.loc[group[col] == value, "t_minutes"] - if len(hits) > 0: - return hits.min() - return None - - -track_events = [] -for (fov, tid), group in merged.groupby(["fov_name", "track_id"]): - group = group.sort_values("t") - n_frames = len(group) - if n_frames < MIN_TRACK_TIMEPOINTS: - continue - - t_start = group["t_minutes"].min() - t_end = group["t_minutes"].max() - track_duration = t_end - t_start - - t_infection = find_first_event(group, "predicted_infection_state", "infected") - t_death = find_first_event(group, "predicted_cell_death_state", "dead") - t_remodel = find_first_event(group, "predicted_organelle_state_g3bp1", "remodel") - - # Was cell ever infected, dead, remodeled? - ever_infected = t_infection is not None - ever_dead = t_death is not None - ever_remodeled = t_remodel is not None - - # Time from infection to death / remodeling - infection_to_death = (t_death - t_infection) if (ever_infected and ever_dead) else None - infection_to_remodel = (t_remodel - t_infection) if (ever_infected and ever_remodeled) else None - remodel_to_death = (t_death - t_remodel) if (ever_remodeled and ever_dead) else None - - track_events.append( - { - "fov_name": fov, - "track_id": tid, - "n_frames": n_frames, - "track_duration_min": track_duration, - "t_infection_min": t_infection, - "t_death_min": t_death, - "t_remodel_min": t_remodel, - "ever_infected": ever_infected, - "ever_dead": ever_dead, - "ever_remodeled": ever_remodeled, - "infection_to_death_min": infection_to_death, - "infection_to_remodel_min": infection_to_remodel, - "remodel_to_death_min": remodel_to_death, - } - ) - -events_df = pd.DataFrame(track_events) - -print(f"\n## Track Event Summary ({len(events_df)} tracks)") -print(f" Ever infected: {events_df['ever_infected'].sum()}") -print(f" Ever dead: {events_df['ever_dead'].sum()}") -print(f" Ever remodeled: {events_df['ever_remodeled'].sum()}") -print(f" Infected & dead: {(events_df['ever_infected'] & events_df['ever_dead']).sum()}") -print(f" Infected & remodeled: {(events_df['ever_infected'] & events_df['ever_remodeled']).sum()}") -print(f" All three: {(events_df['ever_infected'] & events_df['ever_dead'] & events_df['ever_remodeled']).sum()}") - -# %% -# =========================================================================== -# Step 4: Descriptive statistics -# =========================================================================== - -infected_tracks = events_df[events_df["ever_infected"]].copy() - -print("\n## Timing distributions (infected tracks only)") -for col_label, col in [ - ("Infection → Death", "infection_to_death_min"), - ("Infection → Remodel", "infection_to_remodel_min"), - ("Remodel → Death", "remodel_to_death_min"), -]: - valid = infected_tracks[col].dropna() - if len(valid) > 0: - print(f"\n **{col_label}** (n={len(valid)})") - print(f" median: {valid.median():.0f} min, mean: {valid.mean():.0f} min, std: {valid.std():.0f} min") - print(f" range: [{valid.min():.0f}, {valid.max():.0f}] min") - -# Compare death rates: infected vs uninfected -infected_dead = events_df["ever_infected"] & events_df["ever_dead"] -uninfected_dead = ~events_df["ever_infected"] & events_df["ever_dead"] -n_infected = events_df["ever_infected"].sum() -n_uninfected = (~events_df["ever_infected"]).sum() - -print("\n## Death rates") -print(f" Infected tracks: {infected_dead.sum()}/{n_infected} = {infected_dead.sum() / max(n_infected, 1):.1%}") -print( - f" Uninfected tracks: {uninfected_dead.sum()}/{n_uninfected} = {uninfected_dead.sum() / max(n_uninfected, 1):.1%}" -) - -if n_infected > 0 and n_uninfected > 0: - table = np.array( - [ - [infected_dead.sum(), n_infected - infected_dead.sum()], - [uninfected_dead.sum(), n_uninfected - uninfected_dead.sum()], - ] - ) - chi2, p_val, _, _ = stats.chi2_contingency(table) - print(f" Chi-squared: {chi2:.2f}, p={p_val:.4g}") - -# %% -# =========================================================================== -# Step 5: Correlation — infection_to_death vs infection_to_remodel -# =========================================================================== - -both = infected_tracks.dropna(subset=["infection_to_death_min", "infection_to_remodel_min"]).copy() - -print(f"\n## Correlation: Infection→Death vs Infection→Remodel (n={len(both)})") - -if len(both) >= 5: - r_pearson, p_pearson = stats.pearsonr(both["infection_to_remodel_min"], both["infection_to_death_min"]) - r_spearman, p_spearman = stats.spearmanr(both["infection_to_remodel_min"], both["infection_to_death_min"]) - print(f" Pearson r={r_pearson:.3f}, p={p_pearson:.4g}") - print(f" Spearman rho={r_spearman:.3f}, p={p_spearman:.4g}") - - # Bin tracks into early/late remodelers (median split) - median_remodel = both["infection_to_remodel_min"].median() - both["remodel_speed"] = np.where( - both["infection_to_remodel_min"] <= median_remodel, "early_remodel", "late_remodel" - ) - - for label, subdf in both.groupby("remodel_speed"): - death_times = subdf["infection_to_death_min"] - print( - f"\n {label} (n={len(subdf)}): death at median {death_times.median():.0f} min," - f" mean {death_times.mean():.0f} min" - ) - - early = both.loc[both["remodel_speed"] == "early_remodel", "infection_to_death_min"] - late = both.loc[both["remodel_speed"] == "late_remodel", "infection_to_death_min"] - if len(early) >= 3 and len(late) >= 3: - u_stat, u_p = stats.mannwhitneyu(early, late, alternative="two-sided") - print(f"\n Mann-Whitney U test (early vs late remodelers death time): U={u_stat:.0f}, p={u_p:.4g}") - -# %% -# =========================================================================== -# Step 6: Plots -# =========================================================================== - -RESULTS_DIR.mkdir(parents=True, exist_ok=True) - -fig, axes = plt.subplots(2, 2, figsize=(14, 12)) - -# --- Panel A: Scatter of infection→remodel vs infection→death --- -ax = axes[0, 0] -if len(both) >= 5: - ax.scatter( - both["infection_to_remodel_min"], - both["infection_to_death_min"], - alpha=0.4, - s=15, - edgecolors="none", - ) - # Regression line - slope, intercept, _, _, _ = stats.linregress(both["infection_to_remodel_min"], both["infection_to_death_min"]) - x_fit = np.linspace(both["infection_to_remodel_min"].min(), both["infection_to_remodel_min"].max(), 100) - ax.plot(x_fit, slope * x_fit + intercept, "r--", label=f"r={r_pearson:.2f}, p={p_pearson:.2g}") - ax.legend() -ax.set_xlabel("Infection → Remodel (min)") -ax.set_ylabel("Infection → Death (min)") -ax.set_title("A. Remodeling vs Death timing") - -# --- Panel B: Distribution of infection→death for infected vs all tracks --- -ax = axes[0, 1] -infected_death_times = infected_tracks["infection_to_death_min"].dropna() -if len(infected_death_times) > 0: - ax.hist(infected_death_times, bins=30, alpha=0.7, color="#d62728", edgecolor="white") -ax.set_xlabel("Infection → Death (min)") -ax.set_ylabel("Number of tracks") -ax.set_title("B. Time from infection to death") - -# --- Panel C: Death rate comparison --- -ax = axes[1, 0] -categories = ["Infected", "Uninfected"] -dead_counts = [infected_dead.sum(), uninfected_dead.sum()] -alive_counts = [n_infected - infected_dead.sum(), n_uninfected - uninfected_dead.sum()] -x = np.arange(len(categories)) -width = 0.35 -ax.bar(x - width / 2, dead_counts, width, label="Dead", color="#d62728") -ax.bar(x + width / 2, alive_counts, width, label="Alive", color="#2ca02c") -ax.set_xticks(x) -ax.set_xticklabels(categories) -ax.set_ylabel("Number of tracks") -ax.set_title("C. Death rates by infection status") -ax.legend() - -# --- Panel D: Boxplot of death timing by remodel speed --- -ax = axes[1, 1] -if len(both) >= 5: - early_vals = both.loc[both["remodel_speed"] == "early_remodel", "infection_to_death_min"].to_numpy() - late_vals = both.loc[both["remodel_speed"] == "late_remodel", "infection_to_death_min"].to_numpy() - bp = ax.boxplot( - [early_vals, late_vals], - labels=["Early remodelers", "Late remodelers"], - patch_artist=True, - ) - bp["boxes"][0].set_facecolor("#1f77b4") - bp["boxes"][1].set_facecolor("#ff7f0e") - ax.set_ylabel("Infection → Death (min)") - ax.set_title("D. Death timing by remodel speed") - -plt.tight_layout() -fig.savefig(RESULTS_DIR / "infection_death_remodeling.png", dpi=150, bbox_inches="tight") -fig.savefig(RESULTS_DIR / "infection_death_remodeling.pdf", bbox_inches="tight") -plt.show() -print(f"Saved to {RESULTS_DIR}") - -# %% -# =========================================================================== -# Step 7: Timeline heatmap — per-track state over time -# =========================================================================== - -# Show a sample of infected tracks with all 3 states over time -infected_tids = infected_tracks.sort_values("t_infection_min").head(50) -sample_keys = set(zip(infected_tids["fov_name"], infected_tids["track_id"])) - -sample = merged[merged.apply(lambda r: (r["fov_name"], r["track_id"]) in sample_keys, axis=1)].copy() - -if len(sample) > 0: - # Align to infection time - sample = sample.merge( - infected_tids[["fov_name", "track_id", "t_infection_min"]], - on=["fov_name", "track_id"], - ) - sample["t_rel"] = sample["t_minutes"] - sample["t_infection_min"] - - # Encode states as numeric for heatmap - sample["infection_num"] = (sample["predicted_infection_state"] == "infected").astype(int) - sample["death_num"] = (sample["predicted_cell_death_state"] == "dead").astype(int) - sample["remodel_num"] = (sample["predicted_organelle_state_g3bp1"] == "remodel").astype(int) - - fig, axes = plt.subplots(1, 3, figsize=(18, 8), sharey=True) - time_bins = np.arange(sample["t_rel"].min(), sample["t_rel"].max() + FRAME_INTERVAL_MINUTES, FRAME_INTERVAL_MINUTES) - - track_labels = [] - for i, ((fov, tid), _) in enumerate(infected_tids.iterrows()): - track_labels.append(f"{fov}:{tid}") - - for ax, (title, col) in zip( - axes, - [ - ("Infection", "infection_num"), - ("Death", "death_num"), - ("Remodeling", "remodel_num"), - ], - ): - # Pivot: rows=tracks, cols=time bins - track_list = list(zip(infected_tids["fov_name"], infected_tids["track_id"])) - matrix = np.full((len(track_list), len(time_bins) - 1), np.nan) - - for i, (fov, tid) in enumerate(track_list): - track_data = sample[(sample["fov_name"] == fov) & (sample["track_id"] == tid)] - for _, row in track_data.iterrows(): - bin_idx = np.searchsorted(time_bins, row["t_rel"]) - 1 - if 0 <= bin_idx < matrix.shape[1]: - matrix[i, bin_idx] = row[col] - - im = ax.imshow(matrix, aspect="auto", cmap="RdYlBu_r", vmin=0, vmax=1, interpolation="nearest") - ax.set_xlabel("Time relative to infection (min)") - ax.set_title(title) - - # Set x tick labels - n_ticks = min(10, len(time_bins)) - tick_positions = np.linspace(0, len(time_bins) - 2, n_ticks, dtype=int) - ax.set_xticks(tick_positions) - ax.set_xticklabels([f"{time_bins[t]:.0f}" for t in tick_positions], rotation=45) - - axes[0].set_ylabel("Tracks (sorted by infection time)") - plt.colorbar(im, ax=axes[-1], label="State (0=no, 1=yes)") - plt.tight_layout() - fig.savefig(RESULTS_DIR / "track_timeline_heatmap.png", dpi=150, bbox_inches="tight") - fig.savefig(RESULTS_DIR / "track_timeline_heatmap.pdf", bbox_inches="tight") - plt.show() - -# %% -# =========================================================================== -# Step 8: Save results -# =========================================================================== - -events_df.to_csv(RESULTS_DIR / "track_events.csv", index=False) -if len(both) > 0: - both.to_csv(RESULTS_DIR / "infected_remodeled_dead_tracks.csv", index=False) - -print(f"\nAll results saved to {RESULTS_DIR}") - -# %% diff --git a/applications/dynaclr/scripts/pseudotime/infection_onset_distribution.py b/applications/dynaclr/scripts/pseudotime/infection_onset_distribution.py deleted file mode 100644 index 276f3e99c..000000000 --- a/applications/dynaclr/scripts/pseudotime/infection_onset_distribution.py +++ /dev/null @@ -1,1028 +0,0 @@ -# %% -""" -Infection onset timing distribution and phenotype binning. - -Measures the absolute time from experiment start to first infection -(T_perturbation) per track, then bins cells by early/mid/late infection -to compare downstream phenotype responses (death, remodeling). - -Supports both annotation-based and prediction-based infection timing. - -Usage: Run as a Jupyter-compatible script (# %% cell markers). -""" - -from pathlib import Path - -import anndata as ad -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -from scipy import stats - -# %% -# =========================================================================== -# Configuration -# =========================================================================== - -ANNOTATIONS_ROOT = Path("/hpc/projects/organelle_phenotyping/datasets/annotations") -EMBEDDINGS_ROOT = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics") - -# All experiments start at 3 HPI (hours post-infection). -# t=0 in the data corresponds to 3 HPI, so absolute HPI = t_minutes/60 + T_OFFSET_HPI. -T_OFFSET_HPI = 3.0 - -EXPERIMENTS = { - "G3BP1 (Stress Granule)": { - "datasets": [ - { - "annotations_path": ANNOTATIONS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "embeddings_path": EMBEDDINGS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "fov_pattern": "C/2", - "frame_interval_minutes": 30, - "label": "2025_07_24 ZIKV", - }, - { - "annotations_path": ANNOTATIONS_ROOT - / "2025_01_24_A549_G3BP1_DENV" - / "2025_01_24_A549_G3BP1_DENV_combined_annotations.csv", - "embeddings_path": EMBEDDINGS_ROOT - / "2025_01_24_A549_G3BP1_DENV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "fov_pattern": "C/2", - "frame_interval_minutes": 10, - "label": "2025_01_24 DENV", - }, - ], - "remodel_task": "organelle_state_g3bp1", - "remodel_ann_col": "organelle_state", - "remodel_positive": "remodel", - }, - "SEC61B (ER)": { - "datasets": [ - { - "annotations_path": ANNOTATIONS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "embeddings_path": EMBEDDINGS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "fov_pattern": "A/2", - "frame_interval_minutes": 30, - "label": "2025_07_24 ZIKV", - }, - ], - "remodel_task": "organelle_state_sec61b", - "remodel_ann_col": "organelle_state", - "remodel_positive": "remodel", - }, -} - -MIN_TRACK_TIMEPOINTS = 10 - -# Smoothing: require N consecutive frames of a state before calling it a true event. -# Set to 1 to disable (raw first-frame detection). -MIN_CONSECUTIVE_FRAMES = 3 - -# Binning strategy: terciles by default, or custom edges -N_BINS = 3 - -RESULTS_DIR = Path(__file__).parent / "results" / "infection_onset_distribution" - -SAVE_FIGURES = False - -# %% -# =========================================================================== -# Step 1: Helper — extract per-track events from annotations -# =========================================================================== - - -def extract_annotation_events( - ann_df: pd.DataFrame, - fov_pattern: str, - frame_interval: float, - remodel_col: str = "organelle_state", - remodel_positive: str = "remodel", -) -> pd.DataFrame: - """Extract per-track first-event timings from annotation CSV.""" - filtered = ann_df[ann_df["fov_name"].astype(str).str.startswith(fov_pattern)].copy() - has_division = "cell_division_state" in filtered.columns - rows = [] - for (fov, tid), g in filtered.groupby(["fov_name", "track_id"]): - if len(g) < MIN_TRACK_TIMEPOINTS: - continue - t_start, t_end = g["t"].min(), g["t"].max() - inf = g[g["infection_state"] == "infected"] - dead = g[g["cell_death_state"] == "dead"] - remodel = g[g[remodel_col] == remodel_positive] - - t_infection = inf["t"].min() if len(inf) > 0 else None - t_death = dead["t"].min() if len(dead) > 0 else None - t_remodel = remodel["t"].min() if len(remodel) > 0 else None - - t_division = None - if has_division: - mitosis = g[g["cell_division_state"] == "mitosis"] - t_division = mitosis["t"].min() if len(mitosis) > 0 else None - - rows.append( - { - "fov_name": fov, - "track_id": tid, - "source": "annotation", - "t_track_start": t_start * frame_interval, - "t_track_end": t_end * frame_interval, - "track_duration_min": (t_end - t_start) * frame_interval, - "t_infection_min": (t_infection * frame_interval if t_infection is not None else None), - "t_death_min": (t_death * frame_interval if t_death is not None else None), - "t_remodel_min": (t_remodel * frame_interval if t_remodel is not None else None), - "t_division_min": (t_division * frame_interval if t_division is not None else None), - "ever_infected": t_infection is not None, - "ever_dead": t_death is not None, - "ever_remodeled": t_remodel is not None, - "ever_divided": t_division is not None, - } - ) - return pd.DataFrame(rows) - - -# %% -# =========================================================================== -# Step 2: Helper — extract per-track events from predictions -# =========================================================================== - - -def _first_consecutive_event( - sorted_t: np.ndarray, - is_positive: np.ndarray, - min_consecutive: int, -) -> float | None: - """Return the t value where min_consecutive consecutive positive frames first occur.""" - if min_consecutive <= 1: - positives = sorted_t[is_positive] - return float(positives[0]) if len(positives) > 0 else None - - run = 0 - for i, pos in enumerate(is_positive): - if pos: - run += 1 - if run >= min_consecutive: - return float(sorted_t[i - min_consecutive + 1]) - else: - run = 0 - return None - - -def extract_prediction_events( - embeddings_path: Path, - fov_pattern: str, - frame_interval: float, - remodel_task: str = "organelle_state_g3bp1", - remodel_positive: str = "remodel", -) -> pd.DataFrame: - """Extract per-track first-event timings from sensor + organelle + phase zarrs.""" - sensor = ad.read_zarr(embeddings_path / "timeaware_sensor_160patch_104ckpt.zarr") - organelle = ad.read_zarr(embeddings_path / "timeaware_organelle_160patch_104ckpt.zarr") - phase = ad.read_zarr(embeddings_path / "timeaware_phase_160patch_104ckpt.zarr") - - sensor_obs = sensor.obs[sensor.obs["fov_name"].astype(str).str.startswith(fov_pattern)].copy() - organelle_obs = organelle.obs[organelle.obs["fov_name"].astype(str).str.startswith(fov_pattern)].copy() - phase_obs = phase.obs[phase.obs["fov_name"].astype(str).str.startswith(fov_pattern)].copy() - - merge_keys = ["fov_name", "track_id", "t"] - pred_remodel_col = f"predicted_{remodel_task}" - - # Check if phase has division predictions - has_division = "predicted_cell_division_state" in phase_obs.columns - - merged = sensor_obs[merge_keys + ["predicted_infection_state", "predicted_cell_death_state"]].merge( - organelle_obs[merge_keys + [pred_remodel_col]], - on=merge_keys, - how="inner", - ) - if has_division: - merged = merged.merge( - phase_obs[merge_keys + ["predicted_cell_division_state"]], - on=merge_keys, - how="inner", - ) - - rows = [] - for (fov, tid), g in merged.groupby(["fov_name", "track_id"]): - if len(g) < MIN_TRACK_TIMEPOINTS: - continue - g = g.sort_values("t") - t_start, t_end = g["t"].min(), g["t"].max() - - sorted_t = g["t"].to_numpy() - t_infection = _first_consecutive_event( - sorted_t, - (g["predicted_infection_state"] == "infected").values, - MIN_CONSECUTIVE_FRAMES, - ) - t_death = _first_consecutive_event( - sorted_t, - (g["predicted_cell_death_state"] == "dead").values, - MIN_CONSECUTIVE_FRAMES, - ) - t_remodel = _first_consecutive_event( - sorted_t, - (g[pred_remodel_col] == remodel_positive).values, - MIN_CONSECUTIVE_FRAMES, - ) - t_division = None - if has_division: - t_division = _first_consecutive_event( - sorted_t, - (g["predicted_cell_division_state"] == "mitosis").values, - MIN_CONSECUTIVE_FRAMES, - ) - - rows.append( - { - "fov_name": fov, - "track_id": tid, - "source": "prediction", - "t_track_start": t_start * frame_interval, - "t_track_end": t_end * frame_interval, - "track_duration_min": (t_end - t_start) * frame_interval, - "t_infection_min": (t_infection * frame_interval if t_infection is not None else None), - "t_death_min": (t_death * frame_interval if t_death is not None else None), - "t_remodel_min": (t_remodel * frame_interval if t_remodel is not None else None), - "t_division_min": (t_division * frame_interval if t_division is not None else None), - "ever_infected": t_infection is not None, - "ever_dead": t_death is not None, - "ever_remodeled": t_remodel is not None, - "ever_divided": t_division is not None, - } - ) - return pd.DataFrame(rows) - - -# %% -# =========================================================================== -# Step 3: Process all experiments (multiple datasets per organelle) -# =========================================================================== - -all_results = {} - -for exp_name, cfg in EXPERIMENTS.items(): - print(f"\n{'=' * 60}") - print(f" {exp_name}") - print(f"{'=' * 60}") - - all_ann_events = [] - all_pred_events = [] - - for ds in cfg["datasets"]: - print(f"\n Dataset: {ds['label']}") - - ann_df = pd.read_csv(ds["annotations_path"]) - ann_ev = extract_annotation_events( - ann_df, - fov_pattern=ds["fov_pattern"], - frame_interval=ds["frame_interval_minutes"], - remodel_col=cfg["remodel_ann_col"], - remodel_positive=cfg["remodel_positive"], - ) - ann_ev["dataset"] = ds["label"] - all_ann_events.append(ann_ev) - print(f" Annotation: {len(ann_ev)} tracks, {ann_ev['ever_infected'].sum()} infected") - - pred_ev = extract_prediction_events( - embeddings_path=ds["embeddings_path"], - fov_pattern=ds["fov_pattern"], - frame_interval=ds["frame_interval_minutes"], - remodel_task=cfg["remodel_task"], - remodel_positive=cfg["remodel_positive"], - ) - pred_ev["dataset"] = ds["label"] - all_pred_events.append(pred_ev) - print(f" Prediction: {len(pred_ev)} tracks, {pred_ev['ever_infected'].sum()} infected") - - ann_events_df = pd.concat(all_ann_events, ignore_index=True) - pred_events_df = pd.concat(all_pred_events, ignore_index=True) - - # Convert to HPI (hours post-inoculation) - for df in [ann_events_df, pred_events_df]: - df["t_infection_hpi"] = df["t_infection_min"] / 60 + T_OFFSET_HPI - df["t_death_hpi"] = df["t_death_min"] / 60 + T_OFFSET_HPI - df["t_remodel_hpi"] = df["t_remodel_min"] / 60 + T_OFFSET_HPI - df["t_division_hpi"] = df["t_division_min"] / 60 + T_OFFSET_HPI - - print(f"\n Combined annotation: {len(ann_events_df)} tracks, {ann_events_df['ever_infected'].sum()} infected") - print(f" Combined prediction: {len(pred_events_df)} tracks, {pred_events_df['ever_infected'].sum()} infected") - - all_results[exp_name] = { - "cfg": cfg, - "ann_events_df": ann_events_df, - "pred_events_df": pred_events_df, - } - -# %% -# =========================================================================== -# Step 4: Bin infected tracks by infection onset time -# =========================================================================== - - -def bin_and_analyze(events_df: pd.DataFrame, source_label: str) -> pd.DataFrame: - """Bin infected tracks by T_infection terciles and summarize phenotypes.""" - infected = events_df[events_df["ever_infected"]].copy() - if len(infected) < N_BINS: - print(f" Too few infected tracks ({len(infected)}) for {N_BINS} bins") - return infected - - # Tercile binning — labels in HPI (hours post-inoculation) - _, bin_edges = pd.qcut(infected["t_infection_hpi"], q=N_BINS, retbins=True) - bin_labels = [f"{bin_edges[i]:.1f}–{bin_edges[i + 1]:.1f} HPI" for i in range(len(bin_edges) - 1)] - infected["infection_bin"] = pd.qcut( - infected["t_infection_hpi"], - q=N_BINS, - labels=bin_labels, - ) - - print(f"\n## {source_label}: Translocation onset bins") - print(f" Bin edges (HPI): {[f'{e:.1f}' for e in bin_edges]}") - print(f" Labels: {bin_labels}") - - has_division = "ever_divided" in infected.columns - - for bin_label in bin_labels: - subset = infected[infected["infection_bin"] == bin_label] - n = len(subset) - n_dead = subset["ever_dead"].sum() - n_remodel = subset["ever_remodeled"].sum() - - print( - f"\n **{bin_label}** (n={n}, T_inf range: " - f"{subset['t_infection_min'].min():.0f}-{subset['t_infection_min'].max():.0f} min)" - ) - print(f" Death rate: {n_dead}/{n} = {n_dead / max(n, 1):.1%}") - print(f" Remodel rate: {n_remodel}/{n} = {n_remodel / max(n, 1):.1%}") - - if has_division: - n_divided = subset["ever_divided"].sum() - print(f" Division rate: {n_divided}/{n} = {n_divided / max(n, 1):.1%}") - - # Time from infection to death/remodel for those that have it - both_dead = subset[subset["ever_dead"]].copy() - if len(both_dead) > 0: - dt = both_dead["t_death_min"] - both_dead["t_infection_min"] - print( - f" Translocation→Death: median={dt.median():.0f} min, mean={dt.mean():.0f} min (n={len(both_dead)})" - ) - - both_remodel = subset[subset["ever_remodeled"]].copy() - if len(both_remodel) > 0: - dt = both_remodel["t_remodel_min"] - both_remodel["t_infection_min"] - print( - f" Translocation→Remodel: median={dt.median():.0f} min," - f" mean={dt.mean():.0f} min (n={len(both_remodel)})" - ) - - if has_division: - both_divided = subset[subset["ever_divided"]].copy() - if len(both_divided) > 0: - dt = both_divided["t_division_min"] - both_divided["t_infection_min"] - print( - f" Translocation→Division: median={dt.median():.0f} min," - f" mean={dt.mean():.0f} min (n={len(both_divided)})" - ) - - # Kruskal-Wallis across bins for infection→death, infection→remodel, infection→division - event_tests = [ - ("Translocation→Death", "t_death_min"), - ("Translocation→Remodel", "t_remodel_min"), - ] - if has_division: - event_tests.append(("Translocation→Division", "t_division_min")) - for event_label, event_col in event_tests: - infected_with_event = infected.dropna(subset=[event_col]).copy() - infected_with_event["delta"] = infected_with_event[event_col] - infected_with_event["t_infection_min"] - groups = [g["delta"].to_numpy() for _, g in infected_with_event.groupby("infection_bin") if len(g) >= 2] - if len(groups) >= 2: - h_stat, h_p = stats.kruskal(*groups) - print(f"\n Kruskal-Wallis ({event_label} across bins): H={h_stat:.2f}, p={h_p:.4g}") - - return infected - - -for exp_name, res in all_results.items(): - ann_binned = bin_and_analyze(res["ann_events_df"], f"{exp_name} (Annotation)") - pred_binned = bin_and_analyze(res["pred_events_df"], f"{exp_name} (Prediction)") - res["ann_binned"] = ann_binned - res["pred_binned"] = pred_binned - -# %% -# =========================================================================== -# Step 5: Plots — per experiment: onset distribution + response histograms -# =========================================================================== - -if SAVE_FIGURES: - RESULTS_DIR.mkdir(parents=True, exist_ok=True) - -BIN_COLORS = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"] - - -def _plot_kde_by_bin(ax, binned_df, event_col, delta_label): - """Plot KDE curves of response time per infection bin.""" - if "infection_bin" not in binned_df.columns: - return - categories = binned_df["infection_bin"].cat.categories - for i, bin_label in enumerate(categories): - subset = binned_df[binned_df["infection_bin"] == bin_label] - dt = (subset[event_col] - subset["t_infection_min"]).dropna() - if len(dt) >= 3: - from scipy.stats import gaussian_kde - - kde = gaussian_kde(dt, bw_method="scott") - x_grid = np.linspace(dt.min() - 30, dt.max() + 30, 200) - ax.plot(x_grid, kde(x_grid), color=BIN_COLORS[i % len(BIN_COLORS)], linewidth=2) - ax.fill_between( - x_grid, - kde(x_grid), - alpha=0.15, - color=BIN_COLORS[i % len(BIN_COLORS)], - label=f"{bin_label} (n={len(dt)})", - ) - elif len(dt) > 0: - ax.axvline( - dt.median(), - color=BIN_COLORS[i % len(BIN_COLORS)], - linestyle="--", - label=f"{bin_label} (n={len(dt)})", - ) - ax.legend(fontsize=8) - ax.set_xlabel(f"{delta_label} (min)") - ax.set_ylabel("Density") - - -for exp_name, res in all_results.items(): - ann_infected = res["ann_events_df"][res["ann_events_df"]["ever_infected"]] - pred_infected = res["pred_events_df"][res["pred_events_df"]["ever_infected"]] - ann_binned = res["ann_binned"] - pred_binned = res["pred_binned"] - - fig, axes = plt.subplots(2, 4, figsize=(24, 10)) - fig.suptitle(exp_name, fontsize=14, fontweight="bold") - - # --- Row 1: Annotation-based --- - ax = axes[0, 0] - if len(ann_infected) > 0: - ax.hist( - ann_infected["t_infection_hpi"], - bins=20, - alpha=0.7, - color="#1f77b4", - edgecolor="white", - ) - ax.set_xlabel("T_infection (HPI)") - ax.set_ylabel("Number of tracks") - ax.set_title("A. Annotation: infection onset") - - for ax, (delta_label, event_col, panel) in zip( - [axes[0, 1], axes[0, 2], axes[0, 3]], - [ - ("Translocation → Death", "t_death_min", "B"), - ("Translocation → Remodel", "t_remodel_min", "C"), - ("Translocation → Division", "t_division_min", "D"), - ], - ): - _plot_kde_by_bin(ax, ann_binned, event_col, delta_label) - ax.set_title(f"{panel}. Annotation: {delta_label}") - - # --- Row 2: Prediction-based --- - ax = axes[1, 0] - if len(pred_infected) > 0: - ax.hist( - pred_infected["t_infection_hpi"], - bins=30, - alpha=0.7, - color="#ff7f0e", - edgecolor="white", - ) - ax.set_xlabel("T_infection (HPI)") - ax.set_ylabel("Number of tracks") - ax.set_title("E. Prediction: infection onset") - - for ax, (delta_label, event_col, panel) in zip( - [axes[1, 1], axes[1, 2], axes[1, 3]], - [ - ("Translocation → Death", "t_death_min", "F"), - ("Translocation → Remodel", "t_remodel_min", "G"), - ("Translocation → Division", "t_division_min", "H"), - ], - ): - _plot_kde_by_bin(ax, pred_binned, event_col, delta_label) - ax.set_title(f"{panel}. Prediction: {delta_label}") - - plt.tight_layout() - if SAVE_FIGURES: - prefix = exp_name.replace(" ", "_").replace("(", "").replace(")", "") - fig.savefig(RESULTS_DIR / f"{prefix}_onset_binning.png", dpi=150, bbox_inches="tight") - fig.savefig(RESULTS_DIR / f"{prefix}_onset_binning.pdf", bbox_inches="tight") - plt.show() - -# %% -# =========================================================================== -# Step 7: Response time comparison — are elapsed times the same across bins? -# =========================================================================== - - -def plot_response_time_comparison( - binned_df: pd.DataFrame, - source_label: str, - output_dir: Path, -) -> None: - """Boxplot + swarm of response times per infection bin with pairwise tests.""" - if "infection_bin" not in binned_df.columns: - return - - # Compute deltas - binned_df = binned_df.copy() - binned_df["infection_to_death"] = binned_df["t_death_min"] - binned_df["t_infection_min"] - binned_df["infection_to_remodel"] = binned_df["t_remodel_min"] - binned_df["t_infection_min"] - has_division = "t_division_min" in binned_df.columns - if has_division: - binned_df["infection_to_division"] = binned_df["t_division_min"] - binned_df["t_infection_min"] - - n_panels = 4 if has_division else 3 - fig, axes = plt.subplots(1, n_panels, figsize=(6 * n_panels, 6)) - - bin_categories = list(binned_df["infection_bin"].cat.categories) - - # --- Response time boxplots --- - boxplot_items = [ - ("infection_to_death", "Translocation → Death (min)", "Death"), - ("infection_to_remodel", "Translocation → Remodel (min)", "Remodel"), - ] - if has_division: - boxplot_items.append(("infection_to_division", "Translocation → Division (min)", "Division")) - for ax, (delta_col, ylabel, title_suffix) in zip( - axes[: len(boxplot_items)], - boxplot_items, - ): - plot_data = [] - positions = [] - tick_labels = [] - bin_names = [] - for i, bin_label in enumerate(bin_categories): - vals = binned_df.loc[binned_df["infection_bin"] == bin_label, delta_col].dropna() - if len(vals) > 0: - plot_data.append(vals.values) - positions.append(i) - tick_labels.append(f"{bin_label}\n(n={len(vals)})") - bin_names.append(bin_label) - - if len(plot_data) == 0: - ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes) - ax.set_title(f"{source_label}: {title_suffix}") - continue - - bp = ax.boxplot(plot_data, positions=positions, patch_artist=True, widths=0.5) - colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"] - for patch, color in zip(bp["boxes"], colors[: len(plot_data)]): - patch.set_facecolor(color) - patch.set_alpha(0.6) - - # Overlay individual points - for pos, vals in zip(positions, plot_data): - jitter = np.random.default_rng(42).uniform(-0.12, 0.12, len(vals)) - ax.scatter(pos + jitter, vals, alpha=0.4, s=12, color="black", zorder=3) - - ax.set_xticks(positions) - ax.set_xticklabels(tick_labels) - ax.set_ylabel(ylabel) - ax.set_title(f"{source_label}: {title_suffix} response time") - ax.set_xlabel("Translocation onset bin") - - # Pairwise Mann-Whitney U tests - test_results = [] - for i in range(len(plot_data)): - for j in range(i + 1, len(plot_data)): - if len(plot_data[i]) >= 3 and len(plot_data[j]) >= 3: - u_stat, u_p = stats.mannwhitneyu(plot_data[i], plot_data[j], alternative="two-sided") - test_results.append(f"{bin_names[i]} vs {bin_names[j]}: p={u_p:.4g}") - - if test_results: - test_text = "\n".join(test_results) - ax.text( - 0.98, - 0.98, - test_text, - transform=ax.transAxes, - ha="right", - va="top", - fontsize=8, - family="monospace", - bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.5), - ) - - # --- Phenotype rates per bin --- - ax = axes[-1] - rates = [] - for bin_label in bin_categories: - subset = binned_df[binned_df["infection_bin"] == bin_label] - n = len(subset) - row_dict = { - "bin": bin_label, - "death_rate": subset["ever_dead"].sum() / max(n, 1), - "remodel_rate": subset["ever_remodeled"].sum() / max(n, 1), - "n": n, - } - if has_division: - row_dict["division_rate"] = subset["ever_divided"].sum() / max(n, 1) - rates.append(row_dict) - rates_df = pd.DataFrame(rates) - - x = np.arange(len(bin_categories)) - n_bars = 3 if has_division else 2 - width = 0.8 / n_bars - ax.bar( - x - width, - rates_df["death_rate"], - width, - label="Death rate", - color="#d62728", - alpha=0.7, - ) - ax.bar( - x, - rates_df["remodel_rate"], - width, - label="Remodel rate", - color="#1f77b4", - alpha=0.7, - ) - if has_division: - ax.bar( - x + width, - rates_df["division_rate"], - width, - label="Division rate", - color="#2ca02c", - alpha=0.7, - ) - for i, row in rates_df.iterrows(): - max_rate = max(row["death_rate"], row["remodel_rate"]) - if has_division: - max_rate = max(max_rate, row["division_rate"]) - ax.text( - i, - max_rate + 0.02, - f"n={row['n']}", - ha="center", - fontsize=9, - ) - ax.set_xticks(x) - ax.set_xticklabels(bin_categories, rotation=15, ha="right") - ax.set_ylabel("Fraction of tracks") - ax.set_title(f"{source_label}: phenotype rates by bin") - ax.legend() - ax.set_ylim(0, 1.1) - - plt.tight_layout() - if SAVE_FIGURES: - prefix = source_label.lower().replace(" ", "_") - fig.savefig( - output_dir / f"{prefix}_response_time_comparison.png", - dpi=150, - bbox_inches="tight", - ) - fig.savefig(output_dir / f"{prefix}_response_time_comparison.pdf", bbox_inches="tight") - plt.show() - - # Print summary table - print(f"\n## {source_label}: Response time summary (median min)") - summary_rows = [] - for bin_label in bin_categories: - subset = binned_df[binned_df["infection_bin"] == bin_label] - death_dt = subset["infection_to_death"].dropna() - remodel_dt = subset["infection_to_remodel"].dropna() - row_dict = { - "bin": bin_label, - "n_tracks": len(subset), - "transloc→death median": (f"{death_dt.median():.0f}" if len(death_dt) > 0 else "—"), - "transloc→death n": len(death_dt), - "transloc→remodel median": (f"{remodel_dt.median():.0f}" if len(remodel_dt) > 0 else "—"), - "transloc→remodel n": len(remodel_dt), - } - if has_division: - division_dt = subset["infection_to_division"].dropna() - row_dict["transloc→division median"] = f"{division_dt.median():.0f}" if len(division_dt) > 0 else "—" - row_dict["transloc→division n"] = len(division_dt) - summary_rows.append(row_dict) - print(pd.DataFrame(summary_rows).to_string(index=False)) - - -for exp_name, res in all_results.items(): - plot_response_time_comparison(res["pred_binned"], f"{exp_name} (Prediction)", RESULTS_DIR) - plot_response_time_comparison(res["ann_binned"], f"{exp_name} (Annotation)", RESULTS_DIR) - -# %% -# =========================================================================== -# Step 7a: Continuous scatter — HPI vs response time (no binning) -# =========================================================================== - - -def plot_hpi_vs_response( - events_df: pd.DataFrame, - source_label: str, - output_dir: Path, -) -> None: - """Scatter plot of translocation onset (HPI) vs response time with regression.""" - infected = events_df[events_df["ever_infected"]].copy() - if len(infected) < 5: - print(f" {source_label}: too few infected tracks ({len(infected)}) for scatter") - return - - infected["infection_to_death"] = infected["t_death_min"] - infected["t_infection_min"] - infected["infection_to_remodel"] = infected["t_remodel_min"] - infected["t_infection_min"] - - response_items = [ - ("infection_to_death", "Transloc → Death (min)"), - ("infection_to_remodel", "Transloc → Remodel (min)"), - ] - has_division = "t_division_min" in infected.columns - if has_division: - infected["infection_to_division"] = infected["t_division_min"] - infected["t_infection_min"] - response_items.append(("infection_to_division", "Transloc → Division (min)")) - - n_panels = len(response_items) - fig, axes = plt.subplots(1, n_panels, figsize=(6 * n_panels, 5)) - if n_panels == 1: - axes = [axes] - fig.suptitle( - f"{source_label}: T_translocation vs response time", - fontsize=14, - fontweight="bold", - ) - - for ax, (delta_col, xlabel) in zip(axes, response_items): - valid = infected.dropna(subset=[delta_col]) - x = valid[delta_col].to_numpy() - y = valid["t_infection_hpi"].to_numpy() - - if len(x) < 3: - ax.text( - 0.5, - 0.5, - f"n={len(x)}", - ha="center", - va="center", - transform=ax.transAxes, - ) - ax.set_xlabel(xlabel) - ax.set_ylabel("T_translocation (HPI)") - continue - - # Color by division status if available - if has_division and "ever_divided" in valid.columns: - divided_mask = valid["ever_divided"].to_numpy() - ax.scatter( - x[~divided_mask], - y[~divided_mask], - alpha=0.5, - s=20, - color="#1f77b4", - label="No division", - zorder=2, - ) - ax.scatter( - x[divided_mask], - y[divided_mask], - alpha=0.7, - s=30, - color="#2ca02c", - marker="^", - label="Divided", - zorder=3, - ) - ax.legend(fontsize=8) - else: - ax.scatter(x, y, alpha=0.5, s=20, color="#1f77b4", zorder=2) - - ax.text( - 0.03, - 0.97, - f"n={len(x)}", - transform=ax.transAxes, - ha="left", - va="top", - fontsize=9, - family="monospace", - bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.5), - ) - - ax.set_xlabel(xlabel) - ax.set_ylabel("T_translocation (HPI)") - - plt.tight_layout() - if SAVE_FIGURES: - prefix = source_label.lower().replace(" ", "_") - fig.savefig( - output_dir / f"{prefix}_hpi_vs_response.png", - dpi=150, - bbox_inches="tight", - ) - fig.savefig( - output_dir / f"{prefix}_hpi_vs_response.pdf", - bbox_inches="tight", - ) - plt.show() - - -for exp_name, res in all_results.items(): - plot_hpi_vs_response(res["pred_events_df"], f"{exp_name} (Prediction)", RESULTS_DIR) - plot_hpi_vs_response(res["ann_events_df"], f"{exp_name} (Annotation)", RESULTS_DIR) - -# %% -# =========================================================================== -# Step 7b: Division confound analysis — do divided cells respond faster? -# =========================================================================== - - -def plot_division_confound( - binned_df: pd.DataFrame, - source_label: str, - output_dir: Path, -) -> None: - """Compare response times between divided and non-divided cells. - - Tests whether cells that underwent mitosis have shorter - translocation→death or translocation→remodel times, which would - indicate division is a confound for the observed phenotype timing. - """ - if "ever_divided" not in binned_df.columns: - return - if "infection_bin" not in binned_df.columns: - return - - binned_df = binned_df.copy() - binned_df["infection_to_death"] = binned_df["t_death_min"] - binned_df["t_infection_min"] - binned_df["infection_to_remodel"] = binned_df["t_remodel_min"] - binned_df["t_infection_min"] - binned_df["division_label"] = binned_df["ever_divided"].map({True: "Divided", False: "No division"}) - - bin_categories = list(binned_df["infection_bin"].cat.categories) - response_cols = [ - ("infection_to_death", "Transloc → Death (min)"), - ("infection_to_remodel", "Transloc → Remodel (min)"), - ] - - # --- Figure 1: Boxplots stratified by division within each bin --- - fig, axes = plt.subplots( - len(response_cols), - len(bin_categories), - figsize=(6 * len(bin_categories), 5 * len(response_cols)), - squeeze=False, - ) - fig.suptitle( - f"{source_label}: Response times — Divided vs Not divided", - fontsize=14, - fontweight="bold", - ) - - for row_idx, (delta_col, ylabel) in enumerate(response_cols): - for col_idx, bin_label in enumerate(bin_categories): - ax = axes[row_idx, col_idx] - subset = binned_df[binned_df["infection_bin"] == bin_label].dropna(subset=[delta_col]) - divided = subset[subset["ever_divided"]][delta_col] - not_divided = subset[~subset["ever_divided"]][delta_col] - - plot_data = [] - labels = [] - colors_box = [] - if len(not_divided) > 0: - plot_data.append(not_divided.values) - labels.append(f"No div\n(n={len(not_divided)})") - colors_box.append("#1f77b4") - if len(divided) > 0: - plot_data.append(divided.values) - labels.append(f"Divided\n(n={len(divided)})") - colors_box.append("#2ca02c") - - if len(plot_data) == 0: - ax.text( - 0.5, - 0.5, - "No data", - ha="center", - va="center", - transform=ax.transAxes, - ) - else: - bp = ax.boxplot( - plot_data, - patch_artist=True, - widths=0.5, - ) - for patch, c in zip(bp["boxes"], colors_box): - patch.set_facecolor(c) - patch.set_alpha(0.6) - for pos, vals in enumerate(plot_data, 1): - jitter = np.random.default_rng(42).uniform(-0.1, 0.1, len(vals)) - ax.scatter( - pos + jitter, - vals, - alpha=0.4, - s=12, - color="black", - zorder=3, - ) - ax.set_xticklabels(labels) - - # Mann-Whitney if both groups have enough data - if len(divided) >= 3 and len(not_divided) >= 3: - _, p = stats.mannwhitneyu(not_divided, divided, alternative="two-sided") - ax.set_title(f"{bin_label}\np={p:.4g}", fontsize=10) - else: - ax.set_title(bin_label, fontsize=10) - - if col_idx == 0: - ax.set_ylabel(ylabel) - - plt.tight_layout() - if SAVE_FIGURES: - prefix = source_label.lower().replace(" ", "_") - fig.savefig( - output_dir / f"{prefix}_division_confound.png", - dpi=150, - bbox_inches="tight", - ) - fig.savefig( - output_dir / f"{prefix}_division_confound.pdf", - bbox_inches="tight", - ) - plt.show() - - # --- Figure 2: Was division before or after translocation? --- - infected_divided = binned_df[binned_df["ever_divided"]].dropna(subset=["t_division_min"]) - if len(infected_divided) > 0: - infected_divided = infected_divided.copy() - infected_divided["division_relative_to_transloc"] = ( - infected_divided["t_division_min"] - infected_divided["t_infection_min"] - ) - n_before = (infected_divided["division_relative_to_transloc"] < 0).sum() - n_after = (infected_divided["division_relative_to_transloc"] >= 0).sum() - median_dt = infected_divided["division_relative_to_transloc"].median() - - print(f"\n## {source_label}: Division timing relative to translocation") - print(f" Divided before translocation: {n_before}/{len(infected_divided)}") - print(f" Divided after translocation: {n_after}/{len(infected_divided)}") - print(f" Median division–translocation gap: {median_dt:.0f} min") - - # Per-bin breakdown - for bin_label in bin_categories: - sub = infected_divided[infected_divided["infection_bin"] == bin_label] - if len(sub) > 0: - n_b = (sub["division_relative_to_transloc"] < 0).sum() - n_a = (sub["division_relative_to_transloc"] >= 0).sum() - print( - f" {bin_label}: {n_b} before, {n_a} after transloc " - f"(median gap: {sub['division_relative_to_transloc'].median():.0f} min)" - ) - - # --- Summary: overall Mann-Whitney (pooled across bins) --- - print(f"\n## {source_label}: Pooled divided vs not-divided response times") - for delta_col, label in response_cols: - valid = binned_df.dropna(subset=[delta_col]) - div_vals = valid[valid["ever_divided"]][delta_col] - nodiv_vals = valid[~valid["ever_divided"]][delta_col] - if len(div_vals) >= 3 and len(nodiv_vals) >= 3: - _, p = stats.mannwhitneyu(nodiv_vals, div_vals, alternative="two-sided") - print( - f" {label}: no-div median={nodiv_vals.median():.0f} min (n={len(nodiv_vals)}), " - f"div median={div_vals.median():.0f} min (n={len(div_vals)}), " - f"p={p:.4g}" - ) - else: - print(f" {label}: no-div n={len(nodiv_vals)}, div n={len(div_vals)} — too few for test") - - -for exp_name, res in all_results.items(): - plot_division_confound(res["pred_binned"], f"{exp_name} (Prediction)", RESULTS_DIR) - plot_division_confound(res["ann_binned"], f"{exp_name} (Annotation)", RESULTS_DIR) - -# %% -# =========================================================================== -# Step 8: Save CSVs -# =========================================================================== - -if SAVE_FIGURES: - RESULTS_DIR.mkdir(parents=True, exist_ok=True) - for exp_name, res in all_results.items(): - prefix = exp_name.replace(" ", "_").replace("(", "").replace(")", "") - res["ann_events_df"].to_csv(RESULTS_DIR / f"{prefix}_annotation_events.csv", index=False) - res["pred_events_df"].to_csv(RESULTS_DIR / f"{prefix}_prediction_events.csv", index=False) - - if "infection_bin" in res["ann_binned"].columns: - res["ann_binned"].to_csv(RESULTS_DIR / f"{prefix}_annotation_binned.csv", index=False) - if "infection_bin" in res["pred_binned"].columns: - res["pred_binned"].to_csv(RESULTS_DIR / f"{prefix}_prediction_binned.csv", index=False) - - print(f"\nAll results saved to {RESULTS_DIR}") - -# %% diff --git a/applications/dynaclr/scripts/pseudotime/prediction_remodeling.py b/applications/dynaclr/scripts/pseudotime/prediction_remodeling.py deleted file mode 100644 index 0f7a426e1..000000000 --- a/applications/dynaclr/scripts/pseudotime/prediction_remodeling.py +++ /dev/null @@ -1,355 +0,0 @@ -# %% -""" -Prediction-based organelle remodeling analysis. - -Measures remodeling timing using classifier predictions -(predicted_organelle_state in AnnData) instead of human annotations. - -Pipeline: alignment → prediction signal → aggregation → metrics → plotting - -Usage: Run as a Jupyter-compatible script (# %% cell markers). -""" - -import glob -from pathlib import Path - -import anndata as ad -import numpy as np -import pandas as pd - -from dynaclr.evaluation.pseudotime.alignment import align_tracks -from dynaclr.evaluation.pseudotime.metrics import ( - aggregate_population, - compute_track_timing, - find_half_max_time, - find_onset_time, - find_peak_metrics, - run_statistical_tests, -) -from dynaclr.evaluation.pseudotime.plotting import ( - plot_cell_heatmap, - plot_onset_comparison, - plot_response_curves, - plot_timing_distributions, -) -from dynaclr.evaluation.pseudotime.signals import ( - extract_prediction_signal, -) - -# %% -# =========================================================================== -# Dataset configuration -# =========================================================================== - -ANNOTATIONS_ROOT = Path("/hpc/projects/organelle_phenotyping/datasets/annotations") -EMBEDDINGS_ROOT = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics") - -ORGANELLE_CONFIG = { - "G3BP1": { - "experiments": [ - { - "embeddings_path": EMBEDDINGS_ROOT - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "embeddings_pattern": "*organelle*.zarr", - "annotations_path": ANNOTATIONS_ROOT - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "C/2", # uninf c/1, inf c/2 - "frame_interval_minutes": 10, - "task": "organelle_state_g3bp1", - "label": "2025_07_22 ZIKV", - }, - { - "embeddings_path": EMBEDDINGS_ROOT - / "2025_01_24_A549_G3BP1_DENV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "embeddings_pattern": "*organelle*.zarr", - "annotations_path": ANNOTATIONS_ROOT - / "2025_01_24_A549_G3BP1_DENV" - / "2025_01_24_A549_G3BP1_DENV_combined_annotations.csv", - "fov_pattern": "C/2", # ZIKV uninf B/3, inf C/2 - "frame_interval_minutes": 10, - "task": "organelle_state_g3bp1", - "label": "2025_01_24 DENV", - }, - { - "embeddings_path": EMBEDDINGS_ROOT - / "2025_01_28_A549_G3BP1_ZIKV_DENV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "embeddings_pattern": "*organelle*.zarr", - "annotations_path": ANNOTATIONS_ROOT - / "2025_01_28_A549_G3BP1_ZIKV_DENV" - / "2025_01_28_A549_G3BP1_ZIKV_DENV_combined_annotations.csv", - "fov_pattern": "C/4", # DENV uninf B/4 and inf C/4 - "frame_interval_minutes": 30, - "task": "organelle_state_g3bp1", - "label": "2025_01_28 ZIKV", - }, - { - "embeddings_path": EMBEDDINGS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "embeddings_pattern": "*organelle*.zarr", - "annotations_path": ANNOTATIONS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "C/2", # ZIKV uinf C/1 and inf C/2 - "frame_interval_minutes": 30, - "task": "organelle_state_g3bp1", - "label": "2025_07_24 ZIKV", - }, - ], - "controls": [], - "label": "G3BP1 (Stress Granule)", - "color": "#1f77b4", - }, - "SEC61B": { - "experiments": [ - { - "embeddings_path": EMBEDDINGS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3", - "embeddings_pattern": "*organelle*.zarr", - "annotations_path": ANNOTATIONS_ROOT - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" - / "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_combined_annotations.csv", - "fov_pattern": "A/2", - "frame_interval_minutes": 30, - "task": "organelle_state_sec61b", - "label": "2025_07_24 ZIKV", - }, - ], - "controls": [], - "label": "SEC61B (ER)", - "color": "#ff7f0e", - }, -} - -# Analysis parameters -T_PERTURB_SOURCE = "annotation" # Default: use human annotations for T_perturb -USE_PROBABILITY = False # Set True to use continuous probability instead of binary -TIME_BINS_MINUTES = np.arange(-600, 901, 30) -MIN_CELLS_PER_BIN = 5 -MIN_TRACK_TIMEPOINTS = 3 -ONSET_THRESHOLD_SIGMA = 2 - -RESULTS_DIR = Path(__file__).parent / "results" / "prediction_remodeling" - -# %% -# =========================================================================== -# Step 1 + 2: Load data, alignment, and signal extraction -# =========================================================================== - -marker_results = {} - -for marker, config in ORGANELLE_CONFIG.items(): - print(f"\n{'=' * 60}") - print(f"Processing {marker}") - print(f"{'=' * 60}") - - all_experiment_dfs = [] - - for exp in config["experiments"]: - print(f"\n Experiment: {exp['label']}") - - # Load embeddings (AnnData with predictions) - emb_files = glob.glob(str(Path(exp["embeddings_path"]) / exp["embeddings_pattern"])) - if not emb_files: - print(f" No embeddings found matching: {exp['embeddings_pattern']}") - continue - - adata = ad.read_zarr(emb_files[0]) - print(f" Loaded {adata.shape[0]:,} embeddings") - - # Check predictions exist - task = exp.get("task", "organelle_state") - pred_col = f"predicted_{task}" - if pred_col not in adata.obs.columns: - print(f" WARNING: '{pred_col}' not in adata.obs — skipping") - continue - - # Load annotations for infection state alignment - ann_df = pd.read_csv(exp["annotations_path"]) - if "parent_track_id" not in ann_df.columns: - ann_df["parent_track_id"] = -1 - - # Step 1: Alignment (using annotations for T_perturb) - aligned = align_tracks( - ann_df, - frame_interval_minutes=exp["frame_interval_minutes"], - source=T_PERTURB_SOURCE, - fov_pattern=exp["fov_pattern"], - min_track_timepoints=MIN_TRACK_TIMEPOINTS, - ) - - # Step 2: Signal extraction (prediction-based) - aligned = extract_prediction_signal( - adata, - aligned, - task=task, - positive_value="remodel", - use_probability=USE_PROBABILITY, - ) - aligned["experiment"] = exp["label"] - aligned["marker"] = marker - all_experiment_dfs.append(aligned) - - if not all_experiment_dfs: - print(f" No data for {marker}, skipping") - continue - - combined = pd.concat(all_experiment_dfs, ignore_index=True) - - # Step 3: Aggregate - signal_type = "continuous" if USE_PROBABILITY else "fraction" - population_df = aggregate_population(combined, TIME_BINS_MINUTES, signal_type=signal_type) - - n_tracks = combined.groupby(["fov_name", "track_id", "experiment"]).ngroups - marker_results[marker] = { - "combined_df": combined, - "population_df": population_df, - "config": config, - "n_tracks": n_tracks, - "n_experiments": len(config["experiments"]), - "n_frames": len(combined), - } - - print( - f"\n **{marker} summary**: {n_tracks} tracks, " - f"{len(config['experiments'])} experiments, {len(combined):,} total frames" - ) - -# %% -# =========================================================================== -# Step 4: Timing metrics -# =========================================================================== - -timing_rows = [] -for marker, res in marker_results.items(): - pop_df = res["population_df"] - - t_onset, threshold, bl_mean, bl_std = find_onset_time( - pop_df, - sigma_threshold=ONSET_THRESHOLD_SIGMA, - min_cells_per_bin=MIN_CELLS_PER_BIN, - ) - t_50 = find_half_max_time(pop_df) - peak = find_peak_metrics(pop_df) - - timing_rows.append( - { - "marker": marker, - "T_onset_minutes": t_onset, - "T_50_minutes": t_50, - "T_peak_minutes": peak["T_peak_minutes"], - "peak_amplitude": peak["peak_amplitude"], - "T_return_minutes": peak["T_return_minutes"], - "pulse_duration_minutes": peak["pulse_duration_minutes"], - "auc": peak["auc"], - "baseline_mean": bl_mean, - "baseline_std": bl_std, - "n_tracks": res["n_tracks"], - "n_experiments": res["n_experiments"], - } - ) - -timing_df = pd.DataFrame(timing_rows) -print("\n## Prediction-based Timing Metrics\n") -print(timing_df.to_string(index=False)) - -# Per-track timing -signal_type = "continuous" if USE_PROBABILITY else "fraction" -all_track_timing = [] -for marker, res in marker_results.items(): - track_timing = compute_track_timing(res["combined_df"], signal_type=signal_type) - track_timing["marker"] = marker - all_track_timing.append(track_timing) - -if all_track_timing: - track_timing_df = pd.concat(all_track_timing, ignore_index=True) -else: - track_timing_df = pd.DataFrame( - columns=[ - "fov_name", - "track_id", - "onset_minutes", - "total_positive_minutes", - "span_minutes", - "n_positive_frames", - "n_total_frames", - "marker", - ] - ) - print("WARNING: No tracks with positive signal detected across any marker.") - -# %% -# =========================================================================== -# Step 5: Plotting -# =========================================================================== - -marker_curves = {m: res["population_df"] for m, res in marker_results.items()} -marker_configs = {m: res["config"] for m, res in marker_results.items()} - -plot_response_curves( - marker_curves, - marker_configs, - RESULTS_DIR, - signal_type=signal_type, - min_cells_per_bin=MIN_CELLS_PER_BIN, - title="Prediction-based organelle remodeling after infection", - filename_prefix="prediction_remodeling_comparison", -) - -for marker, res in marker_results.items(): - plot_cell_heatmap( - res["combined_df"], - TIME_BINS_MINUTES, - signal_type=signal_type, - organelle_label=res["config"]["label"], - output_dir=RESULTS_DIR, - filename_prefix=f"{marker}_prediction_heatmap", - ) - -if len(track_timing_df) > 0: - plot_timing_distributions( - track_timing_df, - marker_configs, - RESULTS_DIR, - filename_prefix="per_track_onset_duration", - ) - - plot_onset_comparison( - timing_df, - RESULTS_DIR, - filename_prefix="onset_comparison", - ) - -# %% -# =========================================================================== -# Step 6: Statistical tests -# =========================================================================== - -if len(marker_results) > 1 and len(track_timing_df) > 0: - stats_df = run_statistical_tests(marker_results, track_timing_df) - print("\n## Statistical Tests\n") - print(stats_df.to_string(index=False)) - stats_df.to_csv(RESULTS_DIR / "statistical_tests.csv", index=False) - -# %% -# =========================================================================== -# Step 7: Save CSVs -# =========================================================================== - -RESULTS_DIR.mkdir(parents=True, exist_ok=True) - -timing_df.to_csv(RESULTS_DIR / "timing_metrics.csv", index=False) -track_timing_df.to_csv(RESULTS_DIR / "per_track_timing.csv", index=False) - -for marker, res in marker_results.items(): - curve_path = RESULTS_DIR / f"{marker}_population_curve.csv" - res["population_df"].to_csv(curve_path, index=False) - -print(f"\nResults saved to {RESULTS_DIR}") - -# %% diff --git a/applications/dynaclr/src/dynaclr/cli.py b/applications/dynaclr/src/dynaclr/cli.py index ade202d7c..cf79d25c1 100644 --- a/applications/dynaclr/src/dynaclr/cli.py +++ b/applications/dynaclr/src/dynaclr/cli.py @@ -85,6 +85,14 @@ def dynaclr(): ) ) +dynaclr.add_command( + LazyCommand( + name="evaluate-tracking-accuracy", + import_path="dynaclr.evaluation.benchmarking.tracking_accuracy.evaluate_tracking.main", + short_help="Evaluate CTC tracking accuracy with DynaCLR ONNX embeddings", + ) +) + dynaclr.add_command( LazyCommand( name="append-obs", @@ -101,6 +109,14 @@ def dynaclr(): ) ) +dynaclr.add_command( + LazyCommand( + name="combined-dim-reduction", + import_path="dynaclr.evaluation.dimensionality_reduction.reduce_combined.main", + short_help="Joint PCA/PHATE across multiple AnnData stores", + ) +) + dynaclr.add_command( LazyCommand( name="cross-validate", @@ -109,6 +125,22 @@ def dynaclr(): ) ) +dynaclr.add_command( + LazyCommand( + name="run-linear-classifiers", + import_path="dynaclr.evaluation.linear_classifiers.orchestrated.main", + short_help="Run linear classifiers on orchestrator embeddings (batch, CSV metrics)", + ) +) + +dynaclr.add_command( + LazyCommand( + name="split-embeddings", + import_path="dynaclr.evaluation.split_embeddings.main", + short_help="Split combined embeddings zarr into one zarr per experiment", + ) +) + dynaclr.add_command( LazyCommand( name="info", @@ -125,6 +157,14 @@ def dynaclr(): ) ) +dynaclr.add_command( + LazyCommand( + name="preprocess-cell-index", + import_path="dynaclr.data.preprocess_cell_index.main", + short_help="Remove empty-frame rows from a cell index parquet", + ) +) + dynaclr.add_command( LazyCommand( name="convert-ops-parquet", @@ -157,6 +197,63 @@ def dynaclr(): ) ) +dynaclr.add_command( + LazyCommand( + name="compute-mmd", + import_path="dynaclr.evaluation.mmd.compute_mmd.main", + short_help="Compute MMD between perturbation groups in cell embeddings", + ) +) + +dynaclr.add_command( + LazyCommand( + name="plot-mmd-heatmap", + import_path="dynaclr.evaluation.mmd.compute_mmd.plot_mmd_heatmap_cmd", + short_help="Plot combined MMD heatmap (all markers) from per-experiment CSVs", + ) +) + +dynaclr.add_command( + LazyCommand( + name="prepare-eval-configs", + import_path="dynaclr.evaluation.evaluate.main", + short_help="Generate evaluation YAML configs and print JSON manifest (Nextflow entry point)", + ) +) + +dynaclr.add_command( + LazyCommand( + name="check-evals", + import_path="dynaclr.evaluation.check_evals.main", + short_help="Check eval completion status for all models in the registry", + ) +) + +dynaclr.add_command( + LazyCommand( + name="append-annotations", + import_path="dynaclr.evaluation.append_annotations.main", + short_help="Append annotation columns to per-experiment zarrs", + ) +) + +dynaclr.add_command( + LazyCommand( + name="append-predictions", + import_path="dynaclr.evaluation.append_predictions.main", + short_help="Apply saved classifiers and write predictions to per-experiment zarrs", + ) +) + + +dynaclr.add_command( + LazyCommand( + name="plot-embeddings", + import_path="dynaclr.evaluation.plot_embeddings.main", + short_help="Generate scatter plots from an AnnData embedding store", + ) +) + def main(): """Run the DynaCLR CLI.""" diff --git a/applications/dynaclr/src/dynaclr/data/datamodule.py b/applications/dynaclr/src/dynaclr/data/datamodule.py index cd702508b..1c49aaf82 100644 --- a/applications/dynaclr/src/dynaclr/data/datamodule.py +++ b/applications/dynaclr/src/dynaclr/data/datamodule.py @@ -12,10 +12,10 @@ from __future__ import annotations import logging -import os import numpy as np import pandas as pd +import torch from iohub.core.config import TensorStoreConfig from lightning.pytorch import LightningDataModule from monai.data.thread_buffer import ThreadDataLoader @@ -27,8 +27,9 @@ from dynaclr.data.index import MultiExperimentIndex from viscy_data._utils import BatchedCenterSpatialCropd, _transform_channel_wise from viscy_data.channel_dropout import ChannelDropout +from viscy_data.channel_utils import parse_channel_name from viscy_data.sampler import FlexibleBatchSampler -from viscy_transforms import BatchedRandSpatialCropd +from viscy_utils.mp_utils import available_cpus _logger = logging.getLogger(__name__) @@ -51,11 +52,10 @@ class MultiExperimentDataModule(LightningDataModule): Parameters ---------- - collection_path : str or None - Path to collection YAML for ExperimentRegistry.from_collection(). - Optional when ``cell_index_path`` is provided — the registry is - built directly from parquet + zarr metadata via - ExperimentRegistry.from_cell_index(). + cell_index_path : str + Path to preprocessed cell index parquet (from ``build-cell-index`` + + ``preprocess-cell-index``). Contains all metadata needed for + training: TCZYX shape, normalization stats, focus slice. z_window : int Number of Z slices the model consumes (final crop size). z_extraction_window : int or None @@ -85,7 +85,7 @@ class MultiExperimentDataModule(LightningDataModule): batch_size : int Batch size. Default: 128. num_workers : int - Thread workers for ThreadDataLoader. Default: 1. + Thread workers for ThreadDataLoader. Default: 4. batch_group_by : str or list[str] or None Column(s) to group batches by (e.g. ``"experiment"``). Default: None. stratify_by : str | list[str] | None @@ -122,17 +122,9 @@ class MultiExperimentDataModule(LightningDataModule): Only include these wells. Default: None. exclude_fovs : list[str] | None Exclude these FOVs. Default: None. - cell_index_path : str | None - Optional path to a pre-built cell index parquet for faster startup. - When provided, both train and val indices load from this parquet - (filtered by their respective registries). Default: None. focus_channel : str | None Channel name for ``focus_slice`` lookup when auto-resolving z_range. Default: None (uses first source_channel). - num_workers_index : int - Number of parallel processes for building the cell index. Default: 1 - (sequential). When > 1, one process is spawned per experiment. - Ignored when ``cell_index_path`` is provided. reference_pixel_size_xy_um : float or None Reference pixel size in XY (micrometers) for physical-scale normalization. None = no rescaling. Default: None. @@ -157,7 +149,7 @@ class MultiExperimentDataModule(LightningDataModule): def __init__( self, - collection_path: str | None, + cell_index_path: str, z_window: int, z_extraction_window: int | None = None, z_focus_offset: float = 0.5, @@ -168,7 +160,7 @@ def __init__( tau_range: tuple[float, float] = (0.5, 2.0), tau_decay_rate: float = 2.0, batch_size: int = 128, - num_workers: int = 1, + num_workers: int = 4, # Sampling hyperparameters (passed to FlexibleBatchSampler) batch_group_by: str | list[str] | None = None, stratify_by: str | list[str] | None = "perturbation", @@ -185,13 +177,13 @@ def __init__( normalizations: list[MapTransform] | None = None, augmentations: list[MapTransform] | None = None, # Other - cache_pool_bytes: int = 0, + cache_pool_bytes: int = 500_000_000, + recheck_cached_data: str | bool | None = None, + file_io_concurrency: int | None = None, seed: int = 0, include_wells: list[str] | None = None, exclude_fovs: list[str] | None = None, - cell_index_path: str | None = None, focus_channel: str | None = None, - num_workers_index: int = 1, reference_pixel_size_xy_um: float | None = None, reference_pixel_size_z_um: float | None = None, positive_cell_source: str = "lookup", @@ -200,11 +192,14 @@ def __init__( label_columns: dict[str, str] | None = None, max_border_shift: int = -1, shuffle_val: bool = False, + pin_memory: bool = True, + prefetch_factor: int | None = None, + buffer_size: int = 4, ) -> None: super().__init__() # Core parameters - self.collection_path = collection_path + self.cell_index_path = cell_index_path self.z_window = z_window self.z_extraction_window = z_extraction_window self.z_focus_offset = z_focus_offset @@ -240,18 +235,17 @@ def __init__( # Loss hyperparameters (informational) # Other self.cache_pool_bytes = cache_pool_bytes - cpus = os.environ.get("SLURM_CPUS_PER_TASK") - cpus = int(cpus) if cpus is not None else (os.cpu_count() or 4) + cpus = available_cpus(default=4) self.tensorstore_config = TensorStoreConfig( data_copy_concurrency=cpus, cache_pool_bytes=cache_pool_bytes or None, + recheck_cached_data=recheck_cached_data, + file_io_concurrency=file_io_concurrency, ) self.seed = seed self.include_wells = include_wells self.exclude_fovs = exclude_fovs - self.cell_index_path = cell_index_path self.focus_channel = focus_channel - self.num_workers_index = num_workers_index self.reference_pixel_size_xy_um = reference_pixel_size_xy_um self.reference_pixel_size_z_um = reference_pixel_size_z_um self.positive_cell_source = positive_cell_source @@ -260,6 +254,9 @@ def __init__( self.label_columns = label_columns self.max_border_shift = max_border_shift self.shuffle_val = shuffle_val + self.pin_memory = pin_memory + self.prefetch_factor = prefetch_factor + self.buffer_size = buffer_size # Create ChannelDropout module self.channel_dropout = ChannelDropout( @@ -270,6 +267,7 @@ def __init__( # Datasets (populated in setup) self.train_dataset: MultiExperimentTripletDataset | None = None self.val_dataset: MultiExperimentTripletDataset | None = None + self.predict_dataset: MultiExperimentTripletDataset | None = None # ------------------------------------------------------------------ # Setup @@ -292,33 +290,20 @@ def setup(self, stage: str | None = None) -> None: Lightning stage: ``"fit"``, ``"predict"``, etc. """ if stage == "fit" or stage is None: - if self.collection_path is not None: - registry = ExperimentRegistry.from_collection( - self.collection_path, - z_window=self.z_window, - z_extraction_window=self.z_extraction_window, - z_focus_offset=self.z_focus_offset, - focus_channel=getattr(self, "focus_channel", None), - reference_pixel_size_xy_um=self.reference_pixel_size_xy_um, - reference_pixel_size_z_um=self.reference_pixel_size_z_um, - ) - elif self.cell_index_path is not None: - registry = ExperimentRegistry.from_cell_index( - self.cell_index_path, - z_window=self.z_window, - z_extraction_window=self.z_extraction_window, - z_focus_offset=self.z_focus_offset, - focus_channel=getattr(self, "focus_channel", None), - reference_pixel_size_xy_um=self.reference_pixel_size_xy_um, - reference_pixel_size_z_um=self.reference_pixel_size_z_um, - ) - else: - raise ValueError("Either collection_path or cell_index_path must be provided.") + registry, cell_index_df = ExperimentRegistry.from_cell_index( + self.cell_index_path, + z_window=self.z_window, + z_extraction_window=self.z_extraction_window, + z_focus_offset=self.z_focus_offset, + focus_channel=self.focus_channel, + reference_pixel_size_xy_um=self.reference_pixel_size_xy_um, + reference_pixel_size_z_um=self.reference_pixel_size_z_um, + ) if self.val_experiments: - self._setup_experiment_split(registry) + self._setup_experiment_split(registry, cell_index_df) else: - self._setup_fov_split(registry) + self._setup_fov_split(registry, cell_index_df) if self.channels_per_sample is None: self._channel_names = registry.source_channel_labels @@ -333,7 +318,6 @@ def setup(self, stage: str | None = None) -> None: self._augmentation_transform = Compose( self.normalizations + self.augmentations + [self._train_final_crop()] ) - self._no_augmentation_transform = Compose(self.normalizations + [self._val_final_crop()]) _logger.info( "MultiExperimentDataModule setup: %d train anchors, %d val anchors", @@ -341,7 +325,62 @@ def setup(self, stage: str | None = None) -> None: len(self.val_dataset) if self.val_dataset else 0, ) - def _setup_experiment_split(self, registry: ExperimentRegistry) -> None: + elif stage == "predict": + self._setup_predict() + _logger.info( + "MultiExperimentDataModule predict setup: %d anchors", + len(self.predict_dataset) if self.predict_dataset else 0, + ) + + def _setup_predict(self) -> None: + """Set up predict dataset over the full cell index (no train/val split).""" + registry, cell_index_df = ExperimentRegistry.from_cell_index( + self.cell_index_path, + z_window=self.z_window, + z_extraction_window=self.z_extraction_window, + z_focus_offset=self.z_focus_offset, + focus_channel=self.focus_channel, + reference_pixel_size_xy_um=self.reference_pixel_size_xy_um, + reference_pixel_size_z_um=self.reference_pixel_size_z_um, + ) + + if self.channels_per_sample is None: + self._channel_names = registry.source_channel_labels + elif isinstance(self.channels_per_sample, int): + self._channel_names = [f"channel_{i}" for i in range(self.channels_per_sample)] + else: + self._channel_names = list(self.channels_per_sample) + + predict_index = MultiExperimentIndex( + registry=registry, + yx_patch_size=self.yx_patch_size, + tau_range_hours=self.tau_range, + include_wells=self.include_wells, + exclude_fovs=self.exclude_fovs, + cell_index_df=cell_index_df, + positive_cell_source=self.positive_cell_source, + positive_match_columns=self.positive_match_columns, + fit=False, + ) + self.predict_dataset = MultiExperimentTripletDataset( + index=predict_index, + fit=False, + tau_range_hours=self.tau_range, + tau_decay_rate=self.tau_decay_rate, + channels_per_sample=self.channels_per_sample, + positive_cell_source=self.positive_cell_source, + positive_match_columns=self.positive_match_columns, + positive_channel_source=self.positive_channel_source, + label_columns=self.label_columns, + ) + + # Predict transform: normalizations + final center crop only (no augmentations). + # BatchedChannelWiseZReductiond is kept if present in self.augmentations + # since it is architecturally required to produce the 2D model input. + z_reduction = [t for t in self.augmentations if type(t).__name__ == "BatchedChannelWiseZReductiond"] + self._predict_transform = Compose(self.normalizations + z_reduction + [self._train_final_crop()]) + + def _setup_experiment_split(self, registry: ExperimentRegistry, cell_index_df: pd.DataFrame) -> None: """Split by whole experiments into train/val.""" train_names = [e.name for e in registry.experiments if e.name not in self.val_experiments] val_names = [e.name for e in registry.experiments if e.name in self.val_experiments] @@ -364,8 +403,7 @@ def _setup_experiment_split(self, registry: ExperimentRegistry) -> None: tau_range_hours=self.tau_range, include_wells=self.include_wells, exclude_fovs=self.exclude_fovs, - cell_index_path=self.cell_index_path, - num_workers=self.num_workers_index, + cell_index_df=cell_index_df, positive_cell_source=self.positive_cell_source, positive_match_columns=self.positive_match_columns, max_border_shift=self.max_border_shift, @@ -391,8 +429,7 @@ def _setup_experiment_split(self, registry: ExperimentRegistry) -> None: tau_range_hours=self.tau_range, include_wells=self.include_wells, exclude_fovs=self.exclude_fovs, - cell_index_path=self.cell_index_path, - num_workers=self.num_workers_index, + cell_index_df=cell_index_df, positive_cell_source=self.positive_cell_source, positive_match_columns=self.positive_match_columns, max_border_shift=self.max_border_shift, @@ -410,7 +447,7 @@ def _setup_experiment_split(self, registry: ExperimentRegistry) -> None: label_columns=self.label_columns, ) - def _setup_fov_split(self, registry: ExperimentRegistry) -> None: + def _setup_fov_split(self, registry: ExperimentRegistry, cell_index_df: pd.DataFrame) -> None: """Split FOVs within each experiment by split_ratio. Uses experiment-qualified keys ``(experiment, fov_name)`` so that @@ -423,45 +460,119 @@ def _setup_fov_split(self, registry: ExperimentRegistry) -> None: tau_range_hours=self.tau_range, include_wells=self.include_wells, exclude_fovs=self.exclude_fovs, - cell_index_path=self.cell_index_path, - num_workers=self.num_workers_index, + cell_index_df=cell_index_df, positive_cell_source=self.positive_cell_source, positive_match_columns=self.positive_match_columns, tensorstore_config=self.tensorstore_config, ) rng = np.random.default_rng(self.seed) - train_keys: set[tuple[str, str]] = set() - val_keys: set[tuple[str, str]] = set() + + # Build per-row boolean masks directly during the per-experiment + # groupby walk. The previous implementation built + # pd.MultiIndex.from_arrays over every row of tracks + valid_anchors + # (81M+ rows for OPS), which hashes a Python tuple per row and + # dominates setup-time memory. Per-group isin against a small + # Python-set of FOV names is O(group_size) with no object index. + train_fovs_per_exp: dict[str, set[str]] = {} + val_fovs_per_exp: dict[str, set[str]] = {} for exp_name, group in full_index.tracks.groupby("experiment"): fovs = sorted(group["fov_name"].unique()) n_train = max(1, int(len(fovs) * self.split_ratio)) rng.shuffle(fovs) - for f in fovs[:n_train]: - train_keys.add((exp_name, f)) - for f in fovs[n_train:]: - val_keys.add((exp_name, f)) + train_fovs_per_exp[exp_name] = set(fovs[:n_train]) + val_fovs_per_exp[exp_name] = set(fovs[n_train:]) + n_train_fovs = sum(len(s) for s in train_fovs_per_exp.values()) + n_val_fovs = sum(len(s) for s in val_fovs_per_exp.values()) _logger.info( "FOV split (ratio=%.2f): %d train FOVs, %d val FOVs", self.split_ratio, - len(train_keys), - len(val_keys), + n_train_fovs, + n_val_fovs, ) - full_qual = list(zip(full_index.tracks["experiment"], full_index.tracks["fov_name"])) - train_mask = pd.Series([k in train_keys for k in full_qual], index=full_index.tracks.index) + def _build_train_mask(df: pd.DataFrame) -> np.ndarray: + """Row-wise boolean mask: True if (experiment, fov_name) is train.""" + mask = np.zeros(len(df), dtype=bool) + # groupby("experiment") returns integer positions in ``df`` via + # group.index after reset_index; we rely on the caller passing + # reset-indexed frames (which is what MultiExperimentIndex produces). + for exp_name, group in df.groupby("experiment", sort=False): + train_fovs = train_fovs_per_exp.get(exp_name, set()) + if not train_fovs: + continue + sub_mask = group["fov_name"].isin(train_fovs).to_numpy() + mask[group.index.to_numpy()] = sub_mask + return mask + + def _split_by_mask(df: pd.DataFrame, mask: np.ndarray) -> tuple[pd.DataFrame, pd.DataFrame]: + """Partition ``df`` by a boolean mask using integer row indices. + + ``df[bool_mask]`` on an Arrow-backed DataFrame routes through + ``pyarrow.compute.take`` which allocates a fresh buffer per + string column and scales badly with row count × column count. + On a 16M-row × 15-string-col frame this can take 7-8 minutes + per call on a contended node. + + Using ``df.take(int_indices)`` on a frame whose Arrow string + columns have been cast to ``object`` upfront is ~20× faster + because pandas uses plain NumPy fancy indexing on the + materialized object arrays. + """ + train_rows = np.flatnonzero(mask) + val_rows = np.flatnonzero(~mask) + return ( + df.take(train_rows).reset_index(drop=True), + df.take(val_rows).reset_index(drop=True), + ) - train_tracks = full_index.tracks[train_mask].reset_index(drop=True) - val_tracks = full_index.tracks[~train_mask].reset_index(drop=True) + def _materialize_strings(df: pd.DataFrame) -> pd.DataFrame: + """In-place cast remaining ArrowStringArray columns to Categorical. + + ArrowStringArray routes every ``df[mask]`` through + ``pyarrow.compute.take`` which allocates a fresh per-column + buffer and scales catastrophically (7-8 min per call on 16M rows + with 15 string columns on a contended node). Casting to pandas + Categorical uses int codes + a single categories dict, so + slicing is pure NumPy fancy indexing on the codes. + + Low-cardinality columns (``experiment``, ``marker``, etc.) are + already Categorical from ``read_cell_index``/``_align_parquet_columns`` + — those are skipped. High-cardinality columns like ``cell_id`` + become effectively int32-indexed even at ~80M unique values, + since the dict overhead is one-time and the row-aligned codes + are cheap. NumPy-object casts were tried first but allocate + ~5-10 GB of Python string objects per frame, which on 4-rank DDP + OOMs the node. + """ + for col in df.columns: + s = df[col] + if isinstance(s.dtype, pd.CategoricalDtype): + continue + if pd.api.types.is_string_dtype(s) or str(s.dtype).startswith(("string", "Arrow")): + df[col] = s.astype("category") + return df + + _materialize_strings(full_index.tracks) + _materialize_strings(full_index.valid_anchors) + + train_mask = _build_train_mask(full_index.tracks) + train_tracks, val_tracks = _split_by_mask(full_index.tracks, train_mask) + + va = full_index.valid_anchors + train_va_mask = _build_train_mask(va) + train_va, val_va = _split_by_mask(va, train_va_mask) train_index = full_index.clone_with_subset( train_tracks, positive_cell_source=self.positive_cell_source, positive_match_columns=self.positive_match_columns, max_border_shift=self.max_border_shift, + precomputed_valid_anchors=train_va, ) + self.train_dataset = MultiExperimentTripletDataset( index=train_index, fit=True, @@ -474,11 +585,12 @@ def _setup_fov_split(self, registry: ExperimentRegistry) -> None: label_columns=self.label_columns, ) - if val_keys: + if not val_tracks.empty: val_index = full_index.clone_with_subset( val_tracks, positive_cell_source=self.positive_cell_source, positive_match_columns=self.positive_match_columns, + precomputed_valid_anchors=val_va, ) self.val_dataset = MultiExperimentTripletDataset( index=val_index, @@ -496,8 +608,29 @@ def _setup_fov_split(self, registry: ExperimentRegistry) -> None: # Dataloaders # ------------------------------------------------------------------ + def _ddp_topology(self) -> tuple[int, int]: + """Return ``(num_replicas, rank)`` for the current trainer. + + Lightning's auto-wrap hook only passes ``world_size``/``rank`` to + ``sampler``, not ``batch_sampler``. With ``use_distributed_sampler: + false`` and a batch sampler, the datamodule must read them from the + trainer itself and forward them; otherwise every rank iterates the + full sequence and yields identical batches. + + Returns ``(1, 0)`` when no trainer is attached (e.g. bare + dataloader construction in tests) or when the trainer stub lacks + DDP attributes (e.g. the ``_FakeTrainer`` in demo scripts). + """ + trainer = getattr(self, "trainer", None) + world_size = getattr(trainer, "world_size", None) + global_rank = getattr(trainer, "global_rank", None) + if world_size is None or global_rank is None: + return 1, 0 + return world_size, global_rank + def train_dataloader(self) -> ThreadDataLoader: """Return training data loader with FlexibleBatchSampler.""" + num_replicas, rank = self._ddp_topology() sampler = FlexibleBatchSampler( valid_anchors=self.train_dataset.index.valid_anchors, batch_size=self.batch_size, @@ -508,27 +641,76 @@ def train_dataloader(self) -> ThreadDataLoader: temporal_enrichment=self.temporal_enrichment, temporal_window_hours=self.temporal_window_hours, temporal_global_fraction=self.temporal_global_fraction, + num_replicas=num_replicas, + rank=rank, seed=self.seed, ) return ThreadDataLoader( self.train_dataset, use_thread_workers=True, + buffer_size=self.buffer_size, batch_sampler=sampler, num_workers=self.num_workers, + pin_memory=self.pin_memory, + prefetch_factor=self.prefetch_factor, collate_fn=lambda x: x, ) def val_dataloader(self) -> ThreadDataLoader | None: - """Return validation data loader.""" + """Return validation data loader. + + Uses the same ``FlexibleBatchSampler`` as training so ``loss/val`` + is measured on batches whose composition matches the training + regime — e.g. single-marker batches when ``batch_group_by="marker"``, + or perturbation-stratified batches when ``stratify_by`` is set. + + Without this, val was a plain sequential DataLoader that served + one experiment/marker at a time (all 4 example batches end up as + the same marker), and DDP sync of ``loss/val`` silently desynced + across ranks because each rank's shard had a different set of + markers. + + Temporal enrichment is disabled for val (we want a deterministic + representative sample, not oversampled biology-of-interest windows). + """ if self.val_dataset is None: return None + num_replicas, rank = self._ddp_topology() + sampler = FlexibleBatchSampler( + valid_anchors=self.val_dataset.index.valid_anchors, + batch_size=self.batch_size, + batch_group_by=self.batch_group_by, + leaky=self.leaky, + group_weights=self.group_weights, + stratify_by=self.stratify_by, + temporal_enrichment=False, + num_replicas=num_replicas, + rank=rank, + seed=self.seed, + ) return ThreadDataLoader( self.val_dataset, use_thread_workers=True, + buffer_size=self.buffer_size, + batch_sampler=sampler, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + prefetch_factor=self.prefetch_factor, + collate_fn=lambda x: x, + ) + + def predict_dataloader(self) -> ThreadDataLoader: + """Return predict data loader (no shuffling, no dropping).""" + return ThreadDataLoader( + self.predict_dataset, + use_thread_workers=True, + buffer_size=self.buffer_size, batch_size=self.batch_size, num_workers=self.num_workers, - shuffle=self.shuffle_val, + shuffle=False, drop_last=False, + pin_memory=self.pin_memory, + prefetch_factor=self.prefetch_factor, collate_fn=lambda x: x, ) @@ -536,18 +718,15 @@ def val_dataloader(self) -> ThreadDataLoader | None: # Transforms # ------------------------------------------------------------------ - def _train_final_crop(self) -> BatchedRandSpatialCropd: - """Random crop from extraction size to model input size (training).""" - return BatchedRandSpatialCropd( - keys=self._channel_names, - roi_size=(self.z_window, self.final_yx_patch_size[0], self.final_yx_patch_size[1]), - ) - - def _val_final_crop(self) -> BatchedCenterSpatialCropd: - """Center crop from extraction size to model input size (validation).""" + def _train_final_crop(self) -> BatchedCenterSpatialCropd: + """Center crop from extraction size to model input size (training).""" return BatchedCenterSpatialCropd( keys=self._channel_names, - roi_size=(self.z_window, self.final_yx_patch_size[0], self.final_yx_patch_size[1]), + roi_size=( + self.z_window, + self.final_yx_patch_size[0], + self.final_yx_patch_size[1], + ), ) def on_after_batch_transfer(self, batch, dataloader_idx: int): @@ -568,11 +747,38 @@ def on_after_batch_transfer(self, batch, dataloader_idx: int): if isinstance(batch, Tensor): return batch - # Determine transform: augmentation for training, no-aug for val - if self.trainer and self.trainer.validating: - transform = self._no_augmentation_transform - else: - transform = self._augmentation_transform + # During predict: normalizations + z_reduction only (no augmentations, no channel dropout). + if self.trainer.predicting: + norm_meta = batch.get("anchor_norm_meta") + if isinstance(norm_meta, list): + non_none = [m for m in norm_meta if m is not None] + if len(non_none) == 0: + norm_meta = None + elif len(non_none) != len(norm_meta): + raise ValueError("Mixed None/non-None norm_meta in predict batch.") + extra = None + if isinstance(self.channels_per_sample, int): + meta = batch.get("anchor_meta") + if meta is not None: + extra = { + "_is_labelfree": torch.tensor( + [parse_channel_name(m.get("marker", ""))["channel_type"] == "labelfree" for m in meta], + dtype=torch.bool, + device=batch["anchor"].device, + ) + } + batch["anchor"] = _transform_channel_wise( + transform=self._predict_transform, + channel_names=self._channel_names, + patch=batch["anchor"], + norm_meta=norm_meta, + extra=extra, + ) + batch.pop("anchor_norm_meta", None) + batch.pop("anchor_meta", None) + return batch + + transform = self._augmentation_transform for key in ["anchor", "positive", "negative"]: if key in batch: @@ -588,20 +794,31 @@ def on_after_batch_transfer(self, batch, dataloader_idx: int): "All FOVs must have normalization metadata or none of them." ) # else: all non-None, pass through as list + extra = None + if isinstance(self.channels_per_sample, int): + meta = batch.get(f"{key}_meta") + if meta is not None: + extra = { + "_is_labelfree": torch.tensor( + [parse_channel_name(m.get("marker", ""))["channel_type"] == "labelfree" for m in meta], + dtype=torch.bool, + device=batch[key].device, + ) + } transformed = _transform_channel_wise( transform=transform, channel_names=self._channel_names, patch=batch[key], norm_meta=norm_meta, + extra=extra, ) batch[key] = transformed if norm_meta_key in batch: del batch[norm_meta_key] - # Apply ChannelDropout to anchor and positive (training only) - if not (self.trainer and self.trainer.validating): - for key in ["anchor", "positive"]: - if key in batch: - batch[key] = self.channel_dropout(batch[key]) + # Apply ChannelDropout to anchor and positive + for key in ["anchor", "positive"]: + if key in batch: + batch[key] = self.channel_dropout(batch[key]) return batch diff --git a/applications/dynaclr/src/dynaclr/data/dataset.py b/applications/dynaclr/src/dynaclr/data/dataset.py index a2313fe1a..a682b744e 100644 --- a/applications/dynaclr/src/dynaclr/data/dataset.py +++ b/applications/dynaclr/src/dynaclr/data/dataset.py @@ -31,21 +31,70 @@ except ImportError: ts = None +from iohub.ngff import open_ome_zarr + from dynaclr.data.index import MultiExperimentIndex from dynaclr.data.tau_sampling import sample_tau from viscy_data._typing import ULTRACK_INDEX_COLUMNS, NormMeta, SampleMeta from viscy_data._utils import _read_norm_meta + +def _pick_temporal_candidate( + timepoints: dict[int, list[int]], + anchor_t: int, + tau_min: int, + tau_max: int, + tau_decay_rate: float, + rng: np.random.Generator, + tr_marker_arr: np.ndarray | None, + anchor_marker: object | None, +) -> int | None: + """Pick one positive tracks-index for a temporal anchor. + + Mirrors the legacy ``_find_temporal_positive._pick`` logic but + operates on pre-computed NumPy arrays. Returns ``None`` if no + candidate is found in the ``[tau_min, tau_max]`` window. + """ + + def _filter_and_pick(cand_indices: list[int]) -> int | None: + if not cand_indices: + return None + if tr_marker_arr is not None: + # NumPy fancy-index filter: O(n) with n = number of candidates, + # single vectorized array op. + idx_arr = np.asarray(cand_indices, dtype=np.int64) + mask = tr_marker_arr[idx_arr] == anchor_marker + filtered = idx_arr[mask] + if len(filtered) > 0: + return int(filtered[rng.integers(len(filtered))]) + return int(cand_indices[rng.integers(len(cand_indices))]) + + sampled_tau = sample_tau(tau_min, tau_max, rng, tau_decay_rate) + result = _filter_and_pick(timepoints.get(anchor_t + sampled_tau, [])) + if result is not None: + return result + for tau in range(tau_min, tau_max + 1): + if tau == 0: + continue + result = _filter_and_pick(timepoints.get(anchor_t + tau, [])) + if result is not None: + return result + return None + + _META_COLUMNS = [ "experiment", "perturbation", "microscope", "fov_name", + "store_path", "global_track_id", "t", "hours_post_perturbation", "lineage_id", "marker", + "y_clamp", + "x_clamp", ] _logger = logging.getLogger(__name__) @@ -202,8 +251,16 @@ def __init__( self._rng = np.random.default_rng() self._tensorstores: dict[str, ts.TensorStore] = {} + self._store_cache: dict[str, object] = {} # store_path -> Plate + self._position_cache: dict[str, object] = {} # fov_name -> Position self._norm_meta_cache: dict[str, NormMeta | None] = {} - self._build_match_lookup() + if self.fit: + self._build_match_lookup() + self._build_anchor_cache() + + # ------------------------------------------------------------------ + # Initialization helpers + # ------------------------------------------------------------------ def _build_match_lookup(self) -> None: """Build lookup structures for O(1) positive candidate lookup. @@ -222,21 +279,107 @@ def _build_match_lookup(self) -> None: tracks = self.index.tracks if "lineage_id" in self.positive_match_columns: + # observed=True skips unobserved Categorical cross-products; + # without it groupby yields empty groups for every Categorical + # combination, exploding memory and time. Keys are coerced to + # str so the lookup works regardless of dtype (Categorical vs + # object vs ArrowString). + grouped = tracks.groupby(["experiment", "lineage_id", "t"], observed=True).indices self._lineage_timepoints: dict[tuple[str, str], dict[int, list[int]]] = defaultdict( lambda: defaultdict(list) ) - experiments = tracks["experiment"].to_numpy() - lineage_ids = tracks["lineage_id"].to_numpy() - t_values = tracks["t"].to_numpy() - for idx in range(len(tracks)): - self._lineage_timepoints[(experiments[idx], lineage_ids[idx])][t_values[idx]].append(idx) + for (exp, lid, t), row_indices in grouped.items(): + self._lineage_timepoints[(str(exp), str(lid))][int(t)] = row_indices.tolist() else: cols = self.positive_match_columns - self._match_lookup: dict[tuple, list[int]] = defaultdict(list) - col_arrays = [tracks[c].to_numpy() for c in cols] - for idx in range(len(tracks)): - key = tuple(arr[idx] for arr in col_arrays) - self._match_lookup[key].append(idx) + grouped = tracks.groupby(cols).indices + # Store candidate indices as ndarray for O(1) random choice without list copy. + self._match_lookup: dict[tuple, np.ndarray] = { + (k if isinstance(k, tuple) else (k,)): v for k, v in grouped.items() + } + + def _build_anchor_cache(self) -> None: + """Cache valid_anchors/tracks columns as NumPy arrays for fast per-sample access. + + Avoids pandas ``.iloc[idx][col]`` in the hot path, which constructs a + Series per call (~9 ms per anchor on 81M-row indices). NumPy indexing + is ~20 ns. Measured end-to-end speedup: ~3000× on positive-lookup. + + Both ``_va_arrays`` (for anchors) and ``_tr_arrays`` (for positives) + cache the full set of columns needed by ``_slice_patch`` and + ``_build_norm_meta``: ``store_path``, ``fov_name``, ``experiment``, + ``t``, ``y_clamp``, ``x_clamp``, plus ``norm_*`` columns for the + parquet-norm fast path. + + Cache is in-process RAM only — rebuilt on every dataset instantiation + from ``self.index.valid_anchors`` / ``self.index.tracks``. Parquet + remains the source of truth. + + Also precomputes per-experiment tau range (frames) to avoid a registry + lookup per anchor inside ``_sample_positives_temporal``. + """ + + # High-cardinality string columns (store_path, fov_name, experiment, + # marker, channel_name, lineage_id) have few unique values relative to + # row count, so cache them as category codes + categories lookup instead + # of object arrays. Object arrays of strings are ~40-80 bytes/entry; a + # categorical code is 4-8 bytes. On 81M rows this is the difference + # between an OOM and a healthy init. + # + # Access pattern: array[idx] still works if array is a pandas Categorical + # (returns the underlying string); downstream code doesn't care. + def _cache_columns(df: pd.DataFrame, columns: list[str]) -> dict: + out = {} + for col in columns: + if col not in df.columns: + continue + s = df[col] + if s.dtype == object or pd.api.types.is_string_dtype(s): + out[col] = s.astype("category").array # pd.Categorical + else: + out[col] = s.to_numpy() + return out + + # Whitelist columns actually read in the hot path. Caching every + # column of valid_anchors (81M+ rows × ~20 cols × 4 DDP ranks) blows + # the node memory budget; holding only the read set keeps per-rank + # RSS in the low tens of GiB. `positive_match_columns` (user-defined) + # and label column values must also be cached because they drive the + # SupCon key construction and per-sample label lookup respectively. + hot_cols: set[str] = { + "channel_name", + "experiment", + "lineage_id", + "t", + "marker", + "store_path", + "fov_name", + "y_clamp", + "x_clamp", + "norm_mean", + "norm_std", + "norm_median", + "norm_iqr", + } + if self.positive_match_columns: + hot_cols.update(self.positive_match_columns) + if getattr(self, "_label_encoders", None): + for col, _encoder in self._label_encoders.values(): + hot_cols.add(col) + + self._va_arrays: dict = _cache_columns(self.index.valid_anchors, sorted(hot_cols)) + self._tr_arrays: dict = _cache_columns(self.index.tracks, sorted(hot_cols)) + + # Precompute per-experiment tau range in frames to avoid a per-anchor + # registry call inside _sample_positive_indices_temporal. Skip + # experiments with interval_minutes == 0 (static/snapshot datasets like + # OPS) — they never go through the temporal path (positive_match_columns + # wouldn't include lineage_id), so missing entries are harmless and + # computing tau_range_frames for them would ZeroDivisionError. + self._tau_range_frames_cache: dict[str, tuple[int, int]] = {} + for name, exp in self.index.registry._name_map.items(): + if getattr(exp, "interval_minutes", 0): + self._tau_range_frames_cache[name] = self.index.registry.tau_range_frames(name, self.tau_range_hours) # ------------------------------------------------------------------ # Dataset protocol @@ -271,14 +414,16 @@ def __getitems__(self, indices: list[int]) -> dict: anchor_rows = self.index.valid_anchors.iloc[indices] # Pre-compute per-sample channel names based on channel_mode. + # Use the NumPy cache to avoid a pandas Series construction per row. if self._channel_mode == "from_index": - forced_channel_names = [[row["channel_name"]] for _, row in anchor_rows.iterrows()] + chan_arr = self._va_arrays["channel_name"] + forced_channel_names = [[chan_arr[i]] for i in indices] elif self._channel_mode == "fixed": forced_channel_names = [self._fixed_channel_names] * len(indices) else: forced_channel_names = None - anchor_patches, anchor_norms = self._slice_patches(anchor_rows, forced_channel_names) + anchor_patches, anchor_norms = self._slice_patches(self._va_arrays, indices, forced_channel_names) sample: dict = { "anchor": anchor_patches, "anchor_norm_meta": anchor_norms, @@ -286,27 +431,45 @@ def __getitems__(self, indices: list[int]) -> dict: } if self.fit: - positive_rows = self._sample_positives(anchor_rows) - if self._channel_mode == "from_index": - pos_forced_channel_names = [[row["channel_name"]] for _, row in positive_rows.iterrows()] + if self.positive_cell_source == "self": + # SimCLR: anchor and positive share the same patch pre-augmentation. + # Skip the second zarr read + meta extraction entirely — augmentation + # (applied independently downstream in on_after_batch_transfer) is + # what creates the two views. This roughly halves per-batch wall + # time for SimCLR baselines. + # clone the tensor so augmentation has an independent buffer to + # mutate without leaking into the anchor. + sample["positive"] = sample["anchor"].clone() + sample["positive_norm_meta"] = sample["anchor_norm_meta"] + sample["positive_meta"] = sample["anchor_meta"] else: - pos_forced_channel_names = forced_channel_names - positive_patches, positive_norms = self._slice_patches(positive_rows, pos_forced_channel_names) - sample["positive"] = positive_patches - sample["positive_norm_meta"] = positive_norms - sample["positive_meta"] = self._extract_meta(positive_rows) + pos_track_indices = self._sample_positive_indices(anchor_positions=indices) + if self._channel_mode == "from_index": + tr_chan_arr = self._tr_arrays["channel_name"] + pos_forced_channel_names = [[tr_chan_arr[i]] for i in pos_track_indices] + else: + pos_forced_channel_names = forced_channel_names + positive_patches, positive_norms = self._slice_patches( + self._tr_arrays, pos_track_indices, pos_forced_channel_names + ) + positive_rows = self.index.tracks.iloc[pos_track_indices].reset_index(drop=True) + sample["positive"] = positive_patches + sample["positive_norm_meta"] = positive_norms + sample["positive_meta"] = self._extract_meta(positive_rows) else: - indices_list = [] - for _, anchor_row in anchor_rows.iterrows(): - idx_dict: dict = {} - for col in ULTRACK_INDEX_COLUMNS: - if col in anchor_row.index: - idx_dict[col] = anchor_row[col] - elif col not in ["y", "x", "z"]: - # optional columns - pass - indices_list.append(idx_dict) - sample["index"] = indices_list + # Build per-sample index dicts via NumPy column arrays (no .iterrows). + all_cols = list(ULTRACK_INDEX_COLUMNS) + [ + "experiment", + "marker", + "perturbation", + "hours_post_perturbation", + "organelle", + "well", + "microscope", + ] + present_cols = [c for c in all_cols if c in anchor_rows.columns] + col_arrays = {c: anchor_rows[c].to_numpy() for c in present_cols} + sample["index"] = [{c: col_arrays[c][i] for c in present_cols} for i in range(len(anchor_rows))] return sample @@ -328,10 +491,18 @@ def _extract_meta(self, rows: pd.DataFrame) -> list[SampleMeta]: cols = [c for c in _META_COLUMNS if c in rows.columns] records = rows[cols].to_dict(orient="records") if self._label_encoders: - for i, (_, row) in enumerate(rows.iterrows()): + # Pre-extract label columns as NumPy arrays once (avoids per-row + # Series construction in .iterrows()). + label_arrays = { + batch_key: (encoder, rows[col].to_numpy() if col in rows.columns else None) + for batch_key, (col, encoder) in self._label_encoders.items() + } + for i in range(len(records)): labels = {} - for batch_key, (col, encoder) in self._label_encoders.items(): - val = row.get(col) + for batch_key, (encoder, arr) in label_arrays.items(): + if arr is None: + continue + val = arr[i] if val is not None and val in encoder: labels[batch_key] = encoder[val] records[i]["labels"] = labels @@ -341,181 +512,264 @@ def _extract_meta(self, rows: pd.DataFrame) -> list[SampleMeta]: # Positive sampling # ------------------------------------------------------------------ - def _sample_positives(self, anchor_rows: pd.DataFrame) -> pd.DataFrame: - """Sample one positive for each anchor. + def _sample_positive_indices( + self, + anchor_positions: list[int], + ) -> np.ndarray: + """Sample one positive tracks-index for each anchor. - When ``positive_cell_source="self"``, returns a copy of ``anchor_rows`` - (same crop; augmentation creates two views). Otherwise delegates to - :meth:`_find_positive`. + Returns positional indices into ``self.index.tracks`` / ``self._tr_arrays`` + — callers can slice patches directly from the cached NumPy arrays without + materializing a DataFrame. The DataFrame is still constructed downstream + for metadata extraction. Parameters ---------- - anchor_rows : pd.DataFrame - Rows from ``valid_anchors`` for the current batch. + anchor_positions : list[int] + Positional indices into ``valid_anchors`` (same as the sampler output). Returns ------- - pd.DataFrame - One row per anchor from ``self.index.tracks``. + np.ndarray + One tracks-positional-index per anchor, shape ``(len(anchor_positions),)``. """ - if self.positive_cell_source == "self": - return anchor_rows.copy().reset_index(drop=True) + # Temporal lineage mode — vectorized NumPy fast path + # (used by DynaCLR-2D-MIP, DynaCLR-3D-BagOfChannels). + if "lineage_id" in self.positive_match_columns: + return self._sample_positive_indices_temporal(anchor_positions) - pos_rows = [] - for _, row in anchor_rows.iterrows(): - pos = self._find_positive(row, self._rng) - if pos is None: + # Column-match mode (SupCon) — vectorized NumPy fast path. + cols = self.positive_match_columns + va_col_arrs = [self._va_arrays[c] for c in cols] + + pos_track_indices = np.empty(len(anchor_positions), dtype=np.int64) + match_lookup = self._match_lookup + rng = self._rng + for i, ai in enumerate(anchor_positions): + key = tuple(arr[ai] for arr in va_col_arrs) + cands = match_lookup.get(key) + if cands is None or len(cands) == 0: raise RuntimeError( - f"No positive found for anchor (experiment={row.get('experiment')}, " - f"match_key={tuple(row.get(c) for c in self.positive_match_columns)}, " - f"t={row.get('t')}). " + f"No positive found for anchor at position {ai} key={key}. " "This anchor should have been filtered out by valid_anchors." ) - pos_rows.append(pos) - return pd.DataFrame(pos_rows).reset_index(drop=True) + # Random pick from candidates. Note: the anchor's own tracks-index + # may be in `cands`; we don't filter it out explicitly because the + # anchor's valid_anchors-position and its tracks-index are in + # independent index spaces after reset_index(drop=True), and the + # original per-row implementation made the same loose comparison. + # For typical group sizes (>100), the self-as-positive probability + # is <1% — functionally equivalent to `positive_cell_source="self"`. + pos_track_indices[i] = cands[rng.integers(len(cands))] - def _find_positive( - self, - anchor_row: pd.Series, - rng: np.random.Generator, - ) -> pd.Series | None: - """Find a positive sample for a given anchor. + return pos_track_indices + + def _sample_positive_indices_temporal(self, anchor_positions: list[int]) -> np.ndarray: + """Vectorized temporal positive lookup (lineage + tau range). - Dispatches to temporal or generic column-match lookup based on - ``positive_match_columns``. + Uses pre-computed NumPy caches instead of per-row pandas ``.iloc``. + Uses ``self._tau_range_frames_cache`` to avoid a registry call per anchor. Parameters ---------- - anchor_row : pd.Series - A single row from ``valid_anchors``. - rng : numpy.random.Generator - Random number generator for tau sampling and tie-breaking. + anchor_positions : list[int] + Positional indices into ``valid_anchors`` for the batch. Returns ------- - pd.Series or None - A track row for the positive, or ``None`` if no positive found. + np.ndarray + Positional indices into ``self.index.tracks``, one per anchor. """ - if "lineage_id" in self.positive_match_columns: - return self._find_temporal_positive(anchor_row, rng) - return self._find_column_match_positive(anchor_row, rng) + rng = self._rng + exp_arr = self._va_arrays["experiment"] + lid_arr = self._va_arrays["lineage_id"] + t_arr = self._va_arrays["t"] + tau_cache = self._tau_range_frames_cache + + # In from_index mode (flat parquet), we filter candidates to same marker. + marker_filter = self._channel_mode == "from_index" + if marker_filter: + anchor_marker_arr = self._va_arrays["marker"] + tr_marker_arr = self._tr_arrays["marker"] + + pos_track_indices = np.empty(len(anchor_positions), dtype=np.int64) + lt_map = self._lineage_timepoints + + for i, ai in enumerate(anchor_positions): + # Coerce to str: _va_arrays columns come back as Categorical + # scalars after _materialize_strings, which hash differently + # from the str keys in _lineage_timepoints / _tau_range_frames_cache. + exp_name = str(exp_arr[ai]) + lineage_id = str(lid_arr[ai]) + anchor_t = int(t_arr[ai]) + + tau_min, tau_max = tau_cache[exp_name] + timepoints = lt_map.get((exp_name, lineage_id)) + if timepoints is None: + raise RuntimeError( + f"No positive found for anchor at position {ai} " + f"(experiment={exp_name}, lineage_id={lineage_id}, t={anchor_t}). " + "This anchor should have been filtered out by valid_anchors." + ) - def _find_temporal_positive( - self, - anchor_row: pd.Series, - rng: np.random.Generator, - ) -> pd.Series | None: - """Find a temporal positive: same lineage at ``t + tau``. + anchor_marker = anchor_marker_arr[ai] if marker_filter else None + chosen = _pick_temporal_candidate( + timepoints, + anchor_t, + tau_min, + tau_max, + self.tau_decay_rate, + rng, + tr_marker_arr if marker_filter else None, + anchor_marker, + ) + if chosen is None: + raise RuntimeError( + f"No positive found for anchor at position {ai} " + f"(experiment={exp_name}, lineage_id={lineage_id}, t={anchor_t}). " + "This anchor should have been filtered out by valid_anchors." + ) + pos_track_indices[i] = chosen + + return pos_track_indices + + # ------------------------------------------------------------------ + # Patch extraction (tensorstore I/O) + # ------------------------------------------------------------------ + + def _get_position(self, store_path: str, fov_name: str): + """Get or create a cached Position object for the given FOV. + + Cache is keyed by ``(store_path, fov_name)`` — critical for OPS + where the same FOV name (e.g. ``"A/3/0"``) appears across multiple + experiments. Parameters ---------- - anchor_row : pd.Series - A single row from ``valid_anchors``. - rng : numpy.random.Generator - Random number generator for tau sampling and tie-breaking. + store_path : str + Path to the OME-Zarr plate store. + fov_name : str + FOV name (e.g. ``"A/1/0"``). Returns ------- - pd.Series or None - A track row for the positive, or ``None`` if no positive found. + iohub.ngff.Position """ - exp_name = anchor_row["experiment"] - lineage_id = anchor_row["lineage_id"] - anchor_t = anchor_row["t"] - - tau_min, tau_max = self.index.registry.tau_range_frames(exp_name, self.tau_range_hours) - - lt_key = (exp_name, lineage_id) - lt_map = self._lineage_timepoints.get(lt_key) - if lt_map is None: - return None - - # In from_index mode (flat parquet), filter candidates to same marker. - # NOTE:The parquet SHOULD guarantee one channel_name per marker per experiment, - # so marker filtering is equivalent to channel_name filtering. - anchor_marker = anchor_row.get("marker") if self._channel_mode == "from_index" else None - - def _pick(candidate_indices: list[int]) -> pd.Series | None: - if not candidate_indices: - return None - if anchor_marker is not None: - filtered = [ - idx for idx in candidate_indices if self.index.tracks.iloc[idx].get("marker") == anchor_marker - ] - if filtered: - candidate_indices = filtered - chosen_idx = candidate_indices[rng.integers(len(candidate_indices))] - return self.index.tracks.iloc[chosen_idx] - - # Try sampled tau first, then scan full range as fallback - sampled_tau = sample_tau(tau_min, tau_max, rng, self.tau_decay_rate) - target_t = anchor_t + sampled_tau - result = _pick(lt_map.get(target_t, [])) - if result is not None: - return result - - for tau in range(tau_min, tau_max + 1): - if tau == 0: - continue - result = _pick(lt_map.get(anchor_t + tau, [])) - if result is not None: - return result + key = (store_path, fov_name) + if key not in self._position_cache: + if store_path not in self._store_cache: + self._store_cache[store_path] = open_ome_zarr( + store_path, + mode="r", + implementation="tensorstore", + implementation_config=self.index.tensorstore_config, + ) + plate = self._store_cache[store_path] + self._position_cache[key] = plate[fov_name] + return self._position_cache[key] - return None + def _get_tensorstore(self, store_path: str, fov_name: str) -> "ts.TensorStore": + """Get or create a cached tensorstore object for the given FOV. - def _find_column_match_positive( - self, - anchor_row: pd.Series, - rng: np.random.Generator, - ) -> pd.Series | None: - """Find a positive by matching column values, excluding the anchor itself. + Cache is keyed by ``(store_path, fov_name)`` — critical for OPS + where the same FOV name appears across multiple experiments. Parameters ---------- - anchor_row : pd.Series - A single row from ``valid_anchors``. - rng : numpy.random.Generator - Random number generator for tie-breaking. + store_path : str + Path to the OME-Zarr plate store. + fov_name : str + FOV name used together with ``store_path`` as cache key. Returns ------- - pd.Series or None - A track row for the positive, or ``None`` if no candidates found. + ts.TensorStore """ - cols = self.positive_match_columns - key = tuple(anchor_row[c] for c in cols) - all_candidates = self._match_lookup.get(key, []) - # Exclude the anchor row itself by integer index - candidates = [i for i in all_candidates if i != anchor_row.name] - if not candidates: - return None - chosen_idx = candidates[rng.integers(len(candidates))] - return self.index.tracks.iloc[chosen_idx] + key = (store_path, fov_name) + if key not in self._tensorstores: + position = self._get_position(store_path, fov_name) + self._tensorstores[key] = position["0"].native + return self._tensorstores[key] - # ------------------------------------------------------------------ - # Patch extraction (tensorstore I/O) - # ------------------------------------------------------------------ + def _build_norm_meta( + self, + arrays: dict[str, np.ndarray], + idx: int, + forced_channel_names: list[str] | None, + ) -> NormMeta | None: + """Build per-sample normalization metadata from parquet columns. - def _get_tensorstore(self, position, fov_name: str) -> "ts.TensorStore": - """Get or create a cached tensorstore object for the given FOV. + When the parquet has ``norm_mean`` / ``norm_std`` columns (written by + ``preprocess-cell-index``), reads stats directly from the cached + NumPy arrays — no zarr zattrs access and no pandas Series construction. + Falls back to zarr zattrs for old parquets. Parameters ---------- - position : iohub.ngff.Position - Position object from the OME-Zarr store. - fov_name : str - FOV name used as cache key. + arrays : dict[str, np.ndarray] + Pre-cached NumPy column arrays (``_va_arrays`` or ``_tr_arrays``). + idx : int + Positional row index into ``arrays``. + forced_channel_names : list[str] or None + Zarr channel names being read for this sample. Returns ------- - ts.TensorStore + NormMeta or None """ - if fov_name not in self._tensorstores: - self._tensorstores[fov_name] = position["0"].native - return self._tensorstores[fov_name] + # Parquet path: norm columns present and value is not NA + norm_mean_arr = arrays.get("norm_mean") + if norm_mean_arr is not None: + norm_mean = norm_mean_arr[idx] + if norm_mean is not None and not (isinstance(norm_mean, float) and np.isnan(norm_mean)): + tp_stats = { + "mean": torch.tensor(norm_mean, dtype=torch.float32), + "std": torch.tensor(arrays["norm_std"][idx], dtype=torch.float32), + "median": torch.tensor(arrays["norm_median"][idx], dtype=torch.float32), + "iqr": torch.tensor(arrays["norm_iqr"][idx], dtype=torch.float32), + } + if self._channel_mode == "from_index": + return {"channel_0": {"timepoint_statistics": tp_stats}} + else: + ch_arr = arrays.get("channel_name") + ch_name = ch_arr[idx] if ch_arr is not None else "channel_0" + return {ch_name: {"timepoint_statistics": tp_stats}} + + # Fallback: read from zarr zattrs (old parquets without norm columns) + store_path = arrays["store_path"][idx] + fov_name = arrays["fov_name"][idx] + t = arrays["t"][idx] + cache_key = (store_path, fov_name) + if cache_key not in self._norm_meta_cache: + position = self._get_position(store_path, fov_name) + self._norm_meta_cache[cache_key] = _read_norm_meta(position) + cached = self._norm_meta_cache[cache_key] + if cached is None: + return None + raw_norm_meta = {} + for ch, ch_meta in cached.items(): + resolved = {} + for level, level_stats in ch_meta.items(): + if level == "timepoint_statistics" and isinstance(level_stats, dict): + resolved[level] = level_stats.get(str(t)) + else: + resolved[level] = level_stats + raw_norm_meta[ch] = resolved + if forced_channel_names is not None and self._channel_mode == "from_index": + ch = forced_channel_names[0] + if ch in raw_norm_meta: + return {"channel_0": raw_norm_meta[ch]} + return None + if forced_channel_names is not None and self._channel_mode == "fixed": + raw_norm_meta = {name: raw_norm_meta[name] for name in forced_channel_names if name in raw_norm_meta} + return raw_norm_meta or None + return raw_norm_meta def _slice_patch( - self, track_row: pd.Series, forced_channel_names: list[str] | None = None + self, + arrays: dict[str, np.ndarray], + idx: int, + forced_channel_names: list[str] | None = None, ) -> tuple[ "ts.TensorStore", NormMeta | None, @@ -530,8 +784,10 @@ def _slice_patch( Parameters ---------- - track_row : pd.Series - A single row from ``tracks`` or ``valid_anchors``. + arrays : dict[str, np.ndarray] + Pre-cached NumPy column arrays (``_va_arrays`` or ``_tr_arrays``). + idx : int + Positional row index into ``arrays``. forced_channel_names : list[str] or None Zarr channel names to read. When provided, only these channels are sliced from the zarr. None reads all channels. @@ -543,15 +799,15 @@ def _slice_patch( scale factors ``(scale_z, scale_y, scale_x)``, and target size ``(z_window, patch_h, patch_w)``. """ - position = track_row["position"] - fov_name = track_row["fov_name"] - exp_name = track_row["experiment"] + store_path = arrays["store_path"][idx] + fov_name = arrays["fov_name"][idx] + exp_name = arrays["experiment"][idx] - image = self._get_tensorstore(position, fov_name) + image = self._get_tensorstore(store_path, fov_name) - t = track_row["t"] - y_center = int(track_row["y_clamp"]) - x_center = int(track_row["x_clamp"]) + t = int(arrays["t"][idx]) + y_center = int(arrays["y_clamp"][idx]) + x_center = int(arrays["x_clamp"][idx]) # Per-experiment scale factors for physical-space normalization scale_z, scale_y, scale_x = self.index.registry.scale_factors[exp_name] @@ -581,37 +837,8 @@ def _slice_patch( slice(x_center - x_half, x_center + x_half), ] - # Look up norm_meta by zarr channel name directly - # and pre-resolve timepoint_statistics for this sample's timepoint. - # Cache the tensor-converted norm_meta per FOV to avoid repeated - # zattrs reads. Build a shallow per-sample copy (dict structure only, - # tensors shared) since we only replace dict entries, not tensor values. - cache_key = (track_row["store_path"], fov_name) - if cache_key not in self._norm_meta_cache: - self._norm_meta_cache[cache_key] = _read_norm_meta(position) - cached = self._norm_meta_cache[cache_key] - if cached is not None: - raw_norm_meta = {ch: {level: stats for level, stats in ch_meta.items()} for ch, ch_meta in cached.items()} - # Pre-resolve timepoint_statistics for all channels - for ch_name, ch_meta in raw_norm_meta.items(): - if "timepoint_statistics" in ch_meta: - tp_stats = ch_meta["timepoint_statistics"].get(str(t)) - ch_meta["timepoint_statistics"] = tp_stats - else: - raw_norm_meta = None - if raw_norm_meta is not None: - # Filter to requested channels - if forced_channel_names is not None and self._channel_mode == "from_index": - ch = forced_channel_names[0] - if ch in raw_norm_meta: - raw_norm_meta = {"channel_0": raw_norm_meta[ch]} - else: - raw_norm_meta = None - elif forced_channel_names is not None and self._channel_mode == "fixed": - raw_norm_meta = {name: raw_norm_meta[name] for name in forced_channel_names if name in raw_norm_meta} - if not raw_norm_meta: - raw_norm_meta = None - # else: "all" mode — keep full raw_norm_meta + # Build norm_meta from parquet columns (preferred) or zarr zattrs (fallback). + raw_norm_meta = self._build_norm_meta(arrays, idx, forced_channel_names) # Use the configured extraction window as uniform target Z, # not the per-experiment capped range. This ensures all patches @@ -628,15 +855,18 @@ def _slice_patch( def _slice_patches( self, - track_rows: pd.DataFrame, + arrays: dict[str, np.ndarray], + indices: list[int] | np.ndarray, forced_channel_names: list[list[str]] | None = None, ) -> tuple[torch.Tensor, list[NormMeta | None]]: """Slice and stack patches for multiple track rows. Parameters ---------- - track_rows : pd.DataFrame - Multiple rows from ``tracks`` / ``valid_anchors``. + arrays : dict[str, np.ndarray] + Pre-cached NumPy column arrays (``_va_arrays`` or ``_tr_arrays``). + indices : list[int] or np.ndarray + Positional row indices into ``arrays``. forced_channel_names : list[list[str]] or None Per-sample zarr channel names to read. Each inner list contains the channel names for that sample. @@ -651,9 +881,9 @@ def _slice_patches( norms = [] scales = [] targets = [] - for i, (_, row) in enumerate(track_rows.iterrows()): + for i, idx in enumerate(indices): forced = forced_channel_names[i] if forced_channel_names is not None else None - patch, norm, scale, target = self._slice_patch(row, forced_channel_names=forced) + patch, norm, scale, target = self._slice_patch(arrays, int(idx), forced_channel_names=forced) patches.append(patch) norms.append(norm) scales.append(scale) @@ -665,13 +895,31 @@ def _slice_patches( for i, p in enumerate(patches): shape_groups[tuple(p.shape)].append(i) read_tensors: list[Tensor | None] = [None] * len(patches) - for idxs in shape_groups.values(): - group_patches = [patches[i] for i in idxs] - group_result = ts.stack([p.translate_to[0] for p in group_patches]).read().result() # noqa: PD013 + # Issue every shape group's read inside one ts.Batch() so the C++ + # executor can overlap them; only block on .result() after all are + # dispatched. With multiple shape groups (mixed-experiment batches), + # this lets tensorstore schedule reads concurrently instead of one + # group at a time. + pending: list[tuple[list[int], "ts.Future"]] = [] + with ts.Batch(): + for idxs in shape_groups.values(): + group_patches = [patches[i] for i in idxs] + fut = ts.stack([p.translate_to[0] for p in group_patches]).read() # noqa: PD013 + pending.append((idxs, fut)) + for idxs, fut in pending: + group_result = fut.result() for j, idx in enumerate(idxs): read_tensors[idx] = torch.from_numpy(group_result[j]) # Rescale each patch to the uniform target size rescaled = [] for i in range(len(patches)): rescaled.append(_rescale_patch(read_tensors[i], scales[i], targets[i])) + channel_counts = {t.shape[0] for t in rescaled} + if len(channel_counts) > 1: + raise RuntimeError( + f"Batch mixes samples with different channel counts: {sorted(channel_counts)}. " + "This happens with channels_per_sample=None across experiments that have " + "different channel counts. Set channels_per_sample=1 (bag-of-channels) " + "or channels_per_sample=[...] (fixed channel list)." + ) return torch.stack(rescaled), norms diff --git a/applications/dynaclr/src/dynaclr/data/experiment.py b/applications/dynaclr/src/dynaclr/data/experiment.py index 96187cafa..8134f7d54 100644 --- a/applications/dynaclr/src/dynaclr/data/experiment.py +++ b/applications/dynaclr/src/dynaclr/data/experiment.py @@ -12,6 +12,7 @@ from dataclasses import dataclass, field from pathlib import Path +import pandas as pd from iohub.ngff import open_ome_zarr from viscy_data.cell_index import read_cell_index @@ -96,27 +97,26 @@ def __post_init__(self) -> None: # noqa: D105 # Build name -> config map self._name_map = {e.name: e for e in experiments} - # Per-experiment validations + # Per-experiment validation + z-range resolution (single zarr open each) + z_extract = self.z_extraction_window or self.z_window + z_ranges: dict[str, tuple[int, int]] = {} + for exp in experiments: - # 4. Negative interval if exp.interval_minutes < 0: raise ValueError( f"Experiment '{exp.name}': interval_minutes must be non-negative, got {exp.interval_minutes}." ) - - # 5. Empty perturbation_wells if not exp.perturbation_wells: raise ValueError(f"Experiment '{exp.name}': perturbation_wells must not be empty.") - - # 6. data_path existence if not Path(exp.data_path).exists(): raise ValueError(f"Experiment '{exp.name}': data_path does not exist: {exp.data_path}") - # 7. Zarr channel validation — selected channels must exist in zarr with open_ome_zarr(exp.data_path, mode="r") as plate: first_position = next(iter(plate.positions()))[1] zarr_channels = list(first_position.channel_names) - # Store the full zarr channel list for index resolution + z_total = first_position["0"].shape[2] + focus_data = plate.zattrs.get("focus_slice", {}) + exp.channel_names = zarr_channels missing_channels = [ch.name for ch in exp.channels if ch.name not in zarr_channels] if missing_channels: @@ -125,16 +125,52 @@ def __post_init__(self) -> None: # noqa: D105 f"not found in zarr. Available: {zarr_channels}." ) - # Resolve per-experiment z_ranges - self.z_ranges = self._resolve_z_ranges() + # Z-range resolution + if z_extract is None: + z_ranges[exp.name] = (0, z_total) + else: + focus_ch = self.focus_channel or (exp.channels[0].name if exp.channels else None) + ch_focus = focus_data.get(focus_ch, {}) if focus_ch else {} + ds_stats = ch_focus.get("dataset_statistics", {}) + z_focus_mean = ds_stats.get("z_focus_mean") + + z_center = int(round(z_focus_mean)) if z_focus_mean is not None else z_total // 2 + effective_extract = min(z_extract, z_total) + z_below = int(effective_extract * self.z_focus_offset) + z_start = max(0, z_center - z_below) + z_end = min(z_total, z_start + effective_extract) + z_start = max(0, z_end - effective_extract) + + z_ranges[exp.name] = (z_start, z_end) + _logger.info( + "Experiment '%s': z_range=(%d, %d), z_total=%d, z_extraction_window=%d", + exp.name, + z_start, + z_end, + z_total, + effective_extract, + ) + + # Validate extraction windows >= z_window + if self.z_window is not None and z_ranges: + for name, (z_s, z_e) in z_ranges.items(): + if z_e - z_s < self.z_window: + raise ValueError( + f"Experiment '{name}': extraction range ({z_e - z_s}) " + f"< z_window ({self.z_window}). Increase z_extraction_window " + f"or reduce z_window." + ) + self.z_ranges = z_ranges # Validate pixel sizes and compute scale factors - if self.reference_pixel_size_xy_um is not None or self.reference_pixel_size_z_um is not None: - missing = [e.name for e in experiments if e.pixel_size_xy_um is None or e.pixel_size_z_um is None] + if self.reference_pixel_size_xy_um is not None: + missing = [e.name for e in experiments if e.pixel_size_xy_um is None] if missing: - raise ValueError( - f"reference_pixel_size set but experiments are missing pixel_size_xy_um/z_um: {missing}" - ) + raise ValueError(f"reference_pixel_size_xy_um set but experiments missing pixel_size_xy_um: {missing}") + if self.reference_pixel_size_z_um is not None: + missing = [e.name for e in experiments if e.pixel_size_z_um is None] + if missing: + raise ValueError(f"reference_pixel_size_z_um set but experiments missing pixel_size_z_um: {missing}") self.scale_factors = self._compute_scale_factors() @property @@ -158,72 +194,6 @@ def source_channel_labels(self) -> list[str]: # Internal helpers # ------------------------------------------------------------------ - def _resolve_z_ranges(self) -> dict[str, tuple[int, int]]: - """Resolve per-experiment Z extraction ranges. - - When ``z_extraction_window`` is set, extracts a larger Z range - centered on ``z_focus_mean`` (capped by the available Z depth). - The random crop from extraction size to ``z_window`` happens later - in ``on_after_batch_transfer``. - - Falls back to ``z_window`` when ``z_extraction_window`` is None. - """ - experiments = self.collection.experiments - z_ranges: dict[str, tuple[int, int]] = {} - z_extract = self.z_extraction_window or self.z_window - - for exp in experiments: - focus_ch = self.focus_channel or (exp.channels[0].name if exp.channels else None) - - with open_ome_zarr(exp.data_path, mode="r") as plate: - first_pos = next(iter(plate.positions()))[1] - z_total = first_pos["0"].shape[2] - - if z_extract is None: - z_ranges[exp.name] = (0, z_total) - continue - - focus_data = plate.zattrs.get("focus_slice", {}) - ch_focus = focus_data.get(focus_ch, {}) if focus_ch else {} - ds_stats = ch_focus.get("dataset_statistics", {}) - z_focus_mean = ds_stats.get("z_focus_mean") - - if z_focus_mean is None: - z_center = z_total // 2 - else: - z_center = int(round(z_focus_mean)) - - # Cap extraction window by available Z depth. - # z_focus_offset controls asymmetry: 0.5 = symmetric, - # 0.3 = 30% below focus, 70% above (cells on coverslip). - effective_extract = min(z_extract, z_total) - z_below = int(effective_extract * self.z_focus_offset) - z_start = max(0, z_center - z_below) - z_end = min(z_total, z_start + effective_extract) - z_start = max(0, z_end - effective_extract) - - z_ranges[exp.name] = (z_start, z_end) - _logger.info( - "Experiment '%s': z_range=(%d, %d), z_total=%d, z_extraction_window=%d", - exp.name, - z_start, - z_end, - z_total, - effective_extract, - ) - - # Validate: all extraction windows must be >= z_window - if self.z_window is not None and z_ranges: - for name, (z_s, z_e) in z_ranges.items(): - if z_e - z_s < self.z_window: - raise ValueError( - f"Experiment '{name}': extraction range ({z_e - z_s}) " - f"< z_window ({self.z_window}). Increase z_extraction_window " - f"or reduce z_window." - ) - - return z_ranges - def _compute_scale_factors(self) -> dict[str, tuple[float, float, float]]: """Compute per-experiment scale factors for physical-space normalization. @@ -237,18 +207,15 @@ def _compute_scale_factors(self) -> dict[str, tuple[float, float, float]]: """ scale_factors: dict[str, tuple[float, float, float]] = {} for exp in self.collection.experiments: - if ( - self.reference_pixel_size_xy_um is not None - and self.reference_pixel_size_z_um is not None - and exp.pixel_size_xy_um is not None - and exp.pixel_size_z_um is not None - ): + if self.reference_pixel_size_xy_um is not None and exp.pixel_size_xy_um is not None: scale_y = self.reference_pixel_size_xy_um / exp.pixel_size_xy_um scale_x = self.reference_pixel_size_xy_um / exp.pixel_size_xy_um - scale_z = self.reference_pixel_size_z_um / exp.pixel_size_z_um else: scale_y = 1.0 scale_x = 1.0 + if self.reference_pixel_size_z_um is not None and exp.pixel_size_z_um is not None: + scale_z = self.reference_pixel_size_z_um / exp.pixel_size_z_um + else: scale_z = 1.0 scale_factors[exp.name] = (scale_z, scale_y, scale_x) return scale_factors @@ -313,7 +280,7 @@ def from_cell_index( focus_channel: str | None = None, reference_pixel_size_xy_um: float | None = None, reference_pixel_size_z_um: float | None = None, - ) -> ExperimentRegistry: + ) -> tuple["ExperimentRegistry", "pd.DataFrame"]: """Build a registry from a flat cell index parquet and zarr metadata. Derives per-experiment channels from the parquet's ``marker`` and @@ -339,32 +306,24 @@ def from_cell_index( Returns ------- - ExperimentRegistry - Validated registry of experiments. + tuple[ExperimentRegistry, pd.DataFrame] + Validated registry of experiments and the raw cell index DataFrame. """ df = read_cell_index(cell_index_path) if df.empty: raise ValueError(f"Cell index is empty: {cell_index_path}") - # Step 1: Read channel names per (store_path, well) from zarr. - channel_names_cache: dict[tuple[str, str], list[str]] = {} - store_cache: dict[str, object] = {} + # Step 1: Read channel names per store from a single FOV. + # Channel names are uniform across all positions in a plate, + # so we open one FOV directly (store_path/well/fov) instead of + # iterating all positions. + channel_names_cache: dict[str, list[str]] = {} for store_path, group in df.groupby("store_path"): - plate = open_ome_zarr(str(store_path), mode="r") - store_cache[str(store_path)] = plate - for well in group["well"].unique(): - # Find one position in this well - well_str = str(well) - for pos_path, pos in plate.positions(): - if pos_path.startswith(well_str + "/"): - channel_names_cache[(str(store_path), well_str)] = list(pos.channel_names) - break - - # Close all opened stores - for plate in store_cache.values(): - if hasattr(plate, "close"): - plate.close() + first = group.iloc[0] + fov_path = f"{store_path}/{first['well']}/{first['fov']}" + with open_ome_zarr(fov_path, mode="r") as pos: + channel_names_cache[str(store_path)] = list(pos.channel_names) # Step 2: Derive per-experiment channels from flat (marker, channel_name) columns. exp_channels: dict[str, list[ChannelEntry]] = defaultdict(list) @@ -381,14 +340,7 @@ def from_cell_index( for exp_name, exp_group in df.groupby("experiment"): exp_name = str(exp_name) store_path = str(exp_group["store_path"].iloc[0]) - first_well = str(exp_group["well"].iloc[0]) - - channel_names = channel_names_cache.get((store_path, first_well)) - if channel_names is None: - raise ValueError( - f"Experiment '{exp_name}': could not read channel names from zarr " - f"(store_path={store_path}, well={first_well})." - ) + channel_names = channel_names_cache[store_path] # Derive perturbation_wells from parquet perturbation_wells: dict[str, list[str]] = defaultdict(list) @@ -453,7 +405,7 @@ def from_cell_index( experiments=experiments, ) - return cls( + registry = cls( collection=collection, z_window=z_window, z_extraction_window=z_extraction_window, @@ -462,6 +414,7 @@ def from_cell_index( reference_pixel_size_xy_um=reference_pixel_size_xy_um, reference_pixel_size_z_um=reference_pixel_size_z_um, ) + return registry, df def subset(self, experiment_names: list[str]) -> ExperimentRegistry: """Create a new registry with a subset of experiments. diff --git a/applications/dynaclr/src/dynaclr/data/index.py b/applications/dynaclr/src/dynaclr/data/index.py index 177d747ba..ddff9168a 100644 --- a/applications/dynaclr/src/dynaclr/data/index.py +++ b/applications/dynaclr/src/dynaclr/data/index.py @@ -15,7 +15,7 @@ import numpy as np import pandas as pd from iohub.core.config import TensorStoreConfig -from iohub.ngff import Plate, Position, open_ome_zarr +from iohub.ngff import Plate, open_ome_zarr from dynaclr.data.experiment import ExperimentRegistry from viscy_data.cell_index import read_cell_index @@ -185,10 +185,12 @@ def __init__( include_wells: list[str] | None = None, exclude_fovs: list[str] | None = None, cell_index_path: str | Path | None = None, + cell_index_df: pd.DataFrame | None = None, num_workers: int = 1, positive_cell_source: str = "lookup", positive_match_columns: list[str] | None = None, max_border_shift: int = -1, + fit: bool = True, tensorstore_config: TensorStoreConfig | None = None, ) -> None: self.registry = registry @@ -217,44 +219,53 @@ def __init__( else: all_exclude_fovs = None - if cell_index_path is not None: - _logger.info("Loading cell index from parquet: %s", cell_index_path) - tracks = read_cell_index(cell_index_path) - tracks = self._align_parquet_columns(tracks) + if cell_index_df is not None or cell_index_path is not None: + if cell_index_df is not None: + _logger.info( + "Using pre-loaded cell index DataFrame (%d rows)", + len(cell_index_df), + ) + tracks = self._align_parquet_columns(cell_index_df.copy()) + else: + _logger.info("Loading cell index from parquet: %s", cell_index_path) + tracks = read_cell_index(cell_index_path) + tracks = self._align_parquet_columns(tracks) if include_wells is not None: tracks = tracks[tracks["well_name"].isin(include_wells)].copy() if all_exclude_fovs is not None: tracks = tracks[~tracks["fov_name"].isin(all_exclude_fovs)].copy() tracks = self._filter_to_registry_experiments(tracks) - positions, tracks = self._resolve_positions_and_dims(tracks) - self.positions = positions + tracks = self._resolve_dims(tracks) # lineage_id already present from build step — skip _reconstruct_lineage - tracks = self._filter_empty_frames(tracks) + # Empty frames already filtered at parquet build time — skip _filter_empty_frames else: all_tracks = self._load_all_experiments( - include_wells=include_wells, exclude_fovs=all_exclude_fovs, num_workers=num_workers + include_wells=include_wells, + exclude_fovs=all_exclude_fovs, + num_workers=num_workers, ) tracks = pd.concat(all_tracks, ignore_index=True) if all_tracks else pd.DataFrame() tracks = self._reconstruct_lineage(tracks) - positions, tracks = self._resolve_positions_and_dims(tracks) - self.positions = positions - tracks = self._filter_empty_frames(tracks) + tracks = self._resolve_dims(tracks) tracks = self._clamp_borders(tracks) self.tracks = tracks.reset_index(drop=True) - self.valid_anchors = self._compute_valid_anchors( - tau_range_hours, - positive_cell_source=positive_cell_source, - positive_match_columns=positive_match_columns, - ) - if self.valid_anchors.empty and not self.tracks.empty: - raise ValueError( - f"No valid anchors found from {len(self.tracks)} tracks. " - f"positive_cell_source={positive_cell_source!r}, " - f"positive_match_columns={positive_match_columns!r}, " - f"tau_range_hours={tau_range_hours}. " - "Check that tracks have matching positives under these settings." + if fit: + self.valid_anchors = self._compute_valid_anchors( + tau_range_hours, + positive_cell_source=positive_cell_source, + positive_match_columns=positive_match_columns, ) + if self.valid_anchors.empty and not self.tracks.empty: + raise ValueError( + f"No valid anchors found from {len(self.tracks)} tracks. " + f"positive_cell_source={positive_cell_source!r}, " + f"positive_match_columns={positive_match_columns!r}, " + f"tau_range_hours={tau_range_hours}. " + "Check that tracks have matching positives under these settings." + ) + else: + self.valid_anchors = self.tracks # ------- internal methods ------- @@ -344,6 +355,15 @@ def _align_parquet_columns(tracks: pd.DataFrame) -> pd.DataFrame: ) if "microscope" not in tracks.columns: tracks["microscope"] = "" + # Cast low-cardinality string columns to Categorical to make + # downstream boolean-mask slicing (train/val split) a fast int-code + # gather instead of a pyarrow.compute.take over Arrow string buffers. + # Deferred from read_cell_index because ``fov_name`` is rewritten by + # the prefix logic above and Categorical columns don't support string + # concatenation. + for col in ("fov_name", "well_name"): + if col in tracks.columns and tracks[col].dtype == object: + tracks[col] = tracks[col].astype("category") return tracks def _filter_to_registry_experiments(self, tracks: pd.DataFrame) -> pd.DataFrame: @@ -351,22 +371,30 @@ def _filter_to_registry_experiments(self, tracks: pd.DataFrame) -> pd.DataFrame: registry_names = {exp.name for exp in self.registry.experiments} return tracks[tracks["experiment"].isin(registry_names)].copy() - def _resolve_positions_and_dims(self, tracks: pd.DataFrame) -> tuple[list[Position], pd.DataFrame]: - """Open zarr stores for unique (store_path, fov_name) pairs. + def _resolve_dims(self, tracks: pd.DataFrame) -> pd.DataFrame: + """Attach image dimensions to tracks for border clamping. - Attaches ``position``, ``_img_height``, ``_img_width`` columns to - *tracks* and returns the list of resolved Position objects. + When the parquet has ``Y_shape`` / ``X_shape`` columns (built with the + latest ``build_timelapse_cell_index``), reads dimensions directly — no + zarr opens needed. Falls back to opening stores when the columns are + missing (old parquets). """ - all_positions: list[Position] = [] - pos_lookup: dict[tuple[str, str], Position] = {} - dim_lookup: dict[tuple[str, str], tuple[int, int]] = {} - if tracks.empty: - tracks["position"] = pd.Series(dtype=object) tracks["_img_height"] = pd.Series(dtype=int) tracks["_img_width"] = pd.Series(dtype=int) - return all_positions, tracks + return tracks + if "Y_shape" in tracks.columns and "X_shape" in tracks.columns: + tracks["_img_height"] = tracks["Y_shape"] + tracks["_img_width"] = tracks["X_shape"] + return tracks + + _logger.warning( + "Parquet missing Y_shape/X_shape columns. Falling back to opening " + "zarr stores for image dimensions. Rebuild the parquet with " + "`build-cell-index` for faster startup." + ) + dim_lookup: dict[tuple[str, str], tuple[int, int]] = {} for (store_path, well_name, fov_name), _group in tracks.groupby(["store_path", "well_name", "fov_name"]): if store_path not in self._store_cache: self._store_cache[store_path] = open_ome_zarr( @@ -376,60 +404,17 @@ def _resolve_positions_and_dims(self, tracks: pd.DataFrame) -> tuple[list[Positi implementation_config=self.tensorstore_config, ) plate = self._store_cache[store_path] - # fov_name may be just the FOV id (e.g. "000000") or the full - # position path (e.g. "C/1/000000"). Prepend well_name when needed. if "/" in fov_name: position_path = fov_name else: position_path = f"{well_name}/{fov_name}" position = plate[position_path] - pos_lookup[(store_path, fov_name)] = position image = position["0"] dim_lookup[(store_path, fov_name)] = (image.height, image.width) - all_positions.append(position) - tracks["position"] = [pos_lookup[(sp, fn)] for sp, fn in zip(tracks["store_path"], tracks["fov_name"])] tracks["_img_height"] = [dim_lookup[(sp, fn)][0] for sp, fn in zip(tracks["store_path"], tracks["fov_name"])] tracks["_img_width"] = [dim_lookup[(sp, fn)][1] for sp, fn in zip(tracks["store_path"], tracks["fov_name"])] - - return all_positions, tracks - - @staticmethod - def _filter_empty_frames(tracks: pd.DataFrame) -> pd.DataFrame: - """Remove rows whose image frame is all zeros (missing acquisition). - - For each unique (store_path, fov_name, t) combination, reads a small - center crop of channel 0 to detect empty frames. Rows with an all-zero - frame are dropped. - """ - if tracks.empty or "t" not in tracks.columns: - return tracks - - valid_mask = pd.Series(True, index=tracks.index) - - for (store_path, fov_name), group in tracks.groupby(["store_path", "fov_name"]): - pos = group["position"].iloc[0] - image = pos["0"] - h, w = image.shape[-2], image.shape[-1] - cy, cx = h // 2, w // 2 - crop = 16 # 32x32 center crop is enough to detect empty frames - - for t in group["t"].unique(): - try: - patch = np.asarray(image[int(t), 0, :, cy - crop : cy + crop, cx - crop : cx + crop]) - if patch.max() == 0: - row_mask = ( - (tracks["store_path"] == store_path) & (tracks["fov_name"] == fov_name) & (tracks["t"] == t) - ) - valid_mask[row_mask] = False - except Exception: - pass # if we can't read, keep the row - - n_dropped = (~valid_mask).sum() - if n_dropped > 0: - _logger.info("Excluded %d observations from empty frames", n_dropped) - - return tracks[valid_mask].copy() + return tracks @staticmethod def _reconstruct_lineage(tracks: pd.DataFrame) -> pd.DataFrame: @@ -538,7 +523,11 @@ def _clamp_borders(self, tracks: pd.DataFrame) -> pd.DataFrame: n_dropped = n_before - len(tracks) if n_dropped > 0: - _logger.info("Excluded %d border cells (%.1f%%)", n_dropped, 100 * n_dropped / n_before) + _logger.info( + "Excluded %d border cells (%.1f%%)", + n_dropped, + 100 * n_dropped / n_before, + ) tracks = tracks.drop(columns=["_img_height", "_img_width"]) @@ -591,33 +580,43 @@ def _compute_valid_anchors( # Temporal mode: keep only anchors that have a positive at t+tau. # For each experiment, check whether (lineage_id, t+tau) exists - # for any tau in [min_f, max_f] (excluding 0). + # for any tau in [min_f, max_f] (excluding 0). In flat-parquet + # mode (one row per cell × channel), the dataset restricts + # candidates to the same marker at t+tau, so ``marker`` must be + # part of the match key here. Otherwise an anchor at (lid, marker=A, t) + # could pass validation because (lid, marker=B, t+1) exists, but + # fail at sample time because no (lid, marker=A, t+1) exists. + filter_by_marker = "marker" in self.tracks.columns + key_cols = ["lineage_id", "marker", "t"] if filter_by_marker else ["lineage_id", "t"] valid_mask = np.zeros(len(self.tracks), dtype=bool) for exp in self.registry.experiments: min_f, max_f = self.registry.tau_range_frames(exp.name, tau_range_hours) - exp_mask = self.tracks["experiment"].to_numpy() == exp.name - exp_indices = np.where(exp_mask)[0] - if len(exp_indices) == 0: + exp_mask = self.tracks["experiment"] == exp.name + exp_df = self.tracks.loc[exp_mask, key_cols] + if exp_df.empty: continue - lineage_ids = self.tracks["lineage_id"].to_numpy()[exp_indices] - t_values = self.tracks["t"].to_numpy()[exp_indices] - existing_pairs: set[tuple] = set(zip(lineage_ids, t_values)) + taus = [tau for tau in range(min_f, max_f + 1) if tau != 0] + + # Unique key tuples as a MultiIndex for O(1) isin checks. + existing = exp_df.drop_duplicates() + existing_mi = pd.MultiIndex.from_frame(existing) - # Collect all anchor (lineage_id, t) that have any valid positive - valid_anchors: set[tuple] = set() - for tau in range(min_f, max_f + 1): - if tau == 0: - continue - for lid, t in existing_pairs: - if (lid, t + tau) in existing_pairs: - valid_anchors.add((lid, t)) + # For each unique anchor key, check if the shifted key (same + # lineage_id/marker, t+tau) exists for any tau. + found_any = np.zeros(len(existing), dtype=bool) + t_vals = existing["t"].to_numpy() + non_t_arrays = [existing[c].to_numpy() for c in key_cols if c != "t"] + for tau in taus: + shifted_arrays = non_t_arrays + [t_vals + tau] + targets = pd.MultiIndex.from_arrays(shifted_arrays) + found_any |= targets.isin(existing_mi) - # Mark matching rows - for i, idx in enumerate(exp_indices): - if (lineage_ids[i], t_values[i]) in valid_anchors: - valid_mask[idx] = True + # Map valid unique pairs back to all rows in the experiment. + valid_pairs_mi = pd.MultiIndex.from_frame(existing[found_any]) + row_keys = pd.MultiIndex.from_frame(exp_df) + valid_mask[exp_mask.to_numpy()] = row_keys.isin(valid_pairs_mi) return self.tracks[valid_mask].reset_index(drop=True) @@ -651,11 +650,13 @@ def clone_with_subset( positive_cell_source: str = "lookup", positive_match_columns: list[str] | None = None, max_border_shift: int = -1, + precomputed_valid_anchors: pd.DataFrame | None = None, ) -> "MultiExperimentIndex": """Create a shallow copy with a different tracks DataFrame. Reuses the parent's registry, positions, and store cache so no - zarr stores are re-opened. Recomputes ``valid_anchors``. + zarr stores are re-opened. Recomputes ``valid_anchors`` unless + ``precomputed_valid_anchors`` is provided. Parameters ---------- @@ -667,20 +668,27 @@ def clone_with_subset( Forwarded to ``_compute_valid_anchors``. max_border_shift : int Forwarded to ``self.max_border_shift``. -1 inherits from parent. + precomputed_valid_anchors : pd.DataFrame | None + When provided, skip recomputing valid anchors. Pass the already- + filtered valid_anchors subset for this tracks_subset. Avoids + redundant O(N * tau_range) computation in FOV split mode. """ clone = object.__new__(MultiExperimentIndex) clone.registry = self.registry clone.yx_patch_size = self.yx_patch_size clone.tau_range_hours = self.tau_range_hours clone._store_cache = self._store_cache - clone.positions = self.positions + clone.tensorstore_config = self.tensorstore_config clone.max_border_shift = self.max_border_shift if max_border_shift < 0 else max_border_shift clone.tracks = tracks_subset.reset_index(drop=True) - clone.valid_anchors = clone._compute_valid_anchors( - tau_range_hours=self.tau_range_hours, - positive_cell_source=positive_cell_source, - positive_match_columns=positive_match_columns, - ) + if precomputed_valid_anchors is not None: + clone.valid_anchors = precomputed_valid_anchors.reset_index(drop=True) + else: + clone.valid_anchors = clone._compute_valid_anchors( + tau_range_hours=self.tau_range_hours, + positive_cell_source=positive_cell_source, + positive_match_columns=positive_match_columns, + ) if clone.valid_anchors.empty and not clone.tracks.empty: raise ValueError( f"No valid anchors found from {len(clone.tracks)} tracks in subset. " diff --git a/applications/dynaclr/src/dynaclr/data/preprocess_cell_index.py b/applications/dynaclr/src/dynaclr/data/preprocess_cell_index.py new file mode 100644 index 000000000..e49bc1c74 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/data/preprocess_cell_index.py @@ -0,0 +1,30 @@ +"""CLI command for preprocessing a cell index parquet (add norm stats, focus slice, remove empties).""" + +import click + +from viscy_data.cell_index import preprocess_cell_index + + +@click.command() +@click.argument("parquet_path") +@click.option( + "--output", + default=None, + help="Output path. Default: overwrite in place.", +) +@click.option( + "--focus-channel", + default=None, + help="Channel name for focus_slice lookup (e.g. Phase3D). Default: first channel per FOV.", +) +def main(parquet_path, output, focus_channel): + """Preprocess a cell index parquet: add normalization stats, focus slice, remove empty frames. + + Reads precomputed metadata from zarr zattrs and writes them as parquet + columns. Requires `viscy preprocess` to have been run on the zarr stores. + """ + preprocess_cell_index( + parquet_path=parquet_path, + output_path=output, + focus_channel=focus_channel, + ) diff --git a/applications/dynaclr/src/dynaclr/engine.py b/applications/dynaclr/src/dynaclr/engine.py index ead52a89b..924ee9b5c 100644 --- a/applications/dynaclr/src/dynaclr/engine.py +++ b/applications/dynaclr/src/dynaclr/engine.py @@ -225,11 +225,15 @@ def log_embedding_pca(self, embeddings: Tensor, meta: list[dict], tag: str, n_co plt.close(fig) def _get_labels(self, batch: TripletSample, batch_key: str) -> Tensor | None: - """Extract integer labels for a head from the batch. + """Extract labels (scalar or vector) for a head from the batch. Checks top-level batch keys first, then falls back to ``anchor_meta[i]["labels"][batch_key]`` for metadata-carried labels. Returns ``None`` if the key is not found in either location. + + Vector-valued entries (e.g. paired transcriptomic embeddings) are + stacked into a ``(B, D)`` float tensor; scalar entries are stacked + into a ``(B,)`` long tensor. """ if batch_key in batch: return batch[batch_key] @@ -237,6 +241,10 @@ def _get_labels(self, batch: TripletSample, batch_key: str) -> Tensor | None: if not meta or "labels" not in meta[0] or batch_key not in meta[0]["labels"]: return None vals = [m["labels"][batch_key] for m in meta] + first = vals[0] + if isinstance(first, (list, tuple, np.ndarray, Tensor)): + arr = np.asarray(vals, dtype=np.float32) + return torch.from_numpy(arr).to(self.device) return torch.tensor(vals, dtype=torch.long, device=self.device) def _run_auxiliary_heads(self, anchor_features: Tensor, batch: TripletSample, stage: str) -> Tensor: diff --git a/applications/dynaclr/src/dynaclr/evaluation/append_annotations.py b/applications/dynaclr/src/dynaclr/evaluation/append_annotations.py new file mode 100644 index 000000000..3729b649a --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/append_annotations.py @@ -0,0 +1,138 @@ +"""CLI for appending annotation columns to per-experiment AnnData zarr stores. + +Reads per-experiment annotation CSVs and writes task columns (e.g. infection_state, +organelle_state) directly into each zarr's obs. This persists ground truth labels +alongside the embeddings so downstream plots can color by annotation. + +Called as a step in the Nextflow evaluation pipeline after split-embeddings. +Annotation sources are shared with the linear_classifiers step config. + +Usage +----- +dynaclr append-annotations -c append_annotations.yaml +""" + +from __future__ import annotations + +from pathlib import Path + +import anndata as ad +import click + +from dynaclr.evaluation.evaluate_config import AnnotationSource, TaskSpec +from viscy_utils.cli_utils import load_config +from viscy_utils.evaluation.annotation import load_annotation_anndata +from viscy_utils.evaluation.zarr_utils import append_to_anndata_zarr + + +def append_annotations( + embeddings_path: Path, + annotations: list[AnnotationSource], + tasks: list[TaskSpec], +) -> None: + """Append annotation columns to per-experiment zarr obs. + + For each experiment in ``annotations``, loads the matching per-experiment + zarr, joins task columns from the annotation CSV, and persists the + updated obs back to zarr. + + When ``tasks`` is empty, auto-discovers task columns from the + annotation CSV (every column except the join keys ``fov_name``, ``t``, + ``track_id``, ``id``). This supports Wave-2 datasets that publish + annotations independently of any LC training task list. + + Parameters + ---------- + embeddings_path : Path + Directory containing per-experiment zarrs named ``{experiment}.zarr``. + annotations : list[AnnotationSource] + Per-experiment annotation CSV sources. Each entry maps an experiment + name to a CSV path with task columns. + tasks : list[TaskSpec] + Tasks to join (e.g. infection_state, organelle_state). Empty list → + auto-discover from the CSV. + """ + import pandas as pd + + explicit_tasks = [t.task for t in tasks] + join_keys = {"fov_name", "t", "track_id", "id"} + + if explicit_tasks: + click.echo(f"Appending annotations for {len(annotations)} experiments, tasks: {explicit_tasks}") + else: + click.echo( + f"Appending annotations for {len(annotations)} experiments, " + "tasks auto-discovered per-CSV (all non-join-key columns)" + ) + + for ann_src in annotations: + experiment = ann_src.experiment + zarr_path = embeddings_path / f"{experiment}.zarr" + + if not zarr_path.exists(): + click.echo(f" [{experiment}] zarr not found, skipping: {zarr_path}", err=True) + continue + + ann_path = Path(ann_src.path) + if not ann_path.exists(): + raise FileNotFoundError(f"Annotation CSV not found: {ann_src.path}") + + # Resolve task list: explicit if provided, else discover from this CSV. + if explicit_tasks: + task_names = explicit_tasks + else: + csv_cols = pd.read_csv(ann_path, nrows=0).columns.tolist() + task_names = [c for c in csv_cols if c not in join_keys] + click.echo(f" [{experiment}] discovered tasks from CSV: {task_names}") + + click.echo(f"\n [{experiment}]") + adata = ad.read_zarr(zarr_path) + click.echo(f" Loaded {adata.n_obs} cells") + + n_joined = 0 + for task_name in task_names: + try: + adata = load_annotation_anndata(adata, str(ann_path), task_name) + n_valid = int(adata.obs[task_name].notna().sum()) + click.echo(f" {task_name}: {n_valid}/{adata.n_obs} labeled") + n_joined += 1 + except KeyError: + click.echo(f" {task_name}: not in {ann_path.name}, skipping") + + if n_joined == 0: + click.echo(f" No tasks found in {ann_path.name}, skipping zarr write") + continue + + append_to_anndata_zarr(zarr_path, obs=adata.obs) + click.echo(f" Saved obs to {zarr_path}") + + click.echo("\nDone.") + + +class _AppendAnnotationsConfig: + def __init__(self, raw: dict): + self.embeddings_path = Path(raw["embeddings_path"]) + self.annotations = [AnnotationSource(**a) for a in raw["annotations"]] + self.tasks = [TaskSpec(**t) for t in raw["tasks"]] + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to YAML configuration file", +) +def main(config: Path) -> None: + """Append annotation columns to per-experiment AnnData zarr stores.""" + click.echo("=" * 60) + click.echo("APPEND ANNOTATIONS") + click.echo("=" * 60) + raw = load_config(config) + cfg = _AppendAnnotationsConfig(raw) + append_annotations(cfg.embeddings_path, cfg.annotations, cfg.tasks) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/append_predictions.py b/applications/dynaclr/src/dynaclr/evaluation/append_predictions.py new file mode 100644 index 000000000..0d45560c6 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/append_predictions.py @@ -0,0 +1,198 @@ +"""CLI for applying saved linear classifiers to per-experiment AnnData zarr stores. + +Reads the pipelines manifest written by ``dynaclr run-linear-classifiers``, +applies each saved classifier to ALL cells with the matching marker in each +per-experiment zarr, and writes predictions back to obs/obsm/uns. + +This enables plots colored by predicted labels (e.g. predicted_infection_state) +for every cell, including unannotated ones. + +Called as a step in the Nextflow evaluation pipeline after linear classifiers +have been trained (LINEAR_CLASSIFIERS step). + +Usage +----- +dynaclr append-predictions -c append_predictions.yaml +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import anndata as ad +import click +import joblib +import numpy as np + +from viscy_utils.cli_utils import load_config +from viscy_utils.evaluation.zarr_utils import append_to_anndata_zarr + + +def append_predictions( + embeddings_path: Path, + pipelines_dir: Path, +) -> None: + """Apply saved classifiers to all cells and write predictions to zarrs. + + ``pipelines_dir`` may be a ``latest`` symlink into the central LC registry + (e.g. ``/hpc/.../linear_classifiers/DynaCLR-2D-MIP-BagOfChannels/latest``). + The symlink is resolved **once** at startup so the whole run is consistent + even if a new version is published mid-run. + + For each per-experiment zarr, loads all saved classifier pipelines and + applies each one to cells with the matching marker. Results are merged + per task (one ``predicted_{task}`` column per task regardless of how + many marker-specific classifiers contributed), then persisted to zarr. + + Parameters + ---------- + embeddings_path : Path + Directory containing per-experiment zarrs named ``{experiment}.zarr``. + pipelines_dir : Path + Directory containing ``manifest.json`` and ``{task}_{marker}.joblib`` + pipeline files. If this is the ``latest`` symlink, it is resolved + to a ``vN/`` target before loading. + """ + resolved = pipelines_dir.resolve() + version_tag = resolved.name + # Registry layout: {registry_root}/{model_name}/vN/. Two levels up from + # vN is the registry root; one level up is the per-model dir (== model + # name). This is the feature_space identifier. + feature_space = resolved.parent.name if resolved.parent != resolved else "" + click.echo(f"LC pipelines: {pipelines_dir} -> {resolved}") + click.echo(f" feature_space={feature_space} version={version_tag}") + + manifest_path = resolved / "manifest.json" + if not manifest_path.exists(): + raise FileNotFoundError( + f"Pipeline manifest not found: {manifest_path}. Run dynaclr run-linear-classifiers first." + ) + + with open(manifest_path) as f: + manifest_data = json.load(f) + + # New-format manifest: dict with {trained_at, pipelines: [...]}. + if not isinstance(manifest_data, dict) or "pipelines" not in manifest_data: + raise ValueError( + f"Manifest at {manifest_path} is not in the expected format " + "(dict with 'pipelines' key). Re-train with the current " + "run-linear-classifiers to produce a compatible bundle." + ) + manifest_entries = manifest_data["pipelines"] + trained_at = manifest_data.get("trained_at", "") + click.echo(f" trained_at={trained_at}") + + if not manifest_entries: + click.echo("No pipelines in manifest, nothing to do.") + return + + click.echo(f" {len(manifest_entries)} pipeline(s):") + for entry in manifest_entries: + click.echo(f" {entry['task']} / marker={entry['marker_filter']}") + + manifest_markers = {e["marker_filter"] for e in manifest_entries} + + zarr_paths = sorted(embeddings_path.glob("*.zarr")) + if not zarr_paths: + raise FileNotFoundError(f"No .zarr files found in {embeddings_path}") + + click.echo(f"\nProcessing {len(zarr_paths)} per-experiment zarr(s)...") + + for zarr_path in zarr_paths: + click.echo(f"\n {zarr_path.stem}") + adata = ad.read_zarr(zarr_path) + zarr_markers = set(adata.obs["marker"].unique().tolist()) + click.echo(f" {adata.n_obs} cells, markers: {sorted(zarr_markers)}") + + # Coverage report: which zarr markers are predictable from this bundle? + covered = sorted(zarr_markers & manifest_markers) + missing = sorted(zarr_markers - manifest_markers) + click.echo( + f" LC coverage: {len(covered)}/{len(zarr_markers)} markers predictable" + + (f"; missing: {missing}" if missing else "") + ) + + # Group manifest entries by task + tasks_seen: set[str] = {entry["task"] for entry in manifest_entries} + + new_obsm: dict[str, np.ndarray] = {} + + for task in sorted(tasks_seen): + task_entries = [e for e in manifest_entries if e["task"] == task] + + first_pipeline = joblib.load(resolved / task_entries[0]["path"]) + n_classes = len(first_pipeline.classifier.classes_) + classes = first_pipeline.classifier.classes_.tolist() + + all_pred = np.full(adata.n_obs, np.nan, dtype=object) + all_proba = np.full((adata.n_obs, n_classes), np.nan) + + for entry in task_entries: + marker_filter = entry["marker_filter"] + pipeline_path = resolved / entry["path"] + + if not pipeline_path.exists(): + click.echo(f" Pipeline not found: {pipeline_path}, skipping", err=True) + continue + + marker_mask = (adata.obs["marker"] == marker_filter).to_numpy() + n_matching = int(marker_mask.sum()) + if n_matching == 0: + click.echo(f" {task}/{marker_filter}: no matching cells, skipping") + continue + + pipeline = joblib.load(pipeline_path) + adata_subset = adata[marker_mask] + + X_subset = adata_subset.X if isinstance(adata_subset.X, np.ndarray) else adata_subset.X.toarray() + preds = pipeline.predict(X_subset) + probas = pipeline.predict_proba(X_subset) + + all_pred[marker_mask] = preds + all_proba[marker_mask] = probas + click.echo(f" {task}/{marker_filter}: predicted {n_matching} cells") + + adata.obs[f"predicted_{task}"] = all_pred + adata.uns[f"predicted_{task}_classes"] = classes + adata.uns[f"predicted_{task}_lc_version"] = version_tag + adata.uns[f"predicted_{task}_lc_feature_space"] = feature_space + adata.uns[f"predicted_{task}_lc_path"] = str(resolved) + new_obsm[f"predicted_{task}_proba"] = all_proba + + if not new_obsm: + click.echo(" No predictions written (no matching markers)") + continue + + append_to_anndata_zarr(zarr_path, obs=adata.obs, obsm=new_obsm, uns=adata.uns) + click.echo(f" Saved predictions to {zarr_path}") + + click.echo("\nDone.") + + +class _AppendPredictionsConfig: + def __init__(self, raw: dict): + self.embeddings_path = Path(raw["embeddings_path"]) + self.pipelines_dir = Path(raw["pipelines_dir"]) + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to YAML configuration file", +) +def main(config: Path) -> None: + """Apply saved linear classifiers to per-experiment zarrs and write predictions.""" + click.echo("=" * 60) + click.echo("APPEND PREDICTIONS") + click.echo("=" * 60) + raw = load_config(config) + cfg = _AppendPredictionsConfig(raw) + append_predictions(cfg.embeddings_path, cfg.pipelines_dir) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/config.py b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/config.py index 77af8cf07..20e028f05 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/config.py +++ b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/config.py @@ -34,6 +34,10 @@ class SmoothnessEvalConfig(BaseModel): Whether to use memory-optimized computation. verbose : bool Print verbose progress messages. + group_by : str or None + obs column to group by before computing smoothness (e.g. "marker"). + Smoothness is computed per group; the reported aggregate stats are + mean ± std across groups. Set to null to compute on the whole embedding. """ models: list[ModelEntry] = Field(..., min_length=1) @@ -44,6 +48,7 @@ class SmoothnessEvalConfig(BaseModel): save_distributions: bool = False use_optimized: bool = True verbose: bool = False + group_by: Optional[str] = "marker" @model_validator(mode="after") def validate_paths(self): diff --git a/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/evaluate_smoothness.py b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/evaluate_smoothness.py index 91a2e6db7..ae9e7c650 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/evaluate_smoothness.py +++ b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/smoothness/evaluate_smoothness.py @@ -50,6 +50,7 @@ def main(config: Path): for i, model_entry in enumerate(config.models, 1): model_path = Path(model_entry.path) model_label = model_entry.label + experiment_name = model_path.stem click.echo(f"\nProcessing {i}/{len(config.models)}: {model_label}...") @@ -60,28 +61,87 @@ def main(config: Path): if config.verbose: click.echo(f" Loaded {features_ad.shape[0]:,} samples with {features_ad.shape[1]} features") - stats, distributions, _ = compute_embeddings_smoothness( - features_ad, - distance_metric=config.distance_metric, - verbose=config.verbose, - ) + group_col = config.group_by + if group_col and group_col in features_ad.obs.columns: + groups = features_ad.obs[group_col].unique().tolist() + click.echo(f" Computing smoothness per {group_col}: {groups}") + + per_group_rows = [] + group_stats_list = [] + group_distributions = {} + + for group_val in groups: + mask = features_ad.obs[group_col] == group_val + group_ad = features_ad[mask].copy() + + if config.verbose: + click.echo(f" {group_col}={group_val}: {group_ad.shape[0]:,} cells") + + g_stats, g_dists, _ = compute_embeddings_smoothness( + group_ad, + distance_metric=config.distance_metric, + verbose=config.verbose, + ) + per_group_rows.append({group_col: group_val, **g_stats}) + group_stats_list.append(g_stats) + group_distributions[group_val] = g_dists + + if config.save_plots: + _create_smoothness_plot( + g_dists, + g_stats, + f"{model_label}_{experiment_name}_{group_val}", + config.distance_metric, + output_dir, + ) + + per_group_df = pd.DataFrame(per_group_rows) + per_group_df.insert(0, "experiment", experiment_name) + per_group_df.to_csv( + output_dir / f"{model_label}_{experiment_name}_per_{group_col}_smoothness.csv", index=False + ) + click.echo(f" Per-{group_col} stats saved.") + + # Aggregate: mean ± std across groups + metric_cols = [c for c in per_group_df.columns if c != group_col] + agg_means = per_group_df[metric_cols].mean() + agg_stds = per_group_df[metric_cols].std() + stats = agg_means.to_dict() + stats_std = {f"{k}_std": v for k, v in agg_stds.to_dict().items()} + stats.update(stats_std) + + # Concatenate distributions across groups for the combined plot + distributions = { + "adjacent_frame_distribution": np.concatenate( + [d["adjacent_frame_distribution"] for d in group_distributions.values()] + ), + "random_frame_distribution": np.concatenate( + [d["random_frame_distribution"] for d in group_distributions.values()] + ), + } + else: + stats, distributions, _ = compute_embeddings_smoothness( + features_ad, + distance_metric=config.distance_metric, + verbose=config.verbose, + ) all_results[model_label] = stats all_distributions[model_label] = distributions save_results( stats, - output_dir / f"{model_label}_smoothness_stats.csv", + output_dir / f"{model_label}_{experiment_name}_smoothness_stats.csv", format="csv", ) if config.save_distributions: np.save( - output_dir / f"{model_label}_adjacent_distribution.npy", + output_dir / f"{model_label}_{experiment_name}_adjacent_distribution.npy", distributions["adjacent_frame_distribution"], ) np.save( - output_dir / f"{model_label}_random_distribution.npy", + output_dir / f"{model_label}_{experiment_name}_random_distribution.npy", distributions["random_frame_distribution"], ) @@ -91,7 +151,7 @@ def main(config: Path): _create_smoothness_plot( distributions, stats, - model_label, + f"{model_label}_{experiment_name}", config.distance_metric, output_dir, ) diff --git a/applications/dynaclr/src/dynaclr/evaluation/pseudotime/__init__.py b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/__init__.py similarity index 100% rename from applications/dynaclr/src/dynaclr/evaluation/pseudotime/__init__.py rename to applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/__init__.py diff --git a/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/config.py b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/config.py new file mode 100644 index 000000000..aae0b9a9d --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/config.py @@ -0,0 +1,107 @@ +"""Configuration models for CTC tracking accuracy evaluation.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class ONNXModelEntry(BaseModel): + """One model to benchmark. + + Parameters + ---------- + path : str or None + Path to the ONNX model file. None runs the baseline (IoU + spatial edges only, + no embedding model). + label : str + Display name for this model in results. + pixel_size_um : float or None + Pixel size (µm/px) the model was trained at. Used to rescale input crops + when the dataset pixel size differs. None disables rescaling. + """ + + path: str | None + label: str + pixel_size_um: float | None = None + + +class CTCDatasetEntry(BaseModel): + """One CTC dataset directory. + + Parameters + ---------- + path : str + Path to the dataset root (e.g. /hpc/reference/group.royer/CTC/training/BF-C2DL-HSC). + Must contain ``{seq}_ERR_SEG/``, ``{seq}/`` (raw images), and ``{seq}_GT/TRA/`` + subdirectories for each sequence. + sequences : list[str] + Sequence numbers to evaluate (e.g. ["01", "02"]). + pixel_size_um : float or None + Pixel size (µm/px) of the raw images. Used with ``ONNXModelEntry.pixel_size_um`` + to rescale crops before ONNX inference. If None, looked up from + ``TrackingAccuracyConfig.ctc_metadata_path`` by dataset name, then + falls back to reading TIFF XResolution metadata. + """ + + path: str + sequences: list[str] = Field(default=["01", "02"]) + pixel_size_um: float | None = None + + +class TrackingAccuracyConfig(BaseModel): + """Configuration for CTC tracking accuracy evaluation. + + Parameters + ---------- + models : list[ONNXModelEntry] + Models to benchmark. Include an entry with ``path: null`` for the IoU baseline. + datasets : list[CTCDatasetEntry] + CTC datasets to evaluate. + model_input_shape : tuple[int, int] + Height x width of the ONNX model input (must match what the model was exported with). + Default (160, 160) matches the DynaCLR-2D-MIP training resolution. + distance_threshold : float + Maximum spatial distance (pixels) for candidate edges in DistanceEdges. + n_neighbors : int + Maximum candidate edges per cell. + delta_t : int + Maximum frame gap for candidate edges. + division_weight : float + ILP solver weight for cell division events. + appearance_weight : float + ILP solver weight for cell appearance. + disappearance_weight : float + ILP solver weight for cell disappearance. + node_weight : float + ILP solver weight per node (negative = prefer more detections). + output_dir : str + Directory for results CSV. + ctc_metrics : list[str] or None + CTC metric names to include in output. None = all available metrics. + batch_size : int + Number of cell crops per ONNX inference call. + ctc_metadata_path : str or None + Path to a CTC metadata YAML mapping dataset names to + ``[interval_min, y_um, x_um]``. Used to look up pixel size when + ``CTCDatasetEntry.pixel_size_um`` is not set. Falls back to reading + TIFF XResolution tags if the dataset is not in the file. + show_napari : bool + Open a napari viewer after tracking each sequence. Only use when running + interactively on a partition with a display. Default: False. + """ + + models: list[ONNXModelEntry] = Field(..., min_length=1) + datasets: list[CTCDatasetEntry] = Field(..., min_length=1) + ctc_metadata_path: str | None = None + model_input_shape: tuple[int, int] = (160, 160) + distance_threshold: float = 325.0 + n_neighbors: int = 10 + delta_t: int = 5 + division_weight: float = 0.5 + appearance_weight: float = 0.0 + disappearance_weight: float = 0.0 + node_weight: float = -10.0 + output_dir: str + ctc_metrics: list[str] | None = None + batch_size: int = 128 + show_napari: bool = False diff --git a/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/evaluate_tracking.py b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/evaluate_tracking.py new file mode 100644 index 000000000..c4005e068 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/evaluate_tracking.py @@ -0,0 +1,484 @@ +"""CLI tool for CTC tracking accuracy benchmarking with DynaCLR embeddings. + +Evaluates how well DynaCLR embedding similarity, used as an additional edge cost, +improves cell tracking accuracy on CTC (Cell Tracking Challenge) benchmark datasets. + +For each (ONNX model, CTC dataset, sequence) combination: +1. Load segmentation masks and raw images. +2. Build a tracksdata graph (nodes from masks, candidate edges via DistanceEdges). +3. If a model is provided, run ONNX inference on cell crops and weight edges by + embedding cosine similarity * spatial distance weight. +4. If no model is provided, use IoU + spatial distance (baseline). +5. Solve the tracking with ILP and evaluate against CTC ground truth. + +Usage +----- +dynaclr evaluate-tracking-accuracy -c tracking_accuracy_config.yaml +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +import click +import numpy as np +import polars as pl +import tracksdata as td +from dask.array.image import imread +from numpy.typing import NDArray +from rich import print as rprint +from skimage.transform import resize + +from dynaclr.evaluation.benchmarking.tracking_accuracy.config import ( + CTCDatasetEntry, + ONNXModelEntry, + TrackingAccuracyConfig, +) +from dynaclr.evaluation.benchmarking.tracking_accuracy.utils import ( + normalize_crop, + pad_to_shape, + seg_dir, +) +from viscy_utils.cli_utils import load_config + +_logger = logging.getLogger(__name__) + + +def _load_ctc_metadata(path: Path) -> dict[str, float]: + """Load dataset name → x pixel size (µm) from Jordao's CTC metadata YAML. + + Format: ``dataset_name: [interval_min, y_um, x_um]`` + + Parameters + ---------- + path : Path + Path to the metadata YAML file. + + Returns + ------- + dict[str, float] + Mapping from dataset name to x pixel size in µm. + """ + import yaml + + with open(path) as f: + raw = yaml.safe_load(f) + # value is [interval_min, y_um, x_um] — take x (index 2) + return {name: values[2] for name, values in raw.items() if isinstance(values, list)} + + +def _crop_embedding( + frame: NDArray, + mask: list, + source_shape: tuple[int, int], + final_shape: tuple[int, int], + session: Any, + input_name: str, +) -> list[NDArray]: + """Crop cells from a frame and compute DynaCLR embeddings via ONNX. + + Parameters + ---------- + frame : NDArray + Raw image frame (2-D or 3-D with a single z-slice). + mask : list[td.nodes.Mask] + Cell masks for this frame. The parameter name must match the graph + attribute key (``"mask"`` in ``attr_keys``). + source_shape : tuple[int, int] + (height, width) to extract from the image in dataset pixels. + If different from ``final_shape``, the crop is resized to ``final_shape`` + to correct for pixel size differences between dataset and training data. + final_shape : tuple[int, int] + (height, width) of the model input (must match ONNX input size). + session : ort.InferenceSession + ONNX runtime inference session. + input_name : str + Name of the ONNX model's input tensor. + + Returns + ------- + list[NDArray] + L2-normalized embedding vector for each mask (same order). + """ + # Compute frame-level stats once — matches timepoint_statistics normalization used in training + frame_f32 = frame.astype(np.float32) + frame_mean = float(np.mean(frame_f32)) + frame_std = float(np.std(frame_f32)) + + label_img = np.zeros_like(frame, dtype=np.int16) + crops = [] + + for i, m in enumerate(mask, start=1): + if frame.ndim == 3: + extract_shape = (1, *source_shape) + else: + extract_shape = source_shape + + label_img[m.mask_indices()] = i + + crop = m.crop(frame, shape=extract_shape).astype(np.float32) + + if crop.ndim == 3: + if crop.shape[0] != 1: + raise ValueError(f"Expected 1 z-slice in 3D crop, got {crop.shape[0]}") + crop = crop[0] + + crop = pad_to_shape(crop, source_shape, mode="reflect") + + if source_shape != final_shape: + crop = resize(crop, final_shape, order=1, anti_aliasing=True, preserve_range=True).astype(np.float32) + + crop = normalize_crop(crop, frame_mean, frame_std) + + if crop.shape != final_shape: + raise ValueError(f"Crop shape {crop.shape} != final_shape {final_shape}") + + crops.append(crop) + + # shape: (batch, channel, z, h, w) + batch = np.stack(crops, axis=0)[:, np.newaxis, np.newaxis, ...] + output = session.run(None, {input_name: batch}) + + embeddings = output[0] # backbone features (e.g. 768-dim) + embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) + return list(embeddings) + + +def _add_dynaclr_attrs( + model_path: Path, + graph: td.graph.InMemoryGraph, + images: NDArray, + model_input_shape: tuple[int, int], + batch_size: int, + pixel_size_scale: float, +) -> None: + """Add DynaCLR embedding node attributes and cosine similarity edge attributes. + + Parameters + ---------- + model_path : Path + Path to the exported ONNX model. + graph : td.graph.InMemoryGraph + Graph with nodes already added (must have ``mask`` attribute). + images : NDArray + Raw image stack, shape (T, H, W) or (T, Z, H, W). + model_input_shape : tuple[int, int] + (height, width) of the ONNX model input (e.g. (160, 160)). + batch_size : int + Number of crops per ONNX inference call. + pixel_size_scale : float + Ratio of dataset pixel size to model training pixel size + (dataset_um / model_um). Crops are extracted at + ``model_input_shape * pixel_size_scale`` and resized to ``model_input_shape``. + Use 1.0 when no rescaling is needed. + """ + import onnxruntime as ort + + session_options = ort.SessionOptions() + session_options.intra_op_num_threads = 1 + session_options.inter_op_num_threads = 1 + session = ort.InferenceSession( + str(model_path), + sess_options=session_options, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + input_name = session.get_inputs()[0].name + _logger.info( + "ONNX model: input='%s' shape=%s type=%s", + input_name, + session.get_inputs()[0].shape, + session.get_inputs()[0].type, + ) + + source_shape = ( + round(model_input_shape[0] * pixel_size_scale), + round(model_input_shape[1] * pixel_size_scale), + ) + _logger.info( + "Crop pipeline: extract %s px -> resize to %s px (scale=%.3f)", + source_shape, + model_input_shape, + pixel_size_scale, + ) + + from toolz import curry + + crop_fn = curry(_crop_embedding)( + source_shape=source_shape, + final_shape=model_input_shape, + session=session, + input_name=input_name, + ) + + graph.add_node_attr_key("dynaclr_embedding", dtype=pl.List(pl.Float32)) + + td.nodes.GenericFuncNodeAttrs( + func=crop_fn, + output_key="dynaclr_embedding", + attr_keys=["mask"], + batch_size=batch_size, + ).add_node_attrs(graph, frames=images) + + td.edges.GenericFuncEdgeAttrs( + func=np.dot, + output_key="dynaclr_similarity", + attr_keys="dynaclr_embedding", + ).add_edge_attrs(graph) + + +def _build_and_solve( + model_path: Path | None, + images: NDArray, + labels: NDArray, + config: TrackingAccuracyConfig, + pixel_size_scale: float = 1.0, +) -> tuple[td.graph.InMemoryGraph, td.graph.InMemoryGraph]: + """Build a tracksdata graph and solve tracking. + + Parameters + ---------- + model_path : Path or None + ONNX model path. None uses the IoU + spatial baseline. + images : NDArray + Raw image stack (T, H, W). + labels : NDArray + Segmentation label stack (T, H, W). + config : TrackingAccuracyConfig + Evaluation configuration. + pixel_size_scale : float + Ratio of dataset pixel size to model training pixel size + (dataset_um / model_um). Passed to ``_add_dynaclr_attrs``. Default 1.0. + + Returns + ------- + graph : td.graph.InMemoryGraph + Full candidate graph (all nodes + candidate edges). + solution_graph : td.graph.InMemoryGraph + ILP-solved tracking result. + """ + graph = td.graph.InMemoryGraph() + + td.nodes.RegionPropsNodes().add_nodes(graph, labels=labels) + _logger.info("Nodes: %d", graph.num_nodes()) + + dist_op = td.edges.DistanceEdges( + distance_threshold=config.distance_threshold, + n_neighbors=config.n_neighbors, + delta_t=config.delta_t, + ) + dist_op.add_edges(graph) + _logger.info("Candidate edges: %d", graph.num_edges()) + + td.edges.GenericFuncEdgeAttrs( + func=lambda x, y: abs(x - y), + output_key="delta_t", + attr_keys="t", + ).add_edge_attrs(graph) + + dist_weight = (-td.EdgeAttr(td.DEFAULT_ATTR_KEYS.EDGE_DIST) / config.distance_threshold).exp() + + if model_path is not None: + _add_dynaclr_attrs(model_path, graph, images, config.model_input_shape, config.batch_size, pixel_size_scale) + edge_weight = -td.EdgeAttr("dynaclr_similarity") * dist_weight + else: + td.edges.IoUEdgeAttr(output_key="iou").add_edge_attrs(graph) + edge_weight = -(td.EdgeAttr("iou") + 0.1) * dist_weight + + edge_weight = edge_weight / td.EdgeAttr("delta_t").clip(lower_bound=1) + + solver = td.solvers.ILPSolver( + edge_weight=edge_weight, + appearance_weight=config.appearance_weight, + disappearance_weight=config.disappearance_weight, + division_weight=config.division_weight, + node_weight=config.node_weight, + ) + solution_graph = solver.solve(graph) + + return graph, solution_graph + + +def _show_napari_viewer( + graph: td.graph.InMemoryGraph, + images: NDArray, + labels: NDArray, +) -> None: + """Open a napari viewer with the tracking result overlaid on the raw images. + + Parameters + ---------- + graph : td.graph.InMemoryGraph + Full candidate graph (used to derive napari tracks format). + images : NDArray + Raw image stack (T, H, W). + labels : NDArray + Segmentation label stack (T, H, W). + """ + import napari + + tracks_df, track_graph, label_stack = td.functional.to_napari_format( + graph, labels.shape, mask_key=td.DEFAULT_ATTR_KEYS.MASK + ) + viewer = napari.Viewer() + viewer.add_image(images) + viewer.add_labels(label_stack) + viewer.add_tracks(tracks_df, graph=track_graph) + napari.run() + + +def track_single_dataset( + dataset_entry: CTCDatasetEntry, + sequence: str, + model_entry: ONNXModelEntry, + config: TrackingAccuracyConfig, +) -> dict: + """Track one CTC sequence and evaluate metrics. + + Parameters + ---------- + dataset_dir : Path + CTC dataset root. + sequence : str + Sequence number (e.g. "01"). + model_entry : ONNXModelEntry + Model to use (path=None for baseline). + config : TrackingAccuracyConfig + Evaluation configuration. + + Returns + ------- + dict + CTC metrics dict plus ``model``, ``dataset``, ``sequence`` keys. + """ + dataset_dir = Path(dataset_entry.path) + _seg_dir = seg_dir(dataset_dir, sequence) + if not _seg_dir.exists(): + raise FileNotFoundError(f"Segmentation directory not found: {_seg_dir}") + + model_path = Path(model_entry.path) if model_entry.path is not None else None + + _logger.info("Loading labels from %s", _seg_dir) + labels = imread(str(_seg_dir / "*.tif")).compute() + images = imread(str(dataset_dir / sequence / "*.tif")).compute() + + gt_graph = td.graph.InMemoryGraph.from_ctc(dataset_dir / f"{sequence}_GT" / "TRA") + + _logger.info( + "Tracking: model=%s dataset=%s seq=%s", + model_entry.label, + dataset_dir.name, + sequence, + ) + dataset_pixel_size = dataset_entry.pixel_size_um + if dataset_pixel_size is None and config.ctc_metadata_path is not None: + ctc_meta = _load_ctc_metadata(Path(config.ctc_metadata_path)) + dataset_pixel_size = ctc_meta.get(dataset_dir.name) + if dataset_pixel_size is not None: + _logger.info("Pixel size from metadata: %.4f µm/px (%s)", dataset_pixel_size, dataset_dir.name) + else: + _logger.warning( + "Dataset %s not found in %s; no rescaling applied", dataset_dir.name, config.ctc_metadata_path + ) + + if model_entry.pixel_size_um is not None and dataset_pixel_size is not None: + pixel_size_scale = dataset_pixel_size / model_entry.pixel_size_um + else: + pixel_size_scale = 1.0 + + graph, solution_graph = _build_and_solve(model_path, images, labels, config, pixel_size_scale) + + if config.show_napari: + _show_napari_viewer(graph, images, labels) + + _logger.info("Evaluating CTC metrics ...") + metrics = td.metrics.evaluate_ctc_metrics( + solution_graph, + gt_graph, + input_reset=False, + reference_reset=False, + metrics=config.ctc_metrics, + ) + + metrics["model"] = model_entry.label + metrics["dataset"] = dataset_dir.name + metrics["sequence"] = sequence + return metrics + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to tracking accuracy YAML configuration file", +) +def main(config: Path) -> None: + """Evaluate CTC tracking accuracy with DynaCLR ONNX embeddings. + + Runs ILP-based tracking on CTC benchmark datasets, comparing a spatial+IoU + baseline against models that use DynaCLR embedding similarity as an additional + edge cost. Writes results.csv to the configured output directory. + """ + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(name)s %(levelname)s %(message)s", + ) + + raw = load_config(config) + cfg = TrackingAccuracyConfig(**raw) + + output_dir = Path(cfg.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + results: list[dict] = [] + + for model_entry in cfg.models: + for dataset_entry in cfg.datasets: + dataset_dir = Path(dataset_entry.path) + for sequence in dataset_entry.sequences: + _seg = seg_dir(dataset_dir, sequence) + if not _seg.exists(): + click.echo( + f"Skipping {dataset_dir.name}/{sequence}: no segmentation at {_seg}", + err=True, + ) + continue + + try: + row = track_single_dataset(dataset_entry, sequence, model_entry, cfg) + except Exception as exc: + click.echo( + f"Error {model_entry.label} / {dataset_dir.name} / {sequence}: {exc}", + err=True, + ) + _logger.exception("Tracking failed") + continue + + rprint(row) + results.append(row) + + # Write incrementally so partial results are never lost + df = pl.DataFrame(results) + df.write_csv(output_dir / "results.csv") + + if not results: + click.echo("No results produced.", err=True) + return + + df = pl.DataFrame(results) + df.write_csv(output_dir / "results.csv") + click.echo(f"\nResults written to {output_dir / 'results.csv'}") + + # Summary: mean across sequences, grouped by model x dataset + key_metrics = [c for c in ["LNK", "BIO(0)", "OP_CLB(0)", "CHOTA", "TRA", "DET"] if c in df.columns] + if key_metrics: + summary = df.group_by("model", "dataset").agg([pl.col(m).mean() for m in key_metrics]).sort("model", "dataset") + click.echo("\n## Tracking Accuracy Summary (mean over sequences)\n") + click.echo(summary.to_pandas().to_markdown(index=False, floatfmt=".3f")) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/utils.py b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/utils.py new file mode 100644 index 000000000..8fc998465 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/benchmarking/tracking_accuracy/utils.py @@ -0,0 +1,66 @@ +"""Utilities for CTC tracking accuracy evaluation.""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +from numpy.typing import NDArray + + +def seg_dir(dataset_dir: Path, sequence: str) -> Path: + """Return path to the error-segmentation directory for a CTC sequence. + + Parameters + ---------- + dataset_dir : Path + CTC dataset root (e.g. .../BF-C2DL-HSC). + sequence : str + Sequence number (e.g. "01"). + """ + return dataset_dir / f"{sequence}_ERR_SEG" + + +def pad_to_shape(image: NDArray, shape: tuple[int, int], mode: str) -> NDArray: + """Pad image symmetrically to at least the given spatial shape. + + Parameters + ---------- + image : NDArray + 2-D array to pad. + shape : tuple[int, int] + Target (height, width). No-op if image is already large enough. + mode : str + Padding mode passed to ``np.pad``. + """ + diff = np.asarray(shape) - np.asarray(image.shape) + if diff.sum() == 0: + return image + left = diff // 2 + right = diff - left + return np.pad(image, tuple(zip(left, right)), mode=mode) + + +def normalize_crop(crop: NDArray, frame_mean: float, frame_std: float) -> NDArray: + """Z-score normalize a cell crop using whole-frame statistics. + + Matches the training normalization (``NormalizeSampled`` with + ``level=timepoint_statistics``): mean/std are computed over the full + frame, not the cell foreground, so the model sees the same intensity + distribution it was trained on. + + Parameters + ---------- + crop : NDArray + Float32 2-D cell image. + frame_mean : float + Mean pixel intensity of the full frame at this timepoint. + frame_std : float + Std pixel intensity of the full frame at this timepoint. + + Returns + ------- + NDArray + Z-score normalized crop. + """ + return (crop - frame_mean) / max(frame_std, 1e-8) diff --git a/applications/dynaclr/src/dynaclr/evaluation/check_evals.py b/applications/dynaclr/src/dynaclr/evaluation/check_evals.py new file mode 100644 index 000000000..83b6e3ee2 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/check_evals.py @@ -0,0 +1,161 @@ +"""Check completion status of eval runs defined in an eval registry YAML. + +Derives status from filesystem sentinels rather than stored state, so it +is always ground-truth. + +Usage +----- +dynaclr check-evals -r eval_registry.yaml +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Literal + +import click +import yaml + +from dynaclr.evaluation.evaluate_config import EvaluationConfig +from viscy_utils.cli_utils import load_config + +_STEP_SENTINELS: dict[str, str] = { + "predict": "embeddings/embeddings.zarr", + "split": "configs/viewer.yaml", + "reduce_dimensionality": "configs/reduce.yaml", + "reduce_combined": "configs/reduce_combined.yaml", + "smoothness": "smoothness/combined_smoothness_stats.csv", + "plot": "plots", + "linear_classifiers": "linear_classifiers/metrics_summary.csv", +} + +Status = Literal["done", "partial", "pending"] + + +def _check_mmd_step(output_dir: Path, eval_cfg: EvaluationConfig) -> bool: + """Return True if all MMD blocks have at least one result CSV.""" + if not eval_cfg.mmd: + return True # no MMD configured — not a blocking step + for i, block in enumerate(eval_cfg.mmd): + block_name = block.name if block.name else f"mmd_{i}" + block_dir = output_dir / "mmd" / block_name + if not any(block_dir.glob("*.csv")): + return False + return True + + +def _check_plot_step(output_dir: Path) -> bool: + """Return True if the plots directory has at least one PDF.""" + plots_dir = output_dir / "plots" + return any(plots_dir.rglob("*.pdf")) + + +def _missing_steps(eval_cfg: EvaluationConfig) -> list[str]: + """Return steps from eval_cfg.steps that have not yet produced their sentinel output.""" + output_dir = Path(eval_cfg.output_dir) + missing = [] + for step in eval_cfg.steps: + if step == "mmd": + if not _check_mmd_step(output_dir, eval_cfg): + missing.append(step) + elif step == "plot": + if not _check_plot_step(output_dir): + missing.append(step) + elif step in _STEP_SENTINELS: + sentinel = output_dir / _STEP_SENTINELS[step] + if not sentinel.exists(): + missing.append(step) + # unknown steps: skip silently + return missing + + +def _model_status(eval_cfg: EvaluationConfig, force_rerun: bool) -> tuple[Status, list[str]]: + """Return (status, missing_steps) for one model entry.""" + if force_rerun: + return "pending", ["(force_rerun=true)"] + missing = _missing_steps(eval_cfg) + if not missing: + return "done", [] + if len(missing) < len(eval_cfg.steps): + return "partial", missing + return "pending", missing + + +def _load_registry(registry_path: Path) -> list[dict]: + with open(registry_path) as f: + data = yaml.safe_load(f) + return data["models"] + + +def check_evals(registry: Path, workspace_dir: Path | None) -> None: + """Print a markdown table showing completion status for each registered model.""" + models = _load_registry(registry) + + rows = [] + for entry in models: + name = entry["name"] + force_rerun = entry.get("force_rerun", False) + eval_config_path = Path(entry["eval_config"]) + + # Resolve relative paths against workspace_dir (if provided) or registry location + if not eval_config_path.is_absolute(): + base = workspace_dir if workspace_dir else registry.parent.parent.parent.parent + eval_config_path = base / eval_config_path + + try: + raw = load_config(eval_config_path) + eval_cfg = EvaluationConfig(**raw) + status, missing = _model_status(eval_cfg, force_rerun) + missing_str = ", ".join(missing) if missing else "—" + except FileNotFoundError as e: + status = "pending" + missing_str = f"config not found: {e}" + except Exception as e: # noqa: BLE001 + status = "pending" + missing_str = f"error: {e}" + + rows.append((name, status, missing_str)) + + # Print markdown table + col_name = max(len(r[0]) for r in rows) + col_status = max(len(r[1]) for r in rows) + col_missing = max(len(r[2]) for r in rows) + + col_name = max(col_name, len("Model")) + col_status = max(col_status, len("Status")) + col_missing = max(col_missing, len("Missing Steps")) + + header = f"| {'Model':<{col_name}} | {'Status':<{col_status}} | {'Missing Steps':<{col_missing}} |" + sep = f"| {'-' * col_name} | {'-' * col_status} | {'-' * col_missing} |" + click.echo(header) + click.echo(sep) + for name, status, missing_str in rows: + click.echo(f"| {name:<{col_name}} | {status:<{col_status}} | {missing_str:<{col_missing}} |") + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-r", + "--registry", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to eval_registry.yaml", +) +@click.option( + "--workspace-dir", + type=click.Path(exists=True, path_type=Path), + default=None, + help="Workspace root for resolving relative eval_config paths. Defaults to four levels above the registry file.", +) +def main(registry: Path, workspace_dir: Path | None) -> None: + """Print a markdown table showing eval completion status for each registered model. + + Status is derived from filesystem sentinels — never stored manually. + Set force_rerun: true in the registry to mark a model for re-execution + regardless of existing outputs. + """ + check_evals(registry, workspace_dir) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/config.py b/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/config.py index a5448b261..6d40c192b 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/config.py +++ b/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/config.py @@ -30,6 +30,9 @@ class PHATEConfig(BaseModel): knn_dist: str = "cosine" scale_embeddings: bool = False random_state: int = 42 + n_pca: Optional[int] = 50 + subsample: Optional[int] = 50_000 + n_jobs: int = 1 class DimensionalityReductionConfig(BaseModel): @@ -65,3 +68,64 @@ def validate_config(self): if self.pca is None and self.umap is None and self.phate is None: raise ValueError("At least one reduction method must be specified (pca, umap, or phate)") return self + + +class CombinedDatasetConfig(BaseModel): + """Input dataset spec for combined reductions. + + Parameters + ---------- + anndata : str + Path to AnnData zarr store with features in ``.X``. + hcs_plate : str, optional + Path to the raw HCS plate zarr (not used for reductions, but useful for reuse). + """ + + anndata: str = Field(...) + hcs_plate: Optional[str] = None + + +class CombinedDimensionalityReductionConfig(BaseModel): + """Configuration for computing joint dimensionality reductions across multiple AnnData stores. + + Parameters + ---------- + input_paths : list[str], optional + Paths to AnnData zarr stores. Embeddings from all stores are concatenated before fitting + reductions, then per-store slices are written back with a ``_combined`` suffix. + datasets : dict[str, CombinedDatasetConfig], optional + Alternative to ``input_paths``. When provided, ``input_paths`` is derived from + ``datasets[*].anndata``. This matches the multi-dataset YAML used in organelle dynamics. + pca : PCAConfig, optional + PCA parameters. Results stored as ``X_pca_combined``. + umap : UMAPConfig, optional + UMAP parameters. Results stored as ``X_umap_combined``. + phate : PHATEConfig, optional + PHATE parameters. Results stored as ``X_phate_combined``. + overwrite_keys : bool + If True, overwrite existing ``.obsm`` keys. Otherwise raise on conflict. + """ + + input_paths: Optional[list[str]] = None + datasets: Optional[dict[str, CombinedDatasetConfig]] = None + pca: Optional[PCAConfig] = None + umap: Optional[UMAPConfig] = None + phate: Optional[PHATEConfig] = None + overwrite_keys: bool = False + + @model_validator(mode="after") + def validate_config(self): + if self.input_paths is None: + if not self.datasets: + raise ValueError("Either input_paths or datasets must be provided") + self.input_paths = [d.anndata for d in self.datasets.values()] + + if len(self.input_paths) < 1: + raise ValueError("At least one input path must be provided") + + for p in self.input_paths: + if not Path(p).exists(): + raise ValueError(f"Input path not found: {p}") + if self.pca is None and self.umap is None and self.phate is None: + raise ValueError("At least one reduction method must be specified (pca, umap, or phate)") + return self diff --git a/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/reduce_combined.py b/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/reduce_combined.py new file mode 100644 index 000000000..4352ea267 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/reduce_combined.py @@ -0,0 +1,280 @@ +""" +Joint dimensionality reduction (PCA, UMAP, PHATE) across multiple AnnData zarr stores. + +Concatenates embeddings from all stores, fits joint reductions, +then writes per-store slices back as X_*_combined. + +Usage +----- +dynaclr reduce-combined -c multi-dataset-dim-reduction.yml +""" + +import anndata as ad +import click +import numpy as np + +from viscy_utils.cli_utils import format_markdown_table, load_config_section +from viscy_utils.evaluation.zarr_utils import append_to_anndata_zarr + +from .config import CombinedDimensionalityReductionConfig, PHATEConfig +from .reduce_dimensionality import _run_pca, _run_phate, _run_umap + + +def _phate_per_store_fit_idx( + sample_counts: list[int], + lineage_ids: np.ndarray | None, + cap: int | None, + random_state: int, +) -> tuple[np.ndarray, list[dict]]: + """Build the PHATE fit-set with a per-store lineage cap. + + For each store, draw up to ``cap`` whole lineages (random sample without + replacement). Stores with fewer lineages contribute all of theirs. The + returned indices are global row indices into the concatenated feature + matrix; PHATE is later transformed on the full matrix. + + Parameters + ---------- + sample_counts : list[int] + Row count contributed by each store, in concatenation order. + lineage_ids : np.ndarray or None + Per-row lineage identifier (already store-prefixed so namespaces are + disjoint). When None, falls back to per-store random row sampling. + cap : int or None + Maximum lineages drawn per store. ``None`` keeps every row. + random_state : int + Seed for reproducibility. + + Returns + ------- + fit_idx : np.ndarray + Global row indices used to fit PHATE. + per_store_stats : list[dict] + One row per store with ``store_idx``, ``n_lineages_total``, + ``n_lineages_kept``, ``n_rows_total``, ``n_rows_kept``. + """ + rng = np.random.default_rng(random_state) + fit_indices: list[np.ndarray] = [] + per_store_stats: list[dict] = [] + offset = 0 + for store_idx, n_rows in enumerate(sample_counts): + store_slice = slice(offset, offset + n_rows) + if lineage_ids is None: + # No lineage info: cap rows directly. + if cap is None or n_rows <= cap: + idx = np.arange(offset, offset + n_rows) + kept_lineages = -1 # sentinel: unknown + total_lineages = -1 + else: + local = rng.choice(n_rows, size=cap, replace=False) + idx = local + offset + kept_lineages = -1 + total_lineages = -1 + else: + store_lineages = lineage_ids[store_slice] + unique_lineages = np.unique(store_lineages) + total_lineages = len(unique_lineages) + if cap is None or total_lineages <= cap: + idx = np.arange(offset, offset + n_rows) + kept_lineages = total_lineages + else: + chosen = rng.choice(unique_lineages, size=cap, replace=False) + local_mask = np.isin(store_lineages, chosen) + idx = np.where(local_mask)[0] + offset + kept_lineages = cap + + fit_indices.append(idx) + per_store_stats.append( + { + "store_idx": store_idx, + "n_lineages_total": total_lineages, + "n_lineages_kept": kept_lineages, + "n_rows_total": n_rows, + "n_rows_kept": int(idx.size), + } + ) + offset += n_rows + + return np.concatenate(fit_indices), per_store_stats + + +def _select_phate_input( + combined: np.ndarray, + results: dict[str, np.ndarray], + phate_cfg: PHATEConfig, +) -> tuple[np.ndarray, str]: + """Choose what PHATE fits and transforms on. + + When ``phate.n_pca`` is None, the recipe is signalling that PHATE should + *not* run its own internal PCA — feed the already-PCA-reduced input + directly. We require ``X_pca_combined`` to exist in ``results`` (PCA + runs first when both are configured) and use it. + + When ``phate.n_pca`` is an int, fall back to the raw concatenated + embeddings — PHATE will run its internal PCA on them. + """ + if phate_cfg.n_pca is None: + if "X_pca_combined" not in results: + raise click.ClickException( + "PHATE is configured with n_pca=null (skip internal PCA), " + "but X_pca_combined is not available. Add `pca:` to the " + "reduce_combined recipe so PCA runs before PHATE, or set " + "phate.n_pca to an int to use PHATE's internal PCA " + "(warning: hangs on scipy 1.17.1)." + ) + return results["X_pca_combined"], "X_pca_combined" + return combined, "raw .X (PHATE will run internal PCA)" + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=str), + required=True, + help="Path to YAML configuration file", +) +def main(config: str): + """Compute joint PCA, UMAP, and/or PHATE across multiple AnnData zarr stores.""" + click.echo("Loading configuration...") + raw_config = load_config_section(config, None, default_section="reduce_combined") + cfg = CombinedDimensionalityReductionConfig(**raw_config) + + if hasattr(ad, "settings") and hasattr(ad.settings, "allow_write_nullable_strings"): + ad.settings.allow_write_nullable_strings = True + + resolved_paths = [str(p) for p in cfg.input_paths] + dataset_names = list(cfg.datasets.keys()) if cfg.datasets else None + + # Determine which keys will be written + methods_to_run: list[tuple[str, object]] = [] + if cfg.pca is not None: + methods_to_run.append(("pca", cfg.pca)) + if cfg.umap is not None: + methods_to_run.append(("umap", cfg.umap)) + if cfg.phate is not None: + methods_to_run.append(("phate", cfg.phate)) + + key_map = {"pca": "X_pca_combined", "umap": "X_umap_combined", "phate": "X_phate_combined"} + keys_to_write = [key_map[name] for name, _ in methods_to_run] + + # Check for existing keys before loading data + if not cfg.overwrite_keys: + for path in resolved_paths: + adata = ad.read_zarr(path) + for key in keys_to_write: + if key in adata.obsm: + raise click.ClickException( + f"Key '{key}' already exists in {path}. Use overwrite_keys: true to replace." + ) + + # Load embeddings from all stores. Derive lineage IDs for PHATE + # subsampling: a lineage is (path, fov_name, track_id), prefixed + # with the path index so track IDs from different stores don't + # collide. + all_features = [] + all_lineage_ids: list[np.ndarray] = [] + sample_counts = [] + have_lineage_cols = True + for store_idx, path in enumerate(resolved_paths): + click.echo(f"Reading {path}...") + adata = ad.read_zarr(path) + features = np.asarray(adata.X) + all_features.append(features) + sample_counts.append(features.shape[0]) + if "lineage_id" in adata.obs.columns: + all_lineage_ids.append(adata.obs["lineage_id"].to_numpy()) + elif {"fov_name", "track_id"}.issubset(adata.obs.columns): + fov = adata.obs["fov_name"].astype(str).to_numpy() + tid = adata.obs["track_id"].astype(str).to_numpy() + # Prefix with store_idx to keep lineage namespaces disjoint + # across stores in the concatenated array. + lineage = np.array([f"{store_idx}|{f}|{t}" for f, t in zip(fov, tid)]) + all_lineage_ids.append(lineage) + else: + have_lineage_cols = False + click.echo(f" {features.shape[0]:,} samples x {features.shape[1]} features") + + combined = np.concatenate(all_features, axis=0) + if have_lineage_cols and all_lineage_ids: + combined_lineage_ids = np.concatenate(all_lineage_ids) + n_lineages = int(np.unique(combined_lineage_ids).size) + click.echo(f"Combined: {combined.shape[0]:,} samples x {combined.shape[1]} features, {n_lineages:,} lineages") + else: + combined_lineage_ids = None + click.echo( + f"Combined: {combined.shape[0]:,} samples x {combined.shape[1]} features " + "(no lineage_id / fov_name+track_id; PHATE will use random subsampling)" + ) + + # Compute reductions on joint data + # + # Order matters: PCA runs first when both PCA and PHATE are requested, + # so PHATE can be fit on the already-PCA-reduced X_pca_combined. This + # avoids PHATE's internal PCA pre-reduction (sklearn -> scipy.linalg.lu), + # which deadlocks silently on scipy 1.17.1 + sklearn 1.8.0. Pass + # n_pca=None in the recipe to skip the internal PCA when feeding + # PCA-reduced input. + results: dict[str, np.ndarray] = {} + + runner_map = {"pca": _run_pca, "umap": _run_umap, "phate": _run_phate} + for method_name, method_cfg in methods_to_run: + if method_name == "phate": + assert isinstance(method_cfg, PHATEConfig) + phate_input, source_label = _select_phate_input(combined, results, method_cfg) + click.echo(f" PHATE fitting on {source_label} ({phate_input.shape[1]} dims)") + + fit_idx, per_store_stats = _phate_per_store_fit_idx( + sample_counts=sample_counts, + lineage_ids=combined_lineage_ids, + cap=method_cfg.subsample, + random_state=method_cfg.random_state, + ) + click.echo( + "\n" + + format_markdown_table(per_store_stats, title=f"PHATE per-store fit cap (cap={method_cfg.subsample})") + ) + _, embedding = _run_phate( + phate_input, + method_cfg, + lineage_ids=combined_lineage_ids, + fit_idx=fit_idx, + ) + else: + _, embedding = runner_map[method_name](combined, method_cfg) + out_key = key_map[method_name] + results[out_key] = embedding + click.echo(f" {method_name.upper()} done -> {out_key} ({embedding.shape[1]} components)") + + # Slice and write back to each store + offset = 0 + for i, path in enumerate(resolved_paths): + n = sample_counts[i] + store_obsm = {key: emb[offset : offset + n] for key, emb in results.items()} + store_uns = {} + for method_name, _ in methods_to_run: + store_uns[f"{method_name}_combined_datasets"] = resolved_paths + if dataset_names is not None: + store_uns[f"{method_name}_combined_dataset_names"] = dataset_names + offset += n + + click.echo(f"Writing to {path} ({n:,} rows)...") + append_to_anndata_zarr(path, obsm=store_obsm, uns=store_uns) + + # Summary + summary_data = [] + for key, embedding in sorted(results.items()): + summary_data.append( + { + "method": key, + "components": embedding.shape[1], + "total_samples": embedding.shape[0], + "stores": len(resolved_paths), + } + ) + click.echo("\n" + format_markdown_table(summary_data, title="Combined Dimensionality Reduction")) + click.echo(f"Results written to {len(resolved_paths)} store(s)") + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/reduce_dimensionality.py b/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/reduce_dimensionality.py index ed2b47aa2..2fdb76e8e 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/reduce_dimensionality.py +++ b/applications/dynaclr/src/dynaclr/evaluation/dimensionality_reduction/reduce_dimensionality.py @@ -17,8 +17,9 @@ import numpy as np from numpy.typing import NDArray -from viscy_utils.cli_utils import format_markdown_table, load_config +from viscy_utils.cli_utils import format_markdown_table, load_config_section from viscy_utils.evaluation.zarr_utils import append_to_anndata_zarr +from viscy_utils.mp_utils import available_cpus from .config import ( DimensionalityReductionConfig, @@ -51,9 +52,32 @@ def _run_umap(features: NDArray, cfg: UMAPConfig) -> tuple[str, NDArray]: return "X_umap", umap_embedding -def _run_phate(features: NDArray, cfg: PHATEConfig) -> tuple[str, NDArray]: +def _run_phate( + features: NDArray, + cfg: PHATEConfig, + lineage_ids: NDArray | None = None, + fit_idx: NDArray | None = None, +) -> tuple[str, NDArray]: from viscy_utils.evaluation.dimensionality_reduction import compute_phate + # n_jobs == -1 follows the sklearn convention "use all CPUs", but resolved + # SLURM-aware: respects SLURM_CPUS_PER_TASK so we don't oversubscribe a + # node when the job was allocated only a subset of cores. + n_jobs = available_cpus(default=1) if cfg.n_jobs == -1 else cfg.n_jobs + + set_fields = cfg.model_fields_set + resolved = [ + { + "param": name, + "value": "None (no subsample)" + if name == "subsample" and getattr(cfg, name) is None + else getattr(cfg, name), + "source": "yaml" if name in set_fields else "default", + } + for name in ("subsample", "n_pca", "knn", "decay", "knn_dist", "scale_embeddings", "random_state", "n_jobs") + ] + click.echo("\n" + format_markdown_table(resolved, title="PHATE resolved parameters")) + _, phate_embedding = compute_phate( features, n_components=cfg.n_components, @@ -62,6 +86,11 @@ def _run_phate(features: NDArray, cfg: PHATEConfig) -> tuple[str, NDArray]: knn_dist=cfg.knn_dist, scale_embeddings=cfg.scale_embeddings, random_state=cfg.random_state, + n_pca=cfg.n_pca, + subsample=cfg.subsample, + lineage_ids=lineage_ids, + fit_idx=fit_idx, + n_jobs=n_jobs, ) return "X_phate", phate_embedding @@ -77,7 +106,7 @@ def _run_phate(features: NDArray, cfg: PHATEConfig) -> tuple[str, NDArray]: def main(config: Path): """Compute PCA, UMAP, and/or PHATE on saved embeddings.""" click.echo("Loading configuration...") - raw_config = load_config(config) + raw_config = load_config_section(config, None, default_section="reduce_dimensionality") cfg = DimensionalityReductionConfig(**raw_config) click.echo(f"Reading embeddings from {cfg.input_path}...") @@ -103,10 +132,15 @@ def main(config: Path): click.echo(f"Computing {len(methods_to_run)} reduction(s): {', '.join(name for name, _, _ in methods_to_run)}") + lineage_ids = adata.obs["lineage_id"].to_numpy() if "lineage_id" in adata.obs.columns else None + results = {} for method_name, method_cfg, obsm_key in methods_to_run: try: - key, embedding = runner_map[method_name](features, method_cfg) + if method_name == "phate": + key, embedding = _run_phate(features, method_cfg, lineage_ids=lineage_ids) + else: + key, embedding = runner_map[method_name](features, method_cfg) results[key] = embedding click.echo(f" {method_name.upper()} done -> {key} ({embedding.shape[1]} components)") except Exception as e: diff --git a/applications/dynaclr/src/dynaclr/evaluation/evaluate.py b/applications/dynaclr/src/dynaclr/evaluation/evaluate.py new file mode 100644 index 000000000..360092f5a --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/evaluate.py @@ -0,0 +1,584 @@ +"""Evaluation config generator for DynaCLR trained models. + +Generates per-step YAML configs from a single eval YAML and prints a JSON manifest +mapping step names to config paths. Called internally by the Nextflow PREPARE_CONFIGS step. + +Usage +----- +dynaclr prepare-eval-configs -c eval_config.yaml +""" + +from __future__ import annotations + +import json +import shutil +from pathlib import Path +from typing import Any + +import click +import yaml + +from dynaclr.evaluation.evaluate_config import EvaluationConfig +from viscy_utils.cli_utils import load_config + +_Z_REDUCTION_CLASS = "viscy_transforms.BatchedChannelWiseZReductiond" + +# Placeholders used in template YAMLs that operate per-experiment zarr. +# Nextflow processes substitute these at runtime when handling per-experiment channels. +_ZARR_PLACEHOLDER = "__ZARR_PATH__" +_PLOT_DIR_PLACEHOLDER = "__PLOT_DIR__" + + +def _load_training_config(path: str) -> dict: + with open(path) as f: + return yaml.safe_load(f) + + +def _extract_predict_data_config(training_cfg: dict, eval_cfg: EvaluationConfig) -> dict: + """Extract data init_args for the predict YAML from the training config. + + Strips augmentations (except BatchedChannelWiseZReductiond which is + architecturally required), overrides batch_size and split_ratio. + """ + data_init = dict(training_cfg["data"]["init_args"]) + + # Override cell_index_path if user supplied one + if eval_cfg.cell_index_path is not None: + data_init["cell_index_path"] = eval_cfg.cell_index_path + + # Move z-reduction transform from augmentations to end of normalizations + augmentations = data_init.pop("augmentations", []) or [] + z_reduction = [t for t in augmentations if _is_z_reduction(t)] + normalizations = list(data_init.get("normalizations") or []) + data_init["normalizations"] = normalizations + z_reduction + data_init["augmentations"] = [] + + # Predict-specific overrides + data_init["batch_size"] = eval_cfg.predict.batch_size + data_init["num_workers"] = eval_cfg.predict.num_workers + data_init["split_ratio"] = 1.0 + + # Remove training-only keys that are irrelevant for predict + for key in ["stratify_by", "batch_group_by", "temporal_enrichment", "leaky", "group_weights"]: + data_init.pop(key, None) + + return data_init + + +def _is_z_reduction(transform: Any) -> bool: + """Check if a transform config is BatchedChannelWiseZReductiond.""" + if isinstance(transform, dict): + return transform.get("class_path", "") == _Z_REDUCTION_CLASS + return False + + +def _extract_model_config(training_cfg: dict) -> dict: + """Extract model config, setting drop_path_rate=0 for inference. + + Only sets drop_path_rate if the encoder already declares it (e.g. ContrastiveEncoder). + Encoders like DINOv3Model do not accept this parameter and must not receive it. + """ + model = dict(training_cfg["model"]) + init_args = dict(model.get("init_args", {})) + encoder = dict(init_args.get("encoder", {})) + encoder_init = dict(encoder.get("init_args", {})) + if "drop_path_rate" in encoder_init: + encoder_init["drop_path_rate"] = 0.0 + encoder["init_args"] = encoder_init + init_args["encoder"] = encoder + model["init_args"] = init_args + return model + + +# --------------------------------------------------------------------------- +# YAML config generators +# --------------------------------------------------------------------------- + + +def _generate_predict_yaml(eval_cfg: EvaluationConfig, training_cfg: dict, output_dir: Path) -> Path: + """Generate the Lightning predict YAML config.""" + embeddings_path = str(output_dir / "embeddings" / "embeddings.zarr") + data_init = _extract_predict_data_config(training_cfg, eval_cfg) + model_cfg = _extract_model_config(training_cfg) + + embedding_writer: dict = { + "class_path": "viscy_utils.callbacks.embedding_writer.EmbeddingWriter", + "init_args": { + "output_path": embeddings_path, + "overwrite": True, + "embedding_key": eval_cfg.predict.embedding_key, + }, + } + + predict_cfg: dict = { + "seed_everything": 42, + "trainer": { + "accelerator": "gpu", + "devices": eval_cfg.predict.devices, + "num_nodes": 1, + "precision": eval_cfg.predict.precision, + "inference_mode": True, + "logger": False, + "callbacks": [embedding_writer], + }, + "model": model_cfg, + "data": { + "class_path": training_cfg["data"]["class_path"], + "init_args": data_init, + }, + } + # Only emit ckpt_path when set. Foundation-model baselines (e.g. DINOv3-frozen) + # load weights from HuggingFace inside the model __init__ and have no + # Lightning checkpoint; passing ``ckpt_path: null`` to ``viscy predict`` + # would still try to restore from a None path. + if eval_cfg.ckpt_path is not None: + predict_cfg["ckpt_path"] = eval_cfg.ckpt_path + + out_path = output_dir / "configs" / "predict.yml" + with open(out_path, "w") as f: + yaml.dump(predict_cfg, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + return out_path + + +def _generate_reduce_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate dim reduction template config YAML. + + Uses a placeholder for ``input_path`` because the actual per-experiment + zarr paths are only known after the split step runs. + """ + cfg_dict: dict = { + "input_path": _ZARR_PLACEHOLDER, + "overwrite_keys": eval_cfg.reduce_dimensionality.overwrite_keys, + } + if eval_cfg.reduce_dimensionality.pca: + cfg_dict["pca"] = eval_cfg.reduce_dimensionality.pca.model_dump() + if eval_cfg.reduce_dimensionality.umap: + cfg_dict["umap"] = eval_cfg.reduce_dimensionality.umap.model_dump() + if eval_cfg.reduce_dimensionality.phate: + cfg_dict["phate"] = eval_cfg.reduce_dimensionality.phate.model_dump() + + out_path = output_dir / "configs" / "reduce.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _generate_reduce_combined_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate joint dimensionality reduction config YAML. + + ``input_paths`` is populated at runtime by Nextflow (collecting per-experiment zarrs). + """ + rc = eval_cfg.reduce_combined + cfg_dict: dict = { + "input_paths": [_ZARR_PLACEHOLDER], + "overwrite_keys": rc.overwrite_keys, + } + if rc.pca: + cfg_dict["pca"] = rc.pca.model_dump() + if rc.umap: + cfg_dict["umap"] = rc.umap.model_dump() + if rc.phate: + cfg_dict["phate"] = rc.phate.model_dump() + + out_path = output_dir / "configs" / "reduce_combined.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _generate_smoothness_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate smoothness evaluation config YAML.""" + model_name = Path(eval_cfg.training_config).stem + + cfg_dict = { + "models": [{"path": _ZARR_PLACEHOLDER, "label": model_name}], + "evaluation": { + "distance_metric": eval_cfg.smoothness.distance_metric, + "output_dir": str(output_dir / "smoothness"), + "save_plots": eval_cfg.smoothness.save_plots, + "save_distributions": eval_cfg.smoothness.save_distributions, + "verbose": eval_cfg.smoothness.verbose, + }, + } + + out_path = output_dir / "configs" / "smoothness.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _generate_plot_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate per-experiment plot config YAML (template with placeholders).""" + cfg_dict = { + "input_path": _ZARR_PLACEHOLDER, + "output_dir": _PLOT_DIR_PLACEHOLDER, + "embedding_keys": eval_cfg.plot.embedding_keys, + "color_by": eval_cfg.plot.color_by, + "point_size": eval_cfg.plot.point_size, + "components": list(eval_cfg.plot.components), + "pairplot_components": eval_cfg.plot.pairplot_components, + "format": eval_cfg.plot.format, + } + + out_path = output_dir / "configs" / "plot.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _generate_plot_combined_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate combined plot config YAML. + + The input_paths list is patched at runtime by Nextflow. + """ + cfg_dict = { + "input_paths": [_ZARR_PLACEHOLDER], + "output_dir": str(output_dir / "plots" / "combined"), + "embedding_keys": eval_cfg.plot.combined_embedding_keys, + "color_by": eval_cfg.plot.combined_color_by, + "point_size": eval_cfg.plot.point_size, + "components": list(eval_cfg.plot.components), + "pairplot_components": eval_cfg.plot.pairplot_components, + "format": eval_cfg.plot.format, + } + + out_path = output_dir / "configs" / "plot_combined.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _generate_append_annotations_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate append-annotations config YAML. + + Sources annotations from ``eval_cfg.append_annotations`` when present + (Wave-2 datasets that have ground truth but do not train LCs), else + falls back to ``eval_cfg.linear_classifiers.annotations`` (Wave-1 + legacy path where annotations live alongside LC training config). + """ + if eval_cfg.append_annotations is not None and eval_cfg.append_annotations.annotations: + annotations = eval_cfg.append_annotations.annotations + # Tasks list is informational for the writer; when running standalone + # we emit an empty list (annotation columns are inferred from CSV). + tasks: list[dict] = [] + else: + lc = eval_cfg.linear_classifiers + annotations = lc.annotations + tasks = [{"task": t.task, "marker_filters": t.marker_filters} for t in lc.tasks] + + cfg_dict = { + "embeddings_path": str(output_dir / "embeddings"), + "annotations": [{"experiment": a.experiment, "path": a.path} for a in annotations], + "tasks": tasks, + } + out_path = output_dir / "configs" / "append_annotations.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + return out_path + + +def _generate_append_predictions_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate append-predictions config YAML. + + Honors ``eval_cfg.append_predictions.pipelines_dir`` when set (Wave-2 + evaluations fetching from the central LC registry, typically + ``{registry_root}/{model_name}/latest``). Otherwise falls back to the + legacy in-run location ``output_dir/linear_classifiers/pipelines/``. + """ + ap = eval_cfg.append_predictions + if ap is not None and ap.pipelines_dir: + pipelines_dir = ap.pipelines_dir + else: + pipelines_dir = str(output_dir / "linear_classifiers" / "pipelines") + + cfg_dict = { + "embeddings_path": str(output_dir / "embeddings"), + "pipelines_dir": pipelines_dir, + } + out_path = output_dir / "configs" / "append_predictions.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _generate_linear_classifiers_yaml(eval_cfg: EvaluationConfig, output_dir: Path) -> Path: + """Generate linear classifiers config YAML for dynaclr run-linear-classifiers. + + Propagates ``publish_dir`` (central LC registry root) when set — the writer + atomically promotes the trained bundle to ``{publish_dir}/vN/`` and updates + the ``latest`` symlink. + """ + lc = eval_cfg.linear_classifiers + embeddings_dir = str(output_dir / "embeddings") + lc_output_dir = str(output_dir / "linear_classifiers") + + cfg_dict: dict = { + "embeddings_path": embeddings_dir, + "output_dir": lc_output_dir, + "annotations": [{"experiment": a.experiment, "path": a.path} for a in lc.annotations], + "tasks": [{"task": t.task, "marker_filters": t.marker_filters} for t in lc.tasks], + "use_scaling": lc.use_scaling, + "use_pca": lc.use_pca, + "n_pca_components": lc.n_pca_components, + "max_iter": lc.max_iter, + "class_weight": lc.class_weight, + "solver": lc.solver, + "split_train_data": lc.split_train_data, + "random_seed": lc.random_seed, + } + if lc.publish_dir: + cfg_dict["publish_dir"] = lc.publish_dir + + out_path = output_dir / "configs" / "linear_classifiers.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + return out_path + + +def _mmd_block_name(mmd: "MMDStepConfig", idx: int) -> str: # noqa: F821 + """Derive a filesystem-safe name for an MMD block.""" + if mmd.name: + return mmd.name + return f"mmd_{idx}" + + +def _generate_mmd_yaml(mmd: "MMDStepConfig", output_dir: Path, block_name: str) -> Path: # noqa: F821 + """Generate per-experiment MMD config YAML template (uses __ZARR_PATH__ placeholder).""" + cfg_dict = { + "input_path": _ZARR_PLACEHOLDER, + "output_dir": str(output_dir / "mmd" / block_name), + "comparisons": [{"cond_a": c.cond_a, "cond_b": c.cond_b, "label": c.label} for c in mmd.comparisons], + "group_by": mmd.group_by, + "obs_filter": mmd.obs_filter, + "embedding_key": mmd.embedding_key, + "mmd": mmd.mmd.model_dump(), + "map_settings": mmd.map_settings.model_dump(), + "temporal_bin_size": mmd.temporal_bin_size, + "save_plots": mmd.save_plots, + } + out_path = output_dir / "configs" / f"{block_name}.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _generate_mmd_combined_yaml(mmd: "MMDStepConfig", output_dir: Path, block_name: str) -> Path: # noqa: F821 + """Generate cross-experiment MMD config YAML template (input_paths patched at runtime).""" + combined_name = f"{block_name}_cross_exp" + combined_bin_size = ( + mmd.combined_temporal_bin_size if mmd.combined_temporal_bin_size is not None else mmd.temporal_bin_size + ) + cfg_dict = { + "input_paths": [_ZARR_PLACEHOLDER], + "output_dir": str(output_dir / "mmd" / combined_name), + "group_by": mmd.group_by, + "obs_filter": mmd.obs_filter, + "embedding_key": mmd.embedding_key, + "mmd": mmd.mmd.model_dump(), + "map_settings": mmd.map_settings.model_dump(), + "temporal_bin_size": combined_bin_size, + "save_plots": mmd.save_plots, + } + out_path = output_dir / "configs" / f"{combined_name}.yaml" + with open(out_path, "w") as f: + yaml.dump(cfg_dict, f, default_flow_style=False, sort_keys=False) + return out_path + + +def _resolve_cell_index_path(eval_cfg: EvaluationConfig, training_cfg: dict) -> str: + """Resolve the cell index parquet path from eval config or training config fallback.""" + if eval_cfg.cell_index_path is not None: + return eval_cfg.cell_index_path + return training_cfg["data"]["init_args"]["cell_index_path"] + + +# --------------------------------------------------------------------------- +# Main prepare_configs function +# --------------------------------------------------------------------------- + + +def prepare_configs(config: Path) -> None: + """Generate all per-step YAML configs and print a JSON manifest to stdout. + + The manifest maps step names to generated config paths and includes paths + needed by Nextflow to wire the pipeline (embeddings_dir, output_dir, + cell_index_path, mmd_blocks). + """ + raw = load_config(config) + eval_cfg = EvaluationConfig(**raw) + + training_cfg = _load_training_config(eval_cfg.training_config) + output_dir = Path(eval_cfg.output_dir) + + # Create output directories for active steps + subdirs = ["configs", "embeddings"] + step_subdirs = { + "smoothness": "smoothness", + "mmd": "mmd", + "plot": "plots", + "linear_classifiers": "linear_classifiers", + } + for step in eval_cfg.steps: + if step in step_subdirs: + subdirs.append(step_subdirs[step]) + for subdir in subdirs: + (output_dir / subdir).mkdir(parents=True, exist_ok=True) + + # Save a copy of the input eval config for reproducibility and re-runs + shutil.copy(config, output_dir / "configs" / "eval.yaml") + + manifest: dict = { + "output_dir": str(output_dir), + "embeddings_dir": str(output_dir / "embeddings"), + "cell_index_path": _resolve_cell_index_path(eval_cfg, training_cfg), + "mmd_blocks": [], + "mmd_combined_blocks": [], + } + + for step in eval_cfg.steps: + if step == "predict": + predict_yml = _generate_predict_yaml(eval_cfg, training_cfg, output_dir) + manifest["predict"] = str(predict_yml) + click.echo(f"[predict] {predict_yml}", err=True) + + elif step == "split": + click.echo( + f"[split] viewer.yaml will be written to {output_dir / 'configs' / 'viewer.yaml'} after split runs", + err=True, + ) + + elif step == "reduce_dimensionality": + reduce_yaml = _generate_reduce_yaml(eval_cfg, output_dir) + manifest["reduce"] = str(reduce_yaml) + click.echo(f"[reduce] {reduce_yaml}", err=True) + + elif step == "reduce_combined": + reduce_combined_yaml = _generate_reduce_combined_yaml(eval_cfg, output_dir) + manifest["reduce_combined"] = str(reduce_combined_yaml) + click.echo(f"[combined] {reduce_combined_yaml}", err=True) + + elif step == "smoothness": + smoothness_yaml = _generate_smoothness_yaml(eval_cfg, output_dir) + manifest["smoothness"] = str(smoothness_yaml) + click.echo(f"[smooth] {smoothness_yaml}", err=True) + + elif step == "plot": + plot_yaml = _generate_plot_yaml(eval_cfg, output_dir) + manifest["plot"] = str(plot_yaml) + click.echo(f"[plot] {plot_yaml}", err=True) + + elif step == "plot_combined": + plot_combined_yaml = _generate_plot_combined_yaml(eval_cfg, output_dir) + manifest["plot_combined"] = str(plot_combined_yaml) + click.echo(f"[plot] {plot_combined_yaml}", err=True) + + elif step == "mmd": + if not eval_cfg.mmd: + click.echo("[mmd] skipped: no blocks configured", err=True) + continue + for i, mmd_block in enumerate(eval_cfg.mmd): + block_name = _mmd_block_name(mmd_block, i) + mmd_yaml = _generate_mmd_yaml(mmd_block, output_dir, block_name) + manifest[f"mmd_{block_name}"] = str(mmd_yaml) + manifest[f"mmd_{block_name}_dir"] = str(output_dir / "mmd" / block_name) + manifest["mmd_blocks"].append(block_name) + click.echo(f"[mmd] {mmd_yaml}", err=True) + if mmd_block.combined_mode: + mmd_combined_yaml = _generate_mmd_combined_yaml(mmd_block, output_dir, block_name) + combined_name = f"{block_name}_cross_exp" + manifest[f"mmd_{combined_name}"] = str(mmd_combined_yaml) + manifest["mmd_combined_blocks"].append(block_name) + click.echo(f"[mmd] {mmd_combined_yaml}", err=True) + + elif step == "linear_classifiers": + if eval_cfg.linear_classifiers is None: + click.echo("[linear_classifiers] skipped: no config provided", err=True) + continue + if not eval_cfg.linear_classifiers.annotations: + click.echo( + "[linear_classifiers] Warning: annotations is empty. " + "Add experiment + annotation CSV paths before running.", + err=True, + ) + if not eval_cfg.linear_classifiers.tasks: + click.echo( + "[linear_classifiers] Warning: tasks is empty. " + "Add task specs (task + optional marker_filters) before running.", + err=True, + ) + lc_yaml = _generate_linear_classifiers_yaml(eval_cfg, output_dir) + manifest["linear_classifiers"] = str(lc_yaml) + click.echo(f"[lc] {lc_yaml}", err=True) + + elif step == "append_annotations": + # Annotations may live in either: + # (a) eval_cfg.append_annotations.annotations — Wave-2 datasets + # that have ground truth but do not train LCs (alfi). + # (b) eval_cfg.linear_classifiers.annotations — Wave-1 legacy + # path where annotations are colocated with LC training. + has_aa = eval_cfg.append_annotations is not None and eval_cfg.append_annotations.annotations + has_lc = eval_cfg.linear_classifiers is not None and eval_cfg.linear_classifiers.annotations + if not (has_aa or has_lc): + click.echo( + "[append_annotations] skipped: no annotations configured " + "(set append_annotations.annotations or linear_classifiers.annotations)", + err=True, + ) + continue + aa_yaml = _generate_append_annotations_yaml(eval_cfg, output_dir) + manifest["append_annotations"] = str(aa_yaml) + click.echo(f"[append_ann] {aa_yaml}", err=True) + + elif step == "append_predictions": + # Two ways to satisfy append_predictions: + # (a) in-run: the same eval also trains LCs (linear_classifiers in + # steps + LinearClassifiersStepConfig present), and we fetch + # from output_dir/linear_classifiers/pipelines/. + # (b) external: eval_cfg.append_predictions.pipelines_dir points + # at a central registry directory (typically the `latest` + # symlink under a model's registry root). Wave-2 runs. + has_external = eval_cfg.append_predictions is not None and eval_cfg.append_predictions.pipelines_dir + has_in_run = eval_cfg.linear_classifiers is not None and "linear_classifiers" in eval_cfg.steps + if not (has_external or has_in_run): + raise ValueError( + "'append_predictions' requires either:\n" + " (a) 'linear_classifiers' in steps (train LCs in this run), or\n" + " (b) append_predictions.pipelines_dir set to an existing LC bundle\n" + " (fetch pipelines from a separate run / central registry)." + ) + ap_yaml = _generate_append_predictions_yaml(eval_cfg, output_dir) + manifest["append_predictions"] = str(ap_yaml) + click.echo(f"[append_pred] {ap_yaml}", err=True) + + else: + click.echo(f"Unknown step '{step}', skipping", err=True) + + # Print JSON manifest to stdout for Nextflow to consume + click.echo(json.dumps(manifest, indent=2)) + + +# --------------------------------------------------------------------------- +# CLI entry points +# --------------------------------------------------------------------------- + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to evaluation YAML configuration file", +) +def main(config: Path) -> None: + """Generate evaluation configs for a trained DynaCLR model. + + Writes per-step YAML configs to output_dir/configs/ and prints a JSON manifest + to stdout mapping step names to config paths. Used as the entry point for the + Nextflow evaluation pipeline. + """ + prepare_configs(config) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/evaluate_config.py b/applications/dynaclr/src/dynaclr/evaluation/evaluate_config.py new file mode 100644 index 000000000..cc3a81e30 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/evaluate_config.py @@ -0,0 +1,383 @@ +"""Pydantic configuration models for the DynaCLR evaluation orchestrator.""" + +from __future__ import annotations + +from typing import Literal, Optional + +from pydantic import BaseModel + +from dynaclr.evaluation.dimensionality_reduction.config import PCAConfig, PHATEConfig, UMAPConfig +from dynaclr.evaluation.mmd.config import ComparisonSpec, MAPSettings, MMDSettings + + +class PredictStepConfig(BaseModel): + """Configuration for the embedding extraction (predict) step. + + Parameters + ---------- + batch_size : int + Batch size for inference. Default: 128. + num_workers : int + DataLoader thread workers. Default: 2. + precision : str + Mixed-precision setting for Lightning Trainer. Default: "bf16-mixed". + devices : int + Number of GPUs. Default: 1. + embedding_key : {"features", "projections"} + Which array the EmbeddingWriter stores as the primary embedding in + ``adata.X``. ``"features"`` (default) writes the encoder backbone + output. ``"projections"`` writes the trained projection-head output — + required when the projection head is the only finetuned component + (e.g. DINOv3-temporal-MLP, where the DINOv3 backbone is frozen and the + MLP head carries all the learned task signal). The unselected array + is still saved to ``obsm["X_projections"]`` / ``obsm["X_backbone"]`` + as a sidecar. + """ + + batch_size: int = 128 + num_workers: int = 2 + precision: str = "32-true" + devices: int = 1 + embedding_key: Literal["features", "projections"] = "features" + + +class ReduceCombinedStepConfig(BaseModel): + """Configuration for the joint dimensionality reduction step across experiments. + + Parameters + ---------- + overwrite_keys : bool + Whether to overwrite existing obsm keys. Default: True. + pca : PCAConfig or None + PCA parameters for joint fit. Results stored as X_pca_combined. + umap : UMAPConfig or None + UMAP parameters for joint fit. Results stored as X_umap_combined. + phate : PHATEConfig or None + PHATE parameters for joint fit. Results stored as X_phate_combined. + """ + + overwrite_keys: bool = True + pca: Optional[PCAConfig] = PCAConfig(n_components=32, normalize_features=True) + umap: Optional[UMAPConfig] = None + phate: Optional[PHATEConfig] = PHATEConfig(n_components=2, knn=5, decay=40, scale_embeddings=False) + + +class ReduceStepConfig(BaseModel): + """Configuration for the dimensionality reduction step. + + Parameters + ---------- + overwrite_keys : bool + Whether to overwrite existing obsm keys. Default: True. + pca : PCAConfig or None + PCA parameters. None skips PCA. + umap : UMAPConfig or None + UMAP parameters. None skips UMAP. + phate : PHATEConfig or None + PHATE parameters. None skips PHATE. + """ + + overwrite_keys: bool = True + pca: Optional[PCAConfig] = PCAConfig(n_components=32, normalize_features=True) + umap: Optional[UMAPConfig] = None + phate: Optional[PHATEConfig] = None # PHATE runs jointly in reduce_combined, not per-experiment + + +class SmoothnessStepConfig(BaseModel): + """Configuration for the temporal smoothness evaluation step. + + Parameters + ---------- + distance_metric : str + Distance metric. "cosine" or "euclidean". Default: "cosine". + save_plots : bool + Save distribution plots. Default: True. + save_distributions : bool + Save raw distribution arrays. Default: False. + verbose : bool + Print verbose progress. Default: True. + """ + + distance_metric: Literal["cosine", "euclidean"] = "cosine" + save_plots: bool = True + save_distributions: bool = False + verbose: bool = True + + +class PlotStepConfig(BaseModel): + """Configuration for the embedding visualization step. + + Parameters + ---------- + embedding_keys : list[str] + Per-experiment obsm keys to plot (looped over each split zarr). + Default: ["X_pca"]. + combined_embedding_keys : list[str] + Cross-experiment obsm keys to plot once across all zarrs concatenated. + Default: ["X_pca_combined", "X_phate_combined"]. + color_by : list[str] + obs columns for per-experiment plots. Default: perturbation, hours, marker. + combined_color_by : list[str] + obs columns for combined (cross-experiment) plots. Adds "experiment" to color_by. + point_size : float + Scatter plot point size. Default: 1.0. + components : tuple[int, int] + Which components to use as X/Y axes (0-indexed). Default: (0, 1). + pairplot_components : int + Number of leading PCs to render in the pairplot grid (NxN panels). + Smaller is faster: rendering scales with N^2. Default: 8. + format : str + Output format. "pdf" or "png". Default: "pdf". + """ + + embedding_keys: list[str] = ["X_pca"] + combined_embedding_keys: list[str] = ["X_pca_combined", "X_phate_combined"] + color_by: list[str] = ["perturbation", "hours_post_perturbation", "marker"] + combined_color_by: list[str] = ["perturbation", "hours_post_perturbation", "experiment", "marker"] + point_size: float = 1.0 + components: tuple[int, int] = (0, 1) + pairplot_components: int = 8 + format: str = "pdf" + + +class AnnotationSource(BaseModel): + """Annotation CSV for one experiment. + + Parameters + ---------- + experiment : str + Experiment name matching obs["experiment"] in the embeddings zarr. + path : str + Absolute path to the annotation CSV. Must have fov_name, id, and + at least one task column (e.g. infection_state, organelle_state). + """ + + experiment: str + path: str + + +class TaskSpec(BaseModel): + """One classification task to evaluate. + + Parameters + ---------- + task : str + Task column name in annotation CSVs (e.g. infection_state, organelle_state). + marker_filters : list[str] or None + If set, run one classifier per listed marker. None (default) runs one + classifier per marker discovered in the data (all unique obs["marker"] values). + """ + + task: str + marker_filters: Optional[list[str]] = None + + +class MMDStepConfig(BaseModel): + """Configuration for one MMD evaluation block. + + Comparisons are explicit ``(cond_a, cond_b, label)`` pairs — no auto-discovery. + Include a null comparison (e.g. uninfected1 vs uninfected2) to establish + a baseline false-positive rate. + + Parameters + ---------- + comparisons : list[ComparisonSpec] + Explicit pairwise comparisons to run. + group_by : str + obs column whose values are referenced by ``cond_a``/``cond_b``. + Default: "perturbation". + obs_filter : dict[str, str] or None + Subset adata to rows where obs[key] == value before running MMD. + Example: ``{perturbation: uninfected}`` to restrict batch-QC + comparisons to control cells only. None = use all cells. + embedding_key : str or None + obsm key to use. None = raw .X. Default: None. + mmd : MMDSettings + Kernel MMD algorithm settings (permutations, cell caps, seed, etc.). + map_settings : MAPSettings + copairs-based mAP settings. Default: disabled. + temporal_bin_size : float or None + Width of each temporal bin in hours. Edges derived from data max. + None = aggregate MMD. + combined_temporal_bin_size : float or None + Override temporal_bin_size for the combined (cross-experiment) run only. + If not set, falls back to temporal_bin_size. Use None to aggregate across + all time in the combined run while keeping per-experiment binning. + save_plots : bool + Generate kinetics and heatmap plots. Default: True. + combined_mode : bool + Also run cross-experiment MMD with per-experiment batch centering. + Default: False. + name : str or None + Short name used in output filenames (e.g. "perturbation", "batch_qc"). + Auto-derived from group_by if None. + """ + + comparisons: list[ComparisonSpec] + group_by: str = "perturbation" + obs_filter: Optional[dict[str, str]] = None + embedding_key: Optional[str] = None + mmd: MMDSettings = MMDSettings() + map_settings: MAPSettings = MAPSettings() + temporal_bin_size: Optional[float] = None + combined_temporal_bin_size: Optional[float] = None + save_plots: bool = True + combined_mode: bool = False + name: Optional[str] = None + + +class LinearClassifiersStepConfig(BaseModel): + """Configuration for the orchestrated linear classifiers step. + + Parameters + ---------- + annotations : list[AnnotationSource] + Per-experiment annotation CSVs. Each entry maps an experiment name + (matching obs["experiment"] in embeddings.zarr) to a CSV path. + tasks : list[TaskSpec] + Tasks to evaluate. Each task can optionally filter by marker. + publish_dir : str or None + Central LC registry root for this model (e.g., + ``/hpc/projects/.../linear_classifiers/DynaCLR-2D-MIP-BagOfChannels/``). + When set, pipelines are published as a new versioned bundle + (``vN/``) with a ``latest`` symlink update. When None, legacy + behavior: write to ``output_dir/linear_classifiers/pipelines/``. + use_scaling : bool + Apply StandardScaler. Default: True. + use_pca : bool + Apply PCA before classifier. Default: False. + n_pca_components : int or None + Number of PCA components (required if use_pca is True). + max_iter : int + Max iterations for solver. Default: 1000. + class_weight : str or None + Class weighting. "balanced" or None. Default: "balanced". + solver : str + Optimization algorithm. Default: "liblinear". + split_train_data : float + Fraction for training. Default: 0.8. + random_seed : int + Random seed for reproducibility. Default: 42. + """ + + annotations: list[AnnotationSource] + tasks: list[TaskSpec] + publish_dir: Optional[str] = None + use_scaling: bool = True + use_pca: bool = False + n_pca_components: Optional[int] = None + max_iter: int = 1000 + class_weight: Optional[str] = "balanced" + solver: str = "liblinear" + split_train_data: float = 0.8 + random_seed: int = 42 + + +class AppendPredictionsStepConfig(BaseModel): + """Configuration for the append-predictions step. + + Parameters + ---------- + pipelines_dir : str or None + Directory (or ``latest`` symlink) holding a published LC bundle + with ``manifest.json`` and ``{task}_{marker}.joblib`` files. + When None, defaults to ``output_dir/linear_classifiers/pipelines/`` + (legacy layout for runs that both train and apply LCs in the same + eval). Set this explicitly for Wave-2 evaluations that apply + pipelines trained by a separate Wave-1 run. + """ + + pipelines_dir: Optional[str] = None + + +class AppendAnnotationsStepConfig(BaseModel): + """Configuration for the append-annotations step. + + Used by Wave-2 evaluations that have annotation CSVs but do not train + linear classifiers (e.g., alfi). Wave-1 evaluations historically + sourced annotations from ``linear_classifiers.annotations``; this + field lets datasets carry annotations independently of LC training. + When both are set, this field takes precedence. + + Parameters + ---------- + annotations : list[AnnotationSource] + Per-experiment annotation CSVs to merge into per-experiment zarrs. + """ + + annotations: list[AnnotationSource] = [] + + +class EvaluationConfig(BaseModel): + """Top-level configuration for the DynaCLR evaluation orchestrator. + + Parameters + ---------- + training_config : str + Path to the training YAML config (Lightning CLI format). Model + architecture, normalizations, and data parameters are auto-extracted. + ckpt_path : str + Path to the model checkpoint (.ckpt). + cell_index_path : str or None + Override the cell index parquet path from the training config. + None = use the path from the training config. + output_dir : str + Root directory for all evaluation outputs. + steps : list[str] + Ordered list of steps to generate configs for. + Valid values: predict, split, reduce_dimensionality, reduce_combined, + plot, plot_combined, smoothness, mmd, linear_classifiers, + append_annotations, append_predictions. + ``plot`` emits per-experiment scatter plots (fans out one job per + experiment). ``plot_combined`` emits the joint cross-experiment + plot only. List both to get both; list neither to skip plotting. + predict : PredictStepConfig + Predict step configuration. + reduce_dimensionality : ReduceStepConfig + Per-experiment dimensionality reduction step configuration. + reduce_combined : ReduceCombinedStepConfig + Joint dimensionality reduction across all experiments. + smoothness : SmoothnessStepConfig + Smoothness evaluation configuration. + plot : PlotStepConfig + Embedding visualization configuration. + linear_classifiers : LinearClassifiersStepConfig or None + Linear classifier configuration. None disables this step. + append_predictions : AppendPredictionsStepConfig or None + Append-predictions configuration. Set ``pipelines_dir`` to apply + pipelines from a separate eval run (e.g., Wave 2 fetching from the + central LC registry). None keeps legacy behavior. + mmd : list[MMDStepConfig] + MMD evaluation blocks. Each block is an independent run with its own + group_by, comparisons, and optional obs_filter. Empty list disables MMD. + """ + + training_config: str + # ckpt_path is None for foundation-model baselines (e.g. DINOv3-frozen) where + # weights are loaded from HuggingFace inside the model __init__ and there is + # no Lightning checkpoint to restore from. + ckpt_path: Optional[str] = None + cell_index_path: Optional[str] = None + output_dir: str + steps: list[str] = ["predict", "split", "reduce_dimensionality", "reduce_combined", "plot", "smoothness"] + predict: PredictStepConfig = PredictStepConfig() + reduce_dimensionality: ReduceStepConfig = ReduceStepConfig() + reduce_combined: ReduceCombinedStepConfig = ReduceCombinedStepConfig() + smoothness: SmoothnessStepConfig = SmoothnessStepConfig() + plot: PlotStepConfig = PlotStepConfig() + linear_classifiers: Optional[LinearClassifiersStepConfig] = None + append_annotations: Optional[AppendAnnotationsStepConfig] = None + append_predictions: Optional[AppendPredictionsStepConfig] = None + mmd: list[MMDStepConfig] = [] + + @property + def model_name(self) -> str: + """Derive the model identifier from the training config filename stem. + + Example: ``DynaCLR-2D-MIP-BagOfChannels.yml`` → ``"DynaCLR-2D-MIP-BagOfChannels"``. + Used as the ``feature_space`` tag in LC manifests and as the + namespace prefix for predicted columns in output zarrs. + """ + from pathlib import Path as _Path + + return _Path(self.training_config).stem diff --git a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/apply_linear_classifier.py b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/apply_linear_classifier.py index 3c98029b5..34e738ff1 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/apply_linear_classifier.py +++ b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/apply_linear_classifier.py @@ -10,7 +10,7 @@ from anndata import read_zarr from pydantic import ValidationError -from viscy_utils.cli_utils import format_markdown_table, load_config +from viscy_utils.cli_utils import format_markdown_table, load_config_section from viscy_utils.evaluation.linear_classifier import ( load_pipeline_from_wandb, predict_with_classifier, @@ -92,7 +92,7 @@ def main(config: Path): click.echo("=" * 60) try: - config_dict = load_config(config) + config_dict = load_config_section(config, None, default_section="apply_linear_classifier") inference_config = LinearClassifierInferenceConfig(**config_dict) except ValidationError as e: click.echo(f"\n Configuration validation failed:\n{e}", err=True) diff --git a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/cross_validation.py b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/cross_validation.py index 3d4a33d80..47bb3172e 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/cross_validation.py +++ b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/cross_validation.py @@ -37,7 +37,7 @@ get_available_tasks, resolve_task_channels, ) -from viscy_utils.cli_utils import format_markdown_table, load_config +from viscy_utils.cli_utils import format_markdown_table, load_config_section from viscy_utils.evaluation.annotation import load_annotation_anndata from viscy_utils.evaluation.linear_classifier import ( load_and_combine_datasets, @@ -137,17 +137,6 @@ def _get_class_counts(datasets_for_combo: list[dict], task: str) -> dict[str, in return dict(pd.Series(all_labels).value_counts()) -def _detect_n_features(datasets: list[dict], channel: str) -> int | None: - """Detect embedding dimensionality from the first available zarr.""" - for ds in datasets: - embeddings_dir = Path(ds["embeddings_dir"]) - channel_zarrs = find_channel_zarrs(embeddings_dir, [channel]) - if channel in channel_zarrs: - adata = ad.read_zarr(channel_zarrs[channel]) - return adata.shape[1] - return None - - # --------------------------------------------------------------------------- # Core rotating CV unit # --------------------------------------------------------------------------- @@ -234,7 +223,7 @@ def _train_and_evaluate( "random_state": seed, } - pipeline, metrics = train_linear_classifier( + pipeline, metrics, _ = train_linear_classifier( adata=combined_adata, task=task, use_scaling=use_scaling, @@ -828,7 +817,7 @@ def _get_recommended_subsets(summary_df: pd.DataFrame) -> pd.DataFrame: ) def main(config: Path, task: str | None, report: bool): """Run rotating test-set leave-one-dataset-out cross-validation.""" - config_dict = load_config(config) + config_dict = load_config_section(config, None, default_section="cross_validate") if report: config_dict["report"] = True diff --git a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/evaluate_dataset.py b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/evaluate_dataset.py deleted file mode 100644 index ad615758f..000000000 --- a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/evaluate_dataset.py +++ /dev/null @@ -1,456 +0,0 @@ -"""Evaluation pipeline comparing embedding models on a held-out test dataset. - -Trains linear classifiers on cross-dataset embeddings, applies them to a -held-out test set, evaluates predictions, and optionally generates a PDF -comparison report. - -Usage:: - - python scripts/evaluate_dataset.py -c configs/evaluate_dataset_example.yaml - python scripts/evaluate_dataset.py -c config.yaml --report -""" - -from __future__ import annotations - -import argparse -from pathlib import Path -from typing import Any - -import anndata as ad -import joblib -import pandas as pd -from sklearn.metrics import classification_report - -from dynaclr.evaluation.linear_classifiers.utils import ( - find_channel_zarrs, - get_available_tasks, - resolve_task_channels, -) -from viscy_utils.cli_utils import format_markdown_table, load_config -from viscy_utils.evaluation.annotation import load_annotation_anndata -from viscy_utils.evaluation.linear_classifier import ( - load_and_combine_datasets, - predict_with_classifier, - save_pipeline_to_wandb, - train_linear_classifier, -) - -# --------------------------------------------------------------------------- -# Main evaluation function -# --------------------------------------------------------------------------- - - -def run_evaluation(config: dict) -> None: - """Run the full evaluation pipeline: train, infer, evaluate, report. - - Parameters - ---------- - config : dict - Evaluation config parsed from YAML. Expected keys: - - dataset_name: str - - test_annotations_csv: str path - - output_dir: str path - - models: dict of model specs - - task_channels: dict or None (auto-detect from test CSV) - - use_scaling, n_pca_components, max_iter, class_weight, solver, - split_train_data, random_seed - - wandb_logging: bool (default True) - """ - output_dir = Path(config["output_dir"]) - output_dir.mkdir(parents=True, exist_ok=True) - - test_csv = Path(config["test_annotations_csv"]) - tc = resolve_task_channels(config.get("task_channels"), [test_csv]) - if not tc: - raise ValueError("No valid tasks found in test annotations CSV.") - - model_labels = list(config["models"].keys()) - - print("## Evaluation Pipeline") - print(f" Test dataset: {config['dataset_name']}") - print(f" Task-channels: {tc}") - print(f" Models: {model_labels}") - - use_scaling = config.get("use_scaling", True) - n_pca = config.get("n_pca_components") - use_pca = n_pca is not None - split_train_data = config.get("split_train_data", 0.8) - random_seed = config.get("random_seed", 42) - wandb_logging = config.get("wandb_logging", True) - - classifier_params = { - "max_iter": config.get("max_iter", 1000), - "class_weight": config.get("class_weight", "balanced"), - "solver": config.get("solver", "liblinear"), - "random_state": random_seed, - } - - train_results: dict[str, dict[tuple[str, str], dict[str, Any]]] = {} - eval_results: dict[str, dict[tuple[str, str], dict[str, Any]]] = {} - - for model_label, model_spec in config["models"].items(): - print(f"\n### Model: {model_label} ({model_spec.get('name', model_label)})") - model_train: dict[tuple[str, str], dict[str, Any]] = {} - model_eval: dict[tuple[str, str], dict[str, Any]] = {} - model_output_dir = output_dir / model_label - model_output_dir.mkdir(parents=True, exist_ok=True) - - test_embeddings_dir = Path(model_spec["test_embeddings_dir"]) - - for task, channels in tc.items(): - test_channel_zarrs = find_channel_zarrs(test_embeddings_dir, channels) - - for channel in channels: - combo_key = (task, channel) - print(f"\n {task} / {channel}:") - - # --- Train --- - try: - datasets_for_combo = _build_train_datasets(model_spec["train_datasets"], task, channel) - if not datasets_for_combo: - print(" No training datasets available, skipping.") - continue - - print(f" Training on {len(datasets_for_combo)} dataset(s)") - combined_adata = load_and_combine_datasets(datasets_for_combo, task) - - pipeline, metrics = train_linear_classifier( - adata=combined_adata, - task=task, - use_scaling=use_scaling, - use_pca=use_pca, - n_pca_components=n_pca, - classifier_params=classifier_params, - split_train_data=split_train_data, - random_seed=random_seed, - ) - - pipeline_path = model_output_dir / f"{task}_{channel}_pipeline.joblib" - joblib.dump(pipeline, pipeline_path) - print(f" Pipeline saved: {pipeline_path.name}") - - artifact_name = f"{model_spec.get('name', model_label)}_{task}_{channel}_local" - if wandb_logging and "wandb_project" in model_spec: - wandb_config = { - "task": task, - "input_channel": channel, - "marker": config.get("marker"), - "embedding_model": f"{model_spec['name']}-{model_spec['version']}", - "test_dataset": config["dataset_name"], - "use_scaling": use_scaling, - "use_pca": use_pca, - "n_pca_components": n_pca, - "max_iter": classifier_params["max_iter"], - "class_weight": classifier_params["class_weight"], - "solver": classifier_params["solver"], - "split_train_data": split_train_data, - "random_seed": random_seed, - } - wandb_tags = [ - config["dataset_name"], - model_spec["name"], - model_spec["version"], - channel, - task, - "cross-dataset", - ] - artifact_name = save_pipeline_to_wandb( - pipeline=pipeline, - metrics=metrics, - config=wandb_config, - wandb_project=model_spec["wandb_project"], - tags=wandb_tags, - ) - - model_train[combo_key] = { - "pipeline": pipeline, - "metrics": metrics, - "artifact_name": artifact_name, - } - - val_acc = metrics.get("val_accuracy") - val_f1 = metrics.get("val_weighted_f1") - if val_acc is not None: - print(f" Val accuracy: {val_acc:.3f} Val F1: {val_f1:.3f}") - - except Exception as e: - print(f" TRAIN FAILED: {e}") - continue - - # --- Infer + Evaluate --- - if channel not in test_channel_zarrs: - print(f" No test zarr for {channel}, skipping inference.") - continue - - try: - print(" Loading test embeddings...") - test_adata = ad.read_zarr(test_channel_zarrs[channel]) - - artifact_metadata = { - "artifact_name": artifact_name, - "artifact_id": artifact_name, - "artifact_version": "local", - } - test_adata = predict_with_classifier( - test_adata, - pipeline, - task, - artifact_metadata=artifact_metadata, - ) - - pred_path = model_output_dir / f"{task}_{channel}_predictions.zarr" - test_adata.write_zarr(pred_path) - print(f" Saved predictions: {pred_path.name}") - - # Evaluate against ground truth - annotated = load_annotation_anndata(test_adata, str(test_csv), task) - mask = annotated.obs[task].notna() & (annotated.obs[task] != "unknown") - eval_subset = annotated[mask] - - if len(eval_subset) == 0: - print(" No annotated test cells after filtering.") - continue - - pred_col = f"predicted_{task}" - y_true = eval_subset.obs[task].values - y_pred = eval_subset.obs[pred_col].values - - report = classification_report(y_true, y_pred, digits=3, output_dict=True) - - test_metrics = { - "test_accuracy": report["accuracy"], - "test_weighted_precision": report["weighted avg"]["precision"], - "test_weighted_recall": report["weighted avg"]["recall"], - "test_weighted_f1": report["weighted avg"]["f1-score"], - "test_n_samples": len(eval_subset), - } - - for class_name in sorted(set(y_true) | set(y_pred)): - if class_name in report: - test_metrics[f"test_{class_name}_precision"] = report[class_name]["precision"] - test_metrics[f"test_{class_name}_recall"] = report[class_name]["recall"] - test_metrics[f"test_{class_name}_f1"] = report[class_name]["f1-score"] - - annotated_path = model_output_dir / f"{task}_{channel}_annotated.zarr" - annotated.write_zarr(annotated_path) - - model_eval[combo_key] = { - "metrics": test_metrics, - "annotated_adata": annotated, - } - - acc = test_metrics["test_accuracy"] - f1 = test_metrics["test_weighted_f1"] - n = test_metrics["test_n_samples"] - print(f" Test: acc={acc:.3f} F1={f1:.3f} (n={n})") - - except Exception as e: - print(f" EVAL FAILED: {e}") - continue - - train_results[model_label] = model_train - eval_results[model_label] = model_eval - - # Save per-model metrics CSV - _save_metrics_csv( - model_train, - model_eval, - model_output_dir / "metrics_summary.csv", - ) - - # Save combined comparison CSVs - _save_comparison_csv(train_results, output_dir / "train_metrics_comparison.csv") - _save_eval_comparison_csv(eval_results, output_dir / "test_metrics_comparison.csv") - - # Print markdown summary - _print_summary(train_results, eval_results, tc) - - return train_results, eval_results - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _build_train_datasets(train_datasets: list[dict], task: str, channel: str) -> list[dict]: - """Filter and build training dataset dicts for a (task, channel) combo. - - Parameters - ---------- - train_datasets : list[dict] - Raw dataset entries from config, each with 'embeddings_dir' and 'annotations'. - task : str - Classification task to check for. - channel : str - Channel to look for in embeddings_dir. - - Returns - ------- - list[dict] - Filtered list with 'embeddings' and 'annotations' keys. - """ - result = [] - for ds in train_datasets: - embeddings_dir = Path(ds["embeddings_dir"]) - annotations_path = Path(ds["annotations"]) - - channel_zarrs = find_channel_zarrs(embeddings_dir, [channel]) - if channel not in channel_zarrs: - print(f" Skipping {embeddings_dir.parent.name} - no {channel} zarr") - continue - - available_tasks = get_available_tasks(annotations_path) - if task not in available_tasks: - print(f" Skipping {embeddings_dir.parent.name} - no {task} column") - continue - - training_dict = { - "embeddings": str(channel_zarrs[channel]), - "annotations": str(annotations_path), - } - if "include_wells" in ds: - training_dict["include_wells"] = ds["include_wells"] - result.append(training_dict) - return result - - -def _save_metrics_csv( - train_results: dict[tuple[str, str], dict[str, Any]], - eval_results: dict[tuple[str, str], dict[str, Any]], - output_path: Path, -) -> None: - """Save combined train + eval metrics for one model.""" - rows = [] - all_keys = set(train_results.keys()) | set(eval_results.keys()) - for combo_key in sorted(all_keys): - task, channel = combo_key - row = {"task": task, "channel": channel} - if combo_key in train_results: - row.update(train_results[combo_key]["metrics"]) - if combo_key in eval_results: - row.update(eval_results[combo_key]["metrics"]) - rows.append(row) - - if rows: - pd.DataFrame(rows).to_csv(output_path, index=False) - - -def _save_comparison_csv( - all_results: dict[str, dict[tuple[str, str], dict[str, Any]]], - output_path: Path, -) -> None: - """Save combined train metrics comparison across models.""" - rows = [] - for model_label, model_results in all_results.items(): - for (task, channel), result in model_results.items(): - row = {"model": model_label, "task": task, "channel": channel} - row.update(result["metrics"]) - rows.append(row) - if rows: - pd.DataFrame(rows).to_csv(output_path, index=False) - - -def _save_eval_comparison_csv( - all_results: dict[str, dict[tuple[str, str], dict[str, Any]]], - output_path: Path, -) -> None: - """Save combined test metrics comparison across models.""" - rows = [] - for model_label, model_results in all_results.items(): - for (task, channel), result in model_results.items(): - row = {"model": model_label, "task": task, "channel": channel} - row.update(result["metrics"]) - rows.append(row) - if rows: - pd.DataFrame(rows).to_csv(output_path, index=False) - - -def _print_summary( - train_results: dict[str, dict[tuple[str, str], dict[str, Any]]], - eval_results: dict[str, dict[tuple[str, str], dict[str, Any]]], - task_channels: dict[str, list[str]], -) -> None: - """Print markdown summary table of all results.""" - headers = ["Task", "Channel"] - model_labels = list(train_results.keys()) - for label in model_labels: - headers += [ - f"{label} Val Acc", - f"{label} Val F1", - f"{label} Test Acc", - f"{label} Test F1", - ] - - rows = [] - for task, channels in task_channels.items(): - for channel in channels: - row_dict = {"Task": task, "Channel": channel} - for label in model_labels: - tr = train_results.get(label, {}).get((task, channel)) - ev = eval_results.get(label, {}).get((task, channel)) - if tr: - row_dict[f"{label} Val Acc"] = f"{tr['metrics'].get('val_accuracy', float('nan')):.3f}" - row_dict[f"{label} Val F1"] = f"{tr['metrics'].get('val_weighted_f1', float('nan')):.3f}" - else: - row_dict[f"{label} Val Acc"] = "-" - row_dict[f"{label} Val F1"] = "-" - if ev: - row_dict[f"{label} Test Acc"] = f"{ev['metrics'].get('test_accuracy', float('nan')):.3f}" - row_dict[f"{label} Test F1"] = f"{ev['metrics'].get('test_weighted_f1', float('nan')):.3f}" - else: - row_dict[f"{label} Test Acc"] = "-" - row_dict[f"{label} Test F1"] = "-" - rows.append(row_dict) - - print(format_markdown_table(rows, title="Evaluation Summary", headers=headers)) - - -# --------------------------------------------------------------------------- -# Entry point -# --------------------------------------------------------------------------- - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Evaluate embedding models on a held-out test dataset") - parser.add_argument( - "-c", - "--config", - type=str, - required=True, - help="Path to YAML config file", - ) - parser.add_argument( - "--report", - action="store_true", - help="Generate PDF comparison report", - ) - args = parser.parse_args() - - config = load_config(args.config) - - print(f"Dataset: {config['dataset_name']}") - print(f"Output: {config['output_dir']}") - for label, spec in config["models"].items(): - n_train = len(spec["train_datasets"]) - print(f" {label}: {n_train} training dataset(s)") - - train_results, eval_results = run_evaluation(config) - - if args.report: - from dynaclr.evaluation.linear_classifiers.report import generate_comparison_report - - test_csv = Path(config["test_annotations_csv"]) - tc = resolve_task_channels(config.get("task_channels"), [test_csv]) - tasks = list(tc.keys()) - channels = sorted({ch for chs in tc.values() for ch in chs}) - - generate_comparison_report( - output_dir=Path(config["output_dir"]), - dataset_name=config["dataset_name"], - model_labels=list(config["models"].keys()), - tasks=tasks, - channels=channels, - train_results=train_results, - eval_results=eval_results, - ) diff --git a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/orchestrated.py b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/orchestrated.py new file mode 100644 index 000000000..8dea5e275 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/orchestrated.py @@ -0,0 +1,530 @@ +"""Orchestrated linear classifiers evaluation from a single embeddings zarr. + +Reads the combined embeddings.zarr produced by the predict step, filters by +experiment and marker, joins per-experiment annotation CSVs, and trains one +logistic regression classifier per (task, marker_filter) combination. + +Outputs a metrics_summary.csv and a summary PDF to the output directory. +No W&B logging. For standalone training with W&B use ``dynaclr train-linear-classifier``. + +Usage +----- +dynaclr run-linear-classifiers -c linear_classifiers.yaml +""" + +from __future__ import annotations + +import json +import os +import tempfile +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import click +import joblib +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib.backends.backend_pdf import PdfPages +from sklearn.model_selection import train_test_split + +from viscy_utils.cli_utils import format_markdown_table, load_config +from viscy_utils.evaluation.annotation import load_annotation_anndata +from viscy_utils.evaluation.linear_classifier import train_linear_classifier + +matplotlib.use("Agg") + +if TYPE_CHECKING: + import anndata as ad + + from dynaclr.evaluation.evaluate_config import LinearClassifiersStepConfig + + +def run_linear_classifiers( + embeddings_path: Path, + config: LinearClassifiersStepConfig, + output_dir: Path, +) -> pd.DataFrame: + """Train linear classifiers for each (task, marker_filter) combination. + + Parameters + ---------- + embeddings_path : Path + Path to the combined embeddings zarr (AnnData format). Must have + experiment and marker columns in obs (added by the predict step). + config : LinearClassifiersStepConfig + Configuration with annotations list and task specs. + output_dir : Path + Directory to write metrics_summary.csv. + + Returns + ------- + pd.DataFrame + One row per (task, marker_filter) with accuracy, F1, AUROC, etc. + """ + import anndata as ad + + click.echo(f"Loading embeddings from {embeddings_path}") + if embeddings_path.is_dir() and not str(embeddings_path).endswith(".zarr"): + zarr_paths = sorted(embeddings_path.glob("*.zarr")) + if not zarr_paths: + raise FileNotFoundError(f"No .zarr files found in {embeddings_path}") + parts = [ad.read_zarr(p) for p in zarr_paths] + adata = ad.concat(parts, join="outer") + adata.obs_names_make_unique() + click.echo(f" Loaded {len(zarr_paths)} per-experiment zarrs") + else: + adata = ad.read_zarr(embeddings_path) + click.echo(f" {adata.n_obs} cells, {adata.n_vars} features") + + missing = [col for col in ["experiment", "marker"] if col not in adata.obs.columns] + if missing: + raise ValueError( + f"embeddings.zarr obs is missing columns: {missing}. " + "Re-run the predict step with the updated pipeline to include metadata." + ) + + all_metrics: list[dict] = [] + # val_outputs_by_task: task → list of per-marker dicts for plotting + val_outputs_by_task: dict[str, list[dict[str, Any]]] = {} + # Saved pipelines for append-predictions step. When publish_dir is set, + # we stage here and atomically promote to a versioned registry dir at + # the end of training. Otherwise legacy behavior: write in place under + # output_dir/pipelines/. + pipelines_dir = output_dir / "pipelines" + pipelines_dir.mkdir(parents=True, exist_ok=True) + pipeline_manifest: list[dict] = [] + # Collect trained (task, marker, pipeline) tuples for publish_dir promotion. + trained_pipelines: list[tuple[str, str, Any]] = [] + + for task_spec in config.tasks: + task = task_spec.task + # Expand marker_filters: None → all unique markers; list → one run per specified marker + runs: list[str] = ( + task_spec.marker_filters + if task_spec.marker_filters is not None + else sorted(adata.obs["marker"].unique().tolist()) + ) + val_outputs_by_task[task] = [] + + for marker_filter in runs: + label = f"{task}" + (f" (marker={marker_filter})" if marker_filter else " (all markers)") + click.echo(f"\n{'=' * 60}") + click.echo(f"Task: {label}") + click.echo("=" * 60) + + # Filter by marker if specified + if marker_filter is not None: + adata_task = adata[adata.obs["marker"] == marker_filter] + click.echo(f" Filtered to {adata_task.n_obs} cells with marker={marker_filter}") + else: + adata_task = adata + + if adata_task.n_obs == 0: + click.echo(f" No cells found for marker_filter={marker_filter!r}, skipping.") + continue + + # Join annotation CSVs per experiment and collect annotated subsets + annotated_parts: list[ad.AnnData] = [] + for ann_src in config.annotations: + exp_mask = adata_task.obs["experiment"] == ann_src.experiment + n_exp = int(exp_mask.sum()) + if n_exp == 0: + click.echo(f" Experiment {ann_src.experiment!r}: no matching cells, skipping.") + continue + + adata_exp = adata_task[exp_mask].copy() + ann_path = Path(ann_src.path) + if not ann_path.exists(): + raise FileNotFoundError(f"Annotation CSV not found: {ann_src.path}") + + try: + adata_exp = load_annotation_anndata(adata_exp, str(ann_path), task) + except KeyError: + click.echo(f" Experiment {ann_src.experiment!r}: task {task!r} not in {ann_path.name}, skipping.") + continue + + valid_mask = adata_exp.obs[task].notna() & (adata_exp.obs[task] != "unknown") + n_valid = int(valid_mask.sum()) + if n_valid == 0: + click.echo(f" Experiment {ann_src.experiment!r}: no valid labels for {task!r}, skipping.") + continue + + annotated_parts.append(adata_exp[valid_mask]) + click.echo(f" Experiment {ann_src.experiment!r}: {n_valid}/{n_exp} labeled cells") + + if not annotated_parts: + click.echo(f" No annotated data found for task {task!r}, skipping.") + continue + + combined = annotated_parts[0] if len(annotated_parts) == 1 else ad.concat(annotated_parts, join="outer") + class_dist = combined.obs[task].value_counts().to_dict() + click.echo(f" Total: {combined.n_obs} cells, class distribution: {class_dist}") + + classifier_params = { + "max_iter": config.max_iter, + "class_weight": config.class_weight, + "solver": config.solver, + "random_state": config.random_seed, + } + + try: + pipeline, metrics, val_outputs = train_linear_classifier( + adata=combined, + task=task, + use_scaling=config.use_scaling, + use_pca=config.use_pca, + n_pca_components=config.n_pca_components, + classifier_params=classifier_params, + split_train_data=config.split_train_data, + random_seed=config.random_seed, + ) + except ValueError as exc: + click.echo(f" Skipping {label}: {exc}") + continue + + # Save pipeline for append-predictions step. Always write to the + # local staging dir; promotion to publish_dir (if configured) happens + # atomically after all classifiers finish training. + pipeline_filename = f"{task}_{marker_filter}.joblib" + joblib.dump(pipeline, pipelines_dir / pipeline_filename) + pipeline_manifest.append({"task": task, "marker_filter": marker_filter, "path": pipeline_filename}) + trained_pipelines.append((task, marker_filter, pipeline)) + click.echo(f" Pipeline saved: {pipeline_filename}") + + # Replay the same split to recover val obs (hours_post_perturbation) + y_full = combined.obs[task].to_numpy(dtype=object) + val_hours: np.ndarray | None = None + if config.split_train_data < 1.0 and "hours_post_perturbation" in combined.obs.columns: + try: + idx = np.arange(len(combined)) + _, idx_val = train_test_split( + idx, + train_size=config.split_train_data, + random_state=config.random_seed, + stratify=y_full, + shuffle=True, + ) + val_hours = combined.obs["hours_post_perturbation"].to_numpy()[idx_val] + except ValueError: + click.echo(" Could not replay stratified split for val_hours; F1-over-time plot skipped.") + + row = { + "task": task, + "marker_filter": marker_filter, + "n_samples": combined.n_obs, + **metrics, + } + all_metrics.append(row) + val_outputs_by_task[task].append( + { + "marker_filter": marker_filter, + "val_hours": val_hours, + **val_outputs, + } + ) + + if not all_metrics: + click.echo("\nNo classifiers trained — check annotations and marker filters.") + return pd.DataFrame() + + results_df = pd.DataFrame(all_metrics) + output_dir.mkdir(parents=True, exist_ok=True) + summary_path = output_dir / "metrics_summary.csv" + results_df.to_csv(summary_path, index=False) + click.echo(f"\nMetrics summary written to {summary_path}") + + # New-format manifest: dict with trained_at + pipelines list. + # Model identity (feature_space) and version are carried by the directory + # structure: {registry_root}/{model_name}/v{N}/. No need to duplicate here. + manifest_dict = { + "trained_at": datetime.now(timezone.utc).isoformat(), + "pipelines": pipeline_manifest, + } + manifest_path = pipelines_dir / "manifest.json" + with open(manifest_path, "w") as f: + json.dump(manifest_dict, f, indent=2) + click.echo(f"Pipeline manifest written to {manifest_path}") + + # Promote to central LC registry if publish_dir is configured. + publish_dir_str = getattr(config, "publish_dir", None) + if publish_dir_str: + new_dir = _publish_atomically( + publish_dir=Path(publish_dir_str), + trained=trained_pipelines, + manifest_dict=manifest_dict, + ) + click.echo(f"Published LC bundle to {new_dir} (latest -> {new_dir.name})") + + _print_summary(results_df) + for task, task_val_outputs in val_outputs_by_task.items(): + task_df = results_df[results_df["task"] == task] + _save_task_plots(task, task_df, task_val_outputs, output_dir) + return results_df + + +def _publish_atomically( + publish_dir: Path, + trained: list[tuple[str, str, Any]], + manifest_dict: dict, +) -> Path: + """Atomically publish a new versioned LC bundle under ``publish_dir``. + + Writes pipelines + manifest.json to a staging directory, renames it to + ``vN/`` (where N is max existing version + 1), then swaps the ``latest`` + symlink to point at the new version. Crash-safe: partial bundles never + appear as ``vN/`` because the rename is atomic. + + Parameters + ---------- + publish_dir : Path + Model registry root (e.g., + ``/hpc/projects/.../linear_classifiers/DynaCLR-2D-MIP-BagOfChannels/``). + Created if it does not exist. + trained : list of (task, marker_filter, pipeline) + Fitted pipelines to persist. + manifest_dict : dict + Manifest content to write as ``manifest.json`` inside the new + version directory. + + Returns + ------- + Path + Absolute path of the newly published ``vN/`` directory. + """ + publish_dir.mkdir(parents=True, exist_ok=True) + + # Pick next version number by scanning existing v* dirs. + existing = sorted(int(p.name[1:]) for p in publish_dir.glob("v*") if p.is_dir() and p.name[1:].isdigit()) + next_v = (max(existing) + 1) if existing else 1 + new_dir = publish_dir / f"v{next_v}" + + # Stage everything in a temp dir under publish_dir (same filesystem for + # atomic rename). If we crash here, nothing named vN/ appears. + staging = Path(tempfile.mkdtemp(prefix=f".v{next_v}.stage.", dir=publish_dir)) + for task, marker_filter, pipeline in trained: + joblib.dump(pipeline, staging / f"{task}_{marker_filter}.joblib") + with open(staging / "manifest.json", "w") as f: + json.dump(manifest_dict, f, indent=2) + + # Atomic rename: staging -> vN. + os.rename(staging, new_dir) + + # Atomic symlink swap: write latest.new, then rename over latest. + # Relative target ("vN") so the symlink stays valid if the registry + # root is ever moved. + latest = publish_dir / "latest" + latest_new = publish_dir / "latest.new" + if latest_new.is_symlink() or latest_new.exists(): + latest_new.unlink() + os.symlink(new_dir.name, latest_new) + os.replace(latest_new, latest) + + return new_dir + + +def _print_summary(results_df: pd.DataFrame) -> None: + """Print a markdown summary table of key metrics.""" + click.echo("\n## Linear Classifier Results\n") + + per_class_f1_cols = sorted(c for c in results_df.columns if c.startswith("val_") and c.endswith("_f1")) + summary_cols = [ + "task", + "marker_filter", + "n_samples", + "val_accuracy", + "val_weighted_f1", + "val_auroc", + ] + per_class_f1_cols + display = results_df[[c for c in summary_cols if c in results_df.columns]].copy() + + float_cols = [c for c in display.columns if c not in ("task", "marker_filter")] + for col in float_cols: + if pd.api.types.is_float_dtype(display[col]): + display[col] = display[col].map(lambda v: f"{v:.3f}" if pd.notna(v) else "N/A") + + rows = display.to_dict(orient="records") + click.echo(format_markdown_table(rows, headers=list(display.columns))) + + +def _save_task_plots( + task: str, + task_df: pd.DataFrame, + task_val_outputs: list[dict[str, Any]], + output_dir: Path, +) -> None: + """Save one PDF per task with bar chart, ROC curves, and F1-over-time plots. + + Parameters + ---------- + task : str + Task name (used in filename and titles). + task_df : pd.DataFrame + Rows from metrics_summary.csv for this task (one row per marker). + task_val_outputs : list[dict] + Per-marker val outputs. Each entry has keys ``marker_filter``, + ``y_val``, ``y_val_proba``, ``classes``, ``val_hours``. + output_dir : Path + Directory to write ``{task}_summary.pdf``. + """ + pdf_path = output_dir / f"{task}_summary.pdf" + + with PdfPages(pdf_path) as pdf: + _plot_metrics_bar(pdf, task, task_df) + for vo in task_val_outputs: + if vo["y_val"] is None or vo["y_val_proba"] is None: + continue + _plot_roc_curves(pdf, task, vo["marker_filter"], vo["y_val"], vo["y_val_proba"], vo["classes"]) + if vo["val_hours"] is not None: + _plot_f1_over_time( + pdf, task, vo["marker_filter"], vo["y_val"], vo["y_val_proba"], vo["classes"], vo["val_hours"] + ) + + click.echo(f"Plots written to {pdf_path}") + + +def _plot_metrics_bar(pdf: PdfPages, task: str, task_df: pd.DataFrame) -> None: + """Bar chart of AUROC, accuracy, and weighted F1 per marker for one task.""" + metric_cols = ["val_auroc", "val_accuracy", "val_weighted_f1"] + present = [c for c in metric_cols if c in task_df.columns] + if not present: + return + + labels = task_df["marker_filter"].fillna("all").tolist() + x = np.arange(len(labels)) + n_metrics = len(present) + width = 0.8 / n_metrics + + metric_display = {"val_auroc": "AUROC", "val_accuracy": "Accuracy", "val_weighted_f1": "Weighted F1"} + colors = ["#0072B2", "#E69F00", "#009E73"] + + fig, ax = plt.subplots(figsize=(max(6, len(labels) * 1.5), 5)) + for i, col in enumerate(present): + vals = task_df[col].fillna(0).values + ax.bar(x + i * width, vals, width, label=metric_display.get(col, col), color=colors[i], alpha=0.85) + + ax.set_xticks(x + width * (n_metrics - 1) / 2) + ax.set_xticklabels(labels, fontsize=9) + ax.set_ylim(0, 1.05) + ax.axhline(0.5, color="gray", linewidth=0.8, linestyle="--", label="Random (0.5)") + ax.set_ylabel("Score") + ax.set_title(f"{task} — classifier performance per marker") + ax.legend(fontsize=9) + fig.tight_layout() + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + + +def _plot_roc_curves( + pdf: PdfPages, + task: str, + marker_filter: str | None, + y_val: np.ndarray, + y_val_proba: np.ndarray, + classes: list[str], +) -> None: + """One-vs-rest ROC curves for a single (task, marker) classifier.""" + from sklearn.metrics import roc_curve + from sklearn.preprocessing import label_binarize + + # Colorblind-friendly palette (Wong 2011) + palette = ["#0072B2", "#E69F00", "#009E73", "#CC79A7", "#D55E00", "#56B4E9", "#F0E442"] + + fig, ax = plt.subplots(figsize=(6, 5)) + ax.set_title(f"ROC — {task} ({marker_filter})", fontsize=11) + + if len(classes) == 2: + fpr, tpr, _ = roc_curve(y_val, y_val_proba[:, 1], pos_label=classes[1]) + auroc = float(np.trapezoid(tpr, fpr)) + ax.plot(fpr, tpr, color=palette[0], linewidth=2, label=f"{classes[1]} (AUROC={auroc:.3f})") + else: + y_bin = label_binarize(y_val, classes=classes) + for i, cls in enumerate(classes): + fpr, tpr, _ = roc_curve(y_bin[:, i], y_val_proba[:, i]) + auroc = float(np.trapezoid(tpr, fpr)) + ax.plot(fpr, tpr, color=palette[i % len(palette)], linewidth=1.5, label=f"{cls} (AUROC={auroc:.3f})") + + ax.plot([0, 1], [0, 1], "k--", linewidth=0.8) + ax.set_xlabel("False Positive Rate") + ax.set_ylabel("True Positive Rate") + ax.set_xlim([0, 1]) + ax.set_ylim([0, 1.05]) + ax.legend(fontsize=8, loc="lower right") + fig.tight_layout() + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + + +def _plot_f1_over_time( + pdf: PdfPages, + task: str, + marker_filter: str | None, + y_val: np.ndarray, + y_val_proba: np.ndarray, + classes: list[str], + val_hours: np.ndarray, +) -> None: + """Per-class F1 at each unique timepoint for a single (task, marker) classifier.""" + from sklearn.metrics import f1_score + + palette = ["#0072B2", "#E69F00", "#009E73", "#CC79A7", "#D55E00", "#56B4E9", "#F0E442"] + + y_pred = np.array(classes)[np.argmax(y_val_proba, axis=1)] + timepoints = sorted(np.unique(val_hours[~np.isnan(val_hours)])) + + # (n_timepoints, n_classes) + f1_per_time = np.full((len(timepoints), len(classes)), np.nan) + for ti, t in enumerate(timepoints): + mask = val_hours == t + if mask.sum() < 2: + continue + f1s = f1_score(y_val[mask], y_pred[mask], labels=classes, average=None, zero_division=0) + f1_per_time[ti] = f1s + + fig, ax = plt.subplots(figsize=(8, 5)) + for ci, cls in enumerate(classes): + ax.plot(timepoints, f1_per_time[:, ci], marker="o", color=palette[ci % len(palette)], linewidth=2, label=cls) + + ax.set_xlabel("Hours post perturbation") + ax.set_ylabel("F1 score") + ax.set_ylim(0, 1.05) + ax.axhline(0.5, color="gray", linewidth=0.8, linestyle="--") + ax.set_title(f"F1 over time — {task} ({marker_filter})") + ax.legend(fontsize=9) + fig.tight_layout() + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + + +class _RunLinearClassifiersConfig: + """Config container for the run-linear-classifiers CLI.""" + + def __init__(self, raw: dict): + from dynaclr.evaluation.evaluate_config import LinearClassifiersStepConfig + + self.embeddings_path = Path(raw["embeddings_path"]) + self.output_dir = Path(raw["output_dir"]) + self.lc_config = LinearClassifiersStepConfig( + **{k: v for k, v in raw.items() if k not in ("embeddings_path", "output_dir")} + ) + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to YAML configuration file", +) +def main(config: Path) -> None: + """Run linear classifiers on a combined embeddings zarr from the evaluation orchestrator.""" + raw = load_config(config) + cfg = _RunLinearClassifiersConfig(raw) + run_linear_classifiers(cfg.embeddings_path, cfg.lc_config, cfg.output_dir) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/orchestrated_test.py b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/orchestrated_test.py new file mode 100644 index 000000000..a28db7569 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/orchestrated_test.py @@ -0,0 +1,336 @@ +"""Tests for the orchestrated linear classifiers evaluation.""" + +from pathlib import Path + +import anndata as ad +import numpy as np +import pandas as pd +import pytest + +from dynaclr.evaluation.evaluate_config import AnnotationSource, LinearClassifiersStepConfig, TaskSpec +from dynaclr.evaluation.linear_classifiers.orchestrated import run_linear_classifiers + + +def _make_embeddings_zarr( + path: Path, + n_cells: int = 200, + n_features: int = 16, + experiment: str = "exp_A", + use_id_col: bool = True, + extra_markers: list[tuple[str, int]] | None = None, +) -> ad.AnnData: + """Write a synthetic embeddings zarr and return the AnnData. + + Parameters + ---------- + extra_markers : list of (marker_name, n_cells) tuples, optional + Additional markers appended after the default Phase3D/TOMM20 split. + """ + half = n_cells // 2 + markers = ["Phase3D"] * half + ["TOMM20"] * half + extra_cells: list[dict] = [] + if extra_markers: + for marker_name, m_count in extra_markers: + markers += [marker_name] * m_count + extra_cells += [{}] * m_count + + total = n_cells + len(extra_cells) + rng = np.random.default_rng(42) + X = rng.standard_normal((total, n_features)).astype(np.float32) + + obs: dict = { + "fov_name": [f"A/1/FOV{i % 5}" for i in range(total)], + "t": [i % 10 for i in range(total)], + "track_id": list(range(total)), + "experiment": [experiment] * total, + "marker": markers, + "perturbation": ["uninfected"] * (total // 2) + ["ZIKV"] * (total - total // 2), + "hours_post_perturbation": [float(i % 5) * 24.0 for i in range(total)], + } + if use_id_col: + obs["id"] = list(range(total)) + + df = pd.DataFrame(obs) + # Convert string columns to object dtype — pandas 3 defaults to ArrowStringArray + # which anndata's zarr writer does not support. + for col in df.select_dtypes("string").columns: + df[col] = df[col].astype(object) + df.index = pd.Index([str(i) for i in range(total)], dtype=object) + var = pd.DataFrame(index=pd.Index([str(i) for i in range(n_features)], dtype=object)) + adata = ad.AnnData(X=X, obs=df, var=var) + adata.write_zarr(path) + return adata + + +def _make_embeddings_dir(tmp_path: Path, n_cells: int = 200, n_features: int = 16) -> Path: + """Write two per-experiment zarrs to a directory; return the directory path.""" + emb_dir = tmp_path / "embeddings" + emb_dir.mkdir() + _make_embeddings_zarr(emb_dir / "exp_A.zarr", n_cells=n_cells, n_features=n_features, experiment="exp_A") + _make_embeddings_zarr(emb_dir / "exp_B.zarr", n_cells=n_cells, n_features=n_features, experiment="exp_B") + return emb_dir + + +def _make_annotations( + tmp_path: Path, experiment: str, fov_names: list, ts: list, track_ids: list, hours: list | None = None +) -> Path: + """Create a synthetic annotation CSV with infection_state and organelle_state labels. + + fov_name is stored as the first path component only (e.g. "A/1/FOV0" → "A"), + matching what load_annotation_anndata extracts from obs via .str.split("/").str[0]. + """ + labels = ["uninfected" if i % 3 != 0 else "infected" for i in range(len(fov_names))] + # Extract first path component to match the join key in load_annotation_anndata + fov_first = [str(f).split("/")[0] for f in fov_names] + data: dict = { + "fov_name": fov_first, + "t": ts, + "track_id": track_ids, + "infection_state": labels, + "organelle_state": ["normal" if i % 4 != 0 else "abnormal" for i in range(len(fov_names))], + } + if hours is not None: + data["hours_post_perturbation"] = hours + df = pd.DataFrame(data) + csv_path = tmp_path / f"{experiment}_annotations.csv" + df.to_csv(csv_path, index=False) + return csv_path + + +def _setup_dir_with_annotations(tmp_path: Path) -> tuple[Path, Path, Path]: + """Create embeddings directory + annotation CSVs for exp_A and exp_B.""" + emb_dir = _make_embeddings_dir(tmp_path) + ann_paths = {} + for exp in ["exp_A", "exp_B"]: + adata = ad.read_zarr(emb_dir / f"{exp}.zarr") + ann_paths[exp] = _make_annotations( + tmp_path, + exp, + adata.obs["fov_name"].tolist(), + adata.obs["t"].tolist(), + adata.obs["track_id"].tolist(), + hours=adata.obs["hours_post_perturbation"].tolist(), + ) + return emb_dir, ann_paths["exp_A"], ann_paths["exp_B"] + + +def test_run_linear_classifiers_directory_mode(tmp_path): + """Embeddings directory (post-split) is loaded and concatenated correctly.""" + emb_dir, ann_a, ann_b = _setup_dir_with_annotations(tmp_path) + + config = LinearClassifiersStepConfig( + annotations=[ + AnnotationSource(experiment="exp_A", path=str(ann_a)), + AnnotationSource(experiment="exp_B", path=str(ann_b)), + ], + tasks=[TaskSpec(task="infection_state")], + use_scaling=True, + split_train_data=0.8, + ) + + results = run_linear_classifiers(emb_dir, config, tmp_path / "out") + + # auto-expand to Phase3D and TOMM20 → 2 rows + assert len(results) == 2 + assert set(results["marker_filter"].tolist()) == {"Phase3D", "TOMM20"} + assert results.iloc[0]["task"] == "infection_state" + assert results.iloc[0]["n_samples"] == 200 # 100 per experiment × 2 + assert (tmp_path / "out" / "metrics_summary.csv").exists() + # one summary PDF per task + assert (tmp_path / "out" / "infection_state_summary.pdf").exists() + + +def test_run_linear_classifiers_single_zarr_mode(tmp_path): + """Single combined zarr (pre-split) is still accepted.""" + zarr_path = tmp_path / "embeddings.zarr" + adata = _make_embeddings_zarr(zarr_path, experiment="exp_A") + ann = _make_annotations( + tmp_path, + "exp_A", + adata.obs["fov_name"].tolist(), + adata.obs["t"].tolist(), + adata.obs["track_id"].tolist(), + hours=adata.obs["hours_post_perturbation"].tolist(), + ) + + config = LinearClassifiersStepConfig( + annotations=[AnnotationSource(experiment="exp_A", path=str(ann))], + tasks=[TaskSpec(task="infection_state")], + use_scaling=True, + split_train_data=0.8, + ) + + results = run_linear_classifiers(zarr_path, config, tmp_path / "out") + # auto-expand to Phase3D and TOMM20 → 2 rows + assert len(results) == 2 + assert set(results["marker_filter"].tolist()) == {"Phase3D", "TOMM20"} + + +def test_run_linear_classifiers_fallback_join_no_id(tmp_path): + """Annotation join falls back to (fov_name, t, track_id) when id column is absent.""" + zarr_path = tmp_path / "embeddings.zarr" + adata = _make_embeddings_zarr(zarr_path, experiment="exp_A", use_id_col=False) + + assert "id" not in adata.obs.columns + + ann = _make_annotations( + tmp_path, + "exp_A", + adata.obs["fov_name"].tolist(), + adata.obs["t"].tolist(), + adata.obs["track_id"].tolist(), + ) + + config = LinearClassifiersStepConfig( + annotations=[AnnotationSource(experiment="exp_A", path=str(ann))], + tasks=[TaskSpec(task="infection_state")], + use_scaling=True, + split_train_data=0.8, + ) + + results = run_linear_classifiers(zarr_path, config, tmp_path / "out") + # auto-expand to Phase3D and TOMM20 → 2 rows, 100 cells each + assert len(results) == 2 + assert set(results["marker_filter"].tolist()) == {"Phase3D", "TOMM20"} + assert (results["n_samples"] == 100).all() + + +def test_run_linear_classifiers_multiple_tasks(tmp_path): + """Multiple tasks produce one row each in results.""" + emb_dir, ann_a, ann_b = _setup_dir_with_annotations(tmp_path) + + config = LinearClassifiersStepConfig( + annotations=[ + AnnotationSource(experiment="exp_A", path=str(ann_a)), + AnnotationSource(experiment="exp_B", path=str(ann_b)), + ], + tasks=[ + TaskSpec(task="infection_state"), + TaskSpec(task="organelle_state"), + ], + use_scaling=True, + split_train_data=0.8, + ) + + results = run_linear_classifiers(emb_dir, config, tmp_path / "out") + + # auto-expand to Phase3D and TOMM20 → 2 tasks × 2 markers = 4 rows + assert len(results) == 4 + assert set(results["task"].tolist()) == {"infection_state", "organelle_state"} + + +def test_run_linear_classifiers_marker_filter(tmp_path): + """marker_filters restricts cells to those with matching marker.""" + emb_dir, ann_a, ann_b = _setup_dir_with_annotations(tmp_path) + + config = LinearClassifiersStepConfig( + annotations=[ + AnnotationSource(experiment="exp_A", path=str(ann_a)), + AnnotationSource(experiment="exp_B", path=str(ann_b)), + ], + tasks=[TaskSpec(task="infection_state", marker_filters=["Phase3D"])], + use_scaling=True, + split_train_data=0.8, + ) + + results = run_linear_classifiers(emb_dir, config, tmp_path / "out") + + assert not results.empty + # Phase3D is half of each experiment → 100 per exp × 2 = 200 + assert results.iloc[0]["n_samples"] == 200 + + +def test_run_linear_classifiers_missing_metadata_raises(tmp_path): + """Raises ValueError when embeddings zarr lacks experiment/marker columns.""" + X = np.random.standard_normal((50, 8)).astype(np.float32) + obs = pd.DataFrame({"fov_name": [f"A/1/FOV{i}" for i in range(50)]}) + obs["fov_name"] = obs["fov_name"].astype(object) + obs.index = pd.Index([str(i) for i in range(50)], dtype=object) + var = pd.DataFrame(index=pd.Index([str(i) for i in range(8)], dtype=object)) + zarr_path = tmp_path / "embeddings.zarr" + ad.AnnData(X=X, obs=obs, var=var).write_zarr(zarr_path) + + config = LinearClassifiersStepConfig( + annotations=[AnnotationSource(experiment="exp_A", path=str(tmp_path / "ann.csv"))], + tasks=[TaskSpec(task="infection_state")], + ) + + with pytest.raises(ValueError, match="missing columns"): + run_linear_classifiers(zarr_path, config, tmp_path / "out") + + +def test_run_linear_classifiers_unknown_marker_skipped(tmp_path): + """If marker_filters matches no rows, task is skipped and result is empty.""" + emb_dir, ann_a, _ = _setup_dir_with_annotations(tmp_path) + + config = LinearClassifiersStepConfig( + annotations=[AnnotationSource(experiment="exp_A", path=str(ann_a))], + tasks=[TaskSpec(task="infection_state", marker_filters=["NonExistentMarker"])], + ) + + results = run_linear_classifiers(emb_dir, config, tmp_path / "out") + assert results.empty + + +def test_run_linear_classifiers_sparse_marker_skipped(tmp_path): + """Sparse marker with too few samples for stratified split is skipped without crashing.""" + emb_dir = tmp_path / "embeddings" + emb_dir.mkdir() + + # exp_A: 200 cells (Phase3D/TOMM20) + 4 RARE cells (1 infected, 3 uninfected) + adata_a = _make_embeddings_zarr( + emb_dir / "exp_A.zarr", + n_cells=200, + experiment="exp_A", + extra_markers=[("RARE", 4)], + ) + ann_a = _make_annotations( + tmp_path, + "exp_A", + adata_a.obs["fov_name"].tolist(), + adata_a.obs["t"].tolist(), + adata_a.obs["track_id"].tolist(), + hours=adata_a.obs["hours_post_perturbation"].tolist(), + ) + # Override RARE annotation so only 1 sample is "infected" (too few for stratified split) + df = pd.read_csv(ann_a) + rare_idx = adata_a.obs.index[adata_a.obs["marker"] == "RARE"].tolist() + rare_rows = df[df["track_id"].isin([int(i) for i in rare_idx])] + df.loc[rare_rows.index, "infection_state"] = ["infected"] + ["uninfected"] * (len(rare_rows) - 1) + df.to_csv(ann_a, index=False) + + config = LinearClassifiersStepConfig( + annotations=[AnnotationSource(experiment="exp_A", path=str(ann_a))], + tasks=[TaskSpec(task="infection_state")], + use_scaling=True, + split_train_data=0.8, + ) + + # Must not crash; RARE is skipped due to insufficient samples + results = run_linear_classifiers(emb_dir, config, tmp_path / "out") + assert not results.empty + assert "RARE" not in results["marker_filter"].tolist() + assert set(results["marker_filter"].tolist()) == {"Phase3D", "TOMM20"} + + +def test_run_linear_classifiers_f1_over_time_plots_written(tmp_path): + """F1-over-time plots are written when hours_post_perturbation is present.""" + emb_dir, ann_a, ann_b = _setup_dir_with_annotations(tmp_path) + + config = LinearClassifiersStepConfig( + annotations=[ + AnnotationSource(experiment="exp_A", path=str(ann_a)), + AnnotationSource(experiment="exp_B", path=str(ann_b)), + ], + tasks=[TaskSpec(task="infection_state", marker_filters=["Phase3D"])], + use_scaling=True, + split_train_data=0.8, + ) + + out_dir = tmp_path / "out" + results = run_linear_classifiers(emb_dir, config, out_dir) + + assert not results.empty + pdf_path = out_dir / "infection_state_summary.pdf" + assert pdf_path.exists() + assert pdf_path.stat().st_size > 0 diff --git a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/report.py b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/report.py index a55b68e33..e63af9086 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/report.py +++ b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/report.py @@ -1,10 +1,7 @@ -"""PDF report generation for linear classifier evaluation and cross-validation. +"""PDF report generation for linear classifier cross-validation. -Provides two report generators: -- ``generate_comparison_report``: Evaluation report comparing models on a test set. -- ``generate_cv_report``: Cross-validation report with impact analysis. - -Both are optional and gated behind the ``--report`` flag in the respective scripts. +Provides ``generate_cv_report`` for cross-validation reports with impact analysis. +This is optional and gated behind the ``--report`` flag in the cross-validation script. """ from __future__ import annotations @@ -20,7 +17,6 @@ import pandas as pd from matplotlib.backends.backend_pdf import PdfPages from matplotlib.patches import Patch -from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix matplotlib.use("Agg") @@ -39,9 +35,6 @@ "baseline": _COLOR_BASELINE, } -_MODEL_COLORS = {"2D": "#1f77b4", "3D": "#ff7f0e"} -_EXTRA_COLORS = ["#2ca02c", "#9467bd", "#8c564b", "#e377c2"] - _TEMPORAL_PALETTE = [ "#0072B2", "#E69F00", @@ -54,281 +47,6 @@ ] -def _get_model_color(label: str, idx: int = 0) -> str: - return _MODEL_COLORS.get(label, _EXTRA_COLORS[idx % len(_EXTRA_COLORS)]) - - -# --------------------------------------------------------------------------- -# Evaluation report -# --------------------------------------------------------------------------- - - -def generate_comparison_report( - output_dir: Path, - dataset_name: str, - model_labels: list[str], - tasks: list[str], - channels: list[str], - train_results: dict[str, dict[tuple[str, str], dict[str, Any]]], - eval_results: dict[str, dict[tuple[str, str], dict[str, Any]]], -) -> Path: - """Generate a PDF comparing model performance on a held-out test set. - - Parameters - ---------- - output_dir : Path - Directory to save the report. - dataset_name : str - Name of the test dataset. - model_labels : list[str] - Model labels (e.g. ``["2D", "3D"]``). - tasks : list[str] - Classification tasks evaluated. - channels : list[str] - Input channels evaluated. - train_results : dict - ``model_label -> (task, channel) -> {"metrics": {...}, ...}``. - eval_results : dict - ``model_label -> (task, channel) -> {"metrics": {...}, "annotated_adata": ...}``. - - Returns - ------- - Path - Path to the generated PDF. - """ - report_path = output_dir / f"{dataset_name}_comparison_report.pdf" - output_dir.mkdir(parents=True, exist_ok=True) - - with PdfPages(report_path) as pdf: - _eval_page_title(pdf, dataset_name, model_labels, tasks, channels, train_results) - _eval_page_global_metrics(pdf, model_labels, tasks, channels, train_results, eval_results) - for task in tasks: - _eval_page_task_comparison(pdf, task, model_labels, channels, eval_results) - for channel in channels: - _eval_page_channel_comparison(pdf, channel, model_labels, tasks, train_results, eval_results) - - print(f"\nReport saved: {report_path}") - return report_path - - -def _eval_page_title(pdf, dataset_name, model_labels, tasks, channels, train_results): - fig, ax = plt.subplots(figsize=(11, 8.5)) - ax.axis("off") - - lines = [ - "Linear Classifier Comparison Report", - "", - f"Test Dataset: {dataset_name}", - "", - ] - for label in model_labels: - n_combos = len(train_results.get(label, {})) - lines.append(f"Model {label}: {n_combos} classifiers trained") - lines.append("") - lines.append(f"Channels: {', '.join(channels)}") - lines.append(f"Tasks: {', '.join(tasks)}") - - ax.text( - 0.5, - 0.5, - "\n".join(lines), - transform=ax.transAxes, - fontsize=12, - verticalalignment="center", - horizontalalignment="center", - fontfamily="monospace", - ) - fig.suptitle("Model Comparison", fontsize=16, fontweight="bold") - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - - -def _eval_page_global_metrics(pdf, model_labels, tasks, channels, train_results, eval_results): - fig, ax = plt.subplots(figsize=(11, 8.5)) - ax.axis("off") - fig.suptitle("Global Metrics Summary", fontsize=14, fontweight="bold") - - col_labels = ["Task", "Channel"] - for label in model_labels: - col_labels.extend([f"{label}\nVal Acc", f"{label}\nVal F1", f"{label}\nTest Acc", f"{label}\nTest F1"]) - - table_data = [] - for task in tasks: - for channel in channels: - row = [task, channel] - for label in model_labels: - train_r = train_results.get(label, {}).get((task, channel)) - eval_r = eval_results.get(label, {}).get((task, channel)) - val_acc = f"{train_r['metrics']['val_accuracy']:.3f}" if train_r else "-" - val_f1 = f"{train_r['metrics']['val_weighted_f1']:.3f}" if train_r else "-" - test_acc = f"{eval_r['metrics']['test_accuracy']:.3f}" if eval_r else "-" - test_f1 = f"{eval_r['metrics']['test_weighted_f1']:.3f}" if eval_r else "-" - row.extend([val_acc, val_f1, test_acc, test_f1]) - table_data.append(row) - - if table_data: - table = ax.table(cellText=table_data, colLabels=col_labels, loc="center", cellLoc="center") - table.auto_set_font_size(False) - table.set_fontsize(8) - table.scale(1.0, 1.4) - - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - - -def _eval_page_task_comparison(pdf, task, model_labels, channels, eval_results): - n_models = len(model_labels) - - all_classes: set[str] = set() - for label in model_labels: - for ch in channels: - r = eval_results.get(label, {}).get((task, ch)) - if r and "annotated_adata" in r: - adata = r["annotated_adata"] - if task in adata.obs.columns: - all_classes.update(adata.obs[task].dropna().unique()) - all_classes_sorted = sorted(all_classes) - - # F1 bar chart - fig, ax_bar = plt.subplots(figsize=(11, 5)) - fig.suptitle(f"Task: {task} - Per-Class F1", fontsize=14, fontweight="bold") - - if all_classes_sorted: - x = np.arange(len(all_classes_sorted)) - width = 0.8 / max(n_models, 1) - for i, label in enumerate(model_labels): - f1_values = [] - for cls in all_classes_sorted: - f1s = [] - for ch in channels: - r = eval_results.get(label, {}).get((task, ch)) - if r: - f1 = r["metrics"].get(f"test_{cls}_f1") - if f1 is not None: - f1s.append(f1) - f1_values.append(np.mean(f1s) if f1s else 0) - ax_bar.bar( - x + i * width, - f1_values, - width, - label=label, - color=_get_model_color(label, i), - ) - ax_bar.set_xticks(x + width * (n_models - 1) / 2) - ax_bar.set_xticklabels(all_classes_sorted) - ax_bar.set_ylabel("Test F1 (avg across channels)") - ax_bar.legend() - ax_bar.set_ylim(0, 1.05) - - fig.tight_layout() - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - - # Confusion matrices - n_cols = len(channels) - n_rows = n_models - if n_cols == 0 or n_rows == 0: - return - - fig_cm, cm_axes = plt.subplots(n_rows, max(n_cols, 1), figsize=(4 * max(n_cols, 1), 3.5 * n_rows)) - fig_cm.suptitle(f"Confusion Matrices: {task}", fontsize=14, fontweight="bold") - - if n_rows == 1 and n_cols == 1: - cm_axes = [[cm_axes]] - elif n_rows == 1: - cm_axes = [cm_axes] - elif n_cols == 1: - cm_axes = [[row] for row in cm_axes] - - for i, label in enumerate(model_labels): - for j, ch in enumerate(channels): - ax = cm_axes[i][j] - r = eval_results.get(label, {}).get((task, ch)) - if r and "annotated_adata" in r: - adata = r["annotated_adata"] - pred_col = f"predicted_{task}" - mask = adata.obs[task].notna() & (adata.obs[task] != "unknown") - subset = adata[mask] - if len(subset) > 0 and pred_col in subset.obs.columns: - y_true = subset.obs[task].values - y_pred = subset.obs[pred_col].values - labels = sorted(set(y_true) | set(y_pred)) - cm = confusion_matrix(y_true, y_pred, labels=labels) - ConfusionMatrixDisplay(cm, display_labels=labels).plot(ax=ax, cmap="Blues", colorbar=False) - ax.set_title(f"{label} / {ch}", fontsize=10) - - fig_cm.tight_layout() - pdf.savefig(fig_cm, bbox_inches="tight") - plt.close(fig_cm) - - -def _eval_page_channel_comparison(pdf, channel, model_labels, tasks, train_results, eval_results): - fig, axes = plt.subplots(1, 2, figsize=(11, 5)) - fig.suptitle(f"Channel: {channel}", fontsize=14, fontweight="bold") - - n_models = len(model_labels) - x = np.arange(len(tasks)) - width = 0.8 / max(n_models, 1) - - ax = axes[0] - for i, label in enumerate(model_labels): - accs = [] - for task in tasks: - r = eval_results.get(label, {}).get((task, channel)) - accs.append(r["metrics"]["test_accuracy"] if r else 0) - ax.bar( - x + i * width, - accs, - width, - label=label, - color=_get_model_color(label, i), - ) - ax.set_xticks(x + width * (n_models - 1) / 2) - ax.set_xticklabels(tasks, rotation=30, ha="right", fontsize=8) - ax.set_ylabel("Test Accuracy") - ax.set_ylim(0, 1.05) - ax.legend() - ax.set_title("Test Accuracy") - - ax2 = axes[1] - for i, label in enumerate(model_labels): - val_accs, test_accs = [], [] - for task in tasks: - tr = train_results.get(label, {}).get((task, channel)) - ev = eval_results.get(label, {}).get((task, channel)) - val_accs.append(tr["metrics"]["val_accuracy"] if tr else 0) - test_accs.append(ev["metrics"]["test_accuracy"] if ev else 0) - - color = _get_model_color(label, i) - ax2.bar( - x + i * width - width / 4, - val_accs, - width / 2, - label=f"{label} Val", - color=color, - alpha=0.5, - ) - ax2.bar( - x + i * width + width / 4, - test_accs, - width / 2, - label=f"{label} Test", - color=color, - alpha=1.0, - ) - - ax2.set_xticks(x + width * (n_models - 1) / 2) - ax2.set_xticklabels(tasks, rotation=30, ha="right", fontsize=8) - ax2.set_ylabel("Accuracy") - ax2.set_ylim(0, 1.05) - ax2.legend(fontsize=7) - ax2.set_title("Val vs Test (Generalization)") - - fig.tight_layout() - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - - # --------------------------------------------------------------------------- # Cross-validation report # --------------------------------------------------------------------------- diff --git a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/train_linear_classifier.py b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/train_linear_classifier.py index 00e62aa41..d79ff4e8a 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/train_linear_classifier.py +++ b/applications/dynaclr/src/dynaclr/evaluation/linear_classifiers/train_linear_classifier.py @@ -9,7 +9,7 @@ import click from pydantic import ValidationError -from viscy_utils.cli_utils import format_markdown_table, load_config +from viscy_utils.cli_utils import format_markdown_table, load_config_section from viscy_utils.evaluation.linear_classifier import ( load_and_combine_datasets, save_pipeline_to_wandb, @@ -68,7 +68,7 @@ def main(config: Path): click.echo("=" * 60) try: - config_dict = load_config(config) + config_dict = load_config_section(config, None, default_section="train_linear_classifier") train_config = LinearClassifierTrainConfig(**config_dict) except ValidationError as e: click.echo(f"\n Configuration validation failed:\n{e}", err=True) @@ -103,7 +103,7 @@ def main(config: Path): "random_state": train_config.random_seed, } - pipeline, metrics = train_linear_classifier( + pipeline, metrics, _ = train_linear_classifier( adata=combined_adata, task=train_config.task, use_scaling=train_config.use_scaling, diff --git a/applications/dynaclr/src/dynaclr/evaluation/mmd/__init__.py b/applications/dynaclr/src/dynaclr/evaluation/mmd/__init__.py new file mode 100644 index 000000000..1419a3501 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/mmd/__init__.py @@ -0,0 +1 @@ +"""MMD-based evaluation of perturbation effects in cell embedding space.""" diff --git a/applications/dynaclr/src/dynaclr/evaluation/mmd/compute_mmd.py b/applications/dynaclr/src/dynaclr/evaluation/mmd/compute_mmd.py new file mode 100644 index 000000000..c08fdc40b --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/mmd/compute_mmd.py @@ -0,0 +1,924 @@ +"""CLI and analysis logic for MMD-based perturbation effect evaluation.""" + +from __future__ import annotations + +from pathlib import Path + +import anndata as ad +import click +import numpy as np +import pandas as pd + +from dynaclr.evaluation.mmd.config import ( + ComparisonSpec, + MMDCombinedConfig, + MMDEvalConfig, + MMDPooledConfig, + MMDSettings, + _resolve_bin_edges, +) +from viscy_utils.compose import load_composed_config +from viscy_utils.evaluation.mmd import median_heuristic, mmd_permutation_test + + +def _extract_embeddings(adata: ad.AnnData, embedding_key: str | None) -> np.ndarray: + """Extract embedding matrix from AnnData. + + Parameters + ---------- + adata : AnnData + AnnData store with ``.X`` or ``.obsm``. + embedding_key : str or None + obsm key, or None to use ``.X``. + + Returns + ------- + np.ndarray + Embedding matrix, shape (n_cells, n_features). + """ + if embedding_key is None: + X = adata.X + else: + X = adata.obsm[embedding_key] + if hasattr(X, "toarray"): + return X.toarray() + return np.asarray(X) + + +def _subsample(X: np.ndarray, max_n: int | None, rng: np.random.Generator) -> np.ndarray: + if max_n is None or len(X) <= max_n: + return X + idx = rng.choice(len(X), max_n, replace=False) + return X[idx] + + +def _run_one_comparison( + emb_a: np.ndarray, + emb_b: np.ndarray, + settings: MMDSettings, + bandwidth: float | None = None, +) -> tuple[float, float, float, float, float, int, int]: + """Run MMD permutation test for one (cond_a, cond_b) pair. + + Parameters + ---------- + emb_a : np.ndarray + Embeddings for group A. + emb_b : np.ndarray + Embeddings for group B. + settings : MMDSettings + Algorithm settings. + bandwidth : float or None + Pre-computed bandwidth to use. If None, computed via median heuristic. + Pass a value to share bandwidth across comparisons within the same group. + + Returns + ------- + mmd2 : float + p_value : float + bandwidth : float + effect_size : float + mmd2 / bandwidth + activity_zscore : float + (mmd2 - null_mean) / null_std — normalizes observed MMD relative to + the permutation null, comparable across markers and datasets. + n_a_used : int + Actual number of cells used from group A after subsampling/balancing. + n_b_used : int + Actual number of cells used from group B after subsampling/balancing. + All metric floats are NaN if fewer than min_cells cells in either group. + """ + rng = np.random.default_rng(settings.seed) + emb_a = _subsample(emb_a, settings.max_cells, rng) + emb_b = _subsample(emb_b, settings.max_cells, rng) + if settings.balance_samples: + min_n = min(len(emb_a), len(emb_b)) + emb_a = _subsample(emb_a, min_n, rng) + emb_b = _subsample(emb_b, min_n, rng) + n_a_used = len(emb_a) + n_b_used = len(emb_b) + if n_a_used < settings.min_cells or n_b_used < settings.min_cells: + return float("nan"), float("nan"), float("nan"), float("nan"), float("nan"), n_a_used, n_b_used + if bandwidth is None: + bandwidth = median_heuristic(emb_a, emb_b) + mmd2, p_value, null_dist = mmd_permutation_test( + emb_a, emb_b, n_permutations=settings.n_permutations, bandwidth=bandwidth, seed=settings.seed + ) + effect_size = mmd2 / bandwidth if bandwidth > 0 else float("nan") + activity_zscore = float((mmd2 - null_dist.mean()) / (null_dist.std() + 1e-12)) + return mmd2, p_value, bandwidth, effect_size, activity_zscore, n_a_used, n_b_used + + +def _run_map_comparison( + meta: pd.DataFrame, + features: np.ndarray, + comp: ComparisonSpec, + group_by: str, + marker: str, + map_settings, +) -> tuple[float, float]: + """Run copairs mAP for one comparison. + + Returns + ------- + map_value : float + map_p_value : float + Both NaN on failure or if copairs is unavailable. + """ + try: + from viscy_utils.evaluation.embedding_map import compute_embedding_map + except ImportError: + return float("nan"), float("nan") + result = compute_embedding_map( + meta=meta, + features=features, + reference_condition=comp.cond_a, + target_condition=comp.cond_b, + condition_col=group_by, + group_col="marker", + distance=map_settings.distance, + null_size=map_settings.null_size, + seed=map_settings.seed, + ) + if result is None: + return float("nan"), float("nan") + return result["mean_average_precision"], result["p_value"] + + +def run_mmd_analysis(adata: ad.AnnData, config: MMDEvalConfig) -> pd.DataFrame: + """Run per-experiment MMD analysis for explicit comparison pairs across all markers. + + Each comparison is an explicit ``(cond_a, cond_b)`` pair with a label. + The analysis is always faceted by ``obs["marker"]`` and ``obs["experiment"]``. + Each experiment is processed independently to avoid cross-experiment pooling. + + Parameters + ---------- + adata : AnnData + AnnData (single- or multi-experiment) after split-embeddings step. + config : MMDEvalConfig + Analysis configuration. + + Returns + ------- + pd.DataFrame + Results with columns: experiment, marker, cond_a, cond_b, label, + hours_bin_start, hours_bin_end, n_a, n_b, mmd2, p_value, bandwidth, + effect_size, activity_zscore, embedding_key, and optionally map_value, + map_p_value. + """ + if config.obs_filter: + mask = pd.Series([True] * len(adata), index=adata.obs.index) + for col, val in config.obs_filter.items(): + if col not in adata.obs.columns: + raise KeyError(f"obs_filter column '{col}' not found. Available: {list(adata.obs.columns)}") + mask &= adata.obs[col] == val + adata = adata[mask].copy() + + obs = adata.obs + if config.group_by not in obs.columns: + raise KeyError(f"obs column '{config.group_by}' not found. Available: {list(obs.columns)}") + + emb_key_label = config.embedding_key if config.embedding_key is not None else "X" + all_emb = _extract_embeddings(adata, config.embedding_key) + experiments = obs["experiment"].unique() if "experiment" in obs.columns else ["unknown"] + + records: list[dict] = [] + for experiment in experiments: + exp_mask = ( + obs["experiment"] == experiment + if "experiment" in obs.columns + else pd.Series([True] * len(obs), index=obs.index) + ) + for marker in sorted(obs["marker"].unique()): + marker_mask = exp_mask & (obs["marker"] == marker) + + if config.temporal_bin_size is None and config.temporal_bins is None: + # Aggregate mode + shared_bw = _compute_shared_bandwidth( + all_emb, obs, marker_mask, config.comparisons, config.mmd, config.group_by + ) + for comp in config.comparisons: + mask_a = marker_mask & (obs[config.group_by] == comp.cond_a) + mask_b = marker_mask & (obs[config.group_by] == comp.cond_b) + emb_a = all_emb[mask_a.values] + emb_b = all_emb[mask_b.values] + bw = shared_bw if shared_bw is not None else None + mmd2, p_value, bw_out, es, az, na, nb = _run_one_comparison(emb_a, emb_b, config.mmd, bandwidth=bw) + map_val, map_pval = _maybe_map( + obs[marker_mask.values], + all_emb[marker_mask.values], + comp, + config.group_by, + marker, + config.map_settings, + ) + records.append( + _record( + experiment, + marker, + comp, + float("nan"), + float("nan"), + na, + nb, + mmd2, + p_value, + bw_out, + es, + az, + map_val, + map_pval, + emb_key_label, + ) + ) + else: + if "hours_post_perturbation" not in obs.columns: + raise KeyError("temporal binning requires obs column 'hours_post_perturbation'") + max_hours = obs["hours_post_perturbation"].max() + bin_pairs = _resolve_bin_edges(config.temporal_bin_size, config.temporal_bins, max_hours) + for b_start, b_end in bin_pairs: + shared_bw = _compute_shared_bandwidth_temporal( + all_emb, obs, marker_mask, config.comparisons, config.mmd, config.group_by, b_start, b_end + ) + for comp in config.comparisons: + mask_a = marker_mask & (obs[config.group_by] == comp.cond_a) + bin_mask_b = ( + marker_mask + & (obs[config.group_by] == comp.cond_b) + & (obs["hours_post_perturbation"] >= b_start) + & (obs["hours_post_perturbation"] < b_end) + ) + emb_a = all_emb[mask_a.values] + emb_b = all_emb[bin_mask_b.values] + bw = shared_bw if shared_bw is not None else None + mmd2, p_value, bw_out, es, az, na, nb = _run_one_comparison( + emb_a, emb_b, config.mmd, bandwidth=bw + ) + map_val, map_pval = _maybe_map( + obs[marker_mask.values], + all_emb[marker_mask.values], + comp, + config.group_by, + marker, + config.map_settings, + ) + records.append( + _record( + experiment, + marker, + comp, + b_start, + b_end, + na, + nb, + mmd2, + p_value, + bw_out, + es, + az, + map_val, + map_pval, + emb_key_label, + ) + ) + return pd.DataFrame(records) + + +def _compute_shared_bandwidth( + all_emb: np.ndarray, + obs: pd.DataFrame, + marker_mask: pd.Series, + comparisons: list[ComparisonSpec], + settings: MMDSettings, + group_by: str, +) -> float | None: + """Compute bandwidth from the share_bandwidth_from comparison, if configured.""" + if settings.share_bandwidth_from is None: + return None + for comp in comparisons: + if comp.label == settings.share_bandwidth_from: + mask_a = marker_mask & (obs[group_by] == comp.cond_a) + mask_b = marker_mask & (obs[group_by] == comp.cond_b) + emb_a = all_emb[mask_a.values] + emb_b = all_emb[mask_b.values] + if len(emb_a) >= settings.min_cells and len(emb_b) >= settings.min_cells: + return median_heuristic(emb_a, emb_b) + return None + return None + + +def _compute_shared_bandwidth_temporal( + all_emb: np.ndarray, + obs: pd.DataFrame, + marker_mask: pd.Series, + comparisons: list[ComparisonSpec], + settings: MMDSettings, + group_by: str, + b_start: float, + b_end: float, +) -> float | None: + """Compute shared bandwidth from the share_bandwidth_from comparison for a temporal bin.""" + if settings.share_bandwidth_from is None: + return None + for comp in comparisons: + if comp.label == settings.share_bandwidth_from: + mask_a = ( + marker_mask + & (obs[group_by] == comp.cond_a) + & (obs["hours_post_perturbation"] >= b_start) + & (obs["hours_post_perturbation"] < b_end) + ) + mask_b = ( + marker_mask + & (obs[group_by] == comp.cond_b) + & (obs["hours_post_perturbation"] >= b_start) + & (obs["hours_post_perturbation"] < b_end) + ) + emb_a = all_emb[mask_a.values] + emb_b = all_emb[mask_b.values] + if len(emb_a) >= settings.min_cells and len(emb_b) >= settings.min_cells: + return median_heuristic(emb_a, emb_b) + return None + return None + + +def _maybe_map( + obs_sub: pd.DataFrame, + emb_sub: np.ndarray, + comp: ComparisonSpec, + group_by: str, + marker: str, + map_settings, +) -> tuple[float, float]: + """Run mAP if enabled, otherwise return NaN pair.""" + if not map_settings.enabled: + return float("nan"), float("nan") + return _run_map_comparison(obs_sub, emb_sub, comp, group_by, marker, map_settings) + + +def _record( + experiment: str, + marker: str, + comp: ComparisonSpec, + hours_bin_start: float, + hours_bin_end: float, + n_a: int, + n_b: int, + mmd2: float, + p_value: float, + bandwidth: float, + effect_size: float, + activity_zscore: float, + map_value: float, + map_p_value: float, + embedding_key: str, +) -> dict: + return { + "experiment": experiment, + "marker": marker, + "cond_a": comp.cond_a, + "cond_b": comp.cond_b, + "label": comp.label, + "hours_bin_start": hours_bin_start, + "hours_bin_end": hours_bin_end, + "n_a": n_a, + "n_b": n_b, + "mmd2": mmd2, + "p_value": p_value, + "bandwidth": bandwidth, + "effect_size": effect_size, + "activity_zscore": activity_zscore, + "map_value": map_value, + "map_p_value": map_p_value, + "embedding_key": embedding_key, + } + + +def run_mmd_combined(config: MMDCombinedConfig) -> pd.DataFrame: + """Run pairwise cross-experiment MMD, faceted by marker and condition+time bin. + + For each marker, finds all experiments that share it, then for each pair + of those experiments runs MMD per (condition, time_bin) after centering + within that pair only. This measures batch effects between experiments + at matched biological states. + + Parameters + ---------- + config : MMDCombinedConfig + Combined analysis configuration. + + Returns + ------- + pd.DataFrame + Results with columns: marker, exp_a, exp_b, condition, hours_bin_start, + hours_bin_end, n_a, n_b, mmd2, p_value, bandwidth, effect_size, + activity_zscore, embedding_key. + """ + from itertools import combinations + + adatas = {ad.read_zarr(p).obs["experiment"].iloc[0]: ad.read_zarr(p) for p in config.input_paths} + + if config.obs_filter: + filtered = {} + for exp_name, adata in adatas.items(): + mask = pd.Series([True] * len(adata), index=adata.obs.index) + for col, val in config.obs_filter.items(): + if col not in adata.obs.columns: + raise KeyError( + f"obs_filter column '{col}' not found in {exp_name}. Available: {list(adata.obs.columns)}" + ) + mask &= adata.obs[col] == val + filtered[exp_name] = adata[mask].copy() + adatas = filtered + + marker_to_exps: dict[str, list[str]] = {} + for exp_name, adata in adatas.items(): + for marker in adata.obs["marker"].unique(): + marker_to_exps.setdefault(marker, []).append(exp_name) + + emb_key_label = config.embedding_key if config.embedding_key is not None else "X" + records: list[dict] = [] + + for marker, exp_names in sorted(marker_to_exps.items()): + if len(exp_names) < 2: + continue + for exp_a, exp_b in combinations(exp_names, 2): + adata_a = adatas[exp_a][adatas[exp_a].obs["marker"] == marker] + adata_b = adatas[exp_b][adatas[exp_b].obs["marker"] == marker] + emb_a_full = _extract_embeddings(adata_a, config.embedding_key).astype(np.float32) + emb_b_full = _extract_embeddings(adata_b, config.embedding_key).astype(np.float32) + obs_a = adata_a.obs + obs_b = adata_b.obs + + emb_a_full = emb_a_full - emb_a_full.mean(axis=0) + emb_b_full = emb_b_full - emb_b_full.mean(axis=0) + + conditions = sorted(set(obs_a[config.group_by].unique()) & set(obs_b[config.group_by].unique())) + for condition in conditions: + cond_mask_a = obs_a[config.group_by] == condition + cond_mask_b = obs_b[config.group_by] == condition + emb_ca = emb_a_full[cond_mask_a.values] + emb_cb = emb_b_full[cond_mask_b.values] + + if config.temporal_bin_size is None and config.temporal_bins is None: + mmd2, p_value, bw, es, az, na, nb = _run_one_comparison(emb_ca, emb_cb, config.mmd) + records.append( + _combined_record( + marker, + exp_a, + exp_b, + condition, + float("nan"), + float("nan"), + na, + nb, + mmd2, + p_value, + bw, + es, + az, + emb_key_label, + ) + ) + else: + if "hours_post_perturbation" not in obs_a.columns: + raise KeyError("temporal binning requires obs column 'hours_post_perturbation'") + max_hours = min(obs_a["hours_post_perturbation"].max(), obs_b["hours_post_perturbation"].max()) + bin_pairs = _resolve_bin_edges(config.temporal_bin_size, config.temporal_bins, max_hours) + for b_start, b_end in bin_pairs: + bin_mask_a = ( + cond_mask_a + & (obs_a["hours_post_perturbation"] >= b_start) + & (obs_a["hours_post_perturbation"] < b_end) + ) + bin_mask_b = ( + cond_mask_b + & (obs_b["hours_post_perturbation"] >= b_start) + & (obs_b["hours_post_perturbation"] < b_end) + ) + bin_emb_a = emb_a_full[bin_mask_a.values] + bin_emb_b = emb_b_full[bin_mask_b.values] + mmd2, p_value, bw, es, az, na, nb = _run_one_comparison(bin_emb_a, bin_emb_b, config.mmd) + records.append( + _combined_record( + marker, + exp_a, + exp_b, + condition, + b_start, + b_end, + na, + nb, + mmd2, + p_value, + bw, + es, + az, + emb_key_label, + ) + ) + + return pd.DataFrame(records) + + +def _combined_record( + marker: str, + exp_a: str, + exp_b: str, + condition: str, + hours_bin_start: float, + hours_bin_end: float, + n_a: int, + n_b: int, + mmd2: float, + p_value: float, + bandwidth: float, + effect_size: float, + activity_zscore: float, + embedding_key: str, +) -> dict: + return { + "marker": marker, + "exp_a": exp_a, + "exp_b": exp_b, + "condition": condition, + "hours_bin_start": hours_bin_start, + "hours_bin_end": hours_bin_end, + "n_a": n_a, + "n_b": n_b, + "mmd2": mmd2, + "p_value": p_value, + "bandwidth": bandwidth, + "effect_size": effect_size, + "activity_zscore": activity_zscore, + "embedding_key": embedding_key, + } + + +def run_mmd_pooled(config: MMDPooledConfig) -> pd.DataFrame: + """Run pooled multi-experiment MMD/mAP analysis. + + Concatenates cells from all input experiments into a single pool, then + computes MMD (and optionally mAP) per (marker, time_bin, comparison). + Unlike the combined mode (pairwise batch-effect detection), this pools all + experiments together for phenotypic profiling. + + Parameters + ---------- + config : MMDPooledConfig + Pooled analysis configuration. + + Returns + ------- + pd.DataFrame + Results with columns: marker, cond_a, cond_b, label, hours_bin_start, + hours_bin_end, n_a, n_b, mmd2, p_value, bandwidth, effect_size, + activity_zscore, map_value, map_p_value, embedding_key. + FDR-corrected q_value column is also included. + """ + from statsmodels.stats.multitest import multipletests + + adatas = [ad.read_zarr(p) for p in config.input_paths] + combined = ad.concat(adatas, join="outer", label="source_experiment") + combined.obs_names_make_unique() + + if config.obs_filter: + mask = pd.Series([True] * len(combined), index=combined.obs.index) + for col, val in config.obs_filter.items(): + if col not in combined.obs.columns: + raise KeyError(f"obs_filter column '{col}' not found. Available: {list(combined.obs.columns)}") + mask &= combined.obs[col] == val + combined = combined[mask].copy() + + if config.condition_aliases: + alias_map: dict[str, str] = {} + for canonical, variants in config.condition_aliases.items(): + for v in variants: + alias_map[v] = canonical + combined.obs[config.group_by] = combined.obs[config.group_by].map(lambda x: alias_map.get(x, x)) + + obs = combined.obs + if config.group_by not in obs.columns: + raise KeyError(f"obs column '{config.group_by}' not found. Available: {list(obs.columns)}") + + emb_key_label = config.embedding_key if config.embedding_key is not None else "X" + all_emb = _extract_embeddings(combined, config.embedding_key) + + records: list[dict] = [] + for marker in sorted(obs["marker"].unique()): + marker_mask = obs["marker"] == marker + + if config.temporal_bin_size is None and config.temporal_bins is None: + shared_bw = _compute_shared_bandwidth( + all_emb, obs, marker_mask, config.comparisons, config.mmd, config.group_by + ) + for comp in config.comparisons: + mask_a = marker_mask & (obs[config.group_by] == comp.cond_a) + mask_b = marker_mask & (obs[config.group_by] == comp.cond_b) + emb_a = all_emb[mask_a.values] + emb_b = all_emb[mask_b.values] + bw = shared_bw if shared_bw is not None else None + mmd2, p_value, bw_out, es, az, na, nb = _run_one_comparison(emb_a, emb_b, config.mmd, bandwidth=bw) + map_val, map_pval = _maybe_map( + obs[marker_mask.values], + all_emb[marker_mask.values], + comp, + config.group_by, + marker, + config.map_settings, + ) + records.append( + _pooled_record( + marker, + comp, + float("nan"), + float("nan"), + na, + nb, + mmd2, + p_value, + bw_out, + es, + az, + map_val, + map_pval, + emb_key_label, + ) + ) + else: + if "hours_post_perturbation" not in obs.columns: + raise KeyError("temporal binning requires obs column 'hours_post_perturbation'") + max_hours = obs["hours_post_perturbation"].max() + bin_pairs = _resolve_bin_edges(config.temporal_bin_size, config.temporal_bins, max_hours) + for b_start, b_end in bin_pairs: + shared_bw = _compute_shared_bandwidth_temporal( + all_emb, obs, marker_mask, config.comparisons, config.mmd, config.group_by, b_start, b_end + ) + for comp in config.comparisons: + mask_a = marker_mask & (obs[config.group_by] == comp.cond_a) + bin_mask_b = ( + marker_mask + & (obs[config.group_by] == comp.cond_b) + & (obs["hours_post_perturbation"] >= b_start) + & (obs["hours_post_perturbation"] < b_end) + ) + emb_a = all_emb[mask_a.values] + emb_b = all_emb[bin_mask_b.values] + bw = shared_bw if shared_bw is not None else None + mmd2, p_value, bw_out, es, az, na, nb = _run_one_comparison(emb_a, emb_b, config.mmd, bandwidth=bw) + map_val, map_pval = _maybe_map( + obs[marker_mask.values], + all_emb[marker_mask.values], + comp, + config.group_by, + marker, + config.map_settings, + ) + records.append( + _pooled_record( + marker, + comp, + b_start, + b_end, + na, + nb, + mmd2, + p_value, + bw_out, + es, + az, + map_val, + map_pval, + emb_key_label, + ) + ) + + df = pd.DataFrame(records) + if not df.empty: + valid_p = df["p_value"].dropna() + if len(valid_p) > 0: + _, q_values, _, _ = multipletests(df["p_value"].fillna(1.0), alpha=0.05, method="fdr_bh") + df["q_value"] = q_values + df.loc[df["p_value"].isna(), "q_value"] = float("nan") + else: + df["q_value"] = float("nan") + return df + + +def _pooled_record( + marker: str, + comp: ComparisonSpec, + hours_bin_start: float, + hours_bin_end: float, + n_a: int, + n_b: int, + mmd2: float, + p_value: float, + bandwidth: float, + effect_size: float, + activity_zscore: float, + map_value: float, + map_p_value: float, + embedding_key: str, +) -> dict: + return { + "marker": marker, + "cond_a": comp.cond_a, + "cond_b": comp.cond_b, + "label": comp.label, + "hours_bin_start": hours_bin_start, + "hours_bin_end": hours_bin_end, + "n_a": n_a, + "n_b": n_b, + "mmd2": mmd2, + "p_value": p_value, + "bandwidth": bandwidth, + "effect_size": effect_size, + "activity_zscore": activity_zscore, + "map_value": map_value, + "map_p_value": map_p_value, + "embedding_key": embedding_key, + } + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.argument("mmd_dir", type=click.Path(exists=True, path_type=Path)) +@click.option( + "--output-dir", type=click.Path(path_type=Path), default=None, help="Output directory. Default: same as mmd_dir." +) +def plot_mmd_heatmap_cmd(mmd_dir: Path, output_dir: Path | None) -> None: + """Plot a combined MMD heatmap (all markers) from per-experiment CSVs in MMD_DIR.""" + from dynaclr.evaluation.mmd.plotting import plot_mmd_heatmap + + csvs = sorted(mmd_dir.glob("*_mmd_results.csv")) + if not csvs: + raise click.ClickException(f"No *_mmd_results.csv files found in {mmd_dir}") + + df = pd.concat([pd.read_csv(f) for f in csvs], ignore_index=True) + click.echo(f"Loaded {len(df)} rows from {len(csvs)} CSV(s)") + + out = output_dir or mmd_dir + out.mkdir(parents=True, exist_ok=True) + + for comp_label in df["label"].unique(): + sub = df[df["label"] == comp_label] + safe = comp_label.replace(" ", "_").replace("/", "-") + for fmt in ("pdf", "png"): + plot_mmd_heatmap(sub, out / f"all_markers_{safe}_heatmap.{fmt}") + click.echo(f"Saved heatmap for: {comp_label}") + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to MMD evaluation YAML config", +) +@click.option( + "--combined", + is_flag=True, + default=False, + help="Run cross-experiment combined mode (config must have input_paths list)", +) +@click.option( + "--pooled", + is_flag=True, + default=False, + help="Run pooled multi-experiment phenotypic analysis (config must have input_paths list)", +) +def main(config: Path, combined: bool, pooled: bool) -> None: + """Compute MMD between explicit condition pairs in cell embeddings. + + Comparisons are defined as explicit (cond_a, cond_b, label) pairs. + The analysis is always faceted by obs["marker"]. + """ + if combined and pooled: + raise click.UsageError("--combined and --pooled are mutually exclusive") + raw = load_composed_config(config) + output_dir = Path(raw["output_dir"]) + output_dir.mkdir(parents=True, exist_ok=True) + + if combined: + cfg = MMDCombinedConfig(**raw) + df = run_mmd_combined(cfg) + out_csv = output_dir / "combined_mmd_results.csv" + df.to_csv(out_csv, index=False) + click.echo(f"Saved: {out_csv}") + if cfg.save_plots: + _save_plots_combined(df, output_dir, cfg.temporal_bin_size) + _print_summary(df, mode="combined") + elif pooled: + cfg = MMDPooledConfig(**raw) + df = run_mmd_pooled(cfg) + out_csv = output_dir / "pooled_mmd_results.csv" + df.to_csv(out_csv, index=False) + click.echo(f"Saved: {out_csv}") + if cfg.save_plots and len(df): + _save_plots_pooled(df, output_dir) + _print_summary(df, mode="pooled") + else: + cfg = MMDEvalConfig(**raw) + adata = ad.read_zarr(cfg.input_path) + df = run_mmd_analysis(adata, cfg) + experiment = df["experiment"].iloc[0] if len(df) else "unknown" + out_csv = output_dir / f"{experiment}_mmd_results.csv" + df.to_csv(out_csv, index=False) + click.echo(f"Saved: {out_csv}") + if cfg.save_plots and len(df): + _save_plots(df, output_dir, experiment, cfg.temporal_bin_size or cfg.temporal_bins) + _print_summary(df, mode="per_experiment") + + +def _save_plots(df: pd.DataFrame, output_dir: Path, label: str, temporal_config) -> None: + from dynaclr.evaluation.mmd.plotting import plot_mmd_kinetics, plot_mmd_multi_panel_kinetics + + has_bins = temporal_config is not None and len(df) and not df["hours_bin_start"].isna().all() + if not has_bins: + return + for comp_label in df["label"].unique(): + sub = df[df["label"] == comp_label] + safe = comp_label.replace(" ", "_").replace("/", "-") + for fmt in ("pdf", "png"): + plot_mmd_kinetics(sub, output_dir / f"{label}_{safe}_kinetics.{fmt}") + for fmt in ("pdf", "png"): + plot_mmd_multi_panel_kinetics(df, output_dir / f"{label}_multi_panel_kinetics.{fmt}") + if "activity_zscore" in df.columns and not df["activity_zscore"].isna().all(): + from dynaclr.evaluation.mmd.plotting import plot_activity_heatmap, plot_paired_heatmaps + + for fmt in ("pdf", "png"): + plot_activity_heatmap(df, output_dir / f"{label}_activity_heatmap.{fmt}") + labels = [c for c in df["label"].unique() if c] + if len(labels) >= 2: + for fmt in ("pdf", "png"): + plot_paired_heatmaps(df, labels[:2], "activity_zscore", output_dir / f"{label}_paired_activity.{fmt}") + + +def _save_plots_combined(df: pd.DataFrame, output_dir: Path, temporal_bin_size: float | None) -> None: + from dynaclr.evaluation.mmd.plotting import plot_mmd_combined_heatmap, plot_mmd_kinetics + + has_bins = temporal_bin_size is not None and len(df) and not df["hours_bin_start"].isna().all() + for fmt in ("pdf", "png"): + if has_bins: + for marker in df["marker"].unique(): + sub = df[df["marker"] == marker] + safe = marker.replace(" ", "_").replace("/", "-") + plot_mmd_kinetics(sub, output_dir / f"combined_{safe}_kinetics.{fmt}") + plot_mmd_combined_heatmap(df, output_dir / f"combined_heatmap.{fmt}") + + +def _save_plots_pooled(df: pd.DataFrame, output_dir: Path) -> None: + from dynaclr.evaluation.mmd.plotting import ( + plot_activity_heatmap, + plot_mmd_heatmap, + plot_mmd_multi_panel_kinetics, + plot_paired_heatmaps, + ) + + has_bins = not df["hours_bin_start"].isna().all() + for fmt in ("pdf", "png"): + for comp_label in df["label"].unique(): + sub = df[df["label"] == comp_label] + safe = comp_label.replace(" ", "_").replace("/", "-") + plot_mmd_heatmap(sub, output_dir / f"pooled_{safe}_heatmap.{fmt}") + if has_bins: + plot_mmd_multi_panel_kinetics(df, output_dir / f"pooled_multi_panel_kinetics.{fmt}") + if "activity_zscore" in df.columns and not df["activity_zscore"].isna().all(): + plot_activity_heatmap(df, output_dir / f"pooled_activity_heatmap.{fmt}") + labels = [c for c in df["label"].unique() if c] + if len(labels) >= 2: + plot_paired_heatmaps(df, labels[:2], "activity_zscore", output_dir / f"pooled_paired_activity.{fmt}") + + +def _print_summary(df: pd.DataFrame, mode: str = "per_experiment") -> None: + if df.empty: + click.echo("No results.") + return + click.echo("\n## MMD Results Summary\n") + if mode == "combined": + summary = ( + df.dropna(subset=["mmd2"]) + .groupby(["marker", "condition"])[["mmd2", "p_value", "effect_size"]] + .agg({"mmd2": "mean", "p_value": "min", "effect_size": "mean"}) + .round(4) + .reset_index() + ) + elif mode == "pooled": + summary = ( + df.dropna(subset=["mmd2"]) + .groupby(["marker", "label"])[["mmd2", "p_value", "effect_size", "activity_zscore"]] + .agg({"mmd2": "mean", "p_value": "min", "effect_size": "mean", "activity_zscore": "mean"}) + .round(4) + .reset_index() + ) + else: + summary = ( + df.dropna(subset=["mmd2"]) + .groupby(["marker", "label"])[["mmd2", "p_value", "effect_size"]] + .agg({"mmd2": "mean", "p_value": "min", "effect_size": "mean"}) + .round(4) + .reset_index() + ) + click.echo(summary.to_string(index=False)) diff --git a/applications/dynaclr/src/dynaclr/evaluation/mmd/config.py b/applications/dynaclr/src/dynaclr/evaluation/mmd/config.py new file mode 100644 index 000000000..e80463bd4 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/mmd/config.py @@ -0,0 +1,224 @@ +"""Pydantic configuration for the MMD perturbation evaluation step.""" + +from __future__ import annotations + +from typing import Optional + +import numpy as np +from pydantic import BaseModel, model_validator + + +class ComparisonSpec(BaseModel): + """One pairwise comparison to run MMD on. + + Parameters + ---------- + cond_a : str + Value of ``obs[group_by]`` for group A (typically the reference/control). + cond_b : str + Value of ``obs[group_by]`` for group B (typically the treatment). + label : str + Human-readable label for this comparison (used in output filenames and plots). + """ + + cond_a: str + cond_b: str + label: str + + +class MMDSettings(BaseModel): + """Kernel MMD algorithm settings, shared across per-experiment and combined modes. + + Parameters + ---------- + n_permutations : int + Number of permutations for the significance test. Default: 1000. + max_cells : int or None + Subsample each group to at most this many cells before computing MMD. + Controls memory and compute cost. Default: 2000. + min_cells : int + Minimum cells required per group. Groups below this produce NaN. Default: 20. + seed : int + Random seed for subsampling and permutations. Default: 42. + balance_samples : bool + Subsample the larger group to match the smaller group's size before + computing MMD. Prevents sample-size imbalance from inflating test statistics. + Applied after the ``max_cells`` cap. Default: False. + share_bandwidth_from : str or None + Label of a comparison whose bandwidth should be reused for all other + comparisons within the same (marker, time_bin) group. Typically the + baseline comparison (e.g. ``"uninf1 vs uninf2"``). If None, each + comparison computes its own bandwidth independently. Default: None. + """ + + n_permutations: int = 1000 + max_cells: Optional[int] = 2000 + min_cells: int = 20 + seed: int = 42 + balance_samples: bool = False + share_bandwidth_from: Optional[str] = None + + +class MAPSettings(BaseModel): + """Settings for the copairs-based mean Average Precision metric. + + Parameters + ---------- + enabled : bool + Compute mAP alongside MMD. Requires the ``copairs`` package. Default: False. + distance : str + Distance metric passed to copairs (e.g. ``"cosine"``). Default: ``"cosine"``. + null_size : int + Number of null pairs for the mAP permutation test. Default: 10000. + seed : int + Random seed. Default: 0. + """ + + enabled: bool = False + distance: str = "cosine" + null_size: int = 10000 + seed: int = 0 + + +class _MMDBaseConfig(BaseModel): + """Shared fields for all MMD analysis modes. + + Parameters + ---------- + output_dir : str + Directory for CSV results and plots. + group_by : str + obs column used to select condition groups. Default: ``"perturbation"``. + obs_filter : dict[str, str] or None + Restrict analysis to rows where ``obs[key] == value``. Default: None. + embedding_key : str or None + obsm key to use. None = raw ``.X`` backbone embeddings. Default: None. + mmd : MMDSettings + Kernel MMD algorithm settings. + map_settings : MAPSettings + copairs-based mAP settings. Default: disabled. + temporal_bin_size : float or None + Width of each temporal bin in hours, starting from 0. + Bin edges: ``[0, size, 2*size, ..., max_hours]``. + Mutually exclusive with ``temporal_bins``. Default: None (aggregate). + temporal_bins : list[float] or None + Explicit bin edges in hours (e.g. ``[0, 6, 12, 24]``). Takes precedence + over ``temporal_bin_size``. Default: None (aggregate). + save_plots : bool + Generate plots after computing metrics. Default: True. + """ + + output_dir: str + group_by: str = "perturbation" + obs_filter: Optional[dict[str, str]] = None + embedding_key: Optional[str] = None + mmd: MMDSettings = MMDSettings() + map_settings: MAPSettings = MAPSettings() + temporal_bin_size: Optional[float] = None + temporal_bins: Optional[list[float]] = None + save_plots: bool = True + + @model_validator(mode="after") + def _validate_temporal(self) -> "_MMDBaseConfig": + if self.temporal_bin_size is not None and self.temporal_bins is not None: + raise ValueError("temporal_bin_size and temporal_bins are mutually exclusive") + return self + + +def _resolve_bin_edges( + temporal_bin_size: Optional[float], + temporal_bins: Optional[list[float]], + max_hours: float, +) -> Optional[list[tuple[float, float]]]: + """Return a list of (start, end) bin edge pairs, or None if no temporal binning. + + Parameters + ---------- + temporal_bin_size : float or None + Uniform bin width. Generates edges ``[0, size, 2*size, ..., max_hours]``. + temporal_bins : list[float] or None + Explicit bin edges (e.g. ``[0, 6, 12, 24]``). Takes precedence over + ``temporal_bin_size``. + max_hours : float + Maximum hours value in the data, used only when ``temporal_bin_size`` is set. + + Returns + ------- + list[tuple[float, float]] or None + Ordered list of ``(bin_start, bin_end)`` pairs, or ``None`` for aggregate mode. + """ + if temporal_bins is not None: + edges = temporal_bins + elif temporal_bin_size is not None: + edges = list(np.arange(0, max_hours + temporal_bin_size, temporal_bin_size)) + else: + return None + return list(zip(edges[:-1], edges[1:])) + + +class MMDEvalConfig(_MMDBaseConfig): + """Per-experiment MMD analysis with explicit pairwise comparisons. + + Parameters + ---------- + input_path : str + Path to a single per-experiment AnnData zarr store. + comparisons : list[ComparisonSpec] + Explicit list of pairwise comparisons to run (required). + """ + + input_path: str + comparisons: list[ComparisonSpec] + + @model_validator(mode="after") + def _validate(self) -> "MMDEvalConfig": + if not self.comparisons: + raise ValueError("comparisons must not be empty") + return self + + +class MMDCombinedConfig(_MMDBaseConfig): + """Pairwise cross-experiment MMD for batch-effect detection. + + Conditions are auto-discovered from the data intersection — no explicit + comparisons needed. For each marker shared between a pair of experiments, + runs MMD per (condition, time_bin) after per-experiment mean centering. + + Parameters + ---------- + input_paths : list[str] + Paths to per-experiment AnnData zarr stores. + """ + + input_paths: list[str] + + +class MMDPooledConfig(_MMDBaseConfig): + """Pooled multi-experiment phenotypic analysis. + + Concatenates cells from all input experiments before computing MMD/mAP, + faceted by marker and temporal bin. Unlike ``MMDCombinedConfig`` (pairwise + batch-effect detection), this pools all experiments for a single biological + comparison. + + Parameters + ---------- + input_paths : list[str] + Paths to per-experiment AnnData zarr stores to pool. + comparisons : list[ComparisonSpec] + Explicit list of pairwise comparisons to run (required). + condition_aliases : dict[str, list[str]] or None + Mapping from canonical condition name to variant strings found in the + data. E.g. ``{"uninfected": ["uninfected", "uninfected1", "uninfected2"]}``. + Applied to ``obs[group_by]`` before comparisons are evaluated. + """ + + input_paths: list[str] + comparisons: list[ComparisonSpec] + condition_aliases: Optional[dict[str, list[str]]] = None + + @model_validator(mode="after") + def _validate(self) -> "MMDPooledConfig": + if not self.comparisons: + raise ValueError("comparisons must not be empty") + return self diff --git a/applications/dynaclr/src/dynaclr/evaluation/mmd/plotting.py b/applications/dynaclr/src/dynaclr/evaluation/mmd/plotting.py new file mode 100644 index 000000000..9828f0711 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/mmd/plotting.py @@ -0,0 +1,438 @@ +"""Plots for MMD perturbation evaluation: kinetics curves and heatmaps.""" + +from __future__ import annotations + +import math +from pathlib import Path + +import matplotlib +import numpy as np +import pandas as pd + +matplotlib.use("Agg") +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import seaborn as sns +from statsmodels.stats.multitest import multipletests + + +def _bh_significance(p_values: np.ndarray, alpha: float = 0.05) -> np.ndarray: + """Return boolean mask of BH-corrected significant p-values.""" + p_values = np.asarray(p_values, dtype=float) + valid = ~np.isnan(p_values) + sig = np.zeros(len(p_values), dtype=bool) + if valid.sum() == 0: + return sig + _, corrected, _, _ = multipletests(p_values[valid], alpha=alpha, method="fdr_bh") + sig[valid] = corrected + return sig + + +def plot_mmd_kinetics(df: pd.DataFrame, output_path: Path) -> None: + """Plot MMD kinetics curves (one line per marker over temporal bins). + + Parameters + ---------- + df : pd.DataFrame + MMD results for a single treatment group, with columns: + marker, hours_bin_start, hours_bin_end, mmd2, p_value. + output_path : Path + Output file path. Format inferred from suffix (.pdf or .png). + """ + df = df.copy().dropna(subset=["hours_bin_start", "hours_bin_end"]) + if df.empty: + return + df["bin_mid"] = (df["hours_bin_start"] + df["hours_bin_end"]) / 2 + + markers = sorted(df["marker"].unique()) + fig, ax = plt.subplots(figsize=(8, 4)) + palette = sns.color_palette("tab10", n_colors=len(markers)) + + for marker, color in zip(markers, palette): + sub = df[df["marker"] == marker].sort_values("bin_mid") + ax.plot(sub["bin_mid"], sub["mmd2"], marker="o", label=marker, color=color) + # Stars for BH-significant bins + sig = _bh_significance(sub["p_value"]) + for _, row, s in zip(range(len(sub)), sub.itertuples(), sig): + if s: + ax.text(row.bin_mid, row.mmd2, "*", ha="center", va="bottom", color=color, fontsize=12) + + ax.set_xlabel("Hours post perturbation (bin midpoint)") + ax.set_ylabel("MMD²") + ax.set_title(df["label"].iloc[0] if "label" in df.columns else "") + ax.legend(title="Marker", bbox_to_anchor=(1.01, 1), loc="upper left", fontsize=10, title_fontsize=11) + ax.axhline(0, color="gray", linewidth=0.8, linestyle="--") + sns.despine(ax=ax) + fig.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +def plot_mmd_combined_heatmap(df: pd.DataFrame, output_path: Path) -> None: + """Plot combined cross-experiment MMD heatmap: markers × experiment pairs. + + One subplot per condition. Rows = markers, columns = exp_a vs exp_b pairs + (averaged over temporal bins if present). + + Parameters + ---------- + df : pd.DataFrame + Combined MMD results with columns: marker, exp_a, exp_b, condition, + hours_bin_start, hours_bin_end, mmd2, p_value. + output_path : Path + Output file path. + """ + df = df.copy() + df["exp_pair"] = ( + df["exp_a"].str.split("_").str[:3].str.join("_") + "\nvs\n" + df["exp_b"].str.split("_").str[:3].str.join("_") + ) + conditions = sorted(df["condition"].unique()) + n_conds = len(conditions) + + fig, axes = plt.subplots( + 1, n_conds, figsize=(max(5 * n_conds, 6), max(4, df["marker"].nunique() * 0.7)), squeeze=False + ) + + for ax, condition in zip(axes[0], conditions): + sub = df[df["condition"] == condition] + pivot_mmd = sub.pivot_table(index="marker", columns="exp_pair", values="mmd2", aggfunc="mean") + pivot_pval = sub.pivot_table(index="marker", columns="exp_pair", values="p_value", aggfunc="min") + + if pivot_mmd.empty or pivot_mmd.isna().all().all(): + ax.set_visible(False) + continue + + sns.heatmap(pivot_mmd, ax=ax, cmap="viridis", linewidths=0.5, cbar_kws={"label": "MMD²"}) + + sig = _bh_significance(pivot_pval.values.ravel()) + sig_matrix = sig.reshape(pivot_pval.shape) + for r in range(sig_matrix.shape[0]): + for c in range(sig_matrix.shape[1]): + if sig_matrix[r, c]: + ax.text( + c + 0.5, r + 0.5, "*", ha="center", va="center", color="white", fontsize=10, fontweight="bold" + ) + + ax.set_title(f"condition: {condition}") + ax.set_xlabel("Experiment pair") + ax.set_ylabel("Marker") + ax.tick_params(axis="x", labelsize=7) + + fig.suptitle("Cross-experiment MMD — all markers", y=1.01) + fig.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +def plot_mmd_multi_panel_kinetics( + df: pd.DataFrame, + output_path: Path, + baseline_label: str | None = None, + ncols: int = 4, +) -> None: + """Plot per-marker MMD kinetics in a multi-panel grid with optional baseline band. + + One subplot per marker. Treatment comparisons are plotted as colored lines; + if ``baseline_label`` is given, that comparison is shown as a gray dashed + line with a shaded ±1 std band instead of a treatment line. + + Parameters + ---------- + df : pd.DataFrame + MMD results with columns: marker, label, hours_bin_start, hours_bin_end, + mmd2, p_value. + output_path : Path + Output file path (.pdf or .png). + baseline_label : str or None + Label of the baseline comparison to render as a band. Default: None. + ncols : int + Number of columns in the panel grid. Default: 4. + """ + df = df.copy().dropna(subset=["hours_bin_start", "hours_bin_end"]) + if df.empty: + return + df["bin_mid"] = (df["hours_bin_start"] + df["hours_bin_end"]) / 2 + + markers = sorted(df["marker"].unique()) + treatment_labels = [lbl for lbl in df["label"].unique() if lbl != baseline_label] + nrows = math.ceil(len(markers) / ncols) + palette = sns.color_palette("tab10", n_colors=max(len(treatment_labels), 1)) + + # Shared y-axis range + treat_vals = df[df["label"].isin(treatment_labels)]["mmd2"].dropna() + y_min = float(treat_vals.min()) if len(treat_vals) else 0.0 + y_max = float(treat_vals.max()) if len(treat_vals) else 1.0 + y_pad = (y_max - y_min) * 0.1 + 1e-6 + + fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 3.5, nrows * 2.8), squeeze=False) + + for ax_idx, marker in enumerate(markers): + ax = axes[ax_idx // ncols][ax_idx % ncols] + sub = df[df["marker"] == marker] + + # Baseline band + if baseline_label is not None: + base = sub[sub["label"] == baseline_label].sort_values("bin_mid") + if not base.empty: + ax.axhline(base["mmd2"].mean(), color="gray", linewidth=1.0, linestyle="--", zorder=1) + ax.fill_between( + base["bin_mid"], + base["mmd2"] - base["mmd2"].std(), + base["mmd2"] + base["mmd2"].std(), + color="gray", + alpha=0.2, + zorder=1, + ) + + # Treatment lines + for lbl, color in zip(treatment_labels, palette): + treat = sub[sub["label"] == lbl].sort_values("bin_mid") + if treat.empty: + continue + sig = _bh_significance(treat["p_value"]) + ax.plot(treat["bin_mid"], treat["mmd2"], color=color, linewidth=1.2, label=lbl, zorder=2) + sig_rows = treat[sig] + if not sig_rows.empty: + ax.scatter( + sig_rows["bin_mid"], + sig_rows["mmd2"], + color=color, + edgecolors="black", + linewidths=0.8, + s=40, + zorder=3, + ) + + ax.set_title(marker, fontsize=9) + ax.set_ylim(y_min - y_pad, y_max + y_pad) + ax.axhline(0, color="lightgray", linewidth=0.5, linestyle="--") + sns.despine(ax=ax) + + # Hide unused axes + for ax_idx in range(len(markers), nrows * ncols): + axes[ax_idx // ncols][ax_idx % ncols].set_visible(False) + + # Shared legend + handles, lbls = axes[0][0].get_legend_handles_labels() + if handles: + fig.legend( + handles, lbls, loc="lower center", ncol=len(treatment_labels), fontsize=9, bbox_to_anchor=(0.5, -0.02) + ) + + fig.supxlabel("Hours post perturbation (bin midpoint)", fontsize=10) + fig.supylabel("MMD²", fontsize=10) + fig.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +def plot_activity_heatmap( + df: pd.DataFrame, + output_path: Path, + linthresh: float = 1.0, +) -> None: + """Plot activity z-score heatmap (markers × temporal bins). + + Uses symmetric log normalization so both small and large z-scores are + visible. Significance stars mark FDR-corrected significant cells. + + Parameters + ---------- + df : pd.DataFrame + MMD results with columns: marker, label, hours_bin_start, hours_bin_end, + activity_zscore, p_value. + output_path : Path + Output file path (.pdf or .png). + linthresh : float + Linear threshold for ``SymLogNorm``. Values within ``[-linthresh, + linthresh]`` are rendered linearly; outside is log-scaled. Default: 1.0. + """ + if "activity_zscore" not in df.columns or df["activity_zscore"].isna().all(): + return + df = df.copy().dropna(subset=["hours_bin_start", "hours_bin_end", "activity_zscore"]) + if df.empty: + return + df["bin_label"] = df.apply(lambda r: f"{r.hours_bin_start:.0f}–{r.hours_bin_end:.0f}h", axis=1) + + labels = [lbl for lbl in df["label"].unique() if lbl] + n_labels = len(labels) + fig, axes = plt.subplots( + 1, + n_labels, + figsize=(max(5, len(df["bin_label"].unique()) * 1.0 * n_labels), max(4, df["marker"].nunique() * 0.6)), + squeeze=False, + ) + + for ax, lbl in zip(axes[0], labels): + sub = df[df["label"] == lbl] + pivot_z = sub.pivot_table(index="marker", columns="bin_label", values="activity_zscore", aggfunc="mean") + pivot_pval = sub.pivot_table(index="marker", columns="bin_label", values="p_value", aggfunc="min") + bin_order = sub.drop_duplicates("bin_label").sort_values("hours_bin_start")["bin_label"].tolist() + pivot_z = pivot_z.reindex(columns=bin_order) + pivot_pval = pivot_pval.reindex(columns=bin_order) + + if pivot_z.empty or pivot_z.isna().all().all(): + ax.set_visible(False) + continue + + vmax = float(np.nanmax(np.abs(pivot_z.values))) + norm = mcolors.SymLogNorm(linthresh=linthresh, vmin=-vmax, vmax=vmax) + sns.heatmap(pivot_z, ax=ax, cmap="RdBu_r", norm=norm, linewidths=0.3, cbar_kws={"label": "Activity z-score"}) + + sig = _bh_significance(pivot_pval.values.ravel()) + sig_matrix = sig.reshape(pivot_pval.shape) + for r in range(sig_matrix.shape[0]): + for c in range(sig_matrix.shape[1]): + if sig_matrix[r, c]: + ax.text( + c + 0.5, r + 0.5, "*", ha="center", va="center", color="black", fontsize=10, fontweight="bold" + ) + + ax.set_title(lbl) + ax.set_xlabel("Temporal bin") + ax.set_ylabel("Marker") + + fig.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +def plot_paired_heatmaps( + df: pd.DataFrame, + condition_labels: list[str], + value_col: str, + output_path: Path, + linthresh: float = 1.0, +) -> None: + """Plot side-by-side heatmaps for two conditions sharing a colorbar. + + Parameters + ---------- + df : pd.DataFrame + MMD results. Must have columns: marker, label, hours_bin_start, + hours_bin_end, ``value_col``, p_value. + condition_labels : list[str] + Exactly two comparison labels to plot side-by-side. + value_col : str + Column to use as heatmap values (e.g. ``"activity_zscore"``). + output_path : Path + Output file path. + linthresh : float + Linear threshold for ``SymLogNorm``. Default: 1.0. + """ + if value_col not in df.columns or len(condition_labels) < 2: + return + df = df.copy().dropna(subset=["hours_bin_start", "hours_bin_end", value_col]) + if df.empty: + return + df["bin_label"] = df.apply(lambda r: f"{r.hours_bin_start:.0f}–{r.hours_bin_end:.0f}h", axis=1) + bin_order = df.drop_duplicates("bin_label").sort_values("hours_bin_start")["bin_label"].tolist() + + all_vals = df[df["label"].isin(condition_labels)][value_col].dropna() + if all_vals.empty: + return + vmax = float(np.nanmax(np.abs(all_vals))) + norm = mcolors.SymLogNorm(linthresh=linthresh, vmin=-vmax, vmax=vmax) + + fig, axes = plt.subplots( + 1, 2, figsize=(max(10, len(bin_order) * 2), max(4, df["marker"].nunique() * 0.6)), squeeze=False + ) + + for ax, lbl in zip(axes[0], condition_labels[:2]): + sub = df[df["label"] == lbl] + pivot_val = sub.pivot_table(index="marker", columns="bin_label", values=value_col, aggfunc="mean") + pivot_pval = sub.pivot_table(index="marker", columns="bin_label", values="p_value", aggfunc="min") + pivot_val = pivot_val.reindex(columns=bin_order) + pivot_pval = pivot_pval.reindex(columns=bin_order) + + if pivot_val.empty or pivot_val.isna().all().all(): + ax.set_visible(False) + continue + + im = ax.imshow( + pivot_val.values, + aspect="auto", + norm=norm, + cmap="YlOrRd", + origin="upper", + ) + ax.set_xticks(range(len(pivot_val.columns))) + ax.set_xticklabels(pivot_val.columns, rotation=45, ha="right", fontsize=8) + ax.set_yticks(range(len(pivot_val.index))) + ax.set_yticklabels(pivot_val.index, fontsize=8) + ax.set_title(lbl) + + sig = _bh_significance(pivot_pval.values.ravel()) + sig_matrix = sig.reshape(pivot_pval.shape) + for r in range(sig_matrix.shape[0]): + for c in range(sig_matrix.shape[1]): + val = pivot_val.values[r, c] + if np.isfinite(val): + txt = f"{int(val)}" if abs(val) >= 1 else f"{val:.1f}" + if sig_matrix[r, c]: + txt += "*" + ax.text(c, r, txt, ha="center", va="center", fontsize=7, color="black") + + plt.colorbar(im, ax=axes[0], label=value_col) + fig.suptitle(f"{' vs '.join(condition_labels[:2])}", y=1.01) + fig.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +def plot_mmd_heatmap(df: pd.DataFrame, output_path: Path) -> None: + """Plot MMD heatmap (markers x temporal bins or aggregate). + + Parameters + ---------- + df : pd.DataFrame + MMD results for a single treatment group. + output_path : Path + Output file path. + """ + df = df.copy() + has_bins = not df["hours_bin_start"].isna().all() + + if has_bins: + df["bin_label"] = df.apply(lambda r: f"{r.hours_bin_start:.0f}–{r.hours_bin_end:.0f}h", axis=1) + pivot_mmd = df.pivot_table(index="marker", columns="bin_label", values="mmd2", aggfunc="mean") + pivot_pval = df.pivot_table(index="marker", columns="bin_label", values="p_value", aggfunc="min") + # Order columns by bin start + bin_order = df.drop_duplicates("bin_label").sort_values("hours_bin_start")["bin_label"].tolist() + pivot_mmd = pivot_mmd.reindex(columns=bin_order) + pivot_pval = pivot_pval.reindex(columns=bin_order) + xlabel = "Temporal bin" + figsize = (max(6, len(bin_order) * 0.8), max(4, len(pivot_mmd) * 0.6)) + else: + pivot_mmd = df.set_index("marker")[["mmd2"]].rename(columns={"mmd2": "aggregate"}) + pivot_pval = df.set_index("marker")[["p_value"]].rename(columns={"p_value": "aggregate"}) + xlabel = "" + figsize = (3, max(4, len(pivot_mmd) * 0.6)) + + if pivot_mmd.empty or pivot_mmd.isna().all().all(): + return + + fig, ax = plt.subplots(figsize=figsize) + sns.heatmap( + pivot_mmd, + ax=ax, + cmap="viridis", + annot=False, + linewidths=0.5, + cbar_kws={"label": "MMD²"}, + ) + + # Add significance stars + sig = _bh_significance(pivot_pval.values.ravel()) + sig_matrix = sig.reshape(pivot_pval.shape) + for r in range(sig_matrix.shape[0]): + for c in range(sig_matrix.shape[1]): + if sig_matrix[r, c]: + ax.text(c + 0.5, r + 0.5, "*", ha="center", va="center", color="white", fontsize=10, fontweight="bold") + + ax.set_title(f"MMD heatmap — {df['label'].iloc[0] if 'label' in df.columns else ''}") + ax.set_xlabel(xlabel) + ax.set_ylabel("Marker") + fig.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) diff --git a/applications/dynaclr/src/dynaclr/evaluation/plot_embeddings.py b/applications/dynaclr/src/dynaclr/evaluation/plot_embeddings.py new file mode 100644 index 000000000..344825dc7 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/plot_embeddings.py @@ -0,0 +1,294 @@ +"""CLI tool for generating scatter plots from AnnData embedding stores. + +For high-dimensional embeddings (PCA): generates a seaborn pairplot of the +first N components, one figure per color variable. +For low-dimensional embeddings (PHATE, UMAP): generates a simple scatter +colored by each metadata column. + +Usage +----- +dynaclr plot-embeddings -c plot_config.yaml +""" + +from pathlib import Path +from typing import Optional + +import anndata as ad +import click +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from pydantic import BaseModel, Field, model_validator + +from viscy_utils.cli_utils import load_config +from viscy_utils.mp_utils import available_cpus + + +class PlotEmbeddingsConfig(BaseModel): + """Configuration for plot-embeddings command. + + Parameters + ---------- + input_path : str, optional + Path to a single AnnData zarr store. Mutually exclusive with input_paths. + input_paths : list[str], optional + Paths to multiple AnnData zarr stores. All are concatenated before plotting. + Use for combined embeddings (X_pca_combined, X_phate_combined) to get one + figure across all experiments. Mutually exclusive with input_path. + output_dir : str + Directory to save plots. + embedding_keys : list[str] + obsm keys to plot (e.g. X_phate, X_pca). + color_by : list[str] + obs columns to use as hue in pairplots / color in scatter plots. + pairplot_components : int + Number of leading components to include in pairplots. Default: 10. + point_size : float + Scatter plot point size (passed as ``s`` to matplotlib and + ``plot_kws`` to seaborn). Default: 1.0. + format : str + Output format: "pdf", "png", or "both". Default: "pdf". + low_dim_threshold : int + Embeddings with <= this many components use the simple scatter path + instead of pairplot. Default: 4. + """ + + input_path: Optional[str] = None + input_paths: Optional[list[str]] = None + output_dir: str = Field(...) + embedding_keys: list[str] = ["X_pca_combined", "X_phate_combined"] + color_by: list[str] = ["perturbation", "hours_post_perturbation", "experiment", "marker"] + pairplot_components: int = 10 + point_size: float = 1.0 + format: str = "pdf" + low_dim_threshold: int = 4 + + @model_validator(mode="after") + def validate_input(self): + if self.input_path is None and self.input_paths is None: + raise ValueError("Either input_path or input_paths must be provided") + if self.input_path is not None and self.input_paths is not None: + raise ValueError("Provide either input_path or input_paths, not both") + return self + + +_PALETTE = [ + "#1b69a1", + "#d9534f", + "#5cb85c", + "#f0ad4e", + "#9b59b6", + "#1abc9c", + "#e74c3c", + "#3498db", + "#2ecc71", + "#e67e22", +] + + +def _save_fig(fig: plt.Figure, output_dir: Path, stem: str, fmt: str) -> None: + if fmt in ("pdf", "both"): + fig.savefig(output_dir / f"{stem}.pdf", dpi=150, bbox_inches="tight") + if fmt in ("png", "both"): + fig.savefig(output_dir / f"{stem}.png", dpi=150, bbox_inches="tight") + plt.close(fig) + click.echo(f" Saved {stem}.{fmt}") + + +def _pairplot( + emb: np.ndarray, + obs: pd.DataFrame, + color_col: str, + n_components: int, + point_size: float, + emb_key: str, +) -> plt.Figure: + """Build a seaborn pairplot of the first n_components.""" + import seaborn as sns + + n = min(n_components, emb.shape[1]) + cols = [f"{emb_key}_{i}" for i in range(n)] + df = pd.DataFrame(emb[:, :n], columns=cols) + + values = obs[color_col].to_numpy() + is_categorical = values.dtype.kind in ("U", "O", "S") or hasattr(values, "cat") + + if is_categorical: + cats = sorted(str(v) for v in np.unique(values)) + palette = {cat: _PALETTE[i % len(_PALETTE)] for i, cat in enumerate(cats)} + df[color_col] = [str(v) for v in values] + pg = sns.pairplot( + df, + hue=color_col, + palette=palette, + plot_kws={"s": point_size, "alpha": 0.4, "rasterized": True, "zorder": 0}, + diag_kind="hist", + corner=True, + ) + pg.legend.set(title=color_col) + for lh in pg.legend.legend_handles: + lh.set_alpha(1.0) + if hasattr(lh, "set_sizes"): + lh.set_sizes([40]) + else: + lh.set_markersize(8) + for ax_row in pg.axes: + for ax in ax_row: + if ax is not None: + ax.set_rasterization_zorder(1) + else: + # Continuous: no hue support in pairplot — use a custom scatter matrix + df[color_col] = values.astype(float) + pg = sns.pairplot( + df, + plot_kws={"s": point_size, "alpha": 0.4, "rasterized": True, "color": "#888888", "zorder": 0}, + diag_kind="hist", + corner=True, + ) + # Overlay color on lower-triangle axes + norm = plt.Normalize(df[color_col].min(), df[color_col].max()) + cmap = plt.cm.viridis + for i in range(1, n): + for j in range(i): + ax = pg.axes[i][j] + if ax is None: + continue + ax.collections[0].set_visible(False) + sc = ax.scatter( + df.iloc[:, j], + df.iloc[:, i], + c=df[color_col], + cmap=cmap, + norm=norm, + s=point_size, + alpha=0.4, + rasterized=True, + zorder=0, + ) + pg.figure.colorbar(sc, ax=pg.axes[-1][-1], label=color_col) + for ax_row in pg.axes: + for ax in ax_row: + if ax is not None: + ax.set_rasterization_zorder(1) + + pg.figure.suptitle(f"{emb_key} — {color_col}", y=1.01, fontsize=11, fontweight="bold") + return pg.figure + + +def _scatter_2d( + emb: np.ndarray, + obs: pd.DataFrame, + color_cols: list[str], + point_size: float, + emb_key: str, +) -> plt.Figure: + """Simple scatter for low-dimensional embeddings (PHATE, UMAP).""" + ncols = min(4, len(color_cols)) + nrows = (len(color_cols) + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 5 * nrows), squeeze=False) + rng = np.random.default_rng(42) + shuffle = rng.permutation(len(emb)) + x, y = emb[shuffle, 0], emb[shuffle, 1] + + for ax_idx, col in enumerate(color_cols): + ax = axes[ax_idx // ncols][ax_idx % ncols] + values = obs[col].to_numpy()[shuffle] + is_categorical = values.dtype.kind in ("U", "O", "S") or hasattr(values, "cat") + + if is_categorical: + cats = sorted(str(v) for v in np.unique(values)) + for i, cat in enumerate(cats): + mask = np.array([str(v) == cat for v in values]) + ax.scatter( + x[mask], y[mask], s=point_size, c=_PALETTE[i % len(_PALETTE)], label=cat, alpha=0.5, rasterized=True + ) + ax.legend( + markerscale=6, fontsize=10, loc="best", framealpha=1.0, edgecolor="black", ncol=max(1, len(cats) // 8) + ) + else: + sc = ax.scatter(x, y, s=point_size, c=values.astype(float), cmap="viridis", alpha=0.5, rasterized=True) + plt.colorbar(sc, ax=ax, shrink=0.8) + + ax.set_title(col.replace("_", " ").title(), fontsize=10) + ax.set_xlabel(f"{emb_key} 0") + ax.set_ylabel(f"{emb_key} 1") + + for ax_idx in range(len(color_cols), nrows * ncols): + axes[ax_idx // ncols][ax_idx % ncols].set_visible(False) + + fig.suptitle(f"Embeddings: {emb_key}", fontsize=13, fontweight="bold") + plt.tight_layout() + return fig + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to YAML configuration file", +) +def main(config: Path) -> None: + """Generate pairplots (PCA) and scatter plots (PHATE/UMAP) from an AnnData store.""" + matplotlib.use("Agg") + + raw = load_config(config) + cfg = PlotEmbeddingsConfig(**raw) + + if cfg.input_paths is not None: + click.echo(f"Concatenating {len(cfg.input_paths)} zarr stores...") + adata = ad.concat([ad.read_zarr(p) for p in cfg.input_paths], join="outer") + else: + adata = ad.read_zarr(cfg.input_path) + output_dir = Path(cfg.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + valid_color_cols = [c for c in cfg.color_by if c in adata.obs.columns] + missing = set(cfg.color_by) - set(valid_color_cols) + if missing: + click.echo(f"Warning: obs columns not found, skipping: {sorted(missing)}", err=True) + if not valid_color_cols: + click.echo("No valid color columns found, nothing to plot.", err=True) + return + + for emb_key in cfg.embedding_keys: + if emb_key not in adata.obsm: + click.echo(f"Warning: {emb_key} not in obsm, skipping", err=True) + continue + + emb = np.asarray(adata.obsm[emb_key]) + click.echo(f"Plotting {emb_key} ({emb.shape[1]} components)...") + + if emb.shape[1] <= cfg.low_dim_threshold: + # Simple scatter (PHATE, UMAP) + fig = _scatter_2d(emb, adata.obs, valid_color_cols, cfg.point_size, emb_key) + _save_fig(fig, output_dir, f"scatter_{emb_key}", cfg.format) + else: + # Pairplot per color variable (PCA) — render colorings in parallel. + # matplotlib isn't thread-safe but is fine across processes; loky + # spawns one worker per coloring, each with its own figure context. + # Workers re-import matplotlib + seaborn (~1s overhead) so this only + # pays off when the per-pairplot render time exceeds that cost, + # which it does for any pairplot_components >= 4 on >100k cells. + from joblib import Parallel, delayed + + n_jobs = min(len(valid_color_cols), available_cpus(default=1)) + + def _render_one(col): + try: + fig = _pairplot(emb, adata.obs, col, cfg.pairplot_components, cfg.point_size, emb_key) + _save_fig(fig, output_dir, f"pairplot_{emb_key}_{col}", cfg.format) + except Exception as e: + return f" Warning: pairplot {emb_key}/{col} failed: {e}" + return None + + messages = Parallel(n_jobs=n_jobs, backend="loky")(delayed(_render_one)(col) for col in valid_color_cols) + for msg in messages: + if msg is not None: + click.echo(msg, err=True) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/evaluation/split_embeddings.py b/applications/dynaclr/src/dynaclr/evaluation/split_embeddings.py new file mode 100644 index 000000000..c55aedbcd --- /dev/null +++ b/applications/dynaclr/src/dynaclr/evaluation/split_embeddings.py @@ -0,0 +1,101 @@ +"""Split a combined embeddings zarr into one zarr per experiment. + +Reads the combined embeddings.zarr produced by the predict step, groups rows +by obs["experiment"], and writes one AnnData zarr per experiment under +output_dir/{experiment}.zarr. The combined zarr is removed after splitting. + +Usage +----- +dynaclr split-embeddings -c split.yaml + +Or with inline arguments: + +dynaclr split-embeddings --input /path/to/embeddings.zarr --output-dir /path/to/embeddings/ +""" + +from __future__ import annotations + +from pathlib import Path + +import click + + +def split_embeddings(input_path: Path, output_dir: Path) -> list[Path]: + """Split combined embeddings zarr into one zarr per experiment. + + Parameters + ---------- + input_path : Path + Path to the combined embeddings zarr (AnnData format). + Must have obs["experiment"] column. + output_dir : Path + Directory to write per-experiment zarrs. + Each experiment is written to output_dir/{experiment}.zarr. + + Returns + ------- + list[Path] + Paths to the written per-experiment zarrs. + """ + import anndata as ad + + if hasattr(ad, "settings") and hasattr(ad.settings, "allow_write_nullable_strings"): + ad.settings.allow_write_nullable_strings = True + import pandas as pd + + pd.options.future.infer_string = False + + click.echo(f"Loading embeddings from {input_path}") + adata = ad.read_zarr(input_path) + click.echo(f" {adata.n_obs} cells, {adata.n_vars} features") + + if "experiment" not in adata.obs.columns: + raise ValueError( + "embeddings zarr obs is missing 'experiment' column. " + "Re-run the predict step with the updated pipeline to include metadata." + ) + + experiments = adata.obs["experiment"].unique().tolist() + click.echo(f" {len(experiments)} experiments: {experiments}") + + output_dir.mkdir(parents=True, exist_ok=True) + written: list[Path] = [] + + for exp in experiments: + mask = adata.obs["experiment"] == exp + adata_exp = adata[mask].copy() + out_path = output_dir / f"{exp}.zarr" + click.echo(f" Writing {exp}: {adata_exp.n_obs} cells → {out_path}") + adata_exp.write_zarr(out_path) + written.append(out_path) + + click.echo(f"\nRemoving combined zarr: {input_path}") + import shutil + + shutil.rmtree(input_path) + + click.echo(f"\nWrote {len(written)} per-experiment zarrs to {output_dir}") + return written + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "--input", + "input_path", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to combined embeddings zarr", +) +@click.option( + "--output-dir", + type=click.Path(path_type=Path), + required=True, + help="Directory to write per-experiment zarrs", +) +def main(input_path: Path, output_dir: Path) -> None: + """Split a combined embeddings zarr into one zarr per experiment.""" + split_embeddings(input_path, output_dir) + + +if __name__ == "__main__": + main() diff --git a/applications/dynaclr/src/dynaclr/info.py b/applications/dynaclr/src/dynaclr/info.py index fb8523aeb..624ab629e 100644 --- a/applications/dynaclr/src/dynaclr/info.py +++ b/applications/dynaclr/src/dynaclr/info.py @@ -12,6 +12,7 @@ def main(path: Path): """Print summary of an AnnData zarr store.""" import anndata as ad + import scipy.sparse as sp with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -19,7 +20,14 @@ def main(path: Path): click.echo(f"Path: {path}") click.echo(f"Shape: {adata.n_obs:,} obs × {adata.n_vars:,} vars") - click.echo(f"X: dtype={adata.X.dtype}, range=[{np.nanmin(adata.X):.4f}, {np.nanmax(adata.X):.4f}]") + X = adata.X + if sp.issparse(X): + X_dense = X.toarray() + else: + X_dense = X + sparse = sp.issparse(adata.X) + xmin, xmax = np.nanmin(X_dense), np.nanmax(X_dense) + click.echo(f"X: dtype={X_dense.dtype}, sparse={sparse}, range=[{xmin:.4f}, {xmax:.4f}]") if len(adata.obs.columns): click.echo("\nobs columns:") @@ -27,7 +35,7 @@ def main(path: Path): s = adata.obs[col] nuniq = s.nunique() if nuniq <= 10: - vals = ", ".join(str(v) for v in sorted(s.unique()[:10])) + vals = ", ".join(str(v) for v in sorted(s.dropna().unique()[:10])) click.echo(f" {col}: {s.dtype}, {nuniq} unique — [{vals}]") else: click.echo(f" {col}: {s.dtype}, {nuniq} unique") diff --git a/applications/dynaclr/src/dynaclr/pseudotime/__init__.py b/applications/dynaclr/src/dynaclr/pseudotime/__init__.py new file mode 100644 index 000000000..87e8e314d --- /dev/null +++ b/applications/dynaclr/src/dynaclr/pseudotime/__init__.py @@ -0,0 +1,69 @@ +"""DTW-based pseudotime alignment for cellular dynamics from DynaCLR embeddings. + +Public API: + +- :func:`build_template` — fit a DBA template from annotated trajectories. +- :func:`dtw_align_tracks` — warp query tracks onto a template. +- :func:`alignment_results_to_dataframe` — flatten alignment results to a dataframe. +- :func:`extract_dtw_pseudotime` — pull per-track pseudotime from results. +- :func:`classify_response_groups` — k-means cluster cells by alignment summary. +- :class:`TemplateResult` — template + PCA + z-score params bundle. +- :class:`AlignmentResult` — per-track DTW alignment bundle. +- :data:`DEFAULT_POSITIVE_CLASSES` — default infection-label mapping. + +IO helpers: + +- :func:`save_template_zarr` — write a two-flavor template zarr with provenance. +- :func:`load_template_flavor` — read one flavor (raw or pca) from a template. +- :func:`read_template_attrs` — read just the top-level attrs. +- :func:`read_time_calibration` — read the per-position time calibration array. +- :func:`find_embedding_zarr` — glob the embedding zarr under a dataset's pred_dir. +- :func:`date_prefix_from_dataset_id` — extract the ``YYYY_MM_DD_`` prefix. +- :func:`get_dynaclr_versions` — capture viscy/library versions for provenance. + +Lower-level helpers and legacy modules (``alignment``, ``signals``, +``metrics``, ``plotting``, ``evaluation``) are still importable but are +not part of the curated public surface. +""" + +from dynaclr.pseudotime.dtw_alignment import ( + DEFAULT_POSITIVE_CLASSES, + AlignmentResult, + TemplateResult, + alignment_results_to_dataframe, + build_template, + classify_response_groups, + dtw_align_tracks, + extract_dtw_pseudotime, +) +from dynaclr.pseudotime.io import ( + compute_tau_event_band, + date_prefix_from_dataset_id, + find_embedding_zarr, + get_dynaclr_versions, + load_template_flavor, + read_tau_event_band, + read_template_attrs, + read_time_calibration, + save_template_zarr, +) + +__all__ = [ + "DEFAULT_POSITIVE_CLASSES", + "AlignmentResult", + "TemplateResult", + "alignment_results_to_dataframe", + "build_template", + "classify_response_groups", + "compute_tau_event_band", + "date_prefix_from_dataset_id", + "dtw_align_tracks", + "extract_dtw_pseudotime", + "find_embedding_zarr", + "get_dynaclr_versions", + "load_template_flavor", + "read_tau_event_band", + "read_template_attrs", + "read_time_calibration", + "save_template_zarr", +] diff --git a/applications/dynaclr/src/dynaclr/evaluation/pseudotime/alignment.py b/applications/dynaclr/src/dynaclr/pseudotime/alignment.py similarity index 99% rename from applications/dynaclr/src/dynaclr/evaluation/pseudotime/alignment.py rename to applications/dynaclr/src/dynaclr/pseudotime/alignment.py index f4d358c16..13ba9af4c 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/pseudotime/alignment.py +++ b/applications/dynaclr/src/dynaclr/pseudotime/alignment.py @@ -244,7 +244,7 @@ def align_tracks( min_track_timepoints: int = 3, fov_pattern: str | list[str] | None = None, ) -> pd.DataFrame: - """Convenience wrapper: filter_tracks + assign_t_perturb in one call. + """Run filter_tracks + assign_t_perturb in one call (convenience wrapper). Parameters ---------- diff --git a/applications/dynaclr/src/dynaclr/pseudotime/dtw_alignment.py b/applications/dynaclr/src/dynaclr/pseudotime/dtw_alignment.py new file mode 100644 index 000000000..aac210fe3 --- /dev/null +++ b/applications/dynaclr/src/dynaclr/pseudotime/dtw_alignment.py @@ -0,0 +1,948 @@ +"""DTW-based pseudotime alignment for cellular dynamics. + +Aligns cell trajectories to a template infection response using Dynamic +Time Warping (DTW). The template is built from annotated transitioning +cells via DBA (DTW Barycenter Averaging), then all cells are warped +onto it to produce pseudotime values in [0, 1]. + +Preprocessing pipeline: per-experiment z-score -> PCA -> L2-normalize -> DTW. +""" + +from __future__ import annotations + +import logging +import uuid +from typing import NamedTuple + +import anndata as ad +import numpy as np +import pandas as pd +from dtaidistance import dtw, dtw_ndim +from sklearn.decomposition import PCA +from sklearn.preprocessing import normalize + +_logger = logging.getLogger(__name__) + +DEFAULT_POSITIVE_CLASSES: dict[str, str] = { + "infection_state": "infected", + "organelle_state": "remodel", +} +"""Default mapping of label column to positive class. + +Used when the caller does not pass ``positive_classes`` explicitly. Keys +are ``obs`` columns whose entries are categorical strings; values are +the entry that should be treated as the positive class for downstream +binarization. Override per call when working with non-infection labels +(e.g. ``{"cell_division_state": "mitosis"}`` for ALFI). +""" +"""Default mapping of label column to positive class. + +Used when the caller does not pass ``positive_classes`` explicitly. Keys +are ``obs`` columns whose entries are categorical strings; values are +the entry that should be treated as the positive class for downstream +binarization. Override per call when working with non-infection labels +(e.g. ``{"cell_division_state": "mitosis"}`` for ALFI). +""" + + +class TemplateResult(NamedTuple): + """Result of building an infection response template.""" + + template: np.ndarray + template_id: str + pca: PCA | None + zscore_params: dict[str, tuple[np.ndarray, np.ndarray]] + template_cell_ids: list[tuple[str, str, int]] + n_input_tracks: int + explained_variance: float | None + template_labels: dict[str, np.ndarray] | None # {col: (T,) fraction} per label column + time_calibration: np.ndarray | None = None # (T,) mean t_relative_minutes per template position + + +class AlignmentResult(NamedTuple): + """DTW alignment result for a single cell track. + + ``length_normalized_cost`` and ``path_skew`` are scalar gating + signals computed from the warp path. Per discussion §3.8, path_skew + is the primary gate (rejects degenerate non-diagonal warps) and + length_normalized_cost is the secondary gate (stereotypy filter). + Both are surfaced in the alignment parquet so robustness checks + can sweep gate thresholds without re-running DTW. + """ + + cell_uid: str + dataset_id: str + fov_name: str + track_id: int + timepoints: np.ndarray + pseudotime: np.ndarray + dtw_cost: float + length_normalized_cost: float + path_skew: float + warping_path: np.ndarray + warping_speed: np.ndarray + propagated_labels: dict[str, np.ndarray] | None # {col: (T,) fraction} per label column + alignment_region: np.ndarray # per-frame: "pre", "aligned", or "post" + + +def _zscore_embeddings( + embeddings_dict: dict[str, np.ndarray], +) -> tuple[dict[str, np.ndarray], dict[str, tuple[np.ndarray, np.ndarray]]]: + """Per-experiment z-score normalization. + + Parameters + ---------- + embeddings_dict : dict[str, np.ndarray] + {dataset_id: (N, D) embedding array}. + + Returns + ------- + tuple[dict[str, np.ndarray], dict[str, tuple[np.ndarray, np.ndarray]]] + Z-scored embeddings and per-experiment (mean, std) params. + """ + zscored = {} + params = {} + for dataset_id, emb in embeddings_dict.items(): + mean = emb.mean(axis=0) + std = emb.std(axis=0) + std = np.where(std < 1e-10, 1.0, std) + zscored[dataset_id] = (emb - mean) / std + params[dataset_id] = (mean, std) + return zscored, params + + +def _preprocess_embeddings( + embeddings: np.ndarray, + pca: PCA | None = None, +) -> np.ndarray: + """PCA transform + L2 normalize. + + Parameters + ---------- + embeddings : np.ndarray + (N, D) array, already z-scored. + pca : PCA or None + Fitted PCA model. If None, skip dimensionality reduction. + + Returns + ------- + np.ndarray + (N, D') L2-normalized embeddings. + """ + if pca is not None: + embeddings = pca.transform(embeddings) + return normalize(embeddings, norm="l2", axis=1) + + +def _extract_track_trajectories( + adata: ad.AnnData, + df: pd.DataFrame, + min_track_timepoints: int = 3, + crop_window: int | None = None, + label_cols: list[str] | None = None, + positive_classes: dict[str, str] | None = None, +) -> list[tuple[str, int, np.ndarray, np.ndarray, dict[str, np.ndarray] | None]]: + """Extract per-track embedding trajectories from AnnData. + + Parameters + ---------- + adata : ad.AnnData + Embeddings with obs containing fov_name, track_id, t. + df : pd.DataFrame + Filtered tracking DataFrame (used for valid track selection). + Must have t_perturb column if crop_window is set. + min_track_timepoints : int + Minimum timepoints per track (applied after cropping). + crop_window : int or None + If set, crop each track to [t_perturb - crop_window, t_perturb + crop_window]. + Requires t_perturb column in df. None = use full track. + label_cols : list[str] or None + Label columns to extract (e.g., ["infection_state", "organelle_state"]). + Each is binarized using ``positive_classes``. + positive_classes : dict[str, str] or None + Mapping from label column name to its positive class value. + Required when ``label_cols`` is provided. + + Returns + ------- + list[tuple[str, int, np.ndarray, np.ndarray, dict[str, np.ndarray] | None]] + Each element: (fov_name, track_id, embeddings (T, D), timepoints (T,), + labels {col: (T,)} or None). + """ + valid_tracks = df.groupby(["fov_name", "track_id"]).filter(lambda x: len(x) >= min_track_timepoints) + valid_keys = set(zip(valid_tracks["fov_name"], valid_tracks["track_id"])) + + # Build t_perturb lookup if cropping + t_perturb_lookup: dict[tuple[str, int], int] = {} + if crop_window is not None: + if "t_perturb" not in df.columns: + raise ValueError("crop_window requires t_perturb column in df") + for (fov, tid), grp in df.groupby(["fov_name", "track_id"]): + t_perturb_lookup[(fov, tid)] = int(grp["t_perturb"].iloc[0]) + + # Build label lookups per column + label_lookups: dict[str, dict[tuple, int]] = {} + if label_cols: + if positive_classes is None: + raise ValueError("positive_classes is required when label_cols is set") + for col in label_cols: + if col not in df.columns: + continue + if col not in positive_classes: + raise KeyError(f"positive_classes is missing entry for label column {col!r}") + positive_val = positive_classes[col] + lookup: dict[tuple, int] = {} + for _, row in df.iterrows(): + val = row[col] + if pd.notna(val) and val != "": + lookup[(row["fov_name"], row["track_id"], int(row["t"]))] = 1 if val == positive_val else 0 + label_lookups[col] = lookup + + obs = adata.obs.copy() + obs["_iloc"] = np.arange(len(obs)) + trajectories = [] + for (fov_name, track_id), group in obs.groupby(["fov_name", "track_id"]): + if (fov_name, track_id) not in valid_keys: + continue + sorted_group = group.sort_values("t") + + # Crop around t_perturb if requested + if crop_window is not None and (fov_name, track_id) in t_perturb_lookup: + tp = t_perturb_lookup[(fov_name, track_id)] + t_vals = sorted_group["t"].to_numpy() + mask = (t_vals >= tp - crop_window) & (t_vals <= tp + crop_window) + sorted_group = sorted_group.iloc[mask] + + if len(sorted_group) < min_track_timepoints: + continue + + iloc_indices = sorted_group["_iloc"].to_numpy() + emb = adata.X[iloc_indices] + if hasattr(emb, "toarray"): + emb = emb.toarray() + timepoints = sorted_group["t"].to_numpy().astype(int) + + labels = None + if label_lookups: + labels = {} + for col, lookup in label_lookups.items(): + labels[col] = np.array( + [lookup.get((fov_name, track_id, int(t)), 0) for t in timepoints], dtype=np.float64 + ) + + trajectories.append((str(fov_name), int(track_id), np.asarray(emb, dtype=np.float64), timepoints, labels)) + + return trajectories + + +def _dba( + sequences: list[np.ndarray], + max_iter: int = 30, + tol: float = 1e-5, + init: str = "medoid", + random_state: int = 42, +) -> np.ndarray: + """DTW Barycenter Averaging (DBA). + + Parameters + ---------- + sequences : list[np.ndarray] + List of (T_i, D) sequences. + max_iter : int + Maximum iterations. + tol : float + Convergence tolerance on mean absolute change. + init : str + Initialization method. "medoid" selects the sequence with + lowest total DTW cost to all others. + random_state : int + Seed for medoid candidate subsampling. Default 42. + + Returns + ------- + np.ndarray + (T_avg, D) template sequence. + """ + if len(sequences) == 0: + raise ValueError("No sequences provided for DBA.") + + if init == "medoid": + n = len(sequences) + # Subsample for medoid if too many sequences (O(n²) DTW calls) + max_medoid_candidates = 50 + if n > max_medoid_candidates: + rng = np.random.default_rng(random_state) + candidate_idx = rng.choice(n, max_medoid_candidates, replace=False) + _logger.info("DBA medoid init: subsampling %d/%d candidates", max_medoid_candidates, n) + else: + candidate_idx = np.arange(n) + costs = np.zeros(len(candidate_idx)) + for ci, i in enumerate(candidate_idx): + for j in range(n): + if i != j: + costs[ci] += dtw_ndim.distance(sequences[i], sequences[j]) + avg = sequences[int(candidate_idx[np.argmin(costs)])].copy() + else: + avg = sequences[0].copy() + + for iteration in range(max_iter): + n_frames = avg.shape[0] + n_dims = avg.shape[1] + accum = np.zeros((n_frames, n_dims)) + counts = np.zeros(n_frames) + + for seq in sequences: + _, paths = dtw_ndim.warping_paths(avg, seq) + path = dtw.best_path(paths) + for idx_avg, idx_seq in path: + accum[idx_avg] += seq[idx_seq] + counts[idx_avg] += 1 + + counts = np.maximum(counts, 1) + new_avg = accum / counts[:, np.newaxis] + change = np.mean(np.abs(new_avg - avg)) + + _logger.debug(f"DBA iteration {iteration + 1}: mean change = {change:.6f}") + avg = new_avg + + if change < tol: + _logger.info(f"DBA converged at iteration {iteration + 1} (change={change:.2e})") + break + + return avg + + +def build_template( + adata_dict: dict[str, ad.AnnData], + aligned_df_dict: dict[str, pd.DataFrame], + pca_n_components: int | None = 20, + pca_variance_threshold: float | None = None, + dba_max_iter: int = 30, + dba_tol: float = 1e-5, + dba_init: str = "medoid", + control_adata_dict: dict[str, ad.AnnData] | None = None, + crop_window: int | dict[str, int] | None = None, + positive_classes: dict[str, str] | None = None, + random_state: int = 42, +) -> TemplateResult: + """Build a DTW pseudotime template from annotated single-cell trajectories. + + Generic over the underlying biology — what was previously called + ``build_infection_template`` works for any anchored event (infection + onset, mitotic entry, immune activation) provided the caller supplies + the appropriate label-to-positive-class mapping via ``positive_classes``. + + Parameters + ---------- + adata_dict : dict[str, ad.AnnData] + {dataset_id: adata} with embeddings for the cells used to build + the template. + aligned_df_dict : dict[str, pd.DataFrame] + {dataset_id: aligned_df} with t_perturb assigned. + pca_n_components : int or None + Number of PCA components. Ignored if pca_variance_threshold is set. + pca_variance_threshold : float or None + If set, auto-select components to explain this variance fraction. + dba_max_iter : int + Max DBA iterations. + dba_tol : float + DBA convergence tolerance. + dba_init : str + DBA initialization ("medoid"). + control_adata_dict : dict[str, ad.AnnData] | None + Control embeddings per dataset, included in PCA fitting. + crop_window : int or dict[str, int] or None + If set, crop each track to [t_perturb - crop_window, t_perturb + crop_window] + before DBA. Produces a shorter template centered on the anchored + event. Pass a dict to use per-dataset crop windows (e.g. when + datasets have different frame intervals and crop_window was + derived from a fixed duration in minutes). None = use full + tracks (variable length). + positive_classes : dict[str, str] or None + Mapping from label column name to its positive-class value. + Defaults to :data:`DEFAULT_POSITIVE_CLASSES` (infection labels). + Pass ``{"cell_division_state": "mitosis"}`` for ALFI division + templates. + random_state : int + Seed for reproducible PCA / medoid subsampling. Default 42. + + Returns + ------- + TemplateResult + Template array, PCA model, z-score params, and metadata. + """ + if positive_classes is None: + positive_classes = DEFAULT_POSITIVE_CLASSES + raw_embeddings = {} + for dataset_id, adata in adata_dict.items(): + emb = adata.X + if hasattr(emb, "toarray"): + emb = emb.toarray() + raw_embeddings[dataset_id] = np.asarray(emb, dtype=np.float64) + + if control_adata_dict is not None: + for dataset_id, adata in control_adata_dict.items(): + ctrl_key = f"{dataset_id}__control" + emb = adata.X + if hasattr(emb, "toarray"): + emb = emb.toarray() + raw_embeddings[ctrl_key] = np.asarray(emb, dtype=np.float64) + + zscored, zscore_params = _zscore_embeddings(raw_embeddings) + + all_zscored = np.concatenate(list(zscored.values()), axis=0) + use_pca = pca_n_components is not None or pca_variance_threshold is not None + pca = None + explained_variance = None + + if use_pca: + if pca_variance_threshold is not None: + pca = PCA(n_components=pca_variance_threshold, svd_solver="full", random_state=random_state) + else: + n_comp = min(pca_n_components, all_zscored.shape[1], all_zscored.shape[0]) + pca = PCA(n_components=n_comp, random_state=random_state) + pca.fit(all_zscored) + explained_variance = float(np.sum(pca.explained_variance_ratio_)) + _logger.info(f"PCA: {pca.n_components_} components explain {explained_variance:.1%} variance") + + clean_zscore_params = {k: v for k, v in zscore_params.items() if "__control" not in k} + + trajectories = [] + track_labels: list[dict[str, np.ndarray] | None] = [] + track_t_rels: list[np.ndarray] = [] + cell_ids: list[tuple[str, str, int]] = [] + + # Detect which label columns are available across all datasets + label_cols = [col for col in positive_classes if any(col in df.columns for df in aligned_df_dict.values())] + label_cols_or_none = label_cols if label_cols else None + + for dataset_id, adata in adata_dict.items(): + df = aligned_df_dict[dataset_id] + ds_zscored_emb = zscored[dataset_id] + + zscored_adata = ad.AnnData(X=ds_zscored_emb, obs=adata.obs.copy()) + zscored_adata.obs.index = adata.obs.index + + # Build t_relative_minutes lookup for this dataset + t_rel_lookup: dict[tuple[str, int, int], float] = {} + if "t_relative_minutes" in df.columns: + for _, row in df.iterrows(): + t_rel_lookup[(str(row["fov_name"]), int(row["track_id"]), int(row["t"]))] = float( + row["t_relative_minutes"] + ) + + ds_crop_window = crop_window[dataset_id] if isinstance(crop_window, dict) else crop_window + tracks = _extract_track_trajectories( + zscored_adata, + df, + min_track_timepoints=1, + crop_window=ds_crop_window, + label_cols=label_cols_or_none, + positive_classes=positive_classes, + ) + for fov_name, track_id, emb, timepoints, labels in tracks: + processed = _preprocess_embeddings(emb, pca=pca) + trajectories.append(processed) + track_labels.append(labels) + cell_ids.append((dataset_id, fov_name, track_id)) + t_rel = np.array([t_rel_lookup.get((fov_name, track_id, int(t)), np.nan) for t in timepoints]) + track_t_rels.append(t_rel) + + if len(trajectories) == 0: + raise ValueError("No valid trajectories found for template building.") + + _logger.info(f"Building template from {len(trajectories)} trajectories") + template = _dba(trajectories, max_iter=dba_max_iter, tol=dba_tol, init=dba_init, random_state=random_state) + template = normalize(template, norm="l2", axis=1) + + # Compute template labels and time calibration via DTW alignment back to template. + # One DTW path per track; labels and t_relative_minutes mapped through the same path. + n_template = template.shape[0] + template_labels = None + time_calibration = None + + has_labels = label_cols and all(lb is not None for lb in track_labels) + has_t_rel = any(np.any(np.isfinite(t)) for t in track_t_rels) + + if has_labels or has_t_rel: + label_sums = {col: np.zeros(n_template) for col in label_cols} if has_labels else {} + label_counts = {col: np.zeros(n_template) for col in label_cols} if has_labels else {} + time_sums = np.zeros(n_template) + time_counts = np.zeros(n_template) + + for seq, labels_dict, t_rel_arr in zip(trajectories, track_labels, track_t_rels): + _, paths = dtw_ndim.warping_paths(template, seq) + path = dtw.best_path(paths) + if has_labels and labels_dict is not None: + for col in label_cols: + if col not in labels_dict: + continue + col_labels = labels_dict[col] + for idx_template, idx_seq in path: + if idx_seq < len(col_labels): + label_sums[col][idx_template] += col_labels[idx_seq] + label_counts[col][idx_template] += 1 + for idx_template, idx_seq in path: + if idx_seq < len(t_rel_arr) and np.isfinite(t_rel_arr[idx_seq]): + time_sums[idx_template] += t_rel_arr[idx_seq] + time_counts[idx_template] += 1 + + if has_labels: + template_labels = {} + for col in label_cols: + counts = np.maximum(label_counts[col], 1) + template_labels[col] = label_sums[col] / counts + _logger.info( + "Template labels [%s]: %d positions, fraction range [%.2f, %.2f]", + col, + n_template, + template_labels[col].min(), + template_labels[col].max(), + ) + + if has_t_rel and time_counts.sum() > 0: + raw_cal = np.where(time_counts > 0, time_sums / np.maximum(time_counts, 1), np.nan) + # Interpolate any gaps linearly + positions = np.arange(n_template) + valid_mask = np.isfinite(raw_cal) + if valid_mask.sum() >= 2: + time_calibration = np.interp(positions, positions[valid_mask], raw_cal[valid_mask]) + elif valid_mask.sum() == 1: + time_calibration = np.full(n_template, raw_cal[valid_mask][0]) + _logger.info( + "Time calibration: %d positions, range [%.1f, %.1f] min", + n_template, + time_calibration.min(), + time_calibration.max(), + ) + + return TemplateResult( + template=template, + template_id=str(uuid.uuid4()), + pca=pca, + zscore_params=clean_zscore_params, + template_cell_ids=cell_ids, + n_input_tracks=len(trajectories), + explained_variance=explained_variance, + template_labels=template_labels, + time_calibration=time_calibration, + ) + + +def dtw_align_tracks( + adata: ad.AnnData, + df: pd.DataFrame, + template_result: TemplateResult, + dataset_id: str, + min_track_timepoints: int = 3, + psi: int | None = None, + subsequence: bool = False, +) -> list[AlignmentResult]: + """Align cell tracks to a template using DTW. + + Parameters + ---------- + adata : ad.AnnData + Embeddings with obs containing fov_name, track_id, t. + df : pd.DataFrame + Tracking DataFrame (optionally with t_perturb). + template_result : TemplateResult + Template from :func:`build_template`. + dataset_id : str + Identifier for this dataset. + min_track_timepoints : int + Minimum timepoints per track. + psi : int or None + Psi relaxation for DTW. If None, auto-computed: + - subsequence=True: psi = max(track_len - template_len, 0) + - subsequence=False: psi = template_len // 2 + subsequence : bool + If True, use subsequence DTW: sweep the (short) template across + the (long) cell track to find the best-matching segment. + Frames before the matched region get pseudotime=0, + frames after get pseudotime=1. + Use this when the template was built with crop_window. + + Returns + ------- + list[AlignmentResult] + One result per aligned track. + """ + emb = adata.X + if hasattr(emb, "toarray"): + emb = emb.toarray() + emb = np.asarray(emb, dtype=np.float64) + + if dataset_id in template_result.zscore_params: + mean, std = template_result.zscore_params[dataset_id] + else: + mean = emb.mean(axis=0) + std = emb.std(axis=0) + std = np.where(std < 1e-10, 1.0, std) + emb_zscored = (emb - mean) / std + + zscored_adata = ad.AnnData(X=emb_zscored, obs=adata.obs.copy()) + zscored_adata.obs.index = adata.obs.index + + tracks = _extract_track_trajectories(zscored_adata, df, min_track_timepoints) + template = template_result.template + t_template = template.shape[0] + + results = [] + for fov_name, track_id, track_emb, timepoints, _labels in tracks: + processed = _preprocess_embeddings(track_emb, pca=template_result.pca) + n_track = len(processed) + + # Compute psi (must be < min(template_len, track_len)) + max_psi = min(n_track - 1, t_template - 1) + if psi is not None: + track_psi = min(psi, max_psi) + elif subsequence: + # Allow template to float anywhere within the track + track_psi = max_psi + else: + track_psi = min(t_template // 2, max_psi) + + _, paths = dtw_ndim.warping_paths(template, processed, psi=track_psi) + path = dtw.best_path(paths) + path_arr = np.array(path) + + cost = paths[path_arr[-1, 0], path_arr[-1, 1]] + + # length-normalized cost: divide raw DTW cost by the warp path + # length so longer matches don't accumulate more cost simply by + # having more steps. This is the standard ranking signal for + # subsequence DTW. + if len(path_arr) > 0 and np.isfinite(cost): + length_normalized_cost = float(cost) / float(len(path_arr)) + else: + length_normalized_cost = float("inf") + + # Path skew: mean per-step deviation from the ideal diagonal in + # the warp path's own coordinates. The ideal warp from + # (0, 0) to (T-1, n-1) has slope (n-1)/(T-1); the diagonal at + # warp-path step k is (template_step, query_step) = + # (k * (T-1)/(K-1), k * (n-1)/(K-1)). Skew is the mean L1 + # normalized distance from each warp-path point to that ideal + # diagonal point, divided by max(T, n) for [0, 1] scaling. + if len(path_arr) >= 2 and t_template > 1 and n_track > 1: + K = len(path_arr) + ideal_t = np.linspace(path_arr[0, 0], path_arr[-1, 0], K) + ideal_q = np.linspace(path_arr[0, 1], path_arr[-1, 1], K) + dev = np.abs(path_arr[:, 0] - ideal_t) + np.abs(path_arr[:, 1] - ideal_q) + denom = max(t_template, n_track) + path_skew = float(dev.mean() / denom) + else: + path_skew = 0.0 + + pseudotime = np.zeros(n_track) + speed = np.zeros(n_track) + alignment_region = np.full(n_track, "aligned", dtype=object) + + # Map each query frame to its template position + # DTW path: (idx_template, idx_query) pairs + # A query frame may appear multiple times; keep the last (highest) template position + matched_template_pos = np.full(n_track, -1.0) + for idx_template, idx_query in path: + if idx_query < n_track: + matched_template_pos[idx_query] = idx_template + + if subsequence and t_template > 1: + # Find the matched region (query frames that got a template assignment) + matched_mask = matched_template_pos >= 0 + if matched_mask.any(): + first_matched = np.argmax(matched_mask) + last_matched = n_track - 1 - np.argmax(matched_mask[::-1]) + + # Within matched region: pseudotime from template position + for i in range(first_matched, last_matched + 1): + if matched_template_pos[i] >= 0: + pseudotime[i] = matched_template_pos[i] / (t_template - 1) + + # Forward-fill any gaps within the matched region + for i in range(first_matched + 1, last_matched + 1): + if matched_template_pos[i] < 0: + pseudotime[i] = pseudotime[i - 1] + + # Before matched region: pseudotime = 0 + pseudotime[:first_matched] = 0.0 + # After matched region: pseudotime = 1 + pseudotime[last_matched + 1 :] = 1.0 + alignment_region[:first_matched] = "pre" + alignment_region[last_matched + 1 :] = "post" + else: + pseudotime[:] = 0.0 + alignment_region[:] = "pre" + elif t_template > 1: + # Standard DTW: template position / (template_length - 1) + template_positions = np.zeros(n_track) + for idx_template, idx_query in path: + if idx_query < n_track: + template_positions[idx_query] = idx_template + pseudotime = template_positions / (t_template - 1) + + # Propagate template labels to cell frames via warping path + propagated_labels = None + if template_result.template_labels is not None: + propagated_labels = {} + for col, tl in template_result.template_labels.items(): + col_propagated = np.full(n_track, np.nan) + for idx_template, idx_query in path: + if idx_query < n_track and idx_template < len(tl): + col_propagated[idx_query] = tl[idx_template] + + if subsequence: + matched_mask_lbl = matched_template_pos >= 0 + if matched_mask_lbl.any(): + first_m = np.argmax(matched_mask_lbl) + last_m = n_track - 1 - np.argmax(matched_mask_lbl[::-1]) + for i in range(first_m + 1, last_m + 1): + if np.isnan(col_propagated[i]): + col_propagated[i] = col_propagated[i - 1] + col_propagated[:first_m] = 0.0 + col_propagated[last_m + 1 :] = 1.0 + + propagated_labels[col] = col_propagated + + # Compute warping speed (discrete derivative of pseudotime) + for i in range(n_track): + if i == 0: + speed[i] = pseudotime[1] - pseudotime[0] if n_track > 1 else 0.0 + elif i == n_track - 1: + speed[i] = pseudotime[i] - pseudotime[i - 1] + else: + speed[i] = (pseudotime[i + 1] - pseudotime[i - 1]) / 2 + + cell_uid = f"{dataset_id}/{fov_name}/{track_id}" + results.append( + AlignmentResult( + cell_uid=cell_uid, + dataset_id=dataset_id, + fov_name=fov_name, + track_id=track_id, + timepoints=timepoints, + pseudotime=pseudotime, + length_normalized_cost=length_normalized_cost, + path_skew=path_skew, + dtw_cost=float(cost), + warping_path=path_arr, + warping_speed=speed, + propagated_labels=propagated_labels, + alignment_region=alignment_region, + ) + ) + + _logger.info(f"Aligned {len(results)} tracks for dataset {dataset_id}") + return results + + +def classify_response_groups( + alignment_results: list[AlignmentResult] | pd.DataFrame, + cost_percentile_threshold: float = 75.0, + speed_clustering_method: str = "quantile", + speed_quantile: float = 0.5, +) -> pd.DataFrame: + """Classify aligned cells into response groups. + + Groups: + - non_responder: DTW cost above percentile threshold + - early_responder: responders with above-median mean warping speed + - late_responder: responders with below-median mean warping speed + + Parameters + ---------- + alignment_results : list[AlignmentResult] or pd.DataFrame + Alignment results. If DataFrame, must have columns: + cell_uid, dtw_cost, mean_warping_speed (or warping_speed). + cost_percentile_threshold : float + Percentile of DTW cost above which cells are non-responders. + speed_clustering_method : str + "quantile" or "kmeans" for splitting early/late. + speed_quantile : float + Quantile threshold for speed split (used when method="quantile"). + + Returns + ------- + pd.DataFrame + One row per cell with columns: cell_uid, dataset_id, + response_group, dtw_cost, mean_warping_speed. + """ + if isinstance(alignment_results, pd.DataFrame): + df = alignment_results.copy() + if "mean_warping_speed" not in df.columns and "warping_speed" in df.columns: + df["mean_warping_speed"] = df.groupby("cell_uid")["warping_speed"].transform("mean") + per_cell = df.groupby("cell_uid").first().reset_index() + records = [] + for _, row in per_cell.iterrows(): + records.append( + { + "cell_uid": row["cell_uid"], + "dataset_id": row.get("dataset_id", ""), + "dtw_cost": row["dtw_cost"], + "mean_warping_speed": row["mean_warping_speed"], + } + ) + else: + records = [] + for r in alignment_results: + records.append( + { + "cell_uid": r.cell_uid, + "dataset_id": r.dataset_id, + "dtw_cost": r.dtw_cost, + "mean_warping_speed": float(np.mean(np.abs(r.warping_speed))), + } + ) + + df = pd.DataFrame(records) + if len(df) == 0: + df["response_group"] = pd.Series(dtype=str) + return df + + cost_threshold = np.percentile(df["dtw_cost"], cost_percentile_threshold) + df["response_group"] = "non_responder" + + responder_mask = df["dtw_cost"] <= cost_threshold + responders = df[responder_mask] + + if len(responders) > 0: + if speed_clustering_method == "quantile": + speed_threshold = responders["mean_warping_speed"].quantile(speed_quantile) + df.loc[responder_mask & (df["mean_warping_speed"] >= speed_threshold), "response_group"] = "early_responder" + df.loc[responder_mask & (df["mean_warping_speed"] < speed_threshold), "response_group"] = "late_responder" + elif speed_clustering_method == "kmeans": + from sklearn.cluster import KMeans + + speeds = responders["mean_warping_speed"].to_numpy().reshape(-1, 1) + if len(speeds) >= 2: + km = KMeans(n_clusters=2, random_state=42, n_init=10) + labels = km.fit_predict(speeds) + cluster_means = [speeds[labels == c].mean() for c in range(2)] + fast_cluster = int(np.argmax(cluster_means)) + resp_indices = responders.index + for idx, label in zip(resp_indices, labels): + if label == fast_cluster: + df.loc[idx, "response_group"] = "early_responder" + else: + df.loc[idx, "response_group"] = "late_responder" + else: + df.loc[responder_mask, "response_group"] = "early_responder" + + _logger.info( + f"Classification: {(df['response_group'] == 'early_responder').sum()} early, " + f"{(df['response_group'] == 'late_responder').sum()} late, " + f"{(df['response_group'] == 'non_responder').sum()} non-responder" + ) + + return df[["cell_uid", "dataset_id", "response_group", "dtw_cost", "mean_warping_speed"]] + + +def alignment_results_to_dataframe( + results: list[AlignmentResult], + template_id: str, + time_calibration: np.ndarray | None = None, +) -> pd.DataFrame: + """Flatten alignment results into a DataFrame (one row per timepoint). + + Parameters + ---------- + results : list[AlignmentResult] + Output of dtw_align_tracks. + template_id : str + Template UUID to attach. + time_calibration : np.ndarray or None + (T_template,) array mapping template position to mean t_relative_minutes. + If provided, adds an ``estimated_t_rel_minutes`` column. + + Returns + ------- + pd.DataFrame + Columns: cell_uid, dataset_id, fov_name, track_id, t, + pseudotime, dtw_cost, warping_speed, template_id, + plus propagated_{label}_label for each label column, + plus estimated_t_rel_minutes if time_calibration is provided. + """ + rows = [] + for r in results: + for i, t in enumerate(r.timepoints): + row = { + "cell_uid": r.cell_uid, + "dataset_id": r.dataset_id, + "fov_name": r.fov_name, + "track_id": r.track_id, + "t": int(t), + "pseudotime": float(r.pseudotime[i]), + "dtw_cost": r.dtw_cost, + "length_normalized_cost": float(r.length_normalized_cost), + "path_skew": float(r.path_skew), + "warping_speed": float(r.warping_speed[i]), + "alignment_region": r.alignment_region[i], + "template_id": template_id, + } + if r.propagated_labels is not None: + for col, arr in r.propagated_labels.items(): + col_clean = col.replace("_state", "") + row[f"propagated_{col_clean}_label"] = float(arr[i]) + rows.append(row) + df = pd.DataFrame(rows) + if time_calibration is not None and len(df) > 0: + T = len(time_calibration) + df["estimated_t_rel_minutes"] = np.interp( + df["pseudotime"].to_numpy() * (T - 1), + np.arange(T), + time_calibration, + ) + return df + + +def extract_dtw_pseudotime( + adata: ad.AnnData, + df: pd.DataFrame, + template_result: TemplateResult, + dataset_id: str, + min_track_timepoints: int = 3, + cost_percentile_threshold: float = 75.0, + speed_clustering_method: str = "quantile", + speed_quantile: float = 0.5, + psi: int | None = None, +) -> pd.DataFrame: + """Run align + classify + flatten in one call (convenience wrapper). + + Parameters + ---------- + adata : ad.AnnData + Embeddings AnnData. + df : pd.DataFrame + Tracking DataFrame. + template_result : TemplateResult + Built template. + dataset_id : str + Dataset identifier. + min_track_timepoints : int + Minimum timepoints per track. + cost_percentile_threshold : float + Non-responder cost threshold percentile. + speed_clustering_method : str + "quantile" or "kmeans". + speed_quantile : float + Speed split quantile. + + Returns + ------- + pd.DataFrame + Flat DataFrame with pseudotime renamed to "signal" for metrics + compatibility, plus dtw_cost, warping_speed, response_group columns. + """ + results = dtw_align_tracks(adata, df, template_result, dataset_id, min_track_timepoints, psi=psi) + flat = alignment_results_to_dataframe( + results, template_result.template_id, time_calibration=template_result.time_calibration + ) + classifications = classify_response_groups( + results, + cost_percentile_threshold=cost_percentile_threshold, + speed_clustering_method=speed_clustering_method, + speed_quantile=speed_quantile, + ) + merged = flat.merge(classifications[["cell_uid", "response_group"]], on="cell_uid", how="left") + merged = merged.rename(columns={"pseudotime": "signal"}) + return merged diff --git a/applications/dynaclr/src/dynaclr/pseudotime/evaluation.py b/applications/dynaclr/src/dynaclr/pseudotime/evaluation.py new file mode 100644 index 000000000..916606c7c --- /dev/null +++ b/applications/dynaclr/src/dynaclr/pseudotime/evaluation.py @@ -0,0 +1,295 @@ +"""Evaluation of DTW pseudotime against ground truth annotations. + +Compares DTW-derived pseudotime with annotated infection_state and +organelle_state to quantify alignment quality. Designed to run across +multiple embedding types for comparison. +""" + +from __future__ import annotations + +import logging + +import numpy as np +import pandas as pd +from scipy.stats import spearmanr +from sklearn.metrics import average_precision_score, roc_auc_score + +_logger = logging.getLogger(__name__) + + +def pseudotime_vs_annotation_auc( + df: pd.DataFrame, + pseudotime_col: str = "pseudotime", + annotation_col: str = "infection_state", + positive_value: str = "infected", +) -> float: + """ROC-AUC of pseudotime predicting a binary annotation. + + Parameters + ---------- + df : pd.DataFrame + Must have pseudotime_col and annotation_col columns. + pseudotime_col : str + Column with DTW pseudotime values. + annotation_col : str + Column with ground truth annotation. + positive_value : str + Value in annotation_col that is the positive class. + + Returns + ------- + float + ROC-AUC score, or NaN if not computable. + """ + valid = df.dropna(subset=[pseudotime_col, annotation_col]) + valid = valid[valid[annotation_col] != ""] + if len(valid) == 0: + return np.nan + + y_true = (valid[annotation_col] == positive_value).astype(int).to_numpy() + y_score = valid[pseudotime_col].to_numpy() + + if len(np.unique(y_true)) < 2: + return np.nan + + return float(roc_auc_score(y_true, y_score)) + + +def onset_concordance( + df: pd.DataFrame, + pseudotime_col: str = "pseudotime", + annotation_col: str = "infection_state", + positive_value: str = "infected", + min_track_timepoints: int = 3, +) -> tuple[float, int]: + """Spearman correlation between DTW-derived and annotation-derived onset times. + + For each track, onset is defined as the first timepoint where the signal + transitions to positive. Computes correlation across all tracks that have + a detectable onset in both DTW pseudotime and annotations. + + Parameters + ---------- + df : pd.DataFrame + Must have pseudotime_col, annotation_col, fov_name, track_id, t columns. + pseudotime_col : str + Column with DTW pseudotime values. + annotation_col : str + Column with ground truth annotation. + positive_value : str + Positive value in annotation_col. + min_track_timepoints : int + Minimum timepoints per track to include. + + Returns + ------- + tuple[float, int] + (Spearman rho, n_tracks) or (NaN, 0) if not computable. + """ + valid = df.dropna(subset=[pseudotime_col, annotation_col]) + valid = valid[valid[annotation_col] != ""] + + dtw_onsets = [] + ann_onsets = [] + + for (fov, tid), track in valid.groupby(["fov_name", "track_id"]): + if len(track) < min_track_timepoints: + continue + track = track.sort_values("t") + + # Annotation onset: first timepoint with positive value + ann_positive = track[track[annotation_col] == positive_value] + if len(ann_positive) == 0: + continue + ann_onset_t = ann_positive["t"].iloc[0] + + # DTW onset: first timepoint where pseudotime exceeds median of track + pt = track[pseudotime_col].to_numpy() + threshold = np.median(pt) + above = track[track[pseudotime_col] > threshold] + if len(above) == 0: + continue + dtw_onset_t = above["t"].iloc[0] + + dtw_onsets.append(dtw_onset_t) + ann_onsets.append(ann_onset_t) + + if len(dtw_onsets) < 3: + return np.nan, len(dtw_onsets) + + rho, _ = spearmanr(dtw_onsets, ann_onsets) + return float(rho), len(dtw_onsets) + + +def per_timepoint_auc( + df: pd.DataFrame, + pseudotime_col: str = "pseudotime", + annotation_col: str = "infection_state", + positive_value: str = "infected", + time_col: str = "t", +) -> pd.DataFrame: + """ROC-AUC of pseudotime predicting annotation at each timepoint. + + Parameters + ---------- + df : pd.DataFrame + Must have pseudotime_col, annotation_col, time_col columns. + pseudotime_col : str + Column with DTW pseudotime values. + annotation_col : str + Column with ground truth annotation. + positive_value : str + Positive value in annotation_col. + time_col : str + Timepoint column. + + Returns + ------- + pd.DataFrame + Columns: t, auc, n_cells, n_positive. + """ + valid = df.dropna(subset=[pseudotime_col, annotation_col]) + valid = valid[valid[annotation_col] != ""] + + rows = [] + for t_val, group in valid.groupby(time_col): + y_true = (group[annotation_col] == positive_value).astype(int).to_numpy() + y_score = group[pseudotime_col].to_numpy() + n_pos = int(y_true.sum()) + + if len(np.unique(y_true)) < 2: + auc = np.nan + else: + auc = float(roc_auc_score(y_true, y_score)) + + rows.append({"t": t_val, "auc": auc, "n_cells": len(group), "n_positive": n_pos}) + + return pd.DataFrame(rows) + + +def _pseudotime_ap( + df: pd.DataFrame, + pseudotime_col: str = "pseudotime", + annotation_col: str = "infection_state", + positive_value: str = "infected", +) -> float: + """Average precision (AUPRC) of pseudotime predicting a binary annotation. + + Parameters + ---------- + df : pd.DataFrame + Must have pseudotime_col and annotation_col columns. + pseudotime_col : str + Column with DTW pseudotime values. + annotation_col : str + Column with ground truth annotation. + positive_value : str + Value in annotation_col that is the positive class. + + Returns + ------- + float + Average precision score, or NaN if not computable. + """ + valid = df.dropna(subset=[pseudotime_col, annotation_col]) + valid = valid[valid[annotation_col] != ""] + if len(valid) == 0: + return np.nan + + y_true = (valid[annotation_col] == positive_value).astype(int).to_numpy() + y_score = valid[pseudotime_col].to_numpy() + + if len(np.unique(y_true)) < 2: + return np.nan + + return float(average_precision_score(y_true, y_score)) + + +def evaluate_embedding( + alignments: pd.DataFrame, + annotations: pd.DataFrame, + embedding_name: str, + dataset_id: str, +) -> dict: + """Run full evaluation suite for one embedding × dataset. + + Parameters + ---------- + alignments : pd.DataFrame + Output of alignment_results_to_dataframe (has pseudotime, fov_name, + track_id, t columns). + annotations : pd.DataFrame + Annotation CSV with fov_name, track_id, t, infection_state, + organelle_state columns. + embedding_name : str + Name of the embedding (e.g., "sensor", "organelle", "phase"). + dataset_id : str + Dataset identifier. + + Returns + ------- + dict + Summary metrics for this embedding × dataset. + """ + # Merge alignments with annotations + merge_keys = ["fov_name", "track_id", "t"] + merged = alignments.merge( + annotations[merge_keys + ["infection_state", "organelle_state"]], on=merge_keys, how="left" + ) + + result = { + "embedding": embedding_name, + "dataset_id": dataset_id, + "n_cells": len(merged), + "n_tracks": merged.groupby(["fov_name", "track_id"]).ngroup().nunique(), + } + + # Infection state AUC + AP + result["infection_auc"] = pseudotime_vs_annotation_auc( + merged, pseudotime_col="pseudotime", annotation_col="infection_state", positive_value="infected" + ) + result["infection_ap"] = _pseudotime_ap( + merged, pseudotime_col="pseudotime", annotation_col="infection_state", positive_value="infected" + ) + + # Organelle state AUC + AP + result["organelle_auc"] = pseudotime_vs_annotation_auc( + merged, pseudotime_col="pseudotime", annotation_col="organelle_state", positive_value="remodel" + ) + result["organelle_ap"] = _pseudotime_ap( + merged, pseudotime_col="pseudotime", annotation_col="organelle_state", positive_value="remodel" + ) + + # Onset concordance (infection) + rho, n_tracks = onset_concordance( + merged, pseudotime_col="pseudotime", annotation_col="infection_state", positive_value="infected" + ) + result["infection_onset_spearman"] = rho + result["infection_onset_n_tracks"] = n_tracks + + # Onset concordance (organelle) + rho_org, n_tracks_org = onset_concordance( + merged, pseudotime_col="pseudotime", annotation_col="organelle_state", positive_value="remodel" + ) + result["organelle_onset_spearman"] = rho_org + result["organelle_onset_n_tracks"] = n_tracks_org + + # Mean DTW cost + if "dtw_cost" in alignments.columns: + per_track_cost = alignments.groupby(["fov_name", "track_id"])["dtw_cost"].first() + result["mean_dtw_cost"] = float(per_track_cost.mean()) + result["median_dtw_cost"] = float(per_track_cost.median()) + + _logger.info( + "%s/%s: infection_auc=%.3f ap=%.3f, organelle_auc=%.3f ap=%.3f, onset_rho=%.3f (%d tracks)", + embedding_name, + dataset_id, + result.get("infection_auc", np.nan), + result.get("infection_ap", np.nan), + result.get("organelle_auc", np.nan), + result.get("organelle_ap", np.nan), + result.get("infection_onset_spearman", np.nan), + result.get("infection_onset_n_tracks", 0), + ) + + return result diff --git a/applications/dynaclr/src/dynaclr/pseudotime/io.py b/applications/dynaclr/src/dynaclr/pseudotime/io.py new file mode 100644 index 000000000..576b4f2ff --- /dev/null +++ b/applications/dynaclr/src/dynaclr/pseudotime/io.py @@ -0,0 +1,428 @@ +"""IO helpers for DTW pseudotime: template zarr layout + dataset routing. + +Centralizes knowledge of the on-disk schemas so the pipeline scripts do +not have to duplicate ``zarr.open`` plumbing or re-derive dataset paths +from filename conventions. Two responsibilities: + +- **Template zarr IO** (``save_template_zarr``, ``load_template_flavor``, + ``read_template_attrs``, ``read_time_calibration``). + The template zarr stores both PCA and raw flavors of a DBA template in + one store, plus shared metadata (z-score params, t_key_event per cell, + config snapshot, version provenance). + +- **Embedding-zarr discovery** (``date_prefix_from_dataset_id``, + ``find_embedding_zarr``, ``get_dynaclr_versions``). + Resolves a dataset_id + filename pattern to the single zarr produced + by the evaluation pipeline. +""" + +from __future__ import annotations + +import glob +import importlib.metadata as _metadata +import os +import subprocess +from pathlib import Path + +import numpy as np +import zarr +from sklearn.decomposition import PCA + +from dynaclr.pseudotime.dtw_alignment import TemplateResult + + +def date_prefix_from_dataset_id(dataset_id: str) -> str: + """Extract a leading ``YYYY_MM_DD_`` prefix from ``dataset_id``. + + Many embedding zarrs are named with a date prefix derived from the + experiment id. This helper recovers the prefix used to glob for the + embedding file under the dataset's ``pred_dir``. + + Parameters + ---------- + dataset_id : str + Dataset identifier such as ``2024_07_24_A549_ZIKV_SEC61``. + + Returns + ------- + str + Prefix including the trailing underscore, or an empty string if + the id has fewer than three underscore-separated parts. + """ + parts = dataset_id.split("_") + if len(parts) < 3: + return "" + return "_".join(parts[:3]) + "_" + + +def find_embedding_zarr(pred_dir: str | Path, pattern: str) -> str: + """Find the single embedding zarr matching ``pattern`` in ``pred_dir``. + + Parameters + ---------- + pred_dir : str or Path + Directory containing the per-dataset embedding zarrs. + pattern : str + Glob pattern (typically ``date_prefix + embedding_pattern``). + + Returns + ------- + str + Absolute path to the matching zarr. + + Raises + ------ + FileNotFoundError + If zero or more than one zarr matches the pattern. + """ + matches = glob.glob(str(Path(pred_dir) / pattern)) + if len(matches) == 0: + raise FileNotFoundError(f"No zarr matching {pattern} in {pred_dir}") + if len(matches) > 1: + names = sorted(Path(m).name for m in matches) + raise FileNotFoundError(f"Multiple zarrs match {pattern}: {names}") + return matches[0] + + +def get_dynaclr_versions() -> dict[str, str]: + """Return a dict of code/library versions for template provenance. + + Captured fields: + + - ``viscy_git_sha``: short SHA of the current repo HEAD, or + ``"unknown"`` if the repo is unavailable. + - ``dtaidistance_version``: installed dtaidistance package version. + - ``sklearn_version``: installed scikit-learn version. + - ``numpy_version``: installed numpy version. + + Stamping these into every template zarr is what lets a future + consumer reproduce or invalidate a published template after the + embedding model or library stack moves. + """ + sha = "unknown" + try: + result = subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + capture_output=True, + text=True, + cwd=os.path.dirname(__file__), + check=False, + timeout=2, + ) + if result.returncode == 0: + sha = result.stdout.strip() + except (OSError, subprocess.SubprocessError): + pass + + versions = {"viscy_git_sha": sha} + for pkg in ("dtaidistance", "scikit-learn", "numpy"): + try: + versions[f"{pkg.replace('-', '_')}_version"] = _metadata.version(pkg) + except _metadata.PackageNotFoundError: + versions[f"{pkg.replace('-', '_')}_version"] = "unknown" + return versions + + +def compute_tau_event_band( + template: np.ndarray, + threshold_fraction: float = 0.5, +) -> tuple[float, float]: + """Compute the half-rise band of a template's first-derivative magnitude. + + The template's "event" is the region of fastest change. We compute + the per-position rate of change as the L2 norm of consecutive + template differences, then return the pseudotime band where the + rate exceeds ``threshold_fraction`` of its maximum. + + Per discussion §3.4 and the locked execution plan: τ_event is a + band, not a point, because (1) the template's argmax-derivative has + a resolution floor of 1/n_frames, and (2) DBA averages flatten + cell-specific kinks so the template derivative is structurally + smoother than any individual cell's. The band honestly reflects + template-derivative resolution. + + Parameters + ---------- + template : np.ndarray + DBA template, shape ``(T, D)`` where T is the number of + template positions and D the embedding dimension. + threshold_fraction : float + Fraction of the maximum derivative magnitude that defines the + band edges. Default 0.5 (half-rise band). + + Returns + ------- + tuple[float, float] + ``(τ_lo, τ_hi)`` in pseudotime ∈ [0, 1]. If the template has + fewer than two positions or the derivative is degenerate, + returns ``(0.0, 1.0)``. + """ + if template.ndim != 2 or template.shape[0] < 2: + return (0.0, 1.0) + + diffs = np.diff(template, axis=0) + rate = np.linalg.norm(diffs, axis=1) # shape (T-1,) + if rate.size == 0 or float(rate.max()) <= 0: + return (0.0, 1.0) + + threshold = threshold_fraction * float(rate.max()) + above = rate >= threshold + indices = np.where(above)[0] + if indices.size == 0: + return (0.0, 1.0) + + # Map derivative-position indices to pseudotime midpoints. + # rate[i] reports the change from template[i] to template[i+1], so + # the rate's natural pseudotime is the midpoint between positions i + # and i+1: τ = (i + 0.5) / (T - 1). + n_positions = template.shape[0] + tau_lo = float(indices.min() + 0.5) / float(n_positions - 1) + tau_hi = float(indices.max() + 0.5) / float(n_positions - 1) + return (tau_lo, tau_hi) + + +def _save_flavor(group, result: TemplateResult, flavor_name: str) -> None: + """Serialize one ``TemplateResult`` flavor into a zarr group.""" + group.create_array("template", data=result.template) + if result.time_calibration is not None: + group.create_array("time_calibration", data=result.time_calibration) + if result.template_labels is not None: + labels_grp = group.create_group("template_labels") + for col, fractions in result.template_labels.items(): + labels_grp.create_array(col, data=fractions) + if result.pca is not None: + group.create_array("components", data=result.pca.components_) + group.create_array("mean", data=result.pca.mean_) + group.create_array("explained_variance", data=result.pca.explained_variance_) + group.create_array("explained_variance_ratio", data=result.pca.explained_variance_ratio_) + group.attrs["n_components"] = int(result.pca.n_components_) + group.attrs["explained_variance"] = float(result.explained_variance or 0.0) + + # τ_event band: half-rise of the template-derivative magnitude. + # Stored per-flavor because raw and PCA templates have different + # geometries and may yield slightly different bands. + tau_lo, tau_hi = compute_tau_event_band(result.template) + group.create_array("tau_event_band", data=np.array([tau_lo, tau_hi], dtype=np.float64)) + + group.attrs["template_id"] = result.template_id + group.attrs["n_input_tracks"] = int(result.n_input_tracks) + + +def save_template_zarr( + out_path: str | Path, + raw_result: TemplateResult, + pca_result: TemplateResult, + *, + template_name: str, + config_snapshot: dict, + anchor_label: str, + anchor_positive: str, + aggregator: str, + t_key_event_per_cell: np.ndarray, + build_frame_intervals_minutes: dict[str, float], + template_duration_minutes: float, + extra_attrs: dict | None = None, +) -> None: + """Serialize both template flavors + shared metadata into a single zarr. + + Provenance fields (``viscy_git_sha``, ``dtaidistance_version``, …) + are stamped automatically via :func:`get_dynaclr_versions`. + + Parameters + ---------- + out_path : str or Path + Destination zarr directory. Will be created/overwritten. + raw_result, pca_result : TemplateResult + The two template flavors to serialize. + template_name : str + Identifier used downstream to name the alignment parquet. + config_snapshot : dict + Full config under which this template was built. Stored verbatim. + anchor_label, anchor_positive : str + Label column and its positive value (e.g. + ``("infection_state", "infected")``). + aggregator : str + Aggregator name (currently always ``"dba"``). + t_key_event_per_cell : np.ndarray + Per-cell event timepoints in the same order as + ``raw_result.template_cell_ids``. + build_frame_intervals_minutes : dict[str, float] + Per-dataset frame intervals so consumers can apply minute-based + guards without guessing the template's time scale. + template_duration_minutes : float + Duration of the calibrated template in real minutes. + extra_attrs : dict or None + Additional attrs to merge into the store. Useful for + method-specific metadata not covered by the shared schema. + """ + store = zarr.open(str(out_path), mode="w") + + _save_flavor(store.create_group("raw"), raw_result, "raw") + _save_flavor(store.create_group("pca"), pca_result, "pca") + + if raw_result.zscore_params: + zgrp = store.create_group("zscore_params") + for ds_id, (mean, std) in raw_result.zscore_params.items(): + ds_grp = zgrp.create_group(ds_id) + ds_grp.create_array("mean", data=mean) + ds_grp.create_array("std", data=std) + + store.create_array("t_key_event", data=np.asarray(t_key_event_per_cell)) + + store.attrs["template_name"] = template_name + store.attrs["template_cell_ids"] = [list(c) for c in raw_result.template_cell_ids] + store.attrs["anchor_label"] = anchor_label + store.attrs["anchor_positive"] = anchor_positive + store.attrs["aggregator"] = aggregator + store.attrs["template_duration_minutes"] = float(template_duration_minutes) + store.attrs["build_frame_intervals_minutes"] = {k: float(v) for k, v in build_frame_intervals_minutes.items()} + store.attrs["config_snapshot"] = config_snapshot + + versions = get_dynaclr_versions() + for k, v in versions.items(): + store.attrs[k] = v + + if extra_attrs: + for k, v in extra_attrs.items(): + store.attrs[k] = v + + +def load_template_flavor(template_path: str | Path, flavor: str) -> tuple[TemplateResult, dict]: + """Load one flavor from a two-flavor template zarr. + + Parameters + ---------- + template_path : str or Path + Path to the template zarr written by :func:`save_template_zarr`. + flavor : {"raw", "pca"} + Which flavor to materialize. + + Returns + ------- + (TemplateResult, dict) + The selected flavor's :class:`TemplateResult` (template + PCA + for the ``"pca"`` flavor + shared z-score params), and the raw + attrs dict for anything else the caller needs. + + Raises + ------ + ValueError + If ``flavor`` is not ``"raw"`` or ``"pca"``. + KeyError + If the requested flavor is not present in the store. + """ + store = zarr.open(str(template_path), mode="r") + attrs = dict(store.attrs) + + if flavor not in ("raw", "pca"): + raise ValueError(f"flavor must be 'raw' or 'pca', got {flavor!r}") + if flavor not in store: + raise KeyError(f"Flavor {flavor!r} not in template zarr {template_path}") + grp = store[flavor] + + template = np.asarray(grp["template"]) + time_calibration = np.asarray(grp["time_calibration"]) if "time_calibration" in grp else None + + template_labels = None + if "template_labels" in grp: + tl_grp = grp["template_labels"] + template_labels = {col: np.asarray(tl_grp[col]) for col in tl_grp} + + pca = None + if flavor == "pca" and "components" in grp: + n_comp = int(grp.attrs["n_components"]) + pca = PCA(n_components=n_comp) + pca.components_ = np.asarray(grp["components"]) + pca.mean_ = np.asarray(grp["mean"]) + if "explained_variance" in grp: + pca.explained_variance_ = np.asarray(grp["explained_variance"]) + pca.explained_variance_ratio_ = np.asarray(grp["explained_variance_ratio"]) + pca.n_components_ = n_comp + pca.n_features_in_ = pca.components_.shape[1] + + zscore_params: dict[str, tuple[np.ndarray, np.ndarray]] = {} + if "zscore_params" in store: + zgrp = store["zscore_params"] + for ds_id in zgrp: + zscore_params[ds_id] = ( + np.asarray(zgrp[ds_id]["mean"]), + np.asarray(zgrp[ds_id]["std"]), + ) + + template_id = str(grp.attrs.get("template_id", "")) + n_input_tracks = int(grp.attrs.get("n_input_tracks", 0)) + cell_ids = [tuple(c) for c in attrs.get("template_cell_ids", [])] + + result = TemplateResult( + template=template, + template_id=template_id, + pca=pca, + zscore_params=zscore_params, + template_cell_ids=cell_ids, + n_input_tracks=n_input_tracks, + explained_variance=float(grp.attrs.get("explained_variance", 0.0)) or None, + template_labels=template_labels, + time_calibration=time_calibration, + ) + return result, attrs + + +def read_template_attrs(template_path: str | Path) -> dict: + """Read the top-level attrs of a template zarr. + + Convenience wrapper for the common case where a downstream script + only needs the metadata (config snapshot, anchor label, version + stamps) and not the template arrays themselves. + """ + return dict(zarr.open(str(template_path), mode="r").attrs) + + +def read_time_calibration(template_path: str | Path, flavor: str) -> np.ndarray: + """Read the per-position ``time_calibration`` array for a flavor. + + Returns the array of mean ``t_relative_minutes`` per template + position, used for converting DTW pseudotime back to real minutes + in downstream timing analyses. + + Parameters + ---------- + template_path : str or Path + Path to the template zarr. + flavor : {"raw", "pca"} + Which flavor's calibration to read. + + Raises + ------ + KeyError + If the flavor or its ``time_calibration`` array is missing. + """ + if flavor not in ("raw", "pca"): + raise ValueError(f"flavor must be 'raw' or 'pca', got {flavor!r}") + grp = zarr.open(str(template_path), mode="r")[flavor] + if "time_calibration" not in grp: + raise KeyError(f"time_calibration missing for flavor {flavor!r} in {template_path}") + return np.asarray(grp["time_calibration"]) + + +def read_tau_event_band(template_path: str | Path, flavor: str) -> tuple[float, float]: + """Read the τ_event band ``[τ_lo, τ_hi]`` from a template flavor. + + The band is computed at template-build time (see + :func:`compute_tau_event_band`) and stored alongside the template + arrays. Returns ``(0.0, 1.0)`` if the band array is missing + (templates built before the band feature was added). + + Parameters + ---------- + template_path : str or Path + Path to the template zarr. + flavor : {"raw", "pca"} + Which flavor's band to read. + """ + if flavor not in ("raw", "pca"): + raise ValueError(f"flavor must be 'raw' or 'pca', got {flavor!r}") + grp = zarr.open(str(template_path), mode="r")[flavor] + if "tau_event_band" not in grp: + return (0.0, 1.0) + band = np.asarray(grp["tau_event_band"]) + return (float(band[0]), float(band[1])) diff --git a/applications/dynaclr/src/dynaclr/evaluation/pseudotime/metrics.py b/applications/dynaclr/src/dynaclr/pseudotime/metrics.py similarity index 98% rename from applications/dynaclr/src/dynaclr/evaluation/pseudotime/metrics.py rename to applications/dynaclr/src/dynaclr/pseudotime/metrics.py index 54b74777e..b6fb32466 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/pseudotime/metrics.py +++ b/applications/dynaclr/src/dynaclr/pseudotime/metrics.py @@ -110,7 +110,7 @@ def aggregate_population( } ) else: - vals = bin_data[signal_col].values + vals = bin_data[signal_col].to_numpy() results.append( { "time_minutes": bin_start, @@ -280,8 +280,8 @@ def find_peak_metrics( # AUC (area under curve from baseline) valid_mask = post_infection[signal_col].notna() if valid_mask.sum() > 1: - times = post_infection.loc[valid_mask, "time_minutes"].values - values = post_infection.loc[valid_mask, signal_col].values - baseline_mean + times = post_infection.loc[valid_mask, "time_minutes"].to_numpy() + values = post_infection.loc[valid_mask, signal_col].to_numpy() - baseline_mean auc = float(np.trapezoid(values, times)) else: auc = np.nan diff --git a/applications/dynaclr/src/dynaclr/evaluation/pseudotime/plotting.py b/applications/dynaclr/src/dynaclr/pseudotime/plotting.py similarity index 97% rename from applications/dynaclr/src/dynaclr/evaluation/pseudotime/plotting.py rename to applications/dynaclr/src/dynaclr/pseudotime/plotting.py index 47d191b3d..5ff75b805 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/pseudotime/plotting.py +++ b/applications/dynaclr/src/dynaclr/pseudotime/plotting.py @@ -187,12 +187,12 @@ def plot_cell_heatmap( fig, ax = plt.subplots(figsize=(14, max(4, len(pivot) * 0.06))) - bin_centers = pivot.columns.values + bin_centers = pivot.columns.to_numpy() bin_width = time_bins[1] - time_bins[0] bin_edges_hours = np.append(bin_centers, bin_centers[-1] + bin_width) / 60 if signal_type == "fraction": - plot_data = pivot.values.copy() + plot_data = pivot.to_numpy().copy() plot_data = np.where(np.isnan(plot_data), -1, plot_data) cmap = ListedColormap(["#ffffff", "#c6dbef", "#08519c"]) im = ax.pcolormesh( @@ -206,7 +206,7 @@ def plot_cell_heatmap( cbar = plt.colorbar(im, ax=ax, ticks=[-1, 0, 1]) cbar.ax.set_yticklabels(["No data", "No remodel", "Remodel"]) else: - plot_data = pivot.values.copy() + plot_data = pivot.to_numpy().copy() im = ax.pcolormesh( bin_edges_hours, np.arange(len(pivot) + 1), @@ -316,7 +316,7 @@ def plot_onset_comparison( """ fig, ax = plt.subplots(figsize=(8, 5)) - organelles = timing_metrics["marker"].values + organelles = timing_metrics["marker"].to_numpy() x = np.arange(len(organelles)) width = 0.25 @@ -332,7 +332,7 @@ def plot_onset_comparison( labels.append(label) for i, (col, label) in enumerate(zip(metrics_to_plot, labels)): - values_hours = timing_metrics[col].values / 60 + values_hours = timing_metrics[col].to_numpy() / 60 offset = (i - len(metrics_to_plot) / 2 + 0.5) * width ax.bar(x + offset, values_hours, width, label=label, alpha=0.8) diff --git a/applications/dynaclr/src/dynaclr/evaluation/pseudotime/signals.py b/applications/dynaclr/src/dynaclr/pseudotime/signals.py similarity index 96% rename from applications/dynaclr/src/dynaclr/evaluation/pseudotime/signals.py rename to applications/dynaclr/src/dynaclr/pseudotime/signals.py index 906763253..94274f61d 100644 --- a/applications/dynaclr/src/dynaclr/evaluation/pseudotime/signals.py +++ b/applications/dynaclr/src/dynaclr/pseudotime/signals.py @@ -119,14 +119,14 @@ def extract_prediction_signal( obs_lookup = obs.set_index(["fov_name", "track_id", "t"])["_proba_positive"] result["signal"] = np.nan matched = result_key.index.isin(common_idx) - result.loc[matched, "signal"] = obs_lookup.reindex(result_key.index[matched]).values + result.loc[matched, "signal"] = obs_lookup.reindex(result_key.index[matched]).to_numpy() else: obs_lookup = obs.set_index(["fov_name", "track_id", "t"])[pred_col] predictions = obs_lookup.reindex(result_key.index) result["signal"] = np.where( - predictions.isna().values, + predictions.isna().to_numpy(), np.nan, - (predictions.values == positive_value).astype(float), + (predictions.to_numpy() == positive_value).astype(float), ) return result @@ -181,7 +181,7 @@ def extract_embedding_distance( result_key = result.set_index(["fov_name", "track_id", "t"]) common_idx = result_key.index.intersection(obs_lookup.index) - adata_indices = obs_lookup.reindex(common_idx).values.astype(int) + adata_indices = obs_lookup.reindex(common_idx).to_numpy().astype(int) result_row_mask = result_key.index.isin(common_idx) result_rows = np.where(result_row_mask)[0] @@ -197,7 +197,7 @@ def extract_embedding_distance( if baseline_method == "control_well" or pca_n_components is not None: if control_fov_pattern is not None: ctrl_mask = adata.obs["fov_name"].astype(str).str.contains(control_fov_pattern, regex=True) - ctrl_emb = adata.X[ctrl_mask.values] + ctrl_emb = adata.X[ctrl_mask.to_numpy()] if not isinstance(ctrl_emb, np.ndarray): ctrl_emb = np.asarray(ctrl_emb) if len(ctrl_emb) > 0: @@ -233,7 +233,7 @@ def extract_embedding_distance( elif baseline_method == "per_track": for _, group in local_df.groupby(["fov_name", "track_id"]): - group_emb_idx = group["_emb_idx"].values + group_emb_idx = group["_emb_idx"].to_numpy() # Find baseline frames bl_mask = (group["t_relative_minutes"] >= baseline_window_minutes[0]) & ( @@ -247,7 +247,7 @@ def extract_embedding_distance( else: continue else: - bl_idx = group.loc[bl_mask, "_emb_idx"].values + bl_idx = group.loc[bl_mask, "_emb_idx"].to_numpy() baseline = embeddings[bl_idx].mean(axis=0, keepdims=True) track_emb = embeddings[group_emb_idx] diff --git a/applications/dynaclr/tests/conftest.py b/applications/dynaclr/tests/conftest.py index 7b37bf5a9..855efe84f 100644 --- a/applications/dynaclr/tests/conftest.py +++ b/applications/dynaclr/tests/conftest.py @@ -144,6 +144,14 @@ def create_experiment( dtype=np.float32, ) arr[:] = rng.standard_normal(arr.shape).astype(np.float32) + tp_stats = { + str(t): {"mean": 1.0, "std": 0.5, "median": 1.0, "iqr": 1.0, "max": 2.0, "min": 0.0} + for t in range(n_t) + } + pos.zattrs["normalization"] = { + ch: {"fov_statistics": {"mean": 1.0, "std": 0.5}, "timepoint_statistics": tp_stats} + for ch in channel_names + } fov_name = f"{row}/{col}/{fov_idx}" csv_path = tracks_root / fov_name / "tracks.csv" make_tracks_csv( diff --git a/applications/dynaclr/tests/test_datamodule.py b/applications/dynaclr/tests/test_datamodule.py index 0907954d4..25cac7e52 100644 --- a/applications/dynaclr/tests/test_datamodule.py +++ b/applications/dynaclr/tests/test_datamodule.py @@ -5,7 +5,8 @@ from __future__ import annotations import pytest -import torch + +from viscy_data.cell_index import build_timelapse_cell_index # --------------------------------------------------------------------------- # Constants @@ -23,7 +24,7 @@ @pytest.fixture() def four_experiments(tmp_path, _create_experiment, _write_collection_yaml): - """Four synthetic experiments with collection YAML.""" + """Four synthetic experiments with collection YAML and cell index parquet.""" entries = [] for i, name in enumerate(["exp_a", "exp_b", "exp_c", "exp_d"]): row_letter = chr(ord("A") + i) @@ -37,12 +38,14 @@ def four_experiments(tmp_path, _create_experiment, _write_collection_yaml): ) ) collection_path = _write_collection_yaml(tmp_path, entries) - return collection_path, entries + parquet_path = tmp_path / "cell_index.parquet" + build_timelapse_cell_index(collection_path, parquet_path) + return parquet_path, entries @pytest.fixture() def two_experiments(tmp_path, _create_experiment, _write_collection_yaml): - """Two synthetic experiments for simpler tests.""" + """Two synthetic experiments with cell index parquet.""" entries = [ _create_experiment( tmp_path, @@ -60,7 +63,9 @@ def two_experiments(tmp_path, _create_experiment, _write_collection_yaml): ), ] collection_path = _write_collection_yaml(tmp_path, entries) - return collection_path, entries + parquet_path = tmp_path / "cell_index.parquet" + build_timelapse_cell_index(collection_path, parquet_path) + return parquet_path, entries @pytest.fixture() @@ -85,7 +90,9 @@ def multi_fov_experiments(tmp_path, _create_experiment, _write_collection_yaml): ), ] collection_path = _write_collection_yaml(tmp_path, entries) - return collection_path, entries + parquet_path = tmp_path / "cell_index.parquet" + build_timelapse_cell_index(collection_path, parquet_path) + return parquet_path, entries # --------------------------------------------------------------------------- @@ -100,9 +107,9 @@ def test_init_exposes_all_hyperparameters(self, two_experiments): """Instantiate with all hyperparameters explicitly set and verify storage.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = two_experiments + parquet_path, _ = two_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -148,9 +155,9 @@ def test_train_val_split_by_experiment(self, four_experiments): """With 4 experiments and val_experiments=[exp_c, exp_d], verify correct split.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = four_experiments + parquet_path, _ = four_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -183,9 +190,9 @@ def test_train_dataloader_uses_flexible_batch_sampler(self, two_experiments): """train_dataloader() returns a ThreadDataLoader with FlexibleBatchSampler.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = two_experiments + parquet_path, _ = two_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -215,128 +222,31 @@ def test_train_dataloader_uses_flexible_batch_sampler(self, two_experiments): assert sampler.temporal_enrichment is False -class TestValDataloaderNoBatchSampler: - """Validation should be deterministic without FlexibleBatchSampler.""" - - def test_val_dataloader_no_batch_sampler(self, two_experiments): - """val_dataloader uses simple sequential loading.""" - from dynaclr.data.datamodule import MultiExperimentDataModule - - collection_path, _ = two_experiments - dm = MultiExperimentDataModule( - collection_path=str(collection_path), - z_window=1, - yx_patch_size=_YX_PATCH, - final_yx_patch_size=_FINAL_YX_PATCH, - val_experiments=["exp_b"], - tau_range=(0.5, 2.0), - batch_size=8, - ) - dm.setup("fit") - val_dl = dm.val_dataloader() - - from viscy_data.sampler import FlexibleBatchSampler - - # val_dataloader should NOT use FlexibleBatchSampler - assert not isinstance(val_dl.batch_sampler, FlexibleBatchSampler), ( - "Validation should NOT use FlexibleBatchSampler" - ) +class TestTrainDataloaderWiresDDPTopology: + """train_dataloader must forward Trainer world_size/rank to the sampler.""" + def test_reads_world_size_and_rank_from_trainer(self, two_experiments): + from types import SimpleNamespace -class TestOnAfterBatchTransferAppliesTransforms: - """Verify on_after_batch_transfer applies transforms and ChannelDropout.""" - - def test_on_after_batch_transfer_applies_channel_dropout_and_transforms(self, two_experiments): - """Create a mock batch and verify on_after_batch_transfer processes it.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = two_experiments + parquet_path, _ = two_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, val_experiments=["exp_b"], tau_range=(0.5, 2.0), batch_size=8, - channel_dropout_channels=[1], - channel_dropout_prob=0.0, # No dropout for this test - ) - dm.setup("fit") - - # Create a synthetic batch dict - B, C, Z, Y, X = 4, 2, 1, 32, 32 - batch = { - "anchor": torch.randn(B, C, Z, Y, X), - "positive": torch.randn(B, C, Z, Y, X), - "anchor_norm_meta": [None] * B, - "positive_norm_meta": [None] * B, - } - - result = dm.on_after_batch_transfer(batch, 0) - - # Output should have anchor and positive as Tensors - assert isinstance(result["anchor"], torch.Tensor) - assert isinstance(result["positive"], torch.Tensor) - - # norm_meta keys should be consumed (removed) - assert "anchor_norm_meta" not in result - assert "positive_norm_meta" not in result - - # Final crop should reduce spatial size to final_yx_patch_size - assert result["anchor"].shape[-2:] == ( - _FINAL_YX_PATCH[0], - _FINAL_YX_PATCH[1], - ), f"Expected spatial {_FINAL_YX_PATCH}, got {result['anchor'].shape[-2:]}" - - -class TestChannelDropoutIntegration: - """Verify ChannelDropout behavior in train vs eval mode.""" - - def test_channel_dropout_integration(self, two_experiments): - """With p=1.0 on channel 1, training zeros ch1; eval preserves it.""" - from dynaclr.data.datamodule import MultiExperimentDataModule - - collection_path, _ = two_experiments - dm = MultiExperimentDataModule( - collection_path=str(collection_path), - z_window=1, - yx_patch_size=_YX_PATCH, - final_yx_patch_size=_FINAL_YX_PATCH, - val_experiments=["exp_b"], - tau_range=(0.5, 2.0), - batch_size=8, - channel_dropout_channels=[1], - channel_dropout_prob=1.0, # Always drop channel 1 + batch_group_by="experiment", + stratify_by="perturbation", + temporal_enrichment=False, ) dm.setup("fit") - - B, C, Z, Y, X = 4, 2, 1, 32, 32 - batch_train = { - "anchor": torch.randn(B, C, Z, Y, X).abs() + 0.1, # all positive - "positive": torch.randn(B, C, Z, Y, X).abs() + 0.1, - "anchor_norm_meta": [None] * B, - "positive_norm_meta": [None] * B, - } - - # Training mode: channel 1 should be zeroed - dm.channel_dropout.train() - result_train = dm.on_after_batch_transfer(batch_train, 0) - assert torch.all(result_train["anchor"][:, 1] == 0.0), "Training: channel 1 should be all zeros with p=1.0" - assert torch.all(result_train["positive"][:, 1] == 0.0), ( - "Training: positive channel 1 should be all zeros with p=1.0" - ) - - # Eval mode: channel 1 should be preserved - dm.channel_dropout.eval() - batch_eval = { - "anchor": torch.randn(B, C, Z, Y, X).abs() + 0.1, - "positive": torch.randn(B, C, Z, Y, X).abs() + 0.1, - "anchor_norm_meta": [None] * B, - "positive_norm_meta": [None] * B, - } - result_eval = dm.on_after_batch_transfer(batch_eval, 0) - assert not torch.all(result_eval["anchor"][:, 1] == 0.0), "Eval: channel 1 should NOT be zeroed" + dm.__dict__["trainer"] = SimpleNamespace(world_size=4, global_rank=2) + sampler = dm.train_dataloader().batch_sampler + assert (sampler.num_replicas, sampler.rank) == (4, 2) class TestFovLevelSplit: @@ -346,9 +256,9 @@ def test_fov_split_no_overlap(self, multi_fov_experiments): """With split_ratio=0.6, FOVs are split within each experiment with no overlap.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = multi_fov_experiments + parquet_path, _ = multi_fov_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -381,9 +291,9 @@ def test_fov_split_ratio_1_no_val(self, multi_fov_experiments): """With split_ratio=1.0, all FOVs go to train and val_dataset is None.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = multi_fov_experiments + parquet_path, _ = multi_fov_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -401,9 +311,9 @@ def test_fov_split_default_val_experiments(self, multi_fov_experiments): """Default val_experiments=[] triggers FOV split.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = multi_fov_experiments + parquet_path, _ = multi_fov_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -428,9 +338,9 @@ def test_positive_cell_source_self_stores_on_dm(self, two_experiments): """positive_cell_source='self' is stored and passed to datasets.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = two_experiments + parquet_path, _ = two_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -447,9 +357,9 @@ def test_positive_match_columns_stored_on_dm(self, two_experiments): """positive_match_columns is stored on datamodule.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = two_experiments + parquet_path, _ = two_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -464,9 +374,9 @@ def test_positive_channel_source_any_stored(self, two_experiments): """positive_channel_source='any' is stored on datamodule and dataset.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = two_experiments + parquet_path, _ = two_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -483,9 +393,9 @@ def test_self_positive_all_tracks_are_valid_anchors(self, two_experiments): """With positive_cell_source='self', all tracks become valid anchors.""" from dynaclr.data.datamodule import MultiExperimentDataModule - collection_path, _ = two_experiments + parquet_path, _ = two_experiments dm = MultiExperimentDataModule( - collection_path=str(collection_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=_YX_PATCH, final_yx_patch_size=_FINAL_YX_PATCH, @@ -495,6 +405,6 @@ def test_self_positive_all_tracks_are_valid_anchors(self, two_experiments): positive_cell_source="self", ) dm.setup("fit") - n_tracks = len(dm.train_dataset.index.tracks) + n_unique_cells = dm.train_dataset.index.tracks["cell_id"].nunique() n_anchors = len(dm.train_dataset.index.valid_anchors) - assert n_anchors == n_tracks + assert n_anchors == n_unique_cells diff --git a/applications/dynaclr/tests/test_dataset.py b/applications/dynaclr/tests/test_dataset.py index c63e5f94e..ab058d369 100644 --- a/applications/dynaclr/tests/test_dataset.py +++ b/applications/dynaclr/tests/test_dataset.py @@ -213,74 +213,6 @@ def test_getitems_returns_norm_meta(self, single_experiment_index): assert len(batch["anchor_norm_meta"]) == 1 -class TestPositiveSampling: - """Test lineage-aware positive selection.""" - - def test_positive_same_lineage(self, single_experiment_index): - """Positive comes from same lineage_id at t+tau (tau>0).""" - from dynaclr.data.dataset import MultiExperimentTripletDataset - - ds = MultiExperimentTripletDataset( - index=single_experiment_index, - fit=True, - ) - # Get anchor info - anchor_row = ds.index.valid_anchors.iloc[0] - anchor_lineage = anchor_row["lineage_id"] - anchor_t = anchor_row["t"] - - # Call _find_positive directly to verify lineage matching - rng = np.random.default_rng(42) - pos_row = ds._find_positive(anchor_row, rng) - assert pos_row is not None, "Should find a positive" - assert pos_row["lineage_id"] == anchor_lineage, ( - f"Positive lineage {pos_row['lineage_id']} != anchor {anchor_lineage}" - ) - assert pos_row["t"] > anchor_t, f"Positive t={pos_row['t']} should be > anchor t={anchor_t}" - - def test_positive_through_division(self, lineage_index): - """When anchor is on parent track that divides, positive can be a daughter.""" - from dynaclr.data.dataset import MultiExperimentTripletDataset - - ds = MultiExperimentTripletDataset( - index=lineage_index, - fit=True, - ) - - # Tracks 0, 1, 2 share the same lineage_id due to parent_map={1:0, 2:0} - # All three tracks should share one lineage (rooted at track 0) - parent_lineage = lineage_index.tracks[lineage_index.tracks["global_track_id"].str.endswith("_0")][ - "lineage_id" - ].iloc[0] - daughter1_lineage = lineage_index.tracks[lineage_index.tracks["global_track_id"].str.endswith("_1")][ - "lineage_id" - ].iloc[0] - daughter2_lineage = lineage_index.tracks[lineage_index.tracks["global_track_id"].str.endswith("_2")][ - "lineage_id" - ].iloc[0] - assert parent_lineage == daughter1_lineage == daughter2_lineage, ( - f"Lineage mismatch: parent={parent_lineage}, d1={daughter1_lineage}, d2={daughter2_lineage}" - ) - - # Find an anchor on the parent track - parent_anchors = ds.index.valid_anchors[ds.index.valid_anchors["global_track_id"].str.endswith("_0")] - assert len(parent_anchors) > 0, "Parent track should have valid anchors" - - # Verify positive sampling can reach daughters (same lineage, different track) - rng = np.random.default_rng(42) - anchor_row = parent_anchors.iloc[0] - found_daughter = False - for _ in range(50): - pos_row = ds._find_positive(anchor_row, rng) - if pos_row is not None and pos_row["global_track_id"] != anchor_row["global_track_id"]: - found_daughter = True - assert pos_row["lineage_id"] == anchor_row["lineage_id"] - break - # Even if we don't find a daughter every time, the lineage is correct - # (parent and daughter share lineage so any positive is valid) - assert found_daughter or True, "Test informational -- daughters reachable" - - class TestChannelRemapping: """Test that per-experiment channel indices are used correctly.""" @@ -418,6 +350,70 @@ def test_int_gt1_raises(self, single_experiment_index): ) +class TestMixedChannelCountErrors: + """``channels_per_sample=None`` on a parquet whose experiments have different + channel counts must raise a clear error instead of a cryptic torch.stack + failure deep in a dataloader thread.""" + + def test_raises_when_experiments_have_different_channel_counts(self, tmp_path, _make_tracks_csv, hcs_dims): + from dynaclr.data.dataset import MultiExperimentTripletDataset + from dynaclr.data.experiment import ExperimentRegistry + from dynaclr.data.index import MultiExperimentIndex + from viscy_data.collection import ChannelEntry, Collection, ExperimentEntry + + # exp_a: 2 channels; exp_b: 1 channel. + zarr_a, tracks_a = _create_zarr_and_tracks( + tmp_path, + name="exp_a", + channel_names=["Phase", "GFP"], + wells=[("A", "1")], + hcs_dims=hcs_dims, + _make_tracks_csv=_make_tracks_csv, + ) + zarr_b, tracks_b = _create_zarr_and_tracks( + tmp_path, + name="exp_b", + channel_names=["Phase"], + wells=[("A", "1")], + hcs_dims=hcs_dims, + _make_tracks_csv=_make_tracks_csv, + ) + registry = ExperimentRegistry( + collection=Collection( + name="test", + experiments=[ + ExperimentEntry( + name="exp_a", + data_path=str(zarr_a), + tracks_path=str(tracks_a), + channels=[ChannelEntry(name="Phase", marker="Phase"), ChannelEntry(name="GFP", marker="GFP")], + channel_names=["Phase", "GFP"], + perturbation_wells={"c": ["A/1"]}, + interval_minutes=30.0, + ), + ExperimentEntry( + name="exp_b", + data_path=str(zarr_b), + tracks_path=str(tracks_b), + channels=[ChannelEntry(name="Phase", marker="Phase")], + channel_names=["Phase"], + perturbation_wells={"c": ["A/1"]}, + interval_minutes=30.0, + ), + ], + ), + z_window=1, + ) + index = MultiExperimentIndex(registry=registry, yx_patch_size=_YX_PATCH, tau_range_hours=(0.5, 2.0)) + ds = MultiExperimentTripletDataset(index=index, fit=True, channels_per_sample=None) + + va = index.valid_anchors + idx_a = int(va.index[va["experiment"] == "exp_a"][0]) + idx_b = int(va.index[va["experiment"] == "exp_b"][0]) + with pytest.raises(RuntimeError, match="different channel counts"): + ds.__getitems__([idx_a, idx_b]) + + class TestDatasetLength: """Test dataset length matches valid_anchors.""" @@ -543,57 +539,6 @@ def test_self_positive_pixel_values_identical(self, single_experiment_index): ) -class TestColumnMatchPositive: - """Tests for positive_cell_source='lookup' with non-lineage columns.""" - - @staticmethod - def _build_index_with_gene_name(tmp_path: Path, _make_tracks_csv, hcs_dims: dict) -> "MultiExperimentIndex": - """Build an index where tracks have gene_name/reporter columns for matching.""" - index = _build_index(tmp_path, _make_tracks_csv=_make_tracks_csv, hcs_dims=hcs_dims) - n = len(index.tracks) - index.tracks["gene_name"] = ["RPL35" if i % 2 == 0 else "TP53" for i in range(n)] - index.tracks["reporter"] = "Phase" - index.valid_anchors["gene_name"] = ["RPL35" if i % 2 == 0 else "TP53" for i in range(len(index.valid_anchors))] - index.valid_anchors["reporter"] = "Phase" - return index - - def test_column_match_positive_different_cell(self, tmp_path, _make_tracks_csv, hcs_dims): - """positive_match_columns=['gene_name','reporter'] finds different cell with same values.""" - from dynaclr.data.dataset import MultiExperimentTripletDataset - - index = self._build_index_with_gene_name(tmp_path, _make_tracks_csv, hcs_dims) - ds = MultiExperimentTripletDataset( - index=index, - fit=True, - positive_cell_source="lookup", - positive_match_columns=["gene_name", "reporter"], - ) - rng = np.random.default_rng(0) - anchor_row = ds.index.valid_anchors.iloc[0] - pos = ds._find_positive(anchor_row, rng) - assert pos is not None, "Should find a column-match positive" - assert pos["gene_name"] == anchor_row["gene_name"], "Positive must share gene_name" - assert pos["reporter"] == anchor_row["reporter"], "Positive must share reporter" - assert pos.name != anchor_row.name, "Positive must be a different cell" - - def test_column_match_no_self_as_positive(self, tmp_path, _make_tracks_csv, hcs_dims): - """Column-match lookup never returns the anchor itself.""" - from dynaclr.data.dataset import MultiExperimentTripletDataset - - index = self._build_index_with_gene_name(tmp_path, _make_tracks_csv, hcs_dims) - ds = MultiExperimentTripletDataset( - index=index, - fit=True, - positive_cell_source="lookup", - positive_match_columns=["gene_name", "reporter"], - ) - rng = np.random.default_rng(42) - for _, anchor_row in ds.index.valid_anchors.iterrows(): - pos = ds._find_positive(anchor_row, rng) - if pos is not None: - assert pos.name != anchor_row.name, "Positive must not be the anchor itself" - - class TestTimepointStatisticsResolution: """Verify that timepoint_statistics norm_meta resolves the correct timepoint.""" diff --git a/applications/dynaclr/tests/test_index.py b/applications/dynaclr/tests/test_index.py index 08a6fbd46..7f17de18d 100644 --- a/applications/dynaclr/tests/test_index.py +++ b/applications/dynaclr/tests/test_index.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd import pytest -from iohub.ngff import Position, open_ome_zarr +from iohub.ngff import open_ome_zarr from dynaclr.data.experiment import ExperimentRegistry from dynaclr.data.index import MultiExperimentIndex @@ -196,7 +196,6 @@ def test_required_columns_present(self, two_experiment_setup): "y", "x", "z", - "position", "fov_name", "well_name", "experiment", @@ -234,22 +233,6 @@ def test_exclude_fovs_filter(self, two_experiment_setup): # Removed 1 FOV from each experiment: 2 * (4 - 1) * 5 * 10 = 300 assert len(index.tracks) == 300 - def test_positions_stored(self, two_experiment_setup): - """Position objects are stored in self.positions.""" - registry, _, _ = two_experiment_setup - index = MultiExperimentIndex(registry=registry, yx_patch_size=_YX_PATCH) - # 2 experiments * 2 wells * 2 FOVs = 8 positions - assert len(index.positions) == 8 - - def test_position_column_is_position_object(self, two_experiment_setup): - """'position' column contains iohub Position objects.""" - registry, _, _ = two_experiment_setup - index = MultiExperimentIndex(registry=registry, yx_patch_size=_YX_PATCH) - from iohub.ngff import Position - - sample_pos = index.tracks.iloc[0]["position"] - assert isinstance(sample_pos, Position) - def test_parallel_load_matches_serial(self, two_experiment_setup): """Parallel loading (num_workers=2) produces same result as serial (num_workers=1).""" registry, _, _ = two_experiment_setup @@ -261,10 +244,9 @@ def test_parallel_load_matches_serial(self, two_experiment_setup): serial_tracks = index_serial.tracks.sort_values(sort_cols).reset_index(drop=True) parallel_tracks = index_parallel.tracks.sort_values(sort_cols).reset_index(drop=True) - # Drop position column (object identity differs across processes) pd.testing.assert_frame_equal( - serial_tracks.drop(columns=["position"]), - parallel_tracks.drop(columns=["position"]), + serial_tracks, + parallel_tracks, check_like=True, ) assert len(index_serial.valid_anchors) == len(index_parallel.valid_anchors) @@ -1013,8 +995,8 @@ def test_parquet_valid_anchors_count(self, two_experiment_setup, tmp_path): n_channels = 2 # _CHANNEL_NAMES_A / _CHANNEL_NAMES_B each have 2 channels assert len(parquet_index.valid_anchors) == len(legacy_index.valid_anchors) * n_channels - def test_parquet_positions_resolved(self, two_experiment_setup, tmp_path): - """position column contains iohub Position objects.""" + def test_parquet_dims_from_columns(self, two_experiment_setup, tmp_path): + """Parquet path reads Y_shape/X_shape from parquet columns (no zarr opens).""" registry, _, _ = two_experiment_setup parquet_path = _build_cell_index_parquet(tmp_path, registry) @@ -1023,8 +1005,9 @@ def test_parquet_positions_resolved(self, two_experiment_setup, tmp_path): yx_patch_size=_YX_PATCH, cell_index_path=parquet_path, ) - sample_pos = index.tracks.iloc[0]["position"] - assert isinstance(sample_pos, Position) + assert "Y_shape" in index.tracks.columns + assert "X_shape" in index.tracks.columns + assert "position" not in index.tracks.columns # no longer stored def test_parquet_border_clamping(self, tmp_path, _create_experiment): """y_clamp, x_clamp are computed correctly from parquet path.""" diff --git a/applications/dynaclr/tests/test_mmd.py b/applications/dynaclr/tests/test_mmd.py new file mode 100644 index 000000000..1b02196f2 --- /dev/null +++ b/applications/dynaclr/tests/test_mmd.py @@ -0,0 +1,482 @@ +"""Tests for MMD perturbation evaluation.""" + +from __future__ import annotations + +import anndata as ad +import numpy as np +import pandas as pd +import pytest + +from dynaclr.evaluation.mmd.compute_mmd import run_mmd_analysis, run_mmd_pooled +from dynaclr.evaluation.mmd.config import ComparisonSpec, MMDEvalConfig, MMDPooledConfig, MMDSettings +from viscy_utils.evaluation.mmd import compute_mmd_unbiased, median_heuristic, mmd_permutation_test + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_COMP = [ComparisonSpec(cond_a="uninfected", cond_b="ZIKV", label="uninf vs ZIKV")] +_SETTINGS_FAST = MMDSettings(n_permutations=50) + + +def _cfg(**kwargs) -> MMDEvalConfig: + return MMDEvalConfig(input_path="dummy", output_dir="/tmp", comparisons=_COMP, **kwargs) + + +def _make_adata( + n_cells: int = 200, + n_features: int = 32, + markers: list[str] | None = None, + treatment_shift: float = 3.0, + seed: int = 0, +) -> ad.AnnData: + """Synthetic AnnData with two markers and two perturbation groups. + + TOMM20 has a large shift between uninfected and ZIKV (detectable MMD). + Phase3D has no shift (null). + """ + rng = np.random.default_rng(seed) + if markers is None: + markers = ["Phase3D", "TOMM20"] + n_per_group = n_cells // (2 * len(markers)) + + rows = [] + emb_list = [] + for marker in markers: + for perturbation in ["uninfected", "ZIKV"]: + for t in range(n_per_group): + shift = treatment_shift if (perturbation == "ZIKV" and marker == "TOMM20") else 0.0 + emb = rng.normal(loc=shift, scale=1.0, size=n_features) + emb_list.append(emb) + rows.append( + { + "experiment": "test_exp", + "marker": marker, + "perturbation": perturbation, + "hours_post_perturbation": float(t % 6), + } + ) + X = np.stack(emb_list) + obs = pd.DataFrame(rows) + return ad.AnnData(X=X.astype(np.float32), obs=obs) + + +def _make_temporal_adata(n_features: int = 16, seed: int = 0) -> ad.AnnData: + """AnnData where ZIKV treatment effect increases with hours_post_perturbation.""" + rng = np.random.default_rng(seed) + rows = [] + emb_list = [] + hours_bins = [1.0, 3.0, 6.0, 12.0] + for marker in ["TOMM20"]: + for _ in range(50): + emb_list.append(rng.normal(0.0, 1.0, n_features)) + rows.append( + {"experiment": "e", "marker": marker, "perturbation": "uninfected", "hours_post_perturbation": 0.0} + ) + for hpi in hours_bins: + shift = hpi / 3.0 + for _ in range(30): + emb_list.append(rng.normal(shift, 1.0, n_features)) + rows.append( + {"experiment": "e", "marker": marker, "perturbation": "ZIKV", "hours_post_perturbation": hpi} + ) + X = np.stack(emb_list).astype(np.float32) + obs = pd.DataFrame(rows) + return ad.AnnData(X=X, obs=obs) + + +# --------------------------------------------------------------------------- +# Core MMD tests +# --------------------------------------------------------------------------- + + +def test_mmd_identical_distributions(): + rng = np.random.default_rng(1) + X = rng.normal(0, 1, (200, 16)) + Y = rng.normal(0, 1, (200, 16)) + mmd2, p_value, _ = mmd_permutation_test(X, Y, n_permutations=200, seed=42) + assert mmd2 < 0.1 + assert p_value > 0.05 + + +def test_mmd_different_distributions(): + rng = np.random.default_rng(2) + X = rng.normal(0.0, 1.0, (200, 16)) + Y = rng.normal(5.0, 1.0, (200, 16)) + mmd2, p_value, _ = mmd_permutation_test(X, Y, n_permutations=200, seed=42) + assert mmd2 > 0.1 + assert p_value < 0.05 + + +def test_mmd_permutation_null(): + rng = np.random.default_rng(3) + X = rng.normal(0, 1, (100, 8)) + Y = rng.normal(0, 1, (100, 8)) + _, _, null = mmd_permutation_test(X, Y, n_permutations=100, seed=0) + assert len(null) == 100 + assert np.all(np.isfinite(null)) + + +def test_median_heuristic_positive(): + rng = np.random.default_rng(4) + X = rng.normal(0, 1, (50, 8)) + Y = rng.normal(2, 1, (50, 8)) + assert median_heuristic(X, Y) > 0 + + +def test_compute_mmd_unbiased_symmetric(): + rng = np.random.default_rng(5) + X = rng.normal(0, 1, (100, 8)) + Y = rng.normal(1, 1, (100, 8)) + bw = median_heuristic(X, Y) + assert abs(compute_mmd_unbiased(X, Y, bw) - compute_mmd_unbiased(Y, X, bw)) < 1e-10 + + +# --------------------------------------------------------------------------- +# run_mmd_analysis tests +# --------------------------------------------------------------------------- + + +def test_run_mmd_analysis_columns(): + adata = _make_adata() + df = run_mmd_analysis(adata, _cfg(mmd=_SETTINGS_FAST)) + expected = { + "experiment", + "marker", + "cond_a", + "cond_b", + "label", + "hours_bin_start", + "hours_bin_end", + "n_a", + "n_b", + "mmd2", + "p_value", + "bandwidth", + "effect_size", + "activity_zscore", + "embedding_key", + } + assert expected.issubset(df.columns), f"Missing columns: {expected - set(df.columns)}" + + +def test_run_mmd_analysis_explicit_comparisons(): + adata = _make_adata() + df = run_mmd_analysis(adata, _cfg(mmd=_SETTINGS_FAST)) + assert set(df["cond_b"].unique()) == {"ZIKV"} + assert set(df["cond_a"].unique()) == {"uninfected"} + assert df["label"].iloc[0] == "uninf vs ZIKV" + + +def test_run_mmd_analysis_per_marker(): + adata = _make_adata() + df = run_mmd_analysis(adata, _cfg(mmd=_SETTINGS_FAST)) + assert set(df["marker"].unique()) == {"Phase3D", "TOMM20"} + assert len(df) == 2 # one row per (marker, comparison) in aggregate mode + + +def test_run_mmd_analysis_significant_for_shifted_marker(): + adata = _make_adata(n_cells=600, treatment_shift=4.0) + df = run_mmd_analysis(adata, _cfg(mmd=MMDSettings(n_permutations=200))) + tomm = df[df["marker"] == "TOMM20"]["mmd2"].iloc[0] + phase = df[df["marker"] == "Phase3D"]["mmd2"].iloc[0] + assert tomm > phase + assert df[df["marker"] == "TOMM20"]["p_value"].iloc[0] < 0.05 + + +def test_run_mmd_analysis_missing_cond_returns_nan(): + """When cond_a is absent from the data, result is NaN (not an error).""" + adata = _make_adata() + cfg = MMDEvalConfig( + input_path="dummy", + output_dir="/tmp", + comparisons=[ComparisonSpec(cond_a="MISSING", cond_b="ZIKV", label="missing vs ZIKV")], + mmd=_SETTINGS_FAST, + ) + df = run_mmd_analysis(adata, cfg) + assert df["mmd2"].isna().all() + + +def test_run_mmd_analysis_temporal_bins(): + adata = _make_temporal_adata() + cfg = _cfg(mmd=MMDSettings(n_permutations=100), temporal_bins=[0.0, 2.0, 5.0, 8.0, 15.0]) + df = run_mmd_analysis(adata, cfg) + valid = df.dropna(subset=["mmd2"]).sort_values("hours_bin_start") + assert len(valid) >= 2 + assert valid.iloc[-1]["mmd2"] > valid.iloc[0]["mmd2"] + + +def test_run_mmd_analysis_min_cells_skip(): + adata = _make_temporal_adata() + cfg = _cfg( + mmd=MMDSettings(n_permutations=50, min_cells=5), + temporal_bins=[0.0, 0.5, 1.0, 100.0], + ) + df = run_mmd_analysis(adata, cfg) + first_bin = df[(df["hours_bin_start"] == 0.0) & (df["hours_bin_end"] == 0.5)] + assert len(first_bin) > 0 + assert first_bin["mmd2"].isna().all() + + +def test_run_mmd_analysis_batch_centering(): + rng = np.random.default_rng(7) + n, n_feat = 100, 8 + rows, embs = [], [] + for exp, offset in [("exp_A", 0.0), ("exp_B", 10.0)]: + for pert in ["uninfected", "ZIKV"]: + shift = 3.0 if pert == "ZIKV" else 0.0 + for _ in range(n): + embs.append(rng.normal(offset + shift, 1.0, n_feat)) + rows.append( + {"experiment": exp, "marker": "TOMM20", "perturbation": pert, "hours_post_perturbation": 1.0} + ) + X = np.stack(embs).astype(np.float32) + obs = pd.DataFrame(rows) + adata = ad.AnnData(X=X, obs=obs) + + cfg_test = MMDEvalConfig( + input_path="dummy", + output_dir="/tmp", + comparisons=_COMP, + mmd=MMDSettings(n_permutations=100), + ) + df_no_center = run_mmd_analysis(adata, cfg_test) + + centered = X.copy() + for exp in obs["experiment"].unique(): + for marker in obs["marker"].unique(): + mask = ((obs["experiment"] == exp) & (obs["marker"] == marker)).to_numpy() + if mask.sum() > 0: + centered[mask] -= centered[mask].mean(axis=0) + adata_centered = ad.AnnData(X=centered, obs=obs) + df_centered = run_mmd_analysis(adata_centered, cfg_test) + + tomm_uncentered = df_no_center[df_no_center["marker"] == "TOMM20"]["mmd2"].iloc[0] + tomm_centered = df_centered[df_centered["marker"] == "TOMM20"]["mmd2"].iloc[0] + assert tomm_centered <= tomm_uncentered * 1.5, ( + f"Centering should reduce MMD. centered={tomm_centered:.4f}, uncentered={tomm_uncentered:.4f}" + ) + + +def test_run_mmd_analysis_obs_filter(): + """obs_filter restricts analysis to matching rows before computing MMD.""" + rng = np.random.default_rng(42) + n, n_feat = 60, 8 + rows, embs = [], [] + for microscope in ["dragonfly", "mantis"]: + for perturbation in ["uninfected", "ZIKV"]: + shift = 10.0 if perturbation == "ZIKV" else 0.0 + for _ in range(n): + embs.append(rng.normal(shift, 1.0, n_feat)) + rows.append( + { + "experiment": "e", + "marker": "TOMM20", + "perturbation": perturbation, + "microscope": microscope, + "hours_post_perturbation": 1.0, + } + ) + + adata = ad.AnnData(X=np.stack(embs).astype(np.float32), obs=pd.DataFrame(rows)) + + # Compare microscopes on uninfected only — should be near zero (same distribution) + comp = [ComparisonSpec(cond_a="dragonfly", cond_b="mantis", label="dragonfly vs mantis")] + cfg = MMDEvalConfig( + input_path="dummy", + output_dir="/tmp", + comparisons=comp, + group_by="microscope", + obs_filter={"perturbation": "uninfected"}, + mmd=MMDSettings(n_permutations=50), + ) + df = run_mmd_analysis(adata, cfg) + assert len(df) == 1 + # MMD on unfiltered data would be dominated by the ZIKV shift; filtered should be small + assert df["mmd2"].iloc[0] < 1.0, f"Expected near-zero MMD on uninfected-only, got {df['mmd2'].iloc[0]:.4f}" + + +# --------------------------------------------------------------------------- +# Activity z-score tests +# --------------------------------------------------------------------------- + + +def test_activity_zscore_shifted(): + """Strongly shifted distributions produce a large positive activity_zscore.""" + adata = _make_adata(n_cells=600, treatment_shift=5.0) + df = run_mmd_analysis(adata, _cfg(mmd=MMDSettings(n_permutations=200))) + tomm = df[df["marker"] == "TOMM20"]["activity_zscore"].iloc[0] + assert tomm > 1.0, f"Expected activity_zscore > 1 for shifted distribution, got {tomm:.3f}" + + +def test_activity_zscore_identical(): + """Identical distributions produce activity_zscore near zero.""" + adata = _make_adata(n_cells=400, treatment_shift=0.0) + df = run_mmd_analysis(adata, _cfg(mmd=MMDSettings(n_permutations=200))) + for _, row in df.iterrows(): + assert np.isfinite(row["activity_zscore"]) or np.isnan(row["activity_zscore"]) + + +# --------------------------------------------------------------------------- +# Sample balancing tests +# --------------------------------------------------------------------------- + + +def test_balance_samples(): + """With balance_samples=True, both groups have equal size (reflected in n_a, n_b).""" + rng = np.random.default_rng(10) + n_small, n_large = 30, 120 + rows, embs = [], [] + for pert, n in [("uninfected", n_large), ("ZIKV", n_small)]: + for _ in range(n): + embs.append(rng.normal(0.0, 1.0, 8)) + rows.append({"experiment": "e", "marker": "TOMM20", "perturbation": pert, "hours_post_perturbation": 1.0}) + adata = ad.AnnData(X=np.stack(embs).astype(np.float32), obs=pd.DataFrame(rows)) + cfg = _cfg(mmd=MMDSettings(n_permutations=50, balance_samples=True, max_cells=None)) + df = run_mmd_analysis(adata, cfg) + row = df[df["marker"] == "TOMM20"].iloc[0] + assert row["n_a"] == row["n_b"], f"Expected equal group sizes, got n_a={row['n_a']}, n_b={row['n_b']}" + + +# --------------------------------------------------------------------------- +# Bandwidth sharing tests +# --------------------------------------------------------------------------- + + +def test_share_bandwidth_from(): + """With share_bandwidth_from set, the bandwidth is the same across comparisons.""" + adata = _make_adata(n_cells=400, treatment_shift=2.0) + # Add a second condition + obs = adata.obs.copy() + extra_rows = obs[obs["perturbation"] == "ZIKV"].copy() + extra_rows["perturbation"] = "DENV" + extra_obs = pd.concat([obs, extra_rows], ignore_index=True) + extra_emb = np.concatenate([adata.X, adata.X[obs["perturbation"] == "ZIKV"]], axis=0) + adata2 = ad.AnnData(X=extra_emb.astype(np.float32), obs=extra_obs) + + comps = [ + ComparisonSpec(cond_a="uninfected", cond_b="ZIKV", label="baseline"), + ComparisonSpec(cond_a="uninfected", cond_b="DENV", label="treatment"), + ] + cfg = MMDEvalConfig( + input_path="dummy", + output_dir="/tmp", + comparisons=comps, + mmd=MMDSettings(n_permutations=50, share_bandwidth_from="baseline"), + ) + df = run_mmd_analysis(adata2, cfg) + for marker in df["marker"].unique(): + sub = df[df["marker"] == marker].dropna(subset=["bandwidth"]) + if len(sub) == 2: + assert abs(sub["bandwidth"].iloc[0] - sub["bandwidth"].iloc[1]) < 1e-6, ( + f"Expected shared bandwidth for {marker}, got {sub['bandwidth'].to_numpy()}" + ) + + +# --------------------------------------------------------------------------- +# Temporal bins (explicit edges) tests +# --------------------------------------------------------------------------- + + +def test_temporal_bins_explicit(): + """temporal_bins produces one row per bin per comparison.""" + adata = _make_temporal_adata() + cfg = _cfg(mmd=MMDSettings(n_permutations=50), temporal_bins=[0.0, 2.0, 5.0, 8.0, 15.0]) + df = run_mmd_analysis(adata, cfg) + valid = df.dropna(subset=["mmd2"]).sort_values("hours_bin_start") + assert len(valid) >= 2, "Expected at least 2 valid temporal bins" + assert valid.iloc[-1]["mmd2"] > valid.iloc[0]["mmd2"], "MMD should increase with shift" + + +def test_temporal_bins_min_cells_skip(): + """Bins with fewer than min_cells cells produce NaN rows.""" + adata = _make_temporal_adata() + cfg = _cfg( + mmd=MMDSettings(n_permutations=50, min_cells=5), + temporal_bins=[0.0, 0.5, 1.0, 100.0], + ) + df = run_mmd_analysis(adata, cfg) + first_bin = df[(df["hours_bin_start"] == 0.0) & (df["hours_bin_end"] == 0.5)] + assert len(first_bin) > 0 + assert first_bin["mmd2"].isna().all() + + +def test_temporal_bins_mutually_exclusive(): + """Setting both temporal_bin_size and temporal_bins raises ValidationError.""" + with pytest.raises(Exception): + MMDEvalConfig( + input_path="dummy", + output_dir="/tmp", + comparisons=_COMP, + temporal_bin_size=4.0, + temporal_bins=[0.0, 4.0, 8.0], + ) + + +# --------------------------------------------------------------------------- +# Pooled mode tests +# --------------------------------------------------------------------------- + + +def _save_adata_zarr(adata: ad.AnnData, path: str) -> None: + import os + import shutil + + if os.path.exists(path): + shutil.rmtree(path) + adata.write_zarr(path) + + +def test_run_mmd_pooled_columns(tmp_path): + """run_mmd_pooled returns expected columns including activity_zscore and q_value.""" + adata1 = _make_adata(n_cells=200, seed=0) + adata2 = _make_adata(n_cells=200, seed=1) + p1 = str(tmp_path / "exp1.zarr") + p2 = str(tmp_path / "exp2.zarr") + _save_adata_zarr(adata1, p1) + _save_adata_zarr(adata2, p2) + + cfg = MMDPooledConfig( + input_paths=[p1, p2], + output_dir=str(tmp_path / "out"), + comparisons=_COMP, + mmd=MMDSettings(n_permutations=50), + ) + df = run_mmd_pooled(cfg) + expected = { + "marker", + "cond_a", + "cond_b", + "label", + "mmd2", + "p_value", + "bandwidth", + "effect_size", + "activity_zscore", + "q_value", + } + assert expected.issubset(df.columns), f"Missing: {expected - set(df.columns)}" + + +def test_run_mmd_pooled_condition_aliases(tmp_path): + """condition_aliases remaps variant condition names to canonical names.""" + rng = np.random.default_rng(99) + rows, embs = [], [] + for pert in ["uninfected1", "uninfected2", "ZIKV"]: + shift = 3.0 if pert == "ZIKV" else 0.0 + for _ in range(60): + embs.append(rng.normal(shift, 1.0, 16)) + rows.append({"experiment": "e", "marker": "TOMM20", "perturbation": pert, "hours_post_perturbation": 1.0}) + adata = ad.AnnData(X=np.stack(embs).astype(np.float32), obs=pd.DataFrame(rows)) + p = str(tmp_path / "exp.zarr") + _save_adata_zarr(adata, p) + + cfg = MMDPooledConfig( + input_paths=[p], + output_dir=str(tmp_path / "out"), + comparisons=[ComparisonSpec(cond_a="uninfected", cond_b="ZIKV", label="uninf vs ZIKV")], + mmd=MMDSettings(n_permutations=50), + condition_aliases={"uninfected": ["uninfected1", "uninfected2"]}, + ) + df = run_mmd_pooled(cfg) + assert not df["mmd2"].isna().all(), "Expected valid MMD after condition alias remapping" diff --git a/applications/dynaclr/tests/test_multi_experiment_integration.py b/applications/dynaclr/tests/test_multi_experiment_integration.py index 10f30005e..26d22cac1 100644 --- a/applications/dynaclr/tests/test_multi_experiment_integration.py +++ b/applications/dynaclr/tests/test_multi_experiment_integration.py @@ -14,6 +14,7 @@ from lightning.pytorch.loggers import TensorBoardLogger from dynaclr.engine import ContrastiveModule +from viscy_data.cell_index import build_timelapse_cell_index from viscy_models.contrastive.loss import NTXentHCL # --------------------------------------------------------------------------- @@ -52,11 +53,13 @@ def test_multi_experiment_fast_dev_run(tmp_path, _create_experiment, _write_coll perturbation_wells={"control": ["B/1"]}, ) yaml_path = _write_collection_yaml(tmp_path, [exp_alpha, exp_beta]) + parquet_path = tmp_path / "cell_index.parquet" + build_timelapse_cell_index(yaml_path, parquet_path, num_workers=1) from dynaclr.data.datamodule import MultiExperimentDataModule datamodule = MultiExperimentDataModule( - collection_path=str(yaml_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=(32, 32), final_yx_patch_size=(24, 24), @@ -183,11 +186,13 @@ def test_multi_experiment_fast_dev_run_with_all_sampling_axes( start_hpi=0.0, ) yaml_path = _write_collection_yaml(tmp_path, [exp_alpha, exp_beta]) + parquet_path = tmp_path / "cell_index.parquet" + build_timelapse_cell_index(yaml_path, parquet_path, num_workers=1) from dynaclr.data.datamodule import MultiExperimentDataModule datamodule = MultiExperimentDataModule( - collection_path=str(yaml_path), + cell_index_path=str(parquet_path), z_window=1, yx_patch_size=(32, 32), final_yx_patch_size=(24, 24), diff --git a/applications/dynaclr/tests/test_pseudotime.py b/applications/dynaclr/tests/test_pseudotime.py index d091c0e4d..54e0b6264 100644 --- a/applications/dynaclr/tests/test_pseudotime.py +++ b/applications/dynaclr/tests/test_pseudotime.py @@ -10,13 +10,18 @@ import pandas as pd import pytest -from dynaclr.evaluation.pseudotime.alignment import ( +from dynaclr.pseudotime.alignment import ( align_tracks, assign_t_perturb, filter_tracks, identify_lineages, ) -from dynaclr.evaluation.pseudotime.metrics import ( +from dynaclr.pseudotime.dtw_alignment import ( + alignment_results_to_dataframe, + build_template, + dtw_align_tracks, +) +from dynaclr.pseudotime.metrics import ( aggregate_population, compute_track_timing, find_half_max_time, @@ -24,13 +29,13 @@ find_peak_metrics, run_statistical_tests, ) -from dynaclr.evaluation.pseudotime.plotting import ( +from dynaclr.pseudotime.plotting import ( plot_cell_heatmap, plot_onset_comparison, plot_response_curves, plot_timing_distributions, ) -from dynaclr.evaluation.pseudotime.signals import ( +from dynaclr.pseudotime.signals import ( extract_annotation_signal, extract_embedding_distance, extract_prediction_signal, @@ -385,3 +390,120 @@ def test_plot_onset_comparison_saves_files(self, tmp_path): assert isinstance(fig, plt.Figure) assert (tmp_path / "onset_comparison.pdf").exists() assert (tmp_path / "onset_comparison.png").exists() + + +# ── TestTimeCalibration ─────────────────────────────────────────────── + + +class TestTimeCalibration: + """Tests for pseudotime-to-minutes template calibration.""" + + @pytest.fixture + def simple_template_inputs(self): + """Two synthetic 5-timepoint tracks with known t_relative_minutes.""" + rng = np.random.default_rng(0) + D = 8 + n_tracks = 6 + tracks = [] + for i in range(n_tracks): + # Each track: 10 frames, t_relative_minutes from -150 to +150 + fov = "C/2/000" + track_id = i + emb = rng.normal(0, 1, (10, D)).astype(np.float32) + obs = pd.DataFrame( + { + "fov_name": fov, + "track_id": track_id, + "t": np.arange(10), + "infection_state": ["not_infected"] * 5 + ["infected"] * 5, + "organelle_state": ["noremodel"] * 10, + "parent_track_id": -1, + } + ) + tracks.append((fov, track_id, emb, obs)) + + # Build AnnData for one "dataset" + all_obs = pd.concat([t[3] for t in tracks], ignore_index=True) + all_emb = np.vstack([t[2] for t in tracks]) + adata = ad.AnnData(X=all_emb, obs=all_obs) + + # Build aligned_df: t_perturb = 5 for all, t_relative_minutes = (t - 5) * 30 + df = all_obs.copy() + df["t_perturb"] = 5 + df["t_relative_minutes"] = (df["t"] - 5) * 30.0 + + return {"test": adata}, {"test": df} + + def test_build_template_has_time_calibration(self, simple_template_inputs): + adata_dict, aligned_df_dict = simple_template_inputs + result = build_template(adata_dict, aligned_df_dict, pca_n_components=None) + assert result.time_calibration is not None + T = result.template.shape[0] + assert result.time_calibration.shape == (T,) + # Calibration should span a reasonable real-time range + assert result.time_calibration.min() < 0 + assert result.time_calibration.max() > 0 + + def test_time_calibration_monotonically_increasing(self, simple_template_inputs): + adata_dict, aligned_df_dict = simple_template_inputs + result = build_template(adata_dict, aligned_df_dict, pca_n_components=None) + cal = result.time_calibration + # After gap interpolation, calibration should be non-decreasing + diffs = np.diff(cal) + assert np.all(diffs >= -1e-6), f"Non-monotonic calibration: {diffs}" + + def test_estimated_t_rel_in_alignment_output(self, simple_template_inputs): + adata_dict, aligned_df_dict = simple_template_inputs + template = build_template(adata_dict, aligned_df_dict, pca_n_components=None) + assert template.time_calibration is not None + + # Align one dataset against the template + adata = list(adata_dict.values())[0] + df = list(aligned_df_dict.values())[0] + results = dtw_align_tracks(adata, df, template, "test", min_track_timepoints=3) + flat = alignment_results_to_dataframe(results, template.template_id, time_calibration=template.time_calibration) + + assert "estimated_t_rel_minutes" in flat.columns + cal_min = template.time_calibration.min() + cal_max = template.time_calibration.max() + est = flat["estimated_t_rel_minutes"].dropna() + assert len(est) > 0 + assert est.min() >= cal_min - 1.0 + assert est.max() <= cal_max + 1.0 + + +# ── TestMetricsContinuous ───────────────────────────────────────────── + + +class TestMetricsContinuous: + """Tests for continuous-signal metrics (onset, peak).""" + + def test_find_onset_continuous_signal(self): + rows = [] + for t in range(-600, 901, 30): + val = 3.0 if t >= 120 else 0.0 + rows.append({"time_minutes": t, "mean": val, "n_cells": 20}) + pop_df = pd.DataFrame(rows) + onset, threshold, bl_mean, bl_std = find_onset_time( + pop_df, baseline_window=(-600, -60), sigma_threshold=2.0, signal_col="mean" + ) + assert onset is not None + assert onset == 120 + + def test_find_peak_metrics_continuous(self): + rows = [] + for t in range(-300, 601, 30): + if t < 0: + val = 0.0 + elif t <= 150: + val = t / 150.0 * 5.0 + elif t <= 300: + val = 5.0 - (t - 150) / 150.0 * 5.0 + else: + val = 0.0 + rows.append({"time_minutes": t, "mean": val, "n_cells": 20}) + pop_df = pd.DataFrame(rows) + metrics = find_peak_metrics(pop_df, signal_col="mean") + assert not np.isnan(metrics["T_peak_minutes"]) + assert metrics["peak_amplitude"] > 0 + assert metrics["auc"] > 0 diff --git a/applications/dynaclr/tests/test_reduce_dimensionality.py b/applications/dynaclr/tests/test_reduce_dimensionality.py index 3b291b8b7..fbfd4c56a 100644 --- a/applications/dynaclr/tests/test_reduce_dimensionality.py +++ b/applications/dynaclr/tests/test_reduce_dimensionality.py @@ -6,6 +6,8 @@ from pydantic import ValidationError from dynaclr.evaluation.dimensionality_reduction.config import ( + CombinedDatasetConfig, + CombinedDimensionalityReductionConfig, DimensionalityReductionConfig, PCAConfig, PHATEConfig, @@ -154,7 +156,9 @@ class TestCLIIntegration: def test_pca_end_to_end(self, synthetic_zarr, tmp_path): from click.testing import CliRunner - from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import main + from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import ( + main, + ) output_path = str(tmp_path / "output.zarr") config_content = f"input_path: {synthetic_zarr}\noutput_path: {output_path}\npca:\n n_components: 10\n" @@ -172,7 +176,9 @@ def test_pca_end_to_end(self, synthetic_zarr, tmp_path): def test_overwrite_keys_protection(self, synthetic_zarr, tmp_path): from click.testing import CliRunner - from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import main + from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import ( + main, + ) # Pre-populate X_pca adata = ad.read_zarr(synthetic_zarr) @@ -191,7 +197,9 @@ def test_overwrite_keys_protection(self, synthetic_zarr, tmp_path): def test_overwrite_keys_allowed(self, synthetic_zarr, tmp_path): from click.testing import CliRunner - from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import main + from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import ( + main, + ) # Pre-populate X_pca adata = ad.read_zarr(synthetic_zarr) @@ -212,7 +220,9 @@ def test_overwrite_keys_allowed(self, synthetic_zarr, tmp_path): def test_writes_back_to_input_when_no_output(self, synthetic_zarr, tmp_path): from click.testing import CliRunner - from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import main + from dynaclr.evaluation.dimensionality_reduction.reduce_dimensionality import ( + main, + ) config_content = f"input_path: {synthetic_zarr}\npca:\n n_components: 5\n" config_path = tmp_path / "test_config.yaml" @@ -225,3 +235,197 @@ def test_writes_back_to_input_when_no_output(self, synthetic_zarr, tmp_path): adata = ad.read_zarr(synthetic_zarr) assert "X_pca" in adata.obsm assert adata.obsm["X_pca"].shape == (100, 5) + + +class TestAppendToAnndataZarrUns: + """Test that append_to_anndata_zarr preserves existing uns keys.""" + + def test_uns_per_key_preserves_existing(self, tmp_path): + from viscy_utils.evaluation.zarr_utils import append_to_anndata_zarr + + rng = np.random.default_rng(42) + adata = ad.AnnData(X=rng.standard_normal((10, 4)).astype(np.float32)) + adata.uns["existing_key"] = "should_survive" + adata.uns["existing_list"] = ["a", "b"] + zarr_path = tmp_path / "test.zarr" + ad.settings.allow_write_nullable_strings = True + adata.write_zarr(zarr_path) + + append_to_anndata_zarr(zarr_path, uns={"new_key": ["path1", "path2"]}) + + result = ad.read_zarr(zarr_path) + assert result.uns["existing_key"] == "should_survive" + assert list(result.uns["existing_list"]) == ["a", "b"] + assert list(result.uns["new_key"]) == ["path1", "path2"] + + def test_uns_overwrites_specific_key(self, tmp_path): + from viscy_utils.evaluation.zarr_utils import append_to_anndata_zarr + + rng = np.random.default_rng(42) + adata = ad.AnnData(X=rng.standard_normal((10, 4)).astype(np.float32)) + adata.uns["my_key"] = "old_value" + adata.uns["other_key"] = "untouched" + zarr_path = tmp_path / "test.zarr" + ad.settings.allow_write_nullable_strings = True + adata.write_zarr(zarr_path) + + append_to_anndata_zarr(zarr_path, uns={"my_key": "new_value"}) + + result = ad.read_zarr(zarr_path) + assert result.uns["my_key"] == "new_value" + assert result.uns["other_key"] == "untouched" + + +class TestCombinedDimensionalityReductionConfig: + def test_valid_config(self, synthetic_zarr): + cfg = CombinedDimensionalityReductionConfig( + input_paths=[synthetic_zarr], + pca=PCAConfig(n_components=5), + ) + assert len(cfg.input_paths) == 1 + + def test_valid_config_with_datasets_mapping(self, synthetic_zarr): + cfg = CombinedDimensionalityReductionConfig( + datasets={"ds1": CombinedDatasetConfig(anndata=synthetic_zarr)}, + pca=PCAConfig(n_components=5), + ) + assert cfg.input_paths == [synthetic_zarr] + + def test_missing_methods_raises(self, synthetic_zarr): + with pytest.raises(ValidationError, match="At least one reduction method"): + CombinedDimensionalityReductionConfig(input_paths=[synthetic_zarr]) + + def test_missing_path_raises(self): + with pytest.raises(ValidationError, match="Input path not found"): + CombinedDimensionalityReductionConfig( + input_paths=["/nonexistent/path.zarr"], + pca=PCAConfig(), + ) + + +class TestCombinedReduction: + @pytest.fixture + def two_synthetic_zarrs(self, tmp_path): + """Create two synthetic AnnData zarrs with uns metadata.""" + ad.settings.allow_write_nullable_strings = True + rng = np.random.default_rng(42) + paths = [] + for i in range(2): + n = 50 + i * 30 # 50 and 80 samples + X = rng.standard_normal((n, 32)).astype(np.float32) + adata = ad.AnnData(X=X) + adata.uns["classifier_version"] = f"v{i}" + adata.uns["predicted_classes"] = ["alive", "dead"] + zarr_path = tmp_path / f"store_{i}.zarr" + adata.write_zarr(zarr_path) + paths.append(str(zarr_path)) + return paths + + def test_combined_pca_only(self, two_synthetic_zarrs): + from click.testing import CliRunner + + from dynaclr.evaluation.dimensionality_reduction.reduce_combined import main + + config_content = ( + f"input_paths:\n - {two_synthetic_zarrs[0]}\n - {two_synthetic_zarrs[1]}\npca:\n n_components: 5\n" + ) + runner = CliRunner() + with runner.isolated_filesystem(): + config_path = "combined.yaml" + with open(config_path, "w") as f: + f.write(config_content) + result = runner.invoke(main, ["-c", config_path]) + assert result.exit_code == 0, result.output + + for i, path in enumerate(two_synthetic_zarrs): + adata = ad.read_zarr(path) + n = 50 + i * 30 + assert "X_pca_combined" in adata.obsm + assert adata.obsm["X_pca_combined"].shape[0] == n + assert "pca_combined_datasets" in adata.uns + assert list(adata.uns["pca_combined_datasets"]) == two_synthetic_zarrs + # uns preserved + assert adata.uns["classifier_version"] == f"v{i}" + assert list(adata.uns["predicted_classes"]) == ["alive", "dead"] + + @pytest.fixture(autouse=False) + def _skip_no_phate(self): + pytest.importorskip("phate") + + def test_combined_pca_and_phate(self, two_synthetic_zarrs, _skip_no_phate): + from click.testing import CliRunner + + from dynaclr.evaluation.dimensionality_reduction.reduce_combined import main + + config_content = ( + "input_paths:\n" + f" - {two_synthetic_zarrs[0]}\n" + f" - {two_synthetic_zarrs[1]}\n" + "pca:\n" + " n_components: 5\n" + "phate:\n" + " n_components: 2\n" + " knn: 5\n" + " decay: 40\n" + ) + runner = CliRunner() + with runner.isolated_filesystem(): + config_path = "combined.yaml" + with open(config_path, "w") as f: + f.write(config_content) + result = runner.invoke(main, ["-c", config_path]) + assert result.exit_code == 0, result.output + + for i, path in enumerate(two_synthetic_zarrs): + adata = ad.read_zarr(path) + n = 50 + i * 30 + assert adata.obsm["X_pca_combined"].shape[0] == n + assert adata.obsm["X_phate_combined"].shape == (n, 2) + assert "pca_combined_datasets" in adata.uns + assert "phate_combined_datasets" in adata.uns + + def test_overwrite_protection(self, two_synthetic_zarrs): + from click.testing import CliRunner + + from dynaclr.evaluation.dimensionality_reduction.reduce_combined import main + + config_content = ( + f"input_paths:\n - {two_synthetic_zarrs[0]}\n - {two_synthetic_zarrs[1]}\npca:\n n_components: 5\n" + ) + runner = CliRunner() + with runner.isolated_filesystem(): + config_path = "combined.yaml" + with open(config_path, "w") as f: + f.write(config_content) + # First run + result = runner.invoke(main, ["-c", config_path]) + assert result.exit_code == 0, result.output + # Second run without overwrite_keys should fail + result = runner.invoke(main, ["-c", config_path]) + assert result.exit_code != 0 + assert "already exists" in result.output + + def test_overwrite_allowed(self, two_synthetic_zarrs): + from click.testing import CliRunner + + from dynaclr.evaluation.dimensionality_reduction.reduce_combined import main + + config_content = ( + "input_paths:\n" + f" - {two_synthetic_zarrs[0]}\n" + f" - {two_synthetic_zarrs[1]}\n" + "overwrite_keys: true\n" + "pca:\n" + " n_components: 5\n" + ) + runner = CliRunner() + with runner.isolated_filesystem(): + config_path = "combined.yaml" + with open(config_path, "w") as f: + f.write(config_content) + # First run + result = runner.invoke(main, ["-c", config_path]) + assert result.exit_code == 0, result.output + # Second run should also succeed (overwrite_keys=true) + result = runner.invoke(main, ["-c", config_path]) + assert result.exit_code == 0, result.output diff --git a/applications/dynaclr/tests/test_valid_anchors_marker.py b/applications/dynaclr/tests/test_valid_anchors_marker.py new file mode 100644 index 000000000..87c323fd6 --- /dev/null +++ b/applications/dynaclr/tests/test_valid_anchors_marker.py @@ -0,0 +1,212 @@ +"""Regression tests for marker-aware valid_anchors in flat-parquet mode. + +In flat-parquet / bag-of-channels mode, one cell observation becomes one +row per channel. ``_pick_temporal_candidate`` restricts positive candidates +to rows with the same ``marker`` as the anchor, so ``_compute_valid_anchors`` +must also include ``marker`` in the validity key — otherwise an anchor can +pass validation because a different-marker row exists at ``t+tau``, then +crash at sample time with "No positive found". + +These tests hit ``_compute_valid_anchors`` directly via ``object.__new__`` +so they don't need real zarr stores. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pandas as pd +import pytest + +from dynaclr.data.index import MultiExperimentIndex + + +def _make_registry(experiment_names, interval_minutes=30.0): + """Return a minimal object that quacks like ExperimentRegistry for tau math.""" + experiments = [SimpleNamespace(name=n, interval_minutes=interval_minutes) for n in experiment_names] + + def tau_range_frames(name, tau_range_hours): + exp = next(e for e in experiments if e.name == name) + min_h, max_h = tau_range_hours + frames_per_hour = 60.0 / exp.interval_minutes + return (int(round(min_h * frames_per_hour)), int(round(max_h * frames_per_hour))) + + return SimpleNamespace(experiments=experiments, tau_range_frames=tau_range_frames) + + +def _make_index(tracks: pd.DataFrame, registry) -> MultiExperimentIndex: + """Construct a bare MultiExperimentIndex without zarr I/O.""" + index = object.__new__(MultiExperimentIndex) + index.registry = registry + index.tracks = tracks.reset_index(drop=True) + return index + + +class TestMarkerAwareValidAnchors: + """`marker` must be part of the temporal validity key in flat-parquet mode.""" + + def test_anchor_with_cross_marker_positive_rejected(self): + """ + Anchor at (lid, marker=A, t=5) must be REJECTED when the only row + at t+tau is (lid, marker=B, t=6). Without marker-aware validity + this anchor would be accepted and then crash at sample time because + `_pick_temporal_candidate` filters candidates to same marker. + """ + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 2, + "lineage_id": ["L"] * 2, + "marker": ["A", "B"], + "t": [5, 6], + } + ) + registry = _make_registry(["exp"], interval_minutes=30.0) + index = _make_index(tracks, registry) + + # tau_range 0.5h - 1.5h at 30min = (1, 3) frames. + valid = index._compute_valid_anchors( + tau_range_hours=(0.5, 1.5), + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + # Neither row is a valid anchor: A has no same-marker positive in window, + # and B has no same-marker positive either. + assert len(valid) == 0, f"expected 0 valid anchors, got {len(valid)}:\n{valid}" + + def test_anchor_with_same_marker_positive_accepted(self): + """Anchor at (lid, marker=A, t=5) with (lid, marker=A, t=6) IS valid.""" + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 3, + "lineage_id": ["L"] * 3, + "marker": ["A", "A", "B"], + "t": [5, 6, 6], + } + ) + registry = _make_registry(["exp"], interval_minutes=30.0) + index = _make_index(tracks, registry) + + valid = index._compute_valid_anchors( + tau_range_hours=(0.5, 1.5), + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + # (A, t=5) is valid because (A, t=6) exists. + # (A, t=6) is NOT valid because there's no (A, t=7..8). + # (B, t=6) is NOT valid because there's no (B, t=7..8). + assert len(valid) == 1 + row = valid.iloc[0] + assert row["marker"] == "A" + assert row["t"] == 5 + + def test_both_markers_have_positives_both_accepted(self): + """When each marker has its own lineage continuity, both pass.""" + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 4, + "lineage_id": ["L"] * 4, + "marker": ["A", "A", "B", "B"], + "t": [5, 6, 5, 6], + } + ) + registry = _make_registry(["exp"], interval_minutes=30.0) + index = _make_index(tracks, registry) + + valid = index._compute_valid_anchors( + tau_range_hours=(0.5, 1.5), + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + # (A, t=5) valid (A, t=6 exists). (B, t=5) valid (B, t=6 exists). + # t=6 of each marker is NOT valid (no t=7 for either). + assert len(valid) == 2 + assert set(zip(valid["marker"], valid["t"])) == {("A", 5), ("B", 5)} + + def test_no_marker_column_falls_back_to_lineage_t(self): + """When `marker` column is absent, behavior matches legacy (lid, t) keys.""" + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 3, + "lineage_id": ["L"] * 3, + "t": [5, 6, 7], + } + ) + registry = _make_registry(["exp"], interval_minutes=30.0) + index = _make_index(tracks, registry) + + valid = index._compute_valid_anchors( + tau_range_hours=(0.5, 1.5), + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + # tau_range_frames = (1, 3). t=5 needs t=6,7,8 (6,7 exist) -> valid. + # t=6 needs t=7,8,9 (7 exists) -> valid. t=7 needs t=8,9,10 -> NOT valid. + assert len(valid) == 2 + assert set(valid["t"].to_numpy()) == {5, 6} + + +class TestLineageCollisionDetection: + """ + Regression for the ALFI-style bug where two FOVs share the same + ``lineage_id`` because lineage reconstruction collapsed across FOVs. + The marker-aware fix cannot save this — it's a data bug — so the + test documents the failure mode: `_compute_valid_anchors` will + accept anchors whose temporal neighbors are actually in a different + physical FOV. Cached so we notice if lineage reconstruction ever + starts disambiguating by FOV. + """ + + def test_cross_fov_lineage_collision_accepted_today(self): + """Two FOVs share `lineage_id='L'`; validity check treats as one lineage.""" + # FOV1 has t=5 only; FOV2 has t=6 only. They share lineage_id. + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 2, + "lineage_id": ["L", "L"], + "fov_name": ["FOV1", "FOV2"], # different physical fields + "marker": ["A", "A"], + "t": [5, 6], + } + ) + registry = _make_registry(["exp"], interval_minutes=30.0) + index = _make_index(tracks, registry) + + valid = index._compute_valid_anchors( + tau_range_hours=(0.5, 1.5), + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + # Today both rows pass — the fix doesn't consider fov_name in the + # validity key. If cell_index generation ever disambiguates lineage_id + # by fov, this test will flip and should be updated. + assert len(valid) == 1 # (A, t=5) valid because "L" at t=6 exists + # The surviving anchor is t=5 — at sample time it would try to + # pull a patch from FOV2 thinking it's the same biological lineage. + # That's still wrong biologically, but it won't raise "No positive found". + + +@pytest.mark.parametrize("interval_minutes", [15.0, 30.0, 60.0]) +def test_marker_key_respects_per_experiment_tau(interval_minutes): + """Marker-aware validity plays correctly with per-experiment interval_minutes.""" + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 4, + "lineage_id": ["L"] * 4, + "marker": ["A", "A", "A", "A"], + "t": [0, 1, 5, 10], + } + ) + registry = _make_registry(["exp"], interval_minutes=interval_minutes) + index = _make_index(tracks, registry) + + valid = index._compute_valid_anchors( + tau_range_hours=(0.5, 1.5), + positive_cell_source="lookup", + positive_match_columns=["lineage_id"], + ) + min_f, max_f = registry.tau_range_frames("exp", (0.5, 1.5)) + # Every valid anchor t must have some other row at t+tau within [min_f, max_f]. + t_vals = set(tracks["t"].to_numpy()) + for t in valid["t"].to_numpy(): + ok = any((t + tau) in t_vals for tau in range(min_f, max_f + 1) if tau != 0) + assert ok, f"anchor t={t} validated but no t+tau neighbor exists at interval={interval_minutes}" diff --git a/packages/viscy-data/src/viscy_data/_typing.py b/packages/viscy-data/src/viscy_data/_typing.py index 17eb7ed1a..0e6baf6cc 100644 --- a/packages/viscy-data/src/viscy_data/_typing.py +++ b/packages/viscy-data/src/viscy_data/_typing.py @@ -24,6 +24,7 @@ "CELL_INDEX_CORE_COLUMNS", "CELL_INDEX_GROUPING_COLUMNS", "CELL_INDEX_IMAGING_COLUMNS", + "CELL_INDEX_NORMALIZATION_COLUMNS", "CELL_INDEX_OPS_COLUMNS", "CELL_INDEX_TIMELAPSE_COLUMNS", "CellIndex", @@ -245,7 +246,25 @@ class TripletSample(TypedDict): CELL_INDEX_OPS_COLUMNS = ["gene_name", "reporter", "sgRNA"] -CELL_INDEX_IMAGING_COLUMNS = ["pixel_size_xy_um", "pixel_size_z_um"] +CELL_INDEX_IMAGING_COLUMNS = [ + "pixel_size_xy_um", + "pixel_size_z_um", + "T_shape", + "C_shape", + "Z_shape", + "Y_shape", + "X_shape", + "z_focus_mean", +] + +CELL_INDEX_NORMALIZATION_COLUMNS = [ + "norm_mean", + "norm_std", + "norm_median", + "norm_iqr", + "norm_max", + "norm_min", +] # Extracted from viscy/data/triplet.py for shared access ULTRACK_INDEX_COLUMNS = [ diff --git a/packages/viscy-data/src/viscy_data/_utils.py b/packages/viscy-data/src/viscy_data/_utils.py index ea0e96ef0..e6a6523c7 100644 --- a/packages/viscy-data/src/viscy_data/_utils.py +++ b/packages/viscy-data/src/viscy_data/_utils.py @@ -217,4 +217,7 @@ def _transform_channel_wise( ) -> list[Tensor]: scattered_channels = _scatter_channels(channel_names, patch, norm_meta, extra) transformed_channels = transform(scattered_channels) - return _gather_channels(transformed_channels) + extra_keys = ("norm_meta",) + if extra is not None: + extra_keys = ("norm_meta",) + tuple(extra.keys()) + return _gather_channels(transformed_channels, extra_keys=extra_keys) diff --git a/packages/viscy-data/src/viscy_data/cell_index.py b/packages/viscy-data/src/viscy_data/cell_index.py index ca03e8167..09a72f64f 100644 --- a/packages/viscy-data/src/viscy_data/cell_index.py +++ b/packages/viscy-data/src/viscy_data/cell_index.py @@ -15,6 +15,7 @@ from concurrent.futures import ProcessPoolExecutor, as_completed from pathlib import Path +import numpy as np import pandas as pd import pyarrow as pa import pyarrow.parquet as pq @@ -26,6 +27,7 @@ CELL_INDEX_CORE_COLUMNS, CELL_INDEX_GROUPING_COLUMNS, CELL_INDEX_IMAGING_COLUMNS, + CELL_INDEX_NORMALIZATION_COLUMNS, CELL_INDEX_OPS_COLUMNS, CELL_INDEX_TIMELAPSE_COLUMNS, ) @@ -37,6 +39,7 @@ "build_ops_cell_index", "build_timelapse_cell_index", "convert_ops_parquet", + "preprocess_cell_index", "read_cell_index", "validate_cell_index", "write_cell_index", @@ -74,6 +77,18 @@ ("organelle", pa.string()), ("pixel_size_xy_um", pa.float32()), ("pixel_size_z_um", pa.float32()), + ("T_shape", pa.int32()), + ("C_shape", pa.int32()), + ("Z_shape", pa.int32()), + ("Y_shape", pa.int32()), + ("X_shape", pa.int32()), + ("z_focus_mean", pa.float32()), + ("norm_mean", pa.float32()), + ("norm_std", pa.float32()), + ("norm_median", pa.float32()), + ("norm_iqr", pa.float32()), + ("norm_max", pa.float32()), + ("norm_min", pa.float32()), ] ) @@ -85,6 +100,7 @@ + CELL_INDEX_TIMELAPSE_COLUMNS + CELL_INDEX_OPS_COLUMNS + CELL_INDEX_IMAGING_COLUMNS + + CELL_INDEX_NORMALIZATION_COLUMNS ) # --------------------------------------------------------------------------- @@ -168,6 +184,13 @@ def write_cell_index( def read_cell_index(path: str | Path) -> pd.DataFrame: """Read a cell index parquet into a pandas DataFrame. + String columns are materialized as NumPy ``object`` arrays instead of + ``ArrowStringArray``. ArrowStringArray-backed columns route every + boolean mask slice through ``pyarrow.compute.take``, which allocates + a fresh buffer per string column and can spike peak RSS by 50+ GiB + on 80M-row indices during train/val FOV partitioning. NumPy object + columns make ``df[mask]`` a cheap gather. + Parameters ---------- path : str | Path @@ -179,7 +202,155 @@ def read_cell_index(path: str | Path) -> pd.DataFrame: Cell index with correct dtypes. """ table = pq.read_table(str(path), schema=CELL_INDEX_SCHEMA) - return table.to_pandas() + df = table.to_pandas(use_threads=True) + # ArrowStringArray columns with low cardinality (experiment, fov_name, + # marker, store_path, well, microscope, organelle, reporter) become + # Categorical to make ``df[mask]`` a fast int-code gather. Other string + # columns (cell_id, tracks_path, global_track_id, lineage_id, etc.) are + # high cardinality and are already read via the NumPy column cache in + # the dataset, so leave them as ArrowStringArray to avoid allocating + # millions of Python string objects here. + # NB: ``fov`` and ``well`` are NOT cast here because ``_align_parquet_columns`` + # downstream rewrites ``fov_name`` via string concatenation, which pandas + # does not support on Categorical. We cast ``fov_name`` later, after the + # prefix rewrite, in the runtime index layer. + _categorical_cols = ( + "experiment", + "marker", + "store_path", + "microscope", + "organelle", + "reporter", + "channel_name", + ) + for col in _categorical_cols: + if col in df.columns: + df[col] = df[col].astype("category") + return df + + +# --------------------------------------------------------------------------- +# Preprocessing (clean up an existing cell index parquet) +# --------------------------------------------------------------------------- + + +def preprocess_cell_index( + parquet_path: str | Path, + output_path: str | Path | None = None, + focus_channel: str | None = None, +) -> None: + """Add normalization stats, focus slice, and remove invalid rows. + + Reads precomputed metadata from each FOV's ``zattrs`` (written by + ``viscy preprocess``) and writes them as parquet columns: + + - ``norm_mean``, ``norm_std``, ``norm_median``, ``norm_iqr``, + ``norm_max``, ``norm_min`` — per-timepoint, per-channel statistics + - ``z_focus_mean`` — per-FOV focus plane from ``focus_slice`` + + Drops rows where timepoint stats are missing or ``norm_max == 0.0`` + (empty frames). The processed parquet is written to ``output_path``; + the function returns nothing — callers that need the dataframe should + ``read_cell_index`` afterwards. + + Parameters + ---------- + parquet_path : str | Path + Path to the cell index parquet to preprocess. + output_path : str | Path | None + Destination path. When ``None``, overwrites *parquet_path* in place. + focus_channel : str | None + Channel name for ``focus_slice`` lookup (e.g. ``"Phase3D"``). + When ``None``, uses the first channel_name in each FOV's group. + + Raises + ------ + ValueError + If a FOV has no normalization metadata (run ``viscy preprocess`` first). + """ + if output_path is None: + output_path = parquet_path + + df = read_cell_index(parquet_path) + n_before = len(df) + + fov_col = "fov" if "fov" in df.columns else "fov_name" + + # Build lookups from zarr zattrs (one open per unique FOV) + stat_lookup: dict[tuple[str, str, str, int], dict[str, float]] = {} + focus_lookup: dict[tuple[str, str], float] = {} + focus_per_t_lookup: dict[tuple[str, str], dict[int, int]] = {} + + for (store_path, fov), group in df.groupby(["store_path", fov_col]): + fov_path = f"{group['well'].iloc[0]}/{fov}" if "/" not in str(fov) else str(fov) + with open_ome_zarr(f"{store_path}/{fov_path}", mode="r") as pos: + norm_meta = pos.zattrs.get("normalization", None) + focus_meta = pos.zattrs.get("focus_slice", {}) + if norm_meta is None: + raise ValueError( + f"FOV '{fov_path}' in store '{store_path}' has no normalization metadata. " + "Run `viscy preprocess` on this dataset first." + ) + for ch_name, ch_stats in norm_meta.items(): + for t_str, tp_stats in ch_stats.get("timepoint_statistics", {}).items(): + stat_lookup[(str(store_path), str(fov), ch_name, int(t_str))] = tp_stats + + fc = focus_channel or group["channel_name"].iloc[0] + ch_focus = focus_meta.get(fc, {}) + fov_stats = ch_focus.get("fov_statistics", {}) + z_focus = fov_stats.get("z_focus_mean") + if z_focus is not None: + focus_lookup[(str(store_path), str(fov))] = float(z_focus) + per_timepoint = ch_focus.get("per_timepoint", {}) + if per_timepoint: + focus_per_t_lookup[(str(store_path), str(fov))] = { + int(t_str): int(z_idx) for t_str, z_idx in per_timepoint.items() + } + + # Vectorized lookup: build norm + focus column arrays + stat_keys = ["mean", "std", "median", "iqr", "max", "min"] + store_arr = df["store_path"].astype(str).to_numpy() + fov_arr = df[fov_col].astype(str).to_numpy() + ch_arr = df["channel_name"].astype(str).to_numpy() + t_arr = df["t"].astype(int).to_numpy() + + norm_arrays = {stat: np.full(len(df), float("nan"), dtype=np.float32) for stat in stat_keys} + focus_arr = np.full(len(df), float("nan"), dtype=np.float32) + z_arr = df["z"].to_numpy(dtype=np.int16).copy() + valid_mask = np.ones(len(df), dtype=bool) + + for i in range(len(df)): + tp_stats = stat_lookup.get((store_arr[i], fov_arr[i], ch_arr[i], t_arr[i])) + if tp_stats is None or tp_stats.get("max", 1.0) == 0.0: + valid_mask[i] = False + continue + for stat in stat_keys: + norm_arrays[stat][i] = float(tp_stats[stat]) + fov_key = (store_arr[i], fov_arr[i]) + z_focus = focus_lookup.get(fov_key) + if z_focus is not None: + focus_arr[i] = z_focus + z_t = focus_per_t_lookup.get(fov_key, {}).get(t_arr[i]) + if z_t is not None: + z_arr[i] = z_t + + for stat in stat_keys: + df[f"norm_{stat}"] = norm_arrays[stat] + df["z_focus_mean"] = focus_arr + df["z"] = z_arr + + df = df[valid_mask].reset_index(drop=True) + n_dropped = n_before - len(df) + + write_cell_index(df, output_path) + if n_dropped > 0: + _logger.info("Dropped %d invalid rows (%.1f%%).", n_dropped, 100 * n_dropped / n_before) + _logger.info( + "Wrote %d rows to %s (dropped %d, added norm + focus columns)", + len(df), + output_path, + n_dropped, + ) # --------------------------------------------------------------------------- @@ -194,11 +365,17 @@ def _reconstruct_lineage(tracks: pd.DataFrame) -> pd.DataFrame: ancestor. Tracks without a ``parent_track_id`` (or whose parent is not present in the data) are their own root. + The lineage walk is scoped per ``(experiment, well, fov)`` when the + ``well`` column is available. Scoping on ``(experiment, fov)`` alone + collapses cells across wells that share an FOV number (e.g. B/2/002001 + and C/2/002001), producing cross-well lineage_id aliasing that later + crashes the temporal positive lookup with "No positive found". + Parameters ---------- tracks : pd.DataFrame Must contain ``global_track_id``, ``experiment``, ``fov``, ``track_id``. - Optionally ``parent_track_id``. + Optionally ``parent_track_id`` and ``well``. Returns ------- @@ -216,8 +393,9 @@ def _reconstruct_lineage(tracks: pd.DataFrame) -> pd.DataFrame: lineage_series = tracks["lineage_id"].copy() - groups = list(tracks.groupby(["experiment", "fov"])) - for (exp, fov), group in tqdm(groups, desc="Reconstructing lineages", unit="fov"): + group_keys = ["experiment", "well", "fov"] if "well" in tracks.columns else ["experiment", "fov"] + groups = list(tracks.groupby(group_keys)) + for _key, group in tqdm(groups, desc="Reconstructing lineages", unit="fov"): tid_to_gtid: dict[int, str] = dict(zip(group["track_id"], group["global_track_id"])) parent_map: dict[str, str] = {} @@ -274,8 +452,8 @@ def _build_experiment_tracks( if exclude_fovs is not None: all_exclude.update(exclude_fovs) - # Channel-marker pairs from per-experiment channels list - channel_marker_pairs = [(ch.name, ch.marker) for ch in exp.channels] + # Channel entries from per-experiment channels list + channel_entries = [(ch.name, ch.marker, set(ch.wells)) for ch in exp.channels] exp_tracks: list[pd.DataFrame] = [] @@ -305,6 +483,10 @@ def _build_experiment_tracks( raise ValueError(f"Expected exactly one tracking CSV in {tracks_dir}, found: {csv_files}") tracks_df = pd.read_csv(csv_files[0]) + # TCZYX shape from zarr metadata (same for all positions in a well) + img_arr = position["0"] + t_shape, c_shape, z_shape, y_shape, x_shape = img_arr.shape + # Base columns (shared across channel rows) tracks_df["cell_id"] = ( exp.name + "_" + fov_path + "_" + tracks_df["track_id"].astype(str) + "_" + tracks_df["t"].astype(str) @@ -322,12 +504,19 @@ def _build_experiment_tracks( tracks_df["organelle"] = exp.organelle tracks_df["pixel_size_xy_um"] = exp.pixel_size_xy_um tracks_df["pixel_size_z_um"] = exp.pixel_size_z_um + tracks_df["T_shape"] = t_shape + tracks_df["C_shape"] = c_shape + tracks_df["Z_shape"] = z_shape + tracks_df["Y_shape"] = y_shape + tracks_df["X_shape"] = x_shape if "z" not in tracks_df.columns: tracks_df["z"] = 0 - # Explode: one row per channel - for zarr_ch, marker in channel_marker_pairs: + # Explode: one row per channel (skip channels restricted to other wells) + for zarr_ch, marker, valid_wells in channel_entries: + if valid_wells and well_name not in valid_wells: + continue ch_df = tracks_df.copy() ch_df["channel_name"] = zarr_ch ch_df["marker"] = marker @@ -370,7 +559,7 @@ def build_timelapse_cell_index( experiments = collection.experiments n_workers = os.cpu_count() if num_workers == -1 else num_workers - print(f"Building cell index: {len(experiments)} experiments, {n_workers} workers") + _logger.info("Building cell index: %d experiments, %s workers", len(experiments), n_workers) all_tracks: list[pd.DataFrame] = [] @@ -379,7 +568,7 @@ def build_timelapse_cell_index( df = _build_experiment_tracks(exp, include_wells, exclude_fovs) if not df.empty: all_tracks.append(df) - print(f" {exp.name}: {len(df):,} rows") + _logger.info(" %s: %d rows", exp.name, len(df)) else: futures = {} with ProcessPoolExecutor(max_workers=n_workers) as executor: @@ -398,7 +587,7 @@ def build_timelapse_cell_index( df = future.result() if not df.empty: all_tracks.append(df) - print(f" {exp_name}: {len(df):,} rows") + _logger.info(" %s: %d rows", exp_name, len(df)) pbar.update(1) if not all_tracks: @@ -411,7 +600,7 @@ def build_timelapse_cell_index( df[col] = None write_cell_index(df, output_path) - print(f"Wrote {len(df):,} rows to {output_path}") + _logger.info("Wrote %d rows to %s", len(df), output_path) return df diff --git a/packages/viscy-data/src/viscy_data/channel_utils.py b/packages/viscy-data/src/viscy_data/channel_utils.py index 9f7dc3753..63fcc9c16 100644 --- a/packages/viscy-data/src/viscy_data/channel_utils.py +++ b/packages/viscy-data/src/viscy_data/channel_utils.py @@ -50,7 +50,7 @@ def parse_channel_name(name: str) -> dict: # Label-free patterns (use word boundaries for short keywords) labelfree_substrings = ("phase", "brightfield", "retardance") - labelfree_word_patterns = (r"\bbf[\b_]", r"\bdic\b", r"\bpol\b") + labelfree_word_patterns = (r"\bbf(\b|_)", r"\bdic\b", r"\bpol\b", r"\bphc\b") if any(kw in name_lower for kw in labelfree_substrings) or any( re.search(p, name_lower) for p in labelfree_word_patterns ): diff --git a/packages/viscy-data/src/viscy_data/collection.py b/packages/viscy-data/src/viscy_data/collection.py index dd4be9dcb..a28656e7c 100644 --- a/packages/viscy-data/src/viscy_data/collection.py +++ b/packages/viscy-data/src/viscy_data/collection.py @@ -58,10 +58,14 @@ class ChannelEntry(BaseModel): Zarr channel name (e.g. ``"Phase3D"``, ``"raw GFP EX488 EM525-45"``). marker : str Protein marker or channel identity (e.g. ``"Phase3D"``, ``"TOMM20"``). + wells : list[str] + Wells where this channel is biologically valid (e.g. ``["B/3", "C/2"]``). + Empty list means the channel is valid in all wells of the experiment. """ name: str marker: str + wells: list[str] = [] class ExperimentEntry(BaseModel): @@ -144,6 +148,10 @@ class Collection(BaseModel): Collection name. description : str Human-readable description. + datasets_root : str or None + Optional path prefix substituted for ``${datasets_root}`` in + ``data_path`` and ``tracks_path`` at load time. Paths not + starting with this root are left unchanged. provenance : Provenance How the collection was created. experiments : list[ExperimentEntry] @@ -154,6 +162,7 @@ class Collection(BaseModel): name: str description: str = "" + datasets_root: str | None = None provenance: Provenance = Provenance() experiments: list[ExperimentEntry] fov_records: list[FOVRecord] = [] @@ -171,9 +180,9 @@ def _validate_collection(self) -> Collection: seen.add(e.name) for exp in self.experiments: - if exp.interval_minutes <= 0: + if exp.interval_minutes < 0: raise ValueError( - f"Experiment '{exp.name}': interval_minutes must be positive, got {exp.interval_minutes}." + f"Experiment '{exp.name}': interval_minutes must be non-negative, got {exp.interval_minutes}." ) wells = exp.perturbation_wells if not wells: @@ -182,6 +191,39 @@ def _validate_collection(self) -> Collection: return self +_DATASETS_ROOT_VAR = "${datasets_root}" + + +def _resolve_datasets_root(data: dict) -> None: + """Replace ``${datasets_root}`` in experiment paths with the root value. + + Mutates *data* in place. + """ + root = data.get("datasets_root") + if not root: + return + root = root.rstrip("/") + for exp in data.get("experiments", []): + for key in ("data_path", "tracks_path"): + val = exp.get(key, "") + if _DATASETS_ROOT_VAR in val: + exp[key] = val.replace(_DATASETS_ROOT_VAR, root) + + +def _unresolve_datasets_root(data: dict, datasets_root: str) -> None: + """Replace the resolved root prefix with ``${datasets_root}`` for portable YAML. + + Mutates *data* in place. Only paths that start with *datasets_root* are + modified; paths pointing elsewhere are left as absolute strings. + """ + root = datasets_root.rstrip("/") + for exp in data.get("experiments", []): + for key in ("data_path", "tracks_path"): + val = exp.get(key, "") + if val.startswith(root + "/"): + exp[key] = _DATASETS_ROOT_VAR + val[len(root) :] + + def load_collection(path: str | Path) -> Collection: """Load a collection from a YAML file. @@ -197,6 +239,7 @@ def load_collection(path: str | Path) -> Collection: """ with open(Path(path)) as f: data = yaml.safe_load(f) + _resolve_datasets_root(data) return Collection(**data) @@ -211,6 +254,8 @@ def save_collection(collection: Collection, path: str | Path) -> None: Output YAML path. """ data = collection.model_dump(mode="json") + if collection.datasets_root: + _unresolve_datasets_root(data, collection.datasets_root) with open(Path(path), "w") as f: yaml.safe_dump(data, f, default_flow_style=False, sort_keys=False) @@ -257,6 +302,7 @@ def build_collection( name: str, description: str = "", channel_markers: dict[str, list[tuple[str, str]]] | None = None, + datasets_root: str | None = None, ) -> Collection: """Build a collection by grouping FOVRecords into experiments. @@ -277,6 +323,9 @@ def build_collection( Per-experiment ``{exp_name: [(zarr_channel_name, marker), ...]}`` mapping. If None, derives from the first record's ``channel_names`` using channel names as markers. + datasets_root : str or None + Passed through to :class:`Collection`. When set, ``save_collection`` + will write ``${datasets_root}`` prefixes instead of absolute paths. Returns ------- @@ -305,6 +354,15 @@ def build_collection( elif first.channel_names: channels = [ChannelEntry(name=n, marker=n) for n in first.channel_names] + # Auto-populate wells per channel from per-record channel_markers. + # A channel gets a wells restriction if only a subset of wells have + # a non-None marker for it in Airtable. + all_wells = sorted({rec.well_id for rec in recs}) + for ch in channels: + wells_with_marker = sorted({rec.well_id for rec in recs if ch.name in rec.channel_markers}) + if wells_with_marker and wells_with_marker != all_wells: + ch.wells = wells_with_marker + experiments.append( ExperimentEntry( name=exp_name, @@ -316,6 +374,7 @@ def build_collection( start_hpi=first.hours_post_perturbation or 0.0, marker=first.marker or "", organelle=first.organelle or "", + microscope=first.microscope or "", pixel_size_xy_um=getattr(first, "pixel_size_xy_um", None), pixel_size_z_um=getattr(first, "pixel_size_z_um", None), moi=first.moi or 0.0, @@ -325,6 +384,7 @@ def build_collection( return Collection( name=name, description=description, + datasets_root=datasets_root, experiments=experiments, fov_records=records, ) diff --git a/packages/viscy-data/src/viscy_data/sampler.py b/packages/viscy-data/src/viscy_data/sampler.py index 75017b85d..68534b059 100644 --- a/packages/viscy-data/src/viscy_data/sampler.py +++ b/packages/viscy-data/src/viscy_data/sampler.py @@ -153,14 +153,35 @@ def __init__( # Precomputation # ------------------------------------------------------------------ + @staticmethod + def _indices_by_key(keys: pd.Series) -> dict[str, np.ndarray]: + """Return ``{key_str: row_index_array}`` for every unique value in *keys*. + + Fast path for Categorical keys uses NumPy ``cat.codes`` directly — + avoids materializing a pandas groupby iterator, which on large + (~16M row) Arrow-backed DataFrames routes every group slice + through ``pyarrow.compute.take`` and can take tens of minutes. + + For non-Categorical keys, falls back to the pandas groupby. + """ + # Categorical fast path — O(N) single vectorized pass per group. + if isinstance(keys.dtype, pd.CategoricalDtype): + codes = keys.cat.codes.to_numpy() + categories = list(keys.cat.categories) + out: dict[str, np.ndarray] = {} + for c, name in enumerate(categories): + rows = np.flatnonzero(codes == c) + if len(rows) > 0: + out[str(name)] = rows + return out + # Generic fallback. + return {str(name): group.to_numpy() for name, group in keys.groupby(keys).groups.items()} + def _precompute_groups(self) -> None: """Build index lookup tables from valid_anchors columns.""" - # Per-group indices if self.batch_group_by is not None: group_keys = self._compute_strat_keys(self.valid_anchors, self.batch_group_by) - self._group_indices: dict[str, np.ndarray] = { - str(name): group.index.to_numpy() for name, group in self.valid_anchors.groupby(group_keys) - } + self._group_indices: dict[str, np.ndarray] = self._indices_by_key(group_keys) self._group_names: list[str] = list(self._group_indices.keys()) else: self._group_indices = {} @@ -174,16 +195,19 @@ def _precompute_groups(self) -> None: if self.stratify_by is not None: strat_keys = self._compute_strat_keys(self.valid_anchors, self.stratify_by) - # Global stratification indices - for key in strat_keys.unique(): - self._strat_indices[key] = self.valid_anchors.index[strat_keys == key].to_numpy() + # Global stratification indices — NumPy fast path for Categorical. + self._strat_indices = self._indices_by_key(strat_keys) self._strat_names = list(self._strat_indices.keys()) - # Per-group stratification indices + # Per-group × per-stratum indices. Using np.intersect1d between + # pre-built group and strat index arrays stays NumPy-native + # instead of reinvoking pandas groupby on the full 16M-row frame. if self.batch_group_by is not None: - group_keys = self._compute_strat_keys(self.valid_anchors, self.batch_group_by) - for (grp, strat_key), group in self.valid_anchors.groupby([group_keys, strat_keys]): - self._group_strat_indices[(str(grp), str(strat_key))] = group.index.to_numpy() + for grp, g_idx in self._group_indices.items(): + for strat_key, s_idx in self._strat_indices.items(): + common = np.intersect1d(g_idx, s_idx, assume_unique=True) + if len(common) > 0: + self._group_strat_indices[(grp, strat_key)] = common # All indices self._all_indices = np.arange(len(self.valid_anchors)) @@ -212,23 +236,18 @@ def _precompute_groups(self) -> None: @staticmethod def _compute_strat_keys(df: pd.DataFrame, columns: list[str]) -> pd.Series: - """Compute a single string key per row for grouping. + """Compute a single key per row for grouping. - Parameters - ---------- - df : pd.DataFrame - DataFrame to compute keys for. - columns : list[str] - Column names to combine into group keys. + For a single column, returns the raw Series — pandas ``groupby`` + handles Categorical / string / numeric dtypes directly, and + ``df.col.astype(str)`` over an 80M-row Categorical allocates a + Python-object array that can spike 5-8 GiB transient RAM per call. - Returns - ------- - pd.Series - String keys, one per row. Single-column uses values directly; - multi-column joins with ``"|"``. + For multi-column keys, falls back to the ``"|"``-joined string form + which is unavoidable with pandas groupby today. """ if len(columns) == 1: - return df[columns[0]].astype(str) + return df[columns[0]] return df[columns].astype(str).agg("|".join, axis=1) # ------------------------------------------------------------------ @@ -249,13 +268,47 @@ def __len__(self) -> int: return math.ceil(total_batches / self.num_replicas) def __iter__(self) -> Iterator[list[int]]: - """Yield batch-sized lists of integer indices.""" - rng = np.random.default_rng(self.seed + self.epoch) + """Yield batch-sized lists of integer indices. + + Builds batches lazily so the first batch is ready in milliseconds + instead of blocking on a full-epoch materialization. Every rank + still calls ``_build_one_batch`` on every index so the RNG draws + stay identical to the list-based implementation — only the + *yield* is rank-filtered, not the sampling. DDP correctness is + therefore bit-identical to the previous implementation; the only + change is that the main thread sees batch 0 after one + ``_build_one_batch`` call instead of ``total_batches`` calls. + + ``limit_train_batches`` interacts with this: Lightning stops + pulling from the generator after its cap, so we never pay for + the unused suffix of the epoch. + + The epoch counter auto-advances at the start of each iteration + so that the next ``__iter__`` call reseeds the RNG with a fresh + ``seed + epoch`` and yields a different batch sequence. Advancing + at the start (not the end) is robust against early generator + termination from ``limit_train_batches``: Lightning stops pulling + after its cap and garbage-collects the generator, which would + skip any end-of-iter bookkeeping. + + PyTorch Lightning does not call ``set_epoch`` on custom + ``batch_sampler`` instances (``use_distributed_sampler: false`` + with a batch sampler means Lightning's auto-wrap skips us), so + we self-advance. ``set_epoch`` still works if a caller wants + deterministic resume from a specific epoch — call it before the + iteration and the advance will take the resumed epoch as its + starting point. + """ + seed_offset = self.epoch + self.epoch += 1 + rng = np.random.default_rng(self.seed + seed_offset) total_batches = len(self.valid_anchors) // self.batch_size - all_batches = [self._build_one_batch(rng) for _ in range(total_batches)] - # DDP: each rank takes its interleaved slice - my_batches = all_batches[self.rank :: self.num_replicas] - yield from my_batches + rank = self.rank + replicas = self.num_replicas + for i in range(total_batches): + batch = self._build_one_batch(rng) + if i % replicas == rank: + yield batch # ------------------------------------------------------------------ # Batch construction diff --git a/packages/viscy-data/src/viscy_data/schemas.py b/packages/viscy-data/src/viscy_data/schemas.py index a7f96eb5d..575230657 100644 --- a/packages/viscy-data/src/viscy_data/schemas.py +++ b/packages/viscy-data/src/viscy_data/schemas.py @@ -54,6 +54,14 @@ class FOVRecord(BaseModel): Treatment concentration in nanomolar. fluorescence_modality : str or None Fluorescence imaging modality. + microscope : str or None + Microscope identifier (e.g. ``"mantis"``, ``"dragonfly"``). + labelfree_modality : str or None + Label-free imaging modality (e.g. ``"widefield"``, ``"oblique"``). + treatment : str or None + Treatment name (e.g. ``"DMSO"``, ``"Bafilomycin"``). + hours_post_treatment : float or None + Hours post treatment at imaging start. t_shape : int or None Number of timepoints. c_shape : int or None @@ -68,6 +76,10 @@ class FOVRecord(BaseModel): Physical pixel size in the XY plane (micrometers). pixel_size_z_um : float or None Physical pixel size in Z (micrometers). + channel_markers : dict[str, str] + Maps zarr channel name to marker for this well. + Only channels with a non-None marker in Airtable are included. + Empty dict means no per-well channel marker information is available. """ dataset: str @@ -88,6 +100,10 @@ class FOVRecord(BaseModel): seeding_density: int | None = None treatment_concentration_nm: float | None = None fluorescence_modality: str | None = None + microscope: str | None = None + labelfree_modality: str | None = None + treatment: str | None = None + hours_post_treatment: float | None = None t_shape: int | None = None c_shape: int | None = None z_shape: int | None = None @@ -95,3 +111,4 @@ class FOVRecord(BaseModel): x_shape: int | None = None pixel_size_xy_um: float | None = None pixel_size_z_um: float | None = None + channel_markers: dict[str, str] = {} diff --git a/packages/viscy-data/tests/test_cell_index.py b/packages/viscy-data/tests/test_cell_index.py index c6fd6aa62..0166e65bf 100644 --- a/packages/viscy-data/tests/test_cell_index.py +++ b/packages/viscy-data/tests/test_cell_index.py @@ -15,6 +15,7 @@ CELL_INDEX_CORE_COLUMNS, CELL_INDEX_GROUPING_COLUMNS, CELL_INDEX_IMAGING_COLUMNS, + CELL_INDEX_NORMALIZATION_COLUMNS, CELL_INDEX_OPS_COLUMNS, CELL_INDEX_TIMELAPSE_COLUMNS, ) @@ -22,6 +23,7 @@ CELL_INDEX_SCHEMA, _parse_bbox_min_size, _parse_bbox_to_centroid, + _reconstruct_lineage, build_timelapse_cell_index, convert_ops_parquet, read_cell_index, @@ -130,6 +132,7 @@ def test_strict_passes_with_all_columns(self): + CELL_INDEX_TIMELAPSE_COLUMNS + CELL_INDEX_OPS_COLUMNS + CELL_INDEX_IMAGING_COLUMNS + + CELL_INDEX_NORMALIZATION_COLUMNS ): df[col] = None warnings = validate_cell_index(df, strict=True) @@ -143,6 +146,7 @@ def test_all_null_column_warns(self): + CELL_INDEX_TIMELAPSE_COLUMNS + CELL_INDEX_OPS_COLUMNS + CELL_INDEX_IMAGING_COLUMNS + + CELL_INDEX_NORMALIZATION_COLUMNS ): df[col] = None warnings = validate_cell_index(df, strict=True) @@ -246,7 +250,7 @@ def test_lineage_reconstruction(self, tmp_path): dataset = open_ome_zarr(dataset_path, layout="hcs", mode="w", channel_names=["nuclei_labels"]) pos = dataset.create_position("A", "1", "0") rng = np.random.default_rng(42) - pos.create_image("0", rng.random((2, 1, 1, 64, 64)).astype(np.float32)) + pos.create_image("0", rng.random((4, 1, 1, 64, 64)).astype(np.float32)) # Track 0 → root, Track 1 → child of 0, Track 2 → grandchild of 1 tracks_df = pd.DataFrame( @@ -297,6 +301,104 @@ def test_cell_id_unique(self, tracks_hcs_dataset, tmp_path): assert not df["cell_id"].duplicated().any() +class TestReconstructLineage: + """Unit tests for ``_reconstruct_lineage`` — scoped directly, no zarr I/O.""" + + def test_cross_well_same_fov_does_not_collapse(self): + """ + Two wells (B/2 and C/2) that share the same FOV number ("002001") and + contain tracks with the same numeric ``track_id`` / ``parent_track_id`` + must NOT have their lineages fused. Prior to the fix, the groupby was + scoped by (experiment, fov) and the two wells were walked as if they + were one, aliasing their lineage_ids. + """ + rows = [] + # Well B/2, fov 002001: track_id 88 whose parent is 35; root is 35. + rows.append( + { + "experiment": "exp", + "well": "B/2", + "fov": "002001", + "track_id": 35, + "parent_track_id": -1, + "global_track_id": "exp_B/2/002001_35", + } + ) + rows.append( + { + "experiment": "exp", + "well": "B/2", + "fov": "002001", + "track_id": 88, + "parent_track_id": 35, + "global_track_id": "exp_B/2/002001_88", + } + ) + # Well C/2, fov 002001: independent track_id 86 whose parent is 34. + # Without the fix, the (exp, fov="002001") group sees BOTH wells' + # tracks, and the parent_track_id=34 lookup in the B/2-derived map + # fails, so track 86 becomes its own root — but track 35 from B/2 + # appears inside the same group, potentially misrouting. + rows.append( + { + "experiment": "exp", + "well": "C/2", + "fov": "002001", + "track_id": 34, + "parent_track_id": -1, + "global_track_id": "exp_C/2/002001_34", + } + ) + rows.append( + { + "experiment": "exp", + "well": "C/2", + "fov": "002001", + "track_id": 86, + "parent_track_id": 34, + "global_track_id": "exp_C/2/002001_86", + } + ) + tracks = pd.DataFrame(rows) + + result = _reconstruct_lineage(tracks.copy()) + + # B/2 rows must resolve to B/2 root; C/2 rows must resolve to C/2 root. + b2_rows = result[result["well"] == "B/2"] + c2_rows = result[result["well"] == "C/2"] + assert set(b2_rows["lineage_id"].unique()) == {"exp_B/2/002001_35"} + assert set(c2_rows["lineage_id"].unique()) == {"exp_C/2/002001_34"} + + def test_no_parent_track_id_column(self): + """If `parent_track_id` is missing, lineage_id falls back to global_track_id.""" + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 2, + "well": ["A/1"] * 2, + "fov": ["0"] * 2, + "track_id": [0, 1], + "global_track_id": ["exp_A/1/0_0", "exp_A/1/0_1"], + } + ) + result = _reconstruct_lineage(tracks.copy()) + assert (result["lineage_id"] == result["global_track_id"]).all() + + def test_single_well_chain_resolves_to_root(self): + """Basic sanity: a parent → daughter chain resolves daughters to root.""" + tracks = pd.DataFrame( + { + "experiment": ["exp"] * 3, + "well": ["A/1"] * 3, + "fov": ["0"] * 3, + "track_id": [0, 1, 2], + "parent_track_id": [-1, 0, 1], + "global_track_id": ["exp_A/1/0_0", "exp_A/1/0_1", "exp_A/1/0_2"], + } + ) + result = _reconstruct_lineage(tracks.copy()) + assert (result["lineage_id"] == "exp_A/1/0_0").all() + + # --------------------------------------------------------------------------- # OPS builder helpers (tests 11–14) # --------------------------------------------------------------------------- @@ -336,19 +438,29 @@ class TestCrossParadigm: def test_timelapse_has_null_ops_columns(self): """15. Time-lapse parquet has OPS columns as null.""" df = _make_timelapse_df() - for col in CELL_INDEX_OPS_COLUMNS + CELL_INDEX_BIOLOGY_COLUMNS + CELL_INDEX_IMAGING_COLUMNS: + for col in ( + CELL_INDEX_OPS_COLUMNS + + CELL_INDEX_BIOLOGY_COLUMNS + + CELL_INDEX_IMAGING_COLUMNS + + CELL_INDEX_NORMALIZATION_COLUMNS + ): df[col] = None warnings = validate_cell_index(df, strict=True) - ops_warnings = [w for w in warnings if any(c in w for c in CELL_INDEX_OPS_COLUMNS)] + ops_warnings = [w for w in warnings if any(f"'{c}'" in w for c in CELL_INDEX_OPS_COLUMNS)] assert len(ops_warnings) == len(CELL_INDEX_OPS_COLUMNS) def test_ops_has_null_timelapse_columns(self): """16. OPS parquet has time-lapse columns as null.""" df = _make_ops_df() - for col in CELL_INDEX_TIMELAPSE_COLUMNS + CELL_INDEX_BIOLOGY_COLUMNS + CELL_INDEX_IMAGING_COLUMNS: + for col in ( + CELL_INDEX_TIMELAPSE_COLUMNS + + CELL_INDEX_BIOLOGY_COLUMNS + + CELL_INDEX_IMAGING_COLUMNS + + CELL_INDEX_NORMALIZATION_COLUMNS + ): df[col] = None warnings = validate_cell_index(df, strict=True) - tl_warnings = [w for w in warnings if any(c in w for c in CELL_INDEX_TIMELAPSE_COLUMNS)] + tl_warnings = [w for w in warnings if any(f"'{c}'" in w for c in CELL_INDEX_TIMELAPSE_COLUMNS)] assert len(tl_warnings) == len(CELL_INDEX_TIMELAPSE_COLUMNS) def test_concat_schema_compatible(self, tmp_path): diff --git a/packages/viscy-data/tests/test_collection.py b/packages/viscy-data/tests/test_collection.py index 9ca19297e..4cd824ef0 100644 --- a/packages/viscy-data/tests/test_collection.py +++ b/packages/viscy-data/tests/test_collection.py @@ -1,6 +1,7 @@ """Tests for viscy_data.collection: Collection, load/save, build_collection.""" import pytest +import yaml from viscy_data.collection import ( ChannelEntry, @@ -55,16 +56,15 @@ def test_duplicate_experiment_names(self): with pytest.raises(ValueError, match="Duplicate experiment name"): _make_collection(experiments=[exp, exp]) - def test_interval_minutes_not_positive(self): - """Raise ValueError when interval_minutes <= 0.""" + def test_zero_interval_minutes_allowed(self): + """Zero interval_minutes is valid (non-timelapse data).""" exp = _make_experiment(name="exp1", interval_minutes=0.0) - with pytest.raises(ValueError, match="interval_minutes must be positive"): - _make_collection(experiments=[exp]) + _make_collection(experiments=[exp]) def test_negative_interval_minutes(self): """Raise ValueError when interval_minutes is negative.""" exp = _make_experiment(name="exp1", interval_minutes=-5.0) - with pytest.raises(ValueError, match="interval_minutes must be positive"): + with pytest.raises(ValueError, match="interval_minutes must be non-negative"): _make_collection(experiments=[exp]) def test_perturbation_wells_empty(self): @@ -241,3 +241,211 @@ def test_single_marker_dataset_not_split(self): grouped = _group_records(records) assert len(grouped) == 1 assert "plate1" in grouped + + +class TestChannelWells: + """Test per-well channel validity restriction via ChannelEntry.wells.""" + + def _make_viral_sensor_records(self): + """FOVRecords for a mixed plate where viral sensor is only in B/3 and C/2.""" + common = dict( + dataset="2025_01_24", + data_path="/data/2025_01_24.zarr", + tracks_path="/tracks/2025_01_24", + channel_names=["Phase3D", "raw mCherry EX561 EM600-37"], + time_interval_min=15.0, + ) + # B/1, B/2: no viral sensor (channel_markers has no entry for mCherry) + no_sensor = [ + FOVRecord(**common, well_id="B/1", cell_state="uninfected", channel_markers={"Phase3D": "Phase3D"}), + FOVRecord(**common, well_id="B/2", cell_state="uninfected", channel_markers={"Phase3D": "Phase3D"}), + ] + # B/3, C/2: viral sensor present + sensor = [ + FOVRecord( + **common, + well_id="B/3", + cell_state="uninfected", + channel_markers={"Phase3D": "Phase3D", "raw mCherry EX561 EM600-37": "pAL40"}, + ), + FOVRecord( + **common, + well_id="C/2", + cell_state="infected", + channel_markers={"Phase3D": "Phase3D", "raw mCherry EX561 EM600-37": "pAL40"}, + ), + ] + return no_sensor + sensor + + def test_wells_auto_populated_for_partial_channel(self): + """build_collection restricts a channel to wells where it has a marker.""" + records = self._make_viral_sensor_records() + coll = build_collection(records, name="test") + exp = coll.experiments[0] + + phase = next(ch for ch in exp.channels if ch.name == "Phase3D") + mcherry = next(ch for ch in exp.channels if ch.name == "raw mCherry EX561 EM600-37") + + assert phase.wells == [], "Phase3D is valid in all wells — wells must be empty" + assert sorted(mcherry.wells) == ["B/3", "C/2"], "mCherry only valid in B/3, C/2" + + def test_wells_empty_when_all_wells_have_marker(self): + """When all wells share a marker, wells stays empty (no restriction needed).""" + records = [ + FOVRecord( + dataset="exp", + well_id="A/1", + data_path="/d.zarr", + tracks_path="/t", + channel_names=["Phase3D"], + cell_state="uninfected", + channel_markers={"Phase3D": "Phase3D"}, + ), + FOVRecord( + dataset="exp", + well_id="A/2", + data_path="/d.zarr", + tracks_path="/t", + channel_names=["Phase3D"], + cell_state="infected", + channel_markers={"Phase3D": "Phase3D"}, + ), + ] + coll = build_collection(records, name="test") + phase = coll.experiments[0].channels[0] + assert phase.wells == [] + + def test_wells_round_trips_yaml(self, tmp_path): + """wells field survives save_collection → load_collection.""" + records = self._make_viral_sensor_records() + coll = build_collection(records, name="test") + path = tmp_path / "col.yml" + save_collection(coll, path) + loaded = load_collection(path) + mcherry = next(ch for ch in loaded.experiments[0].channels if ch.name == "raw mCherry EX561 EM600-37") + assert sorted(mcherry.wells) == ["B/3", "C/2"] + + def test_channel_entry_wells_default_empty(self): + """ChannelEntry.wells defaults to empty list.""" + ch = ChannelEntry(name="Phase3D", marker="Phase3D") + assert ch.wells == [] + + +def _write_yaml(path, data): + with open(path, "w") as f: + yaml.safe_dump(data, f, default_flow_style=False, sort_keys=False) + + +def _minimal_experiment(name, data_path, tracks_path): + return { + "name": name, + "data_path": data_path, + "tracks_path": tracks_path, + "channels": [{"name": "Phase3D", "marker": "Phase3D"}], + "perturbation_wells": {"mock": ["A/1"]}, + } + + +class TestDatasetsRoot: + """Test ${datasets_root} substitution in load/save round-trip.""" + + def test_resolve_datasets_root(self, tmp_path): + """Paths with ${datasets_root} are fully resolved after load.""" + data = { + "name": "test", + "datasets_root": "/hpc/projects/organelle_phenotyping", + "experiments": [ + _minimal_experiment( + "exp1", + "${datasets_root}/datasets/exp1/exp1.zarr", + "${datasets_root}/datasets/exp1/tracking.zarr", + ) + ], + } + _write_yaml(tmp_path / "col.yml", data) + coll = load_collection(tmp_path / "col.yml") + assert coll.experiments[0].data_path == "/hpc/projects/organelle_phenotyping/datasets/exp1/exp1.zarr" + assert coll.experiments[0].tracks_path == "/hpc/projects/organelle_phenotyping/datasets/exp1/tracking.zarr" + assert coll.datasets_root == "/hpc/projects/organelle_phenotyping" + + def test_round_trip_preserves_templates(self, tmp_path): + """save_collection writes ${datasets_root} back; reload resolves again.""" + data = { + "name": "test", + "datasets_root": "/hpc/projects/organelle_phenotyping", + "experiments": [ + _minimal_experiment( + "exp1", + "${datasets_root}/datasets/exp1/exp1.zarr", + "${datasets_root}/datasets/exp1/tracking.zarr", + ) + ], + } + yaml_path = tmp_path / "col.yml" + _write_yaml(yaml_path, data) + coll = load_collection(yaml_path) + out_path = tmp_path / "col_out.yml" + save_collection(coll, out_path) + + with open(out_path) as f: + on_disk = yaml.safe_load(f) + + assert "${datasets_root}" in on_disk["experiments"][0]["data_path"] + assert "${datasets_root}" in on_disk["experiments"][0]["tracks_path"] + + reloaded = load_collection(out_path) + assert reloaded.experiments[0].data_path == "/hpc/projects/organelle_phenotyping/datasets/exp1/exp1.zarr" + + def test_mixed_paths_non_root_stays_absolute(self, tmp_path): + """Paths not under datasets_root survive save unchanged.""" + data = { + "name": "test", + "datasets_root": "/hpc/projects/organelle_phenotyping", + "experiments": [ + _minimal_experiment( + "exp_vast", + "${datasets_root}/datasets/exp1/exp1.zarr", + "${datasets_root}/datasets/exp1/tracking.zarr", + ), + _minimal_experiment( + "exp_nfs", + "${datasets_root}/datasets/exp2/exp2.zarr", + "/hpc/projects/intracellular_dashboard/viral-sensor/tracking.zarr", + ), + ], + } + yaml_path = tmp_path / "col.yml" + _write_yaml(yaml_path, data) + coll = load_collection(yaml_path) + assert coll.experiments[1].tracks_path == "/hpc/projects/intracellular_dashboard/viral-sensor/tracking.zarr" + + out_path = tmp_path / "col_out.yml" + save_collection(coll, out_path) + with open(out_path) as f: + on_disk = yaml.safe_load(f) + nfs_path = "/hpc/projects/intracellular_dashboard/viral-sensor/tracking.zarr" + assert on_disk["experiments"][1]["tracks_path"] == nfs_path + + def test_no_datasets_root_passthrough(self, tmp_path): + """Collections without datasets_root load and save unchanged.""" + data = { + "name": "test", + "experiments": [ + _minimal_experiment( + "exp1", + "/absolute/data/exp1.zarr", + "/absolute/tracks/exp1", + ) + ], + } + yaml_path = tmp_path / "col.yml" + _write_yaml(yaml_path, data) + coll = load_collection(yaml_path) + assert coll.datasets_root is None + assert coll.experiments[0].data_path == "/absolute/data/exp1.zarr" + + out_path = tmp_path / "col_out.yml" + save_collection(coll, out_path) + with open(out_path) as f: + on_disk = yaml.safe_load(f) + assert on_disk["experiments"][0]["data_path"] == "/absolute/data/exp1.zarr" diff --git a/packages/viscy-data/tests/test_sampler.py b/packages/viscy-data/tests/test_sampler.py index 20571033d..202cb2937 100644 --- a/packages/viscy-data/tests/test_sampler.py +++ b/packages/viscy-data/tests/test_sampler.py @@ -181,6 +181,107 @@ def test_batch_group_by_none_allows_mixing(self, two_experiment_anchors: pd.Data assert any_mixed, "With batch_group_by=None, at least one batch should mix experiments" +# --------------------------------------------------------------------------- +# Marker-aware batching (bag-of-channels regime) +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def multi_marker_anchors() -> pd.DataFrame: + """DataFrame with 1 experiment, 4 markers, 2 conditions, 320 rows total. + + Represents the bag-of-channels regime where each row is one (cell, + timepoint, channel) observation and ``marker`` identifies which + channel/protein the patch came from. + """ + rng = np.random.default_rng(7) + rows = [] + for marker in ["Phase3D", "TOMM20", "SEC61B", "Brightfield"]: + for cond in ["infected", "uninfected"]: + for i in range(40): + rows.append( + { + "experiment": "exp_boc", + "condition": cond, + "marker": marker, + "hours_post_perturbation": rng.uniform(0, 24), + "global_track_id": f"{marker}_{cond}_{i}", + "t": rng.integers(0, 20), + } + ) + df = pd.DataFrame(rows) + return df.reset_index(drop=True) + + +class TestMarkerAware: + """batch_group_by="marker" produces single-marker batches shuffled across markers. + + This is the bag-of-channels training regime — the config asks for one + marker per batch so contrastive pairs stay within the same channel, + while different batches traverse the full marker pool across an + epoch. + """ + + def test_every_batch_is_single_marker(self, multi_marker_anchors: pd.DataFrame): + """Every batch must contain rows from exactly one marker.""" + sampler = FlexibleBatchSampler( + valid_anchors=multi_marker_anchors, + batch_size=16, + batch_group_by="marker", + stratify_by=None, + leaky=0.0, + seed=42, + ) + batches = list(sampler) + assert batches, "Sampler should yield batches" + for batch in batches: + markers = multi_marker_anchors.iloc[batch]["marker"].unique() + assert len(markers) == 1, f"batch_group_by='marker' batch has {len(markers)} markers: {markers}" + + def test_all_markers_appear_across_epoch(self, multi_marker_anchors: pd.DataFrame): + """Across one epoch every marker surfaces in at least one batch.""" + sampler = FlexibleBatchSampler( + valid_anchors=multi_marker_anchors, + batch_size=16, + batch_group_by="marker", + stratify_by=None, + leaky=0.0, + seed=42, + ) + seen: set[str] = set() + for batch in sampler: + seen.update(multi_marker_anchors.iloc[batch]["marker"].unique()) + expected = {"Phase3D", "TOMM20", "SEC61B", "Brightfield"} + assert seen == expected, f"Not all markers surfaced in one epoch: {seen} vs {expected}" + + def test_batches_shuffled_across_markers(self, multi_marker_anchors: pd.DataFrame): + """Consecutive batches should not all be the same marker — the sampler + must interleave marker groups rather than drain them sequentially. + + We require at least half of the marker-to-marker batch transitions + to be a change (pathological samplers that yield all Phase3D + batches first, then all TOMM20, etc., would get a change-ratio + close to ``1/num_batches`` which this threshold catches). + """ + sampler = FlexibleBatchSampler( + valid_anchors=multi_marker_anchors, + batch_size=16, + batch_group_by="marker", + stratify_by=None, + leaky=0.0, + seed=42, + ) + per_batch_marker: list[str] = [] + for batch in sampler: + per_batch_marker.append(multi_marker_anchors.iloc[batch]["marker"].iloc[0]) + transitions = [a != b for a, b in zip(per_batch_marker[:-1], per_batch_marker[1:], strict=False)] + change_ratio = sum(transitions) / len(transitions) + assert change_ratio >= 0.5, ( + f"Only {change_ratio:.1%} of consecutive batches changed marker; " + "sampler appears to drain groups sequentially instead of shuffling" + ) + + # --------------------------------------------------------------------------- # Stratified sampling (SAMP-02) # --------------------------------------------------------------------------- @@ -391,6 +492,22 @@ def test_set_epoch_same_epoch_same_result(self, two_experiment_anchors: pd.DataF batches_b = list(sampler) assert batches_a == batches_b + def test_iter_auto_advances_epoch(self, two_experiment_anchors: pd.DataFrame): + """Consecutive iterations must yield different sequences without set_epoch. + + PL does not call ``set_epoch`` on ``batch_sampler`` instances, so the + sampler must self-advance. Regression guard for the frozen-dataset bug. + """ + sampler = FlexibleBatchSampler( + valid_anchors=two_experiment_anchors, + batch_size=8, + batch_group_by="experiment", + stratify_by=None, + leaky=0.0, + seed=42, + ) + assert list(sampler) != list(sampler) + # --------------------------------------------------------------------------- # __len__ and __iter__ protocol diff --git a/packages/viscy-models/src/viscy_models/components/heads.py b/packages/viscy-models/src/viscy_models/components/heads.py index 9f22fdcfa..0fe68798c 100644 --- a/packages/viscy-models/src/viscy_models/components/heads.py +++ b/packages/viscy-models/src/viscy_models/components/heads.py @@ -23,6 +23,7 @@ "BaseHead", "ClassificationHead", "CosineClassifier", + "CrossModalContrastiveHead", "MLP", "PixelToVoxelHead", "PixelToVoxelShuffleHead", @@ -270,6 +271,152 @@ def log_metrics(self, out: dict, log_fn: callable, stage: str) -> None: log_fn(f"metrics/acc_top{self.top_k}/{self.head_name}/{stage}", topk) +class CrossModalContrastiveHead(BaseHead): + """Cross-modal InfoNCE head pulling image features toward a paired vector target. + + Projects image features and a per-cell paired vector (e.g. transcriptomic + embedding) into a shared space, then minimises an InfoNCE loss across the + batch. Samples whose target contains NaN (e.g. unpaired cells) are masked + out so the head can run on partially-paired batches. + + Parameters + ---------- + head_name : str + Name used for logging. + batch_key : str + Key used to look up the target in the batch dict (e.g. ``"X_pls"``). + in_dims : int + Backbone feature dimensionality. + target_dims : int + Dimensionality of the paired target vector. + proj_dims : int + Dimensionality of the shared projection space. Default 128. + image_hidden : int | list[int] + Hidden width(s) of the image-side projector. Default 256. + target_hidden : int | list[int] + Hidden width(s) of the target-side projector. Default 128. + temperature : float + Softmax temperature for the InfoNCE loss. Default 0.1. + loss_weight : float + Final loss weight (see :class:`BaseHead`). + weight_schedule : {"cosine", "constant"} + weight_start : float + weight_warmup_epochs : int + """ + + def __init__( + self, + head_name: str, + batch_key: str, + in_dims: int, + target_dims: int, + proj_dims: int = 128, + image_hidden: int | list[int] = 256, + target_hidden: int | list[int] = 128, + temperature: float = 0.1, + loss_weight: float = 1.0, + weight_schedule: Literal["cosine", "constant"] = "constant", + weight_start: float = 0.0, + weight_warmup_epochs: int = 50, + ) -> None: + super().__init__( + head_name=head_name, + batch_key=batch_key, + loss_weight=loss_weight, + weight_schedule=weight_schedule, + weight_start=weight_start, + weight_warmup_epochs=weight_warmup_epochs, + ) + self.image_proj = MLP(in_dims=in_dims, hidden_dims=image_hidden, out_dims=proj_dims, norm="ln") + self.target_proj = MLP(in_dims=target_dims, hidden_dims=target_hidden, out_dims=proj_dims, norm="ln") + self.temperature = temperature + + def forward(self, x: Tensor) -> Tensor: + """Project image features into the shared cross-modal space. + + Parameters + ---------- + x : Tensor + Backbone features, shape ``(B, in_dims)``. + + Returns + ------- + Tensor + L2-normalised image projections, shape ``(B, proj_dims)``. + """ + return F.normalize(self.image_proj(x), dim=-1) + + def project_target(self, target: Tensor) -> Tensor: + """Project the paired vector target into the shared space. + + Parameters + ---------- + target : Tensor + Paired target vectors, shape ``(B, target_dims)``. + + Returns + ------- + Tensor + L2-normalised target projections, shape ``(B, proj_dims)``. + """ + return F.normalize(self.target_proj(target), dim=-1) + + def compute_loss(self, y_hat: Tensor, y: Tensor) -> Tensor: + """Symmetric InfoNCE loss between projected image and target features. + + Rows of ``y`` containing NaN are masked out (unpaired cells). If fewer + than two paired cells are present in the batch the loss is zero. + + Parameters + ---------- + y_hat : Tensor + Already-projected image features from :meth:`forward`, shape + ``(B, proj_dims)``. + y : Tensor + Paired target vectors, shape ``(B, target_dims)``. May contain NaN. + + Returns + ------- + Tensor + Scalar InfoNCE loss. + """ + valid = ~torch.isnan(y).any(dim=-1) + if valid.sum() < 2: + return y_hat.new_zeros(()) + z_img = y_hat[valid] + z_tgt = self.project_target(y[valid]) + logits = z_img @ z_tgt.t() / self.temperature + labels = torch.arange(z_img.size(0), device=z_img.device) + return 0.5 * (F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels)) + + def log_metrics(self, out: dict, log_fn: callable, stage: str) -> None: + """Log loss, mean image-target cosine, and retrieval@1 for paired cells. + + Parameters + ---------- + out : dict + Must contain ``"loss"``, ``"logits"`` (projected image features), and + ``"y"`` (raw paired target). + log_fn : callable + Lightning's ``self.log``. + stage : str + ``"train"`` or ``"val"``. + """ + y = out["y"] + valid = ~torch.isnan(y).any(dim=-1) + log_fn(f"loss/aux/{self.head_name}/{stage}", out["loss"]) + log_fn(f"metrics/paired_frac/{self.head_name}/{stage}", valid.float().mean()) + if valid.sum() < 2: + return + z_img = out["logits"][valid] + z_tgt = self.project_target(y[valid]) + cos_diag = (z_img * z_tgt).sum(dim=-1).mean() + logits = z_img @ z_tgt.t() / self.temperature + retrieval = (logits.argmax(dim=1) == torch.arange(z_img.size(0), device=z_img.device)).float().mean() + log_fn(f"metrics/cos/{self.head_name}/{stage}", cos_diag) + log_fn(f"metrics/r@1/{self.head_name}/{stage}", retrieval) + + class CosineClassifier(nn.Module): """L2-normalised linear head with learnable temperature. diff --git a/packages/viscy-models/src/viscy_models/foundation/__init__.py b/packages/viscy-models/src/viscy_models/foundation/__init__.py index 2bed46613..12f1da3ee 100644 --- a/packages/viscy-models/src/viscy_models/foundation/__init__.py +++ b/packages/viscy-models/src/viscy_models/foundation/__init__.py @@ -1,6 +1,7 @@ """Pretrained foundation model wrappers.""" +from viscy_models.foundation.cell_dino import CellDinoModel from viscy_models.foundation.dinov3 import DINOv3Model from viscy_models.foundation.openphenom import OpenPhenomModel -__all__ = ["DINOv3Model", "OpenPhenomModel"] +__all__ = ["CellDinoModel", "DINOv3Model", "OpenPhenomModel"] diff --git a/packages/viscy-models/src/viscy_models/foundation/_dinov2_vit.py b/packages/viscy-models/src/viscy_models/foundation/_dinov2_vit.py new file mode 100644 index 000000000..eddd3d016 --- /dev/null +++ b/packages/viscy-models/src/viscy_models/foundation/_dinov2_vit.py @@ -0,0 +1,344 @@ +"""Vendored DINOv2 vision transformer (eval-only subset) for CELL-DINO. + +Vendored from https://github.com/facebookresearch/dinov2 (Apache-2.0). +Stripped to what is needed to load CELL-DINO checkpoints and run frozen +inference: ``DinoVisionTransformer`` + ``Block`` + standard ``Attention`` +(no xformers, no NestedTensorBlock, no stochastic depth machinery, no +CausalAttentionBlock). + +Module layout matches the original so the published CELL-DINO state_dicts +(``cls_token``, ``pos_embed``, ``patch_embed.proj.*``, ``blocks.X.Y.*``, +``norm.*``) load cleanly. +""" + +from __future__ import annotations + +import math +from functools import partial +from typing import Callable, Optional, Tuple + +import torch +import torch.nn as nn +from torch import Tensor + + +def _make_2tuple(x: int | Tuple[int, int]) -> Tuple[int, int]: + if isinstance(x, tuple): + if len(x) != 2: + raise ValueError(f"expected 2-tuple, got {x!r}") + return x + return (x, x) + + +class PatchEmbed(nn.Module): + """2D image to patch embedding: (B, C, H, W) -> (B, N, D).""" + + def __init__( + self, + img_size: int | Tuple[int, int] = 224, + patch_size: int | Tuple[int, int] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable[..., nn.Module]] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + image_hw = _make_2tuple(img_size) + patch_hw = _make_2tuple(patch_size) + self.img_size = image_hw + self.patch_size = patch_hw + self.patches_resolution = (image_hw[0] // patch_hw[0], image_hw[1] // patch_hw[1]) + self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] + self.in_chans = in_chans + self.embed_dim = embed_dim + self.flatten_embedding = flatten_embedding + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_hw, stride=patch_hw) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, h, w = x.shape + ph, pw = self.patch_size + if h % ph != 0 or w % pw != 0: + raise ValueError(f"input {h}x{w} not divisible by patch {ph}x{pw}") + x = self.proj(x) + h_out, w_out = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, h_out, w_out, self.embed_dim) + return x + + +class LayerScale(nn.Module): + """Per-channel learnable gain applied after attention/MLP.""" + + def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(torch.full((dim,), float(init_values))) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Attention(nn.Module): + """Standard multi-head self-attention using PyTorch SDPA (no xformers).""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + b, n, c = x.shape + qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads) + q, k, v = torch.unbind(qkv, 2) + q, k, v = (t.transpose(1, 2) for t in (q, k, v)) + x = nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop if self.training else 0.0) + x = x.transpose(1, 2).contiguous().view(b, n, c) + return self.proj_drop(self.proj(x)) + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.drop(self.act(self.fc1(x))) + return self.drop(self.fc2(x)) + + +class Block(nn.Module): + """Pre-norm transformer block with LayerScale (eval-only, no drop_path).""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + init_values: Optional[float] = None, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + act_layer: Callable[..., nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, bias=ffn_bias) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + x = x + self.ls1(self.attn(self.norm1(x))) + x = x + self.ls2(self.mlp(self.norm2(x))) + return x + + +class BlockChunk(nn.ModuleList): + """Sequential block group used when ``block_chunks > 0`` (matches state_dict layout).""" + + def forward(self, x: Tensor) -> Tensor: + for blk in self: + x = blk(x) + return x + + +class DinoVisionTransformer(nn.Module): + """Eval-only DINOv2 ViT. + + Parameters + ---------- + img_size, patch_size, in_chans : int + Patch-embedding geometry. CELL-DINO ``channel_adaptive_dino_vitl16`` + was trained at ``img_size=224, patch_size=16, in_chans=1``. + embed_dim, depth, num_heads, mlp_ratio : int / float + Backbone size. ViT-L = ``(1024, 24, 16, 4.0)``. + init_values : float + LayerScale init. CELL-DINO uses ``1.0``. + block_chunks : int + Splits blocks into ``block_chunks`` ``BlockChunk`` groups so that + parameter names match published checkpoints (``blocks...*``). + CELL-DINO uses ``4``. + channel_adaptive : bool + Stored as ``self.bag_of_channels``. Wrapper code outside this module + is responsible for the ``(B,C,H,W) -> (B*C,1,H,W)`` reshape; this flag + is recorded so the wrapper can branch. + """ + + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + ffn_bias: bool = True, + proj_bias: bool = True, + init_values: Optional[float] = None, + block_chunks: int = 1, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + channel_adaptive: bool = False, + ) -> None: + super().__init__() + if num_register_tokens < 0: + raise ValueError("num_register_tokens must be >= 0") + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.num_features = self.embed_dim = embed_dim + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + self.bag_of_channels = channel_adaptive + + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + blocks_list = [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + norm_layer=norm_layer, + ) + for _ in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + chunked.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + def interpolate_pos_encoding(self, x: Tensor, w: int, h: int) -> Tensor: + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + n = self.pos_embed.shape[1] - 1 + if npatch == n and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos = pos_embed[:, 0] + patch_pos = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + m = int(math.sqrt(n)) + if n != m * m: + raise AssertionError("non-square positional grid not supported") + if self.interpolate_offset: + sx = float(w0 + self.interpolate_offset) / m + sy = float(h0 + self.interpolate_offset) / m + kwargs = {"scale_factor": (sx, sy)} + else: + kwargs = {"size": (w0, h0)} + patch_pos = nn.functional.interpolate( + patch_pos.reshape(1, m, m, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + patch_pos = patch_pos.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos.unsqueeze(0), patch_pos), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x: Tensor, masks: Optional[Tensor] = None) -> Tensor: + _, _, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + if self.register_tokens is not None: + x = torch.cat( + (x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), + dim=1, + ) + return x + + def forward_features(self, x: Tensor, masks: Optional[Tensor] = None) -> dict: + x = self.prepare_tokens_with_masks(x, masks) + for blk in self.blocks: + x = blk(x) + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def forward(self, x: Tensor, is_training: bool = False) -> Tensor | dict: + ret = self.forward_features(x) + return ret if is_training else self.head(ret["x_norm_clstoken"]) + + +def vit_large( + patch_size: int = 16, + in_chans: int = 3, + channel_adaptive: bool = False, + **kwargs, +) -> DinoVisionTransformer: + """ViT-L factory matching ``dinov2.models.vision_transformer.vit_large``.""" + return DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4.0, + in_chans=in_chans, + channel_adaptive=channel_adaptive, + **kwargs, + ) diff --git a/packages/viscy-models/src/viscy_models/foundation/cell_dino.py b/packages/viscy-models/src/viscy_models/foundation/cell_dino.py new file mode 100644 index 000000000..aee1240c0 --- /dev/null +++ b/packages/viscy-models/src/viscy_models/foundation/cell_dino.py @@ -0,0 +1,154 @@ +"""CELL-DINO foundation model wrapper for frozen feature extraction. + +CELL-DINO is a DINOv2-architecture ViT pretrained on fluorescence microscopy +(Human Protein Atlas). The ``channel_adaptive_dino_vitl16`` checkpoint +processes one channel at a time through a single-channel ViT-L/16 stem; the +wrapper reshapes ``(B, C, H, W) -> (B*C, 1, H, W)``, runs the backbone, and +mean-pools the cls token across channels to produce a fixed-dimension +embedding regardless of the input channel count. + +Weights are loaded from a local ``.pth`` state_dict; nothing is fetched +from the network. See +``/hpc/projects/organelle_phenotyping/models/CELL-DINO/model_weights/weights/`` +for the published checkpoints. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from viscy_models.foundation._dinov2_vit import vit_large + + +class CellDinoModel(nn.Module): + """Wrap CELL-DINO (channel-adaptive ViT-L/16) for microscopy embeddings. + + The model accepts raw dataloader tensors ``(B, C, D, H, W)`` directly in + :meth:`forward` — preprocessing is applied inline. Z-slice selection is + **not** handled here — configure ``z_range`` on the dataloader so it + delivers the correct focal plane. + + Parameters + ---------- + weights_path : str + Path to the local ``.pth`` state_dict. The default + ``channel_adaptive_dino_vitl16_pretrain_cells-ef7c17ff.pth`` is a + single-channel ViT-L/16 trained at 224 px with ``init_values=1.0`` + and ``block_chunks=4``. + img_size : int + Spatial size after :meth:`preprocess_2d`, by default ``224``. + patch_size : int + Patch size for the ViT, by default ``16``. + freeze : bool + If ``True`` (default), all backbone parameters are frozen and the + model is kept in eval mode. + projection : nn.Module or None + Optional trainable projection head applied to backbone features. + When provided, :meth:`forward` returns ``(features, projection(features))``. + When ``None`` (default), returns ``(features, features)``. + """ + + def __init__( + self, + weights_path: str, + img_size: int = 224, + patch_size: int = 16, + freeze: bool = True, + projection: nn.Module | None = None, + ) -> None: + super().__init__() + self.projection = projection + self.target_size = (img_size, img_size) + + self.model = vit_large( + patch_size=patch_size, + in_chans=1, + channel_adaptive=True, + img_size=img_size, + init_values=1.0, + block_chunks=4, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ) + + state_dict = torch.load(weights_path, map_location="cpu", weights_only=True) + if "state_dict" in state_dict and isinstance(state_dict["state_dict"], dict): + state_dict = state_dict["state_dict"] + missing, unexpected = self.model.load_state_dict(state_dict, strict=True) + if missing or unexpected: + raise RuntimeError(f"CELL-DINO state_dict mismatch — missing={missing}, unexpected={unexpected}") + + self.freeze = freeze + if freeze: + self.model.requires_grad_(False) + self.model.eval() + + def train(self, mode: bool = True) -> "CellDinoModel": + """Override train to keep backbone in eval when frozen.""" + super().train(mode) + if self.freeze: + self.model.eval() + return self + + def preprocess_2d(self, x: Tensor) -> Tensor: + """Convert a raw dataloader tensor to CELL-DINO input. + + Squeezes singleton Z (or takes the middle slice if Z>1), resizes to + ``self.target_size``, and per-image min/max scales each + ``(B*C)`` map to ``[0, 1]``. No ImageNet mean/std is applied — + CELL-DINO uses simple ``[0,1]`` normalization in training. + + Parameters + ---------- + x : Tensor + ``(B, C, D, H, W)`` or ``(B, C, H, W)``. + + Returns + ------- + Tensor + ``(B, C, H_target, W_target)`` ready for :meth:`forward`. The + wrapper reshapes ``(B, C, ...) -> (B*C, 1, ...)`` inside + :meth:`forward`, so this method preserves the channel axis. + """ + if x.ndim == 5: + if x.shape[2] == 1: + x = x[:, :, 0] + else: + x = x[:, :, x.shape[2] // 2] + + x = F.interpolate(x, size=self.target_size, mode="bilinear", align_corners=False) + + b, c, h, w = x.shape + x = x.view(b * c, 1, h, w) + x_min = x.amin(dim=(2, 3), keepdim=True) + x_max = x.amax(dim=(2, 3), keepdim=True) + x = (x - x_min) / (x_max - x_min).clamp(min=1e-8) + return x.view(b, c, h, w) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Run CELL-DINO on an image batch and mean-pool over channels. + + Preprocessing is applied inline, so raw dataloader tensors + ``(B, C, D, H, W)`` or ``(B, C, H, W)`` can be passed directly. + + Returns + ------- + tuple[Tensor, Tensor] + ``(features, projections)`` where features are the + channel-mean-pooled cls token of shape ``(B, 1024)``. If + ``projection`` was provided at init, projections are + ``self.projection(features)``; otherwise both elements are the + same features tensor. + """ + x = self.preprocess_2d(x) + b, c, h, w = x.shape + x = x.reshape(b * c, 1, h, w) + cls = self.model(x) + cls = cls.view(b, c, -1).mean(dim=1) + if self.projection is not None: + return (cls, self.projection(cls)) + return (cls, cls) diff --git a/packages/viscy-transforms/src/viscy_transforms/__init__.py b/packages/viscy-transforms/src/viscy_transforms/__init__.py index 560110d4c..af8a24a1e 100644 --- a/packages/viscy-transforms/src/viscy_transforms/__init__.py +++ b/packages/viscy-transforms/src/viscy_transforms/__init__.py @@ -70,12 +70,18 @@ from viscy_transforms._sharpen import BatchedRandSharpend from viscy_transforms._stack_channels import BatchedStackChannelsd, StackChannelsd from viscy_transforms._tiled_crop import TiledSpatialCropSamplesd +from viscy_transforms._z_reduction import ( + BatchedChannelWiseZReduction, + BatchedChannelWiseZReductiond, +) from viscy_transforms._zoom import BatchedZoom, BatchedZoomd from viscy_transforms._zstack_shift import BatchedRandZStackShiftd __version__ = version("viscy-transforms") __all__ = [ + "BatchedChannelWiseZReduction", + "BatchedChannelWiseZReductiond", "BatchedCenterSpatialCrop", "BatchedCenterSpatialCropd", "BatchedDivisibleCropd", diff --git a/packages/viscy-transforms/src/viscy_transforms/_z_reduction.py b/packages/viscy-transforms/src/viscy_transforms/_z_reduction.py new file mode 100644 index 000000000..398b08d05 --- /dev/null +++ b/packages/viscy-transforms/src/viscy_transforms/_z_reduction.py @@ -0,0 +1,117 @@ +"""Channel-wise Z-reduction transforms for 2D training from 3D z-stacks.""" + +from __future__ import annotations + +from collections.abc import Hashable + +import torch +from monai.transforms import MapTransform +from torch import Tensor + +__all__ = ["BatchedChannelWiseZReduction", "BatchedChannelWiseZReductiond"] + + +class BatchedChannelWiseZReduction: + """Reduce the Z dimension of a ``(B, C, Z, Y, X)`` tensor. + + Label-free samples get the center z-slice; fluorescence samples get a + max-intensity projection (MIP). A per-sample boolean mask selects the + strategy when the batch mixes both types. + + Parameters + ---------- + default_strategy : str + Strategy when no mask is provided: ``"mip"`` or ``"center"``. + """ + + def __init__(self, default_strategy: str = "mip") -> None: + if default_strategy not in ("mip", "center"): + raise ValueError(f"default_strategy must be 'mip' or 'center', got '{default_strategy}'") + self.default_strategy = default_strategy + + def __call__(self, img: Tensor, is_labelfree: Tensor | None = None) -> Tensor: + """Apply z-reduction. + + Parameters + ---------- + img : Tensor + Shape ``(B, C, Z, Y, X)``. + is_labelfree : Tensor or None + Boolean tensor of shape ``(B,)``. ``True`` → center-slice, + ``False`` → MIP. When ``None``, ``default_strategy`` is used + uniformly. + + Returns + ------- + Tensor + Shape ``(B, C, 1, Y, X)``. + """ + z = img.shape[2] + if z == 1: + return img + + if is_labelfree is None: + if self.default_strategy == "center": + return img[:, :, z // 2 : z // 2 + 1] + return img.amax(dim=2, keepdim=True) + + center = img[:, :, z // 2 : z // 2 + 1] + mip = img.amax(dim=2, keepdim=True) + mask = is_labelfree.view(-1, 1, 1, 1, 1) + return torch.where(mask, center, mip) + + +class BatchedChannelWiseZReductiond(MapTransform): + """Dict transform that applies channel-wise Z-reduction. + + In **bag-of-channels mode** each sample may represent a different channel. + The transform reads a ``_is_labelfree`` boolean tensor from the data dict + (injected by the datamodule) to decide per-sample strategy. + + In **all-channels mode** the dict keys identify channel type. Pass + ``labelfree_keys`` to specify which keys should use center-slice; all + others get MIP. + + Parameters + ---------- + keys : KeysCollection + Keys of the image tensors to transform. + labelfree_keys : list[str] or None + Channel keys that should use center-slice (all-channels mode). + When set, ``_is_labelfree`` in the data dict is ignored. + default_strategy : str + Fallback strategy when neither ``labelfree_keys`` nor + ``_is_labelfree`` can determine the channel type. + allow_missing_keys : bool + If ``True``, skip keys not present in the data dict. + """ + + def __init__( + self, + keys, + labelfree_keys: list[str] | None = None, + default_strategy: str = "mip", + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.labelfree_keys = set(labelfree_keys) if labelfree_keys is not None else None + self.reducer = BatchedChannelWiseZReduction(default_strategy=default_strategy) + + def __call__(self, data: dict[Hashable, Tensor]) -> dict[Hashable, Tensor]: + is_labelfree = data.pop("_is_labelfree", None) + + for key in self.key_iterator(data): + if self.labelfree_keys is not None: + # All-channels mode: strategy determined by key name. + img = data[key] + z = img.shape[2] + if z == 1: + continue + if key in self.labelfree_keys: + data[key] = img[:, :, z // 2 : z // 2 + 1] + else: + data[key] = img.amax(dim=2, keepdim=True) + else: + data[key] = self.reducer(data[key], is_labelfree=is_labelfree) + + return data diff --git a/packages/viscy-transforms/tests/test_z_reduction.py b/packages/viscy-transforms/tests/test_z_reduction.py new file mode 100644 index 000000000..dc47cd9c5 --- /dev/null +++ b/packages/viscy-transforms/tests/test_z_reduction.py @@ -0,0 +1,126 @@ +import pytest +import torch + +from viscy_transforms import BatchedChannelWiseZReduction, BatchedChannelWiseZReductiond + + +def _make_img(B=4, C=1, Z=11, Y=8, X=8): + """Create a test image with distinct z-slices for easy verification.""" + img = torch.randn(B, C, Z, Y, X) + return img + + +class TestBatchedChannelWiseZReduction: + def test_mip_only(self): + img = _make_img() + reducer = BatchedChannelWiseZReduction(default_strategy="mip") + out = reducer(img) + assert out.shape == (4, 1, 1, 8, 8) + expected = img.amax(dim=2, keepdim=True) + torch.testing.assert_close(out, expected) + + def test_center_only(self): + img = _make_img() + reducer = BatchedChannelWiseZReduction(default_strategy="center") + out = reducer(img) + assert out.shape == (4, 1, 1, 8, 8) + expected = img[:, :, 5:6] + torch.testing.assert_close(out, expected) + + def test_mixed_mask(self): + img = _make_img() + mask = torch.tensor([True, False, True, False]) + reducer = BatchedChannelWiseZReduction() + out = reducer(img, is_labelfree=mask) + assert out.shape == (4, 1, 1, 8, 8) + center = img[:, :, 5:6] + mip = img.amax(dim=2, keepdim=True) + torch.testing.assert_close(out[0], center[0]) + torch.testing.assert_close(out[1], mip[1]) + torch.testing.assert_close(out[2], center[2]) + torch.testing.assert_close(out[3], mip[3]) + + def test_noop_z1(self): + img = _make_img(Z=1) + reducer = BatchedChannelWiseZReduction() + out = reducer(img) + assert out.shape == img.shape + torch.testing.assert_close(out, img) + + def test_invalid_strategy(self): + with pytest.raises(ValueError): + BatchedChannelWiseZReduction(default_strategy="invalid") + + +class TestBatchedChannelWiseZReductiond: + def test_bag_of_channels_with_mask(self): + data = { + "channel_0": _make_img(), + "_is_labelfree": torch.tensor([True, False, False, True]), + } + transform = BatchedChannelWiseZReductiond(keys=["channel_0"]) + out = transform(data) + assert out["channel_0"].shape == (4, 1, 1, 8, 8) + assert "_is_labelfree" not in out + + def test_all_channels_with_labelfree_keys(self): + phase_img = _make_img() + fluor_img = _make_img() + expected_center = phase_img[:, :, 5:6].clone() + expected_mip = fluor_img.amax(dim=2, keepdim=True) + data = {"Phase3D": phase_img, "TOMM20": fluor_img} + transform = BatchedChannelWiseZReductiond( + keys=["Phase3D", "TOMM20"], + labelfree_keys=["Phase3D"], + ) + out = transform(data) + assert out["Phase3D"].shape == (4, 1, 1, 8, 8) + assert out["TOMM20"].shape == (4, 1, 1, 8, 8) + torch.testing.assert_close(out["Phase3D"], expected_center) + torch.testing.assert_close(out["TOMM20"], expected_mip) + + def test_pops_is_labelfree(self): + data = { + "channel_0": _make_img(), + "_is_labelfree": torch.tensor([False, False, False, False]), + } + transform = BatchedChannelWiseZReductiond(keys=["channel_0"]) + out = transform(data) + assert "_is_labelfree" not in out + + def test_missing_keys(self): + data = {"channel_0": _make_img()} + transform = BatchedChannelWiseZReductiond( + keys=["channel_0", "channel_1"], + allow_missing_keys=True, + ) + out = transform(data) + assert out["channel_0"].shape == (4, 1, 1, 8, 8) + assert "channel_1" not in out + + def test_noop_z1_dict(self): + data = {"channel_0": _make_img(Z=1)} + transform = BatchedChannelWiseZReductiond(keys=["channel_0"]) + out = transform(data) + assert out["channel_0"].shape == (4, 1, 1, 8, 8) + + def test_no_mask_uses_default(self): + img = _make_img() + expected = img[:, :, 5:6].clone() + data = {"channel_0": img} + transform = BatchedChannelWiseZReductiond(keys=["channel_0"], default_strategy="center") + out = transform(data) + torch.testing.assert_close(out["channel_0"], expected) + + def test_labelfree_keys_noop_z1(self): + data = { + "Phase3D": _make_img(Z=1), + "TOMM20": _make_img(Z=1), + } + transform = BatchedChannelWiseZReductiond( + keys=["Phase3D", "TOMM20"], + labelfree_keys=["Phase3D"], + ) + out = transform(data) + torch.testing.assert_close(out["Phase3D"], data["Phase3D"]) + torch.testing.assert_close(out["TOMM20"], data["TOMM20"]) diff --git a/packages/viscy-utils/pyproject.toml b/packages/viscy-utils/pyproject.toml index dde0f28ba..944cd1e0a 100644 --- a/packages/viscy-utils/pyproject.toml +++ b/packages/viscy-utils/pyproject.toml @@ -35,6 +35,8 @@ dependencies = [ "lightning>=2.3", "matplotlib>=3.10", "numpy>=2.4.1", + "onnx", + "onnxscript", "pyyaml", "scikit-image", "scipy", @@ -46,8 +48,13 @@ dependencies = [ ] optional-dependencies.all = [ "viscy-utils[anndata,eval]" ] -optional-dependencies.anndata = [ "anndata", "natsort" ] +# Cap anndata to <0.12.9 — 0.12.9 hard-requires pandas<3, but the rest of +# the stack runs on pandas 3 (see ArrowStringArray downcast workaround in +# callbacks/embedding_writer.py). Lift this cap once anndata 0.13 lands +# with native pandas-3 support. +optional-dependencies.anndata = [ "anndata<0.12.9", "natsort" ] optional-dependencies.eval = [ + "copairs", "phate", "scikit-learn", "umap-learn", @@ -61,7 +68,7 @@ scripts.viscy = "viscy_utils.cli:main" [dependency-groups] dev = [ { include-group = "test" } ] test = [ - "anndata", + "anndata<0.12.9", "natsort", "pytest>=9.0.2", "pytest-cov>=7", diff --git a/packages/viscy-utils/src/viscy_utils/callbacks/__init__.py b/packages/viscy-utils/src/viscy_utils/callbacks/__init__.py index 9a49d6db2..6e41540e4 100644 --- a/packages/viscy-utils/src/viscy_utils/callbacks/__init__.py +++ b/packages/viscy-utils/src/viscy_utils/callbacks/__init__.py @@ -2,12 +2,10 @@ from viscy_utils.callbacks.embedding_writer import EmbeddingWriter from viscy_utils.callbacks.online_eval import OnlineEvalCallback from viscy_utils.callbacks.prediction_writer import HCSPredictionWriter -from viscy_utils.callbacks.save_config_wandb import SaveConfigToWandb __all__ = [ "EmbeddingSnapshotCallback", "EmbeddingWriter", "OnlineEvalCallback", "HCSPredictionWriter", - "SaveConfigToWandb", ] diff --git a/packages/viscy-utils/src/viscy_utils/callbacks/embedding_writer.py b/packages/viscy-utils/src/viscy_utils/callbacks/embedding_writer.py index 7784a25f4..373507b8f 100644 --- a/packages/viscy-utils/src/viscy_utils/callbacks/embedding_writer.py +++ b/packages/viscy-utils/src/viscy_utils/callbacks/embedding_writer.py @@ -156,8 +156,15 @@ def write_embedding_dataset( ultrack_indices = index_df.copy() ultrack_indices["fov_name"] = ultrack_indices["fov_name"].str.strip("/") - for col in ultrack_indices.select_dtypes(include="string").columns: - ultrack_indices[col] = ultrack_indices[col].astype(object) + # TODO: remove once anndata 0.13 supports pandas 3 Arrow-backed strings natively. + # anndata 0.12.9+ requires pandas <3, so we stay on 0.12.6 + pandas 3 and + # must manually downcast ArrowStringArray columns to object dtype before writing. + for col in ultrack_indices.columns: + s = ultrack_indices[col] + if isinstance(s.dtype, pd.StringDtype): + ultrack_indices[col] = s.astype(object) + elif hasattr(s, "cat") and isinstance(s.cat.categories.dtype, pd.StringDtype): + ultrack_indices[col] = s.cat.rename_categories(s.cat.categories.astype(object)) if embedding_key == "projections": if projections is None: diff --git a/packages/viscy-utils/src/viscy_utils/callbacks/online_eval.py b/packages/viscy-utils/src/viscy_utils/callbacks/online_eval.py index bf4045454..817fa9a78 100644 --- a/packages/viscy-utils/src/viscy_utils/callbacks/online_eval.py +++ b/packages/viscy-utils/src/viscy_utils/callbacks/online_eval.py @@ -20,8 +20,9 @@ import torch from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import Callback +from lightning_utilities.core.rank_zero import rank_zero_warn from scipy.stats import spearmanr -from sklearn.model_selection import cross_val_score +from sklearn.model_selection import cross_val_score, train_test_split from sklearn.neighbors import KNeighborsClassifier from viscy_data._typing import TripletSample @@ -48,8 +49,22 @@ def effective_rank(features: np.ndarray) -> float: float Effective rank (scalar >= 1). """ + # Guard against NaN/Inf in features — np.linalg.svd raises + # "SVD did not converge" on non-finite input, which crashes the whole + # run from inside a validation callback. Drop affected rows and return + # NaN when no finite rows remain. Under DDP every rank computes on the + # all-gathered full set, so the warning would otherwise fire once per + # rank with identical content — emit only from rank 0. + finite_mask = np.isfinite(features).all(axis=1) + if not finite_mask.all(): + rank_zero_warn(f"effective_rank: {(~finite_mask).sum()}/{len(features)} rows contain NaN/Inf; skipping those") + features = features[finite_mask] + if features.shape[0] < 2: + return float("nan") _, s, _ = np.linalg.svd(features, full_matrices=False) s = s[s > 1e-10] + if s.size == 0: + return float("nan") p = s / s.sum() entropy = -(p * np.log(p)).sum() return float(np.exp(entropy)) @@ -114,11 +129,15 @@ class OnlineEvalCallback(Callback): Accumulates validation embeddings every ``every_n_epochs`` epochs and computes three metrics: - - ``metrics/knn_acc/{label_key}/val`` — k-NN accuracy (5-fold CV) + - ``metrics/knn_acc/{label_key}/val`` — k-NN accuracy (5-fold CV or + stratified holdout, configurable via ``knn_eval_mode``) - ``metrics/effective_rank/val`` — effective rank of covariance - ``metrics/temporal_smoothness/val`` — Spearman rho (distance vs dt) - Only rank 0 computes metrics. Safe for DDP training. + Under DDP, features and metadata are ``all_gather``-ed across ranks + so every rank computes metrics on the full validation set; the + resulting per-rank scalars are identical and ``sync_dist=True`` + averages them as a no-op. Safe for single-GPU and DDP training. Parameters ---------- @@ -133,6 +152,15 @@ class OnlineEvalCallback(Callback): Metadata key for track identity (temporal smoothness). timepoint_key : str Metadata key for timepoint (temporal smoothness). + knn_eval_mode : {"cv", "holdout"} + How to score the k-NN probe. ``"cv"`` runs 5-fold stratified CV + (default; good for few-class probes like 40 markers). ``"holdout"`` + runs a single stratified 80/20 train/test split — ~5x cheaper and + tolerates classes with only 2 samples, which is the right choice + for many-class probes (e.g. 1001-gene perturbation). + holdout_test_size : float + Fraction of samples held out for scoring when + ``knn_eval_mode="holdout"``. Ignored in CV mode. """ def __init__( @@ -142,6 +170,8 @@ def __init__( k: int = 20, track_id_key: str = "global_track_id", timepoint_key: str = "t", + knn_eval_mode: Literal["cv", "holdout"] = "cv", + holdout_test_size: float = 0.2, ): super().__init__() self.every_n_epochs = every_n_epochs @@ -149,6 +179,8 @@ def __init__( self.k = k self.track_id_key = track_id_key self.timepoint_key = timepoint_key + self.knn_eval_mode = knn_eval_mode + self.holdout_test_size = holdout_test_size self._collecting = False self._features: list[torch.Tensor] = [] self._meta: list[dict] = [] @@ -184,17 +216,42 @@ def on_validation_batch_end( self._meta.extend(batch.get("anchor_meta", [])) def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Compute and log metrics on rank 0.""" + """Compute and log metrics on the full validation set under DDP. + + The validation DataLoader is sharded across ranks (each rank sees + N/world_size samples), so per-rank metrics are not the metrics of + the full set: ``effective_rank``, k-NN CV/holdout accuracy, and + Spearman ``rho`` are non-linear in the sample set, and averaging + per-shard values via ``sync_dist=True`` is statistically wrong. + + Fix: ``all_gather`` features and the per-sample arrays needed for + each metric across ranks, then compute on the full set on every + rank. Every rank produces the same scalar, so ``sync_dist=True`` + becomes a no-op average of identical values — but keeping it + avoids the rank-0-only ``pl_module.log`` DDP deadlock (rank 0 + registers the metric, other ranks don't, and Lightning's + epoch-end all-reduce never completes). + """ if not self._collecting or not self._features: self._reset() return - if trainer.global_rank != 0: - self._reset() - return - features_np = to_numpy(torch.cat(self._features)) + features_local = torch.cat(self._features) + labels_local = self._extract_array(self.label_key, source="labels") + track_ids_local = self._extract_array(self.track_id_key, source="meta") + timepoints_local = self._extract_array(self.timepoint_key, source="meta") + + features_np, labels, track_ids, timepoints = self._gather_across_ranks( + pl_module, + features_local, + labels_local, + track_ids_local, + timepoints_local, + ) + n_samples = features_np.shape[0] epoch = trainer.current_epoch + is_rank_zero = trainer.global_rank == 0 # --- Effective rank (always computable) --- erank = effective_rank(features_np) @@ -203,38 +260,61 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) erank, on_epoch=True, logger=True, - rank_zero_only=True, + sync_dist=True, ) - _logger.info(f"[OnlineEval epoch {epoch}] effective_rank={erank:.1f} (n={n_samples}, d={features_np.shape[1]})") + if is_rank_zero: + _logger.info( + f"[OnlineEval epoch {epoch}] effective_rank={erank:.1f} (n={n_samples}, d={features_np.shape[1]})" + ) # --- k-NN accuracy (requires labels) --- - labels = self._extract_array(self.label_key, source="labels") if labels is not None and len(np.unique(labels)) >= 2: k = min(self.k, n_samples - 1) knn = KNeighborsClassifier(n_neighbors=k, metric="cosine") - cv_folds = min(5, min(np.bincount(labels))) - if cv_folds >= 2: + min_class_count = int(min(np.bincount(labels))) + mode = self.knn_eval_mode + # Auto-degrade CV -> holdout when the smallest class has < 2 + # samples (CV would skip silently). Holdout mode still requires + # >= 2 per class for stratified splitting. + if mode == "cv" and min_class_count < 2: + mode = "holdout" + if mode == "cv": + cv_folds = min(5, min_class_count) scores = cross_val_score(knn, features_np, labels, cv=cv_folds) knn_acc = float(scores.mean()) + eval_desc = f"cv={cv_folds}" + elif mode == "holdout" and min_class_count >= 2: + x_train, x_test, y_train, y_test = train_test_split( + features_np, + labels, + test_size=self.holdout_test_size, + stratify=labels, + random_state=0, + ) + knn.fit(x_train, y_train) + knn_acc = float(knn.score(x_test, y_test)) + eval_desc = f"holdout={self.holdout_test_size:.2f}" + else: + knn_acc = None + if is_rank_zero: + _logger.debug( + f"[OnlineEval epoch {epoch}] Skipping k-NN: " + f"smallest class has {min_class_count} samples (need >=2)." + ) + if knn_acc is not None: pl_module.log( f"metrics/knn_acc/{self.label_key}/val", knn_acc, on_epoch=True, logger=True, - rank_zero_only=True, - ) - _logger.info( - f"[OnlineEval epoch {epoch}] knn_acc({self.label_key}, k={k})={knn_acc:.3f} (cv={cv_folds})" - ) - else: - _logger.debug( - f"[OnlineEval epoch {epoch}] Skipping k-NN: " - f"smallest class has {min(np.bincount(labels))} samples (need >=2)." + sync_dist=True, ) + if is_rank_zero: + _logger.info( + f"[OnlineEval epoch {epoch}] knn_acc({self.label_key}, k={k})={knn_acc:.3f} ({eval_desc})" + ) # --- Temporal smoothness (requires track_id + timepoint) --- - track_ids = self._extract_array(self.track_id_key, source="meta") - timepoints = self._extract_array(self.timepoint_key, source="meta") if track_ids is not None and timepoints is not None: rho = temporal_smoothness(features_np, track_ids, timepoints) if not np.isnan(rho): @@ -243,9 +323,10 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) rho, on_epoch=True, logger=True, - rank_zero_only=True, + sync_dist=True, ) - _logger.info(f"[OnlineEval epoch {epoch}] temporal_smoothness={rho:.3f}") + if is_rank_zero: + _logger.info(f"[OnlineEval epoch {epoch}] temporal_smoothness={rho:.3f}") self._reset() @@ -275,3 +356,98 @@ def _extract_array(self, key: str, source: Literal["labels", "meta"] = "meta") - return None values.append(v) return np.array(values) + + def _gather_across_ranks( + self, + pl_module: LightningModule, + features_local: torch.Tensor, + labels_local: np.ndarray | None, + track_ids_local: np.ndarray | None, + timepoints_local: np.ndarray | None, + ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None, np.ndarray | None]: + """Gather per-rank validation arrays into a full-set view on every rank. + + On single-GPU runs this is a pure passthrough that converts the + feature tensor to numpy. Under DDP, every rank participates in + ``pl_module.all_gather`` so each ends up with the same + concatenated arrays — the metrics computed downstream are then + deterministic and identical on every rank, and the subsequent + ``sync_dist=True`` log call averages identical scalars (a no-op + that still pleases Lightning's epoch-end all-reduce). + + Per-rank shard sizes can differ by one batch when the dataset + size is not divisible by ``world_size * batch_size``; equalize + by truncating to the minimum length before the gather to keep + ``all_gather`` happy with a fixed-shape tensor. + + Parameters + ---------- + pl_module : LightningModule + Used for ``pl_module.all_gather`` (Lightning routes to the + correct backend) and ``pl_module.device``. + features_local : torch.Tensor + ``(n_local, d)`` features collected on this rank. + labels_local, track_ids_local, timepoints_local : np.ndarray or None + Per-rank metadata arrays, ``None`` if the corresponding key + was missing. + + Returns + ------- + features_np : np.ndarray + Full-set features ``(N, d)`` (same on every rank under DDP). + labels, track_ids, timepoints : np.ndarray or None + Full-set metadata arrays. ``None`` if any rank reported + ``None`` (handled via an explicit availability all-reduce). + """ + world_size = getattr(pl_module.trainer, "world_size", 1) or 1 + if world_size <= 1: + return ( + to_numpy(features_local), + labels_local, + track_ids_local, + timepoints_local, + ) + + # Equalize shard sizes — all_gather requires identical shapes per rank. + n_local = torch.tensor([features_local.shape[0]], device=pl_module.device) + n_per_rank = pl_module.all_gather(n_local).flatten() + n_min = int(n_per_rank.min().item()) + features_local = features_local[:n_min].to(pl_module.device) + + # Reduce a per-rank availability flag for each metadata array so + # all ranks agree on whether to compute the dependent metric. + # If any rank is missing the key, treat as missing globally. + def _gather_optional(arr: np.ndarray | None) -> np.ndarray | None: + available = torch.tensor( + [1 if arr is not None else 0], + device=pl_module.device, + ) + available = pl_module.all_gather(available).flatten() + if int(available.min().item()) == 0: + return None + arr_local = arr[:n_min] + # String/object dtypes cannot be converted to a torch tensor, + # so gather them via torch.distributed.all_gather_object as + # Python lists, then re-pack to a numpy array. Numeric dtypes + # take the fast tensor-based path. + if arr_local.dtype.kind in {"U", "S", "O"}: + gathered_list: list[np.ndarray] = [None] * world_size # type: ignore[list-item] + torch.distributed.all_gather_object(gathered_list, arr_local) + return np.concatenate(gathered_list, axis=0) + tensor = torch.as_tensor(arr_local, device=pl_module.device) + gathered = pl_module.all_gather(tensor) + # all_gather returns shape (world_size, n_min, *rest) for + # 1D inputs — collapse to (world_size * n_min, *rest). + gathered = gathered.reshape(-1, *tensor.shape[1:]) if tensor.ndim > 1 else gathered.reshape(-1) + return gathered.detach().cpu().numpy() + + features_gathered = pl_module.all_gather(features_local) + # all_gather on (n_min, d) returns (world_size, n_min, d) — flatten. + features_gathered = features_gathered.reshape(-1, features_gathered.shape[-1]) + + return ( + features_gathered.detach().cpu().numpy(), + _gather_optional(labels_local), + _gather_optional(track_ids_local), + _gather_optional(timepoints_local), + ) diff --git a/packages/viscy-utils/src/viscy_utils/callbacks/save_config_wandb.py b/packages/viscy-utils/src/viscy_utils/callbacks/save_config_wandb.py deleted file mode 100644 index cec542678..000000000 --- a/packages/viscy-utils/src/viscy_utils/callbacks/save_config_wandb.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Save resolved Lightning config to W&B files.""" - -from __future__ import annotations - -import logging -from pathlib import Path - -from lightning.pytorch import Callback, Trainer -from lightning.pytorch.loggers import WandbLogger - -logger = logging.getLogger(__name__) - - -class SaveConfigToWandb(Callback): - """Upload the resolved config.yaml to W&B so it appears in the Files tab. - - Lightning's SaveConfigCallback writes config.yaml to ``trainer.log_dir``, - but WandbLogger does not sync arbitrary files from that directory. - This callback copies it into the W&B run's files directory on fit start. - """ - - def setup(self, trainer: Trainer, pl_module, stage: str) -> None: - """Copy config.yaml to W&B run files on fit start.""" - if stage != "fit": - return - wandb_logger = None - for lg in trainer.loggers: - if isinstance(lg, WandbLogger): - wandb_logger = lg - break - if wandb_logger is None: - return - config_path = Path(trainer.log_dir) / "config.yaml" - if not config_path.exists(): - logger.debug("No config.yaml found at %s, skipping W&B upload.", config_path) - return - run = wandb_logger.experiment - run.save(str(config_path), base_path=str(config_path.parent), policy="now") - logger.info("Uploaded %s to W&B run %s.", config_path, run.id) diff --git a/packages/viscy-utils/src/viscy_utils/cli.py b/packages/viscy-utils/src/viscy_utils/cli.py index 1babc02aa..f753b9734 100644 --- a/packages/viscy-utils/src/viscy_utils/cli.py +++ b/packages/viscy-utils/src/viscy_utils/cli.py @@ -13,9 +13,8 @@ import yaml from jsonargparse import Namespace, lazy_instance from lightning.pytorch import LightningDataModule, LightningModule -from lightning.pytorch.callbacks import TQDMProgressBar from lightning.pytorch.cli import LightningCLI -from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.loggers import WandbLogger from viscy_utils.compose import load_composed_config from viscy_utils.trainer import VisCyTrainer @@ -84,18 +83,12 @@ def subcommands() -> dict[str, set[str]]: return subcommands def add_arguments_to_parser(self, parser) -> None: - """Set default logger and progress bar.""" - defaults = { - "trainer.logger": lazy_instance( - TensorBoardLogger, - save_dir="", - version=datetime.now().strftime(r"%Y%m%d-%H%M%S"), - log_graph=True, - ), - } - if not sys.stdout.isatty(): - defaults["trainer.callbacks"] = [lazy_instance(TQDMProgressBar, refresh_rate=10, leave=True)] - parser.set_defaults(defaults) + """Set default logger.""" + parser.set_defaults( + { + "trainer.logger": lazy_instance(WandbLogger), + } + ) def _parse_ckpt_path(self) -> None: try: diff --git a/packages/viscy-utils/src/viscy_utils/cli_utils.py b/packages/viscy-utils/src/viscy_utils/cli_utils.py index 78903f48b..73183a80f 100644 --- a/packages/viscy-utils/src/viscy_utils/cli_utils.py +++ b/packages/viscy-utils/src/viscy_utils/cli_utils.py @@ -2,12 +2,10 @@ from pathlib import Path -import yaml +from viscy_utils.compose import load_composed_config -def format_markdown_table( - data: dict | list[dict], title: str = None, headers: list[str] = None -) -> str: +def format_markdown_table(data: dict | list[dict], title: str = None, headers: list[str] = None) -> str: """Format data as a markdown table. Parameters @@ -71,7 +69,15 @@ def format_markdown_table( def load_config(config_path: str | Path) -> dict: - """Load YAML configuration file. + """Load a YAML configuration file with optional recipe composition. + + A top-level ``base:`` key is interpreted as a list of relative paths + to other YAML files that are merged before this file's own keys + (later entries override earlier ones; this file overrides the bases). + YAML files without a ``base:`` key behave identically to + ``yaml.safe_load`` — there is no special handling beyond that one + key. See ``viscy_utils.compose.load_composed_config`` for the merge + rules. Parameters ---------- @@ -81,18 +87,54 @@ def load_config(config_path: str | Path) -> dict: Returns ------- dict - Configuration dictionary. + Composed configuration dictionary. Raises ------ FileNotFoundError - If the config file does not exist. - yaml.YAMLError - If the YAML file is malformed. + If the config file (or any referenced base) does not exist. """ - config_path = Path(config_path) - if not config_path.exists(): - raise FileNotFoundError(f"Config file not found: {config_path}") + return load_composed_config(Path(config_path)) + + +def load_config_section(config_path: str | Path, section: str | None, default_section: str | None = None) -> dict: + """Load a YAML config file, optionally selecting a subsection. + + This enables reusing a single YAML file for multiple CLI steps by storing + per-command configuration under a top-level key (``section``), while keeping + shared keys (e.g., ``datasets``) at the root. - with open(config_path, "r") as f: - return yaml.safe_load(f) + Parameters + ---------- + config_path : str | Path + Path to YAML configuration file. + section : str | None + If provided, selects ``config[section]`` and merges in any shared root + keys that are not already present in the section. + default_section : str | None + If ``section`` is None and ``default_section`` exists in the YAML, that section is used. + + Returns + ------- + dict + Configuration dictionary (either full or merged subsection). + """ + cfg = load_config(config_path) + if section is None: + if default_section is None or default_section not in cfg: + return cfg + section = default_section + + if section not in cfg: + raise KeyError(f"Config section not found: {section}") + + section_cfg = cfg[section] or {} + if not isinstance(section_cfg, dict): + raise TypeError(f"Config section must be a mapping: {section}") + + merged = dict(section_cfg) + for k, v in cfg.items(): + if k == section: + continue + merged.setdefault(k, v) + return merged diff --git a/packages/viscy-utils/src/viscy_utils/evaluation/annotation.py b/packages/viscy-utils/src/viscy_utils/evaluation/annotation.py index 91a5af9c3..c3a0e4566 100644 --- a/packages/viscy-utils/src/viscy_utils/evaluation/annotation.py +++ b/packages/viscy-utils/src/viscy_utils/evaluation/annotation.py @@ -129,9 +129,24 @@ def load_annotation_anndata(adata: ad.AnnData, path: str, name: str, categories: annotation = pd.read_csv(path) annotation["fov_name"] = annotation["fov_name"].str.strip("/") - annotation = annotation.set_index(["fov_name", "id"]) - - mi = pd.MultiIndex.from_arrays([adata.obs["fov_name"], adata.obs["id"]], names=["fov_name", "id"]) + # Normalize obs fov_name: strip leading/trailing slashes so both sides match. + obs_fov = adata.obs["fov_name"].astype(object).str.strip("/") + + if "id" in adata.obs.columns and "id" in annotation.columns: + annotation = annotation.set_index(["fov_name", "id"]) + mi = pd.MultiIndex.from_arrays([obs_fov, adata.obs["id"]], names=["fov_name", "id"]) + elif all(c in adata.obs.columns for c in ("fov_name", "t", "track_id")) and all( + c in annotation.columns for c in ("fov_name", "t", "track_id") + ): + annotation = annotation.set_index(["fov_name", "t", "track_id"]) + mi = pd.MultiIndex.from_arrays( + [obs_fov, adata.obs["t"], adata.obs["track_id"]], + names=["fov_name", "t", "track_id"], + ) + else: + raise KeyError( + "Cannot join annotations: embeddings have neither (fov_name, id) nor (fov_name, t, track_id) columns." + ) # Use reindex to handle missing annotations gracefully # This will return NaN for observations that don't have annotations diff --git a/packages/viscy-utils/src/viscy_utils/evaluation/dimensionality_reduction.py b/packages/viscy-utils/src/viscy_utils/evaluation/dimensionality_reduction.py index bbcf690a8..bec3c8c9c 100644 --- a/packages/viscy-utils/src/viscy_utils/evaluation/dimensionality_reduction.py +++ b/packages/viscy-utils/src/viscy_utils/evaluation/dimensionality_reduction.py @@ -2,6 +2,7 @@ import logging +import numpy as np import pandas as pd from numpy.typing import NDArray from xarray import Dataset @@ -18,6 +19,11 @@ def compute_phate( knn_dist: str = "cosine", update_dataset: bool = False, random_state: int = 42, + n_pca: int | None = 50, + subsample: int | None = None, + lineage_ids: NDArray | None = None, + fit_idx: NDArray | None = None, + n_jobs: int = 1, **phate_kwargs, ) -> tuple[object, NDArray]: """Compute PHATE embeddings. @@ -38,6 +44,24 @@ def compute_phate( Whether to update the dataset, by default False. random_state : int, optional Random state, by default 42. + n_pca : int or None, optional + Pre-reduce to this many PCs inside PHATE before graph construction. + Pass ``None`` to skip PHATE's internal PCA — useful when the input is + already PCA-reduced (e.g., ``X_pca_combined``), and avoids the + scipy 1.17.1 / sklearn 1.8.0 ``scipy.linalg.lu`` deadlock that hangs + PHATE's internal pre-PCA step at near-zero CPU. Default 50. + subsample : int or None, optional + Lineage- or row-level subsample drawn here. Mutually exclusive with + ``fit_idx``: when ``fit_idx`` is provided, ``subsample`` is ignored. + lineage_ids : NDArray or None, optional + Per-row lineage identifier (e.g., ``f"{fov}|{track_id}"``). Used by + the internal ``subsample`` path to draw whole lineages instead of + random rows. Ignored when ``fit_idx`` is provided. + fit_idx : NDArray or None, optional + Precomputed row indices to fit PHATE on. When provided, PHATE fits on + ``embeddings[fit_idx]`` and transforms the full input — the caller + owns the subsampling policy (e.g., per-store cap). Disables this + function's internal subsampling. Returns ------- @@ -72,11 +96,35 @@ def compute_phate( decay=decay, knn_dist=knn_dist, random_state=random_state, - n_jobs=-1, + n_jobs=n_jobs, + n_pca=n_pca, **phate_kwargs, ) - phate_embedding = phate_model.fit_transform(embeddings_scaled) + n_samples = embeddings_scaled.shape[0] + if fit_idx is not None: + # Caller-provided fit set — owns the subsampling policy. + _logger.info(f"PHATE: fitting on caller-supplied {len(fit_idx):,} / {n_samples:,} cells, projecting the rest") + phate_model.fit(embeddings_scaled[fit_idx]) + phate_embedding = phate_model.transform(embeddings_scaled) + elif subsample is not None and subsample < n_samples: + rng = np.random.default_rng(random_state) + if lineage_ids is not None: + unique_lineages = np.unique(lineage_ids) + n_lineages = min(subsample, len(unique_lineages)) + chosen_lineages = rng.choice(unique_lineages, size=n_lineages, replace=False) + idx = np.where(np.isin(lineage_ids, chosen_lineages))[0] + _logger.info( + f"PHATE: fitting on {len(idx):,} cells ({n_lineages:,} lineages) " + f"/ {n_samples:,} total, projecting the rest" + ) + else: + idx = rng.choice(n_samples, size=subsample, replace=False) + _logger.info(f"PHATE: fitting on {subsample:,} / {n_samples:,} cells, projecting the rest") + phate_model.fit(embeddings_scaled[idx]) + phate_embedding = phate_model.transform(embeddings_scaled) + else: + phate_embedding = phate_model.fit_transform(embeddings_scaled) if update_dataset and isinstance(embedding_dataset, Dataset): for i in range(min(2, phate_embedding.shape[1])): diff --git a/packages/viscy-utils/src/viscy_utils/evaluation/embedding_map.py b/packages/viscy-utils/src/viscy_utils/evaluation/embedding_map.py new file mode 100644 index 000000000..7952738c1 --- /dev/null +++ b/packages/viscy-utils/src/viscy_utils/evaluation/embedding_map.py @@ -0,0 +1,120 @@ +"""Embedding-level mean Average Precision (mAP) via copairs.""" + +from __future__ import annotations + +import numpy as np +import pandas as pd + + +def compute_embedding_map( + meta: pd.DataFrame, + features: np.ndarray, + reference_condition: str, + target_condition: str, + condition_col: str = "condition", + group_col: str = "marker", + distance: str = "cosine", + null_size: int = 10000, + seed: int = 0, +) -> dict | None: + """Compute mean Average Precision for embedding-space phenotypic profiling. + + Uses ``copairs`` to compute per-cell Average Precision (AP) between a + reference and target condition, then aggregates to mAP per group. Positive + pairs share the same group and condition; negative pairs share only the group + but differ in condition. + + Parameters + ---------- + meta : pd.DataFrame + Cell metadata, one row per cell. Must contain ``condition_col`` and + ``group_col`` columns. + features : np.ndarray + Embedding matrix, shape (n_cells, n_features). Rows correspond to + ``meta`` rows. + reference_condition : str + Value of ``condition_col`` for the reference/control group (``cond_a``). + target_condition : str + Value of ``condition_col`` for the treatment group (``cond_b``). + condition_col : str + Column in ``meta`` that holds condition labels. Default: ``"condition"``. + group_col : str + Column in ``meta`` that holds group labels (e.g. marker/organelle). + Default: ``"marker"``. + distance : str + Distance metric for copairs (e.g. ``"cosine"``). Default: ``"cosine"``. + null_size : int + Number of null pairs for the mAP significance test. Default: 10000. + seed : int + Random seed. Default: 0. + + Returns + ------- + dict or None + ``{"mean_average_precision": float, "p_value": float, + "n_reference": int, "n_target": int}`` or ``None`` if either condition + has no cells. + """ + try: + import copairs.map + import copairs.matching + except ImportError as e: + raise ImportError("copairs is required for mAP computation. Install it with: pip install copairs") from e + + mask_ref = meta[condition_col] == reference_condition + mask_tgt = meta[condition_col] == target_condition + mask = mask_ref | mask_tgt + + if mask_ref.sum() == 0 or mask_tgt.sum() == 0: + return None + + sub_meta = meta[mask].reset_index(drop=True) + sub_feats = features[mask.values] + + reference_col = "reference_index" + sub_meta = sub_meta.copy() + sub_meta[reference_col] = copairs.matching.assign_reference_index( + sub_meta, reference_condition, condition_col, group_col + ) + + pos_sameby = [group_col, condition_col, reference_col] + neg_sameby = [group_col] + neg_diffby = [condition_col, reference_col] + + ap_df = copairs.map.average_precision( + sub_meta, + sub_feats, + pos_sameby=pos_sameby, + neg_sameby=neg_sameby, + neg_diffby=neg_diffby, + batch_size=20000, + distance=distance, + ) + + target_ap = ap_df[sub_meta[condition_col] == target_condition] + if len(target_ap) == 0: + return None + + map_result = copairs.map.mean_average_precision( + target_ap, + sameby=[group_col], + null_size=null_size, + threshold=0.05, + seed=seed, + ) + + if hasattr(map_result, "mean_average_precision"): + mmap = float(map_result.mean_average_precision.iloc[0]) + pval = float(map_result.p_value.iloc[0]) if "p_value" in map_result.columns else float("nan") + elif isinstance(map_result, dict): + mmap = float(map_result.get("mean_average_precision", float("nan"))) + pval = float(map_result.get("p_value", float("nan"))) + else: + return None + + return { + "mean_average_precision": mmap, + "p_value": pval, + "n_reference": int(mask_ref.sum()), + "n_target": int(mask_tgt.sum()), + } diff --git a/packages/viscy-utils/src/viscy_utils/evaluation/linear_classifier.py b/packages/viscy-utils/src/viscy_utils/evaluation/linear_classifier.py index 9bdc0bd35..35b5a7857 100644 --- a/packages/viscy-utils/src/viscy_utils/evaluation/linear_classifier.py +++ b/packages/viscy-utils/src/viscy_utils/evaluation/linear_classifier.py @@ -203,7 +203,7 @@ def train_linear_classifier( classifier_params: Optional[dict[str, Any]] = None, split_train_data: float = 0.8, random_seed: int = 42, -) -> tuple[LinearClassifierPipeline, dict[str, float]]: +) -> tuple[LinearClassifierPipeline, dict[str, float], dict[str, Any]]: """Train a linear classifier on embeddings with preprocessing and evaluation. Parameters @@ -231,6 +231,9 @@ def train_linear_classifier( Trained classifier pipeline with preprocessing. dict Dictionary of evaluation metrics (train and validation if split). + dict + Raw validation outputs for plotting: ``y_val``, ``y_val_proba``, + ``classes``. Values are ``None`` when no validation split was made. """ print("\n" + "=" * 60) print("TRAINING CLASSIFIER") @@ -314,8 +317,10 @@ def train_linear_classifier( train_metrics[f"train_{class_name}_precision"] = train_report[class_name]["precision"] train_metrics[f"train_{class_name}_recall"] = train_report[class_name]["recall"] train_metrics[f"train_{class_name}_f1"] = train_report[class_name]["f1-score"] + train_metrics[f"train_{class_name}_support"] = int(train_report[class_name]["support"]) val_metrics = {} + y_val_proba: Optional[np.ndarray] = None if X_val is not None and y_val is not None: y_val_pred = classifier.predict(X_val) val_report = classification_report(y_val, y_val_pred, digits=3, output_dict=True) @@ -336,6 +341,15 @@ def train_linear_classifier( else: val_metrics["val_auroc"] = roc_auc_score(y_val, y_val_proba, multi_class="ovr", average="macro") print(f" Val AUROC: {val_metrics['val_auroc']:.3f}") + + if len(classifier.classes_) > 2: + for i, class_name in enumerate(classifier.classes_): + try: + val_metrics[f"val_{class_name}_auroc"] = roc_auc_score( + (y_val == class_name).astype(int), y_val_proba[:, i] + ) + except ValueError: + pass except ValueError as e: _logger.warning(f"Could not compute val AUROC (likely only one class present): {e}") @@ -344,6 +358,7 @@ def train_linear_classifier( val_metrics[f"val_{class_name}_precision"] = val_report[class_name]["precision"] val_metrics[f"val_{class_name}_recall"] = val_report[class_name]["recall"] val_metrics[f"val_{class_name}_f1"] = val_report[class_name]["f1-score"] + val_metrics[f"val_{class_name}_support"] = int(val_report[class_name]["support"]) all_metrics = {**train_metrics, **val_metrics} @@ -365,7 +380,13 @@ def train_linear_classifier( task=task, ) - return pipeline, all_metrics + val_outputs: dict[str, Any] = { + "y_val": y_val, + "y_val_proba": y_val_proba, + "classes": classifier.classes_.tolist(), + } + + return pipeline, all_metrics, val_outputs def predict_with_classifier( diff --git a/packages/viscy-utils/src/viscy_utils/evaluation/mmd.py b/packages/viscy-utils/src/viscy_utils/evaluation/mmd.py new file mode 100644 index 000000000..e206dcdc6 --- /dev/null +++ b/packages/viscy-utils/src/viscy_utils/evaluation/mmd.py @@ -0,0 +1,199 @@ +"""Maximum Mean Discrepancy (MMD) with Gaussian RBF kernel and permutation test.""" + +import numpy as np +from numpy.typing import NDArray +from scipy.spatial.distance import cdist + + +def median_heuristic(X: NDArray, Y: NDArray, subsample: int = 1000) -> float: + """Compute Gaussian RBF bandwidth via the median heuristic. + + Subsamples jointly from X and Y, computes all pairwise squared Euclidean + distances, and returns the median. This is the standard bandwidth selection + for MMD tests (Gretton et al., 2012). + + Parameters + ---------- + X : NDArray + Samples from distribution P, shape (n, d). + Y : NDArray + Samples from distribution Q, shape (m, d). + subsample : int + Max samples to draw from the joint (X, Y) pool for median computation. + + Returns + ------- + float + Bandwidth sigma^2 for the Gaussian RBF kernel. + """ + rng = np.random.default_rng(0) + pool = np.concatenate([X, Y], axis=0).astype(np.float32) + if len(pool) > subsample: + idx = rng.choice(len(pool), subsample, replace=False) + pool = pool[idx] + sq_dists = cdist(pool, pool, metric="sqeuclidean") + upper = sq_dists[np.triu_indices_from(sq_dists, k=1)] + return float(np.median(upper)) + 1e-12 + + +def gaussian_rbf_kernel(X: NDArray, Y: NDArray, bandwidth: float) -> NDArray: + """Compute Gaussian RBF kernel matrix K(X, Y) in float32. + + K(x, y) = exp(-||x - y||^2 / (2 * bandwidth)) + + Parameters + ---------- + X : NDArray + Shape (n, d). + Y : NDArray + Shape (m, d). + bandwidth : float + Kernel bandwidth (sigma^2). Must be > 0. + + Returns + ------- + NDArray + Kernel matrix, shape (n, m), float32. + """ + sq_dists = cdist(X.astype(np.float32), Y.astype(np.float32), metric="sqeuclidean") + return np.exp(-sq_dists / (2.0 * bandwidth), dtype=np.float32) + + +def compute_mmd_unbiased(X: NDArray, Y: NDArray, bandwidth: float | None = None) -> float: + """Compute the unbiased quadratic-time MMD^2 estimator. + + MMD^2_u = (1/(n(n-1))) sum_{i!=j} k(x_i, x_j) + + (1/(m(m-1))) sum_{i!=j} k(y_i, y_j) + - (2/(nm)) sum_{i,j} k(x_i, y_j) + + Parameters + ---------- + X : NDArray + Samples from distribution P, shape (n, d). + Y : NDArray + Samples from distribution Q, shape (m, d). + bandwidth : float or None + Gaussian RBF bandwidth. None = median heuristic. + + Returns + ------- + float + Unbiased MMD^2 estimate. + """ + if bandwidth is None: + bandwidth = median_heuristic(X, Y) + n = len(X) + m = len(Y) + K_XX = gaussian_rbf_kernel(X, X, bandwidth) + K_YY = gaussian_rbf_kernel(Y, Y, bandwidth) + K_XY = gaussian_rbf_kernel(X, Y, bandwidth) + np.fill_diagonal(K_XX, 0.0) + np.fill_diagonal(K_YY, 0.0) + mmd2 = K_XX.sum() / (n * (n - 1)) + K_YY.sum() / (m * (m - 1)) - 2.0 * K_XY.mean() + return float(mmd2) + + +_MMD_PERM_MAX_N = 20_000 + + +def mmd_permutation_test( + X: NDArray, + Y: NDArray, + n_permutations: int = 1000, + bandwidth: float | None = None, + seed: int = 42, +) -> tuple[float, float, NDArray]: + """MMD^2 with vectorized permutation test for significance. + + Precomputes the pooled kernel matrix K_pool once, then all permutations + are evaluated via vectorized row/column sums — no repeated cdist calls + and no Python loop over individual permutations. + + Strategy: for each permutation p, MMD^2 = sum_X/n(n-1) + sum_Y/m(m-1) - 2*mean_XY + where sum_X = sum of K_pool[ix,ix] off-diagonal = (K_pool[ix,:] * one_hot_X).sum(). + We represent each permutation as a binary label vector z in {0,1}^(n+m), + then use K_pool @ z and K_pool @ (1-z) to get row sums in O(n_perm * N) ops. + + Parameters + ---------- + X : NDArray + Samples from distribution P, shape (n, d). + Y : NDArray + Samples from distribution Q, shape (m, d). + n_permutations : int + Number of permutations for the null distribution. + bandwidth : float or None + Gaussian RBF bandwidth. None = median heuristic (computed once). + seed : int + Random seed for reproducibility. + + Returns + ------- + mmd2 : float + Observed MMD^2 (unbiased). + p_value : float + Permutation test p-value. + null_distribution : NDArray + Null MMD^2 values from permutations, shape (n_permutations,). + """ + if bandwidth is None: + bandwidth = median_heuristic(X, Y) + n = len(X) + m = len(Y) + N = n + m + # The pooled kernel matrix is (N, N) float32 — quadratic in N. Cap N + # explicitly so callers see a clear error rather than an OOM when they + # forget to subsample (50k => 10 GB; 100k => 40 GB). + if N > _MMD_PERM_MAX_N: + raise ValueError( + f"mmd_permutation_test pooled kernel would be ({N}, {N}) float32 " + f"≈ {(N * N * 4) / 1e9:.1f} GB. Subsample X and/or Y so that " + f"len(X) + len(Y) <= {_MMD_PERM_MAX_N}." + ) + pool = np.concatenate([X, Y], axis=0).astype(np.float32) + # Compute full pooled kernel matrix once: (N, N) float32 + K = gaussian_rbf_kernel(pool, pool, bandwidth) + np.fill_diagonal(K, 0.0) + + def _mmd2_from_labels(z: NDArray) -> NDArray: + """Vectorized MMD^2 for a batch of label vectors. + + Parameters + ---------- + z : NDArray + Shape (n_perm, N), float32, 1 = assigned to X group. + + Returns + ------- + NDArray + MMD^2 values, shape (n_perm,). + """ + nz = z.sum(axis=1) # actual n per permutation (n_perm,) + mz = N - nz # actual m per permutation + # Row sums of K restricted to X-group and Y-group + # K @ z.T -> (N, n_perm), then z @ (K @ z.T) -> (n_perm, n_perm) diagonal = sum_XX + KzT = K @ z.T # (N, n_perm) + sum_XX = (z * KzT.T).sum(axis=1) # (n_perm,) — within-X kernel sums (diagonal zeroed) + sum_YY = ((1 - z) * (K @ (1 - z).T).T).sum(axis=1) # (n_perm,) — within-Y + sum_XY = (z * (K @ (1 - z).T).T).sum(axis=1) # (n_perm,) — cross + kxx = sum_XX / (nz * (nz - 1)) + kyy = sum_YY / (mz * (mz - 1)) + kxy = sum_XY / (nz * mz) + return kxx + kyy - 2.0 * kxy + + # Observed: original split (first n are X) + z_obs = np.zeros((1, N), dtype=np.float32) + z_obs[0, :n] = 1.0 + observed = float(_mmd2_from_labels(z_obs)[0]) + + # Null: random permutations as binary label vectors + rng = np.random.default_rng(seed) + # Generate all permutation indices at once + perms = np.stack([rng.permutation(N) for _ in range(n_permutations)]) # (n_perm, N) + z_null = np.zeros((n_permutations, N), dtype=np.float32) + row_idx = np.arange(n_permutations)[:, None] + z_null[row_idx, perms[:, :n]] = 1.0 + + null = _mmd2_from_labels(z_null) + p_value = float((np.sum(null >= observed) + 1) / (n_permutations + 1)) + return observed, p_value, null diff --git a/packages/viscy-utils/src/viscy_utils/evaluation/zarr_utils.py b/packages/viscy-utils/src/viscy_utils/evaluation/zarr_utils.py index a6e0aefe2..e4566b029 100644 --- a/packages/viscy-utils/src/viscy_utils/evaluation/zarr_utils.py +++ b/packages/viscy-utils/src/viscy_utils/evaluation/zarr_utils.py @@ -7,6 +7,7 @@ import pandas as pd import zarr from anndata.io import write_elem +from pandas.arrays import ArrowStringArray def append_to_anndata_zarr( @@ -31,12 +32,25 @@ def append_to_anndata_zarr( obs : pd.DataFrame, optional Observation metadata. Replaces the entire ``obs`` group. uns : dict, optional - Unstructured annotation. Replaces the entire ``uns`` group. + Mapping of uns keys to values. Each key is written to ``uns/{key}``, + replacing any existing entry while preserving other uns keys. """ store = zarr.open(str(zarr_path), mode="a", use_consolidated=False) ad.settings.allow_write_nullable_strings = True if obs is not None: + # TODO: remove once anndata 0.13 supports pandas 3 Arrow-backed strings natively. + # anndata 0.12.9+ requires pandas <3, so we stay on 0.12.6 + pandas 3 and + # must manually downcast ArrowStringArray columns to object dtype before writing. + obs = obs.copy() + for col in obs.columns: + arr = obs[col].array + if isinstance(arr, ArrowStringArray): + obs[col] = obs[col].astype(object) + elif isinstance(arr, pd.Categorical) and isinstance(arr.categories._values, ArrowStringArray): + obs[col] = obs[col].cat.rename_categories(arr.categories.astype(object)) + if isinstance(obs.index._values, ArrowStringArray): + obs.index = obs.index.astype(object) if "obs" in store: del store["obs"] write_elem(store, "obs", obs) @@ -49,9 +63,13 @@ def append_to_anndata_zarr( write_elem(store, obsm_path, value) if uns is not None: - if "uns" in store: - del store["uns"] - write_elem(store, "uns", uns) + if "uns" not in store: + store.create_group("uns") + for key, value in uns.items(): + uns_path = f"uns/{key}" + if uns_path in store: + del store[uns_path] + write_elem(store, uns_path, value) zarr.consolidate_metadata(str(zarr_path)) diff --git a/packages/viscy-utils/src/viscy_utils/mp_utils.py b/packages/viscy-utils/src/viscy_utils/mp_utils.py index 015008246..967db4c37 100644 --- a/packages/viscy-utils/src/viscy_utils/mp_utils.py +++ b/packages/viscy-utils/src/viscy_utils/mp_utils.py @@ -1,10 +1,41 @@ """Multiprocessing utilities for dataset processing.""" +import os from concurrent.futures import ProcessPoolExecutor import numpy as np +def available_cpus(default: int = 1) -> int: + """Return the number of CPUs the current process is allowed to use. + + Prefers ``SLURM_CPUS_PER_TASK`` (the cluster-scheduler-allocated count) + over ``os.cpu_count()`` (the node's total physical cores). On a 48-core + node where SLURM allocated 16 cores, this returns 16 — preventing + oversubscription and respecting the cgroup pinning that SLURM sets. + + Use this anywhere you'd otherwise reach for ``os.cpu_count()`` to size + a thread pool, ``n_jobs``, ``num_workers``, ``data_copy_concurrency``, + etc. Library defaults like sklearn's ``n_jobs=-1`` and BLAS env autodetect + are NOT SLURM-aware and will spawn one thread per physical core. + + Parameters + ---------- + default : int, optional + Fallback when neither ``SLURM_CPUS_PER_TASK`` nor ``os.cpu_count()`` + is informative. Defaults to 1 (safe single-threaded). + + Returns + ------- + int + Number of CPUs available to this process, ``>= 1``. + """ + slurm_cpus = os.environ.get("SLURM_CPUS_PER_TASK") + if slurm_cpus is not None: + return int(slurm_cpus) + return os.cpu_count() or default + + def mp_wrapper(fn, fn_args, workers): """Execute function with multiprocessing. diff --git a/packages/viscy-utils/tests/test_linear_classifier.py b/packages/viscy-utils/tests/test_linear_classifier.py index aad22f43d..efcd356b8 100644 --- a/packages/viscy-utils/tests/test_linear_classifier.py +++ b/packages/viscy-utils/tests/test_linear_classifier.py @@ -42,11 +42,13 @@ def synthetic_adata_with_unknowns(): class TestLinearClassifierPipeline: @pytest.fixture def trained_pipeline(self, annotated_adata): - pipeline, _ = train_linear_classifier(annotated_adata, task="cell_death_state", use_scaling=True, use_pca=False) + pipeline, _, _ = train_linear_classifier( + annotated_adata, task="cell_death_state", use_scaling=True, use_pca=False + ) return pipeline def test_transform_with_scaler_and_pca(self, annotated_adata): - pipeline, _ = train_linear_classifier( + pipeline, _, _ = train_linear_classifier( annotated_adata, task="cell_death_state", use_scaling=True, @@ -58,7 +60,7 @@ def test_transform_with_scaler_and_pca(self, annotated_adata): assert X_transformed.shape == (X.shape[0], 5) def test_transform_scaler_only(self, annotated_adata): - pipeline, _ = train_linear_classifier( + pipeline, _, _ = train_linear_classifier( annotated_adata, task="cell_death_state", use_scaling=True, @@ -70,7 +72,7 @@ def test_transform_scaler_only(self, annotated_adata): assert pipeline.pca is None def test_transform_no_preprocessing(self, annotated_adata): - pipeline, _ = train_linear_classifier( + pipeline, _, _ = train_linear_classifier( annotated_adata, task="cell_death_state", use_scaling=False, @@ -94,18 +96,18 @@ def test_predict_proba_shape(self, trained_pipeline, annotated_adata): class TestTrainLinearClassifier: def test_train_basic(self, annotated_adata): - pipeline, metrics = train_linear_classifier(annotated_adata, task="cell_death_state") + pipeline, metrics, _ = train_linear_classifier(annotated_adata, task="cell_death_state") assert isinstance(pipeline, LinearClassifierPipeline) assert isinstance(metrics, dict) assert "train_accuracy" in metrics assert "train_weighted_f1" in metrics def test_train_with_scaling(self, annotated_adata): - pipeline, _ = train_linear_classifier(annotated_adata, task="cell_death_state", use_scaling=True) + pipeline, _, _ = train_linear_classifier(annotated_adata, task="cell_death_state", use_scaling=True) assert pipeline.scaler is not None def test_train_with_pca(self, annotated_adata): - pipeline, _ = train_linear_classifier( + pipeline, _, _ = train_linear_classifier( annotated_adata, task="cell_death_state", use_pca=True, @@ -115,26 +117,26 @@ def test_train_with_pca(self, annotated_adata): assert pipeline.pca.n_components == 5 def test_train_no_split(self, annotated_adata): - pipeline, metrics = train_linear_classifier(annotated_adata, task="cell_death_state", split_train_data=1.0) + pipeline, metrics, _ = train_linear_classifier(annotated_adata, task="cell_death_state", split_train_data=1.0) assert "train_accuracy" in metrics assert "val_accuracy" not in metrics def test_train_metrics_keys(self, annotated_adata): - _, metrics = train_linear_classifier(annotated_adata, task="cell_death_state", split_train_data=0.8) + _, metrics, _ = train_linear_classifier(annotated_adata, task="cell_death_state", split_train_data=0.8) assert "train_accuracy" in metrics assert "train_weighted_f1" in metrics for class_name in ["alive", "dead", "apoptotic"]: assert f"train_{class_name}_f1" in metrics def test_train_reproducibility(self, annotated_adata): - _, metrics_a = train_linear_classifier(annotated_adata, task="cell_death_state", random_seed=123) - _, metrics_b = train_linear_classifier(annotated_adata, task="cell_death_state", random_seed=123) + _, metrics_a, _ = train_linear_classifier(annotated_adata, task="cell_death_state", random_seed=123) + _, metrics_b, _ = train_linear_classifier(annotated_adata, task="cell_death_state", random_seed=123) assert metrics_a == metrics_b def test_train_sparse_matrix(self, annotated_adata): sparse_adata = annotated_adata.copy() sparse_adata.X = scipy.sparse.csr_matrix(sparse_adata.X) - pipeline, metrics = train_linear_classifier(sparse_adata, task="cell_death_state") + pipeline, metrics, _ = train_linear_classifier(sparse_adata, task="cell_death_state") assert isinstance(pipeline, LinearClassifierPipeline) assert "train_accuracy" in metrics @@ -142,7 +144,7 @@ def test_train_sparse_matrix(self, annotated_adata): class TestPredictWithClassifier: @pytest.fixture def pipeline_and_adata(self, annotated_adata): - pipeline, _ = train_linear_classifier(annotated_adata, task="cell_death_state") + pipeline, _, _ = train_linear_classifier(annotated_adata, task="cell_death_state") return pipeline, annotated_adata def test_predict_adds_obs_columns(self, pipeline_and_adata): diff --git a/packages/viscy-utils/tests/test_online_eval.py b/packages/viscy-utils/tests/test_online_eval.py index 86076d650..70393f52e 100644 --- a/packages/viscy-utils/tests/test_online_eval.py +++ b/packages/viscy-utils/tests/test_online_eval.py @@ -1,8 +1,15 @@ """Tests for OnlineEvalCallback metrics.""" +from types import SimpleNamespace + import numpy as np +import torch -from viscy_utils.callbacks.online_eval import effective_rank, temporal_smoothness +from viscy_utils.callbacks.online_eval import ( + OnlineEvalCallback, + effective_rank, + temporal_smoothness, +) class TestEffectiveRank: @@ -85,3 +92,42 @@ def test_two_samples_insufficient_pairs(self): rho = temporal_smoothness(features, track_ids, timepoints) assert np.isnan(rho) + + +class TestGatherAcrossRanks: + """_gather_across_ranks must produce full-set arrays under DDP. + + The world_size=1 path is exercised by the integration test; this + class covers the multi-rank gather + missing-key passthrough that + are otherwise unreachable without a real distributed backend. + """ + + def test_world_size_two_concatenates_features_and_handles_missing(self): + """world_size=2 with identical per-rank inputs stacks features and labels; + a globally-missing optional array returns None instead of stalling.""" + + class _FakeModule: + def __init__(self, world_size: int): + self.trainer = SimpleNamespace(world_size=world_size) + self.device = torch.device("cpu") + self._w = world_size + + def all_gather(self, tensor: torch.Tensor) -> torch.Tensor: + return torch.stack([tensor] * self._w, dim=0) + + callback = OnlineEvalCallback() + module = _FakeModule(world_size=2) + features_local = torch.arange(12, dtype=torch.float32).reshape(3, 4) + labels_local = np.array([7, 8, 9]) + + # labels present, track_ids/timepoints missing (None on every rank). + features_np, labels, track_ids, timepoints = callback._gather_across_ranks( + module, features_local, labels_local, None, None + ) + + assert features_np.shape == (6, 4) + np.testing.assert_array_equal(features_np[:3], features_local.numpy()) + np.testing.assert_array_equal(features_np[3:], features_local.numpy()) + np.testing.assert_array_equal(labels, np.array([7, 8, 9, 7, 8, 9])) + assert track_ids is None + assert timepoints is None diff --git a/uv.lock b/uv.lock index 10657c08b..57783b19f 100644 --- a/uv.lock +++ b/uv.lock @@ -2,14 +2,22 @@ version = 1 revision = 3 requires-python = ">=3.12, <3.14" resolution-markers = [ - "python_full_version >= '3.13' and sys_platform == 'win32'", - "python_full_version < '3.13' and sys_platform == 'win32'", - "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version < '3.13' and sys_platform == 'emscripten'", - "python_full_version >= '3.13' and sys_platform == 'linux'", - "python_full_version < '3.13' and sys_platform == 'linux'", - "python_full_version >= '3.13' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version < '3.13' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'linux'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'linux'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", ] [manifest] @@ -684,6 +692,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/80/99/2adc7d8ffead633234817ef8e9a87115c8a11927a94478f6bb3d3f4d4f7d/contourpy-1.3.3-cp313-cp313t-win_arm64.whl", hash = "sha256:3c30273eb2a55024ff31ba7d052dde990d7d8e5450f4bbb6e913558b3d6c2301", size = 199713, upload-time = "2025-07-26T12:02:14.4Z" }, ] +[[package]] +name = "copairs" +version = "0.5.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "duckdb" }, + { name = "pandas" }, + { name = "statsmodels" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/65/25/7e2b2327ce9b3a7312be41070f264a09761fccb146cf60206d27c50e24b6/copairs-0.5.4.tar.gz", hash = "sha256:4d821784fa42d388db66e6a90c4ca1849c79957059260655faa884ffe6559648", size = 41895, upload-time = "2026-01-27T12:21:07.836Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/2a/86a6255d7e892419833ba5951f7574d02c9c83648cd939bb5a921e386858/copairs-0.5.4-py3-none-any.whl", hash = "sha256:e24e41ffdcfabf8d76b4288423f8951ea9c69884d5c4e88f8d9d33ff1ee32bbf", size = 34092, upload-time = "2026-01-27T12:21:06.368Z" }, +] + [[package]] name = "coverage" version = "7.13.4" @@ -759,7 +782,7 @@ name = "cuda-bindings" version = "12.9.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cuda-pathfinder", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "cuda-pathfinder", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/a9/c1/dabe88f52c3e3760d861401bb994df08f672ec893b8f7592dc91626adcf3/cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fda147a344e8eaeca0c6ff113d2851ffca8f7dfc0a6c932374ee5c47caa649c8", size = 12151019, upload-time = "2025-10-21T14:51:43.167Z" }, @@ -990,6 +1013,47 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl", hash = "sha256:2a3175ce74a06109ff9307d90a230f81215cbac9a751f4d1c6194644b8204f9d", size = 21592, upload-time = "2024-05-23T14:13:55.283Z" }, ] +[[package]] +name = "dtaidistance" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cd/01/aa26cc97b64d397ff03b9576b0a04cc79d0e3bae512eb087cfab7d98f4ec/dtaidistance-2.4.0.tar.gz", hash = "sha256:bd4066800254fbd5b620e6462bb759c9d85b79ac2080b354cedc901f446b6c82", size = 1316462, upload-time = "2026-02-12T22:23:56.35Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/63/c1546dc5a4a98f77ca044206e8d8b7604349d36d0b76d5c03ab393a55e60/dtaidistance-2.4.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:64d54f910b53cd7a56b215e06d2b24b22090af836102d48558d3e9569ded2b66", size = 2124723, upload-time = "2026-02-12T22:23:39.482Z" }, + { url = "https://files.pythonhosted.org/packages/ad/9a/4c0cb726c3c93436c993f55fc59d5fd2142c1a0fe6fe9ec06cc7bf25ab15/dtaidistance-2.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3afb229f4524f8bbf835a5dc3e07abcee9b6b9c6af4f14436cad19639102243c", size = 1549051, upload-time = "2026-02-13T08:14:46.866Z" }, + { url = "https://files.pythonhosted.org/packages/f5/8e/ccdd057e4ff71cf0b6fe34220cbd214d469f831b45acbbb4366fdfef6330/dtaidistance-2.4.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:349d6765e10ddbb5e22e937cf1bc42394f5f8d36bc127f8af24a0cd0259f4804", size = 4361729, upload-time = "2026-02-12T22:23:41.184Z" }, + { url = "https://files.pythonhosted.org/packages/00/cf/ef215e8864c21eb14872f98987d9736ebbbe5049d429039e2a93adcacad4/dtaidistance-2.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:6ab9431a5b66aafd37ab4dfcfe563b66694ed192019c1632d2de7a431a883bcd", size = 1443363, upload-time = "2026-02-12T22:23:43.706Z" }, + { url = "https://files.pythonhosted.org/packages/87/89/c64eea692eae3b269719ee5173bf5008b5c165280248e3fad1948c765a2b/dtaidistance-2.4.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:4cf41f3edcc4c1b94ebbc1de029ee9b58da28f33f7bf3af89212cc05e35ec8f1", size = 2117805, upload-time = "2026-02-12T22:23:45.632Z" }, + { url = "https://files.pythonhosted.org/packages/ed/7f/06ce3d5ce51a959be0534584ad2556e6c8be966ef1218a866c6c3d62e3c5/dtaidistance-2.4.0-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:94b841d6575e3ad715b4e213f0f04de25e23c2da3ac21ee9c6775b38f5bdfecf", size = 1738478, upload-time = "2026-02-13T08:14:51.715Z" }, + { url = "https://files.pythonhosted.org/packages/db/8e/6c8a5c7710f9f5e3805281974ce8fea4ad0334c00a1e0f977977c045a594/dtaidistance-2.4.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b0f2a65628aea82175e7f8c5e96faf5372c933ed40e2e39a84957d8fe305158d", size = 4341606, upload-time = "2026-02-12T22:23:46.977Z" }, + { url = "https://files.pythonhosted.org/packages/9f/b6/7f77c6773380742660d09f379d43814b448296fc24c3fb1de15a3d813311/dtaidistance-2.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:b8c9ef4c7270d1a192e8f1b481c2e10e63c33c6e7edfc507acac7f3fdc19949f", size = 1441578, upload-time = "2026-02-12T22:23:48.897Z" }, +] + +[[package]] +name = "duckdb" +version = "1.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/66/744b4931b799a42f8cb9bc7a6f169e7b8e51195b62b246db407fd90bf15f/duckdb-1.5.2.tar.gz", hash = "sha256:638da0d5102b6cb6f7d47f83d0600708ac1d3cb46c5e9aaabc845f9ba4d69246", size = 18017166, upload-time = "2026-04-13T11:30:09.065Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/de/ebe66bbe78125fc610f4fd415447a65349d94245950f3b3dfb31d028af02/duckdb-1.5.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e6495b00cad16888384119842797c49316a96ae1cb132bb03856d980d95afee1", size = 30064950, upload-time = "2026-04-13T11:29:11.468Z" }, + { url = "https://files.pythonhosted.org/packages/2d/8a/3e25b5d03bcf1fb99d189912f8ce92b1db4f9c8778e1b1f55745973a855a/duckdb-1.5.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d72b8856b1839d35648f38301b058f6232f4d36b463fe4dc8f4d3fdff2df1a2e", size = 15969113, upload-time = "2026-04-13T11:29:14.139Z" }, + { url = "https://files.pythonhosted.org/packages/19/bb/58001f0815002b1a93431bf907f77854085c7d049b83d521814a07b9db0b/duckdb-1.5.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2a1de4f4d454b8c97aec546c82003fc834d3422ce4bc6a19902f3462ef293bed", size = 14224774, upload-time = "2026-04-13T11:29:16.758Z" }, + { url = "https://files.pythonhosted.org/packages/d3/2f/a7f0de9509d1cef35608aeb382919041cdd70f58c173865c3da6a0d87979/duckdb-1.5.2-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce0b8141a10d37ecef729c45bc41d334854013f4389f1488bd6035c5579aaac1", size = 19313510, upload-time = "2026-04-13T11:29:19.574Z" }, + { url = "https://files.pythonhosted.org/packages/26/78/eb1e064ea8b9df3b87b167bfd7a407b2f615a4291e06cba756727adfa06c/duckdb-1.5.2-cp312-cp312-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c99ef73a277c8921bc0a1f16dee38d924484251d9cfd20951748c20fcd5ed855", size = 21429692, upload-time = "2026-04-13T11:29:22.575Z" }, + { url = "https://files.pythonhosted.org/packages/5b/12/05b0c47d14839925c5e35b79081d918ca82e3f236bb724a6f58409dd5291/duckdb-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:8d599758b4e48bf12e18c9b960cf491d219f0c4972d19a45489c05cc5ab36f83", size = 13107594, upload-time = "2026-04-13T11:29:25.43Z" }, + { url = "https://files.pythonhosted.org/packages/0b/2c/80558a82b236e044330e84a154b96aacddb343316b479f3d49be03ea11cb/duckdb-1.5.2-cp312-cp312-win_arm64.whl", hash = "sha256:fc85a5dbcbe6eccac1113c72370d1d3aacfdd49198d63950bdf7d8638a307f00", size = 13927537, upload-time = "2026-04-13T11:29:27.842Z" }, + { url = "https://files.pythonhosted.org/packages/98/f2/e3d742808f138d374be4bb516fade3d1f33749b813650810ab7885cdc363/duckdb-1.5.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:4420b3f47027a7849d0e1815532007f377fa95ee5810b47ea717d35525c12f79", size = 30064879, upload-time = "2026-04-13T11:29:30.763Z" }, + { url = "https://files.pythonhosted.org/packages/72/0d/f3dc1cf97e1267ca15e4307d456f96ce583961f0703fd75e62b2ad8d64fa/duckdb-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:bb42e6ed543902e14eae647850da24103a89f0bc2587dec5601b1c1f213bd2ed", size = 15969327, upload-time = "2026-04-13T11:29:33.481Z" }, + { url = "https://files.pythonhosted.org/packages/b1/e0/d5418def53ae4e05a63075705ff44ed5af5a1a5932627eb2b600c5df1c93/duckdb-1.5.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:98c0535cd6d901f61a5ea3c2e26a1fd28482953d794deb183daf568e3aa5dda6", size = 14225107, upload-time = "2026-04-13T11:29:35.882Z" }, + { url = "https://files.pythonhosted.org/packages/16/a7/15aaa59dbecc35e9711980fcdbf525b32a52470b32d18ef678193a146213/duckdb-1.5.2-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:486c862bf7f163c0110b6d85b3e5c031d224a671cca468f12ebb1d3a348f6b39", size = 19313433, upload-time = "2026-04-13T11:29:38.367Z" }, + { url = "https://files.pythonhosted.org/packages/bd/21/d903cc63a5140c822b7b62b373a87dc557e60c29b321dfb435061c5e67cf/duckdb-1.5.2-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:70631c847ca918ee710ec874241b00cf9d2e5be90762cbb2a0389f17823c08f7", size = 21429837, upload-time = "2026-04-13T11:29:41.135Z" }, + { url = "https://files.pythonhosted.org/packages/e3/0a/b770d1f60c70597302130d6247f418549b7094251a02348fbaf1c7e147ae/duckdb-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:52a21823f3fbb52f0f0e5425e20b07391ad882464b955879499b5ff0b45a376b", size = 13107699, upload-time = "2026-04-13T11:29:43.905Z" }, + { url = "https://files.pythonhosted.org/packages/d9/cf/e200fe431d700962d1a908d2ce89f53ccee1cc8db260174ae663ba09686b/duckdb-1.5.2-cp313-cp313-win_arm64.whl", hash = "sha256:411ad438bd4140f189a10e7f515781335962c5d18bd07837dc6d202e3985253d", size = 13927646, upload-time = "2026-04-13T11:29:46.598Z" }, +] + [[package]] name = "dynacell" source = { editable = "applications/dynacell" } @@ -1054,6 +1118,7 @@ dependencies = [ [package.optional-dependencies] eval = [ { name = "anndata" }, + { name = "dtaidistance" }, { name = "natsort" }, { name = "phate" }, { name = "scikit-learn" }, @@ -1062,6 +1127,17 @@ eval = [ { name = "umap-learn" }, { name = "wandb" }, ] +tracking = [ + { name = "gurobipy" }, + { name = "onnxruntime-gpu" }, + { name = "py-ctcmetrics" }, + { name = "tabulate" }, + { name = "tracksdata" }, +] +viz = [ + { name = "imageio-ffmpeg" }, + { name = "matplotlib" }, +] [package.dev-dependencies] dev = [ @@ -1087,15 +1163,23 @@ test = [ requires-dist = [ { name = "anndata", marker = "extra == 'eval'" }, { name = "click" }, - { name = "iohub", specifier = ">=0.3a2" }, + { name = "dtaidistance", marker = "extra == 'eval'" }, + { name = "gurobipy", marker = "extra == 'tracking'", specifier = ">=12.0.1,<13" }, + { name = "imageio-ffmpeg", marker = "extra == 'viz'" }, + { name = "iohub", specifier = ">=0.3.3" }, + { name = "matplotlib", marker = "extra == 'viz'" }, { name = "natsort", marker = "extra == 'eval'" }, + { name = "onnxruntime-gpu", marker = "extra == 'tracking'" }, { name = "phate", marker = "extra == 'eval'" }, + { name = "py-ctcmetrics", marker = "extra == 'tracking'" }, { name = "pytorch-metric-learning" }, { name = "pyyaml" }, { name = "scikit-learn", marker = "extra == 'eval'" }, { name = "seaborn", marker = "extra == 'eval'" }, { name = "statsmodels", marker = "extra == 'eval'" }, + { name = "tabulate", marker = "extra == 'tracking'" }, { name = "torchvision" }, + { name = "tracksdata", marker = "extra == 'tracking'" }, { name = "umap-learn", marker = "extra == 'eval'" }, { name = "viscy-data", extras = ["triplet"], editable = "packages/viscy-data" }, { name = "viscy-models", editable = "packages/viscy-models" }, @@ -1103,7 +1187,7 @@ requires-dist = [ { name = "viscy-utils", extras = ["eval"], editable = "packages/viscy-utils" }, { name = "wandb", marker = "extra == 'eval'" }, ] -provides-extras = ["eval"] +provides-extras = ["eval", "tracking", "viz"] [package.metadata.requires-dev] dev = [ @@ -1256,7 +1340,7 @@ requires-dist = [ { name = "dask", extras = ["array"] }, { name = "eet-features", editable = "../../../../../home/eduardo.hirata/repos/eet_features" }, { name = "eet-inference", editable = "../../../../../home/eduardo.hirata/repos/eet_inference" }, - { name = "iohub", specifier = ">=0.3a2" }, + { name = "iohub", specifier = ">=0.3.3" }, { name = "napari", extras = ["pyqt5"] }, { name = "napari-geff", editable = "../../../../../home/eduardo.hirata/repos/napari-geff" }, { name = "pyyaml" }, @@ -1379,6 +1463,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b1/24/f4ed44e103ee7ec9880c43bb06a9d60eab5f06d80022f83005c67304655d/fill_voids-2.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:976f6a3c5a68f3f3483da779d8c71f11e8e3eec4c104d0d594ba5cd11a36a7fa", size = 181694, upload-time = "2025-09-03T05:28:19.728Z" }, ] +[[package]] +name = "flatbuffers" +version = "25.12.19" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e8/2d/d2a548598be01649e2d46231d151a6c56d10b964d94043a335ae56ea2d92/flatbuffers-25.12.19-py2.py3-none-any.whl", hash = "sha256:7634f50c427838bb021c2d66a3d1168e9d199b0607e6329399f04846d42e20b4", size = 26661, upload-time = "2025-12-19T23:16:13.622Z" }, +] + [[package]] name = "flexcache" version = "0.3" @@ -1897,6 +1989,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/fe/301e0936b79bcab4cacc7548bf2853fc28dced0a578bab1f7ef53c9aa75b/imageio-2.37.2-py3-none-any.whl", hash = "sha256:ad9adfb20335d718c03de457358ed69f141021a333c40a53e57273d8a5bd0b9b", size = 317646, upload-time = "2025-11-04T14:29:37.948Z" }, ] +[[package]] +name = "imageio-ffmpeg" +version = "0.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/bd/c3343c721f2a1b0c9fc71c1aebf1966a3b7f08c2eea8ed5437a2865611d6/imageio_ffmpeg-0.6.0.tar.gz", hash = "sha256:e2556bed8e005564a9f925bb7afa4002d82770d6b08825078b7697ab88ba1755", size = 25210, upload-time = "2025-01-16T21:34:32.747Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/da/58/87ef68ac83f4c7690961bce288fd8e382bc5f1513860fc7f90a9c1c1c6bf/imageio_ffmpeg-0.6.0-py3-none-macosx_10_9_intel.macosx_10_9_x86_64.whl", hash = "sha256:9d2baaf867088508d4a3458e61eeb30e945c4ad8016025545f66c4b5aaef0a61", size = 24932969, upload-time = "2025-01-16T21:34:20.464Z" }, + { url = "https://files.pythonhosted.org/packages/40/5c/f3d8a657d362cc93b81aab8feda487317da5b5d31c0e1fdfd5e986e55d17/imageio_ffmpeg-0.6.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b1ae3173414b5fc5f538a726c4e48ea97edc0d2cdc11f103afee655c463fa742", size = 21113891, upload-time = "2025-01-16T21:34:00.277Z" }, + { url = "https://files.pythonhosted.org/packages/33/e7/1925bfbc563c39c1d2e82501d8372734a5c725e53ac3b31b4c2d081e895b/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1d47bebd83d2c5fc770720d211855f208af8a596c82d17730aa51e815cdee6dc", size = 25632706, upload-time = "2025-01-16T21:33:53.475Z" }, + { url = "https://files.pythonhosted.org/packages/a0/2d/43c8522a2038e9d0e7dbdf3a61195ecc31ca576fb1527a528c877e87d973/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c7e46fcec401dd990405049d2e2f475e2b397779df2519b544b8aab515195282", size = 29498237, upload-time = "2025-01-16T21:34:13.726Z" }, + { url = "https://files.pythonhosted.org/packages/a0/13/59da54728351883c3c1d9fca1710ab8eee82c7beba585df8f25ca925f08f/imageio_ffmpeg-0.6.0-py3-none-win32.whl", hash = "sha256:196faa79366b4a82f95c0f4053191d2013f4714a715780f0ad2a68ff37483cc2", size = 19652251, upload-time = "2025-01-16T21:34:06.812Z" }, + { url = "https://files.pythonhosted.org/packages/2c/c6/fa760e12a2483469e2bf5058c5faff664acf66cadb4df2ad6205b016a73d/imageio_ffmpeg-0.6.0-py3-none-win_amd64.whl", hash = "sha256:02fa47c83703c37df6bfe4896aab339013f62bf02c5ebf2dce6da56af04ffc0a", size = 31246824, upload-time = "2025-01-16T21:34:28.6Z" }, +] + [[package]] name = "imagesize" version = "2.0.0" @@ -1956,7 +2062,7 @@ wheels = [ [[package]] name = "iohub" -version = "0.3.2" +version = "0.3.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "blosc2" }, @@ -1974,9 +2080,9 @@ dependencies = [ { name = "zarr" }, { name = "zarrs" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4c/d8/601a5a2d648370cd90825e3c51bc26155b223443ed936ec8e6d62135c871/iohub-0.3.2.tar.gz", hash = "sha256:54eb5a146efbc94375e2f40f51a98234a33a2e3820de7745fdcf06f33d86fef0", size = 289875, upload-time = "2026-04-10T04:16:50.515Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d0/81/4400daf22b508a237bbe05a58320886b1549bde3cb41eaaf46c4c777f355/iohub-0.3.3.tar.gz", hash = "sha256:8190d3155a5dee0e0b98416970b648008b9d4e86e42a84e32078b04284a6b66e", size = 290272, upload-time = "2026-04-24T00:13:38.941Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/80/cd/7c3fd9ecc51b598468d7702b69ab738d5cf9dbe6ffd22f96f26b2aca4ceb/iohub-0.3.2-py3-none-any.whl", hash = "sha256:6808e979ea229e569a627636a073c41e85c0f0224936ddaa952a30b157b79810", size = 84464, upload-time = "2026-04-10T04:16:49.111Z" }, + { url = "https://files.pythonhosted.org/packages/a0/4f/cf3443512b38501677649f79fb7524b35cb1f26b0238d116ed0407e162bb/iohub-0.3.3-py3-none-any.whl", hash = "sha256:b0eb7781ae076bbd3db7143cb8482612fd414016c589a1beec98bb6de9da1173", size = 84950, upload-time = "2026-04-24T00:13:37.499Z" }, ] [[package]] @@ -3267,7 +3373,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, @@ -3278,7 +3384,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, @@ -3305,9 +3411,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, - { name = "nvidia-cusparse-cu12", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, @@ -3318,7 +3424,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, @@ -3383,6 +3489,82 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4f/21/59baa90924b815b70f88045f0b206b7eab0b68b461c0192692486b516ab7/ome_zarr-0.12.2-py3-none-any.whl", hash = "sha256:655fe1b11ca01148603f9931a5b0af31207dfc03a3a35f9b0ab8639790282bbd", size = 41410, upload-time = "2025-08-22T08:57:12.44Z" }, ] +[[package]] +name = "onnx" +version = "1.21.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "protobuf" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c5/93/942d2a0f6a70538eea042ce0445c8aefd46559ad153469986f29a743c01c/onnx-1.21.0.tar.gz", hash = "sha256:4d8b67d0aaec5864c87633188b91cc520877477ec0254eda122bef8be43cd764", size = 12074608, upload-time = "2026-03-27T21:33:36.118Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7d/ae/cb644ec84c25e63575d9d8790fdcc5d1a11d67d3f62f872edb35fa38d158/onnx-1.21.0-cp312-abi3-macosx_12_0_universal2.whl", hash = "sha256:fc2635400fe39ff37ebc4e75342cc54450eadadf39c540ff132c319bf4960095", size = 17965930, upload-time = "2026-03-27T21:32:48.089Z" }, + { url = "https://files.pythonhosted.org/packages/6f/b6/eeb5903586645ef8a49b4b7892580438741acc3df91d7a5bd0f3a59ea9cb/onnx-1.21.0-cp312-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9003d5206c01fa2ff4b46311566865d8e493e1a6998d4009ec6de39843f1b59b", size = 17531344, upload-time = "2026-03-27T21:32:50.837Z" }, + { url = "https://files.pythonhosted.org/packages/a7/00/4823f06357892d1e60d6f34e7299d2ba4ed2108c487cc394f7ce85a3ff14/onnx-1.21.0-cp312-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a9261bd580fb8548c9c37b3c6750387eb8f21ea43c63880d37b2c622e1684285", size = 17613697, upload-time = "2026-03-27T21:32:54.222Z" }, + { url = "https://files.pythonhosted.org/packages/23/1d/391f3c567ae068c8ac4f1d1316bae97c9eb45e702f05975fe0e17ad441f0/onnx-1.21.0-cp312-abi3-win32.whl", hash = "sha256:9ea4e824964082811938a9250451d89c4ec474fe42dd36c038bfa5df31993d1e", size = 16287200, upload-time = "2026-03-27T21:32:57.277Z" }, + { url = "https://files.pythonhosted.org/packages/9c/a6/5eefbe5b40ea96de95a766bd2e0e751f35bdea2d4b951991ec9afaa69531/onnx-1.21.0-cp312-abi3-win_amd64.whl", hash = "sha256:458d91948ad9a7729a347550553b49ab6939f9af2cddf334e2116e45467dc61f", size = 16441045, upload-time = "2026-03-27T21:33:00.081Z" }, + { url = "https://files.pythonhosted.org/packages/63/c4/0ed8dc037a39113d2a4d66e0005e07751c299c46b993f1ad5c2c35664c20/onnx-1.21.0-cp312-abi3-win_arm64.whl", hash = "sha256:ca14bc4842fccc3187eb538f07eabeb25a779b39388b006db4356c07403a7bbb", size = 16403134, upload-time = "2026-03-27T21:33:03.987Z" }, + { url = "https://files.pythonhosted.org/packages/f8/89/0e1a9beb536401e2f45ac88735e123f2735e12fc7b56ff6c11727e097526/onnx-1.21.0-cp313-cp313t-macosx_12_0_universal2.whl", hash = "sha256:257d1d1deb6a652913698f1e3f33ef1ca0aa69174892fe38946d4572d89dd94f", size = 17975430, upload-time = "2026-03-27T21:33:07.005Z" }, + { url = "https://files.pythonhosted.org/packages/ec/46/e6dc71a7b3b317265591b20a5f71d0ff5c0d26c24e52283139dc90c66038/onnx-1.21.0-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7cd7cb8f6459311bdb557cbf6c0ccc6d8ace11c304d1bba0a30b4a4688e245f8", size = 17537435, upload-time = "2026-03-27T21:33:09.765Z" }, + { url = "https://files.pythonhosted.org/packages/49/2e/27affcac63eaf2ef183a44fd1a1354b11da64a6c72fe6f3fdcf5571bcee5/onnx-1.21.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b58a4cfec8d9311b73dc083e4c1fa362069267881144c05139b3eba5dc3a840", size = 17617687, upload-time = "2026-03-27T21:33:12.619Z" }, + { url = "https://files.pythonhosted.org/packages/1c/5c/ac8ed15e941593a3672ce424280b764979026317811f2e8508432bfc3429/onnx-1.21.0-cp313-cp313t-win_amd64.whl", hash = "sha256:1a9baf882562c4cebf79589bebb7cd71a20e30b51158cac3e3bbaf27da6163bd", size = 16449402, upload-time = "2026-03-27T21:33:15.555Z" }, + { url = "https://files.pythonhosted.org/packages/0e/aa/d2231e0dcaad838217afc64c306c8152a080134d2034e247cc973d577674/onnx-1.21.0-cp313-cp313t-win_arm64.whl", hash = "sha256:bba12181566acf49b35875838eba49536a327b2944664b17125577d230c637ad", size = 16408273, upload-time = "2026-03-27T21:33:18.599Z" }, +] + +[[package]] +name = "onnx-ir" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "onnx" }, + { name = "sympy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/35/e6/672fefb2f108d077f58181a7babf4c0f8d1182a30353ffc9c79c63afc5ee/onnx_ir-0.2.1.tar.gz", hash = "sha256:8b8b10a93f43e65962104de6070c43c5dacb0e3cdfefc7c8059dd83c9db64f35", size = 144279, upload-time = "2026-04-20T20:21:47.735Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8c/aa/f7a53321c60b9ad9ee184b6018292ed6b5389947592a2c8c09c736bb7f9e/onnx_ir-0.2.1-py3-none-any.whl", hash = "sha256:c7285da889312f91882de2092e298a9eeeefbfc1d1951c49d983992967eb09a7", size = 166792, upload-time = "2026-04-20T20:21:46.357Z" }, +] + +[[package]] +name = "onnxruntime-gpu" +version = "1.25.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "flatbuffers" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "protobuf" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/6d/2c13d3eff74caa9e59820a044a75becd34e9cbeeaf7617ad7679cdb1fdb7/onnxruntime_gpu-1.25.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2f0c36c63c8b0eb4091f2567067f480f66f0aedc189eb009545c98ce7e919056", size = 270342429, upload-time = "2026-04-22T17:28:10.526Z" }, + { url = "https://files.pythonhosted.org/packages/8c/2e/9fc303ae59d4caeb85ec3cea6881b7de8ca1d2a07140fade39913cd7ff10/onnxruntime_gpu-1.25.0-cp312-cp312-win_amd64.whl", hash = "sha256:61178cc4d84f59861714554531e01cccbd33ddf13cc0e87a3adea13b24d297ce", size = 220847708, upload-time = "2026-04-22T17:20:47.993Z" }, + { url = "https://files.pythonhosted.org/packages/f5/15/e63fe7b1abad6884bed07e9bb333e9f0ea48fbb8cbc1ea4a67ee6019d5d0/onnxruntime_gpu-1.25.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e462eb13ee9955117baec4f518916c1e7cb1a96001114105632bc6d454c6aee6", size = 270342324, upload-time = "2026-04-22T17:28:21.142Z" }, + { url = "https://files.pythonhosted.org/packages/21/10/b3533243d062b589d4b1f3ae26584af332c5cde618e7f6f5ff6fabbfd5f2/onnxruntime_gpu-1.25.0-cp313-cp313-win_amd64.whl", hash = "sha256:9a3682158e5e911385252eb95d6332b6f525972746c582e10f8a78213b39e624", size = 220848188, upload-time = "2026-04-22T17:20:56.946Z" }, + { url = "https://files.pythonhosted.org/packages/35/6c/d7706dd1d0eaafdba44d5c89f8d952de41e425a1b0cbd3ecfa60f918c249/onnxruntime_gpu-1.25.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8514b92c5929c953850090d823d018770cba2a971efab5f8f69a3c4280cdc632", size = 270364210, upload-time = "2026-04-22T17:28:33.568Z" }, +] + +[[package]] +name = "onnxscript" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "onnx" }, + { name = "onnx-ir" }, + { name = "packaging" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9b/99/fd948eba63ba65b52265a4cd09a14f96bb9f5b730fcef58876c4358bf406/onnxscript-0.7.0.tar.gz", hash = "sha256:c95ed7b339b02cface56ee27689565c46612e1fc542c562298dddfdad5268dc5", size = 612032, upload-time = "2026-04-20T17:09:19.775Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b9/ce/2ed92575cc3be4ea1db5f38f16f20765f9b20b69b14d6c1d9972658a8ee9/onnxscript-0.7.0-py3-none-any.whl", hash = "sha256:5b356907d4501e9919f8599c91d8da967406a37b1fac2b40caa55a49acf242ea", size = 714842, upload-time = "2026-04-20T17:09:22.089Z" }, +] + [[package]] name = "opencv-python-headless" version = "4.13.0.92" @@ -3856,6 +4038,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" }, ] +[[package]] +name = "py-ctcmetrics" +version = "1.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "imagecodecs" }, + { name = "numpy" }, + { name = "pandas" }, + { name = "scikit-learn" }, + { name = "scipy" }, + { name = "tifffile" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5d/25/bc4ff397b3ac93606ee105ab6832cca5f2a06b2dee9e1240f6215f541d4f/py_ctcmetrics-1.3.3.tar.gz", hash = "sha256:e055b7713bc704a42673b1313c7fd5ae55b80d49455132ff27b6b7db609209b0", size = 35153, upload-time = "2026-03-12T08:53:53.572Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/01/cc/c3c0d99df9540ca8ac4ee9c9177c5f88bf9693f5808ab5a5330d7d2fda65/py_ctcmetrics-1.3.3-py3-none-any.whl", hash = "sha256:7f35906030aadf8a4b5be9cf44260969b82b2d6bb3959b93f24928ff557b5f6c", size = 43419, upload-time = "2026-03-12T08:53:52.367Z" }, +] + [[package]] name = "pyairtable" version = "3.3.0" @@ -4143,8 +4342,10 @@ name = "pyqt5-qt5" version = "5.15.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version < '3.13' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", ] wheels = [ { url = "https://files.pythonhosted.org/packages/62/09/99a222b0360616250fb2e6003a54e43a2a06b0774f0f8d5daafb86a2c375/PyQt5_Qt5-5.15.2-py3-none-macosx_10_13_intel.whl", hash = "sha256:76980cd3d7ae87e3c7a33bfebfaee84448fd650bad6840471d6cae199b56e154", size = 40546019, upload-time = "2021-03-10T13:52:47.763Z" }, @@ -4155,12 +4356,18 @@ name = "pyqt5-qt5" version = "5.15.18" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and sys_platform == 'win32'", - "python_full_version < '3.13' and sys_platform == 'win32'", - "python_full_version >= '3.13' and sys_platform == 'emscripten'", - "python_full_version < '3.13' and sys_platform == 'emscripten'", - "python_full_version >= '3.13' and sys_platform == 'linux'", - "python_full_version < '3.13' and sys_platform == 'linux'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'linux'", + "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'linux'", ] wheels = [ { url = "https://files.pythonhosted.org/packages/9a/46/ffe177f99f897a59dc237a20059020427bd2d3853d713992b8081933ddfe/pyqt5_qt5-5.15.18-py3-none-manylinux2014_x86_64.whl", hash = "sha256:bf2457e6371969736b4f660a0c153258fa03dbc6a181348218e6f05421682af7", size = 60864590, upload-time = "2025-11-09T12:57:26.724Z" }, @@ -5231,6 +5438,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, ] +[[package]] +name = "tabulate" +version = "0.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/46/58/8c37dea7bbf769b20d58e7ace7e5edfe65b849442b00ffcdd56be88697c6/tabulate-0.10.0.tar.gz", hash = "sha256:e2cfde8f79420f6deeffdeda9aaec3b6bc5abce947655d17ac662b126e48a60d", size = 91754, upload-time = "2026-03-04T18:55:34.402Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/55/db07de81b5c630da5cbf5c7df646580ca26dfaefa593667fc6f2fe016d2e/tabulate-0.10.0-py3-none-any.whl", hash = "sha256:f0b0622e567335c8fabaaa659f1b33bcb6ddfe2e496071b743aa113f8774f2d3", size = 39814, upload-time = "2026-03-04T18:55:31.284Z" }, +] + [[package]] name = "tasklogger" version = "1.2.0" @@ -5980,7 +6196,7 @@ test = [ [package.metadata] requires-dist = [ { name = "imageio" }, - { name = "iohub", specifier = ">=0.3.2" }, + { name = "iohub", specifier = ">=0.3.3" }, { name = "lightning", specifier = ">=2.3" }, { name = "monai", specifier = ">=1.5.2" }, { name = "numpy", specifier = ">=2.4.1" }, @@ -6140,6 +6356,8 @@ dependencies = [ { name = "lightning" }, { name = "matplotlib" }, { name = "numpy" }, + { name = "onnx" }, + { name = "onnxscript" }, { name = "pyyaml" }, { name = "scikit-image" }, { name = "scipy" }, @@ -6153,6 +6371,7 @@ dependencies = [ [package.optional-dependencies] all = [ { name = "anndata" }, + { name = "copairs" }, { name = "natsort" }, { name = "phate" }, { name = "scikit-learn" }, @@ -6164,6 +6383,7 @@ anndata = [ { name = "natsort" }, ] eval = [ + { name = "copairs" }, { name = "phate" }, { name = "scikit-learn" }, { name = "umap-learn" }, @@ -6190,15 +6410,19 @@ test = [ [package.metadata] requires-dist = [ - { name = "anndata", marker = "extra == 'all'" }, - { name = "anndata", marker = "extra == 'anndata'" }, - { name = "iohub", specifier = ">=0.3a2" }, + { name = "anndata", marker = "extra == 'all'", specifier = "<0.12.9" }, + { name = "anndata", marker = "extra == 'anndata'", specifier = "<0.12.9" }, + { name = "copairs", marker = "extra == 'all'" }, + { name = "copairs", marker = "extra == 'eval'" }, + { name = "iohub", specifier = ">=0.3.3" }, { name = "jsonargparse", extras = ["signatures"], specifier = ">=4.26" }, { name = "lightning", specifier = ">=2.3" }, { name = "matplotlib", specifier = ">=3.10" }, { name = "natsort", marker = "extra == 'all'" }, { name = "natsort", marker = "extra == 'anndata'" }, { name = "numpy", specifier = ">=2.4.1" }, + { name = "onnx" }, + { name = "onnxscript" }, { name = "phate", marker = "extra == 'all'" }, { name = "phate", marker = "extra == 'eval'" }, { name = "pyyaml" }, @@ -6220,7 +6444,7 @@ provides-extras = ["all", "anndata", "eval"] [package.metadata.requires-dev] dev = [ - { name = "anndata" }, + { name = "anndata", specifier = "<0.12.9" }, { name = "natsort" }, { name = "pytest", specifier = ">=9.0.2" }, { name = "pytest-cov", specifier = ">=7" }, @@ -6228,7 +6452,7 @@ dev = [ { name = "wandb" }, ] test = [ - { name = "anndata" }, + { name = "anndata", specifier = "<0.12.9" }, { name = "natsort" }, { name = "pytest", specifier = ">=9.0.2" }, { name = "pytest-cov", specifier = ">=7" },