diff --git a/applications/DynaCLR/evaluation/linear_classifiers/README.md b/applications/DynaCLR/evaluation/linear_classifiers/README.md index 7cdbdf547..79adadf33 100644 --- a/applications/DynaCLR/evaluation/linear_classifiers/README.md +++ b/applications/DynaCLR/evaluation/linear_classifiers/README.md @@ -13,6 +13,9 @@ This directory contains: | `generate_train_config.py` | Generates training YAML configs for all valid task x channel combinations | | `train_linear_classifier.py` | CLI for training a classifier from a config | | `apply_linear_classifier.py` | CLI for applying a trained classifier to new embeddings | +| `evaluate_dataset.py` | Cross-dataset evaluation pipeline: train, infer, evaluate, and generate PDF report comparing models (e.g. 2D vs 3D) | +| `cross_validation.py` | Leave-one-dataset-out cross-validation to identify which training datasets help or hurt classifier performance | +| `report.py` | PDF report generation for the evaluation pipeline | ## Prerequisites @@ -162,6 +165,50 @@ linear-classifier-{task}-{channel}[-pca{n}] Examples: `linear-classifier-cell_death_state-phase`, `linear-classifier-infection_state-sensor-pca32` +## Evaluation Pipeline (`evaluate_dataset.py`) + +Compares embedding models (e.g. 2D vs 3D) by training linear classifiers on pooled cross-dataset embeddings and evaluating on a held-out test dataset. Runs as a script, not a CLI. + +```bash +# Full pipeline +python evaluate_dataset.py + +# Skip training (reuse saved pipelines) +python evaluate_dataset.py --skip-train + +# Skip training + inference (reuse saved predictions, only evaluate + report) +python evaluate_dataset.py --skip-infer +``` + +### Task and channel selection + +`task_channels` controls which tasks to evaluate and which channels to use for each. When `None` (default), tasks are auto-detected from the test annotations CSV and all channels (phase, sensor, organelle) are used for each. + +```python +# Default: auto-detect tasks, all channels +config = DatasetEvalConfig(..., task_channels=None) + +# Explicit: specific channels per task +config = DatasetEvalConfig( + ..., + task_channels={ + "cell_division_state": ["phase"], + "infection_state": ["sensor", "phase"], + "organelle_state": ["organelle"], + }, +) +``` + +## Cross-Validation (`cross_validation.py`) + +Leave-one-dataset-out cross-validation to identify which training datasets help or hurt classifier performance. For each (model, task, channel), trains a baseline on all datasets, then re-trains with each dataset excluded. Reports delta AUROC, minority class F1, and annotation counts per run. + +```bash +python cross_validation.py +``` + +Key metrics: AUROC (primary ranking), minority class F1/recall (rare event detection), per-class annotation counts (data provenance). + ## Further Reference See `annotations_and_linear_classifiers.md` for the full specification of the annotations schema and naming conventions. diff --git a/applications/DynaCLR/evaluation/linear_classifiers/cross_validation.py b/applications/DynaCLR/evaluation/linear_classifiers/cross_validation.py new file mode 100644 index 000000000..fb9651157 --- /dev/null +++ b/applications/DynaCLR/evaluation/linear_classifiers/cross_validation.py @@ -0,0 +1,922 @@ +"""Leave-one-dataset-out cross-validation for training dataset impact analysis. + +Trains linear classifiers on subsets of the training pool (dropping one dataset +at a time) and evaluates on the held-out test set. Produces raw results, +aggregated summaries, and a PDF report identifying which training datasets +help, hurt, or have uncertain impact. +""" + +from __future__ import annotations + +import logging +import warnings +from pathlib import Path +from typing import Any + +import anndata as ad +import numpy as np +import pandas as pd +from sklearn.metrics import classification_report, roc_auc_score + +from viscy.representation.evaluation import load_annotation_anndata +from viscy.representation.evaluation.linear_classifier import ( + load_and_combine_datasets, + predict_with_classifier, + train_linear_classifier, +) + +from .evaluate_dataset import ( + DatasetEvalConfig, + _find_channel_zarrs, + _get_available_tasks, + _resolve_task_channels, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _detect_n_features(config: DatasetEvalConfig) -> int: + """Read n_features from the first available training zarr. + + Parameters + ---------- + config : DatasetEvalConfig + Evaluation configuration. + + Returns + ------- + int + Number of embedding dimensions (columns in X). + """ + for model_spec in config.models.values(): + for train_ds in model_spec.train_datasets: + zarrs = _find_channel_zarrs( + train_ds.embeddings_dir, ["phase", "sensor", "organelle"] + ) + for zarr_path in zarrs.values(): + adata = ad.read_zarr(zarr_path) + return adata.X.shape[1] + raise RuntimeError("Could not detect n_features: no training zarrs found.") + + +def _check_class_safety( + datasets_for_combo: list[dict], + task: str, + min_class_samples: int, +) -> bool: + """Check if the remaining dataset subset has enough samples per class. + + Parameters + ---------- + datasets_for_combo : list[dict] + Dataset dicts with 'embeddings' and 'annotations' keys. + task : str + Classification task column name. + min_class_samples : int + Minimum required samples for each class. + + Returns + ------- + bool + True if all classes meet the minimum threshold. + """ + all_labels = [] + for ds in datasets_for_combo: + ann = pd.read_csv(ds["annotations"]) + if task in ann.columns: + valid = ann[task].dropna() + valid = valid[valid != "unknown"] + all_labels.extend(valid.tolist()) + + if not all_labels: + return False + + class_counts = pd.Series(all_labels).value_counts() + return bool((class_counts >= min_class_samples).all()) + + +def _get_dataset_name(train_ds) -> str: + """Extract a human-readable dataset name from a TrainDataset.""" + return train_ds.annotations.parent.name + + +def _build_datasets_for_combo( + train_datasets, channel: str, task: str +) -> list[tuple[Any, dict]]: + """Build (train_ds, dataset_dict) pairs for a given channel and task. + + Returns + ------- + list[tuple] + Each element is (train_ds_object, {"embeddings": ..., "annotations": ...}). + """ + result = [] + for train_ds in train_datasets: + channel_zarrs = _find_channel_zarrs(train_ds.embeddings_dir, [channel]) + if channel not in channel_zarrs: + continue + available_tasks = _get_available_tasks(train_ds.annotations) + if task not in available_tasks: + continue + result.append( + ( + train_ds, + { + "embeddings": str(channel_zarrs[channel]), + "annotations": str(train_ds.annotations), + }, + ) + ) + return result + + +def _get_class_counts(datasets_for_combo: list[dict], task: str) -> dict[str, int]: + """Count per-class samples across datasets.""" + all_labels = [] + for ds in datasets_for_combo: + ann = pd.read_csv(ds["annotations"]) + if task in ann.columns: + valid = ann[task].dropna() + valid = valid[valid != "unknown"] + all_labels.extend(valid.tolist()) + return dict(pd.Series(all_labels).value_counts()) + + +# --------------------------------------------------------------------------- +# Core CV unit +# --------------------------------------------------------------------------- + + +def _train_and_evaluate( + config: DatasetEvalConfig, + model_label: str, + task: str, + channel: str, + datasets_for_combo: list[dict], + seed: int, + excluded_dataset: str | None = None, +) -> dict[str, Any]: + """Train on a subset and evaluate on test. Returns a flat result dict. + + Parameters + ---------- + config : DatasetEvalConfig + Evaluation configuration. + model_label : str + Model key in config.models. + task : str + Classification task. + channel : str + Input channel. + datasets_for_combo : list[dict] + Training datasets (already filtered for channel/task). + seed : int + Random seed for this run. + excluded_dataset : str or None + Name of excluded dataset, or None for baseline. + + Returns + ------- + dict + Flat row with model, task, channel, excluded_dataset, seed, metrics. + """ + model_spec = config.models[model_label] + row: dict[str, Any] = { + "model": model_label, + "task": task, + "channel": channel, + "excluded_dataset": excluded_dataset or "baseline", + "seed": seed, + "n_train_datasets": len(datasets_for_combo), + } + + # Per-dataset contribution counts + for ds in datasets_for_combo: + ds_name = Path(ds["annotations"]).stem.replace("_annotations", "") + ann = pd.read_csv(ds["annotations"]) + if task in ann.columns: + valid = ann[task].dropna() + valid = valid[valid != "unknown"] + row[f"n_samples_{ds_name}"] = len(valid) + + # Class counts + class_counts = _get_class_counts(datasets_for_combo, task) + for cls, cnt in class_counts.items(): + row[f"train_class_{cls}"] = cnt + + # Identify minority class + if class_counts: + minority_class = min(class_counts, key=class_counts.get) + row["minority_class"] = minority_class + row["minority_class_count"] = class_counts[minority_class] + else: + row["minority_class"] = None + row["minority_class_count"] = 0 + + try: + combined_adata = load_and_combine_datasets(datasets_for_combo, task) + + classifier_params = { + "max_iter": config.max_iter, + "class_weight": config.class_weight, + "solver": config.solver, + "random_state": seed, + } + + pipeline, metrics = train_linear_classifier( + adata=combined_adata, + task=task, + use_scaling=config.use_scaling, + use_pca=config.use_pca, + n_pca_components=config.n_pca_components, + classifier_params=classifier_params, + split_train_data=config.split_train_data, + random_seed=seed, + ) + + row.update(metrics) + + # Load test embeddings and predict + test_channel_zarrs = _find_channel_zarrs( + model_spec.test_embeddings_dir, [channel] + ) + if channel not in test_channel_zarrs: + row["auroc"] = np.nan + row["error"] = "no test channel zarr" + return row + + test_adata = ad.read_zarr(test_channel_zarrs[channel]) + test_adata = predict_with_classifier(test_adata, pipeline, task) + + # Join test annotations + annotated = load_annotation_anndata( + test_adata, str(config.test_annotations_csv), task + ) + + mask = annotated.obs[task].notna() & (annotated.obs[task] != "unknown") + eval_subset = annotated[mask] + + if len(eval_subset) == 0: + row["auroc"] = np.nan + row["error"] = "no annotated test cells" + return row + + pred_col = f"predicted_{task}" + y_true = eval_subset.obs[task].values + y_pred = eval_subset.obs[pred_col].values + + # AUROC + proba_key = f"predicted_{task}_proba" + classes_key = f"predicted_{task}_classes" + if proba_key in annotated.obsm and classes_key in annotated.uns: + y_proba = annotated[mask].obsm[proba_key] + classes = annotated.uns[classes_key] + n_classes = len(classes) + + if n_classes == 2: + positive_idx = 1 + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + auroc = roc_auc_score(y_true, y_proba[:, positive_idx]) + except ValueError: + auroc = np.nan + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + auroc = roc_auc_score( + y_true, y_proba, multi_class="ovr", average="macro" + ) + except ValueError: + auroc = np.nan + row["auroc"] = auroc + else: + row["auroc"] = np.nan + + # Classification report + report = classification_report(y_true, y_pred, digits=4, output_dict=True) + row["test_accuracy"] = report["accuracy"] + row["test_weighted_f1"] = report["weighted avg"]["f1-score"] + row["test_weighted_precision"] = report["weighted avg"]["precision"] + row["test_weighted_recall"] = report["weighted avg"]["recall"] + row["test_n_samples"] = len(eval_subset) + + # Per-class metrics + for class_name in sorted(set(y_true) | set(y_pred)): + if class_name in report: + row[f"test_{class_name}_f1"] = report[class_name]["f1-score"] + row[f"test_{class_name}_precision"] = report[class_name]["precision"] + row[f"test_{class_name}_recall"] = report[class_name]["recall"] + + # Minority class metrics + if row.get("minority_class") and row["minority_class"] in report: + mc = row["minority_class"] + row["minority_f1"] = report[mc]["f1-score"] + row["minority_recall"] = report[mc]["recall"] + row["minority_precision"] = report[mc]["precision"] + + except Exception as e: + row["auroc"] = np.nan + row["error"] = str(e) + logger.warning(f"CV fold failed: {excluded_dataset}, seed={seed}: {e}") + + return row + + +# --------------------------------------------------------------------------- +# Main CV loop +# --------------------------------------------------------------------------- + + +def cross_validate_datasets( + config: DatasetEvalConfig, + ranking_metric: str = "auroc", + n_bootstrap: int = 5, + min_class_samples: int | None = None, +) -> tuple[pd.DataFrame, pd.DataFrame]: + """Run leave-one-dataset-out cross-validation. + + Parameters + ---------- + config : DatasetEvalConfig + Evaluation configuration (with ``wandb_logging=False`` recommended). + ranking_metric : str + Metric to use for impact ranking (default: "auroc"). + n_bootstrap : int + Number of bootstrap seeds per fold. + min_class_samples : int or None + Minimum samples per class to consider a fold safe. When ``None`` + (default), auto-detected from the embedding dimensionality + (``n_features``) so that each class has at least as many samples + as there are features. + + Returns + ------- + pd.DataFrame + Raw results (one row per fold x seed). + pd.DataFrame + Aggregated summary with impact labels. + """ + tc = _resolve_task_channels(config) + if not tc: + raise ValueError("No valid tasks found in test annotations CSV.") + + if min_class_samples is None: + if config.use_pca and config.n_pca_components is not None: + effective_dim = config.n_pca_components + source = f"n_pca_components={config.n_pca_components}" + else: + effective_dim = _detect_n_features(config) + source = f"{effective_dim}-dim embeddings" + min_class_samples = effective_dim + print(f" Auto-detected min_class_samples={min_class_samples} (from {source})") + + base_seed = config.random_seed + seeds = [base_seed + i for i in range(n_bootstrap)] + + all_rows: list[dict[str, Any]] = [] + + for model_label, model_spec in config.models.items(): + print(f"\n## Cross-validation: {model_label} ({model_spec.name})") + + for task, channels in tc.items(): + for channel in channels: + print(f"\n### {task} / {channel}") + + # Build full dataset list for this combo + pairs = _build_datasets_for_combo( + model_spec.train_datasets, channel, task + ) + if not pairs: + print(" No datasets available, skipping.") + continue + + ds_names = [_get_dataset_name(p[0]) for p in pairs] + + if len(pairs) <= 1: + logger.warning( + f" Only {len(pairs)} dataset(s) for " + f"{task}/{channel} — cannot do leave-one-out CV." + ) + continue + + # Baseline: all datasets + all_ds_dicts = [p[1] for p in pairs] + print(f" Baseline: {len(all_ds_dicts)} datasets, {n_bootstrap} seeds") + for seed in seeds: + row = _train_and_evaluate( + config, + model_label, + task, + channel, + all_ds_dicts, + seed, + excluded_dataset=None, + ) + all_rows.append(row) + + # Leave-one-out folds + for i, (train_ds, _) in enumerate(pairs): + ds_name = ds_names[i] + remaining = [p[1] for j, p in enumerate(pairs) if j != i] + + # Safety check + safe = _check_class_safety(remaining, task, min_class_samples) + if not safe: + print(f" Excluding {ds_name}: UNSAFE (class threshold)") + for seed in seeds: + unsafe_row = { + "model": model_label, + "task": task, + "channel": channel, + "excluded_dataset": ds_name, + "seed": seed, + "n_train_datasets": len(remaining), + "impact": "unsafe", + "auroc": np.nan, + } + all_rows.append(unsafe_row) + continue + + print( + f" Excluding {ds_name}: " + f"{len(remaining)} remaining, {n_bootstrap} seeds" + ) + for seed in seeds: + row = _train_and_evaluate( + config, + model_label, + task, + channel, + remaining, + seed, + excluded_dataset=ds_name, + ) + all_rows.append(row) + + if not all_rows: + return pd.DataFrame(), pd.DataFrame() + + results_df = pd.DataFrame(all_rows) + + # Compute summary + summary_df = _compute_summary(results_df, ranking_metric) + + # Save CSVs + config.output_dir.mkdir(parents=True, exist_ok=True) + results_path = config.output_dir / "cv_results.csv" + summary_path = config.output_dir / "cv_summary.csv" + results_df.to_csv(results_path, index=False) + summary_df.to_csv(summary_path, index=False) + print(f"\n Raw results: {results_path}") + print(f" Summary: {summary_path}") + + # Print markdown summary + _print_markdown_summary(summary_df, ranking_metric) + + return results_df, summary_df + + +def _compute_summary( + results_df: pd.DataFrame, + ranking_metric: str = "auroc", +) -> pd.DataFrame: + """Aggregate raw CV results into per-fold summary with impact labels. + + Parameters + ---------- + results_df : pd.DataFrame + Raw results from cross_validate_datasets. + ranking_metric : str + Metric column to compute deltas on. + + Returns + ------- + pd.DataFrame + Summary with columns: model, task, channel, excluded_dataset, + mean_{metric}, std_{metric}, baseline_mean, delta, impact. + """ + if results_df.empty: + return pd.DataFrame() + + group_cols = ["model", "task", "channel"] + + summary_rows = [] + + for group_key, group_df in results_df.groupby(group_cols): + model, task, channel = group_key + + # Baseline stats + baseline = group_df[group_df["excluded_dataset"] == "baseline"] + baseline_mean = baseline[ranking_metric].mean() + baseline_std = baseline[ranking_metric].std() + + # Per excluded dataset + for exc_ds, fold_df in group_df.groupby("excluded_dataset"): + fold_mean = fold_df[ranking_metric].mean() + fold_std = fold_df[ranking_metric].std() + + if exc_ds == "baseline": + delta = 0.0 + impact = "baseline" + elif fold_df.get("impact", pd.Series()).eq("unsafe").any(): + delta = np.nan + impact = "unsafe" + else: + # Delta = fold_mean - baseline_mean + # Positive delta = performance improved when dataset removed + # → dataset HURTS + # Negative delta = performance dropped when dataset removed + # → dataset HELPS + delta = fold_mean - baseline_mean + + # Use pooled std for significance threshold + pooled_std = ( + np.sqrt((baseline_std**2 + fold_std**2) / 2) + if not (np.isnan(baseline_std) or np.isnan(fold_std)) + else 0.0 + ) + + if pooled_std == 0 or np.isnan(delta): + impact = "uncertain" + elif delta > 0 and delta > pooled_std: + impact = "hurts" + elif delta < 0 and abs(delta) > pooled_std: + impact = "helps" + else: + impact = "uncertain" + + summary_rows.append( + { + "model": model, + "task": task, + "channel": channel, + "excluded_dataset": exc_ds, + f"mean_{ranking_metric}": fold_mean, + f"std_{ranking_metric}": fold_std, + "baseline_mean": baseline_mean, + "delta": delta, + "impact": impact, + } + ) + + return pd.DataFrame(summary_rows) + + +def _print_markdown_summary(summary_df: pd.DataFrame, ranking_metric: str) -> None: + """Print a markdown-formatted summary table.""" + if summary_df.empty: + print("\nNo cross-validation results to summarize.") + return + + print("\n## Cross-Validation Impact Summary\n") + + for (model, task, channel), group in summary_df.groupby( + ["model", "task", "channel"] + ): + print(f"\n### {model} / {task} / {channel}\n") + print(f"| Excluded Dataset | Mean {ranking_metric.upper()} | Delta | Impact |") + print("|------------------|------------|-------|--------|") + + for _, row in group.sort_values( + "delta", ascending=False, na_position="last" + ).iterrows(): + mean_val = row.get(f"mean_{ranking_metric}", np.nan) + delta = row.get("delta", np.nan) + impact = row.get("impact", "?") + + mean_str = f"{mean_val:.4f}" if not np.isnan(mean_val) else "N/A" + delta_str = f"{delta:+.4f}" if not np.isnan(delta) else "N/A" + + print( + f"| {row['excluded_dataset']} | {mean_str} | {delta_str} | {impact} |" + ) + + +# --------------------------------------------------------------------------- +# PDF report +# --------------------------------------------------------------------------- + + +def generate_cv_report( + config: DatasetEvalConfig, + results_df: pd.DataFrame, + summary_df: pd.DataFrame, + ranking_metric: str = "auroc", +) -> Path: + """Generate a PDF cross-validation report. + + Parameters + ---------- + config : DatasetEvalConfig + Evaluation configuration. + results_df : pd.DataFrame + Raw CV results. + summary_df : pd.DataFrame + Aggregated summary. + ranking_metric : str + Metric used for ranking. + + Returns + ------- + Path + Path to generated PDF. + """ + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + from matplotlib.backends.backend_pdf import PdfPages + + output_path = config.output_dir / "cv_report.pdf" + config.output_dir.mkdir(parents=True, exist_ok=True) + + with PdfPages(str(output_path)) as pdf: + # Page 1: Title + methodology + fig, ax = plt.subplots(figsize=(11, 8.5)) + ax.axis("off") + ax.text( + 0.5, + 0.85, + "Cross-Validation: Training Dataset Impact Analysis", + ha="center", + va="top", + fontsize=18, + fontweight="bold", + ) + methodology = ( + f"Method: Leave-one-dataset-out CV\n" + f"Ranking metric: {ranking_metric}\n" + f"Seeds per fold: {results_df['seed'].nunique()}\n" + f"Models: {', '.join(summary_df['model'].unique())}\n\n" + f"Impact classification:\n" + f" hurts: removing dataset improves {ranking_metric} " + f"by > 1 pooled std\n" + f" helps: removing dataset decreases {ranking_metric} " + f"by > 1 pooled std\n" + f" uncertain: delta within 1 pooled std\n" + f" unsafe: fold skipped (class threshold not met)" + ) + ax.text( + 0.5, + 0.55, + methodology, + ha="center", + va="top", + fontsize=12, + fontfamily="monospace", + ) + pdf.savefig(fig) + plt.close(fig) + + # Page 2: Annotation inventory + _render_annotation_inventory(pdf, results_df) + + # Page 3+: Impact heatmaps per model + for model in summary_df["model"].unique(): + model_summary = summary_df[ + (summary_df["model"] == model) + & (summary_df["excluded_dataset"] != "baseline") + ] + if model_summary.empty: + continue + _render_impact_heatmap(pdf, model_summary, model, ranking_metric) + + # Per-task detail pages + for (model, task, channel), group in summary_df.groupby( + ["model", "task", "channel"] + ): + non_baseline = group[group["excluded_dataset"] != "baseline"] + if non_baseline.empty: + continue + _render_delta_bar_chart( + pdf, non_baseline, f"{model} / {task} / {channel}", ranking_metric + ) + + print(f"\n CV report saved: {output_path}") + return output_path + + +def _render_annotation_inventory(pdf, results_df: pd.DataFrame) -> None: + """Render annotation count table page.""" + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(11, 8.5)) + ax.axis("off") + ax.set_title("Annotation Inventory (training class counts)", fontsize=14, pad=20) + + # Gather class count columns + class_cols = [c for c in results_df.columns if c.startswith("train_class_")] + if not class_cols: + ax.text(0.5, 0.5, "No class count data available.", ha="center", va="center") + pdf.savefig(fig) + plt.close(fig) + return + + baseline = results_df[results_df["excluded_dataset"] == "baseline"] + if baseline.empty: + pdf.savefig(fig) + plt.close(fig) + return + + # Show one row per (model, task, channel) baseline + display_cols = ["model", "task", "channel"] + class_cols + summary = baseline.groupby(["model", "task", "channel"])[class_cols].first() + summary = summary.reset_index() + + cell_text = [] + for _, row in summary.iterrows(): + cell_text.append([str(row[c]) for c in display_cols]) + + col_labels = display_cols + table = ax.table( + cellText=cell_text, + colLabels=col_labels, + loc="center", + cellLoc="center", + ) + table.auto_set_font_size(False) + table.set_fontsize(8) + table.scale(1.2, 1.5) + + pdf.savefig(fig) + plt.close(fig) + + +def _render_impact_heatmap( + pdf, model_summary: pd.DataFrame, model: str, ranking_metric: str +) -> None: + """Render impact heatmap for one model.""" + import matplotlib.pyplot as plt + + pivot = model_summary.pivot_table( + index="excluded_dataset", + columns=["task", "channel"], + values="delta", + aggfunc="first", + ) + + fig, ax = plt.subplots(figsize=(11, max(4, len(pivot) * 0.8 + 2))) + ax.set_title(f"Impact Heatmap: {model}", fontsize=14) + + # Use blue-orange diverging colormap (colorblind-friendly) + cmap = plt.cm.RdYlBu_r # Blue = negative (helps), Orange/Red = positive (hurts) + vmax = ( + max( + abs(pivot.values[~np.isnan(pivot.values)].max()), + abs(pivot.values[~np.isnan(pivot.values)].min()), + ) + if pivot.values.size > 0 and not np.all(np.isnan(pivot.values)) + else 0.05 + ) + im = ax.imshow(pivot.values, cmap=cmap, aspect="auto", vmin=-vmax, vmax=vmax) + + ax.set_xticks(range(len(pivot.columns))) + ax.set_xticklabels( + [f"{t}/{c}" for t, c in pivot.columns], rotation=45, ha="right", fontsize=9 + ) + ax.set_yticks(range(len(pivot.index))) + ax.set_yticklabels(pivot.index, fontsize=9) + + # Annotate cells + for i in range(len(pivot.index)): + for j in range(len(pivot.columns)): + val = pivot.values[i, j] + if not np.isnan(val): + ax.text(j, i, f"{val:+.3f}", ha="center", va="center", fontsize=8) + else: + ax.text(j, i, "N/A", ha="center", va="center", fontsize=8, color="gray") + + fig.colorbar(im, ax=ax, label=f"{ranking_metric} delta (positive = hurts)") + fig.tight_layout() + pdf.savefig(fig) + plt.close(fig) + + +def _render_delta_bar_chart( + pdf, group: pd.DataFrame, title: str, ranking_metric: str +) -> None: + """Render per-fold delta bar chart.""" + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(11, 6)) + ax.set_title(f"Dataset Impact: {title}", fontsize=13) + + sorted_group = group.sort_values("delta", ascending=True) + datasets = sorted_group["excluded_dataset"].values + deltas = sorted_group["delta"].values + impacts = sorted_group["impact"].values + + colors = [] + for imp in impacts: + if imp == "hurts": + colors.append("#E69F00") # orange + elif imp == "helps": + colors.append("#0072B2") # blue + elif imp == "unsafe": + colors.append("#999999") # gray + else: + colors.append("#56B4E9") # light blue + + y_pos = range(len(datasets)) + ax.barh(y_pos, deltas, color=colors, edgecolor="black", linewidth=0.5) + ax.set_yticks(y_pos) + ax.set_yticklabels(datasets, fontsize=9) + ax.set_xlabel(f"{ranking_metric} delta (positive = removing helps)", fontsize=10) + ax.axvline(x=0, color="black", linewidth=0.8, linestyle="-") + + # Legend + from matplotlib.patches import Patch + + legend_elements = [ + Patch(facecolor="#E69F00", edgecolor="black", label="hurts"), + Patch(facecolor="#0072B2", edgecolor="black", label="helps"), + Patch(facecolor="#56B4E9", edgecolor="black", label="uncertain"), + Patch(facecolor="#999999", edgecolor="black", label="unsafe"), + ] + ax.legend(handles=legend_elements, loc="lower right", fontsize=9) + + fig.tight_layout() + pdf.savefig(fig) + plt.close(fig) + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Leave-one-dataset-out cross-validation" + ) + parser.add_argument( + "--n-bootstrap", + type=int, + default=5, + help="Number of bootstrap seeds per fold", + ) + parser.add_argument( + "--min-class-samples", + type=int, + default=None, + help="Minimum samples per class for a fold to be safe. " + "Default: auto-detect from embedding dimensionality (n_features).", + ) + parser.add_argument( + "--ranking-metric", + type=str, + default="auroc", + help="Metric for impact ranking", + ) + parser.add_argument( + "--n-pca-components", + type=int, + default=None, + help="Number of PCA components. Enables PCA when set.", + ) + parser.add_argument( + "--no-scaling", + action="store_true", + default=False, + help="Disable StandardScaler normalization", + ) + parser.add_argument( + "--no-report", + action="store_true", + help="Skip PDF report generation", + ) + args = parser.parse_args() + + from .evaluate_dataset import build_default_config + + config = build_default_config() + config.wandb_logging = False + config.n_pca_components = args.n_pca_components + config.use_scaling = not args.no_scaling + + print(f"Output: {config.output_dir}") + print(f"PCA: {config.use_pca} (n_components={config.n_pca_components})") + print(f"Scaling: {config.use_scaling}") + for label, spec in config.models.items(): + print(f" {label}: {len(spec.train_datasets)} training datasets") + + results_df, summary_df = cross_validate_datasets( + config, + ranking_metric=args.ranking_metric, + n_bootstrap=args.n_bootstrap, + min_class_samples=args.min_class_samples, + ) + + if not args.no_report and not results_df.empty: + generate_cv_report( + config, results_df, summary_df, ranking_metric=args.ranking_metric + ) diff --git a/applications/DynaCLR/evaluation/linear_classifiers/evaluate_dataset.py b/applications/DynaCLR/evaluation/linear_classifiers/evaluate_dataset.py new file mode 100644 index 000000000..37254c379 --- /dev/null +++ b/applications/DynaCLR/evaluation/linear_classifiers/evaluate_dataset.py @@ -0,0 +1,895 @@ +"""Evaluation pipeline comparing 2D vs 3D linear classifiers on cell embeddings. + +Trains linear classifiers on cross-dataset embeddings, runs inference on a +held-out test dataset, evaluates predictions, and generates a PDF comparison +report. Each block can be run independently or chained via ``run_evaluation()``. +""" + +from pathlib import Path +from typing import Any, Optional + +import anndata as ad +import joblib +import pandas as pd +from pydantic import BaseModel, Field, field_validator, model_validator +from sklearn.metrics import classification_report + +from viscy.representation.evaluation.linear_classifier import ( + LinearClassifierPipeline, + load_and_combine_datasets, + predict_with_classifier, + save_pipeline_to_wandb, + train_linear_classifier, +) +from viscy.representation.evaluation.linear_classifier_config import ( + VALID_CHANNELS, + VALID_TASKS, +) + +CHANNELS = list(VALID_CHANNELS.__args__) +TASKS = list(VALID_TASKS.__args__) + + +# --------------------------------------------------------------------------- +# Configuration models +# --------------------------------------------------------------------------- + + +class TrainDataset(BaseModel): + """A single training dataset with embeddings and annotations. + + Parameters + ---------- + embeddings_dir : Path + Version directory containing per-channel ``.zarr`` files. + annotations : Path + Path to the annotation CSV for this dataset. + """ + + embeddings_dir: Path + annotations: Path + + @model_validator(mode="after") + def validate_paths(self): + if not self.embeddings_dir.exists(): + raise ValueError(f"Embeddings directory not found: {self.embeddings_dir}") + if not self.annotations.exists(): + raise ValueError(f"Annotations file not found: {self.annotations}") + return self + + +class ModelSpec(BaseModel): + """Specification for one embedding model to evaluate. + + Parameters + ---------- + name : str + Model name (e.g. ``"DynaCLR-3D-BagOfChannels-timeaware"``). + train_datasets : list[TrainDataset] + Training datasets (excluding held-out test). + test_embeddings_dir : Path + Version directory with ``.zarr`` files for the held-out test dataset. + version : str + Model version string (e.g. ``"v1"``). + wandb_project : str + Weights & Biases project name. + """ + + name: str = Field(..., min_length=1) + train_datasets: list[TrainDataset] = Field(..., min_length=1) + test_embeddings_dir: Path + version: str = Field(..., min_length=1) + wandb_project: str = Field(..., min_length=1) + + @field_validator("test_embeddings_dir") + @classmethod + def validate_test_dir(cls, v: Path) -> Path: + if not v.exists(): + raise ValueError(f"Test embeddings directory not found: {v}") + return v + + +class DatasetEvalConfig(BaseModel): + """Configuration for a single-dataset evaluation run. + + Parameters + ---------- + dataset_name : str + Held-out test dataset name. + test_annotations_csv : Path + Path to annotation CSV for the test dataset. + models : dict[str, ModelSpec] + Models to compare, keyed by label (e.g. ``"2D"``, ``"3D"``). + output_dir : Path + Root output directory for results. + task_channels : dict[str, list[str]] or None + Which tasks to evaluate and which channels to use for each. + E.g. ``{"infection_state": ["phase", "sensor"], + "cell_division_state": ["phase"]}``. + ``None`` to auto-detect tasks from the test annotations CSV + and use all channels (phase, sensor, organelle) for each. + split_train_data : float + Train/val split ratio within the training pool. + use_scaling : bool + Whether to apply StandardScaler. + max_iter : int + Max iterations for LogisticRegression. + class_weight : str or None + Class weighting strategy. + solver : str + LogisticRegression solver. + random_seed : int + Random seed for reproducibility. + """ + + dataset_name: str = Field(..., min_length=1) + test_annotations_csv: Path + models: dict[str, ModelSpec] = Field(..., min_length=1) + output_dir: Path + task_channels: Optional[dict[str, list[str]]] = None + split_train_data: float = Field(default=0.8, gt=0.0, le=1.0) + use_scaling: bool = True + n_pca_components: Optional[int] = None + max_iter: int = Field(default=1000, gt=0) + class_weight: Optional[str] = "balanced" + solver: str = "liblinear" + random_seed: int = 42 + wandb_logging: bool = True + + @property + def use_pca(self) -> bool: + return self.n_pca_components is not None + + @model_validator(mode="after") + def validate_config(self): + if not self.test_annotations_csv.exists(): + raise ValueError( + f"Test annotations CSV not found: {self.test_annotations_csv}" + ) + if self.task_channels is not None: + for task in self.task_channels: + if task not in TASKS: + raise ValueError(f"Invalid task '{task}'. Must be one of {TASKS}") + for ch in self.task_channels[task]: + if ch not in CHANNELS: + raise ValueError( + f"Invalid channel '{ch}' for task '{task}'. " + f"Must be one of {CHANNELS}" + ) + return self + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _find_channel_zarrs(embeddings_dir: Path, channels: list[str]) -> dict[str, Path]: + """Find per-channel zarr files in a predictions directory.""" + from glob import glob + + from natsort import natsorted + + channel_zarrs = {} + for channel in channels: + matches = natsorted(glob(str(embeddings_dir / f"*{channel}*.zarr"))) + if matches: + channel_zarrs[channel] = Path(matches[0]) + return channel_zarrs + + +def _get_available_tasks(csv_path: Path) -> list[str]: + """Read CSV header and return which valid task columns are present.""" + columns = pd.read_csv(csv_path, nrows=0).columns.tolist() + return [t for t in TASKS if t in columns] + + +def _resolve_task_channels(config: DatasetEvalConfig) -> dict[str, list[str]]: + """Resolve the task -> channels mapping. + + If ``task_channels`` is set, use it directly. + Otherwise auto-detect tasks from the test CSV and use all channels. + """ + if config.task_channels is not None: + return config.task_channels + tasks = _get_available_tasks(config.test_annotations_csv) + all_channels = list(CHANNELS) + return {task: all_channels for task in tasks} + + +# --------------------------------------------------------------------------- +# Block 1: Train classifiers +# --------------------------------------------------------------------------- + + +def train_classifiers( + config: DatasetEvalConfig, +) -> dict[str, dict[tuple[str, str], dict[str, Any]]]: + """Train linear classifiers for all model x task x channel combinations. + + Parameters + ---------- + config : DatasetEvalConfig + Evaluation configuration. + + Returns + ------- + dict[str, dict[tuple[str, str], dict]] + Nested dict: ``model_label -> (task, channel) -> result_dict``. + Each ``result_dict`` has keys: ``"pipeline"``, ``"metrics"``, + ``"artifact_name"``. + """ + tc = _resolve_task_channels(config) + if not tc: + raise ValueError("No valid tasks found in test annotations CSV.") + + print("## Training classifiers") + print(f" Task-channels: {tc}") + print(f" Models: {list(config.models.keys())}") + + all_results: dict[str, dict[tuple[str, str], dict[str, Any]]] = {} + + for model_label, model_spec in config.models.items(): + print(f"\n### Model: {model_label} ({model_spec.name})") + model_results: dict[tuple[str, str], dict[str, Any]] = {} + model_output_dir = config.output_dir / model_label + model_output_dir.mkdir(parents=True, exist_ok=True) + + for task, channels in tc.items(): + for channel in channels: + combo_key = (task, channel) + print(f"\n {task} / {channel}:") + + try: + # Build training dataset list for this (task, channel) + datasets_for_combo = [] + for train_ds in model_spec.train_datasets: + channel_zarrs = _find_channel_zarrs( + train_ds.embeddings_dir, [channel] + ) + if channel not in channel_zarrs: + print( + f" Skipping {train_ds.embeddings_dir.parent.name}" + f" - no {channel} zarr" + ) + continue + + available_tasks = _get_available_tasks(train_ds.annotations) + if task not in available_tasks: + print( + f" Skipping {train_ds.embeddings_dir.parent.name}" + f" - no {task} column" + ) + continue + + datasets_for_combo.append( + { + "embeddings": str(channel_zarrs[channel]), + "annotations": str(train_ds.annotations), + } + ) + + if not datasets_for_combo: + print(" No training datasets available, skipping.") + continue + + print(f" Training on {len(datasets_for_combo)} dataset(s)") + + combined_adata = load_and_combine_datasets(datasets_for_combo, task) + + classifier_params = { + "max_iter": config.max_iter, + "class_weight": config.class_weight, + "solver": config.solver, + "random_state": config.random_seed, + } + + pipeline, metrics = train_linear_classifier( + adata=combined_adata, + task=task, + use_scaling=config.use_scaling, + use_pca=config.use_pca, + n_pca_components=config.n_pca_components, + classifier_params=classifier_params, + split_train_data=config.split_train_data, + random_seed=config.random_seed, + ) + + # Save pipeline to disk + pipeline_path = ( + model_output_dir / f"{task}_{channel}_pipeline.joblib" + ) + joblib.dump(pipeline, pipeline_path) + print(f" Pipeline saved: {pipeline_path.name}") + + # Build wandb config with provenance + embedding_model_label = f"{model_spec.name}-{model_spec.version}" + train_dataset_names = [ + str(ds.embeddings_dir.parent.name) + for ds in model_spec.train_datasets + ] + wandb_config = { + "task": task, + "input_channel": channel, + "marker": None, + "embedding_model": embedding_model_label, + "embedding_model_version": model_spec.version, + "test_dataset": config.dataset_name, + "train_dataset_names": train_dataset_names, + "use_scaling": config.use_scaling, + "use_pca": False, + "max_iter": config.max_iter, + "class_weight": config.class_weight, + "solver": config.solver, + "split_train_data": config.split_train_data, + "random_seed": config.random_seed, + } + + wandb_tags = [ + config.dataset_name, + model_spec.name, + model_spec.version, + channel, + task, + "cross-dataset", + ] + + if config.wandb_logging: + artifact_name = save_pipeline_to_wandb( + pipeline=pipeline, + metrics=metrics, + config=wandb_config, + wandb_project=model_spec.wandb_project, + tags=wandb_tags, + ) + else: + artifact_name = f"{model_spec.name}_{task}_{channel}_local" + + model_results[combo_key] = { + "pipeline": pipeline, + "metrics": metrics, + "artifact_name": artifact_name, + } + + # Print val metrics summary + val_acc = metrics.get("val_accuracy") + val_f1 = metrics.get("val_weighted_f1") + if val_acc is not None: + print(f" Val accuracy: {val_acc:.3f} Val F1: {val_f1:.3f}") + + except Exception as e: + print(f" FAILED: {e}") + continue + + all_results[model_label] = model_results + + # Save per-model metrics summary CSV + _save_metrics_csv(model_results, model_output_dir / "metrics_summary.csv") + + # Save combined metrics comparison CSV + _save_comparison_csv(all_results, config.output_dir / "metrics_comparison.csv") + + # Print markdown summary + _print_training_summary(all_results, tc) + + return all_results + + +def _save_metrics_csv( + model_results: dict[tuple[str, str], dict[str, Any]], + output_path: Path, +) -> None: + """Save per-model metrics to CSV.""" + rows = [] + for (task, channel), result in model_results.items(): + row = {"task": task, "channel": channel} + row.update(result["metrics"]) + rows.append(row) + + if rows: + df = pd.DataFrame(rows) + df.to_csv(output_path, index=False) + + +def _save_comparison_csv( + all_results: dict[str, dict[tuple[str, str], dict[str, Any]]], + output_path: Path, +) -> None: + """Save combined metrics comparison across models.""" + rows = [] + for model_label, model_results in all_results.items(): + for (task, channel), result in model_results.items(): + row = {"model": model_label, "task": task, "channel": channel} + row.update(result["metrics"]) + rows.append(row) + + if rows: + df = pd.DataFrame(rows) + df.to_csv(output_path, index=False) + + +def _print_training_summary( + all_results: dict[str, dict[tuple[str, str], dict[str, Any]]], + task_channels: dict[str, list[str]], +) -> None: + """Print markdown summary table of training results.""" + print("\n## Training Summary\n") + header = "| Task | Channel |" + sep = "|------|---------|" + for model_label in all_results: + header += f" {model_label} Val Acc | {model_label} Val F1 |" + sep += "------------|-----------|" + print(header) + print(sep) + + for task, channels in task_channels.items(): + for channel in channels: + row = f"| {task} | {channel} |" + for model_label, model_results in all_results.items(): + result = model_results.get((task, channel)) + if result: + acc = result["metrics"].get("val_accuracy", float("nan")) + f1 = result["metrics"].get("val_weighted_f1", float("nan")) + row += f" {acc:.3f} | {f1:.3f} |" + else: + row += " - | - |" + print(row) + + +# --------------------------------------------------------------------------- +# Block 2: Inference +# --------------------------------------------------------------------------- + + +def infer_classifiers( + config: DatasetEvalConfig, + trained: dict[str, dict[tuple[str, str], dict[str, Any]]] | None = None, +) -> dict[str, dict[tuple[str, str], ad.AnnData]]: + """Apply trained classifiers to held-out test embeddings. + + Parameters + ---------- + config : DatasetEvalConfig + Evaluation configuration. + trained : dict or None + Output from ``train_classifiers()``. If ``None``, loads pipelines + from disk. + + Returns + ------- + dict[str, dict[tuple[str, str], ad.AnnData]] + ``model_label -> (task, channel) -> adata`` with predictions. + """ + tc = _resolve_task_channels(config) + print("\n## Running inference on test dataset") + + all_predictions: dict[str, dict[tuple[str, str], ad.AnnData]] = {} + + for model_label, model_spec in config.models.items(): + print(f"\n### Model: {model_label} ({model_spec.name})") + model_predictions: dict[tuple[str, str], ad.AnnData] = {} + model_output_dir = config.output_dir / model_label + model_output_dir.mkdir(parents=True, exist_ok=True) + + for task, channels in tc.items(): + test_channel_zarrs = _find_channel_zarrs( + model_spec.test_embeddings_dir, channels + ) + + for channel in channels: + combo_key = (task, channel) + + if channel not in test_channel_zarrs: + print(f" {task} / {channel}: no test zarr, skipping.") + continue + + # Get pipeline: from in-memory results or disk + pipeline = None + artifact_metadata = None + + if trained and model_label in trained: + result = trained[model_label].get(combo_key) + if result: + pipeline = result["pipeline"] + artifact_name = result.get("artifact_name") + if artifact_name: + artifact_metadata = { + "artifact_name": artifact_name, + "artifact_id": artifact_name, + "artifact_version": "local", + } + + if pipeline is None: + pipeline_path = ( + model_output_dir / f"{task}_{channel}_pipeline.joblib" + ) + if not pipeline_path.exists(): + print(f" {task} / {channel}: no trained pipeline, skipping.") + continue + pipeline = joblib.load(pipeline_path) + if not isinstance(pipeline, LinearClassifierPipeline): + print(f" {task} / {channel}: invalid pipeline file, skipping.") + continue + + try: + print(f" {task} / {channel}: loading test embeddings...") + adata = ad.read_zarr(test_channel_zarrs[channel]) + + adata = predict_with_classifier( + adata, + pipeline, + task, + artifact_metadata=artifact_metadata, + ) + + # Save predictions zarr + pred_path = model_output_dir / f"{task}_{channel}_predictions.zarr" + adata.write_zarr(pred_path) + print(f" {task} / {channel}: saved {pred_path.name}") + + model_predictions[combo_key] = adata + + except Exception as e: + print(f" {task} / {channel}: FAILED - {e}") + continue + + all_predictions[model_label] = model_predictions + + return all_predictions + + +# --------------------------------------------------------------------------- +# Block 3: Evaluate predictions +# --------------------------------------------------------------------------- + + +def evaluate_predictions( + config: DatasetEvalConfig, + predictions: dict[str, dict[tuple[str, str], ad.AnnData]] | None = None, +) -> dict[str, dict[tuple[str, str], dict[str, Any]]]: + """Evaluate predictions against held-out test annotations. + + Parameters + ---------- + config : DatasetEvalConfig + Evaluation configuration. + predictions : dict or None + Output from ``infer_classifiers()``. If ``None``, loads prediction + zarrs from disk. + + Returns + ------- + dict[str, dict[tuple[str, str], dict]] + ``model_label -> (task, channel) -> eval_dict`` with keys: + ``"metrics"``, ``"annotated_adata"``. + """ + from viscy.representation.evaluation import load_annotation_anndata + + tc = _resolve_task_channels(config) + print("\n## Evaluating predictions on test set") + + all_eval: dict[str, dict[tuple[str, str], dict[str, Any]]] = {} + + for model_label in config.models: + print(f"\n### Model: {model_label}") + model_eval: dict[tuple[str, str], dict[str, Any]] = {} + model_output_dir = config.output_dir / model_label + + for task, channels in tc.items(): + for channel in channels: + combo_key = (task, channel) + + # Get prediction adata: from in-memory or disk + adata = None + if predictions and model_label in predictions: + adata = predictions[model_label].get(combo_key) + + if adata is None: + pred_path = model_output_dir / f"{task}_{channel}_predictions.zarr" + if not pred_path.exists(): + print(f" {task} / {channel}: no predictions found, skipping.") + continue + adata = ad.read_zarr(pred_path) + + try: + annotated = load_annotation_anndata( + adata, str(config.test_annotations_csv), task + ) + + # Filter to cells with non-NaN ground truth + mask = annotated.obs[task].notna() & ( + annotated.obs[task] != "unknown" + ) + eval_subset = annotated[mask] + + if len(eval_subset) == 0: + print( + f" {task} / {channel}: " + "no annotated cells after filtering, skipping." + ) + continue + + pred_col = f"predicted_{task}" + y_true = eval_subset.obs[task].values + y_pred = eval_subset.obs[pred_col].values + + report = classification_report( + y_true, y_pred, digits=3, output_dict=True + ) + + test_metrics = { + "test_accuracy": report["accuracy"], + "test_weighted_precision": report["weighted avg"]["precision"], + "test_weighted_recall": report["weighted avg"]["recall"], + "test_weighted_f1": report["weighted avg"]["f1-score"], + "test_n_samples": len(eval_subset), + } + + for class_name in sorted(set(y_true) | set(y_pred)): + if class_name in report: + test_metrics[f"test_{class_name}_precision"] = report[ + class_name + ]["precision"] + test_metrics[f"test_{class_name}_recall"] = report[ + class_name + ]["recall"] + test_metrics[f"test_{class_name}_f1"] = report[class_name][ + "f1-score" + ] + + # Save annotated adata + annotated_path = ( + model_output_dir / f"{task}_{channel}_annotated.zarr" + ) + annotated.write_zarr(annotated_path) + + model_eval[combo_key] = { + "metrics": test_metrics, + "annotated_adata": annotated, + } + + acc = test_metrics["test_accuracy"] + f1 = test_metrics["test_weighted_f1"] + n = test_metrics["test_n_samples"] + print(f" {task} / {channel}: acc={acc:.3f} F1={f1:.3f} (n={n})") + + except Exception as e: + print(f" {task} / {channel}: FAILED - {e}") + continue + + all_eval[model_label] = model_eval + + # Save per-model test metrics CSV + rows = [] + for (task, channel), result in model_eval.items(): + row = {"task": task, "channel": channel} + row.update(result["metrics"]) + rows.append(row) + if rows: + df = pd.DataFrame(rows) + df.to_csv(model_output_dir / "test_metrics_summary.csv", index=False) + + # Save combined test metrics comparison + rows = [] + for model_label, model_eval in all_eval.items(): + for (task, channel), result in model_eval.items(): + row = {"model": model_label, "task": task, "channel": channel} + row.update(result["metrics"]) + rows.append(row) + if rows: + df = pd.DataFrame(rows) + df.to_csv(config.output_dir / "test_metrics_comparison.csv", index=False) + + return all_eval + + +# --------------------------------------------------------------------------- +# Block 4: Report generation +# --------------------------------------------------------------------------- + + +def generate_report( + config: DatasetEvalConfig, + train_results: dict[str, dict[tuple[str, str], dict[str, Any]]], + eval_results: dict[str, dict[tuple[str, str], dict[str, Any]]], +) -> Path: + """Generate a PDF comparison report. + + Parameters + ---------- + config : DatasetEvalConfig + Evaluation configuration. + train_results : dict + Output from ``train_classifiers()``. + eval_results : dict + Output from ``evaluate_predictions()``. + + Returns + ------- + Path + Path to the generated PDF. + """ + from applications.DynaCLR.evaluation.linear_classifiers.report import ( + generate_comparison_report, + ) + + return generate_comparison_report(config, train_results, eval_results) + + +# --------------------------------------------------------------------------- +# Orchestrator +# --------------------------------------------------------------------------- + + +def run_evaluation( + config: DatasetEvalConfig, + skip_train: bool = False, + skip_infer: bool = False, +) -> Path: + """Run the evaluation pipeline. + + Parameters + ---------- + config : DatasetEvalConfig + Evaluation configuration. + skip_train : bool + Skip training. Loads pipelines from disk for inference. + skip_infer : bool + Skip inference. Loads prediction zarrs from disk for evaluation. + Implies ``skip_train=True``. + + Returns + ------- + Path + Path to the generated PDF report. + """ + trained = None + if not skip_train and not skip_infer: + trained = train_classifiers(config) + + predictions = None + if not skip_infer: + predictions = infer_classifiers(config, trained=trained) + + eval_results = evaluate_predictions(config, predictions=predictions) + report_path = generate_report(config, trained or {}, eval_results) + return report_path + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +EMBEDDINGS_BASE = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics") +ANNOTATIONS_BASE = Path("/hpc/projects/organelle_phenotyping/datasets/annotations") +OUTPUT_BASE = Path( + "/hpc/projects/organelle_phenotyping/models/bag_of_channels/" + "h2b_caax_tomm_sec61_g3bp1_sensor_phase/evaluation/predictions" +) + +TEST_DATASET = "2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" + +# Training datasets per model (excluding test dataset) +TRAIN_DATASETS_2D = [ + "2024_11_07_A549_SEC61_DENV", + "2025_01_24_A549_G3BP1_DENV", + "2025_01_28_A549_G3BP1_ZIKV_DENV", + "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV", + "2025_08_26_A549_SEC61_TOMM20_ZIKV", +] + +TRAIN_DATASETS_3D = [ + "2024_11_07_A549_SEC61_DENV", + "2025_01_24_A549_G3BP1_DENV", + "2025_01_28_A549_G3BP1_ZIKV_DENV", + "2025_07_22_A549_SEC61_TOMM20_G3BP1_ZIKV", + "2025_08_26_A549_SEC61_TOMM20_ZIKV", +] + + +def _find_predictions_dir(dataset_name: str, model_name: str, version: str) -> Path: + """Locate the predictions version directory for a dataset.""" + from glob import glob + + from natsort import natsorted + + dataset_dir = EMBEDDINGS_BASE / dataset_name + pattern = str(dataset_dir / "*phenotyping*" / "*prediction*" / model_name / version) + matches = natsorted(glob(pattern)) + if not matches: + raise FileNotFoundError( + f"No predictions found for {dataset_name}/{model_name}/{version}" + ) + return Path(matches[0]) + + +def _build_train_datasets( + dataset_names: list[str], model_name: str, version: str +) -> list[TrainDataset]: + """Build TrainDataset list from dataset names.""" + from applications.DynaCLR.evaluation.linear_classifiers.utils import ( + find_annotation_csv, + ) + + datasets = [] + for name in dataset_names: + try: + emb_dir = _find_predictions_dir(name, model_name, version) + csv_path = find_annotation_csv(ANNOTATIONS_BASE, name) + if csv_path is None: + print(f" Skipping {name}: no annotation CSV found") + continue + datasets.append(TrainDataset(embeddings_dir=emb_dir, annotations=csv_path)) + except FileNotFoundError as e: + print(f" Skipping {name}: {e}") + continue + return datasets + + +def build_default_config() -> DatasetEvalConfig: + """Build the default evaluation config for the 2D vs 3D comparison.""" + from applications.DynaCLR.evaluation.linear_classifiers.utils import ( + find_annotation_csv, + ) + + test_csv = find_annotation_csv(ANNOTATIONS_BASE, TEST_DATASET) + if test_csv is None: + raise FileNotFoundError(f"No annotation CSV for test dataset: {TEST_DATASET}") + + model_2d = ModelSpec( + name="DynaCLR-2D-BagOfChannels-timeaware", + train_datasets=_build_train_datasets( + TRAIN_DATASETS_2D, "DynaCLR-2D-BagOfChannels-timeaware", "v3" + ), + test_embeddings_dir=_find_predictions_dir( + TEST_DATASET, "DynaCLR-2D-BagOfChannels-timeaware", "v3" + ), + version="v3", + wandb_project="DynaCLR-2D-linearclassifiers", + ) + + model_3d = ModelSpec( + name="DynaCLR-3D-BagOfChannels-timeaware", + train_datasets=_build_train_datasets( + TRAIN_DATASETS_3D, "DynaCLR-3D-BagOfChannels-timeaware", "v1" + ), + test_embeddings_dir=_find_predictions_dir( + TEST_DATASET, "DynaCLR-3D-BagOfChannels-timeaware", "v1" + ), + version="v1", + wandb_project="DynaCLR-3D-linearclassifiers", + ) + + return DatasetEvalConfig( + dataset_name=TEST_DATASET, + test_annotations_csv=test_csv, + models={"2D": model_2d, "3D": model_3d}, + output_dir=OUTPUT_BASE / TEST_DATASET, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Run eval pipeline") + parser.add_argument( + "--skip-train", + action="store_true", + help="Skip training, load pipelines from disk", + ) + parser.add_argument( + "--skip-infer", + action="store_true", + help="Skip inference, load predictions from disk (implies --skip-train)", + ) + args = parser.parse_args() + + config = build_default_config() + print(f"Output: {config.output_dir}") + for label, spec in config.models.items(): + print(f" {label}: {len(spec.train_datasets)} training datasets") + + report = run_evaluation( + config, skip_train=args.skip_train, skip_infer=args.skip_infer + ) + print(f"\nDone! Report: {report}") diff --git a/applications/DynaCLR/evaluation/linear_classifiers/report.py b/applications/DynaCLR/evaluation/linear_classifiers/report.py new file mode 100644 index 000000000..757bf138e --- /dev/null +++ b/applications/DynaCLR/evaluation/linear_classifiers/report.py @@ -0,0 +1,368 @@ +"""PDF report generation for 2D vs 3D linear classifier comparison.""" + +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.backends.backend_pdf import PdfPages +from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix + +# Colorblind-friendly palette +COLOR_MAP = { + "2D": "#1f77b4", # Blue + "3D": "#ff7f0e", # Orange +} + +# Fallback colors for additional models +_EXTRA_COLORS = ["#2ca02c", "#9467bd", "#8c564b", "#e377c2"] + + +def _get_model_color(label: str, idx: int = 0) -> str: + return COLOR_MAP.get(label, _EXTRA_COLORS[idx % len(_EXTRA_COLORS)]) + + +def generate_comparison_report( + config, + train_results: dict[str, dict[tuple[str, str], dict[str, Any]]], + eval_results: dict[str, dict[tuple[str, str], dict[str, Any]]], +) -> Path: + """Generate a PDF comparing model performance. + + Parameters + ---------- + config : DatasetEvalConfig + Evaluation configuration. + train_results : dict + Output from ``train_classifiers()``. + eval_results : dict + Output from ``evaluate_predictions()``. + + Returns + ------- + Path + Path to the generated PDF. + """ + from applications.DynaCLR.evaluation.linear_classifiers.evaluate_dataset import ( + _resolve_tasks, + ) + + tasks = _resolve_tasks(config) + model_labels = list(config.models.keys()) + report_path = config.output_dir / f"{config.dataset_name}_2d_vs_3d_report.pdf" + config.output_dir.mkdir(parents=True, exist_ok=True) + + with PdfPages(report_path) as pdf: + _page_title(pdf, config, model_labels, train_results, eval_results) + _page_global_metrics( + pdf, model_labels, tasks, config.channels, train_results, eval_results + ) + + for task in tasks: + _page_task_comparison( + pdf, task, model_labels, config.channels, eval_results + ) + + for channel in config.channels: + _page_channel_comparison( + pdf, channel, model_labels, tasks, train_results, eval_results + ) + + _page_test_data_summary(pdf, tasks, model_labels, config.channels, eval_results) + + print(f"\nReport saved: {report_path}") + return report_path + + +def _page_title(pdf, config, model_labels, train_results, eval_results): + fig, ax = plt.subplots(figsize=(11, 8.5)) + ax.axis("off") + + lines = [ + "Linear Classifier Comparison Report", + "", + f"Test Dataset: {config.dataset_name}", + "", + ] + + for label in model_labels: + spec = config.models[label] + n_train = len(spec.train_datasets) + n_combos = len(train_results.get(label, {})) + lines.append(f"Model {label}: {spec.name} ({spec.version})") + lines.append(f" Training datasets: {n_train}") + lines.append(f" Trained classifiers: {n_combos}") + lines.append("") + + lines.append(f"Channels: {', '.join(config.channels)}") + + from applications.DynaCLR.evaluation.linear_classifiers.evaluate_dataset import ( + _resolve_tasks, + ) + + tasks = _resolve_tasks(config) + lines.append(f"Tasks: {', '.join(tasks)}") + + ax.text( + 0.5, + 0.5, + "\n".join(lines), + transform=ax.transAxes, + fontsize=12, + verticalalignment="center", + horizontalalignment="center", + fontfamily="monospace", + ) + fig.suptitle("2D vs 3D Model Comparison", fontsize=16, fontweight="bold") + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + + +def _page_global_metrics( + pdf, model_labels, tasks, channels, train_results, eval_results +): + fig, ax = plt.subplots(figsize=(11, 8.5)) + ax.axis("off") + fig.suptitle("Global Metrics Summary", fontsize=14, fontweight="bold") + + col_labels = ["Task", "Channel"] + for label in model_labels: + col_labels.extend( + [ + f"{label}\nVal Acc", + f"{label}\nVal F1", + f"{label}\nTest Acc", + f"{label}\nTest F1", + ] + ) + + table_data = [] + for task in tasks: + for channel in channels: + row = [task, channel] + for label in model_labels: + train_r = train_results.get(label, {}).get((task, channel)) + eval_r = eval_results.get(label, {}).get((task, channel)) + val_acc = ( + f"{train_r['metrics']['val_accuracy']:.3f}" if train_r else "-" + ) + val_f1 = ( + f"{train_r['metrics']['val_weighted_f1']:.3f}" if train_r else "-" + ) + test_acc = ( + f"{eval_r['metrics']['test_accuracy']:.3f}" if eval_r else "-" + ) + test_f1 = ( + f"{eval_r['metrics']['test_weighted_f1']:.3f}" if eval_r else "-" + ) + row.extend([val_acc, val_f1, test_acc, test_f1]) + table_data.append(row) + + if table_data: + table = ax.table( + cellText=table_data, + colLabels=col_labels, + loc="center", + cellLoc="center", + ) + table.auto_set_font_size(False) + table.set_fontsize(8) + table.scale(1.0, 1.4) + + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + + +def _page_task_comparison(pdf, task, model_labels, channels, eval_results): + n_models = len(model_labels) + n_channels = len(channels) + + fig, axes = plt.subplots( + 1 + 1, 1, figsize=(11, 8.5), gridspec_kw={"height_ratios": [1, 2]} + ) + fig.suptitle(f"Task: {task}", fontsize=14, fontweight="bold") + + # Top: grouped bar chart of F1-per-class + ax_bar = axes[0] + all_classes = set() + for label in model_labels: + r = eval_results.get(label, {}).get((task, channels[0])) + if r: + adata = r.get("annotated_adata") + if adata is not None and task in adata.obs.columns: + all_classes.update(adata.obs[task].dropna().unique()) + all_classes = sorted(all_classes) + + if all_classes: + x = np.arange(len(all_classes)) + width = 0.8 / max(n_models, 1) + + for i, label in enumerate(model_labels): + f1_values = [] + for cls in all_classes: + # Average F1 across channels for this class + f1s = [] + for ch in channels: + r = eval_results.get(label, {}).get((task, ch)) + if r: + f1 = r["metrics"].get(f"test_{cls}_f1") + if f1 is not None: + f1s.append(f1) + f1_values.append(np.mean(f1s) if f1s else 0) + + ax_bar.bar( + x + i * width, + f1_values, + width, + label=label, + color=_get_model_color(label, i), + ) + + ax_bar.set_xticks(x + width * (n_models - 1) / 2) + ax_bar.set_xticklabels(all_classes) + ax_bar.set_ylabel("Test F1 (avg across channels)") + ax_bar.legend() + ax_bar.set_ylim(0, 1.05) + + # Bottom: confusion matrices grid + ax_cm = axes[1] + ax_cm.axis("off") + + n_cols = n_channels + n_rows = n_models + fig_cm, cm_axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3.5 * n_rows)) + fig_cm.suptitle(f"Confusion Matrices: {task}", fontsize=14, fontweight="bold") + + if n_rows == 1: + cm_axes = [cm_axes] + if n_cols == 1: + cm_axes = [[row] for row in cm_axes] + + for i, label in enumerate(model_labels): + for j, ch in enumerate(channels): + ax = cm_axes[i][j] + r = eval_results.get(label, {}).get((task, ch)) + if r and "annotated_adata" in r: + adata = r["annotated_adata"] + pred_col = f"predicted_{task}" + mask = adata.obs[task].notna() & (adata.obs[task] != "unknown") + subset = adata[mask] + if len(subset) > 0: + y_true = subset.obs[task].values + y_pred = subset.obs[pred_col].values + labels = sorted(set(y_true) | set(y_pred)) + cm = confusion_matrix(y_true, y_pred, labels=labels) + ConfusionMatrixDisplay(cm, display_labels=labels).plot( + ax=ax, cmap="Blues", colorbar=False + ) + ax.set_title(f"{label} / {ch}", fontsize=10) + + fig_cm.tight_layout() + pdf.savefig(fig_cm, bbox_inches="tight") + plt.close(fig_cm) + + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + + +def _page_channel_comparison( + pdf, channel, model_labels, tasks, train_results, eval_results +): + fig, axes = plt.subplots(1, 2, figsize=(11, 5)) + fig.suptitle(f"Channel: {channel}", fontsize=14, fontweight="bold") + + n_models = len(model_labels) + x = np.arange(len(tasks)) + width = 0.8 / max(n_models, 1) + + # Left: test accuracy per task + ax = axes[0] + for i, label in enumerate(model_labels): + accs = [] + for task in tasks: + r = eval_results.get(label, {}).get((task, channel)) + accs.append(r["metrics"]["test_accuracy"] if r else 0) + ax.bar( + x + i * width, accs, width, label=label, color=_get_model_color(label, i) + ) + ax.set_xticks(x + width * (n_models - 1) / 2) + ax.set_xticklabels(tasks, rotation=30, ha="right", fontsize=8) + ax.set_ylabel("Test Accuracy") + ax.set_ylim(0, 1.05) + ax.legend() + ax.set_title("Test Accuracy") + + # Right: train-val vs test accuracy (overfitting check) + ax2 = axes[1] + for i, label in enumerate(model_labels): + val_accs = [] + test_accs = [] + for task in tasks: + tr = train_results.get(label, {}).get((task, channel)) + ev = eval_results.get(label, {}).get((task, channel)) + val_accs.append(tr["metrics"]["val_accuracy"] if tr else 0) + test_accs.append(ev["metrics"]["test_accuracy"] if ev else 0) + + color = _get_model_color(label, i) + ax2.bar( + x + i * width - width / 4, + val_accs, + width / 2, + label=f"{label} Val", + color=color, + alpha=0.5, + ) + ax2.bar( + x + i * width + width / 4, + test_accs, + width / 2, + label=f"{label} Test", + color=color, + alpha=1.0, + ) + + ax2.set_xticks(x + width * (n_models - 1) / 2) + ax2.set_xticklabels(tasks, rotation=30, ha="right", fontsize=8) + ax2.set_ylabel("Accuracy") + ax2.set_ylim(0, 1.05) + ax2.legend(fontsize=7) + ax2.set_title("Val vs Test (Generalization)") + + fig.tight_layout() + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + + +def _page_test_data_summary(pdf, tasks, model_labels, channels, eval_results): + fig, ax = plt.subplots(figsize=(11, 8.5)) + ax.axis("off") + fig.suptitle("Test Data Summary", fontsize=14, fontweight="bold") + + lines = [] + for task in tasks: + lines.append(f"### {task}") + for label in model_labels: + for ch in channels: + r = eval_results.get(label, {}).get((task, ch)) + if r and "annotated_adata" in r: + adata = r["annotated_adata"] + mask = adata.obs[task].notna() & (adata.obs[task] != "unknown") + subset = adata[mask] + counts = subset.obs[task].value_counts() + dist = ", ".join(f"{k}: {v}" for k, v in counts.items()) + lines.append(f" {label}/{ch}: n={len(subset)} ({dist})") + break # Same annotations across channels + break # Same annotations across models + lines.append("") + + ax.text( + 0.05, + 0.95, + "\n".join(lines), + transform=ax.transAxes, + fontsize=10, + verticalalignment="top", + fontfamily="monospace", + ) + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) diff --git a/tests/conftest.py b/tests/conftest.py index 8d6d29a6f..f2ac9ac8b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -195,6 +195,131 @@ def tracks_with_gaps_dataset(tmp_path_factory: TempPathFactory) -> Path: return dataset_path +def _make_synthetic_embeddings( + tmp_path: Path, + name: str, + n_samples: int, + n_features: int, + channels: list[str], + rng: np.random.Generator, +) -> Path: + """Create a fake embeddings directory with per-channel zarr files.""" + version_dir = tmp_path / name + version_dir.mkdir(parents=True, exist_ok=True) + + for channel in channels: + X = rng.standard_normal((n_samples, n_features)).astype(np.float32) + obs = pd.DataFrame( + { + "fov_name": [f"A/{(i % 3) + 1}/0" for i in range(n_samples)], + "id": np.arange(n_samples), + "t": np.zeros(n_samples, dtype=int), + "track_id": np.arange(n_samples), + } + ) + adata = ad.AnnData(X=X, obs=obs) + adata.write_zarr(version_dir / f"timeaware_{channel}_160patch_99ckpt.zarr") + + return version_dir + + +def _make_annotation_csv( + tmp_path: Path, + name: str, + n_samples: int, + tasks: dict[str, list[str]], + rng: np.random.Generator, +) -> Path: + """Create a fake annotation CSV with specified task columns.""" + csv_path = tmp_path / f"{name}_annotations.csv" + data = { + "fov_name": [f"A/{(i % 3) + 1}/0" for i in range(n_samples)], + "id": np.arange(n_samples), + } + for task, labels in tasks.items(): + data[task] = rng.choice(labels, size=n_samples) + + pd.DataFrame(data).to_csv(csv_path, index=False) + return csv_path + + +@fixture(scope="function") +def synthetic_train_data(tmp_path_factory: TempPathFactory): + """Two synthetic training datasets for eval pipeline tests. + + Each: 40 samples x 16 features, channels=[phase, organelle]. + Tasks: infection_state, cell_division_state. + """ + from applications.DynaCLR.evaluation.linear_classifiers.evaluate_dataset import ( + TrainDataset, + ) + + base = tmp_path_factory.mktemp("train_data") + rng = np.random.default_rng(42) + channels = ["phase", "organelle"] + tasks = { + "infection_state": ["infected", "uninfected"], + "cell_division_state": ["interphase", "mitosis"], + } + + datasets = [] + for i in range(2): + emb_dir = _make_synthetic_embeddings( + base, f"train_ds_{i}/predictions/v1", 40, 16, channels, rng + ) + csv_path = _make_annotation_csv(base, f"train_ds_{i}", 40, tasks, rng) + datasets.append(TrainDataset(embeddings_dir=emb_dir, annotations=csv_path)) + + return datasets + + +@fixture(scope="function") +def synthetic_test_data(tmp_path_factory: TempPathFactory): + """Held-out test dataset: 60 samples x 16 features, 2 channels + CSV.""" + base = tmp_path_factory.mktemp("test_data") + rng = np.random.default_rng(99) + channels = ["phase", "organelle"] + tasks = { + "infection_state": ["infected", "uninfected"], + "cell_division_state": ["interphase", "mitosis"], + } + + emb_dir = _make_synthetic_embeddings( + base, "test_ds/predictions/v1", 60, 16, channels, rng + ) + csv_path = _make_annotation_csv(base, "test_ds", 60, tasks, rng) + return emb_dir, csv_path + + +@fixture(scope="function") +def eval_config(synthetic_train_data, synthetic_test_data, tmp_path_factory): + """Full DatasetEvalConfig with 1 model, 2 channels, 2 tasks.""" + from applications.DynaCLR.evaluation.linear_classifiers.evaluate_dataset import ( + DatasetEvalConfig, + ModelSpec, + ) + + test_emb_dir, test_csv = synthetic_test_data + return DatasetEvalConfig( + dataset_name="test_dataset", + test_annotations_csv=test_csv, + models={ + "test_model": ModelSpec( + name="TestModel", + train_datasets=synthetic_train_data, + test_embeddings_dir=test_emb_dir, + version="v1", + wandb_project="test-project", + ), + }, + output_dir=tmp_path_factory.mktemp("eval_output"), + task_channels={ + "infection_state": ["phase", "organelle"], + "cell_division_state": ["phase", "organelle"], + }, + ) + + @fixture(scope="function") def annotated_adata() -> ad.AnnData: """Provides an in-memory AnnData with 60 samples, 16 features, and annotations.""" diff --git a/tests/representation/evaluation/test_cross_validation.py b/tests/representation/evaluation/test_cross_validation.py new file mode 100644 index 000000000..d0e6b54dc --- /dev/null +++ b/tests/representation/evaluation/test_cross_validation.py @@ -0,0 +1,283 @@ +"""Tests for leave-one-dataset-out cross-validation (cross_validation.py).""" + +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest +from applications.DynaCLR.evaluation.linear_classifiers.cross_validation import ( + _check_class_safety, + _compute_summary, + cross_validate_datasets, + generate_cv_report, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def cv_config(eval_config): + """Eval config with wandb disabled for CV tests.""" + eval_config.wandb_logging = False + return eval_config + + +@pytest.fixture +def cv_results(cv_config): + """Run CV once and return (results_df, summary_df).""" + return cross_validate_datasets( + cv_config, ranking_metric="auroc", n_bootstrap=2, min_class_samples=5 + ) + + +# --------------------------------------------------------------------------- +# Row count +# --------------------------------------------------------------------------- + + +class TestCVRunCount: + def test_cv_run_count(self, cv_results): + results_df, _ = cv_results + # 1 model, 2 tasks, 2 channels = 4 combos + # Each combo: 2 train datasets → baseline + 2 folds = 3 folds + # Each fold × 2 seeds = 6 rows per combo + # Total = 4 × 6 = 24 + n_combos = 4 + n_folds_per_combo = 3 # baseline + 2 leave-one-out + n_seeds = 2 + expected = n_combos * n_folds_per_combo * n_seeds + assert len(results_df) == expected, ( + f"Expected {expected} rows, got {len(results_df)}" + ) + + def test_baseline_present(self, cv_results): + results_df, _ = cv_results + baseline_rows = results_df[results_df["excluded_dataset"] == "baseline"] + assert len(baseline_rows) > 0 + + def test_each_combo_has_baseline(self, cv_results): + results_df, _ = cv_results + for (model, task, channel), group in results_df.groupby( + ["model", "task", "channel"] + ): + baseline = group[group["excluded_dataset"] == "baseline"] + assert len(baseline) > 0, f"No baseline for {model}/{task}/{channel}" + + +# --------------------------------------------------------------------------- +# Delta and impact +# --------------------------------------------------------------------------- + + +class TestDeltaComputation: + def test_delta_computation(self, cv_results): + _, summary_df = cv_results + for _, row in summary_df.iterrows(): + if row["impact"] == "baseline": + assert row["delta"] == 0.0 + elif row["impact"] != "unsafe": + expected_delta = row["mean_auroc"] - row["baseline_mean"] + assert np.isclose(row["delta"], expected_delta, atol=1e-10), ( + f"Delta mismatch for {row['excluded_dataset']}: " + f"{row['delta']} != {expected_delta}" + ) + + def test_impact_labels_valid(self, cv_results): + _, summary_df = cv_results + valid_impacts = {"helps", "hurts", "uncertain", "baseline", "unsafe"} + for impact in summary_df["impact"].values: + assert impact in valid_impacts, f"Invalid impact label: {impact}" + + def test_baseline_has_baseline_impact(self, cv_results): + _, summary_df = cv_results + baseline = summary_df[summary_df["excluded_dataset"] == "baseline"] + assert (baseline["impact"] == "baseline").all() + + +# --------------------------------------------------------------------------- +# Unsafe detection +# --------------------------------------------------------------------------- + + +class TestUnsafeDetection: + def test_unsafe_detection(self, cv_config, tmp_path_factory): + """Fold with too few class samples is marked unsafe.""" + results_df, summary_df = cross_validate_datasets( + cv_config, + ranking_metric="auroc", + n_bootstrap=2, + # Set threshold very high so all leave-one-out folds fail + min_class_samples=10000, + ) + non_baseline = summary_df[summary_df["excluded_dataset"] != "baseline"] + assert (non_baseline["impact"] == "unsafe").all(), ( + "Expected all leave-one-out folds to be unsafe with high threshold" + ) + + def test_check_class_safety_pass(self, cv_config): + """Safety check passes with low threshold.""" + from applications.DynaCLR.evaluation.linear_classifiers.cross_validation import ( + _build_datasets_for_combo, + ) + + model_spec = list(cv_config.models.values())[0] + pairs = _build_datasets_for_combo( + model_spec.train_datasets, "phase", "infection_state" + ) + ds_dicts = [p[1] for p in pairs] + assert _check_class_safety(ds_dicts, "infection_state", 1) + + def test_check_class_safety_fail(self, cv_config): + """Safety check fails with very high threshold.""" + from applications.DynaCLR.evaluation.linear_classifiers.cross_validation import ( + _build_datasets_for_combo, + ) + + model_spec = list(cv_config.models.values())[0] + pairs = _build_datasets_for_combo( + model_spec.train_datasets, "phase", "infection_state" + ) + ds_dicts = [p[1] for p in pairs] + assert not _check_class_safety(ds_dicts, "infection_state", 10000) + + +# --------------------------------------------------------------------------- +# Metric presence +# --------------------------------------------------------------------------- + + +class TestMetrics: + def test_auroc_computed(self, cv_results): + results_df, _ = cv_results + auroc_vals = results_df["auroc"].dropna() + assert len(auroc_vals) > 0, "No AUROC values computed" + assert (auroc_vals >= 0).all() and (auroc_vals <= 1).all(), ( + "AUROC out of [0, 1] range" + ) + + def test_minority_metrics(self, cv_results): + results_df, _ = cv_results + baseline = results_df[results_df["excluded_dataset"] == "baseline"] + has_minority = baseline["minority_f1"].notna().any() + if has_minority: + assert "minority_recall" in results_df.columns + assert "minority_precision" in results_df.columns + + def test_annotation_counts(self, cv_results): + results_df, _ = cv_results + class_cols = [c for c in results_df.columns if c.startswith("train_class_")] + assert len(class_cols) > 0, "No train class count columns found" + + def test_test_accuracy_present(self, cv_results): + results_df, _ = cv_results + baseline = results_df[results_df["excluded_dataset"] == "baseline"] + assert "test_accuracy" in baseline.columns + assert baseline["test_accuracy"].notna().any() + + +# --------------------------------------------------------------------------- +# CSV output +# --------------------------------------------------------------------------- + + +class TestCSVOutput: + def test_csv_output(self, cv_config, cv_results): + results_path = cv_config.output_dir / "cv_results.csv" + summary_path = cv_config.output_dir / "cv_summary.csv" + assert results_path.exists(), "cv_results.csv not written" + assert summary_path.exists(), "cv_summary.csv not written" + + results = pd.read_csv(results_path) + summary = pd.read_csv(summary_path) + assert len(results) > 0 + assert len(summary) > 0 + assert "excluded_dataset" in results.columns + assert "impact" in summary.columns + + +# --------------------------------------------------------------------------- +# No WandB calls +# --------------------------------------------------------------------------- + + +class TestNoWandB: + def test_no_wandb_calls(self, cv_config, monkeypatch): + """WandB should not be called when wandb_logging=False.""" + wandb_called = {"init": False} + + def mock_init(*args, **kwargs): + wandb_called["init"] = True + raise RuntimeError("wandb.init should not be called") + + monkeypatch.setattr( + "viscy.representation.evaluation.linear_classifier.wandb.init", + mock_init, + ) + + cv_config.wandb_logging = False + # This should not trigger wandb since CV doesn't call save_pipeline_to_wandb + results_df, _ = cross_validate_datasets( + cv_config, n_bootstrap=1, min_class_samples=5 + ) + assert not wandb_called["init"], "wandb.init was called during CV" + + +# --------------------------------------------------------------------------- +# Single dataset edge case +# --------------------------------------------------------------------------- + + +class TestSingleDataset: + def test_single_dataset_skips(self, cv_config): + """With only 1 training dataset, leave-one-out is not possible.""" + model_spec = list(cv_config.models.values())[0] + model_spec.train_datasets = [model_spec.train_datasets[0]] + + results_df, summary_df = cross_validate_datasets( + cv_config, n_bootstrap=2, min_class_samples=5 + ) + # Should produce empty results since we can't do leave-one-out with 1 dataset + assert results_df.empty + + +# --------------------------------------------------------------------------- +# PDF report +# --------------------------------------------------------------------------- + + +class TestPDFReport: + def test_report_generated(self, cv_config, cv_results): + results_df, summary_df = cv_results + report_path = generate_cv_report( + cv_config, results_df, summary_df, ranking_metric="auroc" + ) + assert report_path.exists() + assert report_path.suffix == ".pdf" + + +# --------------------------------------------------------------------------- +# _compute_summary unit test +# --------------------------------------------------------------------------- + + +class TestComputeSummary: + def test_empty_input(self): + result = _compute_summary(pd.DataFrame(), "auroc") + assert result.empty + + def test_correct_columns(self, cv_results): + _, summary_df = cv_results + expected_cols = { + "model", + "task", + "channel", + "excluded_dataset", + "mean_auroc", + "std_auroc", + "baseline_mean", + "delta", + "impact", + } + assert expected_cols.issubset(set(summary_df.columns)) diff --git a/tests/representation/evaluation/test_evaluate_dataset.py b/tests/representation/evaluation/test_evaluate_dataset.py new file mode 100644 index 000000000..67b3a5d52 --- /dev/null +++ b/tests/representation/evaluation/test_evaluate_dataset.py @@ -0,0 +1,401 @@ +"""Tests for the evaluation pipeline (evaluate_dataset.py).""" + +from unittest.mock import MagicMock, patch + +import joblib +import pandas as pd +import pytest +from applications.DynaCLR.evaluation.linear_classifiers.evaluate_dataset import ( + DatasetEvalConfig, + ModelSpec, + TrainDataset, + evaluate_predictions, + infer_classifiers, + train_classifiers, +) +from pydantic import ValidationError + +from viscy.representation.evaluation.linear_classifier import ( + LinearClassifierPipeline, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _mock_wandb(): + """Return a patch context that mocks all wandb calls.""" + mock_run = MagicMock() + mock_run.summary = {} + mock_artifact = MagicMock() + mock_artifact.version = "v0" + mock_run.log_artifact.return_value = mock_artifact + + return patch.multiple( + "viscy.representation.evaluation.linear_classifier", + wandb=MagicMock( + init=MagicMock(return_value=mock_run), + Artifact=MagicMock(return_value=mock_artifact), + ), + ) + + +# --------------------------------------------------------------------------- +# Config validation +# --------------------------------------------------------------------------- + + +class TestDatasetEvalConfig: + def test_valid_config(self, eval_config): + assert eval_config.dataset_name == "test_dataset" + assert len(eval_config.models) == 1 + assert "test_model" in eval_config.models + + def test_missing_test_annotations(self, synthetic_train_data, tmp_path): + emb_dir = tmp_path / "fake_emb" + emb_dir.mkdir() + with pytest.raises(ValidationError, match="Test annotations CSV not found"): + DatasetEvalConfig( + dataset_name="test", + test_annotations_csv=tmp_path / "nonexistent.csv", + models={ + "m": ModelSpec( + name="M", + train_datasets=synthetic_train_data, + test_embeddings_dir=emb_dir, + version="v1", + wandb_project="p", + ) + }, + output_dir=tmp_path / "out", + ) + + def test_missing_embeddings_dir(self, synthetic_train_data, tmp_path): + csv = tmp_path / "ann.csv" + csv.write_text("fov_name,id,infection_state\nA/1/0,0,infected\n") + with pytest.raises( + ValidationError, match="Test embeddings directory not found" + ): + DatasetEvalConfig( + dataset_name="test", + test_annotations_csv=csv, + models={ + "m": ModelSpec( + name="M", + train_datasets=synthetic_train_data, + test_embeddings_dir=tmp_path / "nonexistent", + version="v1", + wandb_project="p", + ) + }, + output_dir=tmp_path / "out", + ) + + def test_missing_train_annotations(self, tmp_path): + emb_dir = tmp_path / "emb" + emb_dir.mkdir() + csv = tmp_path / "ann.csv" + csv.write_text("fov_name,id,infection_state\nA/1/0,0,infected\n") + with pytest.raises(ValidationError, match="Annotations file not found"): + DatasetEvalConfig( + dataset_name="test", + test_annotations_csv=csv, + models={ + "m": ModelSpec( + name="M", + train_datasets=[ + TrainDataset( + embeddings_dir=emb_dir, + annotations=tmp_path / "missing.csv", + ) + ], + test_embeddings_dir=emb_dir, + version="v1", + wandb_project="p", + ) + }, + output_dir=tmp_path / "out", + ) + + def test_invalid_task(self, synthetic_train_data, synthetic_test_data, tmp_path): + test_emb_dir, test_csv = synthetic_test_data + with pytest.raises(ValidationError, match="Invalid task"): + DatasetEvalConfig( + dataset_name="test", + test_annotations_csv=test_csv, + models={ + "m": ModelSpec( + name="M", + train_datasets=synthetic_train_data, + test_embeddings_dir=test_emb_dir, + version="v1", + wandb_project="p", + ) + }, + output_dir=tmp_path / "out", + task_channels={"not_a_valid_task": ["phase"]}, + ) + + def test_auto_detect_tasks( + self, synthetic_train_data, synthetic_test_data, tmp_path + ): + test_emb_dir, test_csv = synthetic_test_data + config = DatasetEvalConfig( + dataset_name="test", + test_annotations_csv=test_csv, + models={ + "m": ModelSpec( + name="M", + train_datasets=synthetic_train_data, + test_embeddings_dir=test_emb_dir, + version="v1", + wandb_project="p", + ) + }, + output_dir=tmp_path / "out", + task_channels=None, + ) + assert config.task_channels is None + + +# --------------------------------------------------------------------------- +# Train classifiers +# --------------------------------------------------------------------------- + + +class TestTrainClassifiers: + def test_trains_all_combinations(self, eval_config): + with _mock_wandb(): + results = train_classifiers(eval_config) + + model_results = results["test_model"] + expected_combos = { + ("infection_state", "phase"), + ("infection_state", "organelle"), + ("cell_division_state", "phase"), + ("cell_division_state", "organelle"), + } + assert set(model_results.keys()) == expected_combos + + def test_metrics_contain_val_keys(self, eval_config): + with _mock_wandb(): + results = train_classifiers(eval_config) + + for combo_key, result in results["test_model"].items(): + metrics = result["metrics"] + assert "val_accuracy" in metrics, f"Missing val_accuracy for {combo_key}" + assert "val_weighted_f1" in metrics, ( + f"Missing val_weighted_f1 for {combo_key}" + ) + + def test_pipeline_saved_to_disk(self, eval_config): + with _mock_wandb(): + train_classifiers(eval_config) + + model_dir = eval_config.output_dir / "test_model" + for task, channels in eval_config.task_channels.items(): + for channel in channels: + pipeline_path = model_dir / f"{task}_{channel}_pipeline.joblib" + assert pipeline_path.exists(), f"Missing {pipeline_path.name}" + pipeline = joblib.load(pipeline_path) + assert isinstance(pipeline, LinearClassifierPipeline) + + def test_metrics_csv_written(self, eval_config): + with _mock_wandb(): + train_classifiers(eval_config) + + model_csv = eval_config.output_dir / "test_model" / "metrics_summary.csv" + assert model_csv.exists() + df = pd.read_csv(model_csv) + assert "task" in df.columns + assert "channel" in df.columns + assert "val_accuracy" in df.columns + assert len(df) == 4 # 2 tasks x 2 channels + + def test_comparison_csv_written(self, eval_config): + with _mock_wandb(): + train_classifiers(eval_config) + + comparison_csv = eval_config.output_dir / "metrics_comparison.csv" + assert comparison_csv.exists() + df = pd.read_csv(comparison_csv) + assert "model" in df.columns + + def test_skips_missing_channel(self, eval_config): + eval_config.task_channels = { + "infection_state": ["phase", "organelle", "sensor"], + "cell_division_state": ["phase", "organelle", "sensor"], + } + with _mock_wandb(): + results = train_classifiers(eval_config) + + model_results = results["test_model"] + trained_channels = {ch for _, ch in model_results.keys()} + assert "sensor" not in trained_channels + assert "phase" in trained_channels + assert "organelle" in trained_channels + + def test_skips_missing_task(self, eval_config): + eval_config.task_channels = { + "infection_state": ["phase", "organelle"], + "cell_division_state": ["phase", "organelle"], + "cell_death_state": ["phase", "organelle"], + } + with _mock_wandb(): + results = train_classifiers(eval_config) + + model_results = results["test_model"] + trained_tasks = {t for t, _ in model_results.keys()} + assert "cell_death_state" not in trained_tasks + assert "infection_state" in trained_tasks + + def test_wandb_config_has_provenance(self, eval_config): + mock_run = MagicMock() + mock_run.summary = {} + mock_artifact = MagicMock() + mock_artifact.version = "v0" + mock_run.log_artifact.return_value = mock_artifact + + mock_wandb_module = MagicMock() + mock_wandb_module.init.return_value = mock_run + mock_wandb_module.Artifact.return_value = mock_artifact + + with patch.multiple( + "viscy.representation.evaluation.linear_classifier", + wandb=mock_wandb_module, + ): + train_classifiers(eval_config) + + calls = mock_wandb_module.init.call_args_list + assert len(calls) > 0 + for call in calls: + wandb_config = call.kwargs.get("config", call.args[0] if call.args else {}) + assert "embedding_model" in wandb_config + assert "test_dataset" in wandb_config + assert "train_dataset_names" in wandb_config + assert wandb_config["embedding_model"] == "TestModel-v1" + assert wandb_config["test_dataset"] == "test_dataset" + + +# --------------------------------------------------------------------------- +# Infer classifiers +# --------------------------------------------------------------------------- + + +class TestInferClassifiers: + def test_predictions_on_all_cells(self, eval_config): + with _mock_wandb(): + trained = train_classifiers(eval_config) + predictions = infer_classifiers(eval_config, trained=trained) + + for model_label, model_preds in predictions.items(): + for (task, channel), adata in model_preds.items(): + pred_col = f"predicted_{task}" + assert pred_col in adata.obs.columns + assert adata.obs[pred_col].notna().all(), ( + f"NaN predictions for {model_label}/{task}/{channel}" + ) + + def test_prediction_columns_exist(self, eval_config): + with _mock_wandb(): + trained = train_classifiers(eval_config) + predictions = infer_classifiers(eval_config, trained=trained) + + for model_preds in predictions.values(): + for (task, _), adata in model_preds.items(): + assert f"predicted_{task}" in adata.obs.columns + assert f"predicted_{task}_proba" in adata.obsm + assert f"predicted_{task}_classes" in adata.uns + + def test_predictions_zarr_saved(self, eval_config): + with _mock_wandb(): + trained = train_classifiers(eval_config) + infer_classifiers(eval_config, trained=trained) + + for task, channels in eval_config.task_channels.items(): + for channel in channels: + pred_path = ( + eval_config.output_dir + / "test_model" + / f"{task}_{channel}_predictions.zarr" + ) + assert pred_path.exists(), f"Missing {pred_path.name}" + + def test_loads_pipeline_from_disk(self, eval_config): + with _mock_wandb(): + train_classifiers(eval_config) + + predictions = infer_classifiers(eval_config, trained=None) + model_preds = predictions["test_model"] + assert len(model_preds) == 4 # 2 tasks x 2 channels + + def test_provenance_in_uns(self, eval_config): + with _mock_wandb(): + trained = train_classifiers(eval_config) + predictions = infer_classifiers(eval_config, trained=trained) + + for model_preds in predictions.values(): + for (task, _), adata in model_preds.items(): + assert f"classifier_{task}_artifact" in adata.uns + + +# --------------------------------------------------------------------------- +# Evaluate predictions +# --------------------------------------------------------------------------- + + +class TestEvaluatePredictions: + def test_test_metrics_computed(self, eval_config): + with _mock_wandb(): + trained = train_classifiers(eval_config) + predictions = infer_classifiers(eval_config, trained=trained) + eval_results = evaluate_predictions(eval_config, predictions=predictions) + + for model_eval in eval_results.values(): + for (task, channel), result in model_eval.items(): + metrics = result["metrics"] + assert "test_accuracy" in metrics + assert "test_weighted_f1" in metrics + assert "test_n_samples" in metrics + assert metrics["test_n_samples"] > 0 + + def test_annotated_zarr_saved(self, eval_config): + with _mock_wandb(): + trained = train_classifiers(eval_config) + predictions = infer_classifiers(eval_config, trained=trained) + evaluate_predictions(eval_config, predictions=predictions) + + for task, channels in eval_config.task_channels.items(): + for channel in channels: + annotated_path = ( + eval_config.output_dir + / "test_model" + / f"{task}_{channel}_annotated.zarr" + ) + assert annotated_path.exists(), f"Missing {annotated_path.name}" + + def test_annotated_has_both_columns(self, eval_config): + with _mock_wandb(): + trained = train_classifiers(eval_config) + predictions = infer_classifiers(eval_config, trained=trained) + eval_results = evaluate_predictions(eval_config, predictions=predictions) + + for model_eval in eval_results.values(): + for (task, _), result in model_eval.items(): + adata = result["annotated_adata"] + assert task in adata.obs.columns + assert f"predicted_{task}" in adata.obs.columns + + def test_metrics_comparison_csv(self, eval_config): + with _mock_wandb(): + trained = train_classifiers(eval_config) + predictions = infer_classifiers(eval_config, trained=trained) + evaluate_predictions(eval_config, predictions=predictions) + + csv_path = eval_config.output_dir / "test_metrics_comparison.csv" + assert csv_path.exists() + df = pd.read_csv(csv_path) + assert "model" in df.columns + assert "test_accuracy" in df.columns + assert len(df) == 4 # 2 tasks x 2 channels diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index 12d732873..cade82ed6 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -207,11 +207,11 @@ class EmbeddingWriter(BasePredictionWriter): write_interval : Literal["batch", "epoch", "batch_and_epoch"], optional When to write the embeddings, by default 'epoch'. umap_kwargs : dict, optional - Keyword arguments passed to UMAP, by default None (i.e. UMAP is not computed). + Keyword arguments passed to UMAP, by default UMAP is not computed. phate_kwargs : dict, optional - Keyword arguments passed to PHATE, by default PHATE is computed with default parameters. + Keyword arguments passed to PHATE, by default PHATE is not computed. pca_kwargs : dict, optional - Keyword arguments passed to PCA, by default PCA is computed with default parameters. + Keyword arguments passed to PCA, by default PCA is not computed. """ def __init__( @@ -219,12 +219,8 @@ def __init__( output_path: Path, write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "epoch", umap_kwargs: dict | None = None, - phate_kwargs: dict | None = { - "knn": 5, - "decay": 40, - "n_jobs": -1, - }, - pca_kwargs: dict | None = {"n_components": 8}, + phate_kwargs: dict | None = None, + pca_kwargs: dict | None = None, overwrite: bool = False, ): super().__init__(write_interval)