Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion core/initial_selection_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
from abc import ABC, abstractmethod

import kmedoids
import numpy as np
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.metrics import pairwise_distances, pairwise_distances_argmin_min
Expand Down Expand Up @@ -142,6 +141,19 @@ def __init__(
max_iter: int = 100,
init: str = "build",
) -> None:
# Imported lazily so that environments without the optional
# ``kmedoids`` package can still use the other strategies. Hydra
# only constructs this class when it's actually selected, so the
# ImportError surfaces immediately and only for users who opted
# in to k-medoids.
try:
import kmedoids # noqa: F401 (probe install, used in _kmedoids)
except ImportError as exc: # pragma: no cover - exercised only without dep
raise ImportError(
"KMedoidsInitialSelection requires the 'kmedoids' package. "
"Install it with `uv pip install kmedoids` (or `uv sync`)."
) from exc

super().__init__("KMEDOIDS", starting_batch_size=starting_batch_size)
self.seed = seed
self.metric = metric
Expand All @@ -163,6 +175,8 @@ def select(
return selected

def _kmedoids(self, embeddings: np.ndarray) -> list[int]:
import kmedoids # already validated in __init__

num_samples = embeddings.shape[0]
k = min(self.starting_batch_size, num_samples)
if k == 0:
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ classifiers = [
dependencies = [
"botorch>=0.16.1",
"hydra-core>=1.3.2",
"kmedoids>=0.5.1",
"numpy>=1.21.0",
"pandas>=1.3.0",
"scikit-learn>=1.6.0",
Expand All @@ -42,6 +41,9 @@ cluster = [
"hydra-submitit-launcher>=1.2.0",
"submitit>=1.5.3",
]
kmedoids = [
"kmedoids>=0.5.1",
]

[project.urls]
Homepage = "https://github.com/cellethology/deepdraw"
Expand Down
8 changes: 5 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading