Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,15 @@ def class_counts(self) -> dict[str, Any]:
totals[cls] = 0
return {"totals": totals}

@property
def total_voxels(self) -> dict[str, int]:
"""Total voxels in the data volume per class, normalised to training-resolution voxels."""
totals: dict[str, int] = {}
for cls in self.classes:
src = self.target_sources.get(cls)
totals[cls] = src.total_voxels if src is not None else 0
Comment on lines +424 to +428
return totals

# ------------------------------------------------------------------
# Misc
# ------------------------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions src/cellmap_data/empty_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def sampling_box(self) -> None:
def class_counts(self) -> dict[str, int]:
return {self.label_class: 0}

@property
def total_voxels(self) -> int:
return 0

def to(self, device: str | torch.device) -> "EmptyImage":
self._nan_tensor = self._nan_tensor.to(device)
return self
Expand Down
14 changes: 14 additions & 0 deletions src/cellmap_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,20 @@ def class_counts(self) -> dict[str, int]:
logger.warning("class_counts failed for %s: %s", self.path, exc)
return {self.label_class: 0}

@property
def total_voxels(self) -> int:
"""Total number of voxels in the s0 data volume, normalised to training-resolution voxels."""
try:
s0_path = self._level_info[0][0]
s0_arr = zarr.open_array(f"{self.path}/{s0_path}", mode="r")
n_spatial = len(self.axes)
spatial_shape = s0_arr.shape[-n_spatial:]
s0_total = int(np.prod(spatial_shape))
return self._scale_count(s0_total, s0_idx=0)
Comment on lines +556 to +565
except Exception as exc:
logger.warning("total_voxels failed for %s: %s", self.path, exc)
return 0

def _scale_count(self, s0_count: int, s0_idx: int = 0) -> int:
"""Scale a voxel count from s0 resolution to training resolution."""
try:
Expand Down
28 changes: 23 additions & 5 deletions src/cellmap_data/multidataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,31 @@ def class_counts(self) -> dict[str, Any]:

@property
def class_weights(self) -> dict[str, float]:
"""Per-class sampling weight: ``bg_voxels / fg_voxels``."""
counts = self.class_counts["totals"]
total_voxels = sum(counts.values())
"""Per-class sampling weight: ``bg_voxels / fg_voxels``.

Background voxels are computed as the total voxels in the data volume
minus the foreground voxels for each class.
"""
fg_counts = self.class_counts["totals"]
# Aggregate actual total voxels per class across all datasets
total_voxels: dict[str, int] = {cls: 0 for cls in self.classes}
for ds in self.datasets:
ds_total = ds.total_voxels
for cls in self.classes:
Comment on lines +76 to +81
total_voxels[cls] += ds_total.get(cls, 0)
weights: dict[str, float] = {}
for cls in self.classes:
fg = counts.get(cls, 0)
bg = total_voxels - fg
fg = fg_counts.get(cls, 0)
total = total_voxels.get(cls, 0)
if fg > total > 0:
logger.warning(
"class_weights: fg (%d) > total_voxels (%d) for class %r; "
"this may indicate a counting error upstream.",
fg,
total,
cls,
)
bg = max(total - fg, 0)
weights[cls] = float(bg) / float(max(fg, 1))
return weights
Comment on lines 70 to 97

Expand Down
Loading