Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,18 +406,27 @@ def get_crop_class_matrix(self) -> np.ndarray:
# Class counts
# ------------------------------------------------------------------

@property
@cached_property
def class_counts(self) -> dict[str, Any]:
"""Aggregate per-class foreground voxel counts from all target sources."""
"""Aggregate per-class foreground voxel counts from all target sources.

Returns a dict with:
- ``"totals"``: per-class foreground voxel counts at training resolution.
- ``"totals_total"``: per-class total voxel counts (full array size) at
training resolution.
"""
totals: dict[str, int] = {}
totals_total: dict[str, int] = {}
for cls in self.classes:
src = self.target_sources.get(cls)
if src is not None:
counts = src.class_counts
totals[cls] = counts.get(cls, 0)
totals_total[cls] = src.total_voxels
else:
totals[cls] = 0
return {"totals": totals}
totals_total[cls] = 0
return {"totals": totals, "totals_total": totals_total}

# ------------------------------------------------------------------
# Misc
Expand Down
4 changes: 4 additions & 0 deletions src/cellmap_data/empty_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def bounding_box(self) -> None:
def sampling_box(self) -> None:
return None

@property
def total_voxels(self) -> int:
return 0

@property
def class_counts(self) -> dict[str, int]:
return {self.label_class: 0}
Expand Down
20 changes: 20 additions & 0 deletions src/cellmap_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,26 @@ def class_counts(self) -> dict[str, int]:
logger.warning("class_counts failed for %s: %s", self.path, exc)
return {self.label_class: 0}

@cached_property
def total_voxels(self) -> int:
"""Total number of voxels in the data volume at training resolution.

Derived from the cached :attr:`bounding_box` (world-space extent of the
dataset) divided by the training-resolution voxel size, so no additional
zarr I/O is needed beyond what is already cached.
"""
try:
total = 1
for ax, (start, end) in self.bounding_box.items():
axis_voxels = int(round((end - start) / self.scale[ax]))
if axis_voxels < 1:
axis_voxels = 1
total *= axis_voxels
return total
except Exception as exc:
logger.warning("total_voxels failed for %s: %s", self.path, exc)
return 0

def _scale_count(self, s0_count: int, s0_idx: int = 0) -> int:
"""Scale a voxel count from s0 resolution to training resolution."""
try:
Expand Down
29 changes: 21 additions & 8 deletions src/cellmap_data/multidataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,36 @@ def class_counts(self) -> dict[str, Any]:

Sequential scan (parallelism offers no benefit over NFS; see
project MEMORY.md notes on ``CellMapMultiDataset.class_counts``).

Returns a dict with:
- ``"totals"``: per-class foreground voxel counts.
- ``"totals_total"``: per-class total voxel counts (full array sizes).
"""
totals: dict[str, int] = {cls: 0 for cls in self.classes}
totals_total: dict[str, int] = {cls: 0 for cls in self.classes}
for ds in tqdm(self.datasets, desc="Counting class voxels", leave=False):
ds_counts = ds.class_counts.get("totals", {})
ds_counts = ds.class_counts
for cls in self.classes:
totals[cls] += ds_counts.get(cls, 0)
return {"totals": totals}
totals[cls] += ds_counts.get("totals", {}).get(cls, 0)
totals_total[cls] += ds_counts.get("totals_total", {}).get(cls, 0)
return {"totals": totals, "totals_total": totals_total}

@property
def class_weights(self) -> dict[str, float]:
"""Per-class sampling weight: ``bg_voxels / fg_voxels``."""
counts = self.class_counts["totals"]
total_voxels = sum(counts.values())
"""Per-class sampling weight: ``bg_voxels / fg_voxels``.

Background voxels for each class are derived from the actual data
volume size (``totals_total``) minus foreground voxels, so the ratio
correctly reflects the class imbalance within each volume.
"""
counts = self.class_counts
fg_counts = counts["totals"]
total_counts = counts["totals_total"]
weights: dict[str, float] = {}
for cls in self.classes:
fg = counts.get(cls, 0)
bg = total_voxels - fg
fg = fg_counts.get(cls, 0)
total = total_counts.get(cls, 0)
bg = max(total - fg, 0)
weights[cls] = float(bg) / float(max(fg, 1))
return weights

Expand Down
8 changes: 8 additions & 0 deletions tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,11 @@ def test_class_counts_keys(self, tmp_path):
counts = img.class_counts
assert "mito" in counts
assert counts["mito"] >= 0

def test_total_voxels_equals_array_size(self, tmp_path):
shape = (10, 10, 10)
data = np.zeros(shape, dtype=np.uint8)
data[2:5, 2:5, 2:5] = 1
path = create_test_zarr(tmp_path, shape=shape, data=data)
img = CellMapImage(path, "mito", [8.0, 8.0, 8.0], [4, 4, 4])
assert img.total_voxels == int(np.prod(shape))
33 changes: 33 additions & 0 deletions tests/test_multidataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,40 @@ def test_class_counts_keys(self, tmp_path):
multi = CellMapMultiDataset([ds1], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS)
counts = multi.class_counts
assert "totals" in counts
assert "totals_total" in counts
assert all(c in counts["totals"] for c in CLASSES)
assert all(c in counts["totals_total"] for c in CLASSES)

def test_class_counts_total_equals_volume_size(self, tmp_path):
"""totals_total should reflect the actual array volume, not sum of fg counts."""
shape = (8, 8, 8)
voxel_size = [8.0, 8.0, 8.0]
ds1 = _make_ds(tmp_path, "d1", shape=shape, voxel_size=voxel_size)
multi = CellMapMultiDataset([ds1], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS)
counts = multi.class_counts
expected_total = int(np.prod(shape))
for cls in CLASSES:
assert counts["totals_total"][cls] == expected_total, (
f"totals_total[{cls!r}] should equal array volume {expected_total}, "
f"got {counts['totals_total'][cls]}"
)

def test_class_weights_bg_uses_volume_size(self, tmp_path):
"""bg in class_weights should be total_voxels - fg, not sum(fg) - fg."""
shape = (8, 8, 8)
voxel_size = [8.0, 8.0, 8.0]
ds1 = _make_ds(tmp_path, "d1", shape=shape, voxel_size=voxel_size)
multi = CellMapMultiDataset([ds1], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS)
counts = multi.class_counts
weights = multi.class_weights
total = int(np.prod(shape))
for cls in CLASSES:
fg = counts["totals"][cls]
expected_bg = total - fg
expected_weight = float(max(expected_bg, 0)) / float(max(fg, 1))
assert (
abs(weights[cls] - expected_weight) < 1e-6
), f"class_weights[{cls!r}] mismatch: expected {expected_weight}, got {weights[cls]}"

def test_get_crop_class_matrix_shape(self, tmp_path):
ds1 = _make_ds(tmp_path, "d1")
Expand Down
Loading