-
Notifications
You must be signed in to change notification settings - Fork 3
Fix unused imports, min_redundant_inds remainder bug, and ClassBalancedSampler index mapping #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
295b90d
6c895fa
ddb6246
4448a56
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -24,8 +24,9 @@ class ClassBalancedSampler(Sampler): | |||||||||||||||||||||||
| ``dataset.get_crop_class_matrix()`` → ``bool[n_crops, n_classes]``. | ||||||||||||||||||||||||
| 2. Maintain running counts of how many times each class has been seen. | ||||||||||||||||||||||||
| 3. At each step: pick the class with the lowest count (ties broken | ||||||||||||||||||||||||
| randomly), sample a crop annotating it, yield that crop index, then | ||||||||||||||||||||||||
| increment counts for *all* classes that crop annotates. | ||||||||||||||||||||||||
| randomly), sample a matrix row (crop) annotating it, map that row to | ||||||||||||||||||||||||
| an actual dataset sample index, and yield the sample index. Then | ||||||||||||||||||||||||
| increment counts for *all* classes that row annotates. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| This guarantees rare classes get sampled as often as common ones. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
@@ -64,6 +65,12 @@ def __init__( | |||||||||||||||||||||||
| self.class_to_crops[c] = indices | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| self.active_classes: list[int] = sorted(self.class_to_crops.keys()) | ||||||||||||||||||||||||
| if not self.active_classes: | ||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||
| "ClassBalancedSampler: no active classes found in crop-class " | ||||||||||||||||||||||||
| "matrix. This can occur when all requested classes are only " | ||||||||||||||||||||||||
| "represented by empty crops (e.g., EmptyImage)." | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def __iter__(self) -> Iterator[int]: | ||||||||||||||||||||||||
| class_counts = np.zeros(self.n_classes, dtype=np.float64) | ||||||||||||||||||||||||
|
|
@@ -79,14 +86,60 @@ def __iter__(self) -> Iterator[int]: | |||||||||||||||||||||||
| ] | ||||||||||||||||||||||||
| target_class = int(self.rng.choice(tied)) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Sample a crop that annotates this class | ||||||||||||||||||||||||
| crop_idx = int(self.rng.choice(self.class_to_crops[target_class])) | ||||||||||||||||||||||||
| # Sample a matrix row (crop) that annotates this class | ||||||||||||||||||||||||
| row_idx = int(self.rng.choice(self.class_to_crops[target_class])) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Increment counts for all classes this crop annotates | ||||||||||||||||||||||||
| annotated = np.where(self.crop_class_matrix[crop_idx])[0] | ||||||||||||||||||||||||
| # Increment counts for all classes this row annotates | ||||||||||||||||||||||||
| annotated = np.where(self.crop_class_matrix[row_idx])[0] | ||||||||||||||||||||||||
| class_counts[annotated] += 1.0 | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| yield crop_idx | ||||||||||||||||||||||||
| # Map matrix row (dataset-level row) to an actual sample index. | ||||||||||||||||||||||||
| # If n_crops equals len(dataset), the row index IS the sample index. | ||||||||||||||||||||||||
| if self.n_crops == len(self.dataset): | ||||||||||||||||||||||||
| sample_idx = row_idx | ||||||||||||||||||||||||
| elif hasattr(self.dataset, "datasets") and hasattr( | ||||||||||||||||||||||||
| self.dataset, "cumulative_sizes" | ||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||
| # ConcatDataset / CellMapMultiDataset: each row corresponds | ||||||||||||||||||||||||
| # to one sub-dataset; pick a random sample within that sub-dataset. | ||||||||||||||||||||||||
| cumulative_sizes = self.dataset.cumulative_sizes | ||||||||||||||||||||||||
| if row_idx < len(cumulative_sizes): | ||||||||||||||||||||||||
| start = int(cumulative_sizes[row_idx - 1]) if row_idx > 0 else 0 | ||||||||||||||||||||||||
| end = int(cumulative_sizes[row_idx]) | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||
| "ClassBalancedSampler: crop index out of range for " | ||||||||||||||||||||||||
| "ConcatDataset/CellMapMultiDataset mapping. " | ||||||||||||||||||||||||
| f"row_idx={row_idx}, n_subdatasets={len(cumulative_sizes)}" | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| if start >= end or end > len(self.dataset): | ||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||
| "ClassBalancedSampler: invalid sub-dataset slice computed " | ||||||||||||||||||||||||
| "from cumulative_sizes for row index " | ||||||||||||||||||||||||
| f"{row_idx}: start={start}, end={end}, " | ||||||||||||||||||||||||
| f"len(dataset)={len(self.dataset)}" | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| sample_idx = int(self.rng.integers(start, end)) | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| # Generic fallback: partition [0, len(dataset)) into n_crops | ||||||||||||||||||||||||
| # contiguous segments and sample within this row's segment. | ||||||||||||||||||||||||
| total = len(self.dataset) | ||||||||||||||||||||||||
| if self.n_crops <= 1 or total <= 0: | ||||||||||||||||||||||||
| start, end = 0, max(total, 1) | ||||||||||||||||||||||||
|
Comment on lines
+127
to
+128
|
||||||||||||||||||||||||
| if self.n_crops <= 1 or total <= 0: | |
| start, end = 0, max(total, 1) | |
| if total <= 0: | |
| raise ValueError( | |
| "ClassBalancedSampler: cannot sample from an empty dataset " | |
| "in generic fallback mapping (len(dataset) == 0). " | |
| "Consider setting samples_per_epoch=0 or removing the " | |
| "sampler for this split." | |
| ) | |
| if self.n_crops <= 1: | |
| start, end = 0, total |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| import pytest | ||
|
|
||
| from cellmap_data.sampler import ClassBalancedSampler | ||
| from cellmap_data.utils.misc import min_redundant_inds | ||
|
|
||
|
|
||
| class FakeDataset: | ||
|
|
@@ -21,6 +22,27 @@ def __len__(self) -> int: | |
| return self._matrix.shape[0] | ||
|
|
||
|
|
||
| class FakeConcatDataset: | ||
| """Minimal ConcatDataset-like dataset with datasets + cumulative_sizes.""" | ||
|
|
||
| def __init__(self, sub_lengths: list[int], matrix: np.ndarray): | ||
| self._matrix = matrix | ||
| self._sub_lengths = sub_lengths | ||
| # Cumulative sizes mirrors torch.utils.data.ConcatDataset behaviour | ||
| self.cumulative_sizes: list[int] = [] | ||
| total = 0 | ||
| for length in sub_lengths: | ||
| total += length | ||
| self.cumulative_sizes.append(total) | ||
| self.datasets = [None] * len(sub_lengths) # placeholders | ||
|
|
||
| def get_crop_class_matrix(self) -> np.ndarray: | ||
| return self._matrix | ||
|
|
||
| def __len__(self) -> int: | ||
| return sum(self._sub_lengths) | ||
|
|
||
|
|
||
| class TestClassBalancedSampler: | ||
| def _make_sampler(self, matrix, samples_per_epoch=None, seed=42): | ||
| ds = FakeDataset(matrix) | ||
|
|
@@ -86,3 +108,68 @@ def test_yields_valid_indices_for_single_class(self): | |
| indices = list(sampler) | ||
| assert len(indices) == 10 | ||
| assert all(0 <= i < 5 for i in indices) | ||
|
|
||
| def test_raises_when_no_active_classes(self): | ||
| """All-False crop-class matrix must raise ValueError immediately.""" | ||
| matrix = np.zeros((4, 3), dtype=bool) | ||
| with pytest.raises(ValueError, match="no active classes"): | ||
| self._make_sampler(matrix, samples_per_epoch=5) | ||
|
Comment on lines
+112
to
+116
|
||
|
|
||
| def test_concat_dataset_indices_in_correct_subdataset(self): | ||
| """ConcatDataset path: each yielded index falls in the expected sub-dataset range.""" | ||
| # Two sub-datasets: first has 10 samples, second has 20 samples | ||
| sub_lengths = [10, 20] | ||
| # Row 0 → only class 0 annotated; Row 1 → only class 1 annotated | ||
| matrix = np.array([[True, False], [False, True]], dtype=bool) | ||
| ds = FakeConcatDataset(sub_lengths, matrix) | ||
| sampler = ClassBalancedSampler(ds, samples_per_epoch=40, seed=0) | ||
| indices = list(sampler) | ||
| assert len(indices) == 40 | ||
| # All indices must be valid dataset indices | ||
| assert all(0 <= i < len(ds) for i in indices) | ||
| # Indices from class-0 crops (row 0 → sub-dataset 0) must be in [0, 10) | ||
| # Indices from class-1 crops (row 1 → sub-dataset 1) must be in [10, 30) | ||
| # Because the sampler alternates classes, roughly half go to each sub-dataset | ||
| indices_set = set(indices) | ||
| assert any(i < 10 for i in indices_set), "No index from sub-dataset 0" | ||
| assert any(10 <= i < 30 for i in indices_set), "No index from sub-dataset 1" | ||
|
|
||
| def test_concat_dataset_all_indices_in_range(self): | ||
| """ConcatDataset path: all yielded indices are within [0, len(dataset)).""" | ||
| sub_lengths = [5, 5, 5] | ||
| matrix = np.eye(3, dtype=bool) | ||
| ds = FakeConcatDataset(sub_lengths, matrix) | ||
| sampler = ClassBalancedSampler(ds, samples_per_epoch=30, seed=7) | ||
| indices = list(sampler) | ||
| assert all(0 <= i < len(ds) for i in indices) | ||
|
|
||
|
|
||
| class TestMinRedundantInds: | ||
| def test_replacement_returns_k(self): | ||
| result = min_redundant_inds(5, 12, replacement=True) | ||
| assert len(result) == 12 | ||
|
|
||
| def test_no_replacement_k_leq_n(self): | ||
| result = min_redundant_inds(10, 4, replacement=False) | ||
| assert len(result) == 4 | ||
| assert len(set(result.tolist())) == 4 # all unique | ||
|
|
||
| def test_no_replacement_k_equals_n(self): | ||
| result = min_redundant_inds(5, 5, replacement=False) | ||
| assert len(result) == 5 | ||
| assert sorted(result.tolist()) == list(range(5)) | ||
|
|
||
| def test_no_replacement_k_gt_n_exact_multiple(self): | ||
| """k=6, n=3: two full permutations, exactly 6 indices returned.""" | ||
| result = min_redundant_inds(3, 6, replacement=False) | ||
| assert len(result) == 6 | ||
|
|
||
| def test_no_replacement_k_gt_n_with_remainder(self): | ||
| """k=7, n=3: must return exactly 7 indices, not 6.""" | ||
| result = min_redundant_inds(3, 7, replacement=False) | ||
| assert len(result) == 7 | ||
|
|
||
| def test_no_replacement_all_values_in_range(self): | ||
| result = min_redundant_inds(4, 11, replacement=False) | ||
| assert len(result) == 11 | ||
| assert all(0 <= v < 4 for v in result.tolist()) | ||
Uh oh!
There was an error while loading. Please reload this page.