diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 6c2d243d9..00dea509a 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -7,7 +7,6 @@ from lightning.pytorch import LightningModule from pytorch_metric_learning.losses import NTXentLoss from torch import Tensor, nn -from umap import UMAP from viscy.data.typing import TrackingIndex, TripletSample from viscy.representation.contrastive import ContrastiveEncoder @@ -183,6 +182,8 @@ def _log_step_samples(self, batch_idx, samples, stage: Literal["train", "val"]): output_list.extend(detach_sample(samples, self.log_samples_per_batch)) def log_embedding_umap(self, embeddings: Tensor, tag: str): + from umap import UMAP + _logger.debug(f"Computing UMAP for {tag} embeddings.") umap = UMAP(n_components=2) embeddings_np = embeddings.detach().cpu().numpy()