diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index c0b3630..6b19a6d 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -406,18 +406,27 @@ def get_crop_class_matrix(self) -> np.ndarray: # Class counts # ------------------------------------------------------------------ - @property + @cached_property def class_counts(self) -> dict[str, Any]: - """Aggregate per-class foreground voxel counts from all target sources.""" + """Aggregate per-class foreground voxel counts from all target sources. + + Returns a dict with: + - ``"totals"``: per-class foreground voxel counts at training resolution. + - ``"totals_total"``: per-class total voxel counts (full array size) at + training resolution. + """ totals: dict[str, int] = {} + totals_total: dict[str, int] = {} for cls in self.classes: src = self.target_sources.get(cls) if src is not None: counts = src.class_counts totals[cls] = counts.get(cls, 0) + totals_total[cls] = src.total_voxels else: totals[cls] = 0 - return {"totals": totals} + totals_total[cls] = 0 + return {"totals": totals, "totals_total": totals_total} # ------------------------------------------------------------------ # Misc diff --git a/src/cellmap_data/empty_image.py b/src/cellmap_data/empty_image.py index 3869efc..062f268 100644 --- a/src/cellmap_data/empty_image.py +++ b/src/cellmap_data/empty_image.py @@ -62,6 +62,10 @@ def bounding_box(self) -> None: def sampling_box(self) -> None: return None + @property + def total_voxels(self) -> int: + return 0 + @property def class_counts(self) -> dict[str, int]: return {self.label_class: 0} diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index a2ae5bc..8fb113d 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -553,6 +553,26 @@ def class_counts(self) -> dict[str, int]: logger.warning("class_counts failed for %s: %s", self.path, exc) return {self.label_class: 0} + @cached_property + def total_voxels(self) -> int: + """Total number of voxels in the data volume at training resolution. + + Derived from the cached :attr:`bounding_box` (world-space extent of the + dataset) divided by the training-resolution voxel size, so no additional + zarr I/O is needed beyond what is already cached. + """ + try: + total = 1 + for ax, (start, end) in self.bounding_box.items(): + axis_voxels = int(round((end - start) / self.scale[ax])) + if axis_voxels < 1: + axis_voxels = 1 + total *= axis_voxels + return total + except Exception as exc: + logger.warning("total_voxels failed for %s: %s", self.path, exc) + return 0 + def _scale_count(self, s0_count: int, s0_idx: int = 0) -> int: """Scale a voxel count from s0 resolution to training resolution.""" try: diff --git a/src/cellmap_data/multidataset.py b/src/cellmap_data/multidataset.py index a9cc49c..41170da 100644 --- a/src/cellmap_data/multidataset.py +++ b/src/cellmap_data/multidataset.py @@ -58,23 +58,36 @@ def class_counts(self) -> dict[str, Any]: Sequential scan (parallelism offers no benefit over NFS; see project MEMORY.md notes on ``CellMapMultiDataset.class_counts``). + + Returns a dict with: + - ``"totals"``: per-class foreground voxel counts. + - ``"totals_total"``: per-class total voxel counts (full array sizes). """ totals: dict[str, int] = {cls: 0 for cls in self.classes} + totals_total: dict[str, int] = {cls: 0 for cls in self.classes} for ds in tqdm(self.datasets, desc="Counting class voxels", leave=False): - ds_counts = ds.class_counts.get("totals", {}) + ds_counts = ds.class_counts for cls in self.classes: - totals[cls] += ds_counts.get(cls, 0) - return {"totals": totals} + totals[cls] += ds_counts.get("totals", {}).get(cls, 0) + totals_total[cls] += ds_counts.get("totals_total", {}).get(cls, 0) + return {"totals": totals, "totals_total": totals_total} @property def class_weights(self) -> dict[str, float]: - """Per-class sampling weight: ``bg_voxels / fg_voxels``.""" - counts = self.class_counts["totals"] - total_voxels = sum(counts.values()) + """Per-class sampling weight: ``bg_voxels / fg_voxels``. + + Background voxels for each class are derived from the actual data + volume size (``totals_total``) minus foreground voxels, so the ratio + correctly reflects the class imbalance within each volume. + """ + counts = self.class_counts + fg_counts = counts["totals"] + total_counts = counts["totals_total"] weights: dict[str, float] = {} for cls in self.classes: - fg = counts.get(cls, 0) - bg = total_voxels - fg + fg = fg_counts.get(cls, 0) + total = total_counts.get(cls, 0) + bg = max(total - fg, 0) weights[cls] = float(bg) / float(max(fg, 1)) return weights diff --git a/tests/test_image.py b/tests/test_image.py index da7be0a..92bd769 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -261,3 +261,11 @@ def test_class_counts_keys(self, tmp_path): counts = img.class_counts assert "mito" in counts assert counts["mito"] >= 0 + + def test_total_voxels_equals_array_size(self, tmp_path): + shape = (10, 10, 10) + data = np.zeros(shape, dtype=np.uint8) + data[2:5, 2:5, 2:5] = 1 + path = create_test_zarr(tmp_path, shape=shape, data=data) + img = CellMapImage(path, "mito", [8.0, 8.0, 8.0], [4, 4, 4]) + assert img.total_voxels == int(np.prod(shape)) diff --git a/tests/test_multidataset.py b/tests/test_multidataset.py index bb4b95c..29f02ee 100644 --- a/tests/test_multidataset.py +++ b/tests/test_multidataset.py @@ -62,7 +62,40 @@ def test_class_counts_keys(self, tmp_path): multi = CellMapMultiDataset([ds1], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS) counts = multi.class_counts assert "totals" in counts + assert "totals_total" in counts assert all(c in counts["totals"] for c in CLASSES) + assert all(c in counts["totals_total"] for c in CLASSES) + + def test_class_counts_total_equals_volume_size(self, tmp_path): + """totals_total should reflect the actual array volume, not sum of fg counts.""" + shape = (8, 8, 8) + voxel_size = [8.0, 8.0, 8.0] + ds1 = _make_ds(tmp_path, "d1", shape=shape, voxel_size=voxel_size) + multi = CellMapMultiDataset([ds1], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS) + counts = multi.class_counts + expected_total = int(np.prod(shape)) + for cls in CLASSES: + assert counts["totals_total"][cls] == expected_total, ( + f"totals_total[{cls!r}] should equal array volume {expected_total}, " + f"got {counts['totals_total'][cls]}" + ) + + def test_class_weights_bg_uses_volume_size(self, tmp_path): + """bg in class_weights should be total_voxels - fg, not sum(fg) - fg.""" + shape = (8, 8, 8) + voxel_size = [8.0, 8.0, 8.0] + ds1 = _make_ds(tmp_path, "d1", shape=shape, voxel_size=voxel_size) + multi = CellMapMultiDataset([ds1], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS) + counts = multi.class_counts + weights = multi.class_weights + total = int(np.prod(shape)) + for cls in CLASSES: + fg = counts["totals"][cls] + expected_bg = total - fg + expected_weight = float(max(expected_bg, 0)) / float(max(fg, 1)) + assert ( + abs(weights[cls] - expected_weight) < 1e-6 + ), f"class_weights[{cls!r}] mismatch: expected {expected_weight}, got {weights[cls]}" def test_get_crop_class_matrix_shape(self, tmp_path): ds1 = _make_ds(tmp_path, "d1")