diff --git a/applications/DynaCLR/evaluation/linear_classifiers/apply_linear_classifier.py b/applications/DynaCLR/evaluation/linear_classifiers/apply_linear_classifier.py index e86bea050..c62927984 100644 --- a/applications/DynaCLR/evaluation/linear_classifiers/apply_linear_classifier.py +++ b/applications/DynaCLR/evaluation/linear_classifiers/apply_linear_classifier.py @@ -61,6 +61,19 @@ def format_predictions_markdown(adata, task: str) -> str: lines.append(f"**Classes:** {', '.join(adata.uns[classes_key])}") lines.append("") + artifact_key = f"classifier_{task}_artifact" + if artifact_key in adata.uns.keys(): + lines.append("### Classifier Provenance") + lines.append("") + lines.append(f"- **Artifact:** {adata.uns[artifact_key]}") + id_key = f"classifier_{task}_id" + if id_key in adata.uns.keys(): + lines.append(f"- **Artifact ID:** {adata.uns[id_key]}") + version_key = f"classifier_{task}_version" + if version_key in adata.uns.keys(): + lines.append(f"- **Artifact Version:** {adata.uns[version_key]}") + lines.append("") + return "\n".join(lines) @@ -88,14 +101,20 @@ def main(config: Path): click.echo(f"\n❌ Failed to load configuration: {e}", err=True) raise click.Abort() + write_path = ( + Path(inference_config.output_path) + if inference_config.output_path is not None + else Path(inference_config.embeddings_path) + ) + click.echo(f"\n✓ Configuration loaded: {config}") click.echo(f" Model: {inference_config.model_name}") click.echo(f" Version: {inference_config.version}") click.echo(f" Embeddings: {inference_config.embeddings_path}") - click.echo(f" Output: {inference_config.output_path}") + click.echo(f" Output: {write_path}") try: - pipeline, loaded_config = load_pipeline_from_wandb( + pipeline, loaded_config, artifact_metadata = load_pipeline_from_wandb( wandb_project=inference_config.wandb_project, model_name=inference_config.model_name, version=inference_config.version, @@ -103,21 +122,31 @@ def main(config: Path): ) task = loaded_config["task"] + marker = loaded_config.get("marker") + task_key = f"{task}_{marker}" if marker else task click.echo(f"\nLoading embeddings from: {inference_config.embeddings_path}") adata = read_zarr(inference_config.embeddings_path) click.echo(f"✓ Loaded embeddings: {adata.shape}") - adata = predict_with_classifier(adata, pipeline, task) + if inference_config.include_wells: + click.echo(f" Well filter: {inference_config.include_wells}") + + adata = predict_with_classifier( + adata, + pipeline, + task_key, + artifact_metadata=artifact_metadata, + include_wells=inference_config.include_wells, + ) - output_path = Path(inference_config.output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) + write_path.parent.mkdir(parents=True, exist_ok=True) - click.echo(f"\nSaving predictions to: {output_path}") - adata.write_zarr(output_path) + click.echo(f"\nSaving predictions to: {write_path}") + adata.write_zarr(write_path) click.echo("✓ Saved predictions") - click.echo("\n" + format_predictions_markdown(adata, task)) + click.echo("\n" + format_predictions_markdown(adata, task_key)) click.echo("\n✓ Inference complete!") diff --git a/applications/DynaCLR/evaluation/linear_classifiers/configs/example_linear_classifier_inference.yaml b/applications/DynaCLR/evaluation/linear_classifiers/configs/example_linear_classifier_inference.yaml index a6e882365..7f7fe16a6 100644 --- a/applications/DynaCLR/evaluation/linear_classifiers/configs/example_linear_classifier_inference.yaml +++ b/applications/DynaCLR/evaluation/linear_classifiers/configs/example_linear_classifier_inference.yaml @@ -21,8 +21,16 @@ wandb_entity: null # Path to embeddings zarr file for inference embeddings_path: /path/to/embeddings.zarr -# Path to save output zarr file with predictions -output_path: /path/to/output_with_predictions.zarr +# Path to save output zarr file with predictions. +# When omitted (or null), predictions are written back to embeddings_path. +# output_path: /path/to/output_with_predictions.zarr -# Whether to overwrite output if it already exists +# Well prefixes to restrict predictions to (optional). +# When omitted, all cells are predicted. Cells in other wells get NaN. +# Useful for organelle-specific classifiers where different wells have different markers. +# include_wells: +# - A/1 +# - A/2 + +# Whether to overwrite output if it already exists (only used when output_path is set) overwrite: false diff --git a/applications/DynaCLR/evaluation/linear_classifiers/train_linear_classifier.py b/applications/DynaCLR/evaluation/linear_classifiers/train_linear_classifier.py index c554f5be9..d89f12a53 100644 --- a/applications/DynaCLR/evaluation/linear_classifiers/train_linear_classifier.py +++ b/applications/DynaCLR/evaluation/linear_classifiers/train_linear_classifier.py @@ -84,6 +84,8 @@ def main(config: Path): click.echo(f"\n✓ Configuration loaded: {config}") click.echo(f" Task: {train_config.task}") click.echo(f" Input channel: {train_config.input_channel}") + if train_config.marker: + click.echo(f" Marker: {train_config.marker}") click.echo(f" Model: {train_config.embedding_model}") click.echo(f" Datasets: {len(train_config.train_datasets)}") diff --git a/tests/representation/evaluation/test_linear_classifier.py b/tests/representation/evaluation/test_linear_classifier.py index e9a6975ab..6cf816b34 100644 --- a/tests/representation/evaluation/test_linear_classifier.py +++ b/tests/representation/evaluation/test_linear_classifier.py @@ -195,6 +195,57 @@ def test_predict_adds_uns_classes(self, pipeline_and_adata): pipeline.classifier.classes_ ) + def test_predict_stores_provenance(self, pipeline_and_adata): + pipeline, adata = pipeline_and_adata + metadata = { + "artifact_name": "linear-classifier-cell_death_state-phase:v2", + "artifact_id": "abc123", + "artifact_version": "v2", + } + result = predict_with_classifier( + adata.copy(), pipeline, "cell_death_state", artifact_metadata=metadata + ) + assert ( + result.uns["classifier_cell_death_state_artifact"] + == "linear-classifier-cell_death_state-phase:v2" + ) + assert result.uns["classifier_cell_death_state_id"] == "abc123" + assert result.uns["classifier_cell_death_state_version"] == "v2" + + def test_predict_no_provenance_by_default(self, pipeline_and_adata): + pipeline, adata = pipeline_and_adata + result = predict_with_classifier(adata.copy(), pipeline, "cell_death_state") + assert "classifier_cell_death_state_artifact" not in result.uns + assert "classifier_cell_death_state_id" not in result.uns + assert "classifier_cell_death_state_version" not in result.uns + + def test_predict_with_include_wells(self, pipeline_and_adata): + pipeline, adata = pipeline_and_adata + data = adata.copy() + result = predict_with_classifier( + data, pipeline, "cell_death_state", include_wells=["A/1"] + ) + well_mask = result.obs["fov_name"].str.startswith("A/1/") + predicted = result.obs["predicted_cell_death_state"] + assert predicted[well_mask].notna().all() + assert predicted[~well_mask].isna().all() + + proba = result.obsm["predicted_cell_death_state_proba"] + assert np.isfinite(proba[well_mask]).all() + assert np.isnan(proba[~well_mask]).all() + + def test_predict_marker_namespaced_task(self, pipeline_and_adata): + pipeline, adata = pipeline_and_adata + result = predict_with_classifier( + adata.copy(), + pipeline, + "organelle_state_g3bp1", + include_wells=["A/1"], + ) + assert "predicted_organelle_state_g3bp1" in result.obs.columns + assert "predicted_organelle_state_g3bp1_proba" in result.obsm + assert "predicted_organelle_state_g3bp1_classes" in result.uns + class TestLoadAndCombineDatasets: """Tests for the load_and_combine_datasets function.""" @@ -434,3 +485,34 @@ def test_output_exists_with_overwrite(self, tmp_path): overwrite=True, ) assert config.overwrite is True + + def test_output_path_none_defaults_to_inplace(self, tmp_path): + emb = tmp_path / "emb.zarr" + emb.mkdir() + config = LinearClassifierInferenceConfig( + wandb_project="test_project", + model_name="test_model", + embeddings_path=str(emb), + ) + assert config.output_path is None + + def test_include_wells(self, tmp_path): + emb = tmp_path / "emb.zarr" + emb.mkdir() + config = LinearClassifierInferenceConfig( + wandb_project="test_project", + model_name="test_model", + embeddings_path=str(emb), + include_wells=["A/1", "B/2"], + ) + assert config.include_wells == ["A/1", "B/2"] + + def test_include_wells_none_by_default(self, tmp_path): + emb = tmp_path / "emb.zarr" + emb.mkdir() + config = LinearClassifierInferenceConfig( + wandb_project="test_project", + model_name="test_model", + embeddings_path=str(emb), + ) + assert config.include_wells is None diff --git a/viscy/representation/evaluation/linear_classifier.py b/viscy/representation/evaluation/linear_classifier.py index 84a2730d3..7b9027800 100644 --- a/viscy/representation/evaluation/linear_classifier.py +++ b/viscy/representation/evaluation/linear_classifier.py @@ -361,6 +361,8 @@ def predict_with_classifier( adata: ad.AnnData, pipeline: LinearClassifierPipeline, task: str, + artifact_metadata: Optional[dict] = None, + include_wells: Optional[list[str]] = None, ) -> ad.AnnData: """Apply trained classifier to make predictions on new data. @@ -371,7 +373,16 @@ def predict_with_classifier( pipeline : LinearClassifierPipeline Trained classifier pipeline with preprocessing. task : str - Name of the classification task. + Name of the classification task (used as column suffix). + artifact_metadata : Optional[dict] + W&B artifact metadata from ``load_pipeline_from_wandb``. When provided, + provenance keys are stored in ``adata.uns`` under + ``classifier_{task}_artifact``, ``classifier_{task}_id``, and + ``classifier_{task}_version``. + include_wells : Optional[list[str]] + Well prefixes to restrict prediction to (e.g. ``["A/1", "B/2"]``). + Cells in other wells will have ``NaN`` for prediction columns. + When ``None``, all cells are predicted. Returns ------- @@ -381,19 +392,44 @@ def predict_with_classifier( and class labels in .uns[f"predicted_{task}_classes"]. """ print("\nApplying preprocessing and making predictions...") - X = adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray() - predictions = pipeline.predict(X) - prediction_proba = pipeline.predict_proba(X) + if include_wells is not None: + well_mask = adata.obs["fov_name"].str.startswith( + tuple(w + "/" for w in include_wells) + ) + n_matched = well_mask.sum() + print(f" Well filter: {include_wells} -> {n_matched}/{len(adata)} cells") + else: + well_mask = np.ones(len(adata), dtype=bool) + + X_full = adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray() + X_subset = X_full[well_mask] + + predictions_subset = pipeline.predict(X_subset) + proba_subset = pipeline.predict_proba(X_subset) + n_classes = proba_subset.shape[1] - adata.obs[f"predicted_{task}"] = predictions - adata.obsm[f"predicted_{task}_proba"] = prediction_proba + all_predictions = np.full(len(adata), np.nan, dtype=object) + all_predictions[well_mask] = predictions_subset + + all_proba = np.full((len(adata), n_classes), np.nan) + all_proba[well_mask] = proba_subset + + adata.obs[f"predicted_{task}"] = all_predictions + adata.obsm[f"predicted_{task}_proba"] = all_proba adata.uns[f"predicted_{task}_classes"] = pipeline.classifier.classes_.tolist() + if artifact_metadata is not None: + adata.uns[f"classifier_{task}_artifact"] = artifact_metadata["artifact_name"] + adata.uns[f"classifier_{task}_id"] = artifact_metadata["artifact_id"] + adata.uns[f"classifier_{task}_version"] = artifact_metadata["artifact_version"] + + predicted_values = adata.obs[f"predicted_{task}"].dropna() print("✓ Predictions complete") + print(f" Predicted {len(predicted_values)}/{len(adata)} cells") print(" Predicted class distribution:") - print(adata.obs[f"predicted_{task}"].value_counts()) - print(f" Probability matrix shape: {prediction_proba.shape}") + print(predicted_values.value_counts()) + print(f" Probability matrix shape: {all_proba.shape}") print(f" Classes: {pipeline.classifier.classes_.tolist()}") return adata @@ -435,17 +471,21 @@ def save_pipeline_to_wandb( task = config["task"] input_channel = config["input_channel"] + marker = config.get("marker") use_pca = config.get("preprocessing", {}).get("use_pca", False) n_pca = config.get("preprocessing", {}).get("n_pca_components") model_name = f"linear-classifier-{task}-{input_channel}" + if marker: + model_name += f"-{marker}" if use_pca: model_name += f"-pca{n_pca}" run = wandb.init( project=wandb_project, entity=wandb_entity, - job_type=f"linear-classifier-{task}-{input_channel}", + job_type=f"linear-classifier-{task}-{input_channel}" + + (f"-{marker}" if marker else ""), name=model_name, group=model_name, config=config, @@ -503,7 +543,7 @@ def load_pipeline_from_wandb( model_name: str, version: str = "latest", wandb_entity: Optional[str] = None, -) -> tuple[LinearClassifierPipeline, dict]: +) -> tuple[LinearClassifierPipeline, dict, dict]: """Load trained pipeline and config from Weights & Biases. Parameters @@ -523,6 +563,9 @@ def load_pipeline_from_wandb( Loaded classifier pipeline. dict Configuration used for training. + dict + Artifact metadata with keys ``artifact_name``, ``artifact_id``, + and ``artifact_version``. """ print("\n" + "=" * 60) print("LOADING MODEL FROM WANDB") @@ -535,6 +578,11 @@ def load_pipeline_from_wandb( ) artifact = run.use_artifact(f"{model_name}:{version}") + artifact_metadata = { + "artifact_name": f"{model_name}:{artifact.version}", + "artifact_id": artifact.id, + "artifact_version": artifact.version, + } artifact_dir = Path(artifact.download()) config_path = artifact_dir / f"{model_name}_config.json" @@ -573,4 +621,4 @@ def load_pipeline_from_wandb( run.finish() - return pipeline, config + return pipeline, config, artifact_metadata diff --git a/viscy/representation/evaluation/linear_classifier_config.py b/viscy/representation/evaluation/linear_classifier_config.py index c1c4812d3..170c8b8e6 100644 --- a/viscy/representation/evaluation/linear_classifier_config.py +++ b/viscy/representation/evaluation/linear_classifier_config.py @@ -57,6 +57,10 @@ class LinearClassifierTrainConfig(BaseModel): # Task metadata task: VALID_TASKS = Field(...) input_channel: VALID_CHANNELS = Field(...) + marker: Optional[str] = Field( + default=None, + description="Marker name for marker-specific tasks (e.g. g3bp1, sec61b, tomm20).", + ) embedding_model: str = Field(..., min_length=1) # Training datasets @@ -138,8 +142,13 @@ class LinearClassifierInferenceConfig(BaseModel): W&B entity (username or team). embeddings_path : str Path to embeddings zarr file for inference. - output_path : str - Path to save output zarr file with predictions. + output_path : Optional[str] + Path to save output zarr file with predictions. When ``None`` + (the default), predictions are written back to ``embeddings_path``. + include_wells : Optional[list[str]] + Well prefixes to restrict prediction to (e.g. ``["A/1", "B/2"]``). + Cells in other wells will have ``NaN`` for prediction columns. + When ``None`` (the default), all cells are predicted. overwrite : bool Whether to overwrite output if it exists. """ @@ -149,12 +158,11 @@ class LinearClassifierInferenceConfig(BaseModel): version: str = Field(default="latest", min_length=1) wandb_entity: Optional[str] = Field(default=None) embeddings_path: str = Field(..., min_length=1) - output_path: str = Field(..., min_length=1) + output_path: Optional[str] = Field(default=None) + include_wells: Optional[list[str]] = Field(default=None) overwrite: bool = Field(default=False) - @field_validator( - "wandb_project", "model_name", "version", "embeddings_path", "output_path" - ) + @field_validator("wandb_project", "model_name", "version", "embeddings_path") @classmethod def validate_non_empty(cls, v: str) -> str: """Ensure string fields are non-empty.""" @@ -166,14 +174,15 @@ def validate_non_empty(cls, v: str) -> str: def validate_paths(self): """Validate input exists and output doesn't exist unless overwrite=True.""" embeddings_path = Path(self.embeddings_path) - output_path = Path(self.output_path) if not embeddings_path.exists(): raise ValueError(f"Embeddings file not found: {self.embeddings_path}") - if output_path.exists() and not self.overwrite: - raise ValueError( - f"Output file already exists: {self.output_path}. " - f"Set overwrite=true to overwrite." - ) + if self.output_path is not None: + output_path = Path(self.output_path) + if output_path.exists() and not self.overwrite: + raise ValueError( + f"Output file already exists: {self.output_path}. " + f"Set overwrite=true to overwrite." + ) return self