From 2b31a91dd41a8d9ae166b9f76c6d2132d065427b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 13 Mar 2026 18:32:52 +0000 Subject: [PATCH 1/2] Initial plan From c8151b30cd3f5ce0140d4792464fb30601002b39 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 13 Mar 2026 18:40:41 +0000 Subject: [PATCH 2/2] fix: compute bg_voxels from actual total_voxels in data volumes for class_weights Co-authored-by: rhoadesScholar <37990507+rhoadesScholar@users.noreply.github.com> --- src/cellmap_data/dataset.py | 9 +++++++++ src/cellmap_data/empty_image.py | 4 ++++ src/cellmap_data/image.py | 14 ++++++++++++++ src/cellmap_data/multidataset.py | 28 +++++++++++++++++++++++----- 4 files changed, 50 insertions(+), 5 deletions(-) diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index c0b3630..44d5d2f 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -419,6 +419,15 @@ def class_counts(self) -> dict[str, Any]: totals[cls] = 0 return {"totals": totals} + @property + def total_voxels(self) -> dict[str, int]: + """Total voxels in the data volume per class, normalised to training-resolution voxels.""" + totals: dict[str, int] = {} + for cls in self.classes: + src = self.target_sources.get(cls) + totals[cls] = src.total_voxels if src is not None else 0 + return totals + # ------------------------------------------------------------------ # Misc # ------------------------------------------------------------------ diff --git a/src/cellmap_data/empty_image.py b/src/cellmap_data/empty_image.py index 3869efc..384ead1 100644 --- a/src/cellmap_data/empty_image.py +++ b/src/cellmap_data/empty_image.py @@ -66,6 +66,10 @@ def sampling_box(self) -> None: def class_counts(self) -> dict[str, int]: return {self.label_class: 0} + @property + def total_voxels(self) -> int: + return 0 + def to(self, device: str | torch.device) -> "EmptyImage": self._nan_tensor = self._nan_tensor.to(device) return self diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index a2ae5bc..13c8144 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -553,6 +553,20 @@ def class_counts(self) -> dict[str, int]: logger.warning("class_counts failed for %s: %s", self.path, exc) return {self.label_class: 0} + @property + def total_voxels(self) -> int: + """Total number of voxels in the s0 data volume, normalised to training-resolution voxels.""" + try: + s0_path = self._level_info[0][0] + s0_arr = zarr.open_array(f"{self.path}/{s0_path}", mode="r") + n_spatial = len(self.axes) + spatial_shape = s0_arr.shape[-n_spatial:] + s0_total = int(np.prod(spatial_shape)) + return self._scale_count(s0_total, s0_idx=0) + except Exception as exc: + logger.warning("total_voxels failed for %s: %s", self.path, exc) + return 0 + def _scale_count(self, s0_count: int, s0_idx: int = 0) -> int: """Scale a voxel count from s0 resolution to training resolution.""" try: diff --git a/src/cellmap_data/multidataset.py b/src/cellmap_data/multidataset.py index a9cc49c..75b2f47 100644 --- a/src/cellmap_data/multidataset.py +++ b/src/cellmap_data/multidataset.py @@ -68,13 +68,31 @@ def class_counts(self) -> dict[str, Any]: @property def class_weights(self) -> dict[str, float]: - """Per-class sampling weight: ``bg_voxels / fg_voxels``.""" - counts = self.class_counts["totals"] - total_voxels = sum(counts.values()) + """Per-class sampling weight: ``bg_voxels / fg_voxels``. + + Background voxels are computed as the total voxels in the data volume + minus the foreground voxels for each class. + """ + fg_counts = self.class_counts["totals"] + # Aggregate actual total voxels per class across all datasets + total_voxels: dict[str, int] = {cls: 0 for cls in self.classes} + for ds in self.datasets: + ds_total = ds.total_voxels + for cls in self.classes: + total_voxels[cls] += ds_total.get(cls, 0) weights: dict[str, float] = {} for cls in self.classes: - fg = counts.get(cls, 0) - bg = total_voxels - fg + fg = fg_counts.get(cls, 0) + total = total_voxels.get(cls, 0) + if fg > total > 0: + logger.warning( + "class_weights: fg (%d) > total_voxels (%d) for class %r; " + "this may indicate a counting error upstream.", + fg, + total, + cls, + ) + bg = max(total - fg, 0) weights[cls] = float(bg) / float(max(fg, 1)) return weights