diff --git a/src/cellmap_data/sampler.py b/src/cellmap_data/sampler.py index 426d829..477fa73 100644 --- a/src/cellmap_data/sampler.py +++ b/src/cellmap_data/sampler.py @@ -24,8 +24,9 @@ class ClassBalancedSampler(Sampler): ``dataset.get_crop_class_matrix()`` → ``bool[n_crops, n_classes]``. 2. Maintain running counts of how many times each class has been seen. 3. At each step: pick the class with the lowest count (ties broken - randomly), sample a crop annotating it, yield that crop index, then - increment counts for *all* classes that crop annotates. + randomly), sample a matrix row (crop) annotating it, map that row to + an actual dataset sample index, and yield the sample index. Then + increment counts for *all* classes that row annotates. This guarantees rare classes get sampled as often as common ones. @@ -64,6 +65,12 @@ def __init__( self.class_to_crops[c] = indices self.active_classes: list[int] = sorted(self.class_to_crops.keys()) + if not self.active_classes: + raise ValueError( + "ClassBalancedSampler: no active classes found in crop-class " + "matrix. This can occur when all requested classes are only " + "represented by empty crops (e.g., EmptyImage)." + ) def __iter__(self) -> Iterator[int]: class_counts = np.zeros(self.n_classes, dtype=np.float64) @@ -79,14 +86,60 @@ def __iter__(self) -> Iterator[int]: ] target_class = int(self.rng.choice(tied)) - # Sample a crop that annotates this class - crop_idx = int(self.rng.choice(self.class_to_crops[target_class])) + # Sample a matrix row (crop) that annotates this class + row_idx = int(self.rng.choice(self.class_to_crops[target_class])) - # Increment counts for all classes this crop annotates - annotated = np.where(self.crop_class_matrix[crop_idx])[0] + # Increment counts for all classes this row annotates + annotated = np.where(self.crop_class_matrix[row_idx])[0] class_counts[annotated] += 1.0 - yield crop_idx + # Map matrix row (dataset-level row) to an actual sample index. + # If n_crops equals len(dataset), the row index IS the sample index. + if self.n_crops == len(self.dataset): + sample_idx = row_idx + elif hasattr(self.dataset, "datasets") and hasattr( + self.dataset, "cumulative_sizes" + ): + # ConcatDataset / CellMapMultiDataset: each row corresponds + # to one sub-dataset; pick a random sample within that sub-dataset. + cumulative_sizes = self.dataset.cumulative_sizes + if row_idx < len(cumulative_sizes): + start = int(cumulative_sizes[row_idx - 1]) if row_idx > 0 else 0 + end = int(cumulative_sizes[row_idx]) + else: + raise ValueError( + "ClassBalancedSampler: crop index out of range for " + "ConcatDataset/CellMapMultiDataset mapping. " + f"row_idx={row_idx}, n_subdatasets={len(cumulative_sizes)}" + ) + if start >= end or end > len(self.dataset): + raise ValueError( + "ClassBalancedSampler: invalid sub-dataset slice computed " + "from cumulative_sizes for row index " + f"{row_idx}: start={start}, end={end}, " + f"len(dataset)={len(self.dataset)}" + ) + sample_idx = int(self.rng.integers(start, end)) + else: + # Generic fallback: partition [0, len(dataset)) into n_crops + # contiguous segments and sample within this row's segment. + total = len(self.dataset) + if self.n_crops <= 1 or total <= 0: + start, end = 0, max(total, 1) + else: + base = total // self.n_crops + remainder = total % self.n_crops + if row_idx < remainder: + start = row_idx * (base + 1) + end = start + (base + 1) + else: + start = remainder * (base + 1) + (row_idx - remainder) * base + end = start + base + if start >= end or end > total: + start, end = 0, total + sample_idx = int(self.rng.integers(start, end)) + + yield sample_idx def __len__(self) -> int: return self.samples_per_epoch diff --git a/src/cellmap_data/utils/misc.py b/src/cellmap_data/utils/misc.py index 42a523b..4ac7782 100644 --- a/src/cellmap_data/utils/misc.py +++ b/src/cellmap_data/utils/misc.py @@ -124,7 +124,9 @@ def min_redundant_inds( ) -> torch.Tensor: """Returns k indices from 0 to n-1 with minimum redundancy. - If replacement is False, the indices are unique. + If replacement is False and k <= n, the indices are unique. + If replacement is False and k > n, duplicates are unavoidable; indices + are unique within each block of size n (minimum redundancy overall). If replacement is True, the indices can have duplicates. Args: @@ -134,13 +136,18 @@ def min_redundant_inds( rng (torch.Generator, optional): The random number generator. Defaults to None. Returns: - torch.Tensor: A tensor of k indices. + torch.Tensor: A tensor of exactly k indices. """ if replacement: return torch.randint(n, (k,), generator=rng) else: if k > n: - # Repeat the unique indices until we have k indices - return torch.cat([torch.randperm(n, generator=rng) for _ in range(k // n)]) + # Repeat unique indices until we have k indices (handle remainder) + full_perms = k // n + remainder = k % n + parts = [torch.randperm(n, generator=rng) for _ in range(full_perms)] + if remainder > 0: + parts.append(torch.randperm(n, generator=rng)[:remainder]) + return torch.cat(parts) else: return torch.randperm(n, generator=rng)[:k] diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c89f0e5..2ee1166 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -254,7 +254,6 @@ def test_small_crop_pad_false_excluded(self, tmp_path): gt_base = str(tmp_path / "gt.zarr") os.makedirs(gt_base, exist_ok=True) - import json with open(os.path.join(gt_base, ".zgroup"), "w") as f: f.write('{"zarr_format": 2}') diff --git a/tests/test_geometry.py b/tests/test_geometry.py index ff1d840..dd1110f 100644 --- a/tests/test_geometry.py +++ b/tests/test_geometry.py @@ -2,8 +2,6 @@ from __future__ import annotations -import pytest - from cellmap_data.utils.geometry import box_intersection, box_shape, box_union diff --git a/tests/test_image.py b/tests/test_image.py index 92bd769..9b555c6 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -252,8 +252,6 @@ def test_rotation_output_shape_preserved(self, tmp_path): class TestCellMapImageClassCounts: def test_class_counts_keys(self, tmp_path): - import zarr as z - data = np.zeros((10, 10, 10), dtype=np.uint8) data[2:5, 2:5, 2:5] = 1 # some foreground path = create_test_zarr(tmp_path, shape=(10, 10, 10), data=data) diff --git a/tests/test_multidataset.py b/tests/test_multidataset.py index 29f02ee..44ca43b 100644 --- a/tests/test_multidataset.py +++ b/tests/test_multidataset.py @@ -3,7 +3,6 @@ from __future__ import annotations import numpy as np -import torch from cellmap_data import CellMapDataset, CellMapMultiDataset @@ -15,8 +14,6 @@ def _make_ds(tmp_path, suffix="", **kwargs): - import tempfile, pathlib - sub = tmp_path / suffix if suffix else tmp_path / "ds0" sub.mkdir(parents=True, exist_ok=True) info = create_test_dataset(sub, classes=CLASSES, **kwargs) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 6b7f4b5..e4d168d 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -6,6 +6,7 @@ import pytest from cellmap_data.sampler import ClassBalancedSampler +from cellmap_data.utils.misc import min_redundant_inds class FakeDataset: @@ -21,6 +22,27 @@ def __len__(self) -> int: return self._matrix.shape[0] +class FakeConcatDataset: + """Minimal ConcatDataset-like dataset with datasets + cumulative_sizes.""" + + def __init__(self, sub_lengths: list[int], matrix: np.ndarray): + self._matrix = matrix + self._sub_lengths = sub_lengths + # Cumulative sizes mirrors torch.utils.data.ConcatDataset behaviour + self.cumulative_sizes: list[int] = [] + total = 0 + for length in sub_lengths: + total += length + self.cumulative_sizes.append(total) + self.datasets = [None] * len(sub_lengths) # placeholders + + def get_crop_class_matrix(self) -> np.ndarray: + return self._matrix + + def __len__(self) -> int: + return sum(self._sub_lengths) + + class TestClassBalancedSampler: def _make_sampler(self, matrix, samples_per_epoch=None, seed=42): ds = FakeDataset(matrix) @@ -86,3 +108,68 @@ def test_yields_valid_indices_for_single_class(self): indices = list(sampler) assert len(indices) == 10 assert all(0 <= i < 5 for i in indices) + + def test_raises_when_no_active_classes(self): + """All-False crop-class matrix must raise ValueError immediately.""" + matrix = np.zeros((4, 3), dtype=bool) + with pytest.raises(ValueError, match="no active classes"): + self._make_sampler(matrix, samples_per_epoch=5) + + def test_concat_dataset_indices_in_correct_subdataset(self): + """ConcatDataset path: each yielded index falls in the expected sub-dataset range.""" + # Two sub-datasets: first has 10 samples, second has 20 samples + sub_lengths = [10, 20] + # Row 0 → only class 0 annotated; Row 1 → only class 1 annotated + matrix = np.array([[True, False], [False, True]], dtype=bool) + ds = FakeConcatDataset(sub_lengths, matrix) + sampler = ClassBalancedSampler(ds, samples_per_epoch=40, seed=0) + indices = list(sampler) + assert len(indices) == 40 + # All indices must be valid dataset indices + assert all(0 <= i < len(ds) for i in indices) + # Indices from class-0 crops (row 0 → sub-dataset 0) must be in [0, 10) + # Indices from class-1 crops (row 1 → sub-dataset 1) must be in [10, 30) + # Because the sampler alternates classes, roughly half go to each sub-dataset + indices_set = set(indices) + assert any(i < 10 for i in indices_set), "No index from sub-dataset 0" + assert any(10 <= i < 30 for i in indices_set), "No index from sub-dataset 1" + + def test_concat_dataset_all_indices_in_range(self): + """ConcatDataset path: all yielded indices are within [0, len(dataset)).""" + sub_lengths = [5, 5, 5] + matrix = np.eye(3, dtype=bool) + ds = FakeConcatDataset(sub_lengths, matrix) + sampler = ClassBalancedSampler(ds, samples_per_epoch=30, seed=7) + indices = list(sampler) + assert all(0 <= i < len(ds) for i in indices) + + +class TestMinRedundantInds: + def test_replacement_returns_k(self): + result = min_redundant_inds(5, 12, replacement=True) + assert len(result) == 12 + + def test_no_replacement_k_leq_n(self): + result = min_redundant_inds(10, 4, replacement=False) + assert len(result) == 4 + assert len(set(result.tolist())) == 4 # all unique + + def test_no_replacement_k_equals_n(self): + result = min_redundant_inds(5, 5, replacement=False) + assert len(result) == 5 + assert sorted(result.tolist()) == list(range(5)) + + def test_no_replacement_k_gt_n_exact_multiple(self): + """k=6, n=3: two full permutations, exactly 6 indices returned.""" + result = min_redundant_inds(3, 6, replacement=False) + assert len(result) == 6 + + def test_no_replacement_k_gt_n_with_remainder(self): + """k=7, n=3: must return exactly 7 indices, not 6.""" + result = min_redundant_inds(3, 7, replacement=False) + assert len(result) == 7 + + def test_no_replacement_all_values_in_range(self): + result = min_redundant_inds(4, 11, replacement=False) + assert len(result) == 11 + assert all(0 <= v < 4 for v in result.tolist())