Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/cellmap_data/dataset_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,22 @@ def __setitem__(
for key, val in arrays.items():
if key in _SKIP_KEYS:
continue
if isinstance(val, (int, float)):
item[key] = val
elif isinstance(val, dict):
if np.isscalar(val):
raise TypeError(
f"Scalar writes are not supported (key={key!r}). "
"Pass an array or tensor with a leading batch dimension."
)
if isinstance(val, dict):
item[key] = {k: v[batch_i] for k, v in val.items()}
else:
elif hasattr(val, "__getitem__") and not isinstance(val, str):
item[key] = val[batch_i]
else:
raise TypeError(
"Unsupported batched value type for key "
f"{key!r}: {type(val).__name__}. Expected a dict of "
"batch-indexable values or a batch-indexable "
"array/tensor/sequence."
)
self.__setitem__(int(i), item)
return

Expand Down
23 changes: 19 additions & 4 deletions src/cellmap_data/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,27 @@ def __setitem__(
Raises
------
TypeError
If *data* is a scalar (int or float). Use an array or tensor instead.
If *data* is a scalar (i.e. ``np.isscalar(data)`` is ``True``, including
Python and NumPy scalar types). Use a non-scalar array or tensor with
shape matching the patch instead. Zero-dimensional arrays/tensors are
also not supported for writes.
"""
if isinstance(data, (int, float)):
if np.isscalar(data):
raise TypeError(
"Scalar writes are not supported. "
"Provide an array or tensor with the patch shape instead."
"Pass an array or tensor with shape matching the patch."
)
# Explicitly reject zero-dimensional arrays/tensors, which are not caught
# by np.isscalar and are documented as unsupported for writes.
if isinstance(data, np.ndarray) and data.ndim == 0:
raise TypeError(
"Zero-dimensional NumPy arrays are not supported for writes. "
"Pass a non-scalar array or tensor with shape matching the patch."
)
if torch.is_tensor(data) and data.dim() == 0:
raise TypeError(
"Zero-dimensional torch.Tensors are not supported for writes. "
"Pass a non-scalar tensor or array with shape matching the patch."
)
first = next(iter(coords.values()))
if isinstance(first, (int, float)):
Expand Down Expand Up @@ -223,7 +238,7 @@ def _write_single(
if data_np.ndim == 0:
raise TypeError(
"Scalar writes are not supported. "
"Provide an array or tensor with the patch shape instead."
"Pass an array or tensor with shape matching the patch."
)

data_np = data_np.astype(self.dtype)
Expand Down
43 changes: 43 additions & 0 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import numpy as np
import pytest
import torch

from cellmap_data import CellMapDatasetWriter
Expand Down Expand Up @@ -56,6 +57,30 @@ def test_repr(self, tmp_path):
)
assert "ImageWriter" in repr(writer)

def _make_image_writer(self, tmp_path):
return ImageWriter(
path=str(tmp_path / "out.zarr" / "mito"),
target_class="mito",
scale={"z": 8.0, "y": 8.0, "x": 8.0},
bounding_box={"z": (0.0, 128.0), "y": (0.0, 128.0), "x": (0.0, 128.0)},
write_voxel_shape={"z": 4, "y": 4, "x": 4},
overwrite=True,
)

def test_setitem_zero_dim_ndarray_raises(self, tmp_path):
"""A 0-D NumPy array must raise TypeError with a clear message."""
writer = self._make_image_writer(tmp_path)
center = {"z": 16.0, "y": 16.0, "x": 16.0}
with pytest.raises(TypeError, match="Zero-dimensional NumPy arrays"):
writer[center] = np.array(1.0)

def test_setitem_zero_dim_tensor_raises(self, tmp_path):
"""A 0-D torch.Tensor must raise TypeError with a clear message."""
writer = self._make_image_writer(tmp_path)
center = {"z": 16.0, "y": 16.0, "x": 16.0}
with pytest.raises(TypeError, match="Zero-dimensional torch.Tensors"):
writer[center] = torch.tensor(1.0)


class TestCellMapDatasetWriter:
def _make_writer(self, tmp_path):
Expand Down Expand Up @@ -111,6 +136,24 @@ def test_setitem_batch(self, tmp_path):
output = {"mito": torch.zeros(2, 4, 4, 4)}
writer[idx_tensor] = output # should not raise

def test_setitem_batch_scalar_raises(self, tmp_path):
"""Passing a scalar value in a batch write must raise TypeError."""
writer = self._make_writer(tmp_path)
indices = writer.writer_indices[:2]
idx_tensor = torch.tensor(indices)
# Scalar instead of a batched array — should raise
output = {"mito": 1.0}
with pytest.raises(TypeError, match="Scalar writes are not supported"):
writer[idx_tensor] = output

def test_setitem_batch_unsupported_type_raises(self, tmp_path):
"""A non-dict, non-indexable value in a batch write must raise TypeError."""
writer = self._make_writer(tmp_path)
indices = writer.writer_indices[:2]
idx_tensor = torch.tensor(indices)
with pytest.raises(TypeError, match="Unsupported batched value type"):
writer[idx_tensor] = {"mito": object()}

def test_loader_iterable(self, tmp_path):
writer = self._make_writer(tmp_path)
loader = writer.loader(batch_size=2)
Expand Down
Loading