Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ def format_predictions_markdown(adata, task: str) -> str:
lines.append(f"**Classes:** {', '.join(adata.uns[classes_key])}")
lines.append("")

artifact_key = f"classifier_{task}_artifact"
if artifact_key in adata.uns.keys():
lines.append("### Classifier Provenance")
lines.append("")
lines.append(f"- **Artifact:** {adata.uns[artifact_key]}")
id_key = f"classifier_{task}_id"
if id_key in adata.uns.keys():
lines.append(f"- **Artifact ID:** {adata.uns[id_key]}")
version_key = f"classifier_{task}_version"
if version_key in adata.uns.keys():
lines.append(f"- **Artifact Version:** {adata.uns[version_key]}")
lines.append("")

return "\n".join(lines)


Expand Down Expand Up @@ -88,36 +101,52 @@ def main(config: Path):
click.echo(f"\n❌ Failed to load configuration: {e}", err=True)
raise click.Abort()

write_path = (
Path(inference_config.output_path)
if inference_config.output_path is not None
else Path(inference_config.embeddings_path)
)

click.echo(f"\n✓ Configuration loaded: {config}")
click.echo(f" Model: {inference_config.model_name}")
click.echo(f" Version: {inference_config.version}")
click.echo(f" Embeddings: {inference_config.embeddings_path}")
click.echo(f" Output: {inference_config.output_path}")
click.echo(f" Output: {write_path}")

try:
pipeline, loaded_config = load_pipeline_from_wandb(
pipeline, loaded_config, artifact_metadata = load_pipeline_from_wandb(
wandb_project=inference_config.wandb_project,
model_name=inference_config.model_name,
version=inference_config.version,
wandb_entity=inference_config.wandb_entity,
)

task = loaded_config["task"]
marker = loaded_config.get("marker")
task_key = f"{task}_{marker}" if marker else task

click.echo(f"\nLoading embeddings from: {inference_config.embeddings_path}")
adata = read_zarr(inference_config.embeddings_path)
click.echo(f"✓ Loaded embeddings: {adata.shape}")

adata = predict_with_classifier(adata, pipeline, task)
if inference_config.include_wells:
click.echo(f" Well filter: {inference_config.include_wells}")

adata = predict_with_classifier(
adata,
pipeline,
task_key,
artifact_metadata=artifact_metadata,
include_wells=inference_config.include_wells,
)

output_path = Path(inference_config.output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
write_path.parent.mkdir(parents=True, exist_ok=True)

click.echo(f"\nSaving predictions to: {output_path}")
adata.write_zarr(output_path)
click.echo(f"\nSaving predictions to: {write_path}")
adata.write_zarr(write_path)
click.echo("✓ Saved predictions")

click.echo("\n" + format_predictions_markdown(adata, task))
click.echo("\n" + format_predictions_markdown(adata, task_key))

click.echo("\n✓ Inference complete!")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,16 @@ wandb_entity: null
# Path to embeddings zarr file for inference
embeddings_path: /path/to/embeddings.zarr

# Path to save output zarr file with predictions
output_path: /path/to/output_with_predictions.zarr
# Path to save output zarr file with predictions.
# When omitted (or null), predictions are written back to embeddings_path.
# output_path: /path/to/output_with_predictions.zarr

# Whether to overwrite output if it already exists
# Well prefixes to restrict predictions to (optional).
# When omitted, all cells are predicted. Cells in other wells get NaN.
# Useful for organelle-specific classifiers where different wells have different markers.
# include_wells:
# - A/1
# - A/2

# Whether to overwrite output if it already exists (only used when output_path is set)
overwrite: false
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def main(config: Path):
click.echo(f"\n✓ Configuration loaded: {config}")
click.echo(f" Task: {train_config.task}")
click.echo(f" Input channel: {train_config.input_channel}")
if train_config.marker:
click.echo(f" Marker: {train_config.marker}")
click.echo(f" Model: {train_config.embedding_model}")
click.echo(f" Datasets: {len(train_config.train_datasets)}")

Expand Down
82 changes: 82 additions & 0 deletions tests/representation/evaluation/test_linear_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,57 @@ def test_predict_adds_uns_classes(self, pipeline_and_adata):
pipeline.classifier.classes_
)

def test_predict_stores_provenance(self, pipeline_and_adata):
pipeline, adata = pipeline_and_adata
metadata = {
"artifact_name": "linear-classifier-cell_death_state-phase:v2",
"artifact_id": "abc123",
"artifact_version": "v2",
}
result = predict_with_classifier(
adata.copy(), pipeline, "cell_death_state", artifact_metadata=metadata
)
assert (
result.uns["classifier_cell_death_state_artifact"]
== "linear-classifier-cell_death_state-phase:v2"
)
assert result.uns["classifier_cell_death_state_id"] == "abc123"
assert result.uns["classifier_cell_death_state_version"] == "v2"

def test_predict_no_provenance_by_default(self, pipeline_and_adata):
pipeline, adata = pipeline_and_adata
result = predict_with_classifier(adata.copy(), pipeline, "cell_death_state")
assert "classifier_cell_death_state_artifact" not in result.uns
assert "classifier_cell_death_state_id" not in result.uns
assert "classifier_cell_death_state_version" not in result.uns

def test_predict_with_include_wells(self, pipeline_and_adata):
pipeline, adata = pipeline_and_adata
data = adata.copy()
result = predict_with_classifier(
data, pipeline, "cell_death_state", include_wells=["A/1"]
)
well_mask = result.obs["fov_name"].str.startswith("A/1/")
predicted = result.obs["predicted_cell_death_state"]
assert predicted[well_mask].notna().all()
assert predicted[~well_mask].isna().all()

proba = result.obsm["predicted_cell_death_state_proba"]
assert np.isfinite(proba[well_mask]).all()
assert np.isnan(proba[~well_mask]).all()

def test_predict_marker_namespaced_task(self, pipeline_and_adata):
pipeline, adata = pipeline_and_adata
result = predict_with_classifier(
adata.copy(),
pipeline,
"organelle_state_g3bp1",
include_wells=["A/1"],
)
assert "predicted_organelle_state_g3bp1" in result.obs.columns
assert "predicted_organelle_state_g3bp1_proba" in result.obsm
assert "predicted_organelle_state_g3bp1_classes" in result.uns


class TestLoadAndCombineDatasets:
"""Tests for the load_and_combine_datasets function."""
Expand Down Expand Up @@ -434,3 +485,34 @@ def test_output_exists_with_overwrite(self, tmp_path):
overwrite=True,
)
assert config.overwrite is True

def test_output_path_none_defaults_to_inplace(self, tmp_path):
emb = tmp_path / "emb.zarr"
emb.mkdir()
config = LinearClassifierInferenceConfig(
wandb_project="test_project",
model_name="test_model",
embeddings_path=str(emb),
)
assert config.output_path is None

def test_include_wells(self, tmp_path):
emb = tmp_path / "emb.zarr"
emb.mkdir()
config = LinearClassifierInferenceConfig(
wandb_project="test_project",
model_name="test_model",
embeddings_path=str(emb),
include_wells=["A/1", "B/2"],
)
assert config.include_wells == ["A/1", "B/2"]

def test_include_wells_none_by_default(self, tmp_path):
emb = tmp_path / "emb.zarr"
emb.mkdir()
config = LinearClassifierInferenceConfig(
wandb_project="test_project",
model_name="test_model",
embeddings_path=str(emb),
)
assert config.include_wells is None
70 changes: 59 additions & 11 deletions viscy/representation/evaluation/linear_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ def predict_with_classifier(
adata: ad.AnnData,
pipeline: LinearClassifierPipeline,
task: str,
artifact_metadata: Optional[dict] = None,
include_wells: Optional[list[str]] = None,
) -> ad.AnnData:
"""Apply trained classifier to make predictions on new data.

Expand All @@ -371,7 +373,16 @@ def predict_with_classifier(
pipeline : LinearClassifierPipeline
Trained classifier pipeline with preprocessing.
task : str
Name of the classification task.
Name of the classification task (used as column suffix).
artifact_metadata : Optional[dict]
W&B artifact metadata from ``load_pipeline_from_wandb``. When provided,
provenance keys are stored in ``adata.uns`` under
``classifier_{task}_artifact``, ``classifier_{task}_id``, and
``classifier_{task}_version``.
include_wells : Optional[list[str]]
Well prefixes to restrict prediction to (e.g. ``["A/1", "B/2"]``).
Cells in other wells will have ``NaN`` for prediction columns.
When ``None``, all cells are predicted.

Returns
-------
Expand All @@ -381,19 +392,44 @@ def predict_with_classifier(
and class labels in .uns[f"predicted_{task}_classes"].
"""
print("\nApplying preprocessing and making predictions...")
X = adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray()

predictions = pipeline.predict(X)
prediction_proba = pipeline.predict_proba(X)
if include_wells is not None:
well_mask = adata.obs["fov_name"].str.startswith(
tuple(w + "/" for w in include_wells)
)
n_matched = well_mask.sum()
print(f" Well filter: {include_wells} -> {n_matched}/{len(adata)} cells")
else:
well_mask = np.ones(len(adata), dtype=bool)

X_full = adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray()
X_subset = X_full[well_mask]

predictions_subset = pipeline.predict(X_subset)
proba_subset = pipeline.predict_proba(X_subset)
n_classes = proba_subset.shape[1]

adata.obs[f"predicted_{task}"] = predictions
adata.obsm[f"predicted_{task}_proba"] = prediction_proba
all_predictions = np.full(len(adata), np.nan, dtype=object)
all_predictions[well_mask] = predictions_subset

all_proba = np.full((len(adata), n_classes), np.nan)
all_proba[well_mask] = proba_subset

adata.obs[f"predicted_{task}"] = all_predictions
adata.obsm[f"predicted_{task}_proba"] = all_proba
adata.uns[f"predicted_{task}_classes"] = pipeline.classifier.classes_.tolist()

if artifact_metadata is not None:
adata.uns[f"classifier_{task}_artifact"] = artifact_metadata["artifact_name"]
adata.uns[f"classifier_{task}_id"] = artifact_metadata["artifact_id"]
adata.uns[f"classifier_{task}_version"] = artifact_metadata["artifact_version"]

predicted_values = adata.obs[f"predicted_{task}"].dropna()
print("✓ Predictions complete")
print(f" Predicted {len(predicted_values)}/{len(adata)} cells")
print(" Predicted class distribution:")
print(adata.obs[f"predicted_{task}"].value_counts())
print(f" Probability matrix shape: {prediction_proba.shape}")
print(predicted_values.value_counts())
print(f" Probability matrix shape: {all_proba.shape}")
print(f" Classes: {pipeline.classifier.classes_.tolist()}")

return adata
Expand Down Expand Up @@ -435,17 +471,21 @@ def save_pipeline_to_wandb(

task = config["task"]
input_channel = config["input_channel"]
marker = config.get("marker")
use_pca = config.get("preprocessing", {}).get("use_pca", False)
n_pca = config.get("preprocessing", {}).get("n_pca_components")

model_name = f"linear-classifier-{task}-{input_channel}"
if marker:
model_name += f"-{marker}"
if use_pca:
model_name += f"-pca{n_pca}"

run = wandb.init(
project=wandb_project,
entity=wandb_entity,
job_type=f"linear-classifier-{task}-{input_channel}",
job_type=f"linear-classifier-{task}-{input_channel}"
+ (f"-{marker}" if marker else ""),
name=model_name,
group=model_name,
config=config,
Expand Down Expand Up @@ -503,7 +543,7 @@ def load_pipeline_from_wandb(
model_name: str,
version: str = "latest",
wandb_entity: Optional[str] = None,
) -> tuple[LinearClassifierPipeline, dict]:
) -> tuple[LinearClassifierPipeline, dict, dict]:
"""Load trained pipeline and config from Weights & Biases.

Parameters
Expand All @@ -523,6 +563,9 @@ def load_pipeline_from_wandb(
Loaded classifier pipeline.
dict
Configuration used for training.
dict
Artifact metadata with keys ``artifact_name``, ``artifact_id``,
and ``artifact_version``.
"""
print("\n" + "=" * 60)
print("LOADING MODEL FROM WANDB")
Expand All @@ -535,6 +578,11 @@ def load_pipeline_from_wandb(
)

artifact = run.use_artifact(f"{model_name}:{version}")
artifact_metadata = {
"artifact_name": f"{model_name}:{artifact.version}",
"artifact_id": artifact.id,
"artifact_version": artifact.version,
}
artifact_dir = Path(artifact.download())

config_path = artifact_dir / f"{model_name}_config.json"
Expand Down Expand Up @@ -573,4 +621,4 @@ def load_pipeline_from_wandb(

run.finish()

return pipeline, config
return pipeline, config, artifact_metadata
Loading