Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions src/cellmap_data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ class ClassBalancedSampler(Sampler):
``dataset.get_crop_class_matrix()`` → ``bool[n_crops, n_classes]``.
2. Maintain running counts of how many times each class has been seen.
3. At each step: pick the class with the lowest count (ties broken
randomly), sample a crop annotating it, yield that crop index, then
increment counts for *all* classes that crop annotates.
randomly), sample a matrix row (crop) annotating it, map that row to
an actual dataset sample index, and yield the sample index. Then
increment counts for *all* classes that row annotates.

This guarantees rare classes get sampled as often as common ones.

Expand Down Expand Up @@ -64,6 +65,12 @@ def __init__(
self.class_to_crops[c] = indices

self.active_classes: list[int] = sorted(self.class_to_crops.keys())
if not self.active_classes:
raise ValueError(
"ClassBalancedSampler: no active classes found in crop-class "
"matrix. This can occur when all requested classes are only "
"represented by empty crops (e.g., EmptyImage)."
)

def __iter__(self) -> Iterator[int]:
class_counts = np.zeros(self.n_classes, dtype=np.float64)
Expand All @@ -79,14 +86,60 @@ def __iter__(self) -> Iterator[int]:
]
target_class = int(self.rng.choice(tied))

# Sample a crop that annotates this class
crop_idx = int(self.rng.choice(self.class_to_crops[target_class]))
# Sample a matrix row (crop) that annotates this class
row_idx = int(self.rng.choice(self.class_to_crops[target_class]))

# Increment counts for all classes this crop annotates
annotated = np.where(self.crop_class_matrix[crop_idx])[0]
# Increment counts for all classes this row annotates
annotated = np.where(self.crop_class_matrix[row_idx])[0]
class_counts[annotated] += 1.0

yield crop_idx
# Map matrix row (dataset-level row) to an actual sample index.
# If n_crops equals len(dataset), the row index IS the sample index.
if self.n_crops == len(self.dataset):
sample_idx = row_idx
elif hasattr(self.dataset, "datasets") and hasattr(
self.dataset, "cumulative_sizes"
):
# ConcatDataset / CellMapMultiDataset: each row corresponds
# to one sub-dataset; pick a random sample within that sub-dataset.
cumulative_sizes = self.dataset.cumulative_sizes
if row_idx < len(cumulative_sizes):
start = int(cumulative_sizes[row_idx - 1]) if row_idx > 0 else 0
end = int(cumulative_sizes[row_idx])
else:
raise ValueError(
"ClassBalancedSampler: crop index out of range for "
"ConcatDataset/CellMapMultiDataset mapping. "
f"row_idx={row_idx}, n_subdatasets={len(cumulative_sizes)}"
)
if start >= end or end > len(self.dataset):
raise ValueError(
"ClassBalancedSampler: invalid sub-dataset slice computed "
"from cumulative_sizes for row index "
f"{row_idx}: start={start}, end={end}, "
f"len(dataset)={len(self.dataset)}"
)
sample_idx = int(self.rng.integers(start, end))
else:
# Generic fallback: partition [0, len(dataset)) into n_crops
# contiguous segments and sample within this row's segment.
total = len(self.dataset)
if self.n_crops <= 1 or total <= 0:
start, end = 0, max(total, 1)
Comment on lines +127 to +128
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the generic fallback mapping, when len(self.dataset) == 0 this branch sets end = max(total, 1) and can yield sample_idx = 0, which is an invalid index for an empty dataset (and will likely crash downstream when DataLoader tries to fetch it). Consider explicitly erroring out when total == 0 (or forcing samples_per_epoch to 0) instead of manufacturing a [0, 1) range.

Suggested change
if self.n_crops <= 1 or total <= 0:
start, end = 0, max(total, 1)
if total <= 0:
raise ValueError(
"ClassBalancedSampler: cannot sample from an empty dataset "
"in generic fallback mapping (len(dataset) == 0). "
"Consider setting samples_per_epoch=0 or removing the "
"sampler for this split."
)
if self.n_crops <= 1:
start, end = 0, total

Copilot uses AI. Check for mistakes.
else:
base = total // self.n_crops
remainder = total % self.n_crops
if row_idx < remainder:
start = row_idx * (base + 1)
end = start + (base + 1)
else:
start = remainder * (base + 1) + (row_idx - remainder) * base
end = start + base
if start >= end or end > total:
start, end = 0, total
sample_idx = int(self.rng.integers(start, end))

yield sample_idx

def __len__(self) -> int:
return self.samples_per_epoch
15 changes: 11 additions & 4 deletions src/cellmap_data/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def min_redundant_inds(
) -> torch.Tensor:
"""Returns k indices from 0 to n-1 with minimum redundancy.
If replacement is False, the indices are unique.
If replacement is False and k <= n, the indices are unique.
If replacement is False and k > n, duplicates are unavoidable; indices
are unique within each block of size n (minimum redundancy overall).
If replacement is True, the indices can have duplicates.
Args:
Expand All @@ -134,13 +136,18 @@ def min_redundant_inds(
rng (torch.Generator, optional): The random number generator. Defaults to None.
Returns:
torch.Tensor: A tensor of k indices.
torch.Tensor: A tensor of exactly k indices.
"""
if replacement:
return torch.randint(n, (k,), generator=rng)
else:
if k > n:
# Repeat the unique indices until we have k indices
return torch.cat([torch.randperm(n, generator=rng) for _ in range(k // n)])
# Repeat unique indices until we have k indices (handle remainder)
full_perms = k // n
remainder = k % n
parts = [torch.randperm(n, generator=rng) for _ in range(full_perms)]
if remainder > 0:
parts.append(torch.randperm(n, generator=rng)[:remainder])
return torch.cat(parts)
else:
return torch.randperm(n, generator=rng)[:k]
1 change: 0 additions & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def test_small_crop_pad_false_excluded(self, tmp_path):

gt_base = str(tmp_path / "gt.zarr")
os.makedirs(gt_base, exist_ok=True)
import json

with open(os.path.join(gt_base, ".zgroup"), "w") as f:
f.write('{"zarr_format": 2}')
Expand Down
2 changes: 0 additions & 2 deletions tests/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations

import pytest

from cellmap_data.utils.geometry import box_intersection, box_shape, box_union


Expand Down
2 changes: 0 additions & 2 deletions tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,6 @@ def test_rotation_output_shape_preserved(self, tmp_path):

class TestCellMapImageClassCounts:
def test_class_counts_keys(self, tmp_path):
import zarr as z

data = np.zeros((10, 10, 10), dtype=np.uint8)
data[2:5, 2:5, 2:5] = 1 # some foreground
path = create_test_zarr(tmp_path, shape=(10, 10, 10), data=data)
Expand Down
3 changes: 0 additions & 3 deletions tests/test_multidataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import numpy as np
import torch

from cellmap_data import CellMapDataset, CellMapMultiDataset

Expand All @@ -15,8 +14,6 @@


def _make_ds(tmp_path, suffix="", **kwargs):
import tempfile, pathlib

sub = tmp_path / suffix if suffix else tmp_path / "ds0"
sub.mkdir(parents=True, exist_ok=True)
info = create_test_dataset(sub, classes=CLASSES, **kwargs)
Expand Down
87 changes: 87 additions & 0 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from cellmap_data.sampler import ClassBalancedSampler
from cellmap_data.utils.misc import min_redundant_inds


class FakeDataset:
Expand All @@ -21,6 +22,27 @@ def __len__(self) -> int:
return self._matrix.shape[0]


class FakeConcatDataset:
"""Minimal ConcatDataset-like dataset with datasets + cumulative_sizes."""

def __init__(self, sub_lengths: list[int], matrix: np.ndarray):
self._matrix = matrix
self._sub_lengths = sub_lengths
# Cumulative sizes mirrors torch.utils.data.ConcatDataset behaviour
self.cumulative_sizes: list[int] = []
total = 0
for length in sub_lengths:
total += length
self.cumulative_sizes.append(total)
self.datasets = [None] * len(sub_lengths) # placeholders

def get_crop_class_matrix(self) -> np.ndarray:
return self._matrix

def __len__(self) -> int:
return sum(self._sub_lengths)


class TestClassBalancedSampler:
def _make_sampler(self, matrix, samples_per_epoch=None, seed=42):
ds = FakeDataset(matrix)
Expand Down Expand Up @@ -86,3 +108,68 @@ def test_yields_valid_indices_for_single_class(self):
indices = list(sampler)
assert len(indices) == 10
assert all(0 <= i < 5 for i in indices)

def test_raises_when_no_active_classes(self):
"""All-False crop-class matrix must raise ValueError immediately."""
matrix = np.zeros((4, 3), dtype=bool)
with pytest.raises(ValueError, match="no active classes"):
self._make_sampler(matrix, samples_per_epoch=5)
Comment on lines +112 to +116
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description says import pytest was removed from tests/test_sampler.py, but this file still imports (and now uses) pytest in the newly added test_raises_when_no_active_classes. Consider updating the PR description to avoid confusion when reviewing the unused-import cleanup.

Copilot uses AI. Check for mistakes.

def test_concat_dataset_indices_in_correct_subdataset(self):
"""ConcatDataset path: each yielded index falls in the expected sub-dataset range."""
# Two sub-datasets: first has 10 samples, second has 20 samples
sub_lengths = [10, 20]
# Row 0 → only class 0 annotated; Row 1 → only class 1 annotated
matrix = np.array([[True, False], [False, True]], dtype=bool)
ds = FakeConcatDataset(sub_lengths, matrix)
sampler = ClassBalancedSampler(ds, samples_per_epoch=40, seed=0)
indices = list(sampler)
assert len(indices) == 40
# All indices must be valid dataset indices
assert all(0 <= i < len(ds) for i in indices)
# Indices from class-0 crops (row 0 → sub-dataset 0) must be in [0, 10)
# Indices from class-1 crops (row 1 → sub-dataset 1) must be in [10, 30)
# Because the sampler alternates classes, roughly half go to each sub-dataset
indices_set = set(indices)
assert any(i < 10 for i in indices_set), "No index from sub-dataset 0"
assert any(10 <= i < 30 for i in indices_set), "No index from sub-dataset 1"

def test_concat_dataset_all_indices_in_range(self):
"""ConcatDataset path: all yielded indices are within [0, len(dataset))."""
sub_lengths = [5, 5, 5]
matrix = np.eye(3, dtype=bool)
ds = FakeConcatDataset(sub_lengths, matrix)
sampler = ClassBalancedSampler(ds, samples_per_epoch=30, seed=7)
indices = list(sampler)
assert all(0 <= i < len(ds) for i in indices)


class TestMinRedundantInds:
def test_replacement_returns_k(self):
result = min_redundant_inds(5, 12, replacement=True)
assert len(result) == 12

def test_no_replacement_k_leq_n(self):
result = min_redundant_inds(10, 4, replacement=False)
assert len(result) == 4
assert len(set(result.tolist())) == 4 # all unique

def test_no_replacement_k_equals_n(self):
result = min_redundant_inds(5, 5, replacement=False)
assert len(result) == 5
assert sorted(result.tolist()) == list(range(5))

def test_no_replacement_k_gt_n_exact_multiple(self):
"""k=6, n=3: two full permutations, exactly 6 indices returned."""
result = min_redundant_inds(3, 6, replacement=False)
assert len(result) == 6

def test_no_replacement_k_gt_n_with_remainder(self):
"""k=7, n=3: must return exactly 7 indices, not 6."""
result = min_redundant_inds(3, 7, replacement=False)
assert len(result) == 7

def test_no_replacement_all_values_in_range(self):
result = min_redundant_inds(4, 11, replacement=False)
assert len(result) == 11
assert all(0 <= v < 4 for v in result.tolist())
Loading