diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index 780dbfa..6a35f60 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -252,12 +252,22 @@ def __setitem__( for key, val in arrays.items(): if key in _SKIP_KEYS: continue - if isinstance(val, (int, float)): - item[key] = val - elif isinstance(val, dict): + if np.isscalar(val): + raise TypeError( + f"Scalar writes are not supported (key={key!r}). " + "Pass an array or tensor with a leading batch dimension." + ) + if isinstance(val, dict): item[key] = {k: v[batch_i] for k, v in val.items()} - else: + elif hasattr(val, "__getitem__") and not isinstance(val, str): item[key] = val[batch_i] + else: + raise TypeError( + "Unsupported batched value type for key " + f"{key!r}: {type(val).__name__}. Expected a dict of " + "batch-indexable values or a batch-indexable " + "array/tensor/sequence." + ) self.__setitem__(int(i), item) return diff --git a/src/cellmap_data/image_writer.py b/src/cellmap_data/image_writer.py index f5fd726..7efb0d2 100644 --- a/src/cellmap_data/image_writer.py +++ b/src/cellmap_data/image_writer.py @@ -181,12 +181,27 @@ def __setitem__( Raises ------ TypeError - If *data* is a scalar (int or float). Use an array or tensor instead. + If *data* is a scalar (i.e. ``np.isscalar(data)`` is ``True``, including + Python and NumPy scalar types). Use a non-scalar array or tensor with + shape matching the patch instead. Zero-dimensional arrays/tensors are + also not supported for writes. """ - if isinstance(data, (int, float)): + if np.isscalar(data): raise TypeError( "Scalar writes are not supported. " - "Provide an array or tensor with the patch shape instead." + "Pass an array or tensor with shape matching the patch." + ) + # Explicitly reject zero-dimensional arrays/tensors, which are not caught + # by np.isscalar and are documented as unsupported for writes. + if isinstance(data, np.ndarray) and data.ndim == 0: + raise TypeError( + "Zero-dimensional NumPy arrays are not supported for writes. " + "Pass a non-scalar array or tensor with shape matching the patch." + ) + if torch.is_tensor(data) and data.dim() == 0: + raise TypeError( + "Zero-dimensional torch.Tensors are not supported for writes. " + "Pass a non-scalar tensor or array with shape matching the patch." ) first = next(iter(coords.values())) if isinstance(first, (int, float)): @@ -223,7 +238,7 @@ def _write_single( if data_np.ndim == 0: raise TypeError( "Scalar writes are not supported. " - "Provide an array or tensor with the patch shape instead." + "Pass an array or tensor with shape matching the patch." ) data_np = data_np.astype(self.dtype) diff --git a/tests/test_writer.py b/tests/test_writer.py index 9395197..5fca73b 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -3,6 +3,7 @@ from __future__ import annotations import numpy as np +import pytest import torch from cellmap_data import CellMapDatasetWriter @@ -56,6 +57,30 @@ def test_repr(self, tmp_path): ) assert "ImageWriter" in repr(writer) + def _make_image_writer(self, tmp_path): + return ImageWriter( + path=str(tmp_path / "out.zarr" / "mito"), + target_class="mito", + scale={"z": 8.0, "y": 8.0, "x": 8.0}, + bounding_box={"z": (0.0, 128.0), "y": (0.0, 128.0), "x": (0.0, 128.0)}, + write_voxel_shape={"z": 4, "y": 4, "x": 4}, + overwrite=True, + ) + + def test_setitem_zero_dim_ndarray_raises(self, tmp_path): + """A 0-D NumPy array must raise TypeError with a clear message.""" + writer = self._make_image_writer(tmp_path) + center = {"z": 16.0, "y": 16.0, "x": 16.0} + with pytest.raises(TypeError, match="Zero-dimensional NumPy arrays"): + writer[center] = np.array(1.0) + + def test_setitem_zero_dim_tensor_raises(self, tmp_path): + """A 0-D torch.Tensor must raise TypeError with a clear message.""" + writer = self._make_image_writer(tmp_path) + center = {"z": 16.0, "y": 16.0, "x": 16.0} + with pytest.raises(TypeError, match="Zero-dimensional torch.Tensors"): + writer[center] = torch.tensor(1.0) + class TestCellMapDatasetWriter: def _make_writer(self, tmp_path): @@ -111,6 +136,24 @@ def test_setitem_batch(self, tmp_path): output = {"mito": torch.zeros(2, 4, 4, 4)} writer[idx_tensor] = output # should not raise + def test_setitem_batch_scalar_raises(self, tmp_path): + """Passing a scalar value in a batch write must raise TypeError.""" + writer = self._make_writer(tmp_path) + indices = writer.writer_indices[:2] + idx_tensor = torch.tensor(indices) + # Scalar instead of a batched array — should raise + output = {"mito": 1.0} + with pytest.raises(TypeError, match="Scalar writes are not supported"): + writer[idx_tensor] = output + + def test_setitem_batch_unsupported_type_raises(self, tmp_path): + """A non-dict, non-indexable value in a batch write must raise TypeError.""" + writer = self._make_writer(tmp_path) + indices = writer.writer_indices[:2] + idx_tensor = torch.tensor(indices) + with pytest.raises(TypeError, match="Unsupported batched value type"): + writer[idx_tensor] = {"mito": object()} + def test_loader_iterable(self, tmp_path): writer = self._make_writer(tmp_path) loader = writer.loader(batch_size=2)