diff --git a/applications/DynaCLR/evaluation/lot_correction/__init__.py b/applications/DynaCLR/evaluation/lot_correction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/DynaCLR/evaluation/lot_correction/apply_lot_correction.py b/applications/DynaCLR/evaluation/lot_correction/apply_lot_correction.py new file mode 100644 index 000000000..9c5c55fb3 --- /dev/null +++ b/applications/DynaCLR/evaluation/lot_correction/apply_lot_correction.py @@ -0,0 +1,84 @@ +"""CLI for applying a fitted LOT pipeline to an embedding zarr. + +Usage +----- + viscy-dynaclr apply-lot-correction -c config.yaml + +Transforms all cells through StandardScaler → PCA → LOT and writes a new +zarr whose ``.X`` contains the corrected embeddings (shape n_cells × n_pca). +All ``.obs`` metadata from the input zarr is preserved. + +Example config (YAML) +--------------------- + input_zarr: /path/to/lightsheet_organelle.zarr + pipeline: /path/to/lot_pipeline.pkl + output_zarr: /path/to/corrected_organelle.zarr + overwrite: false +""" + +import logging +from pathlib import Path + +import click +from pydantic import ValidationError + +from viscy.representation.evaluation.lot_correction import ( + apply_lot_correction, + load_lot_pipeline, +) +from viscy.representation.evaluation.lot_correction_config import LotApplyConfig +from viscy.utils.cli_utils import load_config + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to YAML configuration file.", +) +def main(config: Path): + """Apply a fitted LOT pipeline to correct batch effects in an embedding zarr.""" + click.echo("=" * 60) + click.echo("LOT BATCH CORRECTION — APPLY") + click.echo("=" * 60) + + try: + config_dict = load_config(config) + apply_config = LotApplyConfig(**config_dict) + except ValidationError as e: + click.echo(f"\nConfiguration validation failed:\n{e}", err=True) + raise click.Abort() + except Exception as e: + click.echo(f"\nFailed to load configuration: {e}", err=True) + raise click.Abort() + + click.echo(f"\nConfiguration loaded: {config}") + click.echo(f" Input zarr: {apply_config.input_zarr}") + click.echo(f" Pipeline: {apply_config.pipeline}") + click.echo(f" Output zarr: {apply_config.output_zarr}") + click.echo(f" Overwrite: {apply_config.overwrite}") + + try: + pipeline = load_lot_pipeline(apply_config.pipeline) + click.echo( + f"\nPipeline loaded — n_pca={pipeline['n_pca']}, " + f"PCA variance={pipeline.get('pca_variance_explained', float('nan')):.1f}%" + ) + apply_lot_correction( + input_zarr=apply_config.input_zarr, + pipeline=pipeline, + output_zarr=apply_config.output_zarr, + overwrite=apply_config.overwrite, + ) + click.echo(f"\nCorrected zarr written to: {apply_config.output_zarr}") + except Exception as e: + click.echo(f"\nApplication failed: {e}", err=True) + raise click.Abort() + + +if __name__ == "__main__": + main() diff --git a/applications/DynaCLR/evaluation/lot_correction/configs/apply_lot_organelle_example.yaml b/applications/DynaCLR/evaluation/lot_correction/configs/apply_lot_organelle_example.yaml new file mode 100644 index 000000000..149143ed0 --- /dev/null +++ b/applications/DynaCLR/evaluation/lot_correction/configs/apply_lot_organelle_example.yaml @@ -0,0 +1,8 @@ +# Example: apply fitted LOT pipeline to the light-sheet organelle zarr +# Output zarr will have corrected embeddings in target (confocal) PCA space +# Shape: (n_cells, n_pca=50) — all obs metadata preserved + +input_zarr: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3/timeaware_organelle_160patch_104ckpt.zarr +pipeline: /tmp/lot_test/lot_organelle_pipeline.pkl +output_zarr: /tmp/lot_test/lot_corrected_organelle.zarr +overwrite: false diff --git a/applications/DynaCLR/evaluation/lot_correction/configs/fit_lot_organelle_example.yaml b/applications/DynaCLR/evaluation/lot_correction/configs/fit_lot_organelle_example.yaml new file mode 100644 index 000000000..deddeb7a1 --- /dev/null +++ b/applications/DynaCLR/evaluation/lot_correction/configs/fit_lot_organelle_example.yaml @@ -0,0 +1,24 @@ +# Example: fit LOT correction pipeline on G3BP1 organelle channel +# Source = light-sheet (LS1 + LS2 combined), Target = confocal 223-patch +# +# LS G3BP1 uninfected wells: C/1/ +# Confocal G3BP1 uninfected: fov_name starts with "G3BP1/uninfected" + +source_zarr: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3/timeaware_organelle_160patch_104ckpt.zarr +target_zarr: /hpc/projects/intracellular_dashboard/organelle_box/2026_03_10_A549_strong_organelles_DENV_ZIKV_time_course/5-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3/organelle_223patch_104ckpt.zarr + +source_uninf_filter: + column: fov_name + startswith: + - "C/1/" + +target_uninf_filter: + column: fov_name + startswith: + - "G3BP1/uninfected" + +n_pca: 50 +ns_lot: 3000 +random_seed: 42 + +output_pipeline: /tmp/lot_test/lot_organelle_pipeline.pkl diff --git a/applications/DynaCLR/evaluation/lot_correction/fit_lot_correction.py b/applications/DynaCLR/evaluation/lot_correction/fit_lot_correction.py new file mode 100644 index 000000000..0f0ff3b82 --- /dev/null +++ b/applications/DynaCLR/evaluation/lot_correction/fit_lot_correction.py @@ -0,0 +1,98 @@ +"""CLI for fitting a LOT batch-correction pipeline on embedding zarrs. + +Usage +----- + viscy-dynaclr fit-lot-correction -c config.yaml + +The fitted pipeline (StandardScaler + PCA + LinearTransport) is saved to +the path specified by ``output_pipeline`` in the config file. + +Example config (YAML) +--------------------- + source_zarr: /path/to/lightsheet_organelle.zarr + target_zarr: /path/to/confocal_organelle.zarr + source_uninf_filter: + column: fov_name + startswith: + - "C/1/" + target_uninf_filter: + column: fov_name + startswith: + - "G3BP1/uninfected" + n_pca: 50 + ns_lot: 3000 + random_seed: 42 + output_pipeline: /path/to/lot_pipeline.pkl +""" + +import logging +from pathlib import Path + +import click +from pydantic import ValidationError + +from viscy.representation.evaluation.lot_correction import ( + fit_lot_correction, + save_lot_pipeline, +) +from viscy.representation.evaluation.lot_correction_config import LotFitConfig +from viscy.utils.cli_utils import load_config + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to YAML configuration file.", +) +def main(config: Path): + """Fit a LOT batch-correction pipeline on source and target embedding zarrs.""" + click.echo("=" * 60) + click.echo("LOT BATCH CORRECTION — FIT") + click.echo("=" * 60) + + try: + config_dict = load_config(config) + fit_config = LotFitConfig(**config_dict) + except ValidationError as e: + click.echo(f"\nConfiguration validation failed:\n{e}", err=True) + raise click.Abort() + except Exception as e: + click.echo(f"\nFailed to load configuration: {e}", err=True) + raise click.Abort() + + click.echo(f"\nConfiguration loaded: {config}") + click.echo(f" Source zarr: {fit_config.source_zarr}") + click.echo(f" Target zarr: {fit_config.target_zarr}") + click.echo(f" n_pca: {fit_config.n_pca}") + click.echo(f" ns_lot: {fit_config.ns_lot}") + click.echo(f" Random seed: {fit_config.random_seed}") + click.echo(f" Output: {fit_config.output_pipeline}") + + try: + pipeline = fit_lot_correction( + source_zarr=fit_config.source_zarr, + target_zarr=fit_config.target_zarr, + source_uninf_filter=fit_config.source_uninf_filter.to_dict(), + target_uninf_filter=fit_config.target_uninf_filter.to_dict(), + n_pca=fit_config.n_pca, + ns_lot=fit_config.ns_lot, + random_seed=fit_config.random_seed, + ) + click.echo( + f"\nPipeline fitted — PCA explained variance: " + f"{pipeline['pca_variance_explained']:.1f}%" + ) + save_lot_pipeline(pipeline, fit_config.output_pipeline) + click.echo(f"Pipeline saved to: {fit_config.output_pipeline}") + except Exception as e: + click.echo(f"\nFitting failed: {e}", err=True) + raise click.Abort() + + +if __name__ == "__main__": + main() diff --git a/applications/DynaCLR/evaluation/mmd/__init__.py b/applications/DynaCLR/evaluation/mmd/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/DynaCLR/evaluation/mmd/compute_mmd.py b/applications/DynaCLR/evaluation/mmd/compute_mmd.py new file mode 100644 index 000000000..dcf84a883 --- /dev/null +++ b/applications/DynaCLR/evaluation/mmd/compute_mmd.py @@ -0,0 +1,132 @@ +"""CLI for computing MMD² between two groups of cell embeddings. + +Usage +----- + viscy-dynaclr compute-mmd -c config.yaml + +The command compares two groups (A and B) defined by obs filters on one or +two AnnData zarrs. An optional ``group_by`` field splits the comparison into +per-group rows (e.g. per organelle per timepoint). + +Example configs +--------------- + +**Biological signal** (ZIKV vs uninfected, same zarr, per organelle/timepoint): + + zarr_a: /path/to/organelle_embeddings.zarr + filter_a: + column: condition + startswith: ["uninfected"] + filter_b: + column: condition + equals: "ZIKV" + group_by: + - organelle + - timepoint + use_pca: true + n_pca: 50 + n_perm: 1000 + max_cells: 2000 + random_seed: 42 + output_csv: mmd_results.csv + +**Batch effect** (light-sheet vs confocal, two zarrs): + + zarr_a: /path/to/lightsheet.zarr + zarr_b: /path/to/confocal.zarr + filter_a: + column: fov_name + startswith: ["C/1/"] + filter_b: + column: fov_name + startswith: ["G3BP1/uninfected"] + group_by: [] + use_pca: true + n_pca: 50 + n_perm: 0 + max_cells: 2000 + random_seed: 42 + output_csv: batch_mmd.csv +""" + +import logging +from pathlib import Path + +import click +from pydantic import ValidationError + +from viscy.representation.evaluation.mmd import compute_mmd +from viscy.representation.evaluation.mmd_config import ComputeMMDConfig +from viscy.utils.cli_utils import load_config + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "-c", + "--config", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to YAML configuration file.", +) +def main(config: Path): + """Compute MMD² between two groups of cell embeddings from AnnData zarrs.""" + click.echo("=" * 60) + click.echo("MMD COMPUTATION") + click.echo("=" * 60) + + try: + config_dict = load_config(config) + mmd_config = ComputeMMDConfig(**config_dict) + except ValidationError as e: + click.echo(f"\nConfiguration validation failed:\n{e}", err=True) + raise click.Abort() + except Exception as e: + click.echo(f"\nFailed to load configuration: {e}", err=True) + raise click.Abort() + + click.echo(f"\nConfiguration loaded: {config}") + click.echo(f" Zarr A: {mmd_config.zarr_a}") + click.echo(f" Zarr B: {mmd_config.zarr_b or '(same as A)'}") + click.echo(f" Filter A: {mmd_config.filter_a}") + click.echo(f" Filter B: {mmd_config.filter_b}") + click.echo(f" Group by: {mmd_config.group_by or '(none — single overall)'}") + click.echo(f" PCA: {'yes, n=' + str(mmd_config.n_pca) if mmd_config.use_pca else 'no'}") + click.echo(f" n_perm: {mmd_config.n_perm or 'skipped'}") + click.echo(f" max_cells: {mmd_config.max_cells}") + click.echo(f" Output: {mmd_config.output_csv}") + + try: + results = compute_mmd( + zarr_a=mmd_config.zarr_a, + zarr_b=mmd_config.zarr_b, + filter_a=mmd_config.filter_a.to_dict() if mmd_config.filter_a else None, + filter_b=mmd_config.filter_b.to_dict() if mmd_config.filter_b else None, + group_by=mmd_config.group_by or None, + use_pca=mmd_config.use_pca, + n_pca=mmd_config.n_pca, + n_perm=mmd_config.n_perm, + max_cells=mmd_config.max_cells, + random_seed=mmd_config.random_seed, + ) + + if results.empty: + click.echo("\nNo results computed — check filters and group_by columns.") + raise click.Abort() + + output_path = Path(mmd_config.output_csv) + output_path.parent.mkdir(parents=True, exist_ok=True) + results.to_csv(output_path, index=False) + click.echo(f"\nResults ({len(results)} rows) written to: {output_path}") + click.echo("\n" + results.to_string(index=False)) + + except click.Abort: + raise + except Exception as e: + click.echo(f"\nMMD computation failed: {e}", err=True) + raise click.Abort() + + +if __name__ == "__main__": + main() diff --git a/applications/DynaCLR/evaluation/mmd/configs/mmd_batch_effect_example.yaml b/applications/DynaCLR/evaluation/mmd/configs/mmd_batch_effect_example.yaml new file mode 100644 index 000000000..bef430ab6 --- /dev/null +++ b/applications/DynaCLR/evaluation/mmd/configs/mmd_batch_effect_example.yaml @@ -0,0 +1,25 @@ +# Example: batch effect — light-sheet vs confocal (two different zarrs) +# Uninfected G3BP1 cells only, no grouping (single overall MMD) + +zarr_a: /hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3/timeaware_organelle_160patch_104ckpt.zarr +zarr_b: /hpc/projects/intracellular_dashboard/organelle_box/2026_03_10_A549_strong_organelles_DENV_ZIKV_time_course/5-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3/organelle_223patch_104ckpt.zarr + +filter_a: + column: fov_name + startswith: + - "C/1/" + +filter_b: + column: fov_name + startswith: + - "G3BP1/uninfected" + +group_by: [] # single overall comparison + +use_pca: true +n_pca: 50 +n_perm: 0 +max_cells: 2000 +random_seed: 42 + +output_csv: /tmp/mmd_test/mmd_batch_effect.csv diff --git a/applications/DynaCLR/evaluation/mmd/configs/mmd_biological_signal_example.yaml b/applications/DynaCLR/evaluation/mmd/configs/mmd_biological_signal_example.yaml new file mode 100644 index 000000000..9ec9d8858 --- /dev/null +++ b/applications/DynaCLR/evaluation/mmd/configs/mmd_biological_signal_example.yaml @@ -0,0 +1,26 @@ +# Example: biological signal — ZIKV vs uninfected per organelle per timepoint +# Both groups come from the same zarr (condition filter on obs) + +zarr_a: /hpc/projects/intracellular_dashboard/organelle_box/2026_03_10_A549_strong_organelles_DENV_ZIKV_time_course/5-phenotyping/predictions/DynaCLR-2D-BagOfChannels-timeaware/v3/organelle_223patch_104ckpt.zarr +# zarr_b omitted — same zarr used for both groups + +filter_a: + column: fov_name + startswith: + - "G3BP1/uninfected" + +filter_b: + column: fov_name + startswith: + - "G3BP1/ZIKV" + +group_by: + - t + +use_pca: true +n_pca: 50 +n_perm: 0 # set to 1000 for p-values (slower) +max_cells: 2000 +random_seed: 42 + +output_csv: /tmp/mmd_test/mmd_biological_signal.csv diff --git a/pyproject.toml b/pyproject.toml index 419e9e7aa..6c708b6a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,7 @@ optional-dependencies.visual = [ ] scripts.viscy = "viscy.cli:main" scripts."viscy-dynaclr" = "viscy.cli_dynaclr:main" +scripts.dynaclr = "viscy.cli_dynaclr:main" [tool.setuptools.packages.find] include = [ "viscy", "applications*" ] diff --git a/viscy/cli_dynaclr.py b/viscy/cli_dynaclr.py index 11d3badbd..4a8d18371 100644 --- a/viscy/cli_dynaclr.py +++ b/viscy/cli_dynaclr.py @@ -48,6 +48,24 @@ def get_params(self, ctx): "help": "Apply a trained linear classifier to new embeddings", "short_help": "Apply linear classifier", }, + { + "name": "fit-lot-correction", + "import_path": "applications.DynaCLR.evaluation.lot_correction.fit_lot_correction.main", + "help": "Fit a LOT batch-correction pipeline on source and target embedding zarrs", + "short_help": "Fit LOT batch correction", + }, + { + "name": "apply-lot-correction", + "import_path": "applications.DynaCLR.evaluation.lot_correction.apply_lot_correction.main", + "help": "Apply a fitted LOT pipeline to correct batch effects in an embedding zarr", + "short_help": "Apply LOT batch correction", + }, + { + "name": "compute-mmd", + "import_path": "applications.DynaCLR.evaluation.mmd.compute_mmd.main", + "help": "Compute MMD² between two groups of cell embeddings from AnnData zarrs", + "short_help": "Compute MMD between two groups", + }, ] diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index 489a09c65..aee28b1a1 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -17,10 +17,30 @@ from viscy.data.hcs import HCSDataModule, _read_norm_meta from viscy.data.select import _filter_fovs, _filter_wells from viscy.data.typing import DictTransform, NormMeta -from viscy.transforms import BatchedCenterSpatialCropd +from viscy.transforms import BatchedCenterSpatialCropd, BatchedRescaleYXd _logger = logging.getLogger("lightning.pytorch") + +def _read_pixel_size(data_path: str | Path) -> float: + """Read the X pixel size (µm/pixel) from the first FOV in an OME-Zarr dataset. + + Parameters + ---------- + data_path : str | Path + Path to the OME-Zarr plate or position. + + Returns + ------- + float + X pixel size in micrometers per pixel. + """ + with open_ome_zarr(data_path, mode="r") as store: + for _, pos in store.positions(): + return float(pos.scale[-1]) + raise ValueError(f"No positions found in {data_path}") + + INDEX_COLUMNS = [ "fov_name", "track_id", @@ -349,6 +369,7 @@ def __init__( pin_memory: bool = False, z_window_size: int | None = None, cache_pool_bytes: int = 0, + reference_pixel_size: float | None = None, ): """Lightning data module for triplet sampling of patches. @@ -363,9 +384,20 @@ def __init__( z_range : tuple[int, int] Range of valid z-slices initial_yx_patch_size : tuple[int, int], optional - XY size of the initially sampled image patch, by default (512, 512) + YX size of the initially sampled image patch, by default (512, 512). + Ignored when ``reference_pixel_size`` is set — the patch size is then + computed automatically from the pixel-size ratio. final_yx_patch_size : tuple[int, int], optional Output patch size, by default (224, 224) + reference_pixel_size : float | None, optional + X pixel size (µm/pixel) of the dataset used to train the model. + When provided the data module reads the pixel size of the inference + dataset from its OME-Zarr metadata and computes + ``initial_yx_patch_size = round(final_yx_patch_size * + reference_pixel_size / inference_pixel_size)`` so that the same + physical area is covered. The extracted patch is then rescaled to + ``final_yx_patch_size`` with bilinear interpolation before being + fed to the model. By default ``None`` (no rescaling). split_ratio : float, optional Ratio of training samples, by default 0.8 batch_size : int, optional @@ -445,11 +477,28 @@ def __init__( self.return_negative = return_negative self.augment_validation = augment_validation self._cache_pool_bytes = cache_pool_bytes + if reference_pixel_size is not None: + inference_pixel_size = _read_pixel_size(data_path) + scale = reference_pixel_size / inference_pixel_size + self.initial_yx_patch_size = tuple( + round(s * scale) for s in final_yx_patch_size + ) + _logger.info( + f"Pixel size rescaling enabled: " + f"reference={reference_pixel_size:.4f} µm/px, " + f"inference={inference_pixel_size:.4f} µm/px, " + f"scale={scale:.4f}. " + f"Extracting {self.initial_yx_patch_size} px patches " + f"and resizing to {final_yx_patch_size} px." + ) + self._rescale_to_final = True + else: + self._rescale_to_final = False self._augmentation_transform = Compose( - self.normalizations + self.augmentations + [self._final_crop()] + self.normalizations + self.augmentations + [self._final_spatial_transform()] ) self._no_augmentation_transform = Compose( - self.normalizations + [self._final_crop()] + self.normalizations + [self._final_spatial_transform()] ) def _align_tracks_tables_with_positions( @@ -586,6 +635,22 @@ def _final_crop(self) -> BatchedCenterSpatialCropd: ), ) + def _final_spatial_transform(self) -> BatchedCenterSpatialCropd | BatchedRescaleYXd: + """Return the final spatial transform. + + When ``reference_pixel_size`` was set at construction time, returns a + bilinear resize from the (larger) ``initial_yx_patch_size`` down to + ``final_yx_patch_size``. Otherwise falls back to a centre crop. + """ + if self._rescale_to_final: + return BatchedRescaleYXd( + keys=self.source_channel, + target_yx_size=self.yx_patch_size, + mode="bilinear", + antialias=True, + ) + return self._final_crop() + def _find_transform(self, key: str): if self.trainer: if self.trainer.predicting: diff --git a/viscy/representation/evaluation/lot_correction.py b/viscy/representation/evaluation/lot_correction.py new file mode 100644 index 000000000..c3c9144d2 --- /dev/null +++ b/viscy/representation/evaluation/lot_correction.py @@ -0,0 +1,317 @@ +"""Core functions for LOT (Linear Optimal Transport) batch correction. + +Pipeline +-------- +1. Load source and target embedding zarrs (AnnData format). +2. Filter cells to the uninfected reference population in each dataset. +3. Fit a shared StandardScaler + PCA on the combined source + target cells. +4. Fit a LinearTransport (LOT) map in PCA space using uninfected cells only, + mapping source → target distribution. +5. Save the fitted pipeline (scaler, PCA, LOT) to disk with joblib. + +The saved pipeline can then be applied to any source zarr to produce a new +zarr whose embeddings are in the target's PCA coordinate system, corrected +for cross-platform batch effects. +""" + +import logging +from pathlib import Path +from typing import Union + +import anndata as ad +import joblib +import numpy as np +import ot +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +_logger = logging.getLogger(__name__) + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def _to_np(X) -> np.ndarray: + """Convert sparse or dense matrix to float32 numpy array.""" + return np.array(X.toarray() if hasattr(X, "toarray") else X, dtype=np.float32) + + +def _apply_filter(obs, filter_spec: dict) -> np.ndarray: + """Return a boolean mask for rows of *obs* matching *filter_spec*. + + Parameters + ---------- + obs : pd.DataFrame + AnnData ``.obs`` table. + filter_spec : dict + Must contain ``"column"`` plus one of: + + * ``"startswith"`` – str or list[str]: keep rows where the column + value starts with any of the given prefixes. + * ``"equals"`` – str: keep rows where the column value equals the + given string. + + Returns + ------- + np.ndarray of bool + Boolean mask with the same length as *obs*. + """ + col = filter_spec["column"] + values = obs[col].astype(str) + + if "startswith" in filter_spec: + prefixes = filter_spec["startswith"] + if isinstance(prefixes, str): + prefixes = [prefixes] + mask = np.zeros(len(obs), dtype=bool) + for p in prefixes: + mask |= values.str.startswith(p).values + return mask + + if "equals" in filter_spec: + return (values == str(filter_spec["equals"])).values + + raise ValueError( + "filter_spec must contain either 'startswith' or 'equals'. " + f"Got: {list(filter_spec.keys())}" + ) + + +# ── public API ──────────────────────────────────────────────────────────────── + +def fit_lot_correction( + source_zarr: Union[str, Path], + target_zarr: Union[str, Path], + source_uninf_filter: dict, + target_uninf_filter: dict, + n_pca: int = 50, + ns_lot: int = 3000, + random_seed: int = 42, +) -> dict: + """Fit a shared PCA + LOT batch-correction pipeline. + + Loads source (e.g. light-sheet) and target (e.g. confocal) embedding + zarrs, identifies uninfected reference cells in each, and fits: + + * A shared ``StandardScaler`` on the combined source + target cells. + * A shared ``PCA`` (``n_pca`` components) on the combined scaled cells. + * A ``LinearTransport`` (POT library) that maps the source uninfected + distribution to the target uninfected distribution in PCA space. + + Parameters + ---------- + source_zarr : str or Path + Path to the source AnnData zarr (e.g. light-sheet embeddings). + target_zarr : str or Path + Path to the target AnnData zarr (e.g. confocal embeddings). + source_uninf_filter : dict + Filter spec (see :func:`_apply_filter`) selecting uninfected source + cells used to fit LOT. + target_uninf_filter : dict + Filter spec selecting uninfected target cells used to fit LOT. + n_pca : int, optional + Number of PCA components, by default 50. + ns_lot : int, optional + Maximum number of cells subsampled per dataset for LOT fitting, + by default 3000. + random_seed : int, optional + Random seed for reproducibility, by default 42. + + Returns + ------- + dict with keys: + ``"scaler"`` : fitted StandardScaler + ``"pca"`` : fitted PCA + ``"lot"`` : fitted ot.da.LinearTransport + ``"n_pca"`` : int + ``"ns_lot"`` : int + ``"random_seed"`` : int + """ + rng = np.random.default_rng(random_seed) + + _logger.info("Loading source zarr: %s", source_zarr) + adata_src = ad.read_zarr(source_zarr) + adata_src.obs_names_make_unique() + + _logger.info("Loading target zarr: %s", target_zarr) + adata_tgt = ad.read_zarr(target_zarr) + adata_tgt.obs_names_make_unique() + + _logger.info( + "Source shape: %s Target shape: %s", adata_src.shape, adata_tgt.shape + ) + + X_src = _to_np(adata_src.X) + X_tgt = _to_np(adata_tgt.X) + + src_uninf_mask = _apply_filter(adata_src.obs, source_uninf_filter) + tgt_uninf_mask = _apply_filter(adata_tgt.obs, target_uninf_filter) + + _logger.info( + "Uninfected cells — source: %d / %d, target: %d / %d", + src_uninf_mask.sum(), len(X_src), + tgt_uninf_mask.sum(), len(X_tgt), + ) + + if src_uninf_mask.sum() < 5 or tgt_uninf_mask.sum() < 5: + raise ValueError( + "Too few uninfected cells to fit LOT " + f"(source={src_uninf_mask.sum()}, target={tgt_uninf_mask.sum()}). " + "Check your filter specifications." + ) + + # Fit shared PCA on ALL source + target cells + _logger.info("Fitting shared StandardScaler + PCA-%d ...", n_pca) + scaler = StandardScaler() + X_combined_scaled = scaler.fit_transform(np.vstack([X_src, X_tgt])) + pca = PCA(n_components=n_pca, random_state=random_seed) + Z_all = pca.fit_transform(X_combined_scaled) + var_exp = pca.explained_variance_ratio_.sum() * 100 + _logger.info("PCA explained variance: %.1f%%", var_exp) + + n_src = len(X_src) + Z_src_uninf = Z_all[:n_src][src_uninf_mask] + Z_tgt_uninf = Z_all[n_src:][tgt_uninf_mask] + + # Subsample for LOT fitting + ns_src = min(len(Z_src_uninf), ns_lot) + ns_tgt = min(len(Z_tgt_uninf), ns_lot) + idx_src = rng.choice(len(Z_src_uninf), ns_src, replace=False) + idx_tgt = rng.choice(len(Z_tgt_uninf), ns_tgt, replace=False) + + _logger.info( + "Fitting LOT (source subsample=%d, target subsample=%d) ...", ns_src, ns_tgt + ) + lot = ot.da.LinearTransport(reg=1e-3) + lot.fit(Xs=Z_src_uninf[idx_src], Xt=Z_tgt_uninf[idx_tgt]) + _logger.info("LOT fitted.") + + return { + "scaler": scaler, + "pca": pca, + "lot": lot, + "n_pca": n_pca, + "ns_lot": ns_lot, + "random_seed": random_seed, + "pca_variance_explained": float(var_exp), + } + + +def apply_lot_correction( + input_zarr: Union[str, Path], + pipeline: dict, + output_zarr: Union[str, Path], + overwrite: bool = False, +) -> None: + """Apply a fitted LOT pipeline to an embedding zarr. + + Transforms all cells in *input_zarr* through the pipeline + (StandardScaler → PCA → LOT) and writes an AnnData zarr whose ``.X`` + contains the corrected embeddings in the target's PCA space + (shape ``n_cells × n_pca``). All ``.obs`` metadata is preserved. + + Parameters + ---------- + input_zarr : str or Path + Path to the source AnnData zarr to correct. + pipeline : dict + Fitted pipeline as returned by :func:`fit_lot_correction`. + output_zarr : str or Path + Path to write the corrected AnnData zarr. + overwrite : bool, optional + If ``False`` (default) and *output_zarr* already exists, raise. + """ + output_zarr = Path(output_zarr) + if output_zarr.exists(): + if not overwrite: + raise FileExistsError( + f"Output path already exists: {output_zarr}. " + "Set overwrite=true to overwrite." + ) + import shutil + shutil.rmtree(output_zarr) + + _logger.info("Loading input zarr: %s", input_zarr) + adata_in = ad.read_zarr(input_zarr) + adata_in.obs_names_make_unique() + + X = _to_np(adata_in.X) + _logger.info("Input shape: %s", adata_in.shape) + + scaler = pipeline["scaler"] + pca = pipeline["pca"] + lot = pipeline["lot"] + + _logger.info("Applying StandardScaler → PCA → LOT ...") + Z = pca.transform(scaler.transform(X)) + Z_corrected = lot.transform(Z) + _logger.info( + "Corrected embeddings shape: %s (n_pca=%d)", Z_corrected.shape, pipeline["n_pca"] + ) + + obs = adata_in.obs.copy() + # Convert StringDtype columns (and categoricals with StringDtype categories) + # to object dtype for broad anndata / zarr compatibility. + import pandas as pd + for col in obs.columns: + dtype = obs[col].dtype + if isinstance(dtype, pd.StringDtype): + obs[col] = obs[col].astype(object) + elif isinstance(dtype, pd.CategoricalDtype) and isinstance( + dtype.categories.dtype, pd.StringDtype + ): + obs[col] = obs[col].astype(object).astype("category") + + # Also enable anndata's opt-in for writing nullable string arrays in case + # any StringArray-backed column is still present after the conversion above. + try: + ad.settings.allow_write_nullable_strings = True + except AttributeError: + pass # older anndata versions don't have this setting + + adata_out = ad.AnnData(X=Z_corrected.astype(np.float32), obs=obs) + adata_out.uns["lot_correction"] = { + "source_zarr": str(input_zarr), + "n_pca": pipeline["n_pca"], + "pca_variance_explained": pipeline.get("pca_variance_explained"), + } + + _logger.info("Writing corrected zarr: %s", output_zarr) + adata_out.write_zarr(output_zarr) + _logger.info("Done.") + + +def save_lot_pipeline(pipeline: dict, path: Union[str, Path]) -> None: + """Save a fitted LOT pipeline to disk using joblib. + + Parameters + ---------- + pipeline : dict + Fitted pipeline as returned by :func:`fit_lot_correction`. + path : str or Path + Output path (e.g. ``lot_pipeline.pkl``). + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + joblib.dump(pipeline, path) + _logger.info("Pipeline saved to %s", path) + + +def load_lot_pipeline(path: Union[str, Path]) -> dict: + """Load a fitted LOT pipeline from disk. + + Parameters + ---------- + path : str or Path + Path to the saved pipeline file. + + Returns + ------- + dict + Pipeline with keys ``"scaler"``, ``"pca"``, ``"lot"``. + """ + pipeline = joblib.load(path) + _logger.info( + "Pipeline loaded from %s (n_pca=%d, pca_var=%.1f%%)", + path, pipeline["n_pca"], pipeline.get("pca_variance_explained", float("nan")), + ) + return pipeline diff --git a/viscy/representation/evaluation/lot_correction_config.py b/viscy/representation/evaluation/lot_correction_config.py new file mode 100644 index 000000000..4c216801f --- /dev/null +++ b/viscy/representation/evaluation/lot_correction_config.py @@ -0,0 +1,121 @@ +"""Pydantic configuration models for LOT batch correction.""" + +from pathlib import Path +from typing import Optional, Union + +from pydantic import BaseModel, Field, field_validator, model_validator + + +class UninfFilter(BaseModel): + """Specification for selecting uninfected reference cells from an obs table. + + Exactly one of ``startswith`` or ``equals`` must be provided. + + Parameters + ---------- + column : str + Name of the ``.obs`` column to filter on (e.g. ``"fov_name"``). + startswith : str or list[str], optional + Keep cells whose column value starts with any of these prefixes. + equals : str, optional + Keep cells whose column value equals this string. + """ + + column: str = Field(..., min_length=1) + startswith: Optional[Union[str, list[str]]] = Field(default=None) + equals: Optional[str] = Field(default=None) + + @model_validator(mode="after") + def exactly_one_filter(self): + has_sw = self.startswith is not None + has_eq = self.equals is not None + if not has_sw and not has_eq: + raise ValueError("UninfFilter must specify either 'startswith' or 'equals'.") + if has_sw and has_eq: + raise ValueError("UninfFilter must specify only one of 'startswith' or 'equals'.") + return self + + def to_dict(self) -> dict: + """Convert to the dict format expected by _apply_filter.""" + d = {"column": self.column} + if self.startswith is not None: + d["startswith"] = self.startswith + else: + d["equals"] = self.equals + return d + + +class LotFitConfig(BaseModel): + """Configuration for fitting a LOT batch-correction pipeline. + + Parameters + ---------- + source_zarr : str + Path to the source AnnData zarr (e.g. light-sheet embeddings). + target_zarr : str + Path to the target AnnData zarr (e.g. confocal embeddings). + source_uninf_filter : UninfFilter + Filter identifying uninfected cells in the source dataset. + target_uninf_filter : UninfFilter + Filter identifying uninfected cells in the target dataset. + n_pca : int, optional + Number of PCA components for the shared PCA, by default 50. + ns_lot : int, optional + Maximum cells subsampled per dataset for LOT fitting, by default 3000. + random_seed : int, optional + Random seed, by default 42. + output_pipeline : str + Path to save the fitted pipeline (joblib pickle). + """ + + source_zarr: str = Field(..., min_length=1) + target_zarr: str = Field(..., min_length=1) + source_uninf_filter: UninfFilter + target_uninf_filter: UninfFilter + n_pca: int = Field(default=50, gt=0) + ns_lot: int = Field(default=3000, gt=0) + random_seed: int = Field(default=42) + output_pipeline: str = Field(..., min_length=1) + + @model_validator(mode="after") + def validate_paths(self): + if not Path(self.source_zarr).exists(): + raise ValueError(f"source_zarr not found: {self.source_zarr}") + if not Path(self.target_zarr).exists(): + raise ValueError(f"target_zarr not found: {self.target_zarr}") + return self + + +class LotApplyConfig(BaseModel): + """Configuration for applying a fitted LOT pipeline to a zarr. + + Parameters + ---------- + input_zarr : str + Path to the source AnnData zarr to correct. + pipeline : str + Path to the fitted pipeline file (joblib pickle). + output_zarr : str + Path to write the corrected AnnData zarr. + overwrite : bool, optional + Overwrite output if it exists, by default False. + """ + + input_zarr: str = Field(..., min_length=1) + pipeline: str = Field(..., min_length=1) + output_zarr: str = Field(..., min_length=1) + overwrite: bool = Field(default=False) + + @model_validator(mode="after") + def validate_paths(self): + if not Path(self.input_zarr).exists(): + raise ValueError(f"input_zarr not found: {self.input_zarr}") + if not Path(self.pipeline).exists(): + raise ValueError(f"pipeline file not found: {self.pipeline}") + output = Path(self.output_zarr) + if output.exists() and not self.overwrite: + raise ValueError( + f"output_zarr already exists: {self.output_zarr}. " + "Set overwrite: true to overwrite." + ) + return self diff --git a/viscy/representation/evaluation/mmd.py b/viscy/representation/evaluation/mmd.py new file mode 100644 index 000000000..375343448 --- /dev/null +++ b/viscy/representation/evaluation/mmd.py @@ -0,0 +1,355 @@ +"""Maximum Mean Discrepancy (MMD) computation for embedding zarrs. + +Implements the unbiased RBF-kernel MMD² estimator with median-heuristic +bandwidth selection and an optional permutation test for p-values. + +This module is the library backend used by the ``compute-mmd`` CLI command. +""" + +import logging +from itertools import product +from pathlib import Path +from typing import Optional, Union + +import anndata as ad +import numpy as np +import pandas as pd +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +_logger = logging.getLogger(__name__) + + +# ── low-level kernel helpers ────────────────────────────────────────────────── + +def _sq_dist(A: np.ndarray, B: np.ndarray) -> np.ndarray: + """Pairwise squared Euclidean distances (n × m).""" + A2 = np.sum(A * A, axis=1, keepdims=True) + B2 = np.sum(B * B, axis=1, keepdims=True).T + return np.maximum(A2 + B2 - 2.0 * (A @ B.T), 0.0) + + +def _rbf_kernel(A: np.ndarray, B: np.ndarray, gamma: float) -> np.ndarray: + return np.exp(-gamma * _sq_dist(A, B)) + + +def median_heuristic_gamma( + X: np.ndarray, + Y: np.ndarray, + max_points: int = 2000, + rng: Union[int, np.random.Generator] = 0, +) -> float: + """Estimate RBF bandwidth via the median heuristic. + + Parameters + ---------- + X, Y : np.ndarray + Sample arrays. + max_points : int + Subsample size used for the distance computation. + rng : int or Generator + Random seed or generator. + + Returns + ------- + float + ``gamma = 1 / (2 * median_squared_distance)``. + """ + if isinstance(rng, int): + rng = np.random.default_rng(rng) + Z = np.vstack([X, Y]) + if Z.shape[0] > max_points: + Z = Z[rng.choice(Z.shape[0], size=max_points, replace=False)] + D2 = _sq_dist(Z, Z) + tri = D2[np.triu_indices(D2.shape[0], k=1)] + med = np.median(tri) + return 1.0 / (2.0 * med + 1e-12) + + +def mmd2_unbiased( + X: np.ndarray, + Y: np.ndarray, + gamma: float, +) -> float: + """Unbiased MMD² estimate with RBF kernel. + + Parameters + ---------- + X : np.ndarray, shape (n, d) + Y : np.ndarray, shape (m, d) + gamma : float + RBF kernel bandwidth parameter. + + Returns + ------- + float + Unbiased MMD². + """ + n, m = X.shape[0], Y.shape[0] + Kxx = _rbf_kernel(X, X, gamma); np.fill_diagonal(Kxx, 0.0) + Kyy = _rbf_kernel(Y, Y, gamma); np.fill_diagonal(Kyy, 0.0) + Kxy = _rbf_kernel(X, Y, gamma) + return ( + Kxx.sum() / (n * (n - 1)) + + Kyy.sum() / (m * (m - 1)) + - 2.0 * Kxy.mean() + ) + + +def mmd_permutation_test( + X: np.ndarray, + Y: np.ndarray, + gamma: Optional[float] = None, + n_perm: int = 1000, + max_cells: int = 2000, + rng: Union[int, np.random.Generator] = 0, +) -> tuple[float, Optional[float], float]: + """Compute MMD² with optional permutation test p-value. + + Parameters + ---------- + X : np.ndarray, shape (n, d) + Embeddings for group A. + Y : np.ndarray, shape (m, d) + Embeddings for group B. + gamma : float or None + RBF gamma; estimated via median heuristic if None. + n_perm : int + Number of permutations. Set to 0 to skip the test (p = None). + max_cells : int + Subsample each group to at most this many cells before computing. + rng : int or Generator + Random seed or generator. + + Returns + ------- + (mmd2, p_value, gamma) + ``p_value`` is None when ``n_perm == 0``. + """ + if isinstance(rng, int): + rng = np.random.default_rng(rng) + + # Subsample + if len(X) > max_cells: + X = X[rng.choice(len(X), max_cells, replace=False)] + if len(Y) > max_cells: + Y = Y[rng.choice(len(Y), max_cells, replace=False)] + + X = X.astype(np.float64) + Y = Y.astype(np.float64) + + if gamma is None: + gamma = median_heuristic_gamma(X, Y, rng=rng) + + observed = mmd2_unbiased(X, Y, gamma) + + if n_perm == 0: + return observed, None, gamma + + Z = np.vstack([X, Y]) + n = len(X) + perm_stats = np.empty(n_perm, dtype=float) + for b in range(n_perm): + idx = rng.permutation(len(Z)) + perm_stats[b] = mmd2_unbiased(Z[idx[:n]], Z[idx[n:]], gamma) + + # +1 smoothing (Phipson & Smyth 2010) + p_value = (np.sum(perm_stats >= observed) + 1) / (n_perm + 1) + return observed, p_value, gamma + + +# ── data helpers ────────────────────────────────────────────────────────────── + +def _to_np(X) -> np.ndarray: + return np.array(X.toarray() if hasattr(X, "toarray") else X, dtype=np.float32) + + +def _apply_filter(obs: pd.DataFrame, filter_spec: Optional[dict]) -> np.ndarray: + """Return boolean mask for obs rows matching filter_spec. + + If filter_spec is None, all rows are selected. + Supported keys: ``startswith`` (str or list[str]) or ``equals`` (str), + paired with ``column``. + """ + if filter_spec is None: + return np.ones(len(obs), dtype=bool) + + col = filter_spec["column"] + values = obs[col].astype(str) + + if "startswith" in filter_spec: + prefixes = filter_spec["startswith"] + if isinstance(prefixes, str): + prefixes = [prefixes] + mask = np.zeros(len(obs), dtype=bool) + for p in prefixes: + mask |= values.str.startswith(p).values + return mask + + if "equals" in filter_spec: + return (values == str(filter_spec["equals"])).values + + raise ValueError( + "filter_spec must contain 'startswith' or 'equals'. " + f"Got: {list(filter_spec.keys())}" + ) + + +# ── main computation ────────────────────────────────────────────────────────── + +def compute_mmd( + zarr_a: Union[str, Path], + zarr_b: Optional[Union[str, Path]], + filter_a: Optional[dict], + filter_b: Optional[dict], + group_by: Optional[list[str]], + use_pca: bool, + n_pca: int, + n_perm: int, + max_cells: int, + random_seed: int, +) -> pd.DataFrame: + """Compute MMD² between two groups of embeddings. + + Loads zarr(s), applies obs filters, optionally groups by obs columns, + fits a shared PCA if requested, and computes MMD² (with optional + permutation p-value) for every group combination. + + Parameters + ---------- + zarr_a : str or Path + Path to AnnData zarr for group A. + zarr_b : str or Path or None + Path to AnnData zarr for group B. If None or same as zarr_a, + the single zarr is loaded once and both filters are applied to it. + filter_a : dict or None + Obs filter for group A (see :func:`_apply_filter`). + filter_b : dict or None + Obs filter for group B. + group_by : list[str] or None + Obs column names to stratify comparisons by. MMD is computed + separately for each unique value combination found in **both** groups. + Pass None or an empty list for a single overall comparison. + use_pca : bool + Fit shared PCA on the combined (filtered) embeddings before MMD. + n_pca : int + Number of PCA components. + n_perm : int + Permutation test iterations (0 = skip, p_value will be NaN). + max_cells : int + Maximum cells per group passed to the MMD kernel computation. + random_seed : int + Global random seed. + + Returns + ------- + pd.DataFrame + One row per group combination with columns: + ``[*group_by, n_a, n_b, mmd2, p_value, gamma]``. + """ + rng = np.random.default_rng(random_seed) + + # Load data + same_zarr = zarr_b is None or str(zarr_b) == str(zarr_a) + + _logger.info("Loading zarr A: %s", zarr_a) + adata_a = ad.read_zarr(zarr_a) + adata_a.obs_names_make_unique() + + if same_zarr: + adata_b = adata_a + _logger.info("Using same zarr for group B") + else: + _logger.info("Loading zarr B: %s", zarr_b) + adata_b = ad.read_zarr(zarr_b) + adata_b.obs_names_make_unique() + + mask_a = _apply_filter(adata_a.obs, filter_a) + mask_b = _apply_filter(adata_b.obs, filter_b) + _logger.info( + "Cells after filtering — A: %d / %d, B: %d / %d", + mask_a.sum(), len(adata_a), + mask_b.sum(), len(adata_b), + ) + + X_a_full = _to_np(adata_a.X[mask_a]) + X_b_full = _to_np(adata_b.X[mask_b]) + obs_a = adata_a.obs[mask_a].reset_index(drop=True) + obs_b = adata_b.obs[mask_b].reset_index(drop=True) + + # PCA + if use_pca: + _logger.info("Fitting shared PCA-%d on combined filtered cells ...", n_pca) + sc = StandardScaler() + pca = PCA(n_components=n_pca, random_state=random_seed) + combined = np.vstack([X_a_full, X_b_full]) + Z_all = pca.fit_transform(sc.fit_transform(combined)) + var_exp = pca.explained_variance_ratio_.sum() * 100 + _logger.info("PCA explained variance: %.1f%%", var_exp) + X_a_full = Z_all[:len(X_a_full)].astype(np.float32) + X_b_full = Z_all[len(X_a_full):].astype(np.float32) + + # Build group combinations + if not group_by: + groups = [{}] # single overall comparison + else: + # Find unique values per column in group A, intersect with group B + col_vals = [] + for col in group_by: + vals_a = set(obs_a[col].astype(str).unique()) + vals_b = set(obs_b[col].astype(str).unique()) + common = sorted(vals_a & vals_b) + if not common: + _logger.warning( + "No common values for group_by column '%s' — skipping.", col + ) + return pd.DataFrame() + col_vals.append(common) + groups = [ + dict(zip(group_by, combo)) + for combo in product(*col_vals) + ] + + _logger.info( + "Computing MMD for %d group combination(s) ...", len(groups) + ) + + rows = [] + for group in groups: + # Build mask for this group + ga = np.ones(len(obs_a), dtype=bool) + gb = np.ones(len(obs_b), dtype=bool) + for col, val in group.items(): + ga &= obs_a[col].astype(str) == val + gb &= obs_b[col].astype(str) == val + + Xa = X_a_full[ga] + Xb = X_b_full[gb] + + label = ", ".join(f"{k}={v}" for k, v in group.items()) or "overall" + if len(Xa) < 5 or len(Xb) < 5: + _logger.warning( + "Skipping %s: too few cells (A=%d, B=%d)", label, len(Xa), len(Xb) + ) + continue + + mmd2, p_val, gamma = mmd_permutation_test( + Xa, Xb, + gamma=None, + n_perm=n_perm, + max_cells=max_cells, + rng=rng, + ) + + row = {**group, "n_a": len(Xa), "n_b": len(Xb), + "mmd2": mmd2, + "p_value": p_val if p_val is not None else float("nan"), + "gamma": gamma} + rows.append(row) + _logger.info( + " %-40s n_a=%5d n_b=%5d MMD²=%.5f p=%.4f", + label, len(Xa), len(Xb), mmd2, + p_val if p_val is not None else float("nan"), + ) + + return pd.DataFrame(rows) diff --git a/viscy/representation/evaluation/mmd_config.py b/viscy/representation/evaluation/mmd_config.py new file mode 100644 index 000000000..a2793fda9 --- /dev/null +++ b/viscy/representation/evaluation/mmd_config.py @@ -0,0 +1,67 @@ +"""Pydantic configuration model for the compute-mmd CLI.""" + +from pathlib import Path +from typing import Optional, Union + +from pydantic import BaseModel, Field, model_validator + +from viscy.representation.evaluation.lot_correction_config import UninfFilter + + +class ComputeMMDConfig(BaseModel): + """Configuration for computing MMD between two groups of embeddings. + + Parameters + ---------- + zarr_a : str + Path to AnnData zarr for group A (e.g. uninfected, light-sheet). + zarr_b : str or None + Path to AnnData zarr for group B. When omitted (or same path as + ``zarr_a``), a single zarr is loaded and both filters are applied to it. + filter_a : UninfFilter or None + Obs filter selecting cells for group A. If omitted, all cells are used. + filter_b : UninfFilter or None + Obs filter selecting cells for group B. If omitted, all cells are used. + group_by : list[str] + Obs column names to stratify the comparison by (e.g. + ``["organelle", "timepoint"]``). Leave empty for a single overall MMD. + use_pca : bool + Fit a shared PCA on the combined filtered embeddings before MMD, + by default True. + n_pca : int + Number of PCA components, by default 50. + n_perm : int + Permutation test iterations for p-value estimation, by default 1000. + Set to 0 to skip the permutation test (p_value will be NaN). + max_cells : int + Maximum cells per group passed to the MMD kernel, by default 2000. + random_seed : int + Random seed for reproducibility, by default 42. + output_csv : str + Path to write the results CSV. + """ + + zarr_a: str = Field(..., min_length=1) + zarr_b: Optional[str] = Field(default=None) + + filter_a: Optional[UninfFilter] = Field(default=None) + filter_b: Optional[UninfFilter] = Field(default=None) + + group_by: list[str] = Field(default_factory=list) + + use_pca: bool = Field(default=True) + n_pca: int = Field(default=50, gt=0) + + n_perm: int = Field(default=1000, ge=0) + max_cells: int = Field(default=2000, gt=0) + random_seed: int = Field(default=42) + + output_csv: str = Field(..., min_length=1) + + @model_validator(mode="after") + def validate_zarr_paths(self): + if not Path(self.zarr_a).exists(): + raise ValueError(f"zarr_a not found: {self.zarr_a}") + if self.zarr_b is not None and not Path(self.zarr_b).exists(): + raise ValueError(f"zarr_b not found: {self.zarr_b}") + return self diff --git a/viscy/transforms/__init__.py b/viscy/transforms/__init__.py index 78a9b2487..9968c6174 100644 --- a/viscy/transforms/__init__.py +++ b/viscy/transforms/__init__.py @@ -48,7 +48,12 @@ StackChannelsd, TiledSpatialCropSamplesd, ) -from viscy.transforms._zoom import BatchedZoom, BatchedZoomd +from viscy.transforms._zoom import ( + BatchedRescaleYX, + BatchedRescaleYXd, + BatchedZoom, + BatchedZoomd, +) from viscy.transforms.batched_rand_3d_elasticd import BatchedRand3DElasticd from viscy.transforms.batched_rand_histogram_shiftd import BatchedRandHistogramShiftd from viscy.transforms.batched_rand_local_pixel_shufflingd import ( @@ -80,6 +85,8 @@ "BatchedRandZStackShiftd", "BatchedScaleIntensityRangePercentiles", "BatchedScaleIntensityRangePercentilesd", + "BatchedRescaleYX", + "BatchedRescaleYXd", "BatchedZoom", "BatchedZoomd", "CenterSpatialCropd", diff --git a/viscy/transforms/_zoom.py b/viscy/transforms/_zoom.py index c7964dd3a..706176bd8 100644 --- a/viscy/transforms/_zoom.py +++ b/viscy/transforms/_zoom.py @@ -6,6 +6,84 @@ from typing_extensions import Literal +class BatchedRescaleYX(Transform): + """Rescale the YX spatial dimensions of a batched 5D tensor (B, C, Z, Y, X). + + Merges the batch and Z dimensions before calling + ``torch.nn.functional.interpolate`` (bilinear, 4-D) and then restores the + original shape. This avoids trilinear interpolation so that ``antialias`` + can be used, which is important for downscaling. + + Parameters + ---------- + target_yx_size : tuple[int, int] + Target (Y, X) output size in pixels. + mode : str, optional + Interpolation mode passed to ``F.interpolate``, by default ``"bilinear"``. + antialias : bool, optional + Apply an anti-aliasing filter before downscaling, by default ``True``. + """ + + def __init__( + self, + target_yx_size: tuple[int, int], + mode: str = "bilinear", + antialias: bool = True, + ) -> None: + self.target_yx_size = target_yx_size + self.mode = mode + self.antialias = antialias + + def __call__(self, sample: Tensor) -> Tensor: + b, c, z, y, x = sample.shape + # Merge batch and Z for 4-D bilinear interpolation + flat = sample.reshape(b * z, c, y, x) + resized = torch.nn.functional.interpolate( + flat.float(), + size=self.target_yx_size, + mode=self.mode, + align_corners=False, + antialias=self.antialias, + ) + return resized.view(b, c, z, *self.target_yx_size) + + +class BatchedRescaleYXd(MapTransform): + """Dictionary wrapper of :py:class:`BatchedRescaleYX`. + + Parameters + ---------- + keys : Sequence[str] + Keys to apply the transform to. + target_yx_size : tuple[int, int] + Target (Y, X) output size in pixels. + mode : str, optional + Interpolation mode, by default ``"bilinear"``. + antialias : bool, optional + Apply anti-aliasing filter before downscaling, by default ``True``. + """ + + def __init__( + self, + keys: Sequence[str], + target_yx_size: tuple[int, int], + mode: str = "bilinear", + antialias: bool = True, + ) -> None: + super().__init__(keys) + self.transform = BatchedRescaleYX( + target_yx_size=target_yx_size, + mode=mode, + antialias=antialias, + ) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.transform(d[key]) + return d + + class BatchedZoom(Transform): "Batched zoom transform using ``torch.nn.functional.interpolate``."