From 295b90da2e054dc0c7b5a7d6119b4d49570c843a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 17:30:47 +0000 Subject: [PATCH 1/4] Initial plan From 6c895fa81a24c11c6ec261b6669e8f2401ddee1a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 17:38:39 +0000 Subject: [PATCH 2/4] Fix review comments: unused imports, min_redundant_inds, ClassBalancedSampler Co-authored-by: rhoadesScholar <37990507+rhoadesScholar@users.noreply.github.com> Agent-Logs-Url: https://github.com/janelia-cellmap/cellmap-data/sessions/1e323ffb-2a14-4aae-99e1-8e70c1d393d3 --- src/cellmap_data/sampler.py | 45 +++++++++++++++++++++++++++++++++- src/cellmap_data/utils/misc.py | 9 +++++-- tests/test_dataset.py | 1 - tests/test_geometry.py | 2 -- tests/test_image.py | 2 -- tests/test_multidataset.py | 3 --- tests/test_sampler.py | 1 - 7 files changed, 51 insertions(+), 12 deletions(-) diff --git a/src/cellmap_data/sampler.py b/src/cellmap_data/sampler.py index 426d829..4de777a 100644 --- a/src/cellmap_data/sampler.py +++ b/src/cellmap_data/sampler.py @@ -64,6 +64,12 @@ def __init__( self.class_to_crops[c] = indices self.active_classes: list[int] = sorted(self.class_to_crops.keys()) + if not self.active_classes: + raise ValueError( + "ClassBalancedSampler: no active classes found in crop-class " + "matrix. This can occur when all requested classes are only " + "represented by empty crops (e.g., EmptyImage)." + ) def __iter__(self) -> Iterator[int]: class_counts = np.zeros(self.n_classes, dtype=np.float64) @@ -86,7 +92,44 @@ def __iter__(self) -> Iterator[int]: annotated = np.where(self.crop_class_matrix[crop_idx])[0] class_counts[annotated] += 1.0 - yield crop_idx + # Map crop index (dataset-level row) to an actual sample index. + # If n_crops equals len(dataset), the crop index IS the sample index. + if self.n_crops == len(self.dataset): + sample_idx = crop_idx + elif hasattr(self.dataset, "datasets") and hasattr( + self.dataset, "cumulative_sizes" + ): + # ConcatDataset / CellMapMultiDataset: each crop row corresponds + # to one sub-dataset; pick a random sample within that sub-dataset. + cumulative_sizes = self.dataset.cumulative_sizes + if crop_idx < len(cumulative_sizes): + start = int(cumulative_sizes[crop_idx - 1]) if crop_idx > 0 else 0 + end = int(cumulative_sizes[crop_idx]) + else: + start, end = 0, len(self.dataset) + if start >= end or end > len(self.dataset): + start, end = 0, len(self.dataset) + sample_idx = int(self.rng.integers(start, end)) + else: + # Generic fallback: partition [0, len(dataset)) into n_crops + # contiguous segments and sample within this crop's segment. + total = len(self.dataset) + if self.n_crops <= 1 or total <= 0: + start, end = 0, max(total, 1) + else: + base = total // self.n_crops + remainder = total % self.n_crops + if crop_idx < remainder: + start = crop_idx * (base + 1) + end = start + (base + 1) + else: + start = remainder * (base + 1) + (crop_idx - remainder) * base + end = start + base + if start >= end or end > total: + start, end = 0, total + sample_idx = int(self.rng.integers(start, end)) + + yield sample_idx def __len__(self) -> int: return self.samples_per_epoch diff --git a/src/cellmap_data/utils/misc.py b/src/cellmap_data/utils/misc.py index 42a523b..8341a71 100644 --- a/src/cellmap_data/utils/misc.py +++ b/src/cellmap_data/utils/misc.py @@ -140,7 +140,12 @@ def min_redundant_inds( return torch.randint(n, (k,), generator=rng) else: if k > n: - # Repeat the unique indices until we have k indices - return torch.cat([torch.randperm(n, generator=rng) for _ in range(k // n)]) + # Repeat unique indices until we have k indices (handle remainder) + full_perms = k // n + remainder = k % n + parts = [torch.randperm(n, generator=rng) for _ in range(full_perms)] + if remainder > 0: + parts.append(torch.randperm(n, generator=rng)[:remainder]) + return torch.cat(parts) else: return torch.randperm(n, generator=rng)[:k] diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c89f0e5..2ee1166 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -254,7 +254,6 @@ def test_small_crop_pad_false_excluded(self, tmp_path): gt_base = str(tmp_path / "gt.zarr") os.makedirs(gt_base, exist_ok=True) - import json with open(os.path.join(gt_base, ".zgroup"), "w") as f: f.write('{"zarr_format": 2}') diff --git a/tests/test_geometry.py b/tests/test_geometry.py index ff1d840..dd1110f 100644 --- a/tests/test_geometry.py +++ b/tests/test_geometry.py @@ -2,8 +2,6 @@ from __future__ import annotations -import pytest - from cellmap_data.utils.geometry import box_intersection, box_shape, box_union diff --git a/tests/test_image.py b/tests/test_image.py index 92bd769..9b555c6 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -252,8 +252,6 @@ def test_rotation_output_shape_preserved(self, tmp_path): class TestCellMapImageClassCounts: def test_class_counts_keys(self, tmp_path): - import zarr as z - data = np.zeros((10, 10, 10), dtype=np.uint8) data[2:5, 2:5, 2:5] = 1 # some foreground path = create_test_zarr(tmp_path, shape=(10, 10, 10), data=data) diff --git a/tests/test_multidataset.py b/tests/test_multidataset.py index 29f02ee..44ca43b 100644 --- a/tests/test_multidataset.py +++ b/tests/test_multidataset.py @@ -3,7 +3,6 @@ from __future__ import annotations import numpy as np -import torch from cellmap_data import CellMapDataset, CellMapMultiDataset @@ -15,8 +14,6 @@ def _make_ds(tmp_path, suffix="", **kwargs): - import tempfile, pathlib - sub = tmp_path / suffix if suffix else tmp_path / "ds0" sub.mkdir(parents=True, exist_ok=True) info = create_test_dataset(sub, classes=CLASSES, **kwargs) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 6b7f4b5..d23b294 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -3,7 +3,6 @@ from __future__ import annotations import numpy as np -import pytest from cellmap_data.sampler import ClassBalancedSampler From ddb6246890d877a5317b776def5dcb381d80a039 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Fri, 20 Mar 2026 14:36:21 -0400 Subject: [PATCH 3/4] Update src/cellmap_data/sampler.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/cellmap_data/sampler.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/cellmap_data/sampler.py b/src/cellmap_data/sampler.py index 4de777a..7ed2e57 100644 --- a/src/cellmap_data/sampler.py +++ b/src/cellmap_data/sampler.py @@ -106,9 +106,18 @@ def __iter__(self) -> Iterator[int]: start = int(cumulative_sizes[crop_idx - 1]) if crop_idx > 0 else 0 end = int(cumulative_sizes[crop_idx]) else: - start, end = 0, len(self.dataset) + raise ValueError( + "ClassBalancedSampler: crop index out of range for " + "ConcatDataset/CellMapMultiDataset mapping. " + f"crop_idx={crop_idx}, n_subdatasets={len(cumulative_sizes)}" + ) if start >= end or end > len(self.dataset): - start, end = 0, len(self.dataset) + raise ValueError( + "ClassBalancedSampler: invalid sub-dataset slice computed " + "from cumulative_sizes for crop index " + f"{crop_idx}: start={start}, end={end}, " + f"len(dataset)={len(self.dataset)}" + ) sample_idx = int(self.rng.integers(start, end)) else: # Generic fallback: partition [0, len(dataset)) into n_crops From 4448a5642b6af723ddbe914df49ccaf002fcf33e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 18:43:54 +0000 Subject: [PATCH 4/4] Add tests for ClassBalancedSampler edge cases and min_redundant_inds; fix naming and docstrings Co-authored-by: rhoadesScholar <37990507+rhoadesScholar@users.noreply.github.com> Agent-Logs-Url: https://github.com/janelia-cellmap/cellmap-data/sessions/b1ce1f6e-fdc4-47b9-aa82-562c86c4800d --- src/cellmap_data/sampler.py | 41 ++++++++-------- src/cellmap_data/utils/misc.py | 6 ++- tests/test_sampler.py | 88 ++++++++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 22 deletions(-) diff --git a/src/cellmap_data/sampler.py b/src/cellmap_data/sampler.py index 7ed2e57..477fa73 100644 --- a/src/cellmap_data/sampler.py +++ b/src/cellmap_data/sampler.py @@ -24,8 +24,9 @@ class ClassBalancedSampler(Sampler): ``dataset.get_crop_class_matrix()`` → ``bool[n_crops, n_classes]``. 2. Maintain running counts of how many times each class has been seen. 3. At each step: pick the class with the lowest count (ties broken - randomly), sample a crop annotating it, yield that crop index, then - increment counts for *all* classes that crop annotates. + randomly), sample a matrix row (crop) annotating it, map that row to + an actual dataset sample index, and yield the sample index. Then + increment counts for *all* classes that row annotates. This guarantees rare classes get sampled as often as common ones. @@ -85,54 +86,54 @@ def __iter__(self) -> Iterator[int]: ] target_class = int(self.rng.choice(tied)) - # Sample a crop that annotates this class - crop_idx = int(self.rng.choice(self.class_to_crops[target_class])) + # Sample a matrix row (crop) that annotates this class + row_idx = int(self.rng.choice(self.class_to_crops[target_class])) - # Increment counts for all classes this crop annotates - annotated = np.where(self.crop_class_matrix[crop_idx])[0] + # Increment counts for all classes this row annotates + annotated = np.where(self.crop_class_matrix[row_idx])[0] class_counts[annotated] += 1.0 - # Map crop index (dataset-level row) to an actual sample index. - # If n_crops equals len(dataset), the crop index IS the sample index. + # Map matrix row (dataset-level row) to an actual sample index. + # If n_crops equals len(dataset), the row index IS the sample index. if self.n_crops == len(self.dataset): - sample_idx = crop_idx + sample_idx = row_idx elif hasattr(self.dataset, "datasets") and hasattr( self.dataset, "cumulative_sizes" ): - # ConcatDataset / CellMapMultiDataset: each crop row corresponds + # ConcatDataset / CellMapMultiDataset: each row corresponds # to one sub-dataset; pick a random sample within that sub-dataset. cumulative_sizes = self.dataset.cumulative_sizes - if crop_idx < len(cumulative_sizes): - start = int(cumulative_sizes[crop_idx - 1]) if crop_idx > 0 else 0 - end = int(cumulative_sizes[crop_idx]) + if row_idx < len(cumulative_sizes): + start = int(cumulative_sizes[row_idx - 1]) if row_idx > 0 else 0 + end = int(cumulative_sizes[row_idx]) else: raise ValueError( "ClassBalancedSampler: crop index out of range for " "ConcatDataset/CellMapMultiDataset mapping. " - f"crop_idx={crop_idx}, n_subdatasets={len(cumulative_sizes)}" + f"row_idx={row_idx}, n_subdatasets={len(cumulative_sizes)}" ) if start >= end or end > len(self.dataset): raise ValueError( "ClassBalancedSampler: invalid sub-dataset slice computed " - "from cumulative_sizes for crop index " - f"{crop_idx}: start={start}, end={end}, " + "from cumulative_sizes for row index " + f"{row_idx}: start={start}, end={end}, " f"len(dataset)={len(self.dataset)}" ) sample_idx = int(self.rng.integers(start, end)) else: # Generic fallback: partition [0, len(dataset)) into n_crops - # contiguous segments and sample within this crop's segment. + # contiguous segments and sample within this row's segment. total = len(self.dataset) if self.n_crops <= 1 or total <= 0: start, end = 0, max(total, 1) else: base = total // self.n_crops remainder = total % self.n_crops - if crop_idx < remainder: - start = crop_idx * (base + 1) + if row_idx < remainder: + start = row_idx * (base + 1) end = start + (base + 1) else: - start = remainder * (base + 1) + (crop_idx - remainder) * base + start = remainder * (base + 1) + (row_idx - remainder) * base end = start + base if start >= end or end > total: start, end = 0, total diff --git a/src/cellmap_data/utils/misc.py b/src/cellmap_data/utils/misc.py index 8341a71..4ac7782 100644 --- a/src/cellmap_data/utils/misc.py +++ b/src/cellmap_data/utils/misc.py @@ -124,7 +124,9 @@ def min_redundant_inds( ) -> torch.Tensor: """Returns k indices from 0 to n-1 with minimum redundancy. - If replacement is False, the indices are unique. + If replacement is False and k <= n, the indices are unique. + If replacement is False and k > n, duplicates are unavoidable; indices + are unique within each block of size n (minimum redundancy overall). If replacement is True, the indices can have duplicates. Args: @@ -134,7 +136,7 @@ def min_redundant_inds( rng (torch.Generator, optional): The random number generator. Defaults to None. Returns: - torch.Tensor: A tensor of k indices. + torch.Tensor: A tensor of exactly k indices. """ if replacement: return torch.randint(n, (k,), generator=rng) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index d23b294..e4d168d 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -3,8 +3,10 @@ from __future__ import annotations import numpy as np +import pytest from cellmap_data.sampler import ClassBalancedSampler +from cellmap_data.utils.misc import min_redundant_inds class FakeDataset: @@ -20,6 +22,27 @@ def __len__(self) -> int: return self._matrix.shape[0] +class FakeConcatDataset: + """Minimal ConcatDataset-like dataset with datasets + cumulative_sizes.""" + + def __init__(self, sub_lengths: list[int], matrix: np.ndarray): + self._matrix = matrix + self._sub_lengths = sub_lengths + # Cumulative sizes mirrors torch.utils.data.ConcatDataset behaviour + self.cumulative_sizes: list[int] = [] + total = 0 + for length in sub_lengths: + total += length + self.cumulative_sizes.append(total) + self.datasets = [None] * len(sub_lengths) # placeholders + + def get_crop_class_matrix(self) -> np.ndarray: + return self._matrix + + def __len__(self) -> int: + return sum(self._sub_lengths) + + class TestClassBalancedSampler: def _make_sampler(self, matrix, samples_per_epoch=None, seed=42): ds = FakeDataset(matrix) @@ -85,3 +108,68 @@ def test_yields_valid_indices_for_single_class(self): indices = list(sampler) assert len(indices) == 10 assert all(0 <= i < 5 for i in indices) + + def test_raises_when_no_active_classes(self): + """All-False crop-class matrix must raise ValueError immediately.""" + matrix = np.zeros((4, 3), dtype=bool) + with pytest.raises(ValueError, match="no active classes"): + self._make_sampler(matrix, samples_per_epoch=5) + + def test_concat_dataset_indices_in_correct_subdataset(self): + """ConcatDataset path: each yielded index falls in the expected sub-dataset range.""" + # Two sub-datasets: first has 10 samples, second has 20 samples + sub_lengths = [10, 20] + # Row 0 → only class 0 annotated; Row 1 → only class 1 annotated + matrix = np.array([[True, False], [False, True]], dtype=bool) + ds = FakeConcatDataset(sub_lengths, matrix) + sampler = ClassBalancedSampler(ds, samples_per_epoch=40, seed=0) + indices = list(sampler) + assert len(indices) == 40 + # All indices must be valid dataset indices + assert all(0 <= i < len(ds) for i in indices) + # Indices from class-0 crops (row 0 → sub-dataset 0) must be in [0, 10) + # Indices from class-1 crops (row 1 → sub-dataset 1) must be in [10, 30) + # Because the sampler alternates classes, roughly half go to each sub-dataset + indices_set = set(indices) + assert any(i < 10 for i in indices_set), "No index from sub-dataset 0" + assert any(10 <= i < 30 for i in indices_set), "No index from sub-dataset 1" + + def test_concat_dataset_all_indices_in_range(self): + """ConcatDataset path: all yielded indices are within [0, len(dataset)).""" + sub_lengths = [5, 5, 5] + matrix = np.eye(3, dtype=bool) + ds = FakeConcatDataset(sub_lengths, matrix) + sampler = ClassBalancedSampler(ds, samples_per_epoch=30, seed=7) + indices = list(sampler) + assert all(0 <= i < len(ds) for i in indices) + + +class TestMinRedundantInds: + def test_replacement_returns_k(self): + result = min_redundant_inds(5, 12, replacement=True) + assert len(result) == 12 + + def test_no_replacement_k_leq_n(self): + result = min_redundant_inds(10, 4, replacement=False) + assert len(result) == 4 + assert len(set(result.tolist())) == 4 # all unique + + def test_no_replacement_k_equals_n(self): + result = min_redundant_inds(5, 5, replacement=False) + assert len(result) == 5 + assert sorted(result.tolist()) == list(range(5)) + + def test_no_replacement_k_gt_n_exact_multiple(self): + """k=6, n=3: two full permutations, exactly 6 indices returned.""" + result = min_redundant_inds(3, 6, replacement=False) + assert len(result) == 6 + + def test_no_replacement_k_gt_n_with_remainder(self): + """k=7, n=3: must return exactly 7 indices, not 6.""" + result = min_redundant_inds(3, 7, replacement=False) + assert len(result) == 7 + + def test_no_replacement_all_values_in_range(self): + result = min_redundant_inds(4, 11, replacement=False) + assert len(result) == 11 + assert all(0 <= v < 4 for v in result.tolist())