From 0ed343b2f8503072ee8439052cf9576ed01a960a Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 4 Mar 2026 17:16:50 -0500 Subject: [PATCH 01/33] feat(tests): add comprehensive unit tests for CellMapDataset, CellMapDataSplit, EmptyImage, CellMapImage, CellMapMultiDataset, ClassBalancedSampler, and CellMapDatasetWriter - Implement tests for CellMapDataset including initialization, data retrieval, and validation. - Add tests for CellMapDataSplit covering dataset initialization from dictionary and CSV, validation datasets, and class counts. - Create tests for EmptyImage to ensure it returns NaN values and handles bounding boxes correctly. - Develop tests for CellMapImage focusing on initialization, bounding box calculations, and spatial transformations. - Introduce tests for CellMapMultiDataset to validate dataset combination and index mapping. - Implement tests for ClassBalancedSampler to ensure balanced sampling across classes. - Add tests for CellMapDatasetWriter to verify writing and reading of data, including shape properties and batch writing. --- pyproject.toml | 12 +- src/cellmap_data/__init__.py | 10 +- src/cellmap_data/base_dataset.py | 108 -- src/cellmap_data/base_image.py | 100 -- src/cellmap_data/dataloader.py | 477 ++------- src/cellmap_data/dataset.py | 1332 ++++++------------------ src/cellmap_data/dataset_writer.py | 691 +++++------- src/cellmap_data/datasplit.py | 659 +++++------- src/cellmap_data/empty_image.py | 119 +-- src/cellmap_data/image.py | 994 ++++++++---------- src/cellmap_data/image_writer.py | 434 +++----- src/cellmap_data/multidataset.py | 492 ++------- src/cellmap_data/mutable_sampler.py | 39 - src/cellmap_data/sampler.py | 92 ++ src/cellmap_data/subdataset.py | 102 -- src/cellmap_data/utils/__init__.py | 11 +- src/cellmap_data/utils/geometry.py | 48 + src/cellmap_data/utils/read_limiter.py | 70 -- src/cellmap_data/utils/sampling.py | 39 - src/cellmap_data/utils/view.py | 516 --------- tests/README.md | 326 ------ tests/demo_memory_fix.py | 294 ------ tests/test_api_contract.py | 569 ++++++++++ tests/test_base_classes.py | 220 ---- tests/test_cellmap_dataset.py | 617 ----------- tests/test_cellmap_image.py | 282 ----- tests/test_dataloader.py | 819 ++------------- tests/test_dataset.py | 217 ++++ tests/test_dataset_edge_cases.py | 471 --------- tests/test_dataset_writer.py | 580 ----------- tests/test_dataset_writer_batch.py | 209 ---- tests/test_datasplit.py | 139 +++ tests/test_empty_image.py | 48 + tests/test_empty_image_writer.py | 393 ------- tests/test_geometry.py | 83 ++ tests/test_helpers.py | 404 ++----- tests/test_image.py | 190 ++++ tests/test_image_edge_cases.py | 744 ------------- tests/test_init_optimizations.py | 532 ---------- tests/test_integration.py | 447 -------- tests/test_memory_management.py | 217 ---- tests/test_metadata.py | 291 ------ tests/test_multidataset.py | 96 ++ tests/test_multidataset_datasplit.py | 821 --------------- tests/test_mutable_sampler.py | 279 ----- tests/test_sampler.py | 88 ++ tests/test_subdataset.py | 252 ----- tests/test_transforms.py | 614 ++++++----- tests/test_utils.py | 454 -------- tests/test_windows_stress.py | 415 -------- tests/test_writer.py | 124 +++ 51 files changed, 3894 insertions(+), 13686 deletions(-) delete mode 100644 src/cellmap_data/base_dataset.py delete mode 100644 src/cellmap_data/base_image.py delete mode 100644 src/cellmap_data/mutable_sampler.py create mode 100644 src/cellmap_data/sampler.py delete mode 100644 src/cellmap_data/subdataset.py create mode 100644 src/cellmap_data/utils/geometry.py delete mode 100644 src/cellmap_data/utils/read_limiter.py delete mode 100644 src/cellmap_data/utils/sampling.py delete mode 100644 src/cellmap_data/utils/view.py delete mode 100644 tests/README.md delete mode 100755 tests/demo_memory_fix.py create mode 100644 tests/test_api_contract.py delete mode 100644 tests/test_base_classes.py delete mode 100644 tests/test_cellmap_dataset.py delete mode 100644 tests/test_cellmap_image.py create mode 100644 tests/test_dataset.py delete mode 100644 tests/test_dataset_edge_cases.py delete mode 100644 tests/test_dataset_writer.py delete mode 100644 tests/test_dataset_writer_batch.py create mode 100644 tests/test_datasplit.py create mode 100644 tests/test_empty_image.py delete mode 100644 tests/test_empty_image_writer.py create mode 100644 tests/test_geometry.py create mode 100644 tests/test_image.py delete mode 100644 tests/test_image_edge_cases.py delete mode 100644 tests/test_init_optimizations.py delete mode 100644 tests/test_integration.py delete mode 100644 tests/test_memory_management.py delete mode 100644 tests/test_metadata.py create mode 100644 tests/test_multidataset.py delete mode 100644 tests/test_multidataset_datasplit.py delete mode 100644 tests/test_mutable_sampler.py create mode 100644 tests/test_sampler.py delete mode 100644 tests/test_subdataset.py delete mode 100644 tests/test_utils.py delete mode 100644 tests/test_windows_stress.py create mode 100644 tests/test_writer.py diff --git a/pyproject.toml b/pyproject.toml index 89a069b..d4bbfac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "hatchling.build" # https://peps.python.org/pep-0621/ [project] name = "cellmap-data" -description = "Utility for loading CellMap data for machine learning training, utilizing PyTorch, Xarray, TensorStore, and PyDantic." +description = "Utility for loading CellMap data for machine learning training, utilizing PyTorch and zarr." readme = "README.md" requires-python = ">=3.11" license = { text = "BSD 3-Clause License" } @@ -28,18 +28,10 @@ dependencies = [ "torchvision", "numpy", "matplotlib", - "pydantic_ome_ngff", - "xarray_ome_ngff", - "tensorstore", - # "xarray=2024.10.0", - "xarray-tensorstore==0.1.5", + "zarr>=2.0,<3.0", "universal_pathlib>=0.2.0", "fsspec[s3,http]", - "neuroglancer", - "h5py", # Only needed until the new cellmap-flow is released - # "cellmap-flow", "ipython", - # "py_distance_transforms", "scipy", "tqdm", ] diff --git a/src/cellmap_data/__init__.py b/src/cellmap_data/__init__.py index 9e59ea0..8f5e24a 100644 --- a/src/cellmap_data/__init__.py +++ b/src/cellmap_data/__init__.py @@ -12,8 +12,6 @@ __author__ = "Jeff Rhoades" __email__ = "rhoadesj@hhmi.org" -from .base_dataset import CellMapBaseDataset -from .base_image import CellMapImageBase from .dataloader import CellMapDataLoader from .dataset import CellMapDataset from .dataset_writer import CellMapDatasetWriter @@ -22,12 +20,9 @@ from .image import CellMapImage from .image_writer import ImageWriter from .multidataset import CellMapMultiDataset -from .mutable_sampler import MutableSubsetRandomSampler -from .subdataset import CellMapSubset +from .sampler import ClassBalancedSampler __all__ = [ - "CellMapBaseDataset", - "CellMapImageBase", "CellMapDataLoader", "CellMapDataset", "CellMapDatasetWriter", @@ -35,7 +30,6 @@ "CellMapImage", "ImageWriter", "CellMapMultiDataset", - "CellMapSubset", "EmptyImage", - "MutableSubsetRandomSampler", + "ClassBalancedSampler", ] diff --git a/src/cellmap_data/base_dataset.py b/src/cellmap_data/base_dataset.py deleted file mode 100644 index d27c629..0000000 --- a/src/cellmap_data/base_dataset.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Abstract base class for CellMap dataset objects.""" - -from abc import ABC, abstractmethod -from typing import Any, Callable, Mapping, Sequence - -import torch - - -class CellMapBaseDataset(ABC): - """ - Abstract base class for CellMap dataset objects. - - This class defines the common interface that all CellMap dataset objects - must implement, ensuring consistency across different dataset types. - - Note: `classes`, `input_arrays`, and `target_arrays` are not abstract - properties because implementing classes define them as instance attributes - in __init__, not as properties. - """ - - # These are instance attributes set in __init__, not properties - classes: Sequence[str] | None - input_arrays: Mapping[str, Mapping[str, Any]] - target_arrays: Mapping[str, Mapping[str, Any]] | None - - @property - @abstractmethod - def class_counts(self) -> dict[str, float]: - """ - Return the number of samples in each class, normalized by resolution. - - Returns - ------- - dict[str, float] - Dictionary mapping class names to their counts. - """ - pass - - @property - @abstractmethod - def class_weights(self) -> dict[str, float]: - """ - Return the class weights based on the number of samples in each class. - - Returns - ------- - dict[str, float] - Dictionary mapping class names to their weights. - """ - pass - - @property - @abstractmethod - def validation_indices(self) -> Sequence[int]: - """ - Return the indices for the validation set. - - Returns - ------- - Sequence[int] - List of validation indices. - """ - pass - - @abstractmethod - def to( - self, device: str | torch.device, non_blocking: bool = True - ) -> "CellMapBaseDataset": - """ - Move the dataset to the specified device. - - Parameters - ---------- - device : str | torch.device - The target device. - non_blocking : bool, optional - Whether to use non-blocking transfer, by default True. - - Returns - ------- - CellMapBaseDataset - Self for method chaining. - """ - pass - - @abstractmethod - def set_raw_value_transforms(self, transforms: Callable) -> None: - """ - Set the value transforms for raw input data. - - Parameters - ---------- - transforms : Callable - Transform function to apply to raw data. - """ - pass - - @abstractmethod - def set_target_value_transforms(self, transforms: Callable) -> None: - """ - Set the value transforms for target data. - - Parameters - ---------- - transforms : Callable - Transform function to apply to target data. - """ - pass diff --git a/src/cellmap_data/base_image.py b/src/cellmap_data/base_image.py deleted file mode 100644 index 57e157d..0000000 --- a/src/cellmap_data/base_image.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Abstract base class for CellMap image objects.""" - -from abc import ABC, abstractmethod -from typing import Any, Mapping - -import torch - - -class CellMapImageBase(ABC): - """ - Abstract base class for CellMap image objects. - - This class defines the common interface that all CellMap image objects - must implement, ensuring consistency across different image types. - """ - - @abstractmethod - def __getitem__(self, center: Mapping[str, float]) -> torch.Tensor: - """ - Return image data centered around the given point. - - Parameters - ---------- - center : Mapping[str, float] - The center coordinates in world units. - - Returns - ------- - torch.Tensor - The image data as a PyTorch tensor. - """ - pass - - @property - @abstractmethod - def bounding_box(self) -> Mapping[str, tuple[float, float]] | None: - """ - Return the bounding box of the image in world units. - - Returns - ------- - Mapping[str, tuple[float, float]] | None - Dictionary mapping axis names to (min, max) tuples, or None. - """ - pass - - @property - @abstractmethod - def sampling_box(self) -> Mapping[str, tuple[float, float]] | None: - """ - Return the sampling box of the image in world units. - - The sampling box is the region where centers can be drawn from and - still have full samples drawn from within the bounding box. - - Returns - ------- - Mapping[str, tuple[float, float]] | None - Dictionary mapping axis names to (min, max) tuples, or None. - """ - pass - - @property - @abstractmethod - def class_counts(self) -> float | dict[str, float]: - """ - Return the number of voxels for each class in the image. - - Returns - ------- - float | dict[str, float] - Class counts, either as a single float or dictionary. - """ - pass - - @abstractmethod - def to(self, device: str | torch.device, non_blocking: bool = True) -> None: - """ - Move the image data to the specified device. - - Parameters - ---------- - device : str | torch.device - The target device. - non_blocking : bool, optional - Whether to use non-blocking transfer, by default True. - """ - pass - - @abstractmethod - def set_spatial_transforms(self, transforms: Mapping[str, Any] | None) -> None: - """ - Set spatial transformations for the image data. - - Parameters - ---------- - transforms : Mapping[str, Any] | None - Dictionary of spatial transformations to apply. - """ - pass diff --git a/src/cellmap_data/dataloader.py b/src/cellmap_data/dataloader.py index a92d111..25313b4 100644 --- a/src/cellmap_data/dataloader.py +++ b/src/cellmap_data/dataloader.py @@ -1,164 +1,80 @@ -import os -import platform +"""CellMapDataLoader: thin wrapper around PyTorch DataLoader.""" + +from __future__ import annotations + import logging -from typing import Callable, Optional, Sequence, Union +from typing import Any, Callable, Iterator, Optional, Sequence, Union import torch import torch.utils.data +from torch.utils.data import DataLoader, Dataset, Subset -from .dataset import CellMapDataset -from .dataset_writer import CellMapDatasetWriter -from .image import CellMapImage -from .multidataset import CellMapMultiDataset -from .mutable_sampler import MutableSubsetRandomSampler -from .subdataset import CellMapSubset +from .sampler import ClassBalancedSampler logger = logging.getLogger(__name__) -# Default TensorStore chunk-cache budget applied when neither the constructor -# argument nor CELLMAP_TENSORSTORE_CACHE_BYTES env var is set. -# 2 GiB is a conservative bound that prevents unbounded RAM growth across -# epochs while still providing meaningful caching for hot regions. -# Override via CELLMAP_TENSORSTORE_CACHE_BYTES or the tensorstore_cache_bytes -# constructor argument. -_DEFAULT_TENSORSTORE_CACHE_BYTES = 2 * 1024**3 # 2 GiB - - -def _set_tensorstore_context(dataset, context) -> None: - """ - Recursively set a TensorStore context on every CellMapImage in the dataset tree. - - This must be called before workers are spawned so the bounded cache_pool - limit is picked up by every worker process (via fork inheritance on Linux, - or via pickle on Windows/macOS spawn). - - If an image's TensorStore array has already been opened (``array`` cached), - the new context cannot affect that array; a warning is emitted. - """ - if isinstance(dataset, CellMapMultiDataset): - for ds in dataset.datasets: - _set_tensorstore_context(ds, context) - elif isinstance(dataset, CellMapSubset): - _set_tensorstore_context(dataset.dataset, context) - elif isinstance(dataset, CellMapDataset): - dataset.context = context - all_sources = list(dataset.input_sources.values()) + list( - dataset.target_sources.values() - ) - for source in all_sources: - if isinstance(source, CellMapImage): - _apply_context_to_image(source, context) - elif isinstance(source, dict): - for sub_source in source.values(): - if isinstance(sub_source, CellMapImage): - _apply_context_to_image(sub_source, context) - else: - logger.warning( - "Unsupported dataset type %s in _set_tensorstore_context; " - "TensorStore context was not applied.", - type(dataset).__name__, - ) - - -def _apply_context_to_image(image: "CellMapImage", context) -> None: - """Set the TensorStore context on a single CellMapImage, warning if already opened.""" - if "array" in getattr(image, "__dict__", {}): - logger.warning( - "TensorStore array already opened for %s; " - "cache_pool limit will not apply to this image.", - getattr(image, "path", image), - ) - image.context = context - class CellMapDataLoader: - """ - Optimized DataLoader wrapper for CellMapDataset that uses PyTorch's native DataLoader. + """PyTorch-compatible DataLoader for CellMap datasets. - This class provides a simplified, high-performance interface to PyTorch's DataLoader - with optimizations for GPU training including prefetch_factor, persistent_workers, - and pin_memory support. + Wraps :class:`torch.utils.data.DataLoader` with optional + :class:`~cellmap_data.sampler.ClassBalancedSampler` for class-balanced + training. - Attributes + Parameters ---------- - dataset (CellMapMultiDataset | CellMapDataset | CellMapSubset): Dataset to load. - classes (Iterable[str]): Classes to load. - batch_size (int): Batch size. - num_workers (int): Number of workers. - weighted_sampler (bool): Whether to use a weighted sampler. - sampler (Union[MutableSubsetRandomSampler, Callable, None]): Sampler to use. - is_train (bool): Whether data is for training (shuffled). - rng (Optional[torch.Generator]): Random number generator. - loader (torch.utils.data.DataLoader): Underlying PyTorch DataLoader. - default_kwargs (dict): Default arguments for compatibility. + dataset: + A :class:`~cellmap_data.dataset.CellMapDataset`, + :class:`~cellmap_data.multidataset.CellMapMultiDataset`, or any + compatible dataset / ``Subset``. + classes: + Class names. Defaults to ``dataset.classes``. + batch_size: + Samples per batch. + num_workers: + DataLoader worker processes (0 = main process only). + weighted_sampler: + If ``True`` and ``is_train=True``, use + :class:`ClassBalancedSampler` (requires ``dataset`` to implement + ``get_crop_class_matrix()``). + sampler: + Explicit sampler; overrides *weighted_sampler*. + is_train: + Training mode (enables random sampling and the weighted sampler). + device: + Ignored — tensors are returned on CPU; move them in training loop. + iterations_per_epoch: + Number of samples per epoch when using ClassBalancedSampler. + Defaults to ``len(dataset)``. + **kwargs: + Forwarded to :class:`torch.utils.data.DataLoader`. """ def __init__( self, - dataset: ( - CellMapMultiDataset | CellMapDataset | CellMapSubset | CellMapDatasetWriter - ), - classes: Sequence[str] | None = None, + dataset: Dataset, + classes: Optional[Sequence[str]] = None, batch_size: int = 1, num_workers: int = 0, weighted_sampler: bool = False, - sampler: Union[MutableSubsetRandomSampler, Callable, None] = None, + sampler: Optional[Union[torch.utils.data.Sampler, Callable]] = None, is_train: bool = True, rng: Optional[torch.Generator] = None, device: Optional[str | torch.device] = None, iterations_per_epoch: Optional[int] = None, - tensorstore_cache_bytes: Optional[int] = None, - **kwargs, - ): - """ - Initializes the CellMapDataLoader with an optimized PyTorch DataLoader backend. - - Args: - ---- - dataset: The dataset to load. - classes: The classes to load. - batch_size: The batch size. - num_workers: The number of workers. - weighted_sampler: Whether to use a weighted sampler. - sampler: The sampler to use. - is_train: Whether the data is for training (shuffled). - rng: The random number generator. - device: The device to use ("cuda", "mps", or "cpu"). - iterations_per_epoch: Iterations per epoch for large datasets. - tensorstore_cache_bytes: Total TensorStore chunk-cache budget in bytes - shared across all worker processes. The budget is split evenly: - ``per_worker = tensorstore_cache_bytes // max(1, num_workers)``. - **Important:** When ``tensorstore_cache_bytes < num_workers``, each worker - receives a minimum of 1 byte (instead of 0, which TensorStore treats as - unlimited), so the effective aggregate cache may exceed the requested total. - To avoid this, ensure ``tensorstore_cache_bytes >= num_workers``. - Resolution order: explicit argument → ``CELLMAP_TENSORSTORE_CACHE_BYTES`` - env var → built-in default of 2 GiB. Bounding this value prevents - persistent worker processes from accumulating chunk data unboundedly - across epochs. **Note:** TensorStore ignores a limit of ``0`` (treats - it as unlimited); to minimize caching use a small positive value such - as ``1``. Pass ``0`` only if you explicitly want an unbounded cache. - **kwargs: Additional PyTorch DataLoader arguments. - """ + **kwargs: Any, + ) -> None: self.dataset = dataset - self.classes = classes if classes is not None else dataset.classes + self.classes = ( + classes if classes is not None else getattr(dataset, "classes", []) + ) self.batch_size = batch_size self.num_workers = num_workers - self.weighted_sampler = weighted_sampler - self.sampler = sampler self.is_train = is_train - self.rng = rng - - if platform.system() == "Windows" and num_workers > 0: - logger.warning( - "CellMapDataLoader: num_workers=%d on Windows. " - "The dataset uses a synchronous (single-thread) executor internally " - "so TensorStore reads are never dispatched to ThreadPoolExecutor " - "worker threads. If crashes persist, try num_workers=0.", - num_workers, - ) + self.iterations_per_epoch = iterations_per_epoch + self._kwargs = kwargs - # Set device + # Resolve device if device is None: if torch.cuda.is_available(): device = "cuda" @@ -167,253 +83,80 @@ def __init__( else: device = "cpu" self.device = device - self.iterations_per_epoch = iterations_per_epoch - # Bound TensorStore chunk-cache to prevent unbounded RAM growth in - # persistent worker processes (Linux fork, Windows/macOS spawn). - # Resolve from parameter → env var → built-in default. - if tensorstore_cache_bytes is None: - _env = os.environ.get("CELLMAP_TENSORSTORE_CACHE_BYTES") - if _env is not None: - try: - tensorstore_cache_bytes = int(_env) - except ValueError as exc: - raise ValueError( - "Invalid value for environment variable " - "CELLMAP_TENSORSTORE_CACHE_BYTES: " - f"{_env!r}. Expected an integer number of bytes." - ) from exc + # Build sampler + if sampler is not None: + self._sampler = sampler + elif weighted_sampler and is_train: + if hasattr(dataset, "get_crop_class_matrix"): + n_samples = iterations_per_epoch or len(dataset) + self._sampler: Any = ClassBalancedSampler(dataset, n_samples) else: - tensorstore_cache_bytes = _DEFAULT_TENSORSTORE_CACHE_BYTES - logger.info( - "TensorStore cache limit not set; applying default of %d bytes " - "(%.1f GiB). Override via tensorstore_cache_bytes= or " - "CELLMAP_TENSORSTORE_CACHE_BYTES env var.", - _DEFAULT_TENSORSTORE_CACHE_BYTES, - _DEFAULT_TENSORSTORE_CACHE_BYTES / 1024**3, - ) - if tensorstore_cache_bytes is not None and tensorstore_cache_bytes < 0: - raise ValueError( - f"tensorstore_cache_bytes must be >= 0 when set; got {tensorstore_cache_bytes}" - ) - self.tensorstore_cache_bytes = tensorstore_cache_bytes - - # NOTE: TensorStore silently treats total_bytes_limit=0 as "no limit" - # (it strips zero values from the context spec). Only positive values - # actually register a bound, so we skip context creation for 0. - if ( - tensorstore_cache_bytes is not None - and tensorstore_cache_bytes > 0 - and not isinstance(dataset, CellMapDatasetWriter) - ): - import tensorstore as ts - - effective_workers = max(1, num_workers) - per_worker_bytes = tensorstore_cache_bytes // effective_workers - if per_worker_bytes == 0 and tensorstore_cache_bytes > 0: - per_worker_bytes = 1 logger.warning( - "tensorstore_cache_bytes=%d with num_workers=%d results in " - "per-worker cache limit of 0 bytes, which TensorStore treats as " - "unlimited. Setting per-worker cache limit to 1 byte to enforce " - "a meaningful bound. To avoid this warning, set tensorstore_cache_bytes " - "to at least %d bytes for num_workers=%d.", - tensorstore_cache_bytes, - num_workers, - effective_workers, - effective_workers, - ) - bounded_ctx = ts.Context( - {"cache_pool": {"total_bytes_limit": per_worker_bytes}} - ) - _set_tensorstore_context(dataset, bounded_ctx) - logger.info( - "TensorStore cache bounded: total=%d bytes / %d worker(s) = %d bytes each", - tensorstore_cache_bytes, - effective_workers, - per_worker_bytes, - ) - elif tensorstore_cache_bytes == 0: - logger.warning( - "tensorstore_cache_bytes=0: TensorStore does not support a 0-byte " - "cache limit (it treats 0 as unlimited). The cache is unbounded. " - "To meaningfully limit caching, set a positive value.", - ) - - # Extract DataLoader parameters with optimized defaults - # pin_memory only works with CUDA, so default to True only when CUDA is available - # and device is CUDA - pin_memory_default = ( - torch.cuda.is_available() - and str(device).startswith("cuda") - and platform.system() != "Windows" - ) # pin_memory has issues on Windows with CUDA - self._pin_memory = kwargs.pop("pin_memory", pin_memory_default) - - # Validate pin_memory setting - if self._pin_memory and not str(device).startswith("cuda"): - logger.warning( - "pin_memory=True is only supported with CUDA. Disabling for %s.", - device, - ) - self._pin_memory = False - - self._persistent_workers = kwargs.pop("persistent_workers", num_workers > 0) - self._drop_last = kwargs.pop("drop_last", False) - - # Set prefetch_factor for better GPU utilization (default 2, increase for GPU training) - # Only applicable when num_workers > 0 - if num_workers > 0: - prefetch_factor = kwargs.pop("prefetch_factor", 2) - if not isinstance(prefetch_factor, int) or prefetch_factor < 1: - raise ValueError( - f"prefetch_factor must be a positive integer, got {prefetch_factor}" + "weighted_sampler=True but dataset does not implement " + "get_crop_class_matrix(); falling back to default sampler." ) - self._prefetch_factor = prefetch_factor + self._sampler = None else: - kwargs.pop("prefetch_factor", None) - self._prefetch_factor = None - - # Setup sampler - if self.sampler is None: - if iterations_per_epoch is not None or ( - weighted_sampler and len(self.dataset) > 2**24 - ): - if iterations_per_epoch is None: - raise ValueError( - "iterations_per_epoch must be specified for large datasets." - ) - if isinstance(self.dataset, CellMapDatasetWriter): - raise TypeError( - "CellMapDatasetWriter does not support random sampling." - ) - self.sampler = self.dataset.get_subset_random_sampler( - num_samples=iterations_per_epoch * batch_size, - weighted=weighted_sampler, - rng=self.rng, - ) - elif weighted_sampler and isinstance(self.dataset, CellMapMultiDataset): - self.sampler = self.dataset.get_weighted_sampler( - self.batch_size, self.rng - ) - - self.default_kwargs = kwargs - self.default_kwargs.update( - { - "pin_memory": self._pin_memory, - "persistent_workers": self._persistent_workers, - "drop_last": self._drop_last, - } + self._sampler = None + + # pin_memory: use on CUDA, skip otherwise to avoid issues + pin = kwargs.pop("pin_memory", str(device).startswith("cuda")) + + self.loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=(is_train and self._sampler is None), + sampler=self._sampler, + num_workers=num_workers, + collate_fn=self.collate_fn, + pin_memory=pin, + **self._kwargs, ) - if self._prefetch_factor is not None: - self.default_kwargs["prefetch_factor"] = self._prefetch_factor - - self._pytorch_loader = None - self.refresh() - @property - def loader(self) -> torch.utils.data.DataLoader | None: - """Return the DataLoader.""" - return self._pytorch_loader + # ------------------------------------------------------------------ + # Collation + # ------------------------------------------------------------------ + + @staticmethod + def collate_fn(batch: list[dict]) -> dict[str, Any]: + """Stack tensor values; preserve string / non-tensor items.""" + if not batch: + return {} + keys = batch[0].keys() + result: dict[str, Any] = {} + for key in keys: + values = [item[key] for item in batch] + if isinstance(values[0], torch.Tensor): + result[key] = torch.stack(values) + elif key == "__metadata__": + result[key] = values + else: + try: + result[key] = torch.stack(values) + except (TypeError, RuntimeError): + result[key] = values + return result - def __getitem__(self, indices: Union[int, Sequence[int]]) -> dict: - """Get an item from the DataLoader.""" - if isinstance(indices, int): - indices = [indices] - return self.collate_fn([self.dataset[index] for index in indices]) + # ------------------------------------------------------------------ + # DataLoader interface + # ------------------------------------------------------------------ - def __iter__(self): - """Create an iterator over the dataset.""" - if self._pytorch_loader is None: - self.refresh() - return iter(self._pytorch_loader) + def __iter__(self) -> Iterator[dict]: + return iter(self.loader) - def __len__(self) -> int | None: - """Return the number of batches per epoch.""" - if self._pytorch_loader is None: - return None - return len(self._pytorch_loader) + def __len__(self) -> int: + return len(self.loader) - def to(self, device: str | torch.device, non_blocking: bool = True): - """Move the dataset to the specified device.""" - self.dataset.to(device, non_blocking=non_blocking) + def to(self, device: str | torch.device) -> "CellMapDataLoader": + """Move the underlying dataset to *device* (no-op for CPU datasets).""" + if hasattr(self.dataset, "to"): + self.dataset.to(device) self.device = device return self - def refresh(self): - """Refresh the DataLoader with the current sampler state.""" - if self._pytorch_loader is not None: - # Explicitly drop the old loader before creating a new one. - # With persistent_workers=True, simply reassigning self._pytorch_loader - # keeps workers alive until GC; explicit deletion triggers immediate - # shutdown via DataLoader.__del__ → reference counting. - old_loader = self._pytorch_loader - self._pytorch_loader = None - del old_loader - if isinstance(self.sampler, MutableSubsetRandomSampler): - self.sampler.refresh() - - dataloader_sampler = None - shuffle = False - - if self.sampler is not None: - if isinstance(self.sampler, MutableSubsetRandomSampler): - dataloader_sampler = self.sampler - elif callable(self.sampler): - dataloader_sampler = self.sampler() - else: - dataloader_sampler = self.sampler - else: - shuffle = self.is_train - - dataloader_kwargs = { - "batch_size": self.batch_size, - "shuffle": shuffle if dataloader_sampler is None else False, - "num_workers": self.num_workers, - "collate_fn": self.collate_fn, - "pin_memory": self._pin_memory, - "drop_last": self._drop_last, - "generator": self.rng, - } - - # Add sampler if provided - if dataloader_sampler is not None: - dataloader_kwargs["sampler"] = dataloader_sampler - - # Add persistent_workers only if num_workers > 0 - if self.num_workers > 0: - dataloader_kwargs["persistent_workers"] = self._persistent_workers - if self._prefetch_factor is not None: - dataloader_kwargs["prefetch_factor"] = self._prefetch_factor - - # Add any additional kwargs - for key, value in self.default_kwargs.items(): - if key not in dataloader_kwargs: - dataloader_kwargs[key] = value - - dataloader_kwargs.pop("force_has_data", None) - - # Ensure that dataset is loaded onto CPU if pin_memory is used - if self._pin_memory: - self.dataset.to("cpu") - - self._pytorch_loader = torch.utils.data.DataLoader( - self.dataset, **dataloader_kwargs + def __repr__(self) -> str: + return ( + f"CellMapDataLoader(dataset={self.dataset!r}, " + f"batch_size={self.batch_size}, is_train={self.is_train})" ) - - def collate_fn(self, batch: Sequence) -> dict[str, torch.Tensor]: - """ - Collates a batch of samples into a single dictionary of tensors. - """ - outputs = {} - for b in batch: - for key, value in b.items(): - if key not in outputs: - outputs[key] = [] - outputs[key].append(value) - - for key, value in outputs.items(): - if key != "__metadata__": - outputs[key] = torch.stack(value) - - return outputs diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index 00b1cdd..91a6aff 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -1,1071 +1,421 @@ -# %% -import atexit -import functools -from functools import cached_property +"""CellMapDataset: PyTorch Dataset for CellMap OME-NGFF data.""" + +from __future__ import annotations + import logging import os -import platform -from concurrent.futures import Executor as _ConcurrentExecutor -from concurrent.futures import Future as _ConcurrentFuture -from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import cached_property from typing import Any, Callable, Mapping, Optional, Sequence import numpy as np -import tensorstore import torch -from numpy.typing import ArrayLike from torch.utils.data import Dataset -from .base_dataset import CellMapBaseDataset from .empty_image import EmptyImage from .image import CellMapImage -from .mutable_sampler import MutableSubsetRandomSampler -from .utils.read_limiter import MAX_CONCURRENT_READS, limit_tensorstore_reads -from .utils import get_sliced_shape, is_array_2D, min_redundant_inds, split_target_path +from .utils import split_target_path +from .utils.geometry import box_intersection, box_shape logger = logging.getLogger(__name__) -if logger.level == logging.NOTSET: - logger.setLevel(logging.INFO) - -# Cache system values to avoid repeated calls during dataset instantiation -_OS_NAME = platform.system() -_DATA_BACKEND = os.environ.get("CELLMAP_DATA_BACKEND", "tensorstore") - -# On Windows + TensorStore, calling tensorstore's .read().result() from a -# Python ThreadPoolExecutor worker thread causes a hard native crash -# (STATUS_STACK_BUFFER_OVERRUN / abort, exit code 0xC0000409). The -# limit_tensorstore_reads semaphore only prevents *concurrent* Python reads -# but does not fix the per-thread crash. The safest fix is to run all -# dataset __getitem__ work synchronously in the calling thread so that -# TensorStore is never invoked from a ThreadPoolExecutor worker on Windows. -_USE_IMMEDIATE_EXECUTOR = ( - _OS_NAME == "Windows" and _DATA_BACKEND.lower() == "tensorstore" -) - -# Per-process executor singleton. -# -# Using one ThreadPoolExecutor per CellMapDataset causes a thread explosion -# when many dataset instances exist inside DataLoader worker processes: -# Before: N_datasets × num_workers × max_workers threads -# After: num_workers × max_workers threads -# -# _PROCESS_EXECUTORS is keyed by PID so that forked child processes (DataLoader -# workers on Linux) automatically get their own fresh pool on first access. -# dict.setdefault() is atomic under CPython's GIL, avoiding the need for an -# explicit lock (and the fork-safety issues that come with one). -_PROCESS_EXECUTORS: dict[int, ThreadPoolExecutor] = {} - - -def _get_process_executor(max_workers: int) -> ThreadPoolExecutor: - """Return the per-process shared ThreadPoolExecutor, creating it on first call. - - The first CellMapDataset created in a process determines the pool size. - Subsequent datasets in the same process reuse the existing pool regardless - of their own max_workers setting. - """ - pid = os.getpid() - if pid not in _PROCESS_EXECUTORS: - executor = ThreadPoolExecutor(max_workers=max_workers) - # setdefault is atomic under the GIL; if two threads race here the - # loser's executor is discarded cleanly. - existing = _PROCESS_EXECUTORS.setdefault(pid, executor) - if existing is not executor: - executor.shutdown(wait=False) - return _PROCESS_EXECUTORS[pid] - - -def _shutdown_process_executor() -> None: - """Shut down the executor for the current process (registered via atexit).""" - pid = os.getpid() - executor = _PROCESS_EXECUTORS.pop(pid, None) - if executor is not None: - executor.shutdown(wait=True, cancel_futures=True) -atexit.register(_shutdown_process_executor) +def _make_rotation_matrix(axes: list[str], rotation_config: dict) -> np.ndarray | None: + """Build a rotation matrix from a per-axis angle dict (degrees). - -class _ImmediateExecutor(_ConcurrentExecutor): - """Drop-in for ThreadPoolExecutor that runs tasks in the calling thread. - - On Windows + TensorStore the real ThreadPoolExecutor causes native crashes. - This executor avoids that by executing every submitted callable synchronously - before returning, so the returned Future is already resolved. - ``as_completed`` handles pre-resolved futures correctly (yields immediately). - ``map`` is inherited from ``concurrent.futures.Executor`` and works correctly - because it calls ``submit`` internally (which returns pre-resolved futures). - ``shutdown`` is a no-op because there are no threads to join. + Returns a (n, n) orthonormal rotation matrix, or ``None`` if all angles + are zero. """ - - def submit(self, fn, /, *args, **kwargs): - f = _ConcurrentFuture() - try: - f.set_result(fn(*args, **kwargs)) - except Exception as exc: # noqa: BLE001 - f.set_exception(exc) - return f - - def shutdown(self, wait=True, *, cancel_futures=False): - pass # nothing to shut down - - -_IMMEDIATE_EXECUTOR: _ImmediateExecutor | None = ( - _ImmediateExecutor() if _USE_IMMEDIATE_EXECUTOR else None -) - - -# %% -class CellMapDataset(CellMapBaseDataset, Dataset): - """ - Subclasses PyTorch Dataset to load CellMap data for training. - - This class handles data sources for raw and ground truth data, including paths, - segmentation classes, and input/target array configurations. It retrieves data, - calculates class-specific pixel counts, and generates random crops for training. - It also combines images for different classes into a single output array, - which is useful for training multi-class segmentation networks. + n = len(axes) + R = np.eye(n) + for ax, angle_deg in rotation_config.items(): + if angle_deg == 0 or ax not in axes: + continue + theta = np.deg2rad(angle_deg) + # Determine the two axes to rotate in (perpendicular to *ax*) + ax_idx = axes.index(ax) + other = [i for i in range(n) if i != ax_idx] + if len(other) < 2: + continue + i, j = other[0], other[1] + Ri = np.eye(n) + Ri[i, i] = np.cos(theta) + Ri[i, j] = -np.sin(theta) + Ri[j, i] = np.sin(theta) + Ri[j, j] = np.cos(theta) + R = R @ Ri + return None if np.allclose(R, np.eye(n)) else R + + +class CellMapDataset(Dataset): + """PyTorch Dataset that reads patches from a single CellMap zarr dataset. + + Parameters + ---------- + raw_path: + Path to the raw EM zarr group. + target_path: + Path template for GT labels, with classes in brackets, e.g. + ``"/data/jrc.zarr/labels/[mito,er]"``. Each class occupies a + sub-group of the base path (``base/mito``, ``base/er``, …). + classes: + Segmentation classes to load. + input_arrays: + ``{array_name: {"shape": (z,y,x), "scale": (z,y,x)}}`` specs for + input patches. + target_arrays: + ``{array_name: {"shape": (z,y,x), "scale": (z,y,x)}}`` specs for + target patches. All classes share these specs. + pad: + Whether to pad reads that extend beyond array bounds with + ``pad_value`` (NaN by default). + spatial_transforms: + Augmentation config dict with optional keys ``"mirror"``, + ``"transpose"``, and ``"rotate"``. Example:: + + { + "mirror": {"z": True, "y": True, "x": True}, + "transpose": True, + "rotate": {"z": 45}, # max degrees + } + raw_value_transforms: + Callable applied to each raw input tensor. + target_value_transforms: + Callable (or ``{class: callable}`` dict) applied to each target + tensor. + class_relation_dict: + Stored for API compatibility; not used in inference currently. + force_has_data: + Skip the empty-data check when ``True``. + device: + Ignored — all tensors are returned on CPU. """ def __init__( self, raw_path: str, target_path: str, - classes: Sequence[str] | None, - input_arrays: Mapping[str, Mapping[str, Sequence[int | float]]], - target_arrays: Mapping[str, Mapping[str, Sequence[int | float]]] | None = None, - spatial_transforms: Optional[Mapping[str, Mapping]] = None, # type: ignore + classes: Sequence[str], + input_arrays: Mapping[str, Mapping[str, Any]], + target_arrays: Mapping[str, Mapping[str, Any]], + pad: bool = False, + spatial_transforms: Optional[Mapping[str, Any]] = None, raw_value_transforms: Optional[Callable] = None, - target_value_transforms: Optional[ - Callable | Sequence[Callable] | Mapping[str, Callable] - ] = None, + target_value_transforms: Optional[Callable | Mapping[str, Callable]] = None, class_relation_dict: Optional[Mapping[str, Sequence[str]]] = None, - is_train: bool = False, - axis_order: str = "zyx", - context: Optional[tensorstore.Context] = None, # type: ignore - rng: Optional[torch.Generator] = None, force_has_data: bool = False, - empty_value: float | int = torch.nan, - pad: bool = True, - device: Optional[str | torch.device] = "cpu", - max_workers: Optional[int] = None, + device: Optional[str | torch.device] = None, ) -> None: - """Initializes the CellMapDataset class. - - Args: - ---- - raw_path: Path to the raw data. - target_path: Path to the ground truth data. - classes: List of classes for segmentation training. - input_arrays: Dictionary of input arrays with shape and scale. - target_arrays: Dictionary of target arrays with shape and scale. - spatial_transforms: Spatial transformations to apply. - raw_value_transforms: Transforms for raw data (e.g., normalization). - target_value_transforms: Transforms for target data (e.g., distance transform). - class_relation_dict: Defines mutual exclusivity between classes. - is_train: Whether the dataset is for training. - axis_order: The order of axes (e.g., "zyx"). - context: TensorStore context. - rng: Random number generator. - force_has_data: If True, forces the dataset to report having data. - empty_value: Value for empty data. - pad: Whether to pad data to match requested array shapes. - device: The device for torch tensors. Defaults to CPU. - max_workers: Max worker threads for data loading. - """ - super().__init__() self.raw_path = raw_path self.target_path = target_path - self.target_path_str, self.classes_with_path = split_target_path(target_path) - self.classes = classes if classes is not None else [] - self.raw_only = classes is None - self.input_arrays = input_arrays - self.target_arrays = target_arrays if target_arrays is not None else {} - self.spatial_transforms = spatial_transforms + self.classes = list(classes) + self.input_arrays = dict(input_arrays) + self.target_arrays = dict(target_arrays) + self.pad = pad + self.spatial_transforms_config = spatial_transforms self.raw_value_transforms = raw_value_transforms self.target_value_transforms = target_value_transforms self.class_relation_dict = class_relation_dict - self.is_train = is_train - self.axis_order = axis_order - self.context = context - self._rng = rng self.force_has_data = force_has_data - self.empty_value = empty_value - self.pad = pad - self._current_center = None - self._current_spatial_transforms = None - self.input_sources: dict[str, CellMapImage] = {} - self._device = ( - torch.device(device) if device is not None else torch.device("cpu") - ) - for array_name, array_info in self.input_arrays.items(): - self.input_sources[array_name] = CellMapImage( - self.raw_path, - "raw", - array_info["scale"], # type: ignore - tuple(map(int, array_info["shape"])), - value_transform=self.raw_value_transforms, - context=self.context, - pad=self.pad, - pad_value=0, - interpolation="linear", - device=self._device, - ) - self.target_sources = {} - self.has_data = force_has_data or ( - False if (len(self.target_arrays) > 0 and len(self.classes) > 0) else True - ) - for array_name, array_info in self.target_arrays.items(): - if classes is None: - self.target_sources[array_name] = CellMapImage( - self.raw_path, - "raw", - array_info["scale"], # type: ignore - tuple(map(int, array_info["shape"])), - value_transform=self.target_value_transforms, - context=self.context, - pad=self.pad, - pad_value=0, - interpolation="linear", - device=self._device, - ) - else: - self.target_sources[array_name] = self.get_target_array(array_info) - - self._executor = None - self._executor_pid = None - if max_workers is not None: - self._max_workers = max_workers - else: - # For HPC with I/O lag: prioritize I/O parallelism over CPU count - # Estimate based on number of concurrent I/O operations needed - estimated_concurrent_io = len(self.input_arrays) + len(self.target_arrays) - # Use at least 2 workers (input + target), cap at reasonable limit - # to avoid thread overhead while allowing parallel I/O requests - self._max_workers = min( - max(estimated_concurrent_io, 2), # At least 2 workers - int(os.environ.get("CELLMAP_MAX_WORKERS", 8)), # Cap at 8 by default - ) - - logger.info( - "CellMapDataset: OS=%s backend=%s max_workers=%d max_concurrent_reads=%s " - "inputs=%d targets=%d classes=%d", - _OS_NAME, - _DATA_BACKEND, - self._max_workers, - ( - str(MAX_CONCURRENT_READS) - if MAX_CONCURRENT_READS is not None - else "unlimited" - ), - len(self.input_arrays), - len(self.target_arrays), - len(self.classes), - ) - - @property - def executor(self) -> ThreadPoolExecutor | _ImmediateExecutor: - """ - Lazy accessor for the per-process shared executor. - - On Windows + TensorStore returns the module-level ``_ImmediateExecutor`` - singleton (runs tasks synchronously to avoid native crashes). - - On all other platforms returns the per-process ``ThreadPoolExecutor`` - singleton from ``_get_process_executor``. Using a process-level pool - instead of a per-dataset pool prevents thread explosion when many - ``CellMapDataset`` instances exist inside DataLoader worker processes. - - ``self._executor`` and ``self._executor_pid`` cache the result so the - PID lookup only happens when the PID changes (i.e. after fork). - """ - current_pid = os.getpid() - - if _USE_IMMEDIATE_EXECUTOR: - if self._executor is None or self._executor_pid != current_pid: - self._executor = _IMMEDIATE_EXECUTOR - self._executor_pid = current_pid - return self._executor # type: ignore[return-value] - - # Re-fetch the process executor if PID changed (post-fork child). - if self._executor_pid != current_pid or self._executor is None: - self._executor = _get_process_executor(self._max_workers) - self._executor_pid = current_pid - return self._executor - - def __str__(self) -> str: - return f"CellMapDataset(raw_path={self.raw_path}, target_path={self.target_path}, classes={self.classes})" - - def __del__(self): - """Release the cached executor reference. - - The shared per-process pool is shut down by the module-level atexit - handler (_shutdown_process_executor), not per-dataset, so we must not - call shutdown() here. - """ - self._executor = None + self._rng = np.random.default_rng() - def close(self) -> None: - """Release the cached executor reference. + # Parse target path to get template and annotated classes + gt_path_template, annotated_classes = split_target_path(target_path) + self._gt_path_template = gt_path_template + self._annotated_classes = set(annotated_classes) - The per-process shared pool is shut down via the module-level atexit - handler (_shutdown_process_executor). Individual datasets must not - shut it down, as doing so would break all other datasets in the process. - """ - self._executor = None - - def __new__( - cls, - raw_path: str, - target_path: str, - classes: Sequence[str] | None, - input_arrays: Mapping[str, Mapping[str, Sequence[int | float]]], - target_arrays: Mapping[str, Mapping[str, Sequence[int | float]]] | None = None, - spatial_transforms: Optional[Mapping[str, Mapping]] = None, # type: ignore - raw_value_transforms: Optional[Callable] = None, - target_value_transforms: Optional[ - Callable | Sequence[Callable] | Mapping[str, Callable] - ] = None, - class_relation_dict: Optional[Mapping[str, Sequence[str]]] = None, - is_train: bool = False, - axis_order: str = "zyx", - context: Optional[tensorstore.Context] = None, # type: ignore - rng: Optional[torch.Generator] = None, - force_has_data: bool = False, - empty_value: float | int = torch.nan, - pad: bool = True, - device: Optional[str | torch.device] = None, - max_workers: Optional[int] = None, - ): - # If 2D arrays are requested without a slicing axis, create a - # multidataset with 3 datasets, each slicing along one axis. - if is_array_2D(input_arrays, summary=any) or is_array_2D( - target_arrays, summary=any - ): - from cellmap_data.multidataset import CellMapMultiDataset - - logger.info( - "2D arrays requested without slicing axis. Creating datasets " - "that each slice along one axis. If this is not intended, " - "specify the slicing axis in the input and target arrays." + # Build input sources + self.input_sources: dict[str, CellMapImage] = {} + for arr_name, arr_spec in self.input_arrays.items(): + self.input_sources[arr_name] = CellMapImage( + path=raw_path, + target_class=arr_name, + target_scale=arr_spec["scale"], + target_voxel_shape=arr_spec["shape"], + pad=pad, + value_transform=raw_value_transforms, ) - datasets = [] - for axis in range(3): - logger.debug("Creating dataset for axis %d", axis) - input_arrays_2d = { - name: { - "shape": get_sliced_shape( - tuple(map(int, array_info["shape"])), axis - ), - "scale": array_info["scale"], - } - for name, array_info in input_arrays.items() - } - target_arrays_2d = ( - { - name: { - "shape": get_sliced_shape( - tuple(map(int, array_info["shape"])), axis - ), - "scale": array_info["scale"], - } - for name, array_info in target_arrays.items() - } - if target_arrays is not None - else None - ) - logger.debug("Input arrays for axis %d: %s", axis, input_arrays_2d) - logger.debug("Target arrays for axis %d: %s", axis, target_arrays_2d) - dataset_instance = super(CellMapDataset, cls).__new__(cls) - dataset_instance.__init__( - raw_path, - target_path, - classes, - input_arrays_2d, - target_arrays_2d, - spatial_transforms=spatial_transforms, - raw_value_transforms=raw_value_transforms, - target_value_transforms=target_value_transforms, - class_relation_dict=class_relation_dict, - is_train=is_train, - axis_order=axis_order, - context=context, - rng=rng, - force_has_data=force_has_data, - empty_value=empty_value, + + # Build target sources: one CellMapImage or EmptyImage per class + # Use the first (and typically only) target array spec + first_target_spec = next(iter(target_arrays.values())) + self.target_sources: dict[str, CellMapImage | EmptyImage] = {} + for cls in self.classes: + cls_path = gt_path_template.format(label=cls) + value_tx = self._class_value_transform(cls) + if cls in self._annotated_classes and os.path.exists(cls_path): + self.target_sources[cls] = CellMapImage( + path=cls_path, + target_class=cls, + target_scale=first_target_spec["scale"], + target_voxel_shape=first_target_spec["shape"], pad=pad, - device=device, - max_workers=max_workers, + interpolation="nearest", + value_transform=value_tx, + ) + else: + self.target_sources[cls] = EmptyImage( + path=cls_path, + target_class=cls, + target_scale=first_target_spec["scale"], + target_voxel_shape=first_target_spec["shape"], ) - datasets.append(dataset_instance) - return CellMapMultiDataset( - classes=classes, - input_arrays=input_arrays, - target_arrays=target_arrays, - datasets=datasets, - ) - else: - return super().__new__(cls) - - def __reduce__(self): - """ - Support pickling for multiprocessing DataLoader. - """ - args = ( - self.raw_path, - self.target_path, - self.classes, - self.input_arrays, - self.target_arrays, - self.spatial_transforms, - self.raw_value_transforms, - self.target_value_transforms, - self.class_relation_dict, - self.is_train, - self.axis_order, - self.context, - self._rng, - self.force_has_data, - self.empty_value, - self.pad, - self.device.type if hasattr(self.device, "type") else self.device, - self._max_workers, - ) - return (self.__class__, args, self.__dict__) - @cached_property - def center(self) -> Mapping[str, float] | None: - """Returns the center of the dataset in world units.""" - if self.bounding_box is None: + def _class_value_transform(self, cls: str) -> Optional[Callable]: + """Return the value transform for a specific class.""" + if self.target_value_transforms is None: return None - return { - c: start + (stop - start) / 2 - for c, (start, stop) in self.bounding_box.items() - } + if callable(self.target_value_transforms): + return self.target_value_transforms + if isinstance(self.target_value_transforms, Mapping): + return self.target_value_transforms.get(cls) + return None + + # ------------------------------------------------------------------ + # Spatial properties + # ------------------------------------------------------------------ @cached_property - def largest_voxel_sizes(self) -> Mapping[str, float]: - """Returns the largest voxel size of the dataset.""" - largest_voxel_size = dict.fromkeys(self.axis_order, 0.0) - for source in list(self.input_sources.values()) + list( + def bounding_box(self) -> dict[str, tuple[float, float]] | None: + """Intersection of all source bounding boxes.""" + box = None + for src in list(self.input_sources.values()) + list( self.target_sources.values() ): - if isinstance(source, dict): - for _, source in source.items(): - if not hasattr(source, "scale") or source.scale is None: # type: ignore - continue - for c, size in source.scale.items(): # type: ignore - largest_voxel_size[c] = max(largest_voxel_size[c], size) - else: - if not hasattr(source, "scale") or source.scale is None: - continue - for c, size in source.scale.items(): - largest_voxel_size[c] = max(largest_voxel_size[c], size) - return largest_voxel_size + bb = src.bounding_box + if bb is None: + continue + box = bb if box is None else box_intersection(box, bb) + if box is None: + return None + return box @cached_property - def bounding_box(self) -> Mapping[str, list[float]]: - """Returns the bounding box of the dataset.""" - all_sources = list(self.input_sources.values()) + list( + def sampling_box(self) -> dict[str, tuple[float, float]] | None: + """Intersection of all source sampling boxes.""" + box = None + for src in list(self.input_sources.values()) + list( self.target_sources.values() - ) - # Flatten to individual CellMapImage objects - flat_sources = [] - for source in all_sources: - if isinstance(source, dict): - flat_sources.extend( - s for s in source.values() if hasattr(s, "bounding_box") - ) - elif hasattr(source, "bounding_box"): - flat_sources.append(source) - - # Prefetch bounding boxes in parallel (each triggers a zarr group open) - # Use self.executor to respect Windows+TensorStore immediate executor handling - boxes = list(self.executor.map(lambda s: s.bounding_box, flat_sources)) - - bounding_box: dict[str, list[float]] | None = None - for box in boxes: - bounding_box = self._get_box_intersection(box, bounding_box) - - if bounding_box is None: - logger.warning( - "Bounding box is None. This may cause errors during sampling." - ) - bounding_box = {c: [-np.inf, np.inf] for c in self.axis_order} - return bounding_box + ): + sb = src.sampling_box + if sb is None: + continue + box = sb if box is None else box_intersection(box, sb) + if box is None: + return None + return box @cached_property - def bounding_box_shape(self) -> Mapping[str, int]: - """Returns the shape of the bounding box of the dataset in voxels of the largest voxel size requested.""" - return self._get_box_shape(self.bounding_box) + def _target_scale(self) -> dict[str, float]: + """Scale of the first target array spec.""" + first = next(iter(self.target_arrays.values())) + scale_seq = first["scale"] + first_target_src = next(iter(self.target_sources.values())) + axes = first_target_src.axes + return {c: float(s) for c, s in zip(axes, scale_seq)} + + # ------------------------------------------------------------------ + # Dataset interface + # ------------------------------------------------------------------ - @cached_property - def sampling_box(self) -> Mapping[str, list[float]]: - """Returns the sampling box of the dataset (i.e. where centers can be drawn from and still have full samples drawn from within the bounding box).""" - all_sources = list(self.input_sources.values()) + list( + def __len__(self) -> int: + sb = self.sampling_box + if sb is None: + return 0 + grid = box_shape(sb, self._target_scale) + total = 1 + for v in grid.values(): + total *= v + return total + + def __getitem__(self, idx: int) -> dict[str, Any]: + center = self._idx_to_center(idx) + transforms = self._generate_spatial_transforms() + + # Set transforms on all sources + for src in list(self.input_sources.values()) + list( self.target_sources.values() - ) - flat_sources = [] - for source in all_sources: - if isinstance(source, dict): - flat_sources.extend( - s for s in source.values() if hasattr(s, "sampling_box") - ) - elif hasattr(source, "sampling_box"): - flat_sources.append(source) - - # Prefetch sampling boxes in parallel; bounding_box is already cached - # from the bounding_box property so these are cheap if called after it. - # Use self.executor to respect Windows+TensorStore immediate executor handling - boxes = list(self.executor.map(lambda s: s.sampling_box, flat_sources)) - - sampling_box: dict[str, list[float]] | None = None - for box in boxes: - sampling_box = self._get_box_intersection(box, sampling_box) - - if sampling_box is None: - logger.warning( - "Sampling box is None. This may cause errors during sampling." - ) - sampling_box = {c: [-np.inf, np.inf] for c in self.axis_order} - return sampling_box + ): + src.set_spatial_transforms(transforms) - @cached_property - def sampling_box_shape(self) -> dict[str, int]: - """Returns the shape of the sampling box of the dataset in voxels of the largest voxel size requested.""" - shape = self._get_box_shape(self.sampling_box) - if self.pad: - for c, size in shape.items(): - if size <= 0: - logger.debug( - "Sampling box for axis %s has size %d <= 0. " - "Setting to 1 and padding.", - c, - size, - ) - shape[c] = 1 - return shape + result: dict[str, Any] = {"idx": torch.tensor(idx)} - @cached_property - def size(self) -> int: - """Returns the size of the dataset in voxels of the largest voxel size requested.""" - return int( - np.prod([stop - start for start, stop in self.bounding_box.values()]) - ) + for name, src in self.input_sources.items(): + result[name] = src[center] - @cached_property - def class_counts(self) -> Mapping[str, Mapping[str, float]]: - """Returns the number of pixels for each class in the ground truth data, normalized by the resolution.""" - class_counts = {"totals": dict.fromkeys(self.classes, 0.0)} - class_counts["totals"].update({c + "_bg": 0.0 for c in self.classes}) - for array_name, sources in self.target_sources.items(): - class_counts[array_name] = {} - for label, source in sources.items(): - if isinstance(source, CellMapImage): - class_counts[array_name][label] = source.class_counts - class_counts[array_name][label + "_bg"] = source.bg_count - class_counts["totals"][label] += source.class_counts - class_counts["totals"][label + "_bg"] += source.bg_count - else: - class_counts[array_name][label] = 0.0 - class_counts[array_name][label + "_bg"] = 0.0 - return class_counts + for cls, src in self.target_sources.items(): + result[cls] = src[center] - @cached_property - def class_weights(self) -> dict[str, float]: - """Returns the class weights for the dataset based on the number of samples in each class. Classes without any samples will have a weight of 1.""" - if self.classes is None: - return {} + # Reset spatial transforms + for src in list(self.input_sources.values()) + list( + self.target_sources.values() + ): + src.set_spatial_transforms(None) + + return result + + def _idx_to_center(self, idx: int) -> dict[str, float]: + """Convert flat index to world centre coordinates.""" + sb = self.sampling_box + if sb is None: + raise IndexError(f"sampling_box is None for {self.raw_path!r}") + scale = self._target_scale + grid = box_shape(sb, scale) + axes = list(sb.keys()) + shape_tuple = tuple(grid[ax] for ax in axes) + vox_idx = np.unravel_index(int(idx) % max(1, len(self)), shape_tuple) return { - c: ( - self.class_counts["totals"][c + "_bg"] / self.class_counts["totals"][c] - if self.class_counts["totals"][c] != 0 - else 1 - ) - for c in self.classes - } - - @cached_property - def validation_indices(self) -> Sequence[int]: - """Returns the indices of the dataset that will produce non-overlapping tiles for use in validation, based on the largest requested voxel size.""" - chunk_size = { - c: np.ceil(size - self.sampling_box_shape[c]).astype(int) - for c, size in self.bounding_box_shape.items() - } - return self.get_indices(chunk_size) - - @property - def device(self) -> torch.device: - """Returns the device for the dataset.""" - return self._device - - def __len__(self) -> int: - """Returns the number of unique patches in the dataset.""" - if not self.has_data and not self.force_has_data: - return 0 - # Return at least 1 if the dataset has data, so that samplers can be initialized - return int(max(np.prod(list(self.sampling_box_shape.values())), 1)) - - def __getitem__(self, idx: ArrayLike) -> dict[str, torch.Tensor]: - """Returns a crop of the input and target data as PyTorch tensors, corresponding to the coordinate of the unwrapped index.""" - try: - idx_arr = np.array(idx) - if np.any(idx_arr < 0): - idx_arr[idx_arr < 0] = len(self) + idx_arr[idx_arr < 0] - - center_indices = np.unravel_index( - idx_arr, [self.sampling_box_shape[c] for c in self.axis_order] - ) - except ValueError: - logger.error( - "Index %s out of bounds for dataset of length %d", idx, len(self) - ) - logger.warning("Returning closest index in bounds") - center_indices = [self.sampling_box_shape[c] - 1 for c in self.axis_order] - center = { - c: float( - center_indices[i] * self.largest_voxel_sizes[c] - + self.sampling_box[c][0] - ) - for i, c in enumerate(self.axis_order) + ax: sb[ax][0] + (vox_idx[i] + 0.5) * scale[ax] for i, ax in enumerate(axes) } - self._current_idx = idx - self._current_center = center - spatial_transforms = self.generate_spatial_transforms() - - def get_input_array(array_name: str) -> tuple[str, torch.Tensor]: - self.input_sources[array_name].set_spatial_transforms(spatial_transforms) - with limit_tensorstore_reads(): - array = self.input_sources[array_name][center] - return array_name, array.squeeze()[None, ...] - - futures = [ - self.executor.submit(get_input_array, array_name) - for array_name in self.input_arrays.keys() - ] - - if self.raw_only: - - def get_target_array(array_name: str) -> tuple[str, torch.Tensor]: - self.target_sources[array_name].set_spatial_transforms( - spatial_transforms - ) - with limit_tensorstore_reads(): - array = self.target_sources[array_name][center] - return array_name, array.squeeze()[None, ...] - - else: - - def get_target_array(array_name: str) -> tuple[str, torch.Tensor]: - class_arrays = dict.fromkeys(self.classes) # Force order of classes - inferred_arrays = [] + def _generate_spatial_transforms(self) -> dict | None: + """Generate random spatial transforms from the config for one sample.""" + cfg = self.spatial_transforms_config + if not cfg: + return None - def get_label_array( - label: str, - ) -> tuple[str, torch.Tensor | None]: - source = self.target_sources[array_name].get(label) - if isinstance(source, (CellMapImage, EmptyImage)): - source.set_spatial_transforms(spatial_transforms) - with limit_tensorstore_reads(): - array = source[center].squeeze() - else: - array = None - return label, array + result: dict[str, Any] = {} - # Run label reads synchronously within this pool thread. - # Submitting sub-futures to the same shared pool and blocking - # on as_completed() inside a pool thread causes deadlock when - # all pool slots are occupied by blocking get_target_array tasks. - for label in self.classes: - lbl, array = get_label_array(label) - if array is not None: - class_arrays[lbl] = array + # Mirror + mirror_cfg = cfg.get("mirror") + if mirror_cfg: + if isinstance(mirror_cfg, dict): + result["mirror"] = { + ax: bool(self._rng.random() < 0.5) if enabled else False + for ax, enabled in mirror_cfg.items() + } + else: + axes = next(iter(self.input_sources.values())).axes + result["mirror"] = {ax: bool(self._rng.random() < 0.5) for ax in axes} + + # Transpose + if cfg.get("transpose"): + axes = next(iter(self.input_sources.values())).axes + n = len(axes) + perm = list(self._rng.permutation(n)) + result["transpose"] = perm + + # Rotate + rotate_cfg = cfg.get("rotate") + if rotate_cfg: + axes = next(iter(self.input_sources.values())).axes + if isinstance(rotate_cfg, dict): + # e.g. {"z": 45} → random angle in [-45, 45] degrees + angle_dict: dict[str, float] = {} + for ax, max_angle in rotate_cfg.items(): + if isinstance(max_angle, (list, tuple)): + lo, hi = max_angle else: - inferred_arrays.append(lbl) - - empty_array = self.get_empty_store( - self.target_arrays[array_name], device=self.device - ) - - def infer_label_array(label: str) -> tuple[str, torch.Tensor]: - array = empty_array.clone() - other_labels = self.target_sources[array_name].get(label, []) - for other_label in other_labels: - other_array = class_arrays.get(other_label) - if other_array is not None: - mask = other_array > 0 - array[mask] = 0 - return label, array - - for label in inferred_arrays: - lbl, array = infer_label_array(label) - class_arrays[lbl] = array - - stacked_arrays = [] - for label in self.classes: - arr = class_arrays.get(label) - if arr is not None: - stacked_arrays.append( - arr.to(self.device, non_blocking=True) - if arr.device != self.device - else arr - ) - - array = torch.stack(stacked_arrays) - if array.shape[0] != len(self.classes): - raise ValueError( - f"Target array {array_name} has {array.shape[0]} classes, " - f"but {len(self.classes)} were expected." - ) - return array_name, array + lo, hi = -float(max_angle), float(max_angle) + angle_dict[ax] = float(self._rng.uniform(lo, hi)) + R = _make_rotation_matrix(axes, angle_dict) + if R is not None: + result["rotation_matrix"] = R - futures += [ - self.executor.submit(get_target_array, array_name) - for array_name in self.target_arrays.keys() - ] + return result if result else None - outputs: dict[str, Any] = { - "__metadata__": self.metadata, - } - - for future in as_completed(futures): - array_name, array = future.result() - outputs[array_name] = array - - return outputs - - @property - def metadata(self) -> dict[str, Any]: - """Returns metadata about the dataset.""" - metadata = { - "raw_path": self.raw_path, - "current_center": self._current_center, - "current_idx": self._current_idx, - } + # ------------------------------------------------------------------ + # Sampling utilities + # ------------------------------------------------------------------ - if self._current_spatial_transforms is not None: - metadata["current_spatial_transforms"] = self._current_spatial_transforms - if not self.raw_only: - metadata["target_path_str"] = self.target_path_str - metadata["class_weights"] = self.class_weights - return metadata + def get_indices(self, chunk_size: Mapping[str, float]) -> list[int]: + """Flat indices that tile the sampling box without overlap. - def __repr__(self) -> str: - """Returns a string representation of the dataset.""" - return ( - f"CellMapDataset(\n\tRaw path: {self.raw_path}\n\t" - f"GT path(s): {self.target_path}\n\tClasses: {self.classes})" - ) - - def get_empty_store( - self, array_info: Mapping[str, Sequence[int | float]], device: torch.device - ) -> torch.Tensor: - """Returns an empty store, based on the requested array.""" - shape = tuple(map(int, array_info["shape"])) - empty_store = torch.ones(shape, device=device) * self.empty_value - return empty_store.squeeze() - - def get_target_array( - self, array_info: Mapping[str, Sequence[int | float]] - ) -> dict[str, CellMapImage | EmptyImage | Sequence[str]]: - """ - Returns a target array source for the dataset. - - Creates a dictionary of image sources for each class. If ground truth - data is missing for a class, it can be inferred from other mutually - exclusive classes. + Parameters + ---------- + chunk_size: + World-space tile size per axis in nm (e.g. target output size). """ - empty_store = self.get_empty_store(array_info, device=torch.device("cpu")) - target_array = {} - for i, label in enumerate(self.classes): - target_array[label] = self.get_label_array( - label, i, array_info, empty_store + sb = self.sampling_box + if sb is None: + return [] + scale = self._target_scale + grid = box_shape(sb, scale) + axes = list(sb.keys()) + + # Tile with the given chunk size + chunk_grid = { + ax: max( + 1, int(round((sb[ax][1] - sb[ax][0]) / chunk_size.get(ax, scale[ax]))) ) - - for label in self.classes: - if isinstance(target_array.get(label), (CellMapImage, EmptyImage)): - continue - - is_empty = True - related_labels = target_array.get(label) - if isinstance(related_labels, list): - for other_label in related_labels: - if isinstance(target_array.get(other_label), CellMapImage): - is_empty = False - break - if is_empty: - shape = tuple(map(int, array_info["shape"])) - target_array[label] = EmptyImage( - label, array_info["scale"], shape, empty_store # type: ignore - ) - - return target_array - - def get_label_array( - self, - label: str, - i: int, - array_info: Mapping[str, Sequence[int | float]], - empty_store: torch.Tensor, - ) -> CellMapImage | EmptyImage | Sequence[str]: - """Returns a target array source for a specific class in the dataset.""" - if label in self.classes_with_path: - value_transform: Callable | None = None - if isinstance(self.target_value_transforms, dict): - value_transform = self.target_value_transforms.get(label) - elif isinstance(self.target_value_transforms, list): - value_transform = self.target_value_transforms[i] - elif callable(self.target_value_transforms): - value_transform = self.target_value_transforms - - array = CellMapImage( - self.target_path_str.format(label=label), - label, - array_info["scale"], # type: ignore - tuple(map(int, array_info["shape"])), - value_transform=value_transform, - context=self.context, - pad=self.pad, - pad_value=self.empty_value, - interpolation="nearest", - device=self._device, - ) - if not self.has_data and not self.force_has_data: - self.has_data = array.class_counts > 0 - logger.debug(f"{str(self)} has data: {self.has_data}") - else: - if ( - self.class_relation_dict is not None - and label in self.class_relation_dict - ): - array = self.class_relation_dict[label] - else: - shape = tuple(map(int, array_info["shape"])) - array = EmptyImage( - label, array_info["scale"], shape, empty_store, device=self._device # type: ignore - ) - return array - - def _get_box_shape(self, source_box: Mapping[str, list[float]]) -> dict[str, int]: - """Returns the shape of the box in voxels of the largest voxel size requested.""" - box_shape = {} - for c, (start, stop) in source_box.items(): - size = stop - start - size /= self.largest_voxel_sizes[c] - box_shape[c] = int(np.floor(size)) - return box_shape - - def _get_box_intersection( - self, - source_box: Mapping[str, list[float]] | None, - current_box: dict[str, list[float]] | None, - ) -> dict[str, list[float]] | None: - """Returns the intersection of the source and current boxes.""" - if source_box is None: - return current_box - if current_box is None: - return {k: v[:] for k, v in source_box.items()} - - result_box = {k: v[:] for k, v in current_box.items()} - for c, (start, stop) in source_box.items(): - if stop <= start: - raise ValueError(f"Invalid box: start={start}, stop={stop}") - result_box[c][0] = max(result_box[c][0], start) - result_box[c][1] = min(result_box[c][1], stop) - return result_box - - def verify(self) -> bool: - """Verifies that the dataset is valid to draw samples from.""" - try: - return len(self) > 0 - except Exception as e: - logger.warning("Dataset verification failed: %s", e) - return False - - def get_indices(self, chunk_size: Mapping[str, int]) -> Sequence[int]: - """Returns the indices of the dataset that will tile the dataset according to the chunk_size.""" - # TODO: ADD TEST - # Get padding per axis - indices_dict = {} - for c, size in chunk_size.items(): - if size <= 0: - indices_dict[c] = np.array([0], dtype=int) - else: - indices_dict[c] = np.arange( - 0, self.sampling_box_shape[c], size, dtype=int - ) - + for ax in axes + } indices = [] - shape_values = [self.sampling_box_shape[c] for c in self.axis_order] - for i in np.ndindex(*[len(indices_dict[c]) for c in self.axis_order]): - index = [indices_dict[c][j] for c, j in zip(self.axis_order, i)] - index = np.ravel_multi_index(index, shape_values) - indices.append(index) + shape_tuple = tuple(grid[ax] for ax in axes) + chunk_tuple = tuple(chunk_grid[ax] for ax in axes) + # Step through the sampling box in chunks + for chunk_idx in np.ndindex(*chunk_tuple): + vox_idx = tuple( + int(chunk_idx[i] * shape_tuple[i] / chunk_tuple[i]) + for i in range(len(axes)) + ) + flat = int(np.ravel_multi_index(vox_idx, shape_tuple)) + indices.append(flat) return indices - def to( - self, device: str | torch.device, non_blocking: bool = True - ) -> "CellMapDataset": - """Sets the device for the dataset.""" - self._device = torch.device(device) - device_str = str(self._device) - all_sources = list(self.input_sources.values()) + list( - self.target_sources.values() - ) - for source in all_sources: - if isinstance(source, dict): - for sub_source in source.values(): - if hasattr(sub_source, "to"): - sub_source.to(device_str, non_blocking=non_blocking) - elif hasattr(source, "to"): - source.to(device_str, non_blocking=non_blocking) - return self - - def generate_spatial_transforms(self) -> Optional[Mapping[str, Any]]: - """ - Generates random spatial transforms for training. + def get_crop_class_matrix(self) -> np.ndarray: + """Return a ``[1, n_classes]`` boolean row for ClassBalancedSampler. - Available transforms: - - "mirror": {"axes": {"x": 0.5, "y": 0.5}} - - "transpose": {"axes": ["x", "z"]} - - "rotate": {"axes": {"z": [-90, 90]}} + True where the class is annotated (non-empty CellMapImage). """ - if not self.is_train or self.spatial_transforms is None: - return None + row = np.array( + [ + not isinstance(self.target_sources.get(cls), EmptyImage) + for cls in self.classes + ], + dtype=bool, + ).reshape(1, -1) + return row + + # ------------------------------------------------------------------ + # Class counts + # ------------------------------------------------------------------ - spatial_transforms: dict[str, Any] = {} - for transform, params in self.spatial_transforms.items(): - if transform == "mirror": - mirrored_axes = [ - axis - for axis, prob in params["axes"].items() - if torch.rand(1, generator=self._rng).item() < prob - ] - if mirrored_axes: - spatial_transforms[transform] = mirrored_axes - elif transform == "transpose": - axes = {axis: i for i, axis in enumerate(self.axis_order)} - permuted_axes = [axes[a] for a in params["axes"]] - permuted_indices = torch.randperm( - len(permuted_axes), generator=self._rng - ) - shuffled_axes = [permuted_axes[i] for i in permuted_indices] - axes.update( - {axis: shuffled_axes[i] for i, axis in enumerate(params["axes"])} - ) - spatial_transforms[transform] = axes - elif transform == "rotate": - rotated_axes = {} - for axis, limits in params["axes"].items(): - angle = ( - torch.rand(1, generator=self._rng).item() - * (limits[1] - limits[0]) - + limits[0] - ) - rotated_axes[axis] = angle - if rotated_axes: - spatial_transforms[transform] = rotated_axes + @property + def class_counts(self) -> dict[str, Any]: + """Aggregate per-class foreground voxel counts from all target sources.""" + totals: dict[str, int] = {} + for cls in self.classes: + src = self.target_sources.get(cls) + if src is not None: + counts = src.class_counts + totals[cls] = counts.get(cls, 0) else: - raise ValueError(f"Unknown spatial transform: {transform}") + totals[cls] = 0 + return {"totals": totals} - self._current_spatial_transforms = spatial_transforms - return spatial_transforms + # ------------------------------------------------------------------ + # Misc + # ------------------------------------------------------------------ - def set_raw_value_transforms(self, transforms: Callable) -> None: - """Sets the raw value transforms for the dataset.""" - self.raw_value_transforms = transforms - for source in self.input_sources.values(): - source.value_transform = transforms + def verify(self) -> bool: + """Return True if the dataset has at least one valid sample.""" + return len(self) > 0 or self.force_has_data - def set_target_value_transforms(self, transforms: Callable) -> None: - """Sets the ground truth value transforms for the dataset.""" + def set_raw_value_transforms(self, transforms: Optional[Callable]) -> None: + self.raw_value_transforms = transforms + for src in self.input_sources.values(): + src.value_transform = transforms + # Reset cached properties that depend on sources + self.__dict__.pop("bounding_box", None) + self.__dict__.pop("sampling_box", None) + + def set_target_value_transforms( + self, transforms: Optional[Callable | Mapping[str, Callable]] + ) -> None: self.target_value_transforms = transforms - for sources in self.target_sources.values(): - for source in sources.values(): - if isinstance(source, CellMapImage): - source.value_transform = transforms - - def reset_arrays(self, array_type: str = "target") -> None: - """Resets the specified arrays for the dataset.""" - if array_type.lower() == "input": - self.input_sources = {} - for array_name, array_info in self.input_arrays.items(): - self.input_sources[array_name] = CellMapImage( - self.raw_path, - "raw", - array_info["scale"], # type: ignore - tuple(map(int, array_info["shape"])), - value_transform=self.raw_value_transforms, - context=self.context, - pad=self.pad, - pad_value=0, - ) - elif array_type.lower() == "target": - self.target_sources = {} - self.has_data = False - for array_name, array_info in self.target_arrays.items(): - self.target_sources[array_name] = self.get_target_array(array_info) - else: - raise ValueError(f"Unknown dataset array type: {array_type}") - - def get_random_subset_sampler( - self, num_samples: int, rng: Optional[torch.Generator] = None, **kwargs: Any - ) -> MutableSubsetRandomSampler: - """ - Returns a random sampler that yields exactly `num_samples` indices from this subset. - - If `num_samples` ≤ total number of available indices, samples without replacement. - - If `num_samples` > total number of available indices, samples with replacement using repeated shuffles to minimize duplicates. - """ - indices_generator = functools.partial( - self.get_random_subset_indices, num_samples, rng, **kwargs - ) - - return MutableSubsetRandomSampler(indices_generator) + for cls, src in self.target_sources.items(): + if isinstance(src, CellMapImage): + src.value_transform = self._class_value_transform(cls) - def get_random_subset_indices( - self, num_samples: int, rng: Optional[torch.Generator] = None, **kwargs: Any - ) -> Sequence[int]: - inds = min_redundant_inds(len(self), num_samples, rng=rng) - return inds.tolist() + def set_spatial_transforms(self, transforms: Optional[Mapping[str, Any]]) -> None: + self.spatial_transforms_config = transforms - def get_subset_random_sampler( - self, - num_samples: int, - weighted: bool = False, - rng: Optional[torch.Generator] = None, - ) -> MutableSubsetRandomSampler: - """ - Returns a subset random sampler for the dataset. - - Args: - ---- - num_samples: The number of samples. - weighted: Whether to use weighted sampling. - rng: The random number generator. - - Returns: - ------- - A subset random sampler. - """ - if num_samples is None: - num_samples = len(self) * 2 - - if weighted: - raise NotImplementedError("Weighted sampling is not yet implemented.") - else: - indices_generator = lambda: min_redundant_inds( - len(self), num_samples, rng=rng - ) - - return MutableSubsetRandomSampler(indices_generator, rng=rng) + def to(self, device: str | torch.device) -> "CellMapDataset": + """No-op for API compatibility (tensors returned on CPU).""" + return self - @staticmethod - def empty() -> "CellMapDataset": - """Creates an empty dataset.""" - # Directly instantiate to bypass __new__ logic - instance = super(CellMapDataset, CellMapDataset).__new__(CellMapDataset) - instance.__init__("", "", [], {}, {}, force_has_data=False) - instance.has_data = False - # Set cached_property value directly in __dict__ to bypass computation - instance.__dict__["sampling_box_shape"] = {c: 0 for c in instance.axis_order} - return instance + def __repr__(self) -> str: + return ( + f"CellMapDataset(raw={self.raw_path!r}, " + f"classes={self.classes}, len={len(self)})" + ) diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index d5dcca7..780dbfa 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -1,31 +1,52 @@ -# %% +"""CellMapDatasetWriter: writes model predictions to zarr.""" + +from __future__ import annotations + import logging from functools import cached_property -from typing import Callable, Mapping, Optional, Sequence +from typing import Any, Callable, Mapping, Optional, Sequence import numpy as np -import tensorstore import torch from torch.utils.data import Dataset, Subset -from upath import UPath from .image import CellMapImage from .image_writer import ImageWriter +from .utils.geometry import box_shape, box_union logger = logging.getLogger(__name__) -if logger.level == logging.NOTSET: - logger.setLevel(logging.INFO) -# Special keys that should not be written to disk -_METADATA_KEYS = {"idx"} +_SKIP_KEYS = frozenset({"idx", "__metadata__"}) -# %% class CellMapDatasetWriter(Dataset): - """ - Writes a dataset to disk in a format readable by CellMapDataset. - - This is useful for saving model predictions to disk. + """Writes model predictions back into zarr arrays at world coordinates. + + Parameters + ---------- + raw_path: + Path to the raw EM zarr group (for reading input patches). + target_path: + Base path for output zarr groups (class sub-groups are written + under this path). + classes: + Classes to write. + input_arrays: + Input array specs for reading raw EM patches. + target_arrays: + Output array specs (shape/scale) for each prediction. + target_bounds: + ``{array_name: {axis: (min_nm, max_nm)}}`` bounding boxes for + each target array. Determines the spatial extent of the output. + overwrite: + If ``True``, existing output data is overwritten. + device: + Ignored (API compatibility). + raw_value_transforms: + Value transform applied to raw input patches. + model_classes: + Full list of classes the model was trained on (superset of + *classes*). Used to map model outputs to the correct channel. """ def __init__( @@ -33,490 +54,264 @@ def __init__( raw_path: str, target_path: str, classes: Sequence[str], - input_arrays: Mapping[str, Mapping[str, Sequence[int | float]]], - target_arrays: Mapping[str, Mapping[str, Sequence[int | float]]], - target_bounds: Mapping[str, Mapping[str, list[float]]], - raw_value_transforms: Optional[Callable] = None, - axis_order: str = "zyx", - context: Optional[tensorstore.Context] = None, # type: ignore - rng: Optional[torch.Generator] = None, - empty_value: float | int = 0, + input_arrays: Mapping[str, Mapping[str, Any]], + target_arrays: Mapping[str, Mapping[str, Any]], + target_bounds: Optional[Mapping[str, Mapping[str, Sequence[float]]]] = None, overwrite: bool = False, device: Optional[str | torch.device] = None, + raw_value_transforms: Optional[Callable] = None, + model_classes: Optional[Sequence[str]] = None, + axis_order: str = "zyx", + context: Optional[Any] = None, # ignored – API compat + **kwargs: Any, ) -> None: - """Initializes the CellMapDatasetWriter. - - Args: - ---- - raw_path: Full path to the raw data Zarr, excluding multiscale level. - target_path: Full path to the ground truth Zarr, excluding class name. - classes: The classes in the dataset. - input_arrays: Input arrays for processing, with shape, scale, and - optional scale_level. - target_arrays: Target arrays to write, with the same format as input_arrays. - target_bounds: Bounding boxes for each target array in world units. - raw_value_transforms: Value transforms for raw data. - axis_order: Order of axes (e.g., "zyx"). - context: TensorStore context. - rng: Random number generator. - empty_value: Value for empty data. - overwrite: Whether to overwrite existing data. - device: Device for torch tensors ("cuda", "mps", or "cpu"). - """ self.raw_path = raw_path self.target_path = target_path - self.classes = classes - self.input_arrays = input_arrays - self.target_arrays = target_arrays - self.target_bounds = target_bounds + self.classes = list(classes) + self.input_arrays = dict(input_arrays) + self.target_arrays = dict(target_arrays) + self.target_bounds = dict(target_bounds) if target_bounds else {} + self.overwrite = overwrite self.raw_value_transforms = raw_value_transforms + self.model_classes = list(model_classes) if model_classes else list(classes) self.axis_order = axis_order - self.context = context - self._rng = rng - self.empty_value = empty_value - self.overwrite = overwrite - self._current_center = None - self._current_idx = None + + # Build input sources self.input_sources: dict[str, CellMapImage] = {} - for array_name, array_info in self.input_arrays.items(): - self.input_sources[array_name] = CellMapImage( - self.raw_path, - "raw", - array_info["scale"], - array_info["shape"], # type: ignore - value_transform=self.raw_value_transforms, - context=self.context, + for arr_name, arr_spec in self.input_arrays.items(): + self.input_sources[arr_name] = CellMapImage( + path=raw_path, + target_class=arr_name, + target_scale=arr_spec["scale"], + target_voxel_shape=arr_spec["shape"], pad=True, - pad_value=0, + pad_value=0.0, interpolation="linear", + value_transform=raw_value_transforms, ) + + # Build output ImageWriter instances per (target_array, class) self.target_array_writers: dict[str, dict[str, ImageWriter]] = {} - for array_name, array_info in self.target_arrays.items(): - self.target_array_writers[array_name] = self.get_target_array_writer( - array_name, array_info - ) - self._device: str | torch.device = device if device is not None else "cpu" - if device is not None: - self.to(device, non_blocking=True) + for arr_name, arr_spec in self.target_arrays.items(): + bounds = self.target_bounds.get(arr_name, {}) + self.target_array_writers[arr_name] = {} + for cls in self.classes: + writer_path = f"{target_path}/{cls}" + self.target_array_writers[arr_name][cls] = ImageWriter( + path=writer_path, + target_class=cls, + scale=arr_spec["scale"], + bounding_box=bounds, + write_voxel_shape=arr_spec["shape"], + axis_order=axis_order, + overwrite=overwrite, + ) - @cached_property - def center(self) -> Mapping[str, float] | None: - """Returns the center of the dataset in world units.""" - if self.bounding_box is None: - return None - return { - c: start + (stop - start) / 2 - for c, (start, stop) in self.bounding_box.items() - } + # ------------------------------------------------------------------ + # Spatial properties + # ------------------------------------------------------------------ @cached_property - def smallest_voxel_sizes(self) -> Mapping[str, float]: - """Returns the smallest voxel size of the dataset.""" - smallest_voxel_size = {c: np.inf for c in self.axis_order} - all_sources = list(self.input_sources.values()) + list( - self.target_array_writers.values() - ) - for source in all_sources: - if isinstance(source, dict): - for sub_source in source.values(): - if hasattr(sub_source, "scale") and sub_source.scale is not None: - for c, size in sub_source.scale.items(): - smallest_voxel_size[c] = min(smallest_voxel_size[c], size) - elif hasattr(source, "scale") and source.scale is not None: - for c, size in source.scale.items(): - smallest_voxel_size[c] = min(smallest_voxel_size[c], size) - return smallest_voxel_size + def bounding_box(self) -> dict[str, tuple[float, float]] | None: + """Union of all target bounds.""" + result = None + for bounds in self.target_bounds.values(): + box = {ax: (float(bounds[ax][0]), float(bounds[ax][1])) for ax in bounds} + result = box if result is None else box_union(result, box) + return result @cached_property - def smallest_target_array(self) -> Mapping[str, float]: - """Returns the smallest target array in world units.""" - smallest_target_array = {c: np.inf for c in self.axis_order} - for writer in self.target_array_writers.values(): - for _, writer in writer.items(): - for c, size in writer.write_world_shape.items(): - smallest_target_array[c] = min(smallest_target_array[c], size) - return smallest_target_array + def _write_scale(self) -> dict[str, float]: + """Scale of the first target array.""" + first_spec = next(iter(self.target_arrays.values())) + axes = list(self.axis_order[-len(first_spec["scale"]) :]) + return {c: float(s) for c, s in zip(axes, first_spec["scale"])} @cached_property - def bounding_box(self) -> Mapping[str, list[float]]: - """Returns the bounding box inclusive of all the target images.""" - bounding_box = None - for current_box in self.target_bounds.values(): - bounding_box = self._get_box_union(current_box, bounding_box) - if bounding_box is None: - logger.warning( - "Bounding box is None. This may cause errors during sampling." - ) - bounding_box = {c: [-np.inf, np.inf] for c in self.axis_order} - return bounding_box + def _write_voxel_shape(self) -> dict[str, int]: + first_spec = next(iter(self.target_arrays.values())) + axes = list(self.axis_order[-len(first_spec["shape"]) :]) + return {c: int(t) for c, t in zip(axes, first_spec["shape"])} @cached_property - def bounding_box_shape(self) -> Mapping[str, int]: - """Returns the shape of the bounding box of the dataset in voxels of the smallest voxel size requested.""" - return self._get_box_shape(self.bounding_box) + def sampling_box(self) -> dict[str, tuple[float, float]] | None: + """Bounding box shrunk by half the write patch size.""" + bb = self.bounding_box + if bb is None: + return None + result: dict[str, tuple[float, float]] = {} + half = { + ax: self._write_scale[ax] * self._write_voxel_shape[ax] / 2.0 + for ax in self._write_scale + } + for ax in bb: + h = half.get(ax, 0.0) + lo = bb[ax][0] + h + hi = bb[ax][1] - h + if lo >= hi: + return None + result[ax] = (lo, hi) + return result @cached_property - def sampling_box(self) -> Mapping[str, list[float]]: - """Returns the sampling box of the dataset (i.e. where centers should be drawn from and to fully sample within the bounding box).""" - sampling_box = None - for array_name, array_info in self.target_arrays.items(): - padding = { - c: np.ceil((shape * scale) / 2) - for c, shape, scale in zip( - self.axis_order, array_info["shape"], array_info["scale"] - ) - } - this_box = { - c: [bounds[0] + padding[c], bounds[1] - padding[c]] - for c, bounds in self.target_bounds[array_name].items() + def writer_indices(self) -> list[int]: + """Non-overlapping tile indices covering the sampling box.""" + return self.get_indices( + { + ax: self._write_scale[ax] * self._write_voxel_shape[ax] + for ax in self._write_scale } - sampling_box = self._get_box_union(this_box, sampling_box) - if sampling_box is None: - logger.warning( - "Sampling box is None. This may cause errors during sampling." - ) - sampling_box = {c: [-np.inf, np.inf] for c in self.axis_order} - return sampling_box - - @cached_property - def sampling_box_shape(self) -> dict[str, int]: - """Returns the shape of the sampling box.""" - shape = self._get_box_shape(self.sampling_box) - for c, size in shape.items(): - if size <= 0: - logger.debug( - "Sampling box for axis %s has size %d <= 0. " - "Setting to 1 and padding.", - c, - size, - ) - shape[c] = 1 - return shape - - def __len__(self) -> int: - """Returns the number of samples in the dataset.""" - return int(np.prod(list(self.sampling_box_shape.values()))) - - @cached_property - def size(self) -> int: - """Returns the number of samples in the dataset.""" - return int( - np.prod([stop - start for start, stop in self.bounding_box.values()]) ) - @cached_property - def writer_indices(self) -> Sequence[int]: - """Returns the indices of the dataset that will produce non-overlapping tiles for use in writer, based on the smallest requested target array.""" - return self.get_indices(self.smallest_target_array) - @cached_property def blocks(self) -> Subset: - """A subset of the validation datasets, tiling the validation datasets with non-overlapping blocks.""" + """Subset of this dataset covering non-overlapping write tiles.""" return Subset(self, self.writer_indices) - def loader( - self, - batch_size: int = 1, - num_workers: int = 0, - **kwargs, - ): - """Returns a CellMapDataLoader for the dataset.""" - from .dataloader import CellMapDataLoader - from .subdataset import CellMapSubset - - return CellMapDataLoader( - CellMapSubset(self, self.writer_indices), - batch_size=batch_size, - num_workers=num_workers, - device=self.device, - is_train=False, - sampler=None, - **kwargs, - ).loader - - @property - def device(self) -> str | torch.device: - """Returns the device for the dataset.""" - return self._device + def __len__(self) -> int: + sb = self.sampling_box + if sb is None: + return 0 + grid = box_shape(sb, self._write_scale) + total = 1 + for v in grid.values(): + total *= v + return total def get_center(self, idx: int) -> dict[str, float]: - """ - Gets the center coordinates for a given index. - - Args: - ---- - idx: The index to get the center for. + """World centre coordinates for flat index *idx*.""" + sb = self.sampling_box + if sb is None: + raise IndexError("sampling_box is None") + scale = self._write_scale + grid = box_shape(sb, scale) + axes = list(sb.keys()) + shape_tuple = tuple(grid[ax] for ax in axes) + vox_idx = np.unravel_index(int(idx) % max(1, len(self)), shape_tuple) + return { + ax: sb[ax][0] + (vox_idx[i] + 0.5) * scale[ax] for i, ax in enumerate(axes) + } - Returns: - ------- - A dictionary of center coordinates. - """ - if idx < 0: - idx = len(self) + idx - try: - center_indices = np.unravel_index( - idx, [self.sampling_box_shape[c] for c in self.axis_order] + def get_indices(self, chunk_size: Mapping[str, float]) -> list[int]: + """Flat indices tiling the sampling box with chunk_size steps.""" + sb = self.sampling_box + if sb is None: + return [] + scale = self._write_scale + grid = box_shape(sb, scale) + axes = list(sb.keys()) + shape_tuple = tuple(grid[ax] for ax in axes) + + chunk_grid = { + ax: max( + 1, int(round((sb[ax][1] - sb[ax][0]) / chunk_size.get(ax, scale[ax]))) ) - except ValueError: - logger.error( - "Index %s out of bounds for dataset of length %d", idx, len(self) - ) - logger.warning("Returning closest index in bounds") - center_indices = [self.sampling_box_shape[c] - 1 for c in self.axis_order] - center = { - c: float( - center_indices[i] * self.smallest_voxel_sizes[c] - + self.sampling_box[c][0] - ) - for i, c in enumerate(self.axis_order) + for ax in axes } - return center - - def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: - """Returns a crop of the input and target data as PyTorch tensors, corresponding to the coordinate of the unwrapped index.""" - self._current_idx = idx - self._current_center = self.get_center(idx) - outputs = {} - for array_name in self.input_arrays.keys(): - array = self.input_sources[array_name][self._current_center] - if array.shape[0] != 1: - outputs[array_name] = array[None, ...] - else: - outputs[array_name] = array - outputs["idx"] = torch.tensor(idx) - - return outputs + chunk_tuple = tuple(chunk_grid[ax] for ax in axes) + indices = [] + for chunk_idx in np.ndindex(*chunk_tuple): + vox_idx = tuple( + int(chunk_idx[i] * shape_tuple[i] / chunk_tuple[i]) + for i in range(len(axes)) + ) + flat = int(np.ravel_multi_index(vox_idx, shape_tuple)) + indices.append(flat) + return indices + + # ------------------------------------------------------------------ + # Dataset interface + # ------------------------------------------------------------------ + + def __getitem__(self, idx: int) -> dict[str, Any]: + """Return input patches for index *idx* (for DataLoader iteration).""" + center = self.get_center(int(idx)) + result: dict[str, Any] = {"idx": torch.tensor(idx)} + for arr_name, src in self.input_sources.items(): + patch = src[center] + if patch.ndim > 0 and patch.shape[0] != 1: + patch = patch.unsqueeze(0) + result[arr_name] = patch + return result def __setitem__( self, idx: int | torch.Tensor | np.ndarray | Sequence[int], arrays: dict[str, torch.Tensor | np.ndarray], ) -> None: + """Write prediction *arrays* at the spatial location of *idx*. + + *idx* can be a scalar or a 1-D batch tensor. *arrays* is a dict + ``{class_name: tensor}`` (or ``{array_name: {class_name: tensor}}``). """ - Writes values for the given arrays at the given index. - - Args: - ---- - idx: The index or indices to write to. - arrays: Dictionary of arrays to write to disk. Data can be a - single array with channels for classes, or a dictionary - of arrays per class. - """ - if isinstance(idx, (torch.Tensor, np.ndarray, Sequence)): - if isinstance(idx, torch.Tensor): - idx = idx.cpu().numpy() - for batch_idx, i in enumerate(idx): - # Extract the data for this specific item in the batch - item_arrays = {} - for array_name, array in arrays.items(): - # Skip special metadata keys - if array_name in _METADATA_KEYS: + if isinstance(idx, torch.Tensor): + idx = idx.cpu().numpy() + if isinstance(idx, (np.ndarray, list, tuple)) and np.ndim(idx) > 0: + for batch_i, i in enumerate(idx): + item: dict[str, Any] = {} + for key, val in arrays.items(): + if key in _SKIP_KEYS: continue - if isinstance(array, (int, float)): - # Scalar values are the same for all items - item_arrays[array_name] = array - elif isinstance(array, dict): - # Dictionary of arrays - extract batch item from each - item_arrays[array_name] = { - label: label_array[batch_idx] - for label, label_array in array.items() - } + if isinstance(val, (int, float)): + item[key] = val + elif isinstance(val, dict): + item[key] = {k: v[batch_i] for k, v in val.items()} else: - # Regular array - extract the batch item - item_arrays[array_name] = array[batch_idx] - self.__setitem__(i, item_arrays) + item[key] = val[batch_i] + self.__setitem__(int(i), item) return - self._current_idx = idx - self._current_center = self.get_center(self._current_idx) - for array_name, array in arrays.items(): - # Skip special metadata keys - if array_name in _METADATA_KEYS: - continue - if isinstance(array, (int, float)): - for label in self.classes: - self.target_array_writers[array_name][label][ - self._current_center - ] = array - elif isinstance(array, dict): - for label, label_array in array.items(): - self.target_array_writers[array_name][label][ - self._current_center - ] = label_array - else: - for c, label in enumerate(self.classes): - self.target_array_writers[array_name][label][ - self._current_center - ] = array[c, ...] + center = self.get_center(int(idx)) - def __repr__(self) -> str: - """Returns a string representation of the dataset.""" - return ( - f"CellMapDatasetWriter(\n\tRaw path: {self.raw_path}\n\t" - f"Output path(s): {self.target_path}\n\tClasses: {self.classes})" - ) - - def get_target_array_writer( - self, array_name: str, array_info: Mapping[str, Sequence[int | float]] - ) -> dict[str, ImageWriter]: - """Returns a dictionary of ImageWriter for the target images (per class) for a given array.""" - target_image_writers = {} - for label in self.classes: - target_image_writers[label] = self.get_image_writer( - array_name, label, array_info - ) - - return target_image_writers - - def get_image_writer( - self, - array_name: str, - label: str, - array_info: Mapping[str, Sequence[int | float] | int], - ) -> ImageWriter: - """Returns an ImageWriter for a specific target image.""" - scale = array_info["scale"] - if not isinstance(scale, (Mapping, Sequence)): - raise TypeError(f"Scale must be a Mapping or Sequence, not {type(scale)}") - shape = array_info["shape"] - if not isinstance(shape, (Mapping, Sequence)): - raise TypeError(f"Shape must be a Mapping or Sequence, not {type(shape)}") - if "n_channels" in array_info: - shape = [array_info["n_channels"]] + list(shape) - if "c" not in self.axis_order: - self.axis_order = "c" + self.axis_order - scale_level = array_info.get("scale_level", 0) - if not isinstance(scale_level, int): - raise TypeError(f"Scale level must be an int, not {type(scale_level)}") - - return ImageWriter( - path=str(UPath(self.target_path) / label), - target_class=label, - scale=scale, # type: ignore - bounding_box=self.target_bounds[array_name], - write_voxel_shape=shape, # type: ignore - scale_level=scale_level, - axis_order=self.axis_order, - context=self.context, - fill_value=self.empty_value, - overwrite=self.overwrite, - ) - - def _get_box_shape(self, source_box: Mapping[str, list[float]]) -> dict[str, int]: - """Returns the shape of the box in voxels of the smallest voxel size requested.""" - box_shape = {} - for c, (start, stop) in source_box.items(): - size = stop - start - size /= self.smallest_voxel_sizes[c] - box_shape[c] = int(np.ceil(size)) - return box_shape + for key, val in arrays.items(): + if key in _SKIP_KEYS: + continue + # Find which target array and class this key maps to + for arr_name, writers in self.target_array_writers.items(): + if key in writers: + writers[key][center] = val + elif key in self.classes: + # Flat class key — write to first matching target array + if key in writers: + writers[key][center] = val + else: + # Write per channel if val is multi-channel + cls_idx = ( + self.model_classes.index(key) + if key in self.model_classes + else None + ) + if key in writers: + writers[key][center] = ( + val[cls_idx] + if cls_idx is not None and val.ndim > 0 + else val + ) + break + + # ------------------------------------------------------------------ + # DataLoader helper + # ------------------------------------------------------------------ - def _get_box_union( - self, - source_box: Mapping[str, list[float]] | None, - current_box: Mapping[str, list[float]] | None, - ) -> Mapping[str, list[float]] | None: - """Returns the union of the source and current boxes.""" - if source_box is not None: - if current_box is None: - return source_box - for c, (start, stop) in source_box.items(): - if stop <= start: - raise ValueError(f"Invalid box: start={start}, stop={stop}") - current_box[c][0] = min(current_box[c][0], start) - current_box[c][1] = max(current_box[c][1], stop) - return current_box - - def _get_box_intersection( + def loader( self, - source_box: Mapping[str, list[float]] | None, - current_box: Mapping[str, list[float]] | None, - ) -> Mapping[str, list[float]] | None: - """Returns the intersection of the source and current boxes.""" - if source_box is not None: - if current_box is None: - return source_box - for c, (start, stop) in source_box.items(): - if stop <= start: - raise ValueError(f"Invalid box: start={start}, stop={stop}") - current_box[c][0] = max(current_box[c][0], start) - current_box[c][1] = min(current_box[c][1], stop) - return current_box - - def verify(self) -> bool: - """Verifies that the dataset is valid to draw samples from.""" - # TODO: make more robust - try: - return len(self) > 0 - except Exception as e: - logger.warning("Dataset verification failed: %s", e) - return False - - def get_indices(self, chunk_size: Mapping[str, float]) -> Sequence[int]: - """Returns the indices of the dataset that will tile the dataset according to the chunk_size (supplied in world units).""" - # TODO: ADD TEST - - # Convert the target chunk size in world units to voxel units - chunk_size = { - c: int(size // self.smallest_voxel_sizes[c]) - for c, size in chunk_size.items() - } - - indices_dict = {} - for c, size in chunk_size.items(): - indices_dict[c] = np.arange(0, self.sampling_box_shape[c], size, dtype=int) - - if indices_dict[c][-1] != self.sampling_box_shape[c] - 1: - indices_dict[c] = np.append( - indices_dict[c], self.sampling_box_shape[c] - 1 - ) - - indices = [] - shape_values = list(self.sampling_box_shape.values()) - for i in np.ndindex(*[len(indices_dict[c]) for c in self.axis_order]): - index = [indices_dict[c][j] for c, j in zip(self.axis_order, i)] - index = np.ravel_multi_index(index, shape_values) - indices.append(index) - return indices + batch_size: int = 1, + num_workers: int = 0, + **kwargs: Any, + ) -> Any: + """Return a DataLoader that iterates over non-overlapping write tiles.""" + from .dataloader import CellMapDataLoader - def to( - self, device: str | torch.device, non_blocking: bool = True - ) -> "CellMapDatasetWriter": - """Sets the device for the dataset.""" - if device is None: - device = self.device - self._device = torch.device(device) - for source in self.input_sources.values(): - if isinstance(source, dict): - for source in source.values(): - if not hasattr(source, "to"): - continue - source.to(device, non_blocking=non_blocking) - else: - if not hasattr(source, "to"): - continue - source.to(str(device), non_blocking=non_blocking) - return self - - def set_raw_value_transforms(self, transforms: Callable) -> None: - """Sets the raw value transforms for the dataset.""" - self.raw_value_transforms = transforms - for source in self.input_sources.values(): - source.value_transform = transforms - - def get_weighted_sampler( - self, batch_size: int = 1, rng: Optional[torch.Generator] = None - ): - raise NotImplementedError( - "Weighted sampling is not typically used for writer datasets." - ) + return CellMapDataLoader( + Subset(self, self.writer_indices), + batch_size=batch_size, + num_workers=num_workers, + is_train=False, + **kwargs, + ).loader - def get_subset_random_sampler( - self, num_samples: int, rng: Optional[torch.Generator] = None - ): - raise NotImplementedError( - "Random sampling is not typically used for writer datasets." + def __repr__(self) -> str: + return ( + f"CellMapDatasetWriter(target={self.target_path!r}, " + f"classes={self.classes}, len={len(self)})" ) - - -# %% diff --git a/src/cellmap_data/datasplit.py b/src/cellmap_data/datasplit.py index c6fb110..e1902ae 100644 --- a/src/cellmap_data/datasplit.py +++ b/src/cellmap_data/datasplit.py @@ -1,97 +1,63 @@ +"""CellMapDataSplit: train/validation dataset management.""" + +from __future__ import annotations + import csv import logging import os -from functools import cached_property from typing import Any, Callable, Mapping, Optional, Sequence -import tensorstore import torch import torchvision.transforms.v2 as T +from torch.utils.data import Subset from tqdm import tqdm from .dataset import CellMapDataset from .multidataset import CellMapMultiDataset -from .subdataset import CellMapSubset from .transforms import Binarize, NaNtoNum logger = logging.getLogger(__name__) class CellMapDataSplit: - """ - A class to split the data into training and validation datasets. + """Manages train/validation splits for CellMap data. - Attributes: - ---------- - input_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to input to the network. The dictionary should have the following structure:: - { - "array_name": { - "shape": tuple[int], - "scale": Sequence[float], - }, - ... - } - - target_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to use as targets for the network. The dictionary should have the same structure as input_arrays. - classes (Sequence[str]): A list of classes for segmentation training. Class order will be preserved in the output arrays. - empty_value (int | float): The value to use for empty data. Defaults to torch.nan. - pad (bool | str): Whether to pad the data. If a string, it should be either "train" or "validate". Defaults to False. - datasets (Optional[Mapping[str, Sequence[CellMapDataset]]]): A dictionary containing the dataset objects. The dictionary should have the following structure: - { - "train": Iterable[CellMapDataset], - "validate": Iterable[CellMapDataset], - }. Defaults to None. - dataset_dict (Optional[Mapping[str, Sequence[Mapping[str, str]]]): A dictionary containing the dataset data. Defaults to None. The dictionary should have the following structure:: - - { - "train" | "validate": [{ - "raw": str (path to raw data), - "gt": str (path to ground truth data), - }], - ... - } - - csv_path (Optional[str]): A path to a csv file containing the dataset data. Defaults to None. Each row in the csv file should have the following structure: - train | validate, raw path, gt path - spatial_transforms (Optional[Sequence[dict[str, Any]]]): A sequence of dictionaries containing the spatial transformations to apply to the data. Defaults to None. The dictionary should have the following structure:: - - {transform_name: {transform_args}} - - train_raw_value_transforms (Optional[Callable]): A function to apply to the raw data in training datasets. Example is to add gaussian noise to the raw data. Defailts to Normalize, ToDtype, and NaNtoNum. - val_raw_value_transforms (Optional[Callable]): A function to apply to the raw data in validation datasets. Example is to normalize the raw data. Defaults to Normalize, ToDtype, and NaNtoNum. - target_value_transforms (Optional[Callable | Sequence[Callable] | Mapping[str, Callable]]): A function to convert the ground truth data to target arrays. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order. Defaults to ToDtype and Binarize. - class_relation_dict (Optional[Mapping[str, Sequence[str]]]): A dictionary containing the class relations. The dictionary should have the following structure:: - - { - "class_name": [mutually_exclusive_class_name, ...], - ... - } - - force_has_data (bool): Whether to force the datasets to have data even if no ground truth data is found. Defaults to False. Useful for training with only raw data. - context (Optional[tensorstore.Context]): The TensorStore context for the image data. Defaults to None. - device (Optional[str | torch.device]): Device to use for the dataloaders. Defaults to None. - - Note: - ---- - The csv_path, dataset_dict, and datasets arguments are mutually exclusive, but one must be supplied. - - Methods: - ------- - __repr__(): Returns the string representation of the class. - from_csv(csv_path: str): Loads the dataset data from a csv file. - construct(dataset_dict: Mapping[str, Sequence[Mapping[str, str]]]): Constructs the datasets from the dataset dictionary. - verify_datasets(): Verifies that the datasets have data, and removes ones that don't from 'self.train_datasets' and 'self.validation_datasets'. - set_raw_value_transforms(train_transforms: Optional[Callable] = None, val_transforms: Optional[Callable] = None): Sets the raw value transforms for each dataset in the training/validation multi-datasets. - set_target_value_transforms(transforms: Callable): Sets the target value transforms for each dataset in both training and validation multi-datasets. - set_spatial_transforms(spatial_transforms: dict[str, Any] | None): Sets the spatial transforms for each dataset in the training multi-dataset. - set_arrays(arrays: Mapping[str, Mapping[str, Sequence[int | float]]], type: str = "target", usage: str = "validate"): Sets the input or target arrays for the training or validation datasets. - - Properties: - train_datasets_combined: A multi-dataset from the combination of all training datasets. - validation_datasets_combined: A multi-dataset from the combination of all validation datasets. - validation_blocks: A subset of the validation datasets, tiling the validation datasets with non-overlapping blocks. - class_counts: A dictionary containing the class counts for the training and validation datasets. + Reads dataset paths from a CSV file, a ``dataset_dict``, or a + pre-built ``datasets`` mapping, then constructs + :class:`CellMapDataset` objects and exposes combined datasets for + training and validation. + Parameters + ---------- + input_arrays: + ``{name: {"shape": (z,y,x), "scale": (z,y,x)}}`` + target_arrays: + Same structure as *input_arrays*. + classes: + Segmentation class names. + pad: + Pad strategy: ``False``, ``True``, ``"train"``, or ``"validate"``. + datasets: + Pre-built ``{"train": [CellMapDataset, …], "validate": […]}`` + mapping. Mutually exclusive with *dataset_dict* / *csv_path*. + dataset_dict: + ``{"train": [{"raw": path, "gt": path}, …], "validate": […]}``. + csv_path: + Path to CSV with rows ``split,raw_path,gt_path[,raw_name,gt_name]``. + spatial_transforms: + Augmentation config for training datasets. + train_raw_value_transforms: + Transform applied to raw data during training. + val_raw_value_transforms: + Transform applied to raw data during validation. + target_value_transforms: + Transform applied to GT labels. + class_relation_dict: + Mutual-exclusion class relations (stored, not used for inference). + force_has_data: + Skip empty-data check on each dataset. + device: + Ignored (API compatibility). """ def __init__( @@ -100,342 +66,258 @@ def __init__( target_arrays: Optional[ Mapping[str, Mapping[str, Sequence[int | float]]] ] = None, - classes: Sequence[str] | None = None, - empty_value: int | float = torch.nan, + classes: Optional[Sequence[str]] = None, pad: bool | str = False, datasets: Optional[Mapping[str, Sequence[CellMapDataset]]] = None, dataset_dict: Optional[Mapping[str, Sequence[Mapping[str, str]]]] = None, csv_path: Optional[str] = None, spatial_transforms: Optional[Mapping[str, Any]] = None, - train_raw_value_transforms: Optional[T.Transform] = T.Compose( + train_raw_value_transforms: Optional[Callable] = T.Compose( [ T.ToDtype(torch.float, scale=True), NaNtoNum({"nan": 0, "posinf": None, "neginf": None}), - ], + ] ), - val_raw_value_transforms: Optional[T.Transform] = T.Compose( + val_raw_value_transforms: Optional[Callable] = T.Compose( [ T.ToDtype(torch.float, scale=True), NaNtoNum({"nan": 0, "posinf": None, "neginf": None}), - ], + ] ), - target_value_transforms: Optional[T.Transform] = T.Compose( + target_value_transforms: Optional[Callable] = T.Compose( [T.ToDtype(torch.float), Binarize()] ), class_relation_dict: Optional[Mapping[str, Sequence[str]]] = None, force_has_data: bool = False, - context: Optional[tensorstore.Context] = None, # type: ignore + context: Optional[Any] = None, # ignored, kept for API compat device: Optional[str | torch.device] = None, ) -> None: - """Initializes the CellMapDatasets class. - - Args: - ---- - input_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to input to the network. The dictionary should have the following structure:: - - { - "array_name": { - "shape": tuple[int], - "scale": Sequence[float], - }, - ... - } - - target_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to use as targets for the network. The dictionary should have the same structure as input_arrays. - classes (Sequence[str]): A list of classes for segmentation training. Class order will be preserved in the output arrays. - empty_value (int | float): The value to use for empty data. Defaults to torch.nan. - pad (bool | str): Whether to pad the data. If a string, it should be either "train" or "validate". Defaults to False. - datasets (Optional[Mapping[str, Sequence[CellMapDataset]]]): A dictionary containing the dataset objects. Defaults to None. The dictionary should have the following structure:: - - { - "train": Iterable[CellMapDataset], - "validate": Iterable[CellMapDataset], - }. - - dataset_dict (Optional[Mapping[str, Sequence[Mapping[str, str]]]): A dictionary containing the dataset data. Defaults to None. The dictionary should have the following structure:: - - { - "train" | "validate": [{ - "raw": str (path to raw data), - "gt": str (path to ground truth data), - }], - ... - } - - csv_path (Optional[str]): A path to a csv file containing the dataset data. Defaults to None. Each row in the csv file should have the following structure:" - - train | validate, raw path, gt path - - spatial_transforms (Optional[Sequence[dict[str, Any]]]): A sequence of dictionaries containing the spatial transformations to apply to the data. Defaults to None. The dictionary should have the following structure:: - - {transform_name: {transform_args}} - - train_raw_value_transforms (Optional[Callable]): A function to apply to the raw data in training datasets. Defaults to None. Example is to add gaussian noise to the raw data. - val_raw_value_transforms (Optional[Callable]): A function to apply to the raw data in validation datasets. Defaults to None. Example is to normalize the raw data. - target_value_transforms (Optional[Callable | Sequence[Callable] | Mapping[str, Callable]]): A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order. - class_relation_dict (Optional[Mapping[str, Sequence[str]]]): A dictionary containing the class relations. The dictionary should have the following structure:: - - { - "class_name": [mutually_exclusive_class_name, ...], - ... - } - - force_has_data (bool): Whether to force the datasets to have data even if no ground truth data is found. Defaults to False. Useful for training with only raw data. - context (Optional[tensorstore.Context]): The TensorStore context for the image data. Defaults to None. - device (Optional[str | torch.device]): Device to use for the dataloaders. Defaults to None. - - Note: - ---- - The csv_path, dataset_dict, and datasets arguments are mutually exclusive, but one must be supplied. - - """ - logger.info("Initializing CellMapDataSplit...") - self.input_arrays = input_arrays - self.target_arrays = target_arrays - self.classes = classes - self.empty_value = empty_value + self.input_arrays = dict(input_arrays) + self.target_arrays = dict(target_arrays) if target_arrays else {} + self.classes = list(classes) if classes else [] self.pad = pad - self.device = device - if isinstance(pad, str): - self.pad_training = pad.lower() == "train" - self.pad_validation = pad.lower() == "validate" - else: - self.pad_training = pad - self.pad_validation = pad + self.spatial_transforms = spatial_transforms + self.train_raw_value_transforms = train_raw_value_transforms + self.val_raw_value_transforms = val_raw_value_transforms + self.target_value_transforms = target_value_transforms + self.class_relation_dict = class_relation_dict self.force_has_data = force_has_data + # Storage for train/val datasets + self.train_datasets: list[CellMapDataset] = [] + self._validation_datasets: list[CellMapDataset] = [] + if datasets is not None: - self.datasets = datasets - self.train_datasets = datasets["train"] - if "validate" in datasets: - self.validation_datasets = datasets["validate"] - else: - self.validation_datasets = [] - self.dataset_dict = None + self.train_datasets = list(datasets.get("train", [])) + self._validation_datasets = list(datasets.get("validate", [])) elif dataset_dict is not None: - self.dataset_dict = dataset_dict + self._construct(dataset_dict) + self._verify_datasets() elif csv_path is not None: - self.dataset_dict = self.from_csv(csv_path) - else: - # No data source provided - this should raise an error - raise ValueError( - "One of 'datasets', 'dataset_dict', or 'csv_path' must be provided" + dataset_dict = self._parse_csv(csv_path) + self._construct(dataset_dict) + self._verify_datasets() + # else: empty split; user can call _construct later or set datasets directly + + # ------------------------------------------------------------------ + # CSV parsing + # ------------------------------------------------------------------ + + @staticmethod + def _parse_csv( + csv_path: str, + ) -> dict[str, list[dict[str, str]]]: + """Parse the dataset CSV into a ``dataset_dict``. + + Expected CSV columns: ``split, raw_path, gt_path`` (and optionally + ``raw_name``, ``gt_name`` which are ignored). + """ + result: dict[str, list[dict[str, str]]] = { + "train": [], + "validate": [], + } + with open(csv_path, newline="") as fh: + reader = csv.reader(fh) + for row in reader: + if not row or row[0].startswith("#"): + continue + if len(row) < 3: + logger.warning("Skipping malformed CSV row: %s", row) + continue + split = row[0].strip() + raw_path = row[1].strip() + gt_path = row[2].strip() + if split not in result: + result[split] = [] + result[split].append({"raw": raw_path, "gt": gt_path}) + return result + + # ------------------------------------------------------------------ + # Dataset construction + # ------------------------------------------------------------------ + + def _construct( + self, dataset_dict: Mapping[str, Sequence[Mapping[str, str]]] + ) -> None: + """Build CellMapDataset objects from the dict of raw/gt path pairs.""" + for split, entries in dataset_dict.items(): + is_train = split.lower().startswith("train") + pad = ( + self.pad + if isinstance(self.pad, bool) + else (split.lower() in self.pad.lower()) ) + raw_tx = ( + self.train_raw_value_transforms + if is_train + else self.val_raw_value_transforms + ) + spatial_tx = self.spatial_transforms if is_train else None - # Temporary initialization of datasets lists for dataset_dict and csv_path paths. - # These will be immediately overwritten by the construct() method for non-'datasets' paths. - if datasets is None: - self.train_datasets = [] - self.validation_datasets = [] + for entry in entries: + raw_path = entry.get("raw", "") + gt_path = entry.get("gt", "") + if not raw_path: + continue + try: + ds = CellMapDataset( + raw_path=raw_path, + target_path=gt_path, + classes=self.classes, + input_arrays=self.input_arrays, + target_arrays=self.target_arrays, + pad=pad, + spatial_transforms=spatial_tx, + raw_value_transforms=raw_tx, + target_value_transforms=self.target_value_transforms, + class_relation_dict=self.class_relation_dict, + force_has_data=self.force_has_data, + ) + if is_train: + self.train_datasets.append(ds) + else: + self._validation_datasets.append(ds) + except Exception as exc: + logger.warning( + "Skipping dataset raw=%r gt=%r: %s", raw_path, gt_path, exc + ) - self.spatial_transforms = spatial_transforms - self.train_raw_value_transforms = train_raw_value_transforms - self.val_raw_value_transforms = val_raw_value_transforms - self.target_value_transforms = target_value_transforms - self.class_relation_dict = class_relation_dict - self.context = context - if self.dataset_dict is not None: - self.construct(self.dataset_dict) - self.verify_datasets() - # Require training datasets unless force_has_data is True - if not self.force_has_data and not (len(self.train_datasets) > 0): - raise ValueError("No valid training datasets found.") - logger.info("CellMapDataSplit initialized.") + def _verify_datasets(self) -> None: + """Remove datasets that report no valid data.""" + if self.force_has_data: + return + self.train_datasets = [ + ds + for ds in tqdm( + self.train_datasets, + desc="Verifying train datasets", + leave=False, + ) + if ds.verify() + ] + self._validation_datasets = [ + ds + for ds in tqdm( + self._validation_datasets, + desc="Verifying val datasets", + leave=False, + ) + if ds.verify() + ] - def __repr__(self) -> str: - """Returns the string representation of the class.""" - return f"CellMapDataSplit(\n\tInput arrays: {self.input_arrays}\n\tTarget arrays:{self.target_arrays}\n\tClasses: {self.classes}\n\tDataset dict: {self.dataset_dict}\n\tSpatial transforms: {self.spatial_transforms}\n\tRaw value transforms: {self.train_raw_value_transforms}\n\tGT value transforms: {self.target_value_transforms}\n\tForce has data: {self.force_has_data}\n\tContext: {self.context})" + # ------------------------------------------------------------------ + # Cached combined datasets + # ------------------------------------------------------------------ - @cached_property + @property def train_datasets_combined(self) -> CellMapMultiDataset: - """A multi-dataset from the combination of all training datasets.""" - return CellMapMultiDataset( - self.classes, - self.input_arrays, - self.target_arrays, - [ds for ds in self.train_datasets if self.force_has_data or ds.has_data], - ) + """Combined training dataset for use with DataLoader.""" + if "train_datasets_combined" not in self.__dict__: + self.__dict__["train_datasets_combined"] = CellMapMultiDataset( + datasets=self.train_datasets, + classes=self.classes, + input_arrays=self.input_arrays, + target_arrays=self.target_arrays, + ) + return self.__dict__["train_datasets_combined"] - @cached_property + @property def validation_datasets_combined(self) -> CellMapMultiDataset: - """A multi-dataset from the combination of all validation datasets.""" - if len(self.validation_datasets) == 0: - logger.warning("Validation datasets not loaded.") - return CellMapMultiDataset.empty() - return CellMapMultiDataset( - self.classes, - self.input_arrays, - self.target_arrays, - [ - ds - for ds in self.validation_datasets - if self.force_has_data or ds.has_data - ], - ) - - @cached_property - def validation_blocks(self) -> CellMapSubset: - """A subset of the validation datasets, tiling the validation datasets with non-overlapping blocks.""" - return CellMapSubset( - self.validation_datasets_combined, - self.validation_datasets_combined.validation_indices, - ) - - @cached_property - def class_counts(self) -> dict[str, dict[str, float]]: - """A dictionary containing the class counts for the training and validation datasets.""" + """Combined validation dataset.""" + if "validation_datasets_combined" not in self.__dict__: + self.__dict__["validation_datasets_combined"] = CellMapMultiDataset( + datasets=self._validation_datasets, + classes=self.classes, + input_arrays=self.input_arrays, + target_arrays=self.target_arrays, + ) + return self.__dict__["validation_datasets_combined"] + + @property + def validation_datasets(self) -> list[CellMapDataset]: + """List of individual validation datasets.""" + return self._validation_datasets + + @property + def validation_blocks(self) -> Subset: + """Non-overlapping validation tile indices wrapped in a Subset.""" + if "validation_blocks" not in self.__dict__: + combined = self.validation_datasets_combined + indices = combined.validation_indices + self.__dict__["validation_blocks"] = Subset(combined, indices) + return self.__dict__["validation_blocks"] + + @property + def class_counts(self) -> dict[str, Any]: + """Train and validation class counts.""" return { "train": self.train_datasets_combined.class_counts, "validate": self.validation_datasets_combined.class_counts, } - @classmethod - def from_csv(cls, csv_path) -> dict[str, Sequence[dict[str, str]]]: - """Loads the dataset_dict data from a csv file.""" - dataset_dict = {} - with open(csv_path) as f: - reader = csv.reader(f) - logger.info("Reading csv file...") - for row in reader: - try: - if row[0] not in dataset_dict: - dataset_dict[row[0]] = [] - dataset_dict[row[0]].append( - { - "raw": os.path.join(row[1], row[2]), - "gt": os.path.join(row[3], row[4]) if len(row) > 3 else "", - } - ) - except Exception as e: - logger.warning(f"Skipping row {reader.line_num} due to error: {e}") - - return dataset_dict - - def construct(self, dataset_dict) -> None: - """Constructs the datasets from the dataset dictionary.""" - self.train_datasets = [] - self.validation_datasets = [] - self.datasets = {} - logger.info("Constructing datasets...") - if "train" in dataset_dict: - for data_paths in tqdm(dataset_dict["train"], desc="Training datasets"): - try: - self.train_datasets.append( - CellMapDataset( - data_paths["raw"], - data_paths["gt"], - self.classes, - self.input_arrays, - self.target_arrays, - spatial_transforms=self.spatial_transforms, - raw_value_transforms=self.train_raw_value_transforms, - target_value_transforms=self.target_value_transforms, - is_train=True, - context=self.context, - force_has_data=self.force_has_data, - empty_value=self.empty_value, - class_relation_dict=self.class_relation_dict, - pad=self.pad_training, - ) - ) - except Exception as e: - logger.warning(f"Skipping training dataset due to error: {e}") - if "validate" in dataset_dict: - for data_paths in tqdm( - dataset_dict["validate"], desc="Validation datasets" - ): - try: - self.validation_datasets.append( - CellMapDataset( - data_paths["raw"], - data_paths["gt"], - self.classes, - self.input_arrays, - self.target_arrays, - spatial_transforms=self.spatial_transforms, - raw_value_transforms=self.val_raw_value_transforms, - target_value_transforms=self.target_value_transforms, - is_train=False, - context=self.context, - force_has_data=self.force_has_data, - empty_value=self.empty_value, - class_relation_dict=self.class_relation_dict, - pad=self.pad_validation, - ) - ) - except Exception as e: - logger.warning(f"Skipping validation dataset due to error: {e}") - self.datasets = { - "train": self.train_datasets, - "validate": self.validation_datasets, - } + # ------------------------------------------------------------------ + # Cache invalidation + # ------------------------------------------------------------------ - def verify_datasets(self) -> None: - """Verifies that the datasets have data, and removes ones that don't from ``self.train_datasets`` and ``self.validation_datasets``.""" - if self.force_has_data: - return - logger.info("Verifying datasets...") - verified_datasets = [] - for ds in tqdm(self.train_datasets, desc="Training datasets"): - if ds.verify(): - verified_datasets.append(ds) - self.train_datasets = verified_datasets - - verified_datasets = [] - for ds in tqdm(self.validation_datasets, desc="Validation datasets"): - if ds.verify(): - verified_datasets.append(ds) - self.validation_datasets = verified_datasets + def _invalidate(self) -> None: + """Clear all cached combined-dataset properties.""" + for key in ( + "train_datasets_combined", + "validation_datasets_combined", + "validation_blocks", + ): + self.__dict__.pop(key, None) + + # ------------------------------------------------------------------ + # Setters (invalidate cache) + # ------------------------------------------------------------------ def set_raw_value_transforms( self, train_transforms: Optional[Callable] = None, val_transforms: Optional[Callable] = None, ) -> None: - """Sets the raw value transforms for each dataset in the training/validation multi-datasets.""" - if train_transforms is not None: - for dataset in self.train_datasets: - dataset.set_raw_value_transforms(train_transforms) - if "train_datasets_combined" in self.__dict__: - self.train_datasets_combined.set_raw_value_transforms(train_transforms) - if val_transforms is not None: - for dataset in self.validation_datasets: - dataset.set_raw_value_transforms(val_transforms) - if "validation_datasets_combined" in self.__dict__: - self.validation_datasets_combined.set_raw_value_transforms( - val_transforms - ) - - def set_target_value_transforms(self, transforms: Callable) -> None: - """Sets the target value transforms for each dataset in the multi-datasets.""" - for dataset in self.train_datasets: - dataset.set_target_value_transforms(transforms) - if "train_datasets_combined" in self.__dict__: - self.train_datasets_combined.set_target_value_transforms(transforms) - - for dataset in self.validation_datasets: - dataset.set_target_value_transforms(transforms) - if "validation_datasets_combined" in self.__dict__: - self.validation_datasets_combined.set_target_value_transforms(transforms) - if "validation_blocks" in self.__dict__: - self.validation_blocks.set_target_value_transforms(transforms) + self.train_raw_value_transforms = train_transforms + self.val_raw_value_transforms = val_transforms + for ds in self.train_datasets: + ds.set_raw_value_transforms(train_transforms) + for ds in self._validation_datasets: + ds.set_raw_value_transforms(val_transforms) + self._invalidate() + + def set_target_value_transforms(self, transforms: Optional[Callable]) -> None: + self.target_value_transforms = transforms + for ds in self.train_datasets + self._validation_datasets: + ds.set_target_value_transforms(transforms) + self._invalidate() def set_spatial_transforms( - self, - train_transforms: Optional[dict[str, Any]] = None, - val_transforms: Optional[dict[str, Any]] = None, + self, spatial_transforms: Optional[Mapping[str, Any]] ) -> None: - """Sets the raw value transforms for each dataset in the training/validation multi-dataset.""" - if train_transforms is not None: - for dataset in self.train_datasets: - dataset.spatial_transforms = train_transforms - if "train_datasets_combined" in self.__dict__: - self.train_datasets_combined.set_spatial_transforms(train_transforms) - if val_transforms is not None: - for dataset in self.validation_datasets: - dataset.spatial_transforms = val_transforms - if "validation_datasets_combined" in self.__dict__: - self.validation_datasets_combined.set_spatial_transforms(val_transforms) + self.spatial_transforms = spatial_transforms + for ds in self.train_datasets: + ds.set_spatial_transforms(spatial_transforms) + self._invalidate() def set_arrays( self, @@ -443,36 +325,21 @@ def set_arrays( type: str = "target", usage: str = "validate", ) -> None: - """Sets the input or target arrays for the training or validation datasets.""" - reset_attrs = [] - for dataset in self.datasets[usage]: - if type == "inputs": - dataset.input_arrays = arrays - elif type == "target": - dataset.target_arrays = arrays - else: - raise ValueError("Type must be 'inputs' or 'target'.") - dataset.reset_arrays(type) - - if usage == "train": - self.train_datasets = self.datasets["train"] - reset_attrs.append("train_datasets_combined") - elif usage == "validate": - self.validation_datasets = self.datasets["validate"] - reset_attrs.extend(["validation_datasets_combined", "validation_blocks"]) - for attr in reset_attrs: - self.__dict__.pop(attr, None) - - def to(self, device: str | torch.device, non_blocking: bool = True) -> None: - """Sets the device for the dataloaders.""" - self.device = device - for dataset in self.train_datasets: - dataset.to(device, non_blocking=non_blocking) - for dataset in self.validation_datasets: - dataset.to(device, non_blocking=non_blocking) - if "train_datasets_combined" in self.__dict__: - self.train_datasets_combined.to(device, non_blocking=non_blocking) - if "validation_datasets_combined" in self.__dict__: - self.validation_datasets_combined.to(device, non_blocking=non_blocking) - if "validation_blocks" in self.__dict__: - self.validation_blocks.to(device, non_blocking=non_blocking) + if type == "target": + self.target_arrays = dict(arrays) + else: + self.input_arrays = dict(arrays) + self._invalidate() + + def to(self, device: str | torch.device) -> "CellMapDataSplit": + self.train_datasets_combined.to(device) + self.validation_datasets_combined.to(device) + return self + + def __repr__(self) -> str: + return ( + f"CellMapDataSplit(" + f"train={len(self.train_datasets)}, " + f"val={len(self._validation_datasets)}, " + f"classes={self.classes})" + ) diff --git a/src/cellmap_data/empty_image.py b/src/cellmap_data/empty_image.py index 83b54c3..32ea702 100644 --- a/src/cellmap_data/empty_image.py +++ b/src/cellmap_data/empty_image.py @@ -1,88 +1,73 @@ -from typing import Any, Mapping, Optional, Sequence +"""EmptyImage: a NaN-filled placeholder for unannotated classes.""" -import torch - -from .base_image import CellMapImageBase +from __future__ import annotations +from typing import Any, Callable, Mapping, Optional, Sequence -class EmptyImage(CellMapImageBase): - """ - A class for handling empty image data. +import torch - This class is used to create an empty image object, which can be used as a placeholder for images that do not exist in the dataset. It can be used to maintain a consistent API for image objects even when no data is present. - Attributes - ---------- - label_class (str): The intended label class of the image. - target_scale (Sequence[float]): The intended scale of the image in physical space. - target_voxel_shape (Sequence[int]): The intended shape of the image in voxels. - store (Optional[torch.Tensor]): The tensor to return. - axis_order (str): The intended order of the axes in the image. - empty_value (float | int): The value to fill the image with. +class EmptyImage: + """Returns a NaN tensor for every read. - Methods - ------- - __getitem__(center: Mapping[str, float]) -> torch.Tensor: Returns the empty image data. - to(device: str): Moves the image data to the given device. - set_spatial_transforms(transforms: Mapping[str, Any] | None): - Imitates the method in CellMapImage, but does nothing for an EmptyImage object. + Used when a class is not annotated in a given dataset. NaN signals + *unknown* to the model, as opposed to *absent* (zeros). - Properties: - bounding_box (None): Returns None. - sampling_box (None): Returns None. - bg_count (float): Returns zero. - class_counts (float): Returns zero. + The constructor signature mirrors :class:`CellMapImage` so the two can + be used interchangeably inside :class:`CellMapDataset`. """ def __init__( self, - label_class: str, - scale: Sequence[float], - voxel_shape: Sequence[int], - store: Optional[torch.Tensor] = None, - axis_order: str = "zyx", - empty_value: float | int = -100, - ): - self.label_class = label_class - self.scale_tuple = scale - if len(voxel_shape) < len(axis_order): - axis_order = axis_order[-len(voxel_shape) :] - self.output_shape = {c: voxel_shape[i] for i, c in enumerate(axis_order)} - self.output_size = {c: t * s for c, t, s in zip(axis_order, voxel_shape, scale)} - self.axes = axis_order - self._bounding_box = None - self._class_counts = 0.0 - self._bg_count = 0.0 - self.scale = {c: sc for c, sc in zip(self.axes, self.scale_tuple)} - self.empty_value = empty_value - if store is not None: - self.store = store - else: - self.store = ( - torch.ones([self.output_shape[c] for c in self.axes]) * self.empty_value - ) + path: str, + target_class: str, + target_scale: Sequence[float], + target_voxel_shape: Sequence[int], + pad: bool = False, + pad_value: float = float("nan"), + interpolation: str = "nearest", + axis_order: str | Sequence[str] = "zyx", + value_transform: Optional[Callable] = None, + context: Any = None, + device: Optional[str | torch.device] = None, + ) -> None: + self.path = path + self.label_class = target_class + axis_order = list(axis_order) + if len(axis_order) > len(target_voxel_shape): + ndim_fix = len(axis_order) - len(target_voxel_shape) + target_voxel_shape = [1] * ndim_fix + list(target_voxel_shape) + self.axes: list[str] = axis_order[: len(target_voxel_shape)] + self.scale = {c: float(s) for c, s in zip(axis_order, target_scale)} + self.output_shape = {c: int(t) for c, t in zip(axis_order, target_voxel_shape)} + self._nan_tensor = torch.full( + [self.output_shape[ax] for ax in self.axes], float("nan") + ) def __getitem__(self, center: Mapping[str, float]) -> torch.Tensor: - return self.store + return self._nan_tensor.clone() + + def set_spatial_transforms(self, transforms: dict | None) -> None: + pass @property def bounding_box(self) -> None: - return self._bounding_box + return None @property def sampling_box(self) -> None: - return self._bounding_box + return None @property - def bg_count(self) -> float: - return self._bg_count - - @property - def class_counts(self) -> float: - return self._class_counts - - def to(self, device: str | torch.device, non_blocking: bool = True) -> None: - self.store = self.store.to(device, non_blocking=non_blocking) - - def set_spatial_transforms(self, transforms: Mapping[str, Any] | None) -> None: - pass + def class_counts(self) -> dict[str, int]: + return {self.label_class: 0} + + def to(self, device: str | torch.device) -> "EmptyImage": + self._nan_tensor = self._nan_tensor.to(device) + return self + + def __repr__(self) -> str: + return ( + f"EmptyImage(class={self.label_class!r}, " + f"shape={list(self.output_shape.values())})" + ) diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 47f90f0..43f44ff 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -1,650 +1,502 @@ -from functools import cached_property -import json +"""CellMapImage: reads patches from OME-NGFF multiscale zarr arrays.""" + +from __future__ import annotations + import logging -import os +from functools import cached_property from typing import Any, Callable, Mapping, Optional, Sequence -import dask.array as da import numpy as np -import tensorstore as ts import torch -import xarray -import xarray_tensorstore as xt +import torch.nn.functional as F import zarr -from pydantic_ome_ngff.v04.multiscale import MultiscaleGroupAttrs, MultiscaleMetadata -from pydantic_ome_ngff.v04.transform import Scale, Translation, VectorScale -from scipy.spatial.transform import Rotation as rot -from xarray_ome_ngff.v04.multiscale import coords_from_transforms +from scipy.ndimage import rotate as scipy_rotate -from .base_image import CellMapImageBase +from .utils.geometry import box_shape logger = logging.getLogger(__name__) -class CellMapImage(CellMapImageBase): - """ - A class for handling image data from a CellMap dataset. - - This class is used to load image data from a CellMap dataset, and can apply spatial transformations to the data. It also handles the loading of the image data from the dataset, and can apply value transformations to the data. The image data is returned as a PyTorch tensor formatted for use in training, and can be loaded onto a specified device. +class CellMapImage: + """Load a patch from a single OME-NGFF multiscale zarr array. + + Parameters + ---------- + path: + Path to the zarr group (e.g. ``/data/jrc_hela-2.zarr/raw``). + target_class: + Semantic label for this array (e.g. ``"raw"``, ``"mito"``). + target_scale: + Desired voxel size in nm per axis, ordered according to *axis_order*. + target_voxel_shape: + Output patch size in voxels, ordered according to *axis_order*. + pad: + Whether to pad with *pad_value* when the patch extends beyond the + array bounds. If ``False`` the center is clamped so patches never + extend outside. + pad_value: + Fill value for out-of-bounds regions. Defaults to ``nan`` so the + model can mask unknown data in the loss. + interpolation: + ``"linear"`` for bilinear/trilinear resampling (raw EM), + ``"nearest"`` for nearest-neighbour (labels). + axis_order: + Axis names in the order they appear in *target_scale* / + *target_voxel_shape*. Defaults to ``"zyx"``. + value_transform: + Optional callable applied to the output tensor (e.g. ``Binarize``). + context: + Ignored (kept for API compatibility with the old TensorStore code). + device: + Ignored – tensors are returned on CPU; the DataLoader moves them. """ def __init__( self, path: str, target_class: str, - target_scale: Sequence[float], # TODO: make work with dict - target_voxel_shape: Sequence[int], # TODO: make work with dict + target_scale: Sequence[float], + target_voxel_shape: Sequence[int], pad: bool = False, - pad_value: float | int = np.nan, + pad_value: float = float("nan"), interpolation: str = "nearest", axis_order: str | Sequence[str] = "zyx", value_transform: Optional[Callable] = None, - context: Optional[ts.Context] = None, # type: ignore + context: Any = None, device: Optional[str | torch.device] = None, ) -> None: - """Initializes a CellMapImage object. - - Args: - ---- - path (str): The path to the image file. - target_class (str): The label class of the image. - target_scale (Sequence[float]): The scale of the image data to return in physical space. - target_voxel_shape (Sequence[int]): The shape of the image data to return in voxels. - axis_order (str, optional): The order of the axes in the image. Defaults to "zyx". - value_transform (Optional[callable], optional): A function to transform the image pixel data. Defaults to None. - context (Optional[tensorstore.Context], optional): The context for the image data. Defaults to None. - device (Optional[str | torch.device], optional): The device to load the image data onto. Defaults to "cuda" if available, then "mps", then "cpu". - """ self.path = path self.label_class = target_class - # Below makes assumptions about image scale, and also locks which axis is sliced to 2D (this should only be encountered if bypassing dataset) + axis_order = list(axis_order) + + # Pad scale / shape to match axis_order length (preserve existing behaviour) if len(axis_order) > len(target_scale): - logger.info( - f"Axis order {axis_order} has more axes than target scale {target_scale}. Padding target scale with first given scale ({target_scale[0]})." - ) target_scale = [target_scale[0]] * ( len(axis_order) - len(target_scale) ) + list(target_scale) if len(axis_order) > len(target_voxel_shape): ndim_fix = len(axis_order) - len(target_voxel_shape) - logger.warning( - f"Axis order {axis_order} has more axes than target voxel shape {target_voxel_shape}. Padding first {ndim_fix} target voxel shapes with 1s." - ) target_voxel_shape = [1] * ndim_fix + list(target_voxel_shape) + self.pad = pad self.pad_value = pad_value self.interpolation = interpolation - self.scale = {c: s for c, s in zip(axis_order, target_scale)} - self.output_shape = {c: t for c, t in zip(axis_order, target_voxel_shape)} - self.output_size = { - c: t * s for c, t, s in zip(axis_order, target_voxel_shape, target_scale) + self.axes: list[str] = axis_order[: len(target_voxel_shape)] + self.scale: dict[str, float] = { + c: float(s) for c, s in zip(axis_order, target_scale) + } + self.output_shape: dict[str, int] = { + c: int(t) for c, t in zip(axis_order, target_voxel_shape) + } + self.output_size: dict[str, float] = { + c: self.output_shape[c] * self.scale[c] for c in self.axes } - self.axes = axis_order[: len(target_voxel_shape)] self.value_transform = value_transform - self.context = context - self._current_spatial_transforms = None - self._current_coords: Any = None - self._current_center = None - if device is not None: - self.device = device - elif torch.cuda.is_available(): - self.device = "cuda" - elif torch.backends.mps.is_available(): - self.device = "mps" - else: - self.device = "cpu" + self._current_spatial_transforms: Optional[dict] = None - def __getitem__(self, center: Mapping[str, float]) -> torch.Tensor: - """Returns image data centered around the given point, based on the scale and shape of the target output image.""" - try: - if isinstance(list(center.values())[0], int | float): - self._current_center = center - - # Use cached coordinate offsets + translation (much faster than np.linspace) - # This eliminates repeated coordinate grid generation - coords = {c: self.coord_offsets[c] + center[c] for c in self.axes} - - # Bounds checking: warn when crop edges extend beyond the - # annotation bounding box. This is normal and handled by - # padding when the crop window is larger than the annotation - # volume, so log at DEBUG to avoid noise during training. - for c in self.axes: - if center[c] - self.output_size[c] / 2 < self.bounding_box[c][0]: - logger.debug( - f"Crop edge for axis {c} in image {self.path} extends " - f"below annotation bounds: {center[c] - self.output_size[c] / 2} " - f"< {self.bounding_box[c][0]} (center={center[c]})" - ) - if center[c] + self.output_size[c] / 2 > self.bounding_box[c][1]: - logger.debug( - f"Crop edge for axis {c} in image {self.path} extends " - f"above annotation bounds: {center[c] + self.output_size[c] / 2} " - f"> {self.bounding_box[c][1]} (center={center[c]})" - ) - - # Apply any spatial transformations to the coordinates and return the image data as a PyTorch tensor - data = self.apply_spatial_transforms(coords) - else: - self._current_center = {k: np.mean(v) for k, v in center.items()} - self._current_coords = center - # Optimized tensor creation: use torch.from_numpy when possible to avoid data copying - array_data = self.return_data(self._current_coords).values - if isinstance(array_data, np.ndarray): - data = torch.from_numpy(array_data) - else: - data = torch.tensor(array_data) - - # Apply any value transformations to the data - if self.value_transform is not None: - data = self.value_transform(data) - - # Return data on CPU - let the DataLoader handle device transfer with streams - # This avoids redundant transfers and allows for optimized batch transfers - return data - finally: - # Clear cached array property to prevent memory accumulation from xarray - # operations (interp/reindex/sel) during training iterations. The array - # will be reopened on next access if needed. Use finally to ensure cleanup - # even if an exception occurs during data retrieval. - self._clear_array_cache() + # ------------------------------------------------------------------ + # Zarr / OME-NGFF metadata (all cached after first access) + # ------------------------------------------------------------------ - def __repr__(self) -> str: - """Returns a string representation of the CellMapImage object.""" - return f"CellMapImage({self.array_path})" - - def _clear_array_cache(self) -> None: - """ - Clear the cached xarray DataArray to release intermediate objects. + @cached_property + def _zarr_group(self) -> zarr.Group: + return zarr.open_group(self.path, mode="r") - xarray operations (interp, reindex, sel) create intermediate arrays that - remain referenced through the DataArray. Clearing the cache after each - __getitem__ releases those references without closing the underlying - TensorStore handle, which is separately cached in _ts_store and reused. - """ - if "array" in self.__dict__: - del self.__dict__["array"] + @cached_property + def _multiscale_attrs(self) -> dict: + return dict(self._zarr_group.attrs)["multiscales"][0] @cached_property - def coord_offsets(self) -> Mapping[str, np.ndarray]: - """ - Cached coordinate offsets from center. + def _spatial_axes_order(self) -> list[str]: + """Spatial axis names as declared in the multiscale metadata.""" + result = [] + for ax in self._multiscale_attrs.get("axes", []): + if isinstance(ax, dict): + if ax.get("type", "space") == "space": + result.append(ax["name"]) + else: + result.append(str(ax)) + return result or self.axes - These offsets are constant for a given scale/shape and are used to - construct coordinate grids by simply adding the center position. - This eliminates repeated np.linspace calls in __getitem__. + @cached_property + def _level_info(self) -> list[tuple[str, dict[str, float], dict[str, float]]]: + """For each scale level: (path, voxel_size, origin) dicts keyed by axis.""" + levels = [] + spatial = self._spatial_axes_order + for ds in self._multiscale_attrs["datasets"]: + level_path: str = ds["path"] + voxel_size: dict[str, float] = {} + origin: dict[str, float] = {} + for tx in ds.get("coordinateTransformations", []): + if tx["type"] == "scale": + scales = tx["scale"] + # Skip leading non-spatial dims (e.g. channel) + spatial_scales = scales[-len(spatial) :] + voxel_size = {c: float(s) for c, s in zip(spatial, spatial_scales)} + elif tx["type"] == "translation": + trans = tx["translation"] + spatial_trans = trans[-len(spatial) :] + origin = {c: float(t) for c, t in zip(spatial, spatial_trans)} + if not origin: + origin = {c: 0.0 for c in spatial} + levels.append((level_path, voxel_size, origin)) + return levels - Returns - ------- - Mapping[str, np.ndarray] - Dictionary mapping axis names to coordinate offset arrays. - """ - return { - c: np.linspace( - -self.output_size[c] / 2 + self.scale[c] / 2, - self.output_size[c] / 2 - self.scale[c] / 2, - self.output_shape[c], - ) - for c in self.axes - } + @cached_property + def scale_level(self) -> int: + """Index of the best-matching scale level for ``self.scale``.""" + best_idx = 0 + best_score = float("inf") + for i, (_, vox_size, _) in enumerate(self._level_info): + # Sum of relative differences across requested axes + diffs = [ + abs(vox_size.get(ax, 0) - self.scale[ax]) / max(self.scale[ax], 1e-9) + for ax in self.axes + if ax in vox_size + ] + score = sum(diffs) + # Prefer finer (smaller) voxel sizes when equally close + if score < best_score: + best_score = score + best_idx = i + return best_idx @cached_property - def _array_shape(self) -> tuple[int, ...]: - """Shape of the selected scale-level array. + def _selected_level(self) -> tuple[str, dict[str, float], dict[str, float]]: + return self._level_info[self.scale_level] - Cached as a compact tuple of ints rather than opening the zarr array - on every access. Used by ``full_coords`` and ``shape`` so that both - share the same single zarr metadata read. - """ - return tuple(int(s) for s in self.group[self.scale_level].shape) + @cached_property + def _zarr_array(self) -> zarr.Array: + level_path, _, _ = self._selected_level + return zarr.open_array(f"{self.path}/{level_path}", mode="r") @cached_property - def shape(self) -> Mapping[str, int]: - """Returns the shape of the image.""" - return {c: s for c, s in zip(self.axes, self._array_shape)} + def _voxel_size(self) -> dict[str, float]: + _, vox_size, _ = self._selected_level + return vox_size @cached_property - def center(self) -> Mapping[str, float]: - """Returns the center of the image in world units.""" - return { - c: start + (stop - start) / 2 - for c, (start, stop) in self.bounding_box.items() - } + def _origin(self) -> dict[str, float]: + _, _, origin = self._selected_level + return origin + + # ------------------------------------------------------------------ + # Spatial properties + # ------------------------------------------------------------------ @cached_property - def multiscale_attrs(self) -> MultiscaleMetadata: - """Returns the multiscale metadata of the image.""" - return MultiscaleGroupAttrs( - multiscales=self.group.attrs["multiscales"] - ).multiscales[0] + def bounding_box(self) -> dict[str, tuple[float, float]]: + """World bounding box ``{axis: (min_nm, max_nm)}`` of the selected level.""" + arr_shape = self._zarr_array.shape + n_spatial = len(self.axes) + spatial_shape = arr_shape[-n_spatial:] + result: dict[str, tuple[float, float]] = {} + for i, ax in enumerate(self.axes): + start = self._origin.get(ax, 0.0) + end = start + spatial_shape[i] * self._voxel_size.get(ax, 1.0) + result[ax] = (start, end) + return result @cached_property - def coordinateTransformations( - self, - ) -> tuple[Scale] | tuple[Scale, Translation]: - """Returns the coordinate transformations of the image, based on the multiscale metadata.""" - # multi_tx = multi.coordinateTransformations - dset = [ - ds for ds in self.multiscale_attrs.datasets if ds.path == self.scale_level - ][0] - # tx_fused = normalize_transforms(multi_tx, dset.coordinateTransformations) - return dset.coordinateTransformations + def sampling_box(self) -> dict[str, tuple[float, float]] | None: + """Shrunk bounding box where patch centres can be drawn without going OOB. - @property - def full_coords(self) -> tuple[xarray.DataArray, ...]: - """Returns the full coordinates of the image's axes in world units. - - This is a plain ``@property`` (not ``@cached_property``) so the large - coordinate arrays are NOT kept alive between ``__getitem__`` calls. - ``bounding_box`` accesses it once (during its own cached initialisation) - and then the arrays are freed. ``array`` accesses it each time it is - rebuilt (after ``_clear_array_cache``), which is also once per - ``__getitem__``, so there is no performance regression. - - All inputs (``multiscale_attrs``, ``coordinateTransformations``, - ``_array_shape``) are individually cached, so reconstruction is pure - in-process arithmetic — no NFS reads after the first call. + Returns ``None`` if the array is smaller than the requested patch. """ - return coords_from_transforms( - axes=self.multiscale_attrs.axes, - transforms=self.coordinateTransformations, # type: ignore - shape=self._array_shape, # type: ignore + bb = self.bounding_box + result: dict[str, tuple[float, float]] = {} + for ax in self.axes: + half = self.output_size[ax] / 2.0 + lo = bb[ax][0] + half + hi = bb[ax][1] - half + if lo >= hi: + return None + result[ax] = (lo, hi) + return result + + def get_center(self, idx: int) -> dict[str, float]: + """World coordinates of the centre voxel for flat index *idx*. + + *idx* indexes into the regular grid defined by the sampling box and + ``self.scale`` (one point per output voxel). + """ + sb = self.sampling_box + if sb is None: + raise ValueError( + f"sampling_box is None for {self.path!r} " + f"(array too small for requested patch size)" + ) + grid = box_shape(sb, self.scale) + axes = list(sb.keys()) + shape_tuple = tuple(grid[ax] for ax in axes) + vox_idx = np.unravel_index(int(idx), shape_tuple) + return { + ax: sb[ax][0] + (vox_idx[i] + 0.5) * self.scale[ax] + for i, ax in enumerate(axes) + } + + # ------------------------------------------------------------------ + # Spatial-transform API (called by CellMapDataset before each read) + # ------------------------------------------------------------------ + + def set_spatial_transforms(self, transforms: dict | None) -> None: + """Store the spatial transforms that will be applied in the next ``__getitem__``.""" + self._current_spatial_transforms = transforms + + # ------------------------------------------------------------------ + # Read + # ------------------------------------------------------------------ + + def _compute_read_shape(self) -> list[int]: + """Read shape large enough to accommodate the current rotation.""" + base = [self.output_shape[ax] for ax in self.axes] + if self._current_spatial_transforms is None: + return base + R = self._current_spatial_transforms.get("rotation_matrix") + if R is None: + return base + R = np.asarray(R, dtype=float) + n = len(self.axes) + return [ + int(np.ceil(base[i] * sum(abs(R[i, j]) for j in range(n)))) + for i in range(n) + ] + + def __getitem__(self, center: Mapping[str, float]) -> torch.Tensor: + """Return a tensor patch centred at *center* (world coords in nm).""" + # Legacy: centre values may be arrays (take mean) + center = { + k: float(np.mean(v)) if not isinstance(v, (int, float)) else float(v) + for k, v in center.items() + } + + read_shape = self._compute_read_shape() + arr = self._zarr_array + spatial_ndim = len(self.axes) + arr_shape = arr.shape[-spatial_ndim:] + n_leading = arr.ndim - spatial_ndim + + slices: list[slice] = [] + pad_widths: list[tuple[int, int]] = [] + + for i, ax in enumerate(self.axes): + vs = self._voxel_size.get(ax, 1.0) + vox_center = (center[ax] - self._origin.get(ax, 0.0)) / vs + start = int(np.floor(vox_center - read_shape[i] / 2.0)) + end = start + read_shape[i] + + pad_lo = max(0, -start) + pad_hi = max(0, end - arr_shape[i]) + slices.append(slice(max(0, start), min(arr_shape[i], end))) + pad_widths.append((pad_lo, pad_hi)) + + # Prepend slices for leading non-spatial dims (e.g. channel) + full_slices: list[Any] = [slice(None)] * n_leading + slices + leading_pads: list[tuple[int, int]] = [(0, 0)] * n_leading + + data = torch.from_numpy(np.asarray(arr[tuple(full_slices)], dtype=np.float32)) + + # Pad if any region was out of bounds + if any(p for pair in pad_widths for p in pair): + if self.pad: + flat_pad: list[int] = [] + for pw in reversed(leading_pads + pad_widths): + flat_pad += [pw[0], pw[1]] + data = F.pad(data, flat_pad, mode="constant", value=self.pad_value) + else: + # Clamp: just use what we could read (shape will be smaller than + # requested if near an edge; caller gets less data) + pass + + # Resample if voxel size differs from target scale + needs_resample = any( + abs(self._voxel_size.get(ax, 1.0) - self.scale[ax]) + / max(self.scale[ax], 1e-9) + > 0.01 + for ax in self.axes ) + if needs_resample: + zoom = [self._voxel_size.get(ax, 1.0) / self.scale[ax] for ax in self.axes] + out_spatial = [ + max(1, int(round(read_shape[i] * zoom[i]))) for i in range(spatial_ndim) + ] + # Bring data to [N, C, *spatial] for interpolate + orig_ndim = data.ndim + while data.ndim < spatial_ndim + 2: + data = data.unsqueeze(0) + + if self.interpolation == "nearest": + mode = "nearest" + extra: dict = {} + else: + mode = "trilinear" if spatial_ndim == 3 else "bilinear" + extra = {"align_corners": False} - @cached_property - def scale_level(self) -> str: - """Returns the multiscale level of the image.""" - return self.find_level(self.scale) + data = F.interpolate(data, size=out_spatial, mode=mode, **extra) - @cached_property - def group(self) -> zarr.Group: - """Returns the zarr group object for the multiscale image.""" - if self.path[:5] == "s3://": - return zarr.open_group(zarr.N5FSStore(self.path, anon=True), mode="r") - return zarr.open_group(self.path, mode="r") + while data.ndim > orig_ndim: + data = data.squeeze(0) - @cached_property - def array_path(self) -> str: - """Returns the path to the single-scale image array.""" - return os.path.join(self.path, self.scale_level) + # Apply rotation if requested + R = ( + self._current_spatial_transforms.get("rotation_matrix") + if self._current_spatial_transforms + else None + ) + if R is not None: + data = self._apply_rotation(data, np.asarray(R, dtype=float)) + # Crop centre to target output shape + target_shape = [self.output_shape[ax] for ax in self.axes] + crop_slices: list[Any] = [slice(None)] * (data.ndim - spatial_ndim) + for i in range(spatial_ndim): + curr = data.shape[data.ndim - spatial_ndim + i] + lo = (curr - target_shape[i]) // 2 + crop_slices.append(slice(lo, lo + target_shape[i])) + data = data[tuple(crop_slices)] + + # Mirror + if self._current_spatial_transforms is not None: + mirror = self._current_spatial_transforms.get("mirror") + if mirror is not None: + for i, ax in enumerate(self.axes): + flag = ( + mirror.get(ax, False) if isinstance(mirror, dict) else mirror[i] + ) + if flag: + data = data.flip(data.ndim - spatial_ndim + i) + + # Transpose + if self._current_spatial_transforms is not None: + perm = self._current_spatial_transforms.get("transpose") + if perm is not None: + n_lead = data.ndim - spatial_ndim + full_perm = list(range(n_lead)) + [n_lead + p for p in perm] + data = data.permute(*full_perm).contiguous() - @cached_property - def _ts_store(self) -> ts.TensorStore: # type: ignore - """ - Opens and caches the TensorStore array handle. + if self.value_transform is not None: + data = self.value_transform(data) - ts.open() is called exactly once per CellMapImage instance and the - resulting handle is kept alive for the instance's lifetime. The handle - is lightweight (it holds a reference to the shared context and chunk - cache) and is safe to reuse across many __getitem__ calls. + return data - Separating this from the `array` cached_property means that clearing - `array` after each __getitem__ (to release xarray intermediate objects) - does not trigger a new ts.open() call on the next access. - """ - spec = xt._zarr_spec_from_path(self.array_path) - array_future = ts.open(spec, read=True, write=False, context=self.context) - try: - return array_future.result() - except ValueError as e: - logger.warning( - "Failed to open with default driver: %s. Falling back to zarr3 driver.", - e, - ) - spec["driver"] = "zarr3" - return ts.open(spec, read=True, write=False, context=self.context).result() + def _apply_rotation(self, data: torch.Tensor, R: np.ndarray) -> torch.Tensor: + """Apply a rotation matrix to the spatial dimensions of *data*. - @cached_property - def array(self) -> xarray.DataArray: + Uses ``torch.nn.functional.affine_grid`` + ``grid_sample`` so that + gradients can flow through if needed. For labels (``interpolation == + "nearest"``) nearest-neighbour resampling is used. """ - Returns the image data as an xarray DataArray. + spatial_ndim = len(self.axes) + original_ndim = data.ndim - This property is cached but is explicitly cleared after each __getitem__ - call to release xarray intermediate objects (from interp/reindex/sel) - that would otherwise accumulate during training. Clearing it is cheap - because the underlying TensorStore handle is separately cached in - _ts_store and is not reopened. - """ - if ( - os.environ.get("CELLMAP_DATA_BACKEND", "tensorstore").lower() - != "tensorstore" - ): - data = da.from_array( - self.group[self.scale_level], - chunks="auto", + # Bring to [N, C, *spatial] + while data.ndim < spatial_ndim + 2: + data = data.unsqueeze(0) + + N = data.shape[0] + + # affine_grid expects the INVERSE transform (output → input mapping) + R_inv = torch.tensor(R.T, dtype=torch.float32) + + if spatial_ndim == 3: + theta = ( + torch.cat([R_inv, torch.zeros(3, 1)], dim=1) + .unsqueeze(0) + .expand(N, -1, -1) ) else: - data = xt._TensorStoreAdapter(self._ts_store) - return xarray.DataArray(data=data, coords=self.full_coords) + theta = ( + torch.cat([R_inv[:2, :2], torch.zeros(2, 1)], dim=1) + .unsqueeze(0) + .expand(N, -1, -1) + ) - @cached_property - def translation(self) -> Mapping[str, float]: - """Returns the translation of the image.""" - return {c: self.bounding_box[c][0] for c in self.axes} + grid = F.affine_grid(theta, list(data.shape), align_corners=False) + mode = "nearest" if self.interpolation == "nearest" else "bilinear" + rotated = F.grid_sample( + data.float(), grid, mode=mode, padding_mode="zeros", align_corners=False + ) - @cached_property - def bounding_box(self) -> Mapping[str, list[float]]: - """Returns the bounding box of the dataset in world units.""" - bounding_box = {} - for coord in self.full_coords: - bounding_box[coord.dims[0]] = [ - coord.data.min(), - coord.data.max(), - ] - return bounding_box + # Replace zero-padded corners with pad_value (only matters for continuous data) + if not np.isnan(self.pad_value) and self.pad_value != 0.0: + # grid_sample fills OOB with 0; patch with pad_value + oob_mask = (grid[..., 0].abs() > 1) | (grid[..., 1].abs() > 1) + if spatial_ndim == 3: + oob_mask = oob_mask | (grid[..., 2].abs() > 1) + oob_mask = oob_mask.unsqueeze(1).expand_as(rotated) + rotated[oob_mask] = self.pad_value - @cached_property - def sampling_box(self) -> Optional[Mapping[str, list[float]]]: - """Returns the sampling box of the dataset (i.e. where centers can be drawn from and still have full samples drawn from within the bounding box), in world units.""" - sampling_box = {} - output_padding = {c: np.ceil(s / 2) for c, s in self.output_size.items()} - for c, (start, stop) in self.bounding_box.items(): - sampling_box[c] = [ - start + output_padding[c], - stop - output_padding[c], - ] - try: - assert ( - sampling_box[c][0] < sampling_box[c][1] - ), f"Sampling box for axis {c} is invalid: {sampling_box[c]} for image {self.path}. Image is not large enough to sample from as requested." - except AssertionError as e: - if self.pad: - sampling_box[c] = [ - self.center[c] - self.scale[c], - self.center[c] + self.scale[c], - ] - else: - raise e - return sampling_box + while rotated.ndim > original_ndim: + rotated = rotated.squeeze(0) - @cached_property - def bg_count(self) -> float: - """Returns the number of background pixels in the ground truth data, normalized by the resolution.""" - # Trigger class_counts, which sets self._bg_count as a side effect - _ = self.class_counts - return self._bg_count + return rotated - @cached_property - def class_counts(self) -> float: - """Returns the number of voxels for the contained class at the training resolution. + # ------------------------------------------------------------------ + # Class counts (for weighted sampling) + # ------------------------------------------------------------------ - Reads ``complement_counts`` from the s0-level ``.zattrs`` JSON directly - (bypassing zarr Array construction) then scales the s0 voxel counts to - the training resolution so that counts are comparable across datasets - regardless of their native s0 resolution. + @property + def class_counts(self) -> dict[str, int]: + """Foreground voxel count at s0, normalised to training-resolution voxels. + + Fast path reads pre-cached counts from ``s0/.zattrs``. Slow path + counts non-zero voxels in the s0 array and writes the result back. """ + # Fast path: check for cached counts in s0 attrs try: - # Read s0 attrs directly from the zarr store to avoid the overhead - # of opening a full zarr Array (which reads .zarray + .zattrs). - store = self.group.store - try: - # zarr v2 layout: s0/.zattrs and s0/.zarray - s0_attrs = json.loads(bytes(store["s0/.zattrs"]).decode("utf-8")) - s0_shape = json.loads(bytes(store["s0/.zarray"]).decode("utf-8"))[ - "shape" - ] - except KeyError: - # N5 / zarr v3 fallback: open via zarr API - s0_arr = self.group["s0"] - s0_attrs = dict(s0_arr.attrs) - s0_shape = list(s0_arr.shape) - - bg_s0 = int( - s0_attrs["cellmap"]["annotation"]["complement_counts"]["absent"] - ) - fg_s0 = int(np.prod(s0_shape)) - bg_s0 - - # s0 physical scale from the already-cached multiscale metadata. - s0_scale = None - for dataset in self.multiscale_attrs.datasets: - if dataset.path == "s0": - for transform in dataset.coordinateTransformations: - if isinstance(transform, VectorScale): - s0_scale = list(transform.scale) - break - break - - if s0_scale is None: - raise ValueError("s0 scale not found in multiscale metadata") - - # Convert s0 voxel counts to training-resolution voxel counts so - # that datasets with different native resolutions are weighted by - # how much data they actually contribute at the training scale. - # training_voxels = s0_voxels * (s0_voxel_vol / training_voxel_vol) - s0_voxel_vol = float(np.prod(s0_scale)) - training_voxel_vol = float(np.prod(list(self.scale.values()))) - scale_ratio = s0_voxel_vol / training_voxel_vol - - class_counts = float(fg_s0) * scale_ratio - self._bg_count = float(bg_s0) * scale_ratio - - except Exception as e: - # TODO: This fallback is very expensive; precompute complement_counts - # for all images via cellmap-schemas to avoid this path. - logger.warning( - "Unable to get class counts for %s from metadata, " - "falling back to calculating from array. Error: %s, %s", - self.path, - e, - type(e), - ) - # Fallback: read the array at training resolution and count voxels. - # No scale multiplication — counts are already in training voxels. - array_data = self.array.compute() - fg_training = int(np.count_nonzero(array_data)) - class_counts = float(fg_training) - self._bg_count = float(array_data.size - fg_training) - - # Write the computed counts back to s0/.zattrs so future calls use - # the fast path. Counts must be converted from training resolution - # back to s0 resolution because the metadata always lives at s0. + s0_path = self._level_info[0][0] + s0_attrs = dict(zarr.open_group(self.path, mode="r")[s0_path].attrs) + counts = s0_attrs.get("class_counts") + if counts is not None and self.label_class in counts: + raw_count = counts[self.label_class] + return {self.label_class: self._scale_count(raw_count, s0_idx=0)} + except Exception: + pass + + # Slow path: count non-zero voxels in s0 + try: + s0_path = self._level_info[0][0] + s0_arr = zarr.open_array(f"{self.path}/{s0_path}", mode="r") + fg_count = int(np.count_nonzero(s0_arr[:])) + # Cache result in s0/.zattrs try: - s0_scale = None - for _dataset in self.multiscale_attrs.datasets: - if _dataset.path == "s0": - for transform in _dataset.coordinateTransformations: - if isinstance(transform, VectorScale): - s0_scale = list(transform.scale) - break - break - - if s0_scale is None: - raise ValueError("s0 scale not found in multiscale metadata") - - # Get s0 array shape to compute the total s0 voxel count. - _store = self.group.store - try: - s0_shape = json.loads(bytes(_store["s0/.zarray"]).decode("utf-8"))[ - "shape" - ] - except KeyError: - s0_shape = list(self.group["s0"].shape) - - total_s0 = int(np.prod(s0_shape)) - s0_voxel_vol = float(np.prod(s0_scale)) - training_voxel_vol = float(np.prod(list(self.scale.values()))) - - # fg_s0 = fg_training * (training_voxel_vol / s0_voxel_vol) - fg_s0 = int(round(fg_training * training_voxel_vol / s0_voxel_vol)) - fg_s0 = min(fg_s0, total_s0) # clamp to valid range - bg_s0 = total_s0 - fg_s0 - - # Merge into existing s0 attrs via a writable group handle. - writable_group = zarr.open_group(self.path, mode="r+") - s0_zarr = writable_group["s0"] - existing_attrs = dict(s0_zarr.attrs) - cellmap = existing_attrs.setdefault("cellmap", {}) - annotation = cellmap.setdefault("annotation", {}) - complement_counts = annotation.setdefault("complement_counts", {}) - complement_counts["absent"] = bg_s0 - s0_zarr.attrs.update(existing_attrs) - logger.info( - "Wrote complement_counts metadata for %s: absent=%d", - self.path, - bg_s0, - ) - except Exception as write_err: - logger.warning( - "Unable to write complement_counts metadata for %s: %s", - self.path, - write_err, - ) - - return class_counts - - def to(self, device: str, *args, **kwargs) -> None: - """Sets what device returned image data will be loaded onto.""" - self.device = device - - def find_level(self, target_scale: Mapping[str, float]) -> str: - """Finds the multiscale level that is closest to the target scale.""" - # Get the order of axes in the image - axes = [] - for axis in self.group.attrs["multiscales"][0]["axes"]: - if axis["type"] == "space": - axes.append(axis["name"]) - - last_path: str | None = None - scale = {} - for level in self.group.attrs["multiscales"][0]["datasets"]: - for transform in level["coordinateTransformations"]: - if "scale" in transform: - scale = {c: s for c, s in zip(axes, transform["scale"])} - break - for c in axes: - if scale[c] > target_scale[c]: - if last_path is None: - return level["path"] - else: - return last_path - last_path = level["path"] - return last_path # type: ignore - - def rotate_coords( - self, coords: Mapping[str, Sequence[float]], angles: Mapping[str, float] - ) -> Mapping[str, tuple[Sequence[str], np.ndarray]] | Mapping[str, Sequence[float]]: - """Rotates the given coordinates by the given angles.""" - # Check to see if a rotation is necessary - if not any([a != 0 for a in angles.values()]): - return coords - - # Convert the coordinates dictionary to a vector - coords_vector, axes_lengths = self._coord_dict_to_vector(coords) - - # Recenter the coordinates around the origin - center = coords_vector.mean(axis=0) - coords_vector -= center - - rotation_vector = [angles[c] if c in angles else 0 for c in self.axes] - rotator = rot.from_rotvec(rotation_vector, degrees=True) - - # Apply the rotation - rotated_coords = rotator.apply(coords_vector) - - # Recenter the coordinates around the original center - rotated_coords += center - return self._coord_vector_to_grid_dict(rotated_coords, axes_lengths) - - def _coord_dict_to_vector( - self, coords_dict: Mapping[str, Sequence[float]] - ) -> tuple[np.ndarray, Mapping[str, int]]: - """Converts a dictionary of coordinates to a vector, for use with rotate_coords.""" - coord_vector = np.stack( - np.meshgrid(*[coords_dict[c] for c in self.axes]), axis=-1 - ).reshape(-1, len(self.axes)) - axes_lengths = {c: len(coords_dict[c]) for c in self.axes} - return coord_vector, axes_lengths - - def _coord_vector_to_grid_dict( - self, coords_vector: np.ndarray, axes_lengths: Mapping[str, int] - ) -> Mapping[str, tuple[Sequence[str], np.ndarray]]: - """Converts a vector of coordinates to a grid type dictionary.""" - shape = [axes_lengths[c] for c in self.axes] - axes = [c for c in self.axes] - coords_dict = { - c: (axes, coords_vector[:, self.axes.index(c)].reshape(shape)) - for c in self.axes - } - - return coords_dict - - def set_spatial_transforms(self, transforms: Mapping[str, Any] | None) -> None: - """Sets spatial transformations for the image data, for setting global transforms at the 'dataset' level.""" - self._current_spatial_transforms = transforms - - def apply_spatial_transforms(self, coords) -> torch.Tensor: - """Applies spatial transformations to the given coordinates.""" - # Apply spatial transformations to the coordinates - # Because some spatial transformations require the image array, we need to apply them after pulling the data. This is done by separating the transforms into two groups - if self._current_spatial_transforms is not None: - # Because of the implementation details, we explicitly apply transforms in a specific order - if "mirror" in self._current_spatial_transforms: - for axis in self._current_spatial_transforms["mirror"]: - # Assumes the coords are the default xarray format - coords[axis] = coords[axis][::-1] - if "rotate" in self._current_spatial_transforms: - # Assumes the coords are the default xarray format, and that the rotation is in degrees - # Converts the coordinates to a vector, rotates them, then converts them to a grid dictionary - coords = self.rotate_coords( - coords, self._current_spatial_transforms["rotate"] - ) - if "deform" in self._current_spatial_transforms: - raise NotImplementedError("Deformations are not yet implemented.") - self._current_coords = coords - - # Pull data from the image - data = self.return_data(coords) - data = data.values - - # Apply and spatial transformations that require the image array (e.g. transpose) - if self._current_spatial_transforms is not None: - if "transpose" in self._current_spatial_transforms: - new_order = [ - self._current_spatial_transforms["transpose"][c] for c in self.axes - ] - data = np.transpose(data, new_order) - - # Optimized tensor creation: use torch.from_numpy when possible to avoid data copying - if isinstance(data, np.ndarray): - return torch.from_numpy(data) - else: - return torch.tensor(data) + g = zarr.open_group(self.path, mode="r+") + attrs = dict(g[s0_path].attrs) + if "class_counts" not in attrs: + attrs["class_counts"] = {} + attrs["class_counts"][self.label_class] = fg_count + g[s0_path].attrs.update(attrs) + except Exception: + pass + return {self.label_class: self._scale_count(fg_count, s0_idx=0)} + except Exception as exc: + logger.warning("class_counts failed for %s: %s", self.path, exc) + return {self.label_class: 0} + + def _scale_count(self, s0_count: int, s0_idx: int = 0) -> int: + """Scale a voxel count from s0 resolution to training resolution.""" + try: + _, s0_vox, _ = self._level_info[s0_idx] + s0_vol = 1.0 + train_vol = 1.0 + for ax in self.axes: + s0_vol *= s0_vox.get(ax, 1.0) + train_vol *= self.scale.get(ax, 1.0) + if train_vol == 0: + return s0_count + return int(s0_count * (s0_vol / train_vol)) + except Exception: + return s0_count + + # ------------------------------------------------------------------ + # Misc + # ------------------------------------------------------------------ + + def to(self, device: str | torch.device) -> "CellMapImage": + """No-op (tensors are always returned on CPU). Kept for API compatibility.""" + return self - @cached_property - def tolerance(self) -> float: - """Returns the tolerance for nearest neighbor interpolation.""" - # Calculate the tolerance as half the norm of the original image scale (i.e. traversing half a pixel diagonally) + epsilon (1e-6) - actual_scale = [ - ct for ct in self.coordinateTransformations if isinstance(ct, VectorScale) - ][0].scale - half_diagonal = np.linalg.norm(actual_scale) / 2 - return float(half_diagonal + 1e-6) - - def return_data( - self, - coords: ( - Mapping[str, Sequence[float]] - | Mapping[str, tuple[Sequence[str], np.ndarray]] - ), - ) -> xarray.DataArray: - """Pulls data from the image based on the given coordinates, applying interpolation if necessary, and returns the data as an xarray DataArray.""" - if not isinstance(list(coords.values())[0][0], (float, int)): - data = self.array.interp( - coords=coords, - method=self.interpolation, # type: ignore - ) - elif self.pad: - data = self.array.reindex( - **(coords), # type: ignore - method="nearest", - tolerance=self.tolerance, - fill_value=self.pad_value, - ) - else: - data = self.array.sel(**(coords), method="nearest") # type: ignore - if ( - os.environ.get("CELLMAP_DATA_BACKEND", "tensorstore").lower() - != "tensorstore" - ): - # NOTE: Forcing eager loading of dask array here may cause high memory usage and block further lazy optimizations. - data = data.compute() - return data + def __repr__(self) -> str: + return ( + f"CellMapImage({self.path!r}, class={self.label_class!r}, " + f"scale={list(self.scale.values())}, shape={list(self.output_shape.values())})" + ) diff --git a/src/cellmap_data/image_writer.py b/src/cellmap_data/image_writer.py index 090a758..9145954 100644 --- a/src/cellmap_data/image_writer.py +++ b/src/cellmap_data/image_writer.py @@ -1,28 +1,48 @@ +"""ImageWriter: writes patch data back to a zarr array at world coordinates.""" + +from __future__ import annotations + import logging import os from functools import cached_property -from typing import Mapping, Optional, Sequence, Union +from typing import Mapping, Optional, Sequence import numpy as np -import tensorstore import torch -import xarray -import xarray_tensorstore as xt +import zarr from numpy.typing import ArrayLike -from pydantic_ome_ngff.v04.axis import Axis -from pydantic_ome_ngff.v04.transform import VectorScale, VectorTranslation from upath import UPath -from xarray_ome_ngff.v04.multiscale import coords_from_transforms -from cellmap_data.utils import create_multiscale_metadata +from .utils.metadata import create_multiscale_metadata logger = logging.getLogger(__name__) class ImageWriter: - """ - This class is used to write image data to a single-resolution zarr. - ...existing docstring... + """Write patches of a single class to a single-resolution zarr array. + + Parameters + ---------- + path: + Base path of the zarr group (e.g. ``/out/predictions.zarr/mito``). + target_class: + Semantic label (e.g. ``"mito"``). + scale: + Voxel size in nm per spatial axis (dict or sequence in *axis_order*). + bounding_box: + World bounding box ``{axis: (min_nm, max_nm)}``. + write_voxel_shape: + Patch size in voxels (dict or sequence in *axis_order*). + scale_level: + Scale level index to write (default 0 = full resolution). + axis_order: + Spatial axis names. + overwrite: + If ``True``, existing data at *path* is overwritten. + dtype: + Output dtype (default ``float32``). + fill_value: + Value to pre-fill the output array with (default ``0``). """ def __init__( @@ -30,311 +50,199 @@ def __init__( path: str | UPath, target_class: str, scale: Mapping[str, float] | Sequence[float], - bounding_box: Mapping[str, list[float]], + bounding_box: Mapping[str, Sequence[float]], write_voxel_shape: Mapping[str, int] | Sequence[int], scale_level: int = 0, axis_order: str = "zyx", - context: Optional[tensorstore.Context] = None, + context: Optional[object] = None, # ignored – kept for API compat overwrite: bool = False, dtype: np.dtype = np.float32, fill_value: float | int = 0, ) -> None: self.base_path = str(path) - self.path = (UPath(path) / f"s{scale_level}").path self.label_class = self.target_class = target_class + self.scale_level = scale_level + self.overwrite = overwrite + self.dtype = dtype + self.fill_value = fill_value + if isinstance(scale, Sequence): - scale = {c: s for c, s in zip(axis_order[::-1], scale[::-1])} - self.scale = scale + scale = {c: float(s) for c, s in zip(axis_order, scale)} + self.scale: dict[str, float] = dict(scale) + + self.axes: str = axis_order + self.spatial_axes: list[str] = list(axis_order[-len(self.scale) :]) + if isinstance(write_voxel_shape, Sequence): - if len(axis_order) > len(write_voxel_shape): # TODO: This might be a bug + if len(axis_order) > len(write_voxel_shape): write_voxel_shape = [1] * ( len(axis_order) - len(write_voxel_shape) ) + list(write_voxel_shape) - elif ( - len(axis_order) + 1 == len(write_voxel_shape) and "c" not in axis_order - ): - axis_order = "c" + axis_order - write_voxel_shape = {c: t for c, t in zip(axis_order, write_voxel_shape)} - self.axes = axis_order - # Assume axes correspond to last dimensions of voxel shape - self.spatial_axes = axis_order[-len(scale) :] - self.bounding_box = bounding_box - self.write_voxel_shape = write_voxel_shape - self.write_world_shape = { - c: write_voxel_shape[c] * scale[c] for c in self.spatial_axes - } - self.scale_level = scale_level - self.context = context - self.overwrite = overwrite - self.dtype = dtype - self.fill_value = fill_value - self.metadata = { - "offset": list(self.offset.values()), - "axes": [c for c in axis_order], - "voxel_size": list(self.scale.values()), - "shape": list(self.shape.values()), - "units": "nanometer", - "chunk_shape": list(write_voxel_shape.values()), + write_voxel_shape = { + c: int(t) for c, t in zip(axis_order, write_voxel_shape) + } + self.write_voxel_shape: dict[str, int] = dict(write_voxel_shape) + self.write_world_shape: dict[str, float] = { + c: self.write_voxel_shape[c] * self.scale[c] for c in self.spatial_axes } - @cached_property - def array(self) -> xarray.DataArray: - os.makedirs(UPath(self.base_path), exist_ok=True) - group_path = str(self.base_path).split(".zarr")[0] + ".zarr" - for group in [""] + list(UPath(str(self.base_path).split(".zarr")[-1]).parts)[ - 1: - ]: - group_path = UPath(group_path) / group - with open(group_path / ".zgroup", "w") as f: - f.write('{"zarr_format": 2}') - create_multiscale_metadata( - ds_name=str(self.base_path), - voxel_size=self.metadata["voxel_size"], - translation=self.metadata["offset"], - units=self.metadata["units"], - axes=self.metadata["axes"], - base_scale_level=self.scale_level, - levels_to_add=0, - out_path=str(UPath(self.base_path) / ".zattrs"), - ) - spec = { - "driver": "zarr", - "kvstore": {"driver": "file", "path": self.path}, - } - open_kwargs = { - "read": True, - "write": True, - "create": True, - "delete_existing": self.overwrite, - "dtype": self.dtype, - "shape": list(self.shape.values()), - "fill_value": self.fill_value, - "chunk_layout": tensorstore.ChunkLayout(write_chunk_shape=self.chunk_shape), - "context": self.context, + self.bounding_box: dict[str, tuple[float, float]] = { + c: (float(bounding_box[c][0]), float(bounding_box[c][1])) + for c in self.spatial_axes } - array_future = tensorstore.open( - spec, - **open_kwargs, - ) - try: - array = array_future.result() - except ValueError as e: - if "ALREADY_EXISTS" in str(e): - raise FileExistsError( - f"Image already exists at {self.path}. Set overwrite=True to overwrite the image." - ) - logger.warning("Error opening with zarr driver: %s", e) - logger.warning("Falling back to zarr3 driver") - spec["driver"] = "zarr3" - array_future = tensorstore.open(spec, **open_kwargs) - array = array_future.result() - data = xarray.DataArray( - data=xt._TensorStoreAdapter(array), - coords=coords_from_transforms( - axes=[ - Axis( - name=c, - type="space" if c != "c" else "channel", - unit="nm" if c != "c" else "", - ) - for c in self.axes - ], - transforms=( - VectorScale(scale=tuple(self.scale.values())), - VectorTranslation(translation=tuple(self.offset.values())), - ), - shape=tuple(self.shape.values()), - ), - ) - with open(UPath(self.path) / ".zattrs", "w") as f: - f.write("{}") - return data + + # ------------------------------------------------------------------ + # Cached properties + # ------------------------------------------------------------------ @cached_property - def chunk_shape(self) -> Sequence[int]: - return list(self.write_voxel_shape.values()) + def offset(self) -> dict[str, float]: + return {c: self.bounding_box[c][0] for c in self.spatial_axes} @cached_property - def world_shape(self) -> Mapping[str, float]: + def world_shape(self) -> dict[str, float]: return { c: self.bounding_box[c][1] - self.bounding_box[c][0] for c in self.spatial_axes } @cached_property - def shape(self) -> Mapping[str, int]: + def shape(self) -> dict[str, int]: return { c: int(np.ceil(self.world_shape[c] / self.scale[c])) for c in self.spatial_axes } @cached_property - def center(self) -> Mapping[str, float]: - return {str(k): float(np.mean(v)) for k, v in self.array.coords.items()} + def chunk_shape(self) -> list[int]: + return [self.write_voxel_shape[c] for c in self.spatial_axes] @cached_property - def offset(self) -> Mapping[str, float]: - return {c: self.bounding_box[c][0] for c in self.spatial_axes} + def array_path(self) -> str: + return str(UPath(self.base_path) / f"s{self.scale_level}") @cached_property - def full_coords(self) -> tuple[xarray.DataArray, ...]: - return coords_from_transforms( - axes=[ - Axis( - name=c, - type="space" if c != "c" else "channel", - unit="nm" if c != "c" else "", - ) - for c in self.axes - ], - transforms=( - VectorScale(scale=tuple(self.scale.values())), - VectorTranslation(translation=tuple(self.offset.values())), - ), - shape=tuple(self.shape.values()), + def _zarr_array(self) -> zarr.Array: + """Open (creating if necessary) the output zarr array.""" + os.makedirs(str(UPath(self.base_path)), exist_ok=True) + + # Ensure every ancestor group has a .zgroup + group_path = str(self.base_path).split(".zarr")[0] + ".zarr" + inner = UPath(str(self.base_path).split(".zarr")[-1]) + for part in [""] + list(inner.parts)[1:]: + gp = str(UPath(group_path) / part) + zgroup = UPath(gp) / ".zgroup" + if not zgroup.exists(): + os.makedirs(gp, exist_ok=True) + zgroup.write_text('{"zarr_format": 2}') + group_path = gp + + # Write OME-NGFF multiscale metadata + create_multiscale_metadata( + ds_name=self.base_path, + voxel_size=[self.scale[c] for c in self.spatial_axes], + translation=[self.offset[c] for c in self.spatial_axes], + units="nanometer", + axes=self.spatial_axes, + base_scale_level=self.scale_level, + levels_to_add=0, + out_path=str(UPath(self.base_path) / ".zattrs"), ) - def align_coords( - self, coords: Mapping[str, tuple[Sequence, np.ndarray]] - ) -> Mapping[str, tuple[Sequence, np.ndarray]]: - aligned_coords = {} - for c in self.spatial_axes: - aligned_coords[c] = np.array( - self.array.coords[c][ - np.abs(np.array(self.array.coords[c])[:, None] - coords[c]).argmin( - axis=0 - ) - ] - ).squeeze() - return aligned_coords - - def aligned_coords_from_center(self, center: Mapping[str, float]): - coords = {} - for c in self.axes: - # Use center-of-voxel alignment - start_requested = ( - center[c] - self.write_world_shape[c] / 2 + self.scale[c] / 2 - ) - start_aligned_idx = int( - np.abs(self.array.coords[c] - start_requested).argmin() - ) - coords[c] = self.array.coords[c][ - start_aligned_idx : start_aligned_idx + self.write_voxel_shape[c] - ] - return coords + total_shape = [self.shape[c] for c in self.spatial_axes] + arr = zarr.open_array( + self.array_path, + mode="w" if self.overwrite else "a", + shape=total_shape, + dtype=self.dtype, + chunks=self.chunk_shape, + fill_value=self.fill_value, + ) + # Empty attrs for scale-level array + with open(str(UPath(self.array_path) / ".zattrs"), "w") as f: + f.write("{}") + return arr + + # ------------------------------------------------------------------ + # Write + # ------------------------------------------------------------------ def __setitem__( self, - coords: Union[Mapping[str, float], Mapping[str, tuple[Sequence, np.ndarray]]], - data: Union[torch.Tensor, ArrayLike, float, int], + coords: Mapping[str, float] | Mapping[str, Sequence], + data: torch.Tensor | ArrayLike | float | int, ) -> None: - """ - Set data at the specified coordinates. - - This method handles two types of coordinate inputs: - 1. Center coordinates: mapping axis names to float values - 2. Batch coordinates: mapping axis names to sequences of coordinates + """Write *data* at the location given by *coords*. - Args: - ---- - coords: Either center coordinates or batch coordinates - data: Data to write at the coordinates + *coords* can be: + - ``{axis: float}`` centre coordinates — single patch. + - ``{axis: Sequence[float]}`` centres — batch. """ - first_coord_value = next(iter(coords.values())) - - if isinstance(first_coord_value, (int, float)): - # Handle single item with center coordinates - self._write_single_item(coords, data) # type: ignore + first = next(iter(coords.values())) + if isinstance(first, (int, float)): + self._write_single(coords, data) # type: ignore[arg-type] else: - # Handle batch of items with coordinate sequences - self._write_batch_items(coords, data) # type: ignore + self._write_batch(coords, data) # type: ignore[arg-type] - def _write_single_item( + def _write_single( self, - center_coords: Mapping[str, float], - data: Union[torch.Tensor, ArrayLike], + center: Mapping[str, float], + data: torch.Tensor | ArrayLike, ) -> None: - """Write a single data item using center coordinates.""" - # Convert center coordinates to aligned array coordinates - aligned_coords = self.aligned_coords_from_center(center_coords) + arr = self._zarr_array + arr_shape = [self.shape[c] for c in self.spatial_axes] + + slices: list[slice] = [] + for i, c in enumerate(self.spatial_axes): + start_nm = center[c] - self.write_world_shape[c] / 2.0 + start_vox = int(round((start_nm - self.offset[c]) / self.scale[c])) + end_vox = start_vox + self.write_voxel_shape[c] + clamp_start = max(0, start_vox) + clamp_end = min(arr_shape[i], end_vox) + slices.append(slice(clamp_start, clamp_end)) - # Convert data to numpy array with correct dtype if isinstance(data, torch.Tensor): - data = data.cpu().numpy() - data_array = np.array(data).astype(self.dtype) - - # Remove batch dimension if present - if data_array.ndim == len(self.axes) + 1 and data_array.shape[0] == 1: - data_array = np.squeeze(data_array, axis=0) - - # Check for shape mismatches - expected_shape = tuple(self.write_voxel_shape[c] for c in self.axes) - if data_array.shape != expected_shape: - if len(data_array.shape) < len(expected_shape) and 1 in expected_shape: - # Try to expand dimensions to fit expected shape - for axis, size in enumerate(expected_shape): - if size == 1: - data_array = np.expand_dims(data_array, axis=axis) - else: - raise ValueError( - f"Data shape {data_array.shape} does not match expected shape {expected_shape}." - ) - coord_shape = tuple(len(aligned_coords[c]) for c in self.axes) - if coord_shape != expected_shape: - # Try to crop data to fit within bounds if necessary - min_shape = tuple(min(c, e) for c, e in zip(coord_shape, expected_shape)) - slices = tuple(slice(0, s) for s in min_shape) - data_array = data_array[slices] - if data_array.shape != coord_shape: - raise ValueError( - f"Aligned coordinates shape {coord_shape} does not match expected shape {expected_shape}." - ) - UserWarning( - f"Data shape cropped to {data_array.shape} to fit within bounds." - ) - - # Write to array - self.array.loc[aligned_coords] = data_array - - def _write_batch_items( - self, - batch_coords: Mapping[str, tuple[Sequence, np.ndarray]], - data: Union[torch.Tensor, ArrayLike], - ) -> None: - """Write multiple data items by iterating through coordinate batches.""" - # Do for each item in the batch - for i in range(data.shape[0]): - # Extract center coordinates for this item - item_coords = {axis: batch_coords[axis][i] for axis in self.axes} - - # Extract data for this item - item_data = data[i] # type: ignore + data_np = data.detach().cpu().numpy() + else: + data_np = np.asarray(data) + data_np = data_np.astype(self.dtype) - # Write this single item using center coordinates - self._write_single_item(item_coords, item_data) + # Strip batch / channel leading dims of size 1 + while data_np.ndim > len(self.spatial_axes) and data_np.shape[0] == 1: + data_np = data_np.squeeze(0) - def __repr__(self) -> str: - return f"ImageWriter({self.path}: {self.label_class} @ {list(self.scale.values())} {self.metadata['units']})" + # Crop data to clamped region (near array edges) + actual = tuple(s.stop - s.start for s in slices) + if data_np.shape != actual: + data_np = data_np[tuple(slice(0, e) for e in actual)] - def __getitem__( - self, coords: Mapping[str, float] | Mapping[str, tuple[Sequence, np.ndarray]] - ) -> torch.Tensor: - """ - Get the image data at the specified center coordinates. + arr[tuple(slices)] = data_np - Args: - ---- - coords (Mapping[str, float] | Mapping[str, tuple[Sequence, np.ndarray]]): The center coordinates or aligned coordinates. + def _write_batch( + self, + batch_coords: Mapping[str, Sequence], + data: torch.Tensor | ArrayLike, + ) -> None: + n = len(next(iter(batch_coords.values()))) + for i in range(n): + center = {ax: float(batch_coords[ax][i]) for ax in self.spatial_axes} + item = data[i] if hasattr(data, "__getitem__") else data # type: ignore[index] + self._write_single(center, item) + + def __getitem__(self, coords: Mapping[str, float]) -> torch.Tensor: + """Read the patch centred at *coords*.""" + arr = self._zarr_array + arr_shape = [self.shape[c] for c in self.spatial_axes] + slices: list[slice] = [] + for i, c in enumerate(self.spatial_axes): + start_nm = coords[c] - self.write_world_shape[c] / 2.0 + start_vox = int(round((start_nm - self.offset[c]) / self.scale[c])) + end_vox = start_vox + self.write_voxel_shape[c] + slices.append(slice(max(0, start_vox), min(arr_shape[i], end_vox))) + return torch.from_numpy(np.array(arr[tuple(slices)])) - Returns: - ------- - torch.Tensor: The image data at the specified center. - """ - # Check if center or coords are provided - if isinstance(list(coords.values())[0], int | float): - center = coords - aligned_coords = self.aligned_coords_from_center(center) # type: ignore - else: - # If coords are provided, align them - aligned_coords = self.align_coords(coords) # type: ignore - return torch.tensor(self.array.loc[aligned_coords].data).squeeze() + def __repr__(self) -> str: + return ( + f"ImageWriter({self.base_path!r}: {self.label_class!r} " + f"@ {list(self.scale.values())} nm)" + ) diff --git a/src/cellmap_data/multidataset.py b/src/cellmap_data/multidataset.py index f999b25..a9cc49c 100644 --- a/src/cellmap_data/multidataset.py +++ b/src/cellmap_data/multidataset.py @@ -1,426 +1,148 @@ -import functools -from functools import cached_property +"""CellMapMultiDataset: combines multiple CellMapDataset instances.""" + +from __future__ import annotations + import logging +from functools import cached_property from typing import Any, Callable, Mapping, Optional, Sequence import numpy as np import torch -from torch.utils.data import ConcatDataset, WeightedRandomSampler +from torch.utils.data import ConcatDataset from tqdm import tqdm -from .base_dataset import CellMapBaseDataset from .dataset import CellMapDataset -from .mutable_sampler import MutableSubsetRandomSampler -from .utils.sampling import min_redundant_inds logger = logging.getLogger(__name__) -class CellMapMultiDataset(CellMapBaseDataset, ConcatDataset): - """ - This class is used to combine multiple datasets into a single dataset. It is a subclass of PyTorch's ConcatDataset. It maintains the same API as the ConcatDataset class. It retrieves raw and groundtruth data from multiple CellMapDataset objects. See the CellMapDataset class for more information on the dataset object. - - Attributes - ---------- - classes: Sequence[str] - The classes in the dataset. - input_arrays: Mapping[str, Mapping[str, Sequence[int | float]]] - The input arrays for each dataset in the multi-dataset. - target_arrays: Mapping[str, Mapping[str, Sequence[int | float]]] - The target arrays for each dataset in the multi-dataset. - datasets: Sequence[CellMapDataset] - The datasets to be combined into the multi-dataset. - - Methods - ------- - to(device: str | torch.device) -> "CellMapMultiDataset": - Moves the multi-dataset to the specified device. - get_weighted_sampler(batch_size: int = 1, rng: Optional[torch.Generator] = None) -> WeightedRandomSampler: - Returns a weighted random sampler for the multi-dataset. - get_subset_random_sampler(num_samples: int, weighted: bool = True, rng: Optional[torch.Generator] = None) -> torch.utils.data.SubsetRandomSampler: - Returns a random sampler that samples num_samples from the multi-dataset. - get_indices(chunk_size: Mapping[str, int]) -> Sequence[int]: - Returns the indices of the multi-dataset that will tile all of the datasets according to the requested chunk_size. - set_raw_value_transforms(transforms: Callable) -> None: - Sets the raw value transforms for each dataset in the multi-dataset. - set_target_value_transforms(transforms: Callable) -> None: - Sets the target value transforms for each dataset in the multi-dataset. - set_spatial_transforms(spatial_transforms: Mapping[str, Any] | None) -> None: - Sets the spatial transforms for each dataset in the multi-dataset. +class CellMapMultiDataset(ConcatDataset): + """Concatenates multiple :class:`CellMapDataset` instances. - Properties: - class_counts: Mapping[str, float] - Returns a nested dictionary containing the number of samples in each class for each dataset in the multi-dataset, with class-specific counts nested under a 'totals' key. - class_weights: Mapping[str, float] - Returns the class weights for the multi-dataset based on the number of samples in each class. - dataset_weights: Mapping[CellMapDataset, float] - Returns the weights for each dataset in the multi-dataset based on the number of samples of each class in each dataset. - sample_weights: Sequence[float] - Returns the weights for each sample in the multi-dataset based on the number of samples in each dataset. - validation_indices: Sequence[int] - Returns the indices of the validation set for each dataset in the multi-dataset. + Provides aggregate ``class_counts``, ``get_crop_class_matrix``, and + ``validation_indices`` over all constituent datasets, which are required + by :class:`~cellmap_data.sampler.ClassBalancedSampler` and + :class:`~cellmap_data.dataloader.CellMapDataLoader`. + Parameters + ---------- + datasets: + List of :class:`CellMapDataset` objects to concatenate. + classes: + Shared segmentation classes (must match each dataset's ``classes``). + input_arrays: + Shared input array specs. + target_arrays: + Shared target array specs. """ def __init__( self, - classes: Sequence[str] | None, - input_arrays: Mapping[str, Mapping[str, Sequence[int | float]]], - target_arrays: Mapping[str, Mapping[str, Sequence[int | float]]] | None, datasets: Sequence[CellMapDataset], + classes: Sequence[str], + input_arrays: Mapping[str, Mapping[str, Any]], + target_arrays: Mapping[str, Mapping[str, Any]], ) -> None: - super().__init__(datasets) - self.input_arrays = input_arrays - self.target_arrays = target_arrays if target_arrays is not None else {} - self.classes = classes if classes is not None else [] + super().__init__(datasets) # initialises ConcatDataset + self.classes = list(classes) + self.input_arrays = dict(input_arrays) + self.target_arrays = dict(target_arrays) - def __repr__(self) -> str: - out_string = "CellMapMultiDataset([" - for dataset in self.datasets: - out_string += f"\n\t{dataset}," - out_string += "\n])" - return out_string - - def __reduce__(self): - """ - Support pickling for multiprocessing DataLoader and spawned processes. - """ - # These are the args __init__ needs: - args = (self.classes, self.input_arrays, self.target_arrays, self.datasets) - # Return: (callable, args_for_constructor, state_dict) - return (self.__class__, args, self.__dict__) + # ------------------------------------------------------------------ + # Class weights / sampling + # ------------------------------------------------------------------ @property - def has_data(self) -> bool: - """ - Returns True if the multi-dataset has data, i.e., if it contains any datasets. - """ - return len(self) > 0 + def class_counts(self) -> dict[str, Any]: + """Aggregate foreground voxel counts across all datasets. - @cached_property - def class_counts(self) -> dict[str, dict[str, float]]: + Sequential scan (parallelism offers no benefit over NFS; see + project MEMORY.md notes on ``CellMapMultiDataset.class_counts``). """ - Returns the number of samples in each class for each dataset in the multi-dataset, as well as the total number of samples in each class. - """ - classes: list[str] = list(self.classes or []) - class_counts: dict[str, dict[str, float]] = { - "totals": {c: 0.0 for c in classes} - } - class_counts["totals"].update({c + "_bg": 0.0 for c in classes}) - n_datasets = len(self.datasets) - - # Short-circuit if no classes or no datasets to avoid unnecessary computation - if not classes: - logger.info("No classes configured; returning empty totals dict") - return class_counts - if n_datasets == 0: - logger.info( - "No datasets to gather counts for; returning zero-initialized totals for configured classes" - ) - return class_counts + totals: dict[str, int] = {cls: 0 for cls in self.classes} + for ds in tqdm(self.datasets, desc="Counting class voxels", leave=False): + ds_counts = ds.class_counts.get("totals", {}) + for cls in self.classes: + totals[cls] += ds_counts.get(cls, 0) + return {"totals": totals} - logger.info("Gathering class counts for %d datasets...", n_datasets) - - # Sequential scan: class_counts is now a fast metadata read (two JSON - # files per image via the zarr store). A thread pool offered no - # meaningful speedup and caused all workers to block simultaneously on - # NFS hard-mounts, making progress stall at 0/N indefinitely. - for ds in tqdm(self.datasets, desc="Gathering class counts"): - ds_counts = ds.class_counts - for c in classes: - if c in ds_counts["totals"]: - class_counts["totals"][c] += ds_counts["totals"][c] - class_counts["totals"][c + "_bg"] += ds_counts["totals"][c + "_bg"] - return class_counts - - @cached_property + @property def class_weights(self) -> dict[str, float]: - """ - Returns the class weights for the multi-dataset based on the number of samples in each class. - """ - if self.classes is None: - return {} - return { - c: ( - self.class_counts["totals"][c + "_bg"] / self.class_counts["totals"][c] - if self.class_counts["totals"][c] != 0 - else 1 - ) - for c in self.classes - } + """Per-class sampling weight: ``bg_voxels / fg_voxels``.""" + counts = self.class_counts["totals"] + total_voxels = sum(counts.values()) + weights: dict[str, float] = {} + for cls in self.classes: + fg = counts.get(cls, 0) + bg = total_voxels - fg + weights[cls] = float(bg) / float(max(fg, 1)) + return weights + + def get_crop_class_matrix(self) -> np.ndarray: + """Stack ``[n_crops, n_classes]`` bool matrix from all datasets.""" + return np.vstack([ds.get_crop_class_matrix() for ds in self.datasets]) + + # ------------------------------------------------------------------ + # Validation + # ------------------------------------------------------------------ @cached_property - def dataset_weights(self) -> Mapping[CellMapDataset, float]: - """ - Returns the weights for each dataset in the multi-dataset based on the number of samples in each dataset. - """ - dataset_weights = {} - for dataset in self.datasets: - if len(self.classes) == 0: - # If no classes are defined, assign equal weight to all datasets - dataset_weight = 1.0 - else: - dataset_weight = np.sum( - [ - dataset.class_counts["totals"][c] * self.class_weights[c] # type: ignore - for c in self.classes - ] + def validation_indices(self) -> list[int]: + """Non-overlapping tile indices across all datasets (for validation).""" + indices: list[int] = [] + offset = 0 + for ds in self.datasets: + # Use the output size of the first target array as the tile size + first_target_spec = next(iter(ds.target_arrays.values())) + scale = { + c: float(s) + for c, s in zip( + next(iter(ds.target_sources.values())).axes, + first_target_spec["scale"], ) - dataset_weight *= (1 / len(dataset)) if len(dataset) > 0 else 0 # type: ignore - dataset_weights[dataset] = dataset_weight - return dataset_weights - - @cached_property - def sample_weights(self) -> Sequence[float]: - """ - Returns the weights for each sample in the multi-dataset based on the number of samples in each dataset. - """ - sample_weights = [] - for dataset, dataset_weight in self.dataset_weights.items(): - sample_weights += [dataset_weight] * len(dataset) - return sample_weights - - @cached_property - def validation_indices(self) -> Sequence[int]: - """ - Returns the indices of the validation set for each dataset in the multi-dataset. - """ - indices = [] - for i, dataset in enumerate(self.datasets): - try: - offset = self.cumulative_sizes[i - 1] if i > 0 else 0 - sample_indices = np.array(dataset.validation_indices) + offset # type: ignore - indices.extend(list(sample_indices)) - except AttributeError: - logger.warning( - "Unable to get validation indices for dataset %r; skipping this " - "dataset when building validation_indices.", - dataset, + } + shape = { + c: int(t) + for c, t in zip( + next(iter(ds.target_sources.values())).axes, + first_target_spec["shape"], ) + } + chunk_size = {ax: scale[ax] * shape[ax] for ax in scale} + local_indices = ds.get_indices(chunk_size) + indices.extend(i + offset for i in local_indices) + offset += len(ds) return indices def verify(self) -> bool: - """ - Verifies that all datasets in the multi-dataset have the same classes and input/target array keys. - """ - if len(self.datasets) == 0: - return False - - n_verified_datasets = 0 - for dataset in self.datasets: - n_verified_datasets += int(dataset.verify()) # type: ignore - try: - assert ( - dataset.classes == self.classes # type: ignore - ), "All datasets must have the same classes." - assert set(dataset.input_arrays.keys()) == set( # type: ignore - self.input_arrays.keys() - ), "All datasets must have the same input arrays." - if self.target_arrays is not None: - assert set(dataset.target_arrays.keys()) == set( # type: ignore - self.target_arrays.keys() - ), "All datasets must have the same target arrays." - except AssertionError as e: - logger.error( - f"Dataset {dataset} does not match the expected structure: {e}" - ) - return False - return n_verified_datasets > 0 - - def to( - self, device: str | torch.device, non_blocking: bool = True - ) -> "CellMapMultiDataset": - for dataset in self.datasets: - dataset.to(device, non_blocking=non_blocking) # type: ignore - return self - - def get_weighted_sampler( - self, batch_size: int = 1, rng: Optional[torch.Generator] = None - ) -> WeightedRandomSampler: - return WeightedRandomSampler( - self.sample_weights, batch_size, replacement=False, generator=rng - ) - - def get_random_subset_indices( - self, - num_samples: int, - weighted: bool = True, - rng: Optional[torch.Generator] = None, - ) -> Sequence[int]: - if not weighted: - return min_redundant_inds(len(self), num_samples, rng=rng).tolist() - else: - # 1) Draw raw counts per dataset - dataset_weights = torch.tensor( - [self.dataset_weights[ds] for ds in self.datasets], dtype=torch.double # type: ignore - ) - dataset_weights[dataset_weights < 0.1] = 0.1 - - raw_choice = torch.multinomial( - dataset_weights, - num_samples, - replacement=num_samples > len(dataset_weights), - generator=rng, - ) - raw_counts = [ - (raw_choice == i).sum().item() for i in range(len(self.datasets)) - ] - - # 2) Clamp counts at each dataset's size and accumulate overflow - final_counts = [] - overflow = 0 - for i, ds in enumerate(self.datasets): - size_i = len(ds) # type: ignore - c = raw_counts[i] - if c > size_i: - overflow += c - size_i - c = size_i - final_counts.append(c) - - # 3) Distribute overflow via recursion, using dataset_weights - capacity = [len(ds) - final_counts[i] for i, ds in enumerate(self.datasets)] # type: ignore - weights = dataset_weights.clone() - - def redistribute(counts, caps, free_weights, over): - """ - Recursively assign `over` extra samples to datasets in proportion to `free_weights`, - but never exceed capacities in `caps`. - - Args: - ---- - counts (List[int]): current final_counts per dataset - caps (List[int]): remaining capacity per dataset - free_weights (torch.Tensor): clone of dataset_weights - over (int): number of overflow samples to distribute - - Returns: - ------- - (new_counts, new_caps) after assigning as many as possible; - any leftover overflow will be handled by deeper recursion. - """ - if over <= 0: - return counts, caps - - # Zero out weights where capacity == 0 - prob = free_weights.clone() - for idx, cap_i in enumerate(caps): - if cap_i <= 0: - prob[idx] = 0.0 - - total = prob.sum().item() - if total <= 0: - # no capacity left to assign any overflow - return counts, caps - - prob = prob / total - - # Draw all `over` picks at once - picks = torch.multinomial( - prob, - over, - replacement=True, - generator=rng, - ) - freq = torch.bincount(picks, minlength=len(self.datasets)).tolist() - - new_counts = [] - new_caps = [] - leftover = 0 - for j, f_j in enumerate(freq): - cap_j = caps[j] - if f_j <= cap_j: - assigned = f_j - rem = 0 - else: - assigned = cap_j - rem = f_j - cap_j - - new_counts.append(counts[j] + assigned) - new_caps.append(cap_j - assigned) - leftover += rem - - # Recurse only if there’s leftover overflow - return redistribute(new_counts, new_caps, free_weights, leftover) - - # Call the recursive allocator once - final_counts, capacity = redistribute( - final_counts, capacity, weights, overflow - ) - - # 4) Now that final_counts sums to num_samples (and each ≤ its dataset size), - # draw without replacement from each dataset: - indices = [] - index_offset = 0 - for i, ds in enumerate(self.datasets): - c = final_counts[i] - size_i = len(ds) # type: ignore - if c == 0: - index_offset += size_i - continue - ds_indices = min_redundant_inds(size_i, c, rng=rng) - indices.append(ds_indices + index_offset) - index_offset += size_i - - all_indices = torch.cat(indices).flatten() - all_indices = all_indices[ - min_redundant_inds(len(all_indices), num_samples, rng) - ].tolist() - return all_indices - - def get_subset_random_sampler( - self, - num_samples: int, - weighted: bool = True, - rng: Optional[torch.Generator] = None, - ) -> MutableSubsetRandomSampler: - indices_generator = functools.partial( - self.get_random_subset_indices, num_samples, weighted, rng - ) + return len(self) > 0 - return MutableSubsetRandomSampler( - indices_generator, - rng=rng, - ) + # ------------------------------------------------------------------ + # Transform setters (delegate to all datasets) + # ------------------------------------------------------------------ - def get_indices(self, chunk_size: Mapping[str, int]) -> Sequence[int]: - """Returns the indices of the dataset that will tile all of the datasets according to the chunk_size.""" - indices = [] - for i, dataset in enumerate(self.datasets): - if i == 0: - offset = 0 - else: - offset = self.cumulative_sizes[i - 1] - sample_indices = np.array(dataset.get_indices(chunk_size)) + offset # type: ignore - indices.extend(list(sample_indices)) - return indices + def set_raw_value_transforms(self, transforms: Optional[Callable]) -> None: + for ds in self.datasets: + ds.set_raw_value_transforms(transforms) + self.__dict__.pop("validation_indices", None) - def set_raw_value_transforms(self, transforms: Callable) -> None: - """Sets the raw value transforms for each dataset in the multi-dataset.""" - for dataset in self.datasets: - dataset.set_raw_value_transforms(transforms) # type: ignore + def set_target_value_transforms( + self, transforms: Optional[Callable | Mapping[str, Callable]] + ) -> None: + for ds in self.datasets: + ds.set_target_value_transforms(transforms) - def set_target_value_transforms(self, transforms: Callable) -> None: - """Sets the target value transforms for each dataset in the multi-dataset.""" - for dataset in self.datasets: - dataset.set_target_value_transforms(transforms) # type: ignore + def set_spatial_transforms(self, transforms: Optional[Mapping[str, Any]]) -> None: + for ds in self.datasets: + ds.set_spatial_transforms(transforms) - def set_spatial_transforms( - self, spatial_transforms: Mapping[str, Any] | None - ) -> None: - """Sets the raw value transforms for each dataset in the training multi-dataset.""" - for dataset in self.datasets: - dataset.spatial_transforms = spatial_transforms # type: ignore + def to(self, device: str | torch.device) -> "CellMapMultiDataset": + for ds in self.datasets: + ds.to(device) + return self - @staticmethod - def empty() -> "CellMapMultiDataset": - """Creates an empty dataset.""" - empty_dataset = CellMapMultiDataset([], {}, {}, [CellMapDataset.empty()]) - empty_dataset.classes = [] - # Pre-populate the cached_property values via instance dict to avoid recomputation - vars(empty_dataset).update( - class_counts={}, - class_weights={}, - validation_indices=[], + def __repr__(self) -> str: + return ( + f"CellMapMultiDataset({len(self.datasets)} datasets, " + f"classes={self.classes}, len={len(self)})" ) - - return empty_dataset diff --git a/src/cellmap_data/mutable_sampler.py b/src/cellmap_data/mutable_sampler.py deleted file mode 100644 index 67a505e..0000000 --- a/src/cellmap_data/mutable_sampler.py +++ /dev/null @@ -1,39 +0,0 @@ -from collections.abc import Iterator, Sequence -from typing import Callable, Optional - -import torch - - -class MutableSubsetRandomSampler(torch.utils.data.Sampler[int]): - """A mutable version of SubsetRandomSampler that allows changing the indices after initialization. - - Args: - ---- - indices_generator (Callable[[], Sequence[int]]): A callable that returns a sequence of indices to sample from. - rng (Optional[torch.Generator]): Generator used in sampling. - """ - - indices: Sequence[int] - indices_generator: Callable - rng: Optional[torch.Generator] - - def __init__( - self, indices_generator: Callable, rng: Optional[torch.Generator] = None - ): - self.indices_generator = indices_generator - if callable(self.indices_generator): - self.indices = list(self.indices_generator()) - else: - self.indices = list(self.indices_generator) - self.rng = rng - - def __iter__(self) -> Iterator[int]: - for i in torch.randperm(len(self.indices), generator=self.rng): - yield self.indices[i] - - def __len__(self) -> int: - return len(self.indices) - - def refresh(self) -> None: - """Redraw the indices used by the sampler by calling the indices generator.""" - self.indices = list(self.indices_generator()) diff --git a/src/cellmap_data/sampler.py b/src/cellmap_data/sampler.py new file mode 100644 index 0000000..426d829 --- /dev/null +++ b/src/cellmap_data/sampler.py @@ -0,0 +1,92 @@ +"""Class-balanced sampler for CellMap datasets. + +Problem: Some classes appear in 200+ crops, others in <10. +Uniform crop sampling means rare classes barely appear during training. + +Solution: At each step, pick the least-seen class so far, then sample +a crop that annotates it. All active classes receive roughly equal +representation over an epoch regardless of annotation frequency. +""" + +from __future__ import annotations + +from typing import Iterator, Optional + +import numpy as np +from torch.utils.data import Sampler + + +class ClassBalancedSampler(Sampler): + """Greedy class-balanced sampler. + + Algorithm: + 1. Build a crop-class matrix from the dataset: + ``dataset.get_crop_class_matrix()`` → ``bool[n_crops, n_classes]``. + 2. Maintain running counts of how many times each class has been seen. + 3. At each step: pick the class with the lowest count (ties broken + randomly), sample a crop annotating it, yield that crop index, then + increment counts for *all* classes that crop annotates. + + This guarantees rare classes get sampled as often as common ones. + + The sampler resets ``class_counts`` to zero at the start of each + ``__iter__`` call, so no ``refresh()`` is needed between epochs. + + Parameters + ---------- + dataset: + Must implement ``get_crop_class_matrix() -> np.ndarray``. + samples_per_epoch: + Number of samples to yield per epoch. Defaults to ``len(dataset)``. + seed: + Random seed for reproducibility. + """ + + def __init__( + self, + dataset, + samples_per_epoch: Optional[int] = None, + seed: int = 42, + ) -> None: + self.dataset = dataset + self.samples_per_epoch = samples_per_epoch or len(dataset) + self.rng = np.random.default_rng(seed) + + # [n_crops × n_classes] boolean matrix + self.crop_class_matrix: np.ndarray = dataset.get_crop_class_matrix() + self.n_crops, self.n_classes = self.crop_class_matrix.shape + + # Pre-compute per-class crop lists + self.class_to_crops: dict[int, np.ndarray] = {} + for c in range(self.n_classes): + indices = np.where(self.crop_class_matrix[:, c])[0] + if len(indices) > 0: + self.class_to_crops[c] = indices + + self.active_classes: list[int] = sorted(self.class_to_crops.keys()) + + def __iter__(self) -> Iterator[int]: + class_counts = np.zeros(self.n_classes, dtype=np.float64) + + for _ in range(self.samples_per_epoch): + # Find least-seen active class; break ties randomly + active_counts = np.array([class_counts[c] for c in self.active_classes]) + min_count = active_counts.min() + tied = [ + self.active_classes[i] + for i, v in enumerate(active_counts) + if v == min_count + ] + target_class = int(self.rng.choice(tied)) + + # Sample a crop that annotates this class + crop_idx = int(self.rng.choice(self.class_to_crops[target_class])) + + # Increment counts for all classes this crop annotates + annotated = np.where(self.crop_class_matrix[crop_idx])[0] + class_counts[annotated] += 1.0 + + yield crop_idx + + def __len__(self) -> int: + return self.samples_per_epoch diff --git a/src/cellmap_data/subdataset.py b/src/cellmap_data/subdataset.py deleted file mode 100644 index c2eaf80..0000000 --- a/src/cellmap_data/subdataset.py +++ /dev/null @@ -1,102 +0,0 @@ -import functools -from typing import Any, Callable, Optional, Sequence - -import torch -from torch.utils.data import Subset - -from .base_dataset import CellMapBaseDataset -from .dataset import CellMapDataset -from .dataset_writer import CellMapDatasetWriter -from .multidataset import CellMapMultiDataset -from .mutable_sampler import MutableSubsetRandomSampler -from .utils.sampling import min_redundant_inds - - -class CellMapSubset(CellMapBaseDataset, Subset): - """ - This subclasses PyTorch Subset to wrap a CellMapDataset or CellMapMultiDataset object under a common API, which can be used for dataloading. It maintains the same API as the Subset class. It retrieves raw and groundtruth data from a CellMapDataset or CellMapMultiDataset object. - """ - - def __init__( - self, - dataset: CellMapDataset | CellMapMultiDataset | CellMapDatasetWriter, - indices: Sequence[int], - ) -> None: - """ - Args: - ---- - dataset: CellMapDataset | CellMapMultiDataset - The dataset to be subsetted. - indices: Sequence[int] - The indices of the dataset to be used as the subset. - """ - super().__init__(dataset, indices) - - @property - def input_arrays(self) -> dict[str, dict[str, Any]]: - """The input arrays in the dataset.""" - return self.dataset.input_arrays - - @property - def target_arrays(self) -> dict[str, dict[str, Any]]: - """The target arrays in the dataset.""" - return self.dataset.target_arrays - - @property - def classes(self) -> Sequence[str]: - """The classes in the dataset.""" - return self.dataset.classes - - @property - def class_counts(self) -> dict[str, float]: - """The number of samples in each class in the dataset normalized by resolution.""" - return self.dataset.class_counts - - @property - def class_weights(self) -> dict[str, float]: - """The class weights for the dataset based on the number of samples in each class.""" - return self.dataset.class_weights - - @property - def validation_indices(self) -> Sequence[int]: - """The indices of the validation set.""" - return self.dataset.validation_indices - - def to(self, device, non_blocking: bool = True) -> "CellMapSubset": - """Move the dataset to the specified device.""" - self.dataset.to(device, non_blocking=non_blocking) - return self - - def set_raw_value_transforms(self, transforms: Callable) -> None: - """Sets the raw value transforms for the subset dataset.""" - self.dataset.set_raw_value_transforms(transforms) - - def set_target_value_transforms(self, transforms: Callable) -> None: - """Sets the target value transforms for the subset dataset.""" - self.dataset.set_target_value_transforms(transforms) - - def get_random_subset_indices( - self, num_samples: int, rng: Optional[torch.Generator] = None, **kwargs: Any - ) -> Sequence[int]: - inds = min_redundant_inds(len(self.indices), num_samples, rng=rng) - return torch.tensor(self.indices, dtype=torch.long)[inds].tolist() - - def get_subset_random_sampler( - self, - num_samples: int, - rng: Optional[torch.Generator] = None, - **kwargs: Any, - ) -> MutableSubsetRandomSampler: - """ - Returns a random sampler that yields exactly `num_samples` indices from this subset. - - If `num_samples` ≤ total number of available indices, samples without replacement. - - If `num_samples` > total number of available indices, samples with replacement using repeated shuffles to minimize duplicates. - """ - indices_generator = functools.partial( - self.get_random_subset_indices, num_samples, rng, **kwargs - ) - - return MutableSubsetRandomSampler( - indices_generator, - rng=rng, - ) diff --git a/src/cellmap_data/utils/__init__.py b/src/cellmap_data/utils/__init__.py index 39444b1..68ef6a9 100644 --- a/src/cellmap_data/utils/__init__.py +++ b/src/cellmap_data/utils/__init__.py @@ -5,6 +5,7 @@ get_image_grid, get_image_grid_numpy, ) +from .geometry import box_intersection, box_shape, box_union from .metadata import ( add_multiscale_metadata_levels, create_multiscale_metadata, @@ -17,12 +18,11 @@ get_sliced_shape, is_array_2D, longest_common_substring, + min_redundant_inds, permute_singleton_dimension, split_target_path, torch_max_value, ) -from .sampling import min_redundant_inds -from .view import get_neuroglancer_link, open_neuroglancer __all__ = [ "fig_to_image", @@ -30,6 +30,9 @@ "get_image_dict", "get_image_grid", "get_image_grid_numpy", + "box_intersection", + "box_shape", + "box_union", "add_multiscale_metadata_levels", "create_multiscale_metadata", "find_level", @@ -39,10 +42,8 @@ "get_sliced_shape", "is_array_2D", "longest_common_substring", + "min_redundant_inds", "permute_singleton_dimension", "split_target_path", "torch_max_value", - "min_redundant_inds", - "get_neuroglancer_link", - "open_neuroglancer", ] diff --git a/src/cellmap_data/utils/geometry.py b/src/cellmap_data/utils/geometry.py new file mode 100644 index 0000000..cfcc85a --- /dev/null +++ b/src/cellmap_data/utils/geometry.py @@ -0,0 +1,48 @@ +"""Spatial bounding box utilities for world-coordinate arithmetic.""" + +from __future__ import annotations + + +def box_intersection(a: dict, b: dict) -> dict | None: + """Intersection of two bounding boxes. + + Each box is ``{axis: (min, max), ...}`` in world coordinates (nm). + Returns ``None`` if there is no overlap on any shared axis. + """ + result = {} + for ax in a: + if ax not in b: + continue + lo = max(a[ax][0], b[ax][0]) + hi = min(a[ax][1], b[ax][1]) + if lo >= hi: + return None + result[ax] = (lo, hi) + return result if result else None + + +def box_union(a: dict, b: dict) -> dict: + """Bounding box that contains both *a* and *b*.""" + axes = set(a) | set(b) + result = {} + for ax in axes: + if ax in a and ax in b: + result[ax] = (min(a[ax][0], b[ax][0]), max(a[ax][1], b[ax][1])) + elif ax in a: + result[ax] = a[ax] + else: + result[ax] = b[ax] + return result + + +def box_shape(box: dict, scale: dict) -> dict: + """Convert a world bounding box to a voxel count per axis. + + Args: + box: ``{axis: (min, max)}`` in nm. + scale: ``{axis: voxel_size}`` in nm/voxel. + + Returns: + ``{axis: int}`` — number of voxels per axis (at least 1). + """ + return {ax: max(1, int(round((box[ax][1] - box[ax][0]) / scale[ax]))) for ax in box} diff --git a/src/cellmap_data/utils/read_limiter.py b/src/cellmap_data/utils/read_limiter.py deleted file mode 100644 index bf67bfe..0000000 --- a/src/cellmap_data/utils/read_limiter.py +++ /dev/null @@ -1,70 +0,0 @@ -""" -Global TensorStore read limiter for Windows crash prevention. - -On Windows, concurrent TensorStore materializations from multiple threads -(triggered by source[center], .interp, ._TensorStoreAdapter.__array__, etc.) -cause native hard crashes / aborts. This module provides a semaphore-backed -context manager that serializes those reads on Windows+TensorStore while -acting as a no-op on all other platforms. - -Configuration -------------- -CELLMAP_DATA_BACKEND : str - Set to "tensorstore" (default) to enable the limiter on Windows. - Set to anything else (e.g. "dask") to disable it entirely. - -CELLMAP_MAX_CONCURRENT_READS : int - Maximum concurrent TensorStore reads allowed on Windows. - Defaults to 1 (fully serialized). Increase cautiously. - -Notes ------ -Both environment variables must be set **before** this module is imported, -as the semaphore is created once at import time. -""" - -import os -import platform -import threading -from contextlib import contextmanager - -_IS_WINDOWS = platform.system() == "Windows" -_IS_TENSORSTORE = ( - os.environ.get("CELLMAP_DATA_BACKEND", "tensorstore").lower() == "tensorstore" -) - -MAX_CONCURRENT_READS: int | None -_read_semaphore: threading.Semaphore | None - -if _IS_WINDOWS and _IS_TENSORSTORE: - MAX_CONCURRENT_READS = int(os.environ.get("CELLMAP_MAX_CONCURRENT_READS", "1")) - _read_semaphore = threading.Semaphore(MAX_CONCURRENT_READS) -else: - MAX_CONCURRENT_READS = None - _read_semaphore = None - - -@contextmanager -def limit_tensorstore_reads(): - """Context manager that gates TensorStore reads on Windows. - - On Windows with the TensorStore backend, at most ``MAX_CONCURRENT_READS`` - threads may be inside this context at once. On all other platforms (or - when using the Dask backend) this is a true no-op with zero overhead. - - Usage - ----- - :: - - with limit_tensorstore_reads(): - array = source[center] # the unsafe read - # torch-only work continues here unconstrained - """ - if _read_semaphore is not None: - _read_semaphore.acquire() - try: - yield - finally: - _read_semaphore.release() - else: - yield diff --git a/src/cellmap_data/utils/sampling.py b/src/cellmap_data/utils/sampling.py deleted file mode 100644 index 8dcfe2d..0000000 --- a/src/cellmap_data/utils/sampling.py +++ /dev/null @@ -1,39 +0,0 @@ -import warnings -from typing import Optional - -import torch - -MAX_SIZE = ( - 512 * 1024 * 1024 -) # 512 million - increased from 64M to handle larger datasets efficiently - - -def min_redundant_inds( - size: int, num_samples: int, rng: Optional[torch.Generator] = None -) -> torch.Tensor: - """ - Returns a list of indices that will sample `num_samples` from a dataset of size `size` with minimal redundancy. - If `num_samples` is greater than `size`, it will sample with replacement. - """ - if size <= 0: - raise ValueError("Size must be a positive integer.") - elif size > MAX_SIZE: - warnings.warn( - f"Size={size} exceeds MAX_SIZE={MAX_SIZE}. Using faster sampling strategy that doesn't ensure minimal redundancy." - ) - return torch.randint(0, size, (num_samples,), generator=rng) - if num_samples > size: - warnings.warn( - f"Requested num_samples={num_samples} exceeds available samples={size}. " - "Sampling with replacement using repeated permutations to minimize duplicates." - ) - # Determine how many full permutations and remainder are needed - full_iters = num_samples // size - remainder = num_samples % size - - inds_list = [] - for _ in range(full_iters): - inds_list.append(torch.randperm(size, generator=rng)) - if remainder > 0: - inds_list.append(torch.randperm(size, generator=rng)[:remainder]) - return torch.cat(inds_list, dim=0) diff --git a/src/cellmap_data/utils/view.py b/src/cellmap_data/utils/view.py deleted file mode 100644 index e377650..0000000 --- a/src/cellmap_data/utils/view.py +++ /dev/null @@ -1,516 +0,0 @@ -import json -import logging -import operator -import os -import re -import time -import urllib.parse -import webbrowser -from multiprocessing.pool import ThreadPool - -import numpy as np -import zarr - -logger = logging.getLogger(__name__) - -# S3 bucket names and paths for Janelia COSEM datasets -GT_S3_BUCKET = "janelia-cosem-datasets" -RAW_S3_BUCKET = "janelia-cosem-datasets" -S3_SEARCH_PATH = "{dataset}/{dataset}.zarr/recon-1/{name}" -S3_CROP_NAME = "labels/groundtruth/{crop}/{label}" -S3_RAW_NAME = "em/fibsem-uint8" - - -def get_multiscale_voxel_sizes(path: str): - if "s3://" in path: - import s3fs - - # Use s3fs to read the zarr metadata - fs = s3fs.S3FileSystem(anon=True) - store = s3fs.S3Map( - root=path.removeprefix("zarr://s3://"), - s3=fs, - check=False, # skip consistency checks for speed - ) - ds = zarr.open(store, mode="r") - else: - # Use local zarr store - ds = zarr.open(path, mode="r") - voxel_sizes = {} - for scale_ds in ds.attrs["multiscales"][0]["datasets"]: - for transform in scale_ds["coordinateTransformations"]: - if transform["type"] == "scale": - voxel_sizes[scale_ds["path"]] = transform["scale"] - break - if not voxel_sizes: - raise ValueError( - f"No scale transformations found in the zarr metadata at {path}" - ) - return voxel_sizes - - -def get_neuroglancer_link(metadata): - # extract dataset name from raw_path - m = re.search(r"/([^/]+)/\\1\\.zarr", metadata["raw_path"]) - if m: - dataset = m.group(1) - else: - # fallback: take parent folder name before .zarr - dataset = os.path.basename(metadata["raw_path"].split(".zarr")[0]) - # build raw EM layer source - raw_key = S3_SEARCH_PATH.format(dataset=dataset, name=S3_RAW_NAME) - raw_source = f"zarr://s3://{RAW_S3_BUCKET}/{raw_key}" - voxel_sizes = [get_multiscale_voxel_sizes(raw_source)] - layers = {"raw": {"type": "image", "source": raw_source}} - # segmentation layers - # extract crop identifier from target_path_str - m2 = re.search(r"labels/groundtruth/([^/]+)/", metadata["target_path_str"]) - crop = m2.group(1) if m2 else "" - for class_name in metadata["class_weights"].keys(): - seg_path = S3_CROP_NAME.format(crop=crop, label=class_name) - seg_key = S3_SEARCH_PATH.format(dataset=dataset, name=seg_path) - seg_source = f"zarr://s3://{GT_S3_BUCKET}/{seg_key}" - # get voxel size for this segmentation layer - voxel_sizes.append(get_multiscale_voxel_sizes(seg_source)) - layers[class_name] = {"type": "segmentation", "source": seg_source} - - # find the minimum voxel size across all layers - voxel_size = np.min( - [ - np.min([np.array(vs) for vs in ds_vs.values()], axis=0) - for ds_vs in voxel_sizes - ], - axis=0, - ).tolist() - # navigation pose (x, y, z) - position = [ - metadata["current_center"]["z"] / voxel_size[0], - metadata["current_center"]["y"] / voxel_size[1], - metadata["current_center"]["x"] / voxel_size[2], - ] - state = { - "layers": layers, - "navigation": {"pose": {"position": {"voxelCoordinates": position}}}, - } - fragment = urllib.parse.quote(json.dumps(state), safe='/:,"{}[]') - return f"https://neuroglancer-demo.appspot.com/#!{fragment}" - - -def open_neuroglancer(metadata): - """ - Launch a Neuroglancer viewer showing raw data and labels, - centered on the point in metadata['current_center']. - - metadata: dict with keys - - 'raw_path': path to your raw Zarr (no .zarr extension in source) - - 'current_center': {'x':…, 'y':…, 'z':…} - - 'target_path_str': format string with '{label}' for each class - - 'class_weights': dict mapping class names to weights - - Returns the Neuroglancer.Viewer object. - """ - import neuroglancer - from IPython.core.getipython import get_ipython - from IPython.display import IFrame, display - - # 1) bind to localhost on a random port - neuroglancer.set_server_bind_address("localhost", 0) - viewer = neuroglancer.Viewer() - - # 2) build layer sources - raw_source = get_layer( - metadata["raw_path"], - layer_type="image", - ) - label_layers = {} - for class_name in metadata["class_weights"].keys(): - # fill in the placeholder - label_path = metadata["target_path_str"].format(label=class_name) - label_layers[class_name] = get_layer( - label_path, - layer_type="segmentation", - ) - - # 3) push state in one atomic txn - with viewer.txn() as s: - # raw intensity volume - s.layers["raw"] = raw_source - # one layer per class - for class_name, layer in label_layers.items(): - s.layers[class_name] = layer - - # 4) display inline or print URL - url = viewer.get_viewer_url() - print(f"Neuroglancer viewer URL: {url}") - if get_ipython() is not None: - # If running in Jupyter, display the viewer inline - viewer_iframe = IFrame(url, width=1000, height=600) - display(viewer_iframe) - else: - webbrowser.open(url) - - # 5) center the view on the current center when it is available - # by starting a background thread - def _center_view(): - while len(viewer.state.dimensions.to_json()) < 3: - time.sleep(0.1) # wait for dimensions to be set - with viewer.txn() as s: - # jump to the stored center (x, y, z) - cx = float(metadata["current_center"]["x"]) / ( - viewer.state.dimensions["x"].scale * 10**9 - ) - cy = float(metadata["current_center"]["y"]) / ( - viewer.state.dimensions["y"].scale * 10**9 - ) - cz = float(metadata["current_center"]["z"]) / ( - viewer.state.dimensions["z"].scale * 10**9 - ) - # (z is the first dimension in Neuroglancer) - s.position = [cz, cy, cx] - - pool = ThreadPool(processes=1) - pool.apply_async(_center_view) - return viewer - - -def get_layer( - data_path: str, - layer_type: str = "image", - multiscale: bool = True, -): - """ - Get a Neuroglancer layer from a zarr data path for a LocalVolume. - - Parameters - ---------- - data_path : str - The path to the zarr data. - layer_type : str - The type of layer to get. Can be "image" or "segmentation". Default is "image". - multiscale : bool - Whether the metadata is OME-NGFF multiscale. Default is True. - - Returns - ------- - neuroglancer.Layer - The Neuroglancer layer. - """ - import neuroglancer - from upath import UPath - - # Construct an xarray with Tensorstore backend - # Get metadata - if multiscale: - # Add all scales - layers = [] - scales, metadata = parse_multiscale_metadata(data_path) - for scale in scales: - this_path = (UPath(data_path) / scale).path - image = get_image(this_path) - - layers.append( - neuroglancer.LocalVolume( - data=image, - dimensions=neuroglancer.CoordinateSpace( - scales=metadata[scale]["voxel_size"], - units=metadata[scale]["units"], - names=metadata[scale]["names"], - ), - voxel_offset=metadata[scale]["voxel_offset"], - ) - ) - - class ScalePyramid(neuroglancer.LocalVolume): - """A neuroglancer layer that provides volume data on different scales. - Mimics a LocalVolume. - From https://github.com/funkelab/funlib.show.neuroglancer/blob/master/funlib/show/neuroglancer/scale_pyramid.py - - Args: - ---- - volume_layers (``list`` of ``LocalVolume``): - - One ``LocalVolume`` per provided resolution. - """ - - def __init__(self, volume_layers): - volume_layers = volume_layers - - super(neuroglancer.LocalVolume, self).__init__() - - logger.debug("Creating scale pyramid...") - - self.min_voxel_size = min( - [tuple(layer.dimensions.scales) for layer in volume_layers] - ) - self.max_voxel_size = max( - [tuple(layer.dimensions.scales) for layer in volume_layers] - ) - - self.dims = len(volume_layers[0].dimensions.scales) - self.volume_layers = { - tuple( - int(x) - for x in map( - operator.truediv, - layer.dimensions.scales, - self.min_voxel_size, - ) - ): layer - for layer in volume_layers - } - - logger.debug("min_voxel_size: %s", self.min_voxel_size) - logger.debug("scale keys: %s", self.volume_layers.keys()) - logger.debug(self.info()) - - @property - def volume_type(self): - return self.volume_layers[(1,) * self.dims].volume_type - - @property - def token(self): - return self.volume_layers[(1,) * self.dims].token - - def info(self): - reference_layer = self.volume_layers[(1,) * self.dims] - reference_info = reference_layer.info() - - info = { - "dataType": reference_info["dataType"], - "encoding": reference_info["encoding"], - "generation": reference_info["generation"], - "coordinateSpace": reference_info["coordinateSpace"], - "shape": reference_info["shape"], - "volumeType": reference_info["volumeType"], - "voxelOffset": reference_info["voxelOffset"], - "chunkLayout": reference_info["chunkLayout"], - "downsamplingLayout": reference_info["downsamplingLayout"], - "maxDownsampling": int( - np.prod( - np.array(self.max_voxel_size) - // np.array(self.min_voxel_size) - ) - ), - "maxDownsampledSize": reference_info["maxDownsampledSize"], - "maxDownsamplingScales": reference_info["maxDownsamplingScales"], - } - - return info - - def get_encoded_subvolume(self, data_format, start, end, scale_key=None): - if scale_key is None: - scale_key = ",".join(("1",) * self.dims) - - scale = tuple(int(s) for s in scale_key.split(",")) - closest_scale = None - min_diff = np.inf - for volume_scales in self.volume_layers.keys(): - scale_diff = np.array(scale) // np.array(volume_scales) - if any(scale_diff < 1): - continue - scale_diff = scale_diff.max() - if scale_diff < min_diff: - min_diff = scale_diff - closest_scale = volume_scales - - assert closest_scale is not None - relative_scale = np.array(scale) // np.array(closest_scale) - - return self.volume_layers[closest_scale].get_encoded_subvolume( - data_format, - start, - end, - scale_key=",".join(map(str, relative_scale)), - ) - - def get_object_mesh(self, object_id): - return self.volume_layers[(1,) * self.dims].get_object_mesh(object_id) - - def invalidate(self): - return self.volume_layers[(1,) * self.dims].invalidate() - - volume = ScalePyramid(layers) - - else: - # Handle single scale zarr files - names = ["z", "y", "x"] - units = ["nm", "nm", "nm"] - attrs = zarr.open(data_path, mode="r").attrs.asdict() - if "voxel_size" in attrs: - voxel_size = attrs["voxel_size"] - elif "resolution" in attrs: - voxel_size = attrs["resolution"] - elif "scale" in attrs: - voxel_size = attrs["scale"] - else: - voxel_size = [1, 1, 1] - - if "translation" in attrs: - translation = attrs["translation"] - elif "offset" in attrs: - translation = attrs["offset"] - else: - translation = [0, 0, 0] - - voxel_offset = np.array(translation) / np.array(voxel_size) - - image = open_ds_tensorstore(data_path) - # image = get_image(data_path) - - volume = neuroglancer.LocalVolume( - data=image, - dimensions=neuroglancer.CoordinateSpace( - scales=voxel_size, - units=units, - names=names, - ), - voxel_offset=voxel_offset, - ) - - if layer_type == "segmentation": - return neuroglancer.SegmentationLayer(source=volume) - else: - return neuroglancer.ImageLayer(source=volume) - - -def get_image(data_path: str): - import tensorstore - import xarray_tensorstore as xt - - try: - return open_ds_tensorstore(data_path) - except ValueError: - spec = xt._zarr_spec_from_path(data_path, zarr_format=2) - array_future = tensorstore.open(spec, read=True, write=False) - try: - array = array_future.result() - except ValueError: - UserWarning("Falling back to zarr3 driver") - spec["driver"] = "zarr3" - array_future = ts_open(spec, read=True, write=False) - array = array_future.result() - return array - - -def parse_multiscale_metadata(data_path: str): - metadata = zarr.open(data_path, mode="r").attrs.asdict()["multiscales"][0] - scales = [] - parsed = {} - for ds in metadata["datasets"]: - scales.append(ds["path"]) - - names = [] - units = [] - translation = [] - voxel_size = [] - - for axis in metadata["axes"]: - if axis["name"] == "c": - names.append("c^") - voxel_size.append(1) - translation.append(0) - units.append("") - else: - names.append(axis["name"]) - units.append("nm") - - for transform in ds["coordinateTransformations"]: - if transform["type"] == "scale": - voxel_size.extend(transform["scale"]) - elif transform["type"] == "translation": - translation.extend(transform["translation"]) - - parsed[ds["path"]] = { - "names": names, - "units": units, - "voxel_size": voxel_size, - "translation": translation, - "voxel_offset": np.array(translation) / np.array(voxel_size), - } - scales.sort(key=lambda x: int(x[1:])) - return scales, parsed - - -def open_ds_tensorstore(dataset_path: str, mode="r", concurrency_limit=None): - from tensorstore import d as ts_d - from tensorstore import open as ts_open - - # open with zarr or n5 depending on extension - filetype = ( - "zarr" if dataset_path.rfind(".zarr") > dataset_path.rfind(".n5") else "n5" - ) - extra_args = {} - - if dataset_path.startswith("s3://"): - kvstore = { - "driver": "s3", - "bucket": dataset_path.split("/")[2], - "path": "/".join(dataset_path.split("/")[3:]), - "aws_credentials": { - "anonymous": True, - }, - } - elif dataset_path.startswith("gs://"): - # check if path ends with s#int - if ends_with_scale(dataset_path): - scale_index = int(dataset_path.rsplit("/s")[1]) - dataset_path = dataset_path.rsplit("/s")[0] - else: - scale_index = 0 - filetype = "neuroglancer_precomputed" - kvstore = dataset_path - extra_args = {"scale_index": scale_index} - else: - kvstore = { - "driver": "file", - "path": os.path.normpath(dataset_path), - } - - if concurrency_limit: - spec = { - "driver": filetype, - "context": { - "data_copy_concurrency": {"limit": concurrency_limit}, - "file_io_concurrency": {"limit": concurrency_limit}, - }, - "kvstore": kvstore, - **extra_args, - } - else: - spec = {"driver": filetype, "kvstore": kvstore, **extra_args} - - if mode == "r": - dataset_future = ts_open(spec, read=True, write=False) - else: - dataset_future = ts_open(spec, read=False, write=True) - - if dataset_path.startswith("gs://"): - # NOTE: Currently a hack since google store is for some reason - # stored as mutlichannel - ts_dataset = dataset_future.result()[ts_d["channel"][0]] - else: - ts_dataset = dataset_future.result() - - # return ts_dataset - return LazyNormalization(ts_dataset) - - -def ends_with_scale(string): - pattern = ( - r"s\d+$" # Matches 's' followed by one or more digits at the end of the string - ) - return bool(re.search(pattern, string)) - - -class LazyNormalization: - def __init__(self, ts_dataset): - self.ts_dataset = ts_dataset - self.input_norms = [] - - def __getitem__(self, ind): - g = self.ts_dataset[ind].read().result() - self.input_norms.append((g.min(), g.max())) - return g - - def __getattr__(self, name): - return getattr(self.ts_dataset, name) diff --git a/tests/README.md b/tests/README.md deleted file mode 100644 index 716a8e4..0000000 --- a/tests/README.md +++ /dev/null @@ -1,326 +0,0 @@ -# CellMap-Data Test Suite - -Comprehensive test coverage for the cellmap-data library using pytest with real implementations (no mocks). - -## Overview - -This test suite provides extensive coverage of all core components: - -- **test_helpers.py**: Utilities for creating real Zarr/OME-NGFF test data -- **test_cellmap_image.py**: CellMapImage initialization and configuration -- **test_transforms.py**: All augmentation transforms with real tensors -- **test_cellmap_dataset.py**: CellMapDataset configuration and parameters -- **test_dataloader.py**: CellMapDataLoader setup and optimizations -- **test_multidataset_datasplit.py**: Multi-dataset and train/val splits -- **test_dataset_writer.py**: CellMapDatasetWriter for predictions -- **test_empty_image_writer.py**: EmptyImage and ImageWriter utilities -- **test_mutable_sampler.py**: MutableSubsetRandomSampler functionality -- **test_utils.py**: Utility function tests -- **test_integration.py**: End-to-end workflow integration tests -- **test_windows_stress.py**: TensorStore read-limiter unit tests, executor lifecycle, and concurrent-read stress tests - -## Running Tests - -### Prerequisites - -Install the package with test dependencies: - -```bash -pip install -e ".[test]" -``` - -Or install dependencies individually: - -```bash -pip install pytest pytest-cov pytest-timeout -pip install torch torchvision tensorstore xarray zarr numpy -pip install pydantic-ome-ngff xarray-ome-ngff xarray-tensorstore -``` - -### Run All Tests - -```bash -# Run all tests -pytest tests/ - -# Run with coverage -pytest tests/ --cov=cellmap_data --cov-report=html - -# Run with verbose output -pytest tests/ -v - -# Run specific test file -pytest tests/test_cellmap_dataset.py -v -``` - -### Run Specific Test Categories - -```bash -# Core component tests -pytest tests/test_cellmap_image.py tests/test_cellmap_dataset.py - -# Transform tests -pytest tests/test_transforms.py - -# DataLoader tests -pytest tests/test_dataloader.py - -# Integration tests -pytest tests/test_integration.py - -# Utility tests -pytest tests/test_utils.py tests/test_mutable_sampler.py -``` - -### Run Tests by Pattern - -```bash -# Run all initialization tests -pytest tests/ -k "test_initialization" - -# Run all configuration tests -pytest tests/ -k "test.*config" - -# Run all integration tests -pytest tests/ -k "integration" -``` - -## Test Design Principles - -### No Mocks - Real Implementations - -All tests use real implementations: -- **Real Zarr arrays** with OME-NGFF metadata -- **Real TensorStore** backend for array access -- **Real PyTorch tensors** for data and transforms -- **Real file I/O** using temporary directories - -This ensures tests validate actual behavior, not mocked interfaces. - -### Test Data Generation - -The `test_helpers.py` module provides utilities to create realistic test data: - -```python -from tests.test_helpers import create_test_dataset - -# Create a complete test dataset -config = create_test_dataset( - tmp_path, - raw_shape=(64, 64, 64), - num_classes=3, - raw_scale=(8.0, 8.0, 8.0), -) -# Returns paths, shapes, scales, and class names -``` - -### Fixtures and Reusability - -Tests use pytest fixtures for common setups: - -```python -@pytest.fixture -def test_dataset(self, tmp_path): - """Create a test dataset for loader tests.""" - config = create_test_dataset(tmp_path, ...) - return create_dataset_from_config(config) -``` - -## Test Coverage - -### Core Components - -- ✅ **CellMapImage**: Initialization, device selection, transforms, 2D/3D, dtypes -- ✅ **CellMapDataset**: Configuration, arrays, transforms, parameters, `close()` lifecycle -- ✅ **CellMapDataLoader**: Batching, workers, sampling, optimization -- ✅ **CellMapMultiDataset**: Combining datasets, multi-scale -- ✅ **CellMapDataSplit**: Train/val splits, configuration -- ✅ **CellMapDatasetWriter**: Prediction writing, bounds, multiple outputs -- ✅ **EmptyImage/ImageWriter**: Placeholders and writing utilities -- ✅ **MutableSubsetRandomSampler**: Weighted sampling, reproducibility -- ✅ **read_limiter**: Semaphore state, context manager correctness, deadlock safety, stress reads - -### Transforms - -- ✅ **Normalize**: Scaling, mean subtraction -- ✅ **GaussianNoise**: Noise addition, different std values -- ✅ **RandomContrast**: Contrast adjustment, ranges -- ✅ **RandomGamma**: Gamma correction, ranges -- ✅ **NaNtoNum**: NaN/inf replacement -- ✅ **Binarize**: Thresholding, different values -- ✅ **GaussianBlur**: Blur with different sigmas -- ✅ **Transform Composition**: Sequential application - -### Utilities - -- ✅ **Array operations**: Shape utilities, 2D detection -- ✅ **Coordinate transforms**: Scaling, translation -- ✅ **Dtype utilities**: Torch/numpy conversion, max values -- ✅ **Sampling utilities**: Weights, balancing -- ✅ **Path utilities**: Path splitting, class extraction - -### Integration Tests - -- ✅ **Training workflows**: Complete pipelines, transforms -- ✅ **Multi-dataset training**: Combining datasets, loaders -- ✅ **Train/val splits**: Complete workflows -- ✅ **Transform pipelines**: Complex augmentation sequences -- ✅ **Edge cases**: Small datasets, single class, anisotropic, 2D - -## Test Organization - -``` -tests/ -├── conftest.py # Pytest configuration -├── __init__.py # Test package init -├── README.md # This file -├── test_helpers.py # Test data generation utilities -├── test_cellmap_image.py # CellMapImage tests -├── test_cellmap_dataset.py # CellMapDataset tests -├── test_dataloader.py # CellMapDataLoader tests -├── test_multidataset_datasplit.py # MultiDataset/DataSplit tests -├── test_dataset_writer.py # DatasetWriter tests -├── test_empty_image_writer.py # EmptyImage/ImageWriter tests -├── test_mutable_sampler.py # MutableSubsetRandomSampler tests -├── test_transforms.py # Transform tests -├── test_utils.py # Utility function tests -├── test_integration.py # Integration tests -└── test_windows_stress.py # TensorStore read-limiter & concurrent stress tests -``` - -## Continuous Integration - -Tests are designed to run in CI environments: - -- **No GPU required**: Tests use CPU by default (configured in `conftest.py`) -- **Fast execution**: Tests use small datasets for speed -- **Isolated**: Each test uses temporary directories -- **Parallel-safe**: Tests can run in parallel with pytest-xdist - -### CI Configuration - -```yaml -# Example GitHub Actions workflow -- name: Run tests - run: | - pytest tests/ --cov=cellmap_data --cov-report=xml - -- name: Upload coverage - uses: codecov/codecov-action@v3 -``` - -## Extending Tests - -### Adding New Test Files - -1. Create new file: `tests/test_new_component.py` -2. Import test helpers: `from .test_helpers import create_test_dataset` -3. Use pytest fixtures for setup -4. Follow existing patterns for consistency - -### Adding New Test Cases - -```python -class TestNewComponent: - """Test suite for new component.""" - - @pytest.fixture - def test_config(self, tmp_path): - """Create test configuration.""" - return create_test_dataset(tmp_path, ...) - - def test_basic_functionality(self, test_config): - """Test basic functionality.""" - # Use real data from test_config - component = NewComponent(**test_config) - assert component is not None -``` - -## Debugging Tests - -### Run Single Test with Output - -```bash -pytest tests/test_cellmap_dataset.py::TestCellMapDataset::test_initialization_basic -v -s -``` - -### Run with Debugger - -```bash -pytest tests/test_cellmap_dataset.py --pdb -``` - -### Check Test Coverage - -```bash -pytest tests/ --cov=cellmap_data --cov-report=term-missing -``` - -### Generate HTML Coverage Report - -```bash -pytest tests/ --cov=cellmap_data --cov-report=html -# Open htmlcov/index.html in browser -``` - -## Known Limitations - -### GPU Tests - -GPU-specific tests are limited because: -- CI environments typically don't have GPUs -- GPU availability varies across systems -- Tests focus on CPU to ensure broad compatibility - -GPU functionality can be tested manually: -```bash -# Run tests with GPU if available -CUDA_VISIBLE_DEVICES=0 pytest tests/ -``` - -### Large-Scale Tests - -Tests use small datasets for speed. For large-scale testing: -- Manually test with production-sized data -- Use integration tests with larger configurations -- Monitor memory usage and performance - -### Windows Crash Regression Tests - -`test_windows_stress.py::TestConcurrentGetitem::test_windows_high_concurrency_no_crash` is -skipped on non-Windows platforms (via `@pytest.mark.skipif`). To run it on Windows CI: - -```yaml -# GitHub Actions — add a Windows runner -runs-on: windows-latest -steps: - - run: pytest tests/test_windows_stress.py -v -``` - -A native TensorStore abort caused by concurrent reads will appear as a **non-zero process exit -code** rather than a Python exception; pytest will report the job as failed, which is the -correct CI signal. - -The cross-platform deadlock and semaphore tests (`TestReadLimiterUnit`, -`TestExecutorLifecycle`, serial and multi-worker `TestConcurrentGetitem` tests) run on all -platforms and are included in the normal `pytest tests/` run. - -## Contributing - -When adding tests: - -1. **Use real implementations** - no mocks unless absolutely necessary -2. **Use test helpers** - leverage existing test data generation -3. **Add docstrings** - explain what each test validates -4. **Keep tests fast** - use minimal datasets -5. **Test edge cases** - include boundary conditions -6. **Follow patterns** - maintain consistency with existing tests - -## Questions or Issues - -If you have questions about the tests or find issues: - -1. Check this README for guidance -2. Look at existing tests for patterns -3. Review test helper utilities -4. Open an issue with specific questions diff --git a/tests/demo_memory_fix.py b/tests/demo_memory_fix.py deleted file mode 100755 index 762942c..0000000 --- a/tests/demo_memory_fix.py +++ /dev/null @@ -1,294 +0,0 @@ -#!/usr/bin/env python -""" -Memory profiling demo for the CellMapImage array cache fix. - -Demonstrates two levels of profiling: - 1. Mock class — fast, no real data needed, shows the principle. - 2. Real CellMapImage — uses a temporary Zarr dataset to profile - actual xarray/TensorStore allocations. - -Profiling tools used: - - tracemalloc (built-in): snapshot comparison shows *what* is growing, - not just peak usage. - - objgraph (optional, pip install objgraph): counts live Python objects - by type, confirming whether xarray DataArrays accumulate. - -Usage: - python tests/demo_memory_fix.py - DEMO_ITERS=200 python tests/demo_memory_fix.py -""" - -import gc -import io -import os -import sys -import tempfile -import tracemalloc -from pathlib import Path - -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) - -import numpy as np - -try: - import objgraph - - HAS_OBJGRAPH = True -except ImportError: - HAS_OBJGRAPH = False - - -# --------------------------------------------------------------------------- -# Profiling helpers -# --------------------------------------------------------------------------- - - -def profile_iters(label, call_fn, iterations=100, snapshot_every=25): - """ - Run call_fn(i) for `iterations` steps and track memory growth. - - Prints a tracemalloc snapshot diff every `snapshot_every` steps showing - which allocation sites are growing (not just peak usage). If objgraph is - available, also prints live object-type counts so you can confirm whether - xarray DataArrays or numpy arrays are accumulating. - - Args: - label: Description printed as the section header. - call_fn: Callable taking iteration index, e.g. lambda i: img[center]. - iterations: Total number of iterations to run. - snapshot_every: How often to print an intermediate snapshot diff. - """ - print(f"\n{'─' * 64}") - print(f" {label}") - print(f"{'─' * 64}") - - gc.collect() - tracemalloc.start() - baseline = tracemalloc.take_snapshot() - - # objgraph.show_growth() tracks growth relative to the previous call; - # calling it once here establishes the baseline object counts. - if HAS_OBJGRAPH: - objgraph.show_growth( - limit=10, file=io.StringIO() - ) # prime state, discard output - - for i in range(iterations): - call_fn(i) - - if (i + 1) % snapshot_every == 0: - gc.collect() - snap = tracemalloc.take_snapshot() - stats = snap.compare_to(baseline, "lineno") - growing = [s for s in stats if s.size_diff > 0] - - print(f"\n [iter {i + 1}/{iterations}] Allocations grown vs. baseline:") - if growing: - for s in growing[:6]: - kb = s.size_diff / 1024 - loc = s.traceback[0] - print(f" {kb:+8.1f} KB {loc.filename}:{loc.lineno}") - else: - print(" (none — memory is stable)") - - if HAS_OBJGRAPH: - print( - f"\n [iter {i + 1}/{iterations}] New object types since last check:" - ) - objgraph.show_growth(limit=5, shortnames=False) - - current, peak = tracemalloc.get_traced_memory() - tracemalloc.stop() - - print(f"\n Summary — current: {current / 1024:.1f} KB, peak: {peak / 1024:.1f} KB") - - -# --------------------------------------------------------------------------- -# Section 1: Mock demo (no real data needed) -# --------------------------------------------------------------------------- - - -class _MockCacheUser: - """ - Minimal stand-in simulating CellMapImage's cached_property array pattern. - - Each __getitem__ allocates a new array into self._array_cache, mirroring - how CellMapImage builds an xarray DataArray on every access. With - clear_cache=True the cache is dropped immediately (the fix); without it - the reference accumulates. - """ - - def __init__(self, shape=(512, 512)): - self.shape = shape - self._array_cache = None - - def _clear_array_cache(self): - self._array_cache = None - - def __getitem__(self, idx, clear_cache=True): - self._array_cache = np.ones(self.shape, dtype=np.float32) - result = self._array_cache - if clear_cache: - self._clear_array_cache() - return result - - -def run_mock_demo(iterations): - print("\n" + "=" * 64) - print("SECTION 1: Mock demo (no real data, illustrates the principle)") - print("=" * 64) - - leaky = _MockCacheUser() - fixed = _MockCacheUser() - - profile_iters( - "WITHOUT cache clearing (leaky)", - lambda i: leaky.__getitem__(i, clear_cache=False), - iterations=iterations, - ) - profile_iters( - "WITH cache clearing (fixed)", - lambda i: fixed.__getitem__(i, clear_cache=True), - iterations=iterations, - ) - - -# --------------------------------------------------------------------------- -# Section 2: Real CellMapImage with a temporary Zarr store -# --------------------------------------------------------------------------- - - -def _build_test_zarr(root_path: Path, shape=(32, 32, 32), scale=(4.0, 4.0, 4.0)): - """Create a minimal OME-NGFF Zarr array for profiling.""" - import zarr - from pydantic_ome_ngff.v04.axis import Axis - from pydantic_ome_ngff.v04.multiscale import ( - Dataset as MultiscaleDataset, - MultiscaleMetadata, - ) - from pydantic_ome_ngff.v04.transform import VectorScale - - root_path.mkdir(parents=True, exist_ok=True) - data = np.random.rand(*shape).astype(np.float32) - store = zarr.DirectoryStore(str(root_path)) - root = zarr.group(store=store, overwrite=True) - chunks = tuple(min(16, s) for s in shape) - root.create_dataset("s0", data=data, chunks=chunks, overwrite=True) - - axes = [Axis(name=n, type="space", unit="nanometer") for n in ["z", "y", "x"]] - datasets = ( - MultiscaleDataset( - path="s0", - coordinateTransformations=(VectorScale(type="scale", scale=scale),), - ), - ) - root.attrs["multiscales"] = [ - MultiscaleMetadata( - version="0.4", name="test", axes=axes, datasets=datasets - ).model_dump(mode="json", exclude_none=True) - ] - return str(root_path) - - -def run_real_demo(iterations): - print("\n" + "=" * 64) - print("SECTION 2: Real CellMapImage with a temporary Zarr dataset") - print("=" * 64) - - try: - from cellmap_data.image import CellMapImage - except ImportError as e: - print(f"\n Skipping — could not import CellMapImage: {e}") - return - - try: - with tempfile.TemporaryDirectory() as tmp: - # Larger array so each DataArray is meaningfully sized (~2 MB) - shape = (64, 64, 64) - scale = [4.0, 4.0, 4.0] - voxel_shape = [16, 16, 16] - img_path = _build_test_zarr( - Path(tmp) / "raw", shape=shape, scale=tuple(scale) - ) - - # Volume spans 0–256 nm per axis; vary centers to exercise interp/reindex - rng = np.random.default_rng(42) - half = voxel_shape[0] * scale[0] / 2 # 32 nm margin - lo, hi = half, shape[0] * scale[0] - half # 32 to 224 nm - - def random_center(i): - coords = rng.uniform(lo, hi, size=3) - return { - "z": float(coords[0]), - "y": float(coords[1]), - "x": float(coords[2]), - } - - def make_image(): - return CellMapImage( - path=img_path, - target_class="raw", - target_scale=scale, - target_voxel_shape=voxel_shape, - device="cpu", - ) - - # Warmup: load all heavy imports and initialize TensorStore context - # before profiling either mode, so the comparison is not confounded - # by import costs. - print("\n Warming up (pre-loading imports and TensorStore context)...") - _warmup = make_image() - for _ in range(5): - _warmup[{"z": 128.0, "y": 128.0, "x": 128.0}] - del _warmup - gc.collect() - print(" Done.\n") - - # Leaky first (no imports to pay), then fixed — equal footing. - img_leaky = make_image() - img_leaky._clear_array_cache = lambda: None - profile_iters( - "CellMapImage — WITHOUT cache clearing (leaky)", - lambda i: img_leaky[random_center(i)], - iterations=iterations, - ) - - img_fixed = make_image() - profile_iters( - "CellMapImage — WITH cache clearing (fixed)", - lambda i: img_fixed[random_center(i)], - iterations=iterations, - ) - - except Exception as e: - print(f"\n Error during real demo: {e}") - raise - - -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - - -def main(): - iterations = int(os.environ.get("DEMO_ITERS", "100")) - - print("=" * 64) - print("CellMapImage Memory Profiling Demo") - print("=" * 64) - print( - f"\n iterations : {iterations} (set DEMO_ITERS env var to change)\n" - f" tracemalloc: built-in\n" - f" objgraph : {'available' if HAS_OBJGRAPH else 'not installed — pip install objgraph'}" - ) - - run_mock_demo(iterations=iterations) - run_real_demo(iterations=iterations) - - print("\n" + "=" * 64) - print("Done.") - print("=" * 64) - - -if __name__ == "__main__": - main() diff --git a/tests/test_api_contract.py b/tests/test_api_contract.py new file mode 100644 index 0000000..73da9d9 --- /dev/null +++ b/tests/test_api_contract.py @@ -0,0 +1,569 @@ +"""Tests that validate every API call documented in API_TO_PRESERVE.md. + +These tests mirror the exact constructor signatures, attribute accesses, and +call patterns used in cellmap-segmentation-challenge. +""" + +from __future__ import annotations + +import csv +import os + +import torch +import pytest + +from cellmap_data import ( + CellMapDataLoader, + CellMapDataSplit, + CellMapDatasetWriter, + CellMapImage, +) +from cellmap_data.transforms.augment import Binarize, NaNtoNum +from cellmap_data.utils import ( + array_has_singleton_dim, + get_fig_dict, + is_array_2D, + longest_common_substring, + permute_singleton_dimension, +) + +from .test_helpers import create_test_dataset, create_test_zarr + +import torchvision.transforms.v2 as T + +# Default transforms used throughout cellmap-segmentation-challenge +_RAW_TX = T.Compose( + [ + T.ToDtype(torch.float, scale=True), + NaNtoNum({"nan": 0, "posinf": None, "neginf": None}), + ] +) +_TARGET_TX = T.Compose([T.ToDtype(torch.float), Binarize()]) + +INPUT_ARRAYS = {"raw": {"shape": (4, 4, 4), "scale": (8.0, 8.0, 8.0)}} +TARGET_ARRAYS = {"labels": {"shape": (4, 4, 4), "scale": (8.0, 8.0, 8.0)}} +CLASSES = ["mito", "er"] + + +# --------------------------------------------------------------------------- +# CellMapDataSplit — full API +# --------------------------------------------------------------------------- + + +class TestCellMapDataSplitAPI: + """Mirrors utils/dataloader.py usage in cellmap-segmentation-challenge.""" + + def _make_csv(self, tmp_path): + train_info = create_test_dataset(tmp_path / "train", classes=CLASSES) + val_info = create_test_dataset(tmp_path / "val", classes=CLASSES) + csv_path = str(tmp_path / "split.csv") + with open(csv_path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(["train", train_info["raw_path"], train_info["gt_path"]]) + w.writerow(["validate", val_info["raw_path"], val_info["gt_path"]]) + return csv_path + + def test_constructor_with_csv(self, tmp_path): + csv_path = self._make_csv(tmp_path) + split = CellMapDataSplit( + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + classes=CLASSES, + pad=True, + csv_path=csv_path, + train_raw_value_transforms=_RAW_TX, + val_raw_value_transforms=_RAW_TX, + target_value_transforms=_TARGET_TX, + spatial_transforms=None, + device="cpu", + class_relation_dict=None, + force_has_data=True, + ) + assert split is not None + + def test_validation_datasets_is_list(self, tmp_path): + csv_path = self._make_csv(tmp_path) + split = CellMapDataSplit( + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + classes=CLASSES, + pad=True, + csv_path=csv_path, + force_has_data=True, + ) + assert isinstance(split.validation_datasets, list) + assert len(split.validation_datasets) == 1 + + def test_validation_blocks_is_subset(self, tmp_path): + from torch.utils.data import Subset + + csv_path = self._make_csv(tmp_path) + split = CellMapDataSplit( + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + classes=CLASSES, + pad=True, + csv_path=csv_path, + force_has_data=True, + ) + blocks = split.validation_blocks + assert isinstance(blocks, Subset) + assert len(blocks) > 0 + + def test_train_datasets_combined(self, tmp_path): + from cellmap_data.multidataset import CellMapMultiDataset + + csv_path = self._make_csv(tmp_path) + split = CellMapDataSplit( + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + classes=CLASSES, + pad=True, + csv_path=csv_path, + force_has_data=True, + ) + combined = split.train_datasets_combined + assert isinstance(combined, CellMapMultiDataset) + assert len(combined) > 0 + + def test_to_device_noop(self, tmp_path): + """split.to(device) should not raise (no-op on CPU datasets).""" + csv_path = self._make_csv(tmp_path) + split = CellMapDataSplit( + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + classes=CLASSES, + csv_path=csv_path, + force_has_data=True, + ) + split.to("cpu") # should not raise + + +# --------------------------------------------------------------------------- +# CellMapDataLoader — full API +# --------------------------------------------------------------------------- + + +class TestCellMapDataLoaderAPI: + """Mirrors utils/dataloader.py lines 188, 204 in cellmap-segmentation-challenge.""" + + def _make_split(self, tmp_path): + train_info = create_test_dataset(tmp_path / "train", classes=CLASSES) + val_info = create_test_dataset(tmp_path / "val", classes=CLASSES) + ds_dict = { + "train": [{"raw": train_info["raw_path"], "gt": train_info["gt_path"]}], + "validate": [{"raw": val_info["raw_path"], "gt": val_info["gt_path"]}], + } + return CellMapDataSplit( + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + classes=CLASSES, + pad=True, + dataset_dict=ds_dict, + force_has_data=True, + ) + + def test_validation_loader_api(self, tmp_path): + """Replicates: CellMapDataLoader(blocks, classes, batch_size, is_train=False, device).""" + split = self._make_split(tmp_path) + blocks = split.validation_blocks + loader = CellMapDataLoader( + blocks, + classes=CLASSES, + batch_size=2, + is_train=False, + device="cpu", + ) + batches = list(loader) + assert len(batches) > 0 + + def test_training_loader_api(self, tmp_path): + """Replicates: CellMapDataLoader(combined, ..., is_train=True, weighted_sampler=True, iterations_per_epoch).""" + split = self._make_split(tmp_path) + combined = split.train_datasets_combined + loader = CellMapDataLoader( + combined, + classes=CLASSES, + batch_size=2, + is_train=True, + device="cpu", + iterations_per_epoch=4, + weighted_sampler=True, + ) + # Should yield exactly ceil(4 / batch_size) batches + batches = list(loader) + assert len(batches) > 0 + + def test_batch_dict_has_idx(self, tmp_path): + """All batches must contain the 'idx' key (needed for writer[batch['idx']] = outputs).""" + split = self._make_split(tmp_path) + loader = CellMapDataLoader( + split.validation_blocks, + classes=CLASSES, + batch_size=2, + is_train=False, + ) + for batch in loader: + assert "idx" in batch + assert isinstance(batch["idx"], torch.Tensor) + break + + def test_loader_is_iterable(self, tmp_path): + split = self._make_split(tmp_path) + loader = CellMapDataLoader( + split.validation_blocks, classes=CLASSES, batch_size=1, is_train=False + ) + assert hasattr(loader, "__iter__") + assert hasattr(loader, "__len__") + + def test_blocks_to_device_before_loader(self, tmp_path): + """split.validation_blocks.to(device) is called before passing to loader.""" + split = self._make_split(tmp_path) + blocks = split.validation_blocks + # .to(device) on a Subset delegates to its dataset + if hasattr(blocks.dataset, "to"): + blocks.dataset.to("cpu") + loader = CellMapDataLoader( + blocks, classes=CLASSES, batch_size=1, is_train=False + ) + batches = list(loader) + assert len(batches) > 0 + + +# --------------------------------------------------------------------------- +# CellMapDatasetWriter — full API +# --------------------------------------------------------------------------- + + +class TestCellMapDatasetWriterAPI: + """Mirrors predict.py and process.py usage.""" + + def _make_writer(self, tmp_path, model_classes=None): + raw_path = create_test_zarr( + tmp_path, name="raw", shape=(32, 32, 32), voxel_size=[8.0, 8.0, 8.0] + ) + out_path = str(tmp_path / "predictions.zarr") + bounds = {"pred": {"z": (0.0, 256.0), "y": (0.0, 256.0), "x": (0.0, 256.0)}} + import torchvision.transforms.v2 as T + + raw_tx = T.Compose( + [ + T.ToDtype(torch.float, scale=True), + NaNtoNum({"nan": 0, "posinf": None, "neginf": None}), + ] + ) + return CellMapDatasetWriter( + raw_path=raw_path, + target_path=out_path, + classes=CLASSES, + input_arrays=INPUT_ARRAYS, + target_arrays={"pred": {"shape": (4, 4, 4), "scale": (8.0, 8.0, 8.0)}}, + target_bounds=bounds, + overwrite=False, + device="cuda", + raw_value_transforms=raw_tx, + model_classes=model_classes or CLASSES, + ) + + def test_constructor_full_signature(self, tmp_path): + writer = self._make_writer(tmp_path) + assert writer is not None + + def test_loader_method(self, tmp_path): + """writer.loader(batch_size) returns an iterable DataLoader.""" + writer = self._make_writer(tmp_path) + loader = writer.loader(batch_size=2) + assert hasattr(loader, "__iter__") + batches = list(loader) + assert len(batches) > 0 + + def test_loader_batch_has_idx(self, tmp_path): + writer = self._make_writer(tmp_path) + for batch in writer.loader(batch_size=2): + assert "idx" in batch + break + + def test_setitem_with_batch_idx(self, tmp_path): + """writer[batch['idx']] = outputs — the main write pattern.""" + writer = self._make_writer(tmp_path) + loader = writer.loader(batch_size=2) + for batch in loader: + idx = batch["idx"] + # Model outputs: one tensor per class + outputs = {cls: torch.zeros(len(idx), 4, 4, 4) for cls in CLASSES} + writer[idx] = outputs # should not raise + break + + def test_setitem_scalar_idx(self, tmp_path): + writer = self._make_writer(tmp_path) + idx = writer.writer_indices[0] + outputs = {"mito": torch.zeros(4, 4, 4), "er": torch.zeros(4, 4, 4)} + writer[idx] = outputs # should not raise + + def test_model_classes_superset(self, tmp_path): + """model_classes may be a superset of classes (write subset only).""" + writer = self._make_writer(tmp_path, model_classes=CLASSES + ["nucleus"]) + assert writer.model_classes == CLASSES + ["nucleus"] + + def test_bounding_box_exposed(self, tmp_path): + writer = self._make_writer(tmp_path) + bb = writer.bounding_box + assert bb is not None + assert "z" in bb + + +# --------------------------------------------------------------------------- +# CellMapImage — full API +# --------------------------------------------------------------------------- + + +class TestCellMapImageAPI: + """Mirrors predict.py, process.py, and utils/matched_crop.py usage.""" + + def test_constructor_full_signature(self, tmp_path): + """Replicates: CellMapImage(path, target_class, target_scale, target_voxel_shape, pad, pad_value, interpolation).""" + path = create_test_zarr(tmp_path, shape=(32, 32, 32)) + img = CellMapImage( + path=path, + target_class="label", + target_scale=(8.0, 8.0, 8.0), + target_voxel_shape=(4, 4, 4), + pad=True, + pad_value=0, + interpolation="linear", + ) + assert img is not None + + def test_scale_level_is_int(self, tmp_path): + """matched_crop.py:293 — img.scale_level.""" + path = create_test_zarr(tmp_path) + img = CellMapImage(path, "label", (8.0, 8.0, 8.0), (4, 4, 4)) + assert isinstance(img.scale_level, int) + assert img.scale_level >= 0 + + def test_bounding_box_is_dict(self, tmp_path): + """process.py — img.bounding_box.""" + path = create_test_zarr(tmp_path, shape=(32, 32, 32)) + img = CellMapImage(path, "raw", (8.0, 8.0, 8.0), (4, 4, 4)) + bb = img.bounding_box + assert isinstance(bb, dict) + assert all(len(v) == 2 for v in bb.values()) + + def test_get_center_returns_dict(self, tmp_path): + """predict.py — img.get_center(idx).""" + path = create_test_zarr(tmp_path, shape=(32, 32, 32)) + img = CellMapImage(path, "raw", (8.0, 8.0, 8.0), (4, 4, 4)) + center = img.get_center(0) + assert isinstance(center, dict) + assert all(isinstance(v, float) for v in center.values()) + + def test_array_indexing(self, tmp_path): + """predict.py / process.py — img[...] to load data.""" + path = create_test_zarr(tmp_path, shape=(32, 32, 32)) + img = CellMapImage(path, "raw", (8.0, 8.0, 8.0), (4, 4, 4), pad=True) + center = img.get_center(0) + patch = img[center] + assert isinstance(patch, torch.Tensor) + assert patch.shape == torch.Size([4, 4, 4]) + + def test_nearest_interpolation(self, tmp_path): + path = create_test_zarr(tmp_path, shape=(32, 32, 32)) + img = CellMapImage( + path, "label", (8.0, 8.0, 8.0), (4, 4, 4), interpolation="nearest" + ) + center = img.get_center(0) + patch = img[center] + assert patch.shape == torch.Size([4, 4, 4]) + + def test_linear_interpolation(self, tmp_path): + path = create_test_zarr(tmp_path, shape=(32, 32, 32)) + img = CellMapImage( + path, "raw", (8.0, 8.0, 8.0), (4, 4, 4), interpolation="linear", pad=True + ) + center = img.get_center(0) + patch = img[center] + assert patch.shape == torch.Size([4, 4, 4]) + + +# --------------------------------------------------------------------------- +# NaNtoNum and Binarize — exact import paths and usage patterns +# --------------------------------------------------------------------------- + + +class TestTransformImportPaths: + def test_nan_to_num_import(self): + from cellmap_data.transforms.augment import NaNtoNum + + t = NaNtoNum({"nan": 0, "posinf": None, "neginf": None}) + x = torch.tensor([float("nan")]) + out = t(x) + assert out[0] == 0.0 + + def test_binarize_import(self): + from cellmap_data.transforms.augment import Binarize + + t = Binarize() + x = torch.tensor([0.0, 1.0]) + assert torch.allclose(t(x), torch.tensor([0.0, 1.0])) + + def test_binarize_explicit_threshold(self): + from cellmap_data.transforms.augment import Binarize + + t = Binarize(0.5) + x = torch.tensor([0.3, 0.7]) + out = t(x) + assert out[0] == 0.0 + assert out[1] == 1.0 + + +# --------------------------------------------------------------------------- +# Utility functions — exact signatures and behaviors from API_TO_PRESERVE.md +# --------------------------------------------------------------------------- + + +class TestUtilFunctions: + # --- longest_common_substring --- + + def test_lcs_basic(self): + result = longest_common_substring("raw_input", "raw_target") + assert result == "raw_" + + def test_lcs_identical(self): + assert longest_common_substring("abc", "abc") == "abc" + + def test_lcs_no_common(self): + assert longest_common_substring("aaa", "bbb") == "" + + def test_lcs_import_path(self): + from cellmap_data.utils import longest_common_substring + + assert longest_common_substring("in_key", "target_key") == "_key" + + # --- array_has_singleton_dim --- + + def test_singleton_dim_true(self): + info = {"shape": (1, 64, 64)} + assert array_has_singleton_dim(info) is True + + def test_singleton_dim_false(self): + info = {"shape": (32, 64, 64)} + assert array_has_singleton_dim(info) is False + + def test_singleton_dim_none_input(self): + assert array_has_singleton_dim(None) is False + + def test_singleton_dim_empty(self): + assert array_has_singleton_dim({}) is False + + def test_singleton_dim_nested(self): + info = { + "a": {"shape": (1, 32, 32)}, + "b": {"shape": (8, 8, 8)}, + } + # summary=True (default) → any() of inner results + assert array_has_singleton_dim(info) is True + + def test_singleton_dim_nested_no_summary(self): + info = { + "a": {"shape": (1, 32, 32)}, + "b": {"shape": (8, 8, 8)}, + } + result = array_has_singleton_dim(info, summary=False) + assert isinstance(result, dict) + assert result["a"] is True + assert result["b"] is False + + # --- is_array_2D --- + + def test_is_2d_true(self): + info = {"shape": (64, 64)} + assert is_array_2D(info) is True + + def test_is_2d_false_3d(self): + info = {"shape": (32, 64, 64)} + assert is_array_2D(info) is False + + def test_is_2d_singleton_is_not_2d(self): + """A (1, 64, 64) shape has 3 dims, so is_array_2D returns False.""" + info = {"shape": (1, 64, 64)} + assert is_array_2D(info) is False + + def test_is_2d_none_input(self): + assert is_array_2D(None) is False + + def test_is_2d_nested_with_summary(self): + info = { + "a": {"shape": (64, 64)}, + "b": {"shape": (32, 64, 64)}, + } + # summary=any → True (at least one 2D) + result = is_array_2D(info, summary=any) + assert result is True + + def test_is_2d_nested_no_summary(self): + info = { + "a": {"shape": (64, 64)}, + "b": {"shape": (32, 64, 64)}, + } + result = is_array_2D(info) + assert isinstance(result, dict) + assert result["a"] is True + assert result["b"] is False + + # --- permute_singleton_dimension --- + + def test_permute_adds_singleton_if_none(self): + arr_dict = {"shape": [8, 8], "scale": [8.0, 8.0]} + permute_singleton_dimension(arr_dict, axis=0) + assert arr_dict["shape"][0] == 1 + assert len(arr_dict["shape"]) == 3 + + def test_permute_moves_existing_singleton(self): + arr_dict = {"shape": [1, 64, 64], "scale": [8.0, 8.0, 8.0]} + permute_singleton_dimension(arr_dict, axis=2) + # Singleton should now be at axis 2 + assert arr_dict["shape"][2] == 1 + + def test_permute_nested_dict(self): + arr_dict = { + "a": {"shape": [8, 8], "scale": [8.0, 8.0]}, + } + permute_singleton_dimension(arr_dict, axis=0) + assert arr_dict["a"]["shape"][0] == 1 + + def test_permute_scale_expanded(self): + """2D scale → 3D scale after permute.""" + arr_dict = {"shape": [8, 8], "scale": [8.0, 8.0]} + permute_singleton_dimension(arr_dict, axis=0) + assert len(arr_dict["scale"]) == 3 + + # --- get_fig_dict --- + # Actual signature: get_fig_dict(input_data: Tensor, target_data: Tensor, + # outputs: Tensor, classes: list) -> dict + # input_data shape: [batch, channels, *spatial] + # target_data shape: [batch, n_classes, *spatial] + # outputs shape: [batch, n_classes, *spatial] + + def test_get_fig_dict_returns_dict(self): + """get_fig_dict returns a dict of matplotlib figures.""" + import matplotlib + + matplotlib.use("Agg") # non-interactive backend for CI + # 1 batch item, 1 input channel, 1 class, 8x8 2D slices + input_data = torch.rand(1, 1, 8, 8) + target_data = torch.rand(1, 1, 8, 8) + outputs = torch.rand(1, 1, 8, 8) + result = get_fig_dict(input_data, target_data, outputs, ["mito"]) + assert isinstance(result, dict) + assert len(result) >= 1 + + def test_get_fig_dict_key_per_class(self): + import matplotlib + + matplotlib.use("Agg") + input_data = torch.rand(1, 1, 8, 8) + target_data = torch.rand(1, 2, 8, 8) + outputs = torch.rand(1, 2, 8, 8) + result = get_fig_dict(input_data, target_data, outputs, ["mito", "er"]) + # One figure per class + assert len(result) == 2 diff --git a/tests/test_base_classes.py b/tests/test_base_classes.py deleted file mode 100644 index 353a87e..0000000 --- a/tests/test_base_classes.py +++ /dev/null @@ -1,220 +0,0 @@ -"""Tests for base abstract classes.""" - -from abc import ABC - -import pytest -import torch - -from cellmap_data.base_dataset import CellMapBaseDataset -from cellmap_data.base_image import CellMapImageBase - - -class TestCellMapBaseDataset: - """Test the CellMapBaseDataset abstract base class.""" - - def test_cannot_instantiate_abstract_class(self): - """Test that CellMapBaseDataset cannot be instantiated directly.""" - with pytest.raises(TypeError, match="Can't instantiate abstract class"): - CellMapBaseDataset() - - def test_incomplete_implementation_raises_error(self): - """Test that incomplete implementations cannot be instantiated.""" - - # Missing all abstract methods - class IncompleteDataset(CellMapBaseDataset): - pass - - with pytest.raises(TypeError, match="Can't instantiate abstract class"): - IncompleteDataset() - - # Missing some abstract methods - class PartialDataset(CellMapBaseDataset): - @property - def class_counts(self): - return {} - - @property - def class_weights(self): - return {} - - with pytest.raises(TypeError, match="Can't instantiate abstract class"): - PartialDataset() - - def test_complete_implementation_can_be_instantiated(self): - """Test that complete implementations can be instantiated.""" - - class CompleteDataset(CellMapBaseDataset): - def __init__(self): - self.classes = ["class1", "class2"] - self.input_arrays = {"raw": {}} - self.target_arrays = {"labels": {}} - - @property - def class_counts(self): - return {"class1": 100.0, "class2": 200.0} - - @property - def class_weights(self): - return {"class1": 0.67, "class2": 0.33} - - @property - def validation_indices(self): - return [0, 1, 2] - - def to(self, device, non_blocking=True): - return self - - def set_raw_value_transforms(self, transforms): - pass - - def set_target_value_transforms(self, transforms): - pass - - # Should not raise - dataset = CompleteDataset() - assert isinstance(dataset, CellMapBaseDataset) - assert dataset.classes == ["class1", "class2"] - assert dataset.class_counts == {"class1": 100.0, "class2": 200.0} - assert dataset.class_weights == {"class1": 0.67, "class2": 0.33} - assert dataset.validation_indices == [0, 1, 2] - assert dataset.to("cpu") is dataset - dataset.set_raw_value_transforms(lambda x: x) - dataset.set_target_value_transforms(lambda x: x) - - def test_attributes_are_defined(self): - """Test that expected attributes are defined in the base class.""" - # Check type annotations exist - assert hasattr(CellMapBaseDataset, "__annotations__") - annotations = CellMapBaseDataset.__annotations__ - assert "classes" in annotations - assert "input_arrays" in annotations - assert "target_arrays" in annotations - - -class TestCellMapImageBase: - """Test the CellMapImageBase abstract base class.""" - - def test_cannot_instantiate_abstract_class(self): - """Test that CellMapImageBase cannot be instantiated directly.""" - with pytest.raises(TypeError, match="Can't instantiate abstract class"): - CellMapImageBase() - - def test_incomplete_implementation_raises_error(self): - """Test that incomplete implementations cannot be instantiated.""" - - # Missing all abstract methods - class IncompleteImage(CellMapImageBase): - pass - - with pytest.raises(TypeError, match="Can't instantiate abstract class"): - IncompleteImage() - - # Missing some abstract methods - class PartialImage(CellMapImageBase): - @property - def bounding_box(self): - return {"x": (0, 100), "y": (0, 100)} - - @property - def sampling_box(self): - return {"x": (10, 90), "y": (10, 90)} - - with pytest.raises(TypeError, match="Can't instantiate abstract class"): - PartialImage() - - def test_complete_implementation_can_be_instantiated(self): - """Test that complete implementations can be instantiated.""" - - class CompleteImage(CellMapImageBase): - def __getitem__(self, center): - return torch.zeros((1, 64, 64)) - - @property - def bounding_box(self): - return {"x": (0.0, 100.0), "y": (0.0, 100.0)} - - @property - def sampling_box(self): - return {"x": (10.0, 90.0), "y": (10.0, 90.0)} - - @property - def class_counts(self): - return 1000.0 - - def to(self, device, non_blocking=True): - pass - - def set_spatial_transforms(self, transforms): - pass - - # Should not raise - image = CompleteImage() - assert isinstance(image, CellMapImageBase) - center = {"x": 50.0, "y": 50.0} - result = image[center] - assert isinstance(result, torch.Tensor) - assert result.shape == (1, 64, 64) - assert image.bounding_box == {"x": (0.0, 100.0), "y": (0.0, 100.0)} - assert image.sampling_box == {"x": (10.0, 90.0), "y": (10.0, 90.0)} - assert image.class_counts == 1000.0 - image.to("cpu") - image.set_spatial_transforms(None) - - def test_class_counts_supports_dict_return_type(self): - """Test that class_counts can return a dictionary.""" - - class MultiClassImage(CellMapImageBase): - def __getitem__(self, center): - return torch.zeros((1, 64, 64)) - - @property - def bounding_box(self): - return {"x": (0.0, 100.0)} - - @property - def sampling_box(self): - return {"x": (10.0, 90.0)} - - @property - def class_counts(self): - return {"class1": 500.0, "class2": 300.0, "class3": 200.0} - - def to(self, device, non_blocking=True): - pass - - def set_spatial_transforms(self, transforms): - pass - - image = MultiClassImage() - counts = image.class_counts - assert isinstance(counts, dict) - assert counts == {"class1": 500.0, "class2": 300.0, "class3": 200.0} - - def test_bounding_box_can_be_none(self): - """Test that bounding_box property can return None.""" - - class UnboundedImage(CellMapImageBase): - def __getitem__(self, center): - return torch.zeros((1, 64, 64)) - - @property - def bounding_box(self): - return None - - @property - def sampling_box(self): - return None - - @property - def class_counts(self): - return 1000.0 - - def to(self, device, non_blocking=True): - pass - - def set_spatial_transforms(self, transforms): - pass - - image = UnboundedImage() - assert image.bounding_box is None - assert image.sampling_box is None diff --git a/tests/test_cellmap_dataset.py b/tests/test_cellmap_dataset.py deleted file mode 100644 index c5e0ed8..0000000 --- a/tests/test_cellmap_dataset.py +++ /dev/null @@ -1,617 +0,0 @@ -""" -Tests for CellMapDataset class. - -Tests dataset creation, data loading, and transformations using real data. -""" - -import pytest -import torch -import torchvision.transforms.v2 as T - -from cellmap_data import CellMapDataset -from cellmap_data.transforms import Binarize - -from .test_helpers import create_minimal_test_dataset, create_test_dataset - - -class TestCellMapDataset: - """Test suite for CellMapDataset class.""" - - @pytest.fixture - def minimal_dataset_config(self, tmp_path): - """Create a minimal dataset configuration.""" - return create_minimal_test_dataset(tmp_path) - - @pytest.fixture - def standard_dataset_config(self, tmp_path): - """Create a standard dataset configuration.""" - return create_test_dataset( - tmp_path, - raw_shape=(32, 32, 32), - num_classes=3, - raw_scale=(8.0, 8.0, 8.0), - ) - - def test_initialization_basic(self, minimal_dataset_config): - """Test basic dataset initialization.""" - config = minimal_dataset_config - - input_arrays = { - "raw": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - } - } - - target_arrays = { - "gt": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - } - } - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - is_train=True, - force_has_data=True, - ) - - assert dataset.raw_path == config["raw_path"] - assert dataset.classes == config["classes"] - assert dataset.is_train is True - assert len(dataset.classes) == 2 - - def test_initialization_without_classes(self, minimal_dataset_config): - """Test dataset initialization without classes (raw data only).""" - config = minimal_dataset_config - - input_arrays = { - "raw": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - } - } - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=None, - input_arrays=input_arrays, - is_train=False, - force_has_data=True, - ) - - assert dataset.raw_only is True - assert dataset.classes == [] - - def test_input_arrays_configuration(self, minimal_dataset_config): - """Test input arrays configuration.""" - config = minimal_dataset_config - - input_arrays = { - "raw_4nm": { - "shape": (16, 16, 16), - "scale": (4.0, 4.0, 4.0), - }, - "raw_8nm": { - "shape": (8, 8, 8), - "scale": (8.0, 8.0, 8.0), - }, - } - - target_arrays = { - "gt": { - "shape": (8, 8, 8), - "scale": (8.0, 8.0, 8.0), - } - } - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - ) - - assert "raw_4nm" in dataset.input_arrays - assert "raw_8nm" in dataset.input_arrays - assert dataset.input_arrays["raw_4nm"]["shape"] == (16, 16, 16) - - def test_target_arrays_configuration(self, minimal_dataset_config): - """Test target arrays configuration.""" - config = minimal_dataset_config - - input_arrays = { - "raw": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - } - } - - target_arrays = { - "labels": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - }, - "distances": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - }, - } - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - ) - - assert "labels" in dataset.target_arrays - assert "distances" in dataset.target_arrays - - def test_spatial_transforms_configuration(self, minimal_dataset_config): - """Test spatial transforms configuration.""" - config = minimal_dataset_config - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - spatial_transforms = { - "mirror": {"axes": {"x": 0.5, "y": 0.5, "z": 0.2}}, - "rotate": {"axes": {"z": [-30, 30]}}, - "transpose": {"axes": ["x", "y"]}, - } - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - spatial_transforms=spatial_transforms, - is_train=True, - force_has_data=True, - ) - - assert dataset.spatial_transforms is not None - assert "mirror" in dataset.spatial_transforms - assert "rotate" in dataset.spatial_transforms - - def test_value_transforms_configuration(self, minimal_dataset_config): - """Test value transforms configuration.""" - config = minimal_dataset_config - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - raw_transforms = T.Compose( - [ - T.ToDtype(torch.float, scale=True), - ] - ) - - target_transforms = T.Compose( - [ - Binarize(threshold=0.5), - ] - ) - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - raw_value_transforms=raw_transforms, - target_value_transforms=target_transforms, - ) - - assert dataset.raw_value_transforms is not None - assert dataset.target_value_transforms is not None - - def test_class_relation_dict(self, minimal_dataset_config): - """Test class relationship dictionary.""" - config = minimal_dataset_config - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - class_relation_dict = { - "class_0": ["class_1"], - "class_1": ["class_0"], - } - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - class_relation_dict=class_relation_dict, - ) - - assert dataset.class_relation_dict is not None - assert "class_0" in dataset.class_relation_dict - - def test_axis_order_parameter(self, minimal_dataset_config): - """Test different axis orders.""" - config = minimal_dataset_config - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - for axis_order in ["zyx", "xyz", "yxz"]: - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - axis_order=axis_order, - ) - assert dataset.axis_order == axis_order - - def test_is_train_parameter(self, minimal_dataset_config): - """Test is_train parameter.""" - config = minimal_dataset_config - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - # Training dataset - train_dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - is_train=True, - force_has_data=True, - ) - assert train_dataset.is_train is True - - # Validation dataset - val_dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - is_train=False, - force_has_data=True, - ) - assert val_dataset.is_train is False - - def test_pad_parameter(self, minimal_dataset_config): - """Test pad parameter.""" - config = minimal_dataset_config - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - # With padding - dataset_pad = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - pad=True, - ) - assert dataset_pad.pad is True - - # Without padding - dataset_no_pad = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - pad=False, - ) - assert dataset_no_pad.pad is False - - def test_empty_value_parameter(self, minimal_dataset_config): - """Test empty_value parameter.""" - config = minimal_dataset_config - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - # Test with NaN - dataset_nan = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - empty_value=torch.nan, - ) - assert torch.isnan(torch.tensor(dataset_nan.empty_value)) - - # Test with numeric value - dataset_zero = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - empty_value=0.0, - ) - assert dataset_zero.empty_value == 0.0 - - def test_device_parameter(self, minimal_dataset_config): - """Test device parameter.""" - config = minimal_dataset_config - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - # CPU device - dataset_cpu = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - device="cpu", - ) - # Device should be set (exact value checked in image tests) - assert dataset_cpu is not None - - def test_force_has_data_parameter(self, minimal_dataset_config): - """Test force_has_data parameter.""" - config = minimal_dataset_config - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - force_has_data=True, - ) - - assert dataset.force_has_data is True - - def test_rng_parameter(self, minimal_dataset_config): - """Test random number generator parameter.""" - config = minimal_dataset_config - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - # Create custom RNG - rng = torch.Generator() - rng.manual_seed(42) - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - rng=rng, - ) - - assert dataset._rng is rng - - def test_context_parameter(self, minimal_dataset_config): - """Test TensorStore context parameter.""" - import tensorstore as ts - - config = minimal_dataset_config - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - context = ts.Context() - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - context=context, - ) - - assert dataset.context is context - - def test_max_workers_parameter(self, minimal_dataset_config): - """Test max_workers parameter.""" - config = minimal_dataset_config - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - max_workers=4, - ) - - # Dataset should be created successfully - assert dataset is not None - - def test_empty_dataset_creation(self): - """Test CellMapDataset.empty() static method.""" - from cellmap_data import CellMapDataset - - # Create an empty dataset - empty_dataset = CellMapDataset.empty() - - # Verify basic properties - assert empty_dataset is not None - assert isinstance(empty_dataset, CellMapDataset) - assert empty_dataset.has_data is False - assert len(empty_dataset) == 0 - - # Verify the newly initialized attributes - assert hasattr(empty_dataset, "sampling_box_shape") - assert isinstance(empty_dataset.sampling_box_shape, dict) - assert all(v == 0 for v in empty_dataset.sampling_box_shape.values()) - - # Verify axis_order is set (should have default value) - assert hasattr(empty_dataset, "axis_order") - assert len(empty_dataset.sampling_box_shape) == len(empty_dataset.axis_order) - - def test_empty_datasetsampling_box_shape(self): - """Test that empty dataset has correct sampling_box_shape initialization.""" - from cellmap_data import CellMapDataset - - empty_dataset = CellMapDataset.empty() - - # Verify sampling_box_shape keys match axis_order - assert set(empty_dataset.sampling_box_shape.keys()) == set( - empty_dataset.axis_order - ) - - # Verify all dimensions are 0 - for axis in empty_dataset.axis_order: - assert empty_dataset.sampling_box_shape[axis] == 0 - - def test_bounding_box_shape(self, minimal_dataset_config): - """Test that bounding_box_shape property correctly computes shape from bounding box.""" - config = minimal_dataset_config - - input_arrays = { - "raw": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - } - } - - target_arrays = { - "gt": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - } - } - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - force_has_data=True, - ) - - # Access bounding_box_shape - bbox_shape = dataset.bounding_box_shape - - # Verify it's a dict with expected keys - assert isinstance(bbox_shape, dict) - assert set(bbox_shape.keys()) == set(dataset.axis_order) - - # Verify all values are positive integers - for axis, size in bbox_shape.items(): - assert isinstance(size, (int, float)) - assert size > 0 - - def test_size_property(self, minimal_dataset_config): - """Test that size property correctly computes dataset size from bounding box. - - This test ensures the bug fix in PR #61 is covered: size must use - .values() not .items() to properly unpack bounding box numeric bounds. - """ - config = minimal_dataset_config - - input_arrays = { - "raw": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - } - } - - target_arrays = { - "gt": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - } - } - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - force_has_data=True, - ) - - # Access the size property - this would have raised TypeError before the fix - size = dataset.size - - # Verify it's a positive integer - assert isinstance(size, int) - assert size > 0 - - # Verify it matches the product of bounding box dimensions - import numpy as np - - expected_size = int( - np.prod([stop - start for start, stop in dataset.bounding_box.values()]) - ) - assert size == expected_size - - def test_size_property_with_known_dimensions(self, tmp_path): - """Test size property with specific known bounding box dimensions.""" - # Create a dataset with known dimensions - config = create_test_dataset( - tmp_path, - raw_shape=(40, 30, 20), # Known shape - num_classes=2, - raw_scale=(4.0, 4.0, 4.0), - ) - - input_arrays = { - "raw": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - } - } - - target_arrays = { - "gt": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - } - } - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - force_has_data=True, - ) - - # Get the size - size = dataset.size - - # Verify it's the product of the bounding box dimensions - # The bounding box should correspond to the raw data shape - bbox = dataset.bounding_box - import numpy as np - - # Calculate expected size from bounding box - dims = [stop - start for start, stop in bbox.values()] - expected_size = int(np.prod(dims)) - - assert size == expected_size - assert size > 0 diff --git a/tests/test_cellmap_image.py b/tests/test_cellmap_image.py deleted file mode 100644 index 1f238bc..0000000 --- a/tests/test_cellmap_image.py +++ /dev/null @@ -1,282 +0,0 @@ -""" -Tests for CellMapImage class. - -Tests image loading, spatial transformations, and value transformations -using real Zarr data without mocks. -""" - -import numpy as np -import pytest -import torch - -from cellmap_data import CellMapImage - -from .test_helpers import create_test_image_data, create_test_zarr_array - - -class TestCellMapImage: - """Test suite for CellMapImage class.""" - - @pytest.fixture - def test_zarr_image(self, tmp_path): - """Create a test Zarr image.""" - data = create_test_image_data((32, 32, 32), pattern="gradient") - path = tmp_path / "test_image.zarr" - create_test_zarr_array(path, data, scale=(4.0, 4.0, 4.0)) - return str(path), data - - def test_initialization(self, test_zarr_image): - """Test basic initialization of CellMapImage.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - axis_order="zyx", - ) - - assert image.path == path - assert image.label_class == "test_class" - assert image.scale == {"z": 4.0, "y": 4.0, "x": 4.0} - assert image.output_shape == {"z": 16, "y": 16, "x": 16} - assert image.axes == "zyx" - - def test_device_selection(self, test_zarr_image): - """Test device selection logic.""" - path, _ = test_zarr_image - - # Test explicit device - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - device="cpu", - ) - assert image.device == "cpu" - - # Test automatic device selection - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - ) - # Should select cuda if available, otherwise mps, otherwise cpu - assert image.device in ["cuda", "mps", "cpu"] - - def test_scale_and_shape_mismatch(self, test_zarr_image): - """Test handling of mismatched axis order, scale, and shape.""" - path, _ = test_zarr_image - - # Test with more axes in axis_order than in scale - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0), - target_voxel_shape=(8, 8), - axis_order="zyx", - ) - # Should pad scale with first value - assert image.scale == {"z": 4.0, "y": 4.0, "x": 4.0} - - # Test with more axes in axis_order than in voxel_shape - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8), - axis_order="zyx", - ) - # Should pad voxel_shape with 1s - assert image.output_shape == {"z": 1, "y": 8, "x": 8} - - def test_output_size_calculation(self, test_zarr_image): - """Test that output size is correctly calculated.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test", - target_scale=(8.0, 8.0, 8.0), - target_voxel_shape=(16, 16, 16), - ) - - # Output size should be voxel_shape * scale - expected_size = {"z": 128.0, "y": 128.0, "x": 128.0} - assert image.output_size == expected_size - - def test_value_transform(self, test_zarr_image): - """Test value transform application.""" - path, _ = test_zarr_image - - # Create a simple transform that multiplies by 2 - def multiply_by_2(x): - return x * 2 - - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - value_transform=multiply_by_2, - ) - - assert image.value_transform is not None - # Test the transform works - test_tensor = torch.tensor([1.0, 2.0, 3.0]) - result = image.value_transform(test_tensor) - expected = torch.tensor([2.0, 4.0, 6.0]) - assert torch.allclose(result, expected) - - def test_2d_image(self, tmp_path): - """Test handling of 2D images.""" - # Create a 2D image - data = create_test_image_data((32, 32), pattern="checkerboard") - path = tmp_path / "test_2d.zarr" - create_test_zarr_array(path, data, axes=("y", "x"), scale=(4.0, 4.0)) - - image = CellMapImage( - path=str(path), - target_class="test_2d", - target_scale=(4.0, 4.0), - target_voxel_shape=(16, 16), - axis_order="yx", - ) - - assert image.axes == "yx" - assert image.scale == {"y": 4.0, "x": 4.0} - - def test_pad_parameter(self, test_zarr_image): - """Test pad parameter.""" - path, _ = test_zarr_image - - image_with_pad = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - pad=True, - ) - assert image_with_pad.pad is True - - image_without_pad = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - pad=False, - ) - assert image_without_pad.pad is False - - def test_pad_value(self, test_zarr_image): - """Test pad value parameter.""" - path, _ = test_zarr_image - - # Test with NaN pad value - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - pad=True, - pad_value=np.nan, - ) - assert np.isnan(image.pad_value) - - # Test with numeric pad value - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - pad=True, - pad_value=0.0, - ) - assert image.pad_value == 0.0 - - def test_interpolation_modes(self, test_zarr_image): - """Test different interpolation modes.""" - path, _ = test_zarr_image - - for interp in ["nearest", "linear"]: - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - interpolation=interp, - ) - assert image.interpolation == interp - - def test_different_axis_orders(self, tmp_path): - """Test different axis orderings.""" - for axis_order in ["xyz", "zyx", "yxz"]: - data = create_test_image_data((16, 16, 16), pattern="random") - path = tmp_path / f"test_{axis_order}.zarr" - create_test_zarr_array( - path, data, axes=tuple(axis_order), scale=(4.0, 4.0, 4.0) - ) - - image = CellMapImage( - path=str(path), - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - axis_order=axis_order, - ) - assert image.axes == axis_order - assert len(image.scale) == 3 - - def test_different_dtypes(self, tmp_path): - """Test handling of different data types.""" - dtypes = [np.float32, np.float64, np.uint8, np.uint16, np.int32] - - for dtype in dtypes: - data = create_test_image_data((16, 16, 16), dtype=dtype, pattern="constant") - path = tmp_path / f"test_{dtype.__name__}.zarr" - create_test_zarr_array(path, data, scale=(4.0, 4.0, 4.0)) - - image = CellMapImage( - path=str(path), - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - ) - # Image should be created successfully - assert image.path == str(path) - - def test_context_parameter(self, test_zarr_image): - """Test TensorStore context parameter.""" - import tensorstore as ts - - path, _ = test_zarr_image - - # Create a custom context - context = ts.Context() - - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - context=context, - ) - - assert image.context is context - - def test_without_context(self, test_zarr_image): - """Test that image works without explicit context.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - context=None, - ) - - assert image.context is None diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 177765a..e56c8a1 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -1,759 +1,92 @@ -""" -Tests for CellMapDataLoader class. +"""Tests for CellMapDataLoader.""" -Tests data loading, batching, and optimization features using real data. -""" +from __future__ import annotations -import pytest -import tensorstore as ts import torch -from cellmap_data import CellMapDataLoader, CellMapDataset, CellMapMultiDataset -from .test_helpers import create_test_dataset - - -class TestCellMapDataLoader: - """Test suite for CellMapDataLoader class.""" - - @pytest.fixture - def test_dataset(self, tmp_path): - """Create a test dataset for loader tests.""" - config = create_test_dataset( - tmp_path, - raw_shape=(32, 32, 32), - num_classes=2, - raw_scale=(4.0, 4.0, 4.0), - ) - - input_arrays = { - "raw": { - "shape": (16, 16, 16), - "scale": (4.0, 4.0, 4.0), - } - } - - target_arrays = { - "gt": { - "shape": (16, 16, 16), - "scale": (4.0, 4.0, 4.0), - } - } - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - is_train=True, - force_has_data=True, - # Force dataset to have data for testing - ) - - return dataset - - def test_initialization_basic(self, test_dataset): - """Test basic DataLoader initialization.""" - loader = CellMapDataLoader( - test_dataset, - batch_size=2, - num_workers=0, # Use 0 for testing - ) - - assert loader is not None - assert loader.batch_size == 2 - - def test_batch_size_parameter(self, test_dataset): - """Test different batch sizes.""" - for batch_size in [1, 2, 4, 8]: - loader = CellMapDataLoader( - test_dataset, - batch_size=batch_size, - num_workers=0, - ) - assert loader.batch_size == batch_size - - def test_num_workers_parameter(self, test_dataset): - """Test num_workers parameter.""" - for num_workers in [0, 1, 2]: - loader = CellMapDataLoader( - test_dataset, - batch_size=2, - num_workers=num_workers, - ) - # Loader should be created successfully - assert loader is not None - - def test_weighted_sampler_parameter(self, test_dataset): - """Test weighted sampler option.""" - # With weighted sampler - loader_weighted = CellMapDataLoader( - test_dataset, - batch_size=2, - weighted_sampler=True, - num_workers=0, - ) - assert loader_weighted is not None - - # Without weighted sampler - loader_no_weight = CellMapDataLoader( - test_dataset, - batch_size=2, - weighted_sampler=False, - num_workers=0, - ) - assert loader_no_weight is not None - - def test_is_train_parameter(self, test_dataset): - """Test is_train parameter.""" - # Training loader - train_loader = CellMapDataLoader( - test_dataset, - batch_size=2, - is_train=True, - force_has_data=True, - num_workers=0, - ) - assert train_loader is not None - - # Validation loader - val_loader = CellMapDataLoader( - test_dataset, - batch_size=2, - is_train=False, - force_has_data=True, - num_workers=0, - ) - assert val_loader is not None - - def test_device_parameter(self, test_dataset): - """Test device parameter.""" - loader_cpu = CellMapDataLoader( - test_dataset, - batch_size=2, - device="cpu", - num_workers=0, - ) - assert loader_cpu is not None - - def test_pin_memory_parameter(self, test_dataset): - """Test pin_memory parameter.""" - loader = CellMapDataLoader( - test_dataset, - batch_size=2, - pin_memory=True, - num_workers=0, - ) - assert loader is not None - - def test_persistent_workers_parameter(self, test_dataset): - """Test persistent_workers parameter.""" - # Only works with num_workers > 0 - loader = CellMapDataLoader( - test_dataset, - batch_size=2, - num_workers=1, - persistent_workers=True, - ) - assert loader is not None - - def test_prefetch_factor_parameter(self, test_dataset): - """Test prefetch_factor parameter.""" - # Only works with num_workers > 0 - for prefetch in [2, 4, 8]: - loader = CellMapDataLoader( - test_dataset, - batch_size=2, - num_workers=1, - prefetch_factor=prefetch, - ) - assert loader is not None - - def test_iterations_per_epoch_parameter(self, test_dataset): - """Test iterations_per_epoch parameter.""" - loader = CellMapDataLoader( - test_dataset, - batch_size=2, - iterations_per_epoch=10, - num_workers=0, - ) - assert loader is not None - - def test_shuffle_parameter(self, test_dataset): - """Test shuffle parameter.""" - # With shuffle - loader_shuffle = CellMapDataLoader( - test_dataset, - batch_size=2, - shuffle=True, - num_workers=0, - ) - assert loader_shuffle is not None - - # Without shuffle - loader_no_shuffle = CellMapDataLoader( - test_dataset, - batch_size=2, - shuffle=False, - num_workers=0, - ) - assert loader_no_shuffle is not None - - def test_drop_last_parameter(self, test_dataset): - """Test drop_last parameter.""" - loader = CellMapDataLoader( - test_dataset, - batch_size=3, - drop_last=True, - num_workers=0, - ) - assert loader is not None - - def test_timeout_parameter(self, test_dataset): - """Test timeout parameter.""" - loader = CellMapDataLoader( - test_dataset, - batch_size=2, - num_workers=1, - timeout=30, - ) - assert loader is not None - - -class TestDataLoaderOperations: - """Test DataLoader operations and functionality.""" - - @pytest.fixture - def simple_loader(self, tmp_path): - """Create a simple loader for operation tests.""" - config = create_test_dataset( - tmp_path, - raw_shape=(24, 24, 24), - num_classes=2, - raw_scale=(4.0, 4.0, 4.0), - ) - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - ) - - print(config) - assert len(dataset) > 0 - - return CellMapDataLoader(dataset, batch_size=2, num_workers=0) - - def test_length(self, simple_loader): - """Test that loader has a length.""" - # Loader should implement __len__ - length = len(simple_loader) - assert isinstance(length, int) - assert length > 0 - - def test_device_transfer(self, simple_loader): - """Test transferring loader to device.""" - # Test CPU transfer - loader_cpu = simple_loader.to("cpu") - assert loader_cpu is not None - - def test_non_blocking_transfer(self, simple_loader): - """Test non-blocking device transfer.""" - loader = simple_loader.to("cpu", non_blocking=True) - assert loader is not None - - -class TestDataLoaderIntegration: - """Integration tests for DataLoader with datasets.""" - - def test_loader_with_transforms(self, tmp_path): - """Test loader with dataset that has transforms.""" - import torchvision.transforms.v2 as T - - from cellmap_data.transforms import Binarize - - config = create_test_dataset( - tmp_path, - raw_shape=(32, 32, 32), - num_classes=2, - ) - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - raw_transforms = T.Compose([T.ToDtype(torch.float, scale=True)]) - target_transforms = T.Compose([Binarize(threshold=0.5)]) - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - raw_value_transforms=raw_transforms, - target_value_transforms=target_transforms, - ) - - loader = CellMapDataLoader(dataset, batch_size=2, num_workers=0) - assert loader is not None - - def test_loader_with_spatial_transforms(self, tmp_path): - """Test loader with spatial transforms.""" - config = create_test_dataset( - tmp_path, - raw_shape=(32, 32, 32), - num_classes=2, - ) - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - spatial_transforms = { - "mirror": {"axes": {"x": 0.5}}, - "rotate": {"axes": {"z": [-30, 30]}}, - } - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - spatial_transforms=spatial_transforms, - is_train=True, - force_has_data=True, - ) - - loader = CellMapDataLoader(dataset, batch_size=2, num_workers=0) - assert loader is not None - - def test_loader_reproducibility(self, tmp_path): - """Test loader reproducibility with fixed seed.""" - config = create_test_dataset( - tmp_path, - raw_shape=(24, 24, 24), - num_classes=2, - seed=42, - ) - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - # Create two loaders with same seed - torch.manual_seed(42) - dataset1 = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - ) - loader1 = CellMapDataLoader(dataset1, batch_size=2, num_workers=0) - - torch.manual_seed(42) - dataset2 = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - ) - loader2 = CellMapDataLoader(dataset2, batch_size=2, num_workers=0) - - # Both loaders should be created successfully - assert loader1 is not None - assert loader2 is not None - - def test_multiple_loaders_same_dataset(self, tmp_path): - """Test multiple loaders for same dataset.""" - config = create_test_dataset( - tmp_path, - raw_shape=(32, 32, 32), - num_classes=2, - ) - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - ) - - # Create multiple loaders - loader1 = CellMapDataLoader(dataset, batch_size=2, num_workers=0) - loader2 = CellMapDataLoader(dataset, batch_size=4, num_workers=0) - - assert loader1.batch_size == 2 - assert loader2.batch_size == 4 - - def test_loader_memory_optimization(self, tmp_path): - """Test memory optimization settings.""" - config = create_test_dataset( - tmp_path, - raw_shape=(32, 32, 32), - num_classes=2, - ) - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - ) - - # Test with memory optimization settings - loader = CellMapDataLoader( - dataset, - batch_size=2, - num_workers=1, - pin_memory=True, - prefetch_factor=2, - persistent_workers=True, - ) - - assert loader is not None +from cellmap_data import CellMapDataLoader, CellMapDataset +from .test_helpers import create_test_dataset -# --------------------------------------------------------------------------- -# Helper: collect every CellMapImage from a CellMapDataset's sources -# --------------------------------------------------------------------------- -def _all_images(dataset: CellMapDataset): - """Yield every CellMapImage in a dataset's input and target sources.""" - from cellmap_data.image import CellMapImage +INPUT_ARRAYS = {"raw": {"shape": (4, 4, 4), "scale": (8.0, 8.0, 8.0)}} +TARGET_ARRAYS = {"labels": {"shape": (4, 4, 4), "scale": (8.0, 8.0, 8.0)}} +CLASSES = ["mito", "er"] - for source in list(dataset.input_sources.values()) + list( - dataset.target_sources.values() - ): - if isinstance(source, CellMapImage): - yield source - elif isinstance(source, dict): - for sub in source.values(): - if isinstance(sub, CellMapImage): - yield sub +def _make_ds(tmp_path): + info = create_test_dataset(tmp_path, classes=CLASSES) + return CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=CLASSES, + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + pad=True, + ) -class TestTensorStoreCacheBounding: - """Tests for the tensorstore_cache_bytes cache-bounding feature.""" - @pytest.fixture - def dataset(self, tmp_path): - config = create_test_dataset( - tmp_path, - raw_shape=(24, 24, 24), - num_classes=2, - raw_scale=(4.0, 4.0, 4.0), - ) - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - return CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - force_has_data=True, - ) - - # -- parameter stored on loader ------------------------------------------ - - def test_cache_bytes_stored_on_loader(self, dataset): - """tensorstore_cache_bytes is stored as an attribute.""" +class TestCellMapDataLoader: + def test_basic_iteration(self, tmp_path): + ds = _make_ds(tmp_path) + loader = CellMapDataLoader(ds, classes=CLASSES, batch_size=2, is_train=False) + batches = list(loader) + assert len(batches) > 0 + + def test_batch_contains_idx(self, tmp_path): + ds = _make_ds(tmp_path) + loader = CellMapDataLoader(ds, classes=CLASSES, batch_size=2, is_train=False) + for batch in loader: + assert "idx" in batch + break + + def test_batch_raw_shape(self, tmp_path): + ds = _make_ds(tmp_path) + loader = CellMapDataLoader(ds, classes=CLASSES, batch_size=2, is_train=False) + for batch in loader: + raw = batch["raw"] + assert isinstance(raw, torch.Tensor) + # batch_size is 2 but last batch may be smaller + assert raw.shape[1:] == torch.Size([4, 4, 4]) + break + + def test_len(self, tmp_path): + ds = _make_ds(tmp_path) + loader = CellMapDataLoader(ds, classes=CLASSES, batch_size=1, is_train=False) + assert len(loader) > 0 + + def test_weighted_sampler_train(self, tmp_path): + from cellmap_data.sampler import ClassBalancedSampler + + ds = _make_ds(tmp_path) loader = CellMapDataLoader( - dataset, num_workers=0, tensorstore_cache_bytes=100_000_000 - ) - assert loader.tensorstore_cache_bytes == 100_000_000 - - def test_default_limit(self, dataset): - """Without the parameter (or env var), cache bytes is set by default.""" - from cellmap_data.dataloader import _DEFAULT_TENSORSTORE_CACHE_BYTES - - loader = CellMapDataLoader(dataset, num_workers=0) - assert loader.tensorstore_cache_bytes == _DEFAULT_TENSORSTORE_CACHE_BYTES - - # -- per-worker byte math ------------------------------------------------ - - def test_per_worker_division(self, dataset): - """per_worker = total // num_workers is applied to every CellMapImage.""" - total = 400_000_000 # 400 MB - num_workers = 3 - CellMapDataLoader( - dataset, num_workers=num_workers, tensorstore_cache_bytes=total - ) - # 133_333_333 each, if total = 400_000_000 and num_workers = 3 - expected = total // num_workers - for img in _all_images(dataset): - assert isinstance(img.context, ts.Context) - assert img.context["cache_pool"].to_json() == { - "total_bytes_limit": expected - } - - def test_single_process_uses_full_budget(self, dataset): - """With num_workers=0 the whole budget goes to the single process (÷ 1).""" - total = 200_000_000 - CellMapDataLoader(dataset, num_workers=0, tensorstore_cache_bytes=total) - for img in _all_images(dataset): - assert img.context["cache_pool"].to_json() == {"total_bytes_limit": total} - - def test_context_set_on_target_images(self, dataset): - """Cache limit is applied to target-source images, not just input-source images.""" - from cellmap_data.image import CellMapImage - - CellMapDataLoader(dataset, num_workers=2, tensorstore_cache_bytes=200_000_000) - expected = 200_000_000 // 2 - for sources in dataset.target_sources.values(): - if isinstance(sources, dict): - for src in sources.values(): - if isinstance(src, CellMapImage): - assert src.context["cache_pool"].to_json() == { - "total_bytes_limit": expected - } - - # -- env var fallback ---------------------------------------------------- - - def test_env_var_sets_cache_limit(self, dataset, monkeypatch): - """CELLMAP_TENSORSTORE_CACHE_BYTES env var is used when parameter is not set.""" - monkeypatch.setenv("CELLMAP_TENSORSTORE_CACHE_BYTES", "300000000") - loader = CellMapDataLoader(dataset, num_workers=3) - assert loader.tensorstore_cache_bytes == 300_000_000 - expected = 300_000_000 // 3 # 100 MB per worker - for img in _all_images(dataset): - assert img.context["cache_pool"].to_json() == { - "total_bytes_limit": expected - } - - def test_param_overrides_env_var(self, dataset, monkeypatch): - """Explicit parameter takes precedence over the env var.""" - monkeypatch.setenv("CELLMAP_TENSORSTORE_CACHE_BYTES", "999999999") - CellMapDataLoader(dataset, num_workers=2, tensorstore_cache_bytes=200_000_000) - expected = 200_000_000 // 2 # 100 MB — param wins - for img in _all_images(dataset): - assert img.context["cache_pool"].to_json() == { - "total_bytes_limit": expected - } - - # -- validation ---------------------------------------------------------- - - def test_negative_cache_bytes_raises_error(self, dataset): - """Negative tensorstore_cache_bytes values are rejected.""" - with pytest.raises(ValueError, match="must be >= 0"): - CellMapDataLoader(dataset, num_workers=1, tensorstore_cache_bytes=-100) - - def test_negative_env_var_raises_error(self, dataset, monkeypatch): - """Negative values in CELLMAP_TENSORSTORE_CACHE_BYTES are rejected.""" - monkeypatch.setenv("CELLMAP_TENSORSTORE_CACHE_BYTES", "-500") - with pytest.raises(ValueError, match="must be >= 0"): - CellMapDataLoader(dataset, num_workers=1) - - def test_invalid_env_var_raises_error(self, dataset, monkeypatch): - """Non-integer values in CELLMAP_TENSORSTORE_CACHE_BYTES are rejected.""" - monkeypatch.setenv("CELLMAP_TENSORSTORE_CACHE_BYTES", "not_a_number") - with pytest.raises(ValueError, match="Invalid value for environment variable"): - CellMapDataLoader(dataset, num_workers=1) - - def test_warning_when_cache_less_than_workers(self, dataset, caplog): - """A warning is logged when tensorstore_cache_bytes < num_workers.""" - import logging - - with caplog.at_level(logging.WARNING, logger="cellmap_data.dataloader"): - CellMapDataLoader(dataset, num_workers=3, tensorstore_cache_bytes=2) - - # Check that warning was emitted - assert any( - "per-worker cache limit of 0 bytes" in r.message for r in caplog.records + ds, classes=CLASSES, batch_size=2, is_train=True, weighted_sampler=True ) - # Each worker gets 1 byte (the minimum) - for img in _all_images(dataset): - assert img.context["cache_pool"].to_json() == {"total_bytes_limit": 1} - - # -- CellMapMultiDataset traversal --------------------------------------- - - def test_multidataset_all_images_bounded(self, tmp_path): - """Recursive traversal reaches images in every sub-dataset.""" - datasets = [] - for i in range(2): - config = create_test_dataset( - tmp_path / f"ds{i}", - raw_shape=(24, 24, 24), - num_classes=2, - raw_scale=(4.0, 4.0, 4.0), - ) - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - datasets.append( - CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - force_has_data=True, - ) - ) - - multi = CellMapMultiDataset( - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets=datasets, - ) - - CellMapDataLoader(multi, num_workers=2, tensorstore_cache_bytes=200_000_000) - expected = 200_000_000 // 2 - - for ds in datasets: - for img in _all_images(ds): - assert img.context["cache_pool"].to_json() == { - "total_bytes_limit": expected - } + assert isinstance(loader._sampler, ClassBalancedSampler) - # -- warning when array already open ------------------------------------ - - def test_warning_when_array_already_open(self, dataset, caplog): - """A warning is logged when _array is already cached on an image.""" - import logging - - img = next(iter(dataset.input_sources.values())) - _ = img.array # force-open the TensorStore array - - with caplog.at_level(logging.WARNING, logger="cellmap_data.dataloader"): - CellMapDataLoader( - dataset, num_workers=1, tensorstore_cache_bytes=100_000_000 - ) - - assert any( - "cache_pool limit will not apply" in r.message for r in caplog.records - ) - # context is still updated on the image object (even though the open array isn't affected) - assert img.context["cache_pool"].to_json() == {"total_bytes_limit": 100_000_000} - - # -- functional: data still loads ---------------------------------------- - - def test_data_loads_with_bounded_cache(self, dataset): - """A bounded-cache loader can still produce a valid batch.""" + def test_no_weighted_sampler_val(self, tmp_path): + ds = _make_ds(tmp_path) loader = CellMapDataLoader( - dataset, batch_size=2, num_workers=0, tensorstore_cache_bytes=50_000_000 - ) - batch = next(iter(loader)) - assert "raw" in batch - assert isinstance(batch["raw"], torch.Tensor) - assert batch["raw"].shape[0] == 2 - - -class TestRefreshCleanup: - """Tests for refresh() explicitly releasing the old DataLoader. - - Before the fix, calling refresh() while persistent_workers=True left old - worker processes alive until GC, causing process accumulation across epochs. - After the fix, the old loader is explicitly deleted before creating a new one. - """ - - @pytest.fixture - def loader(self, tmp_path): - config = create_test_dataset( - tmp_path, - raw_shape=(24, 24, 24), - num_classes=2, - raw_scale=(4.0, 4.0, 4.0), - ) - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - force_has_data=True, - ) - return CellMapDataLoader(dataset, batch_size=2, num_workers=0) - - def test_refresh_replaces_loader(self, loader): - """refresh() must replace _pytorch_loader with a new object.""" - old_loader = loader._pytorch_loader - loader.refresh() - new_loader = loader._pytorch_loader - - assert new_loader is not None - assert new_loader is not old_loader - - def test_refresh_multiple_times_does_not_accumulate(self, loader): - """Calling refresh() repeatedly must not leave stale references.""" - import gc - import weakref - - # Hold a weak reference to the first loader - old_loader = loader._pytorch_loader - ref = weakref.ref(old_loader) - del old_loader - - loader.refresh() - - # Force GC to collect any reference cycles - gc.collect() - - # The old loader should have been released - assert ref() is None, ( - "Old DataLoader was not released after refresh(); " - "worker processes may be leaking." - ) - - def test_refresh_with_persistent_workers_releases_old_loader(self, tmp_path): - """With persistent_workers=True, refresh() must still release the old loader.""" - import gc - import weakref - - config = create_test_dataset( - tmp_path, - raw_shape=(24, 24, 24), - num_classes=2, - raw_scale=(4.0, 4.0, 4.0), - ) - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - force_has_data=True, - ) - # num_workers=1 makes persistent_workers meaningful - persistent_loader = CellMapDataLoader( - dataset, + ds, + classes=CLASSES, batch_size=2, - num_workers=1, - persistent_workers=True, - ) - - old_pytorch_loader = persistent_loader._pytorch_loader - ref = weakref.ref(old_pytorch_loader) - del old_pytorch_loader - - persistent_loader.refresh() - gc.collect() - - assert ref() is None, ( - "Old DataLoader with persistent_workers=True was not released after " - "refresh(); worker processes are leaking." + is_train=False, + weighted_sampler=True, ) - - def test_refresh_loader_is_functional(self, loader): - """After refresh(), the new loader must still produce valid batches.""" - loader.refresh() - batch = next(iter(loader)) - assert "raw" in batch - assert isinstance(batch["raw"], torch.Tensor) + # weighted_sampler is only used when is_train=True + assert loader._sampler is None + + def test_collate_fn_stacks_tensors(self): + batch = [ + {"idx": torch.tensor(0), "raw": torch.zeros(4, 4, 4)}, + {"idx": torch.tensor(1), "raw": torch.ones(4, 4, 4)}, + ] + result = CellMapDataLoader.collate_fn(batch) + assert result["raw"].shape == torch.Size([2, 4, 4, 4]) + assert result["idx"].shape == torch.Size([2]) + + def test_repr(self, tmp_path): + ds = _make_ds(tmp_path) + loader = CellMapDataLoader(ds, classes=CLASSES, batch_size=1, is_train=False) + r = repr(loader) + assert "CellMapDataLoader" in r diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..da1e8a9 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,217 @@ +"""Tests for CellMapDataset.""" + +from __future__ import annotations + +import numpy as np +import torch + +from cellmap_data import CellMapDataset +from cellmap_data.empty_image import EmptyImage +from cellmap_data.image import CellMapImage + +from .test_helpers import create_test_dataset + +INPUT_ARRAYS = {"raw": {"shape": (4, 4, 4), "scale": (8.0, 8.0, 8.0)}} +TARGET_ARRAYS = {"labels": {"shape": (4, 4, 4), "scale": (8.0, 8.0, 8.0)}} + + +class TestCellMapDataset: + def test_init(self, tmp_path): + info = create_test_dataset(tmp_path, classes=["mito", "er"]) + ds = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + ) + assert ds.classes == ["mito", "er"] + assert "raw" in ds.input_sources + assert "mito" in ds.target_sources + assert "er" in ds.target_sources + + def test_missing_class_is_empty_image(self, tmp_path): + info = create_test_dataset(tmp_path, classes=["mito"]) + ds = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=["mito", "er"], # er not annotated + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + ) + assert isinstance(ds.target_sources["mito"], CellMapImage) + assert isinstance(ds.target_sources["er"], EmptyImage) + + def test_len_positive(self, tmp_path): + info = create_test_dataset(tmp_path) + ds = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + ) + assert len(ds) > 0 + + def test_getitem_returns_dict_with_idx(self, tmp_path): + info = create_test_dataset(tmp_path) + ds = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + pad=True, + ) + item = ds[0] + assert "idx" in item + assert item["idx"].item() == 0 + + def test_getitem_raw_is_tensor(self, tmp_path): + info = create_test_dataset(tmp_path) + ds = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + pad=True, + ) + item = ds[0] + assert isinstance(item["raw"], torch.Tensor) + assert item["raw"].shape == torch.Size([4, 4, 4]) + + def test_getitem_missing_class_nan(self, tmp_path): + info = create_test_dataset(tmp_path, classes=["mito"]) + ds = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=["mito", "er"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + pad=True, + ) + item = ds[0] + # unannotated class → NaN + assert torch.isnan(item["er"]).all() + # annotated class → not all NaN + assert not torch.isnan(item["mito"]).all() + + def test_get_crop_class_matrix_shape(self, tmp_path): + info = create_test_dataset(tmp_path, classes=["mito"]) + ds = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=["mito", "er"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + ) + mat = ds.get_crop_class_matrix() + assert mat.shape == (1, 2) + # mito is annotated (True), er is not (False) + assert mat[0, 0] == True + assert mat[0, 1] == False + + def test_get_indices_non_empty(self, tmp_path): + info = create_test_dataset(tmp_path) + ds = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + ) + chunk_size = {"z": 32.0, "y": 32.0, "x": 32.0} + indices = ds.get_indices(chunk_size) + assert len(indices) > 0 + + def test_verify(self, tmp_path): + info = create_test_dataset(tmp_path) + ds = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + ) + assert ds.verify() + + def test_bounding_box(self, tmp_path): + info = create_test_dataset( + tmp_path, shape=(32, 32, 32), voxel_size=[8.0, 8.0, 8.0] + ) + ds = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + ) + bb = ds.bounding_box + assert bb is not None + assert set(bb.keys()) == {"z", "y", "x"} + + def test_repr(self, tmp_path): + info = create_test_dataset(tmp_path) + ds = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + ) + r = repr(ds) + assert "CellMapDataset" in r + + def test_class_counts(self, tmp_path): + info = create_test_dataset(tmp_path) + ds = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + ) + counts = ds.class_counts + assert "totals" in counts + assert all(c in counts["totals"] for c in info["classes"]) + + def test_spatial_transforms_mirror(self, tmp_path): + """Mirror spatial transform → item differs from un-transformed.""" + info = create_test_dataset(tmp_path, shape=(32, 32, 32)) + ds_plain = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + pad=True, + ) + ds_mirror = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + pad=True, + spatial_transforms={"mirror": {"z": True, "y": False, "x": False}}, + ) + # Mirror is random; with always-true z mirror, result differs from original + raw_plain = ds_plain[0]["raw"] + raw_mirrored = ds_mirror[0]["raw"] + # They may or may not match depending on RNG, but shapes must match + assert raw_plain.shape == raw_mirrored.shape diff --git a/tests/test_dataset_edge_cases.py b/tests/test_dataset_edge_cases.py deleted file mode 100644 index 147ad27..0000000 --- a/tests/test_dataset_edge_cases.py +++ /dev/null @@ -1,471 +0,0 @@ -"""Tests for CellMapDataset edge cases and special methods.""" - -import pickle - -import numpy as np -import pytest -import torch - -from cellmap_data import CellMapDataset, CellMapMultiDataset - -from .test_helpers import create_minimal_test_dataset - - -class TestCellMapDatasetEdgeCases: - """Test edge cases and special methods in CellMapDataset.""" - - @pytest.fixture - def minimal_dataset(self, tmp_path): - """Create a minimal dataset for testing.""" - config = create_minimal_test_dataset(tmp_path) - - input_arrays = { - "raw": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - } - } - - target_arrays = { - "gt": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - } - } - - dataset = CellMapDataset( - raw_path=str(config["raw_path"]), - target_path=str(config["gt_path"]), - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - force_has_data=True, - ) - - return dataset, config - - def test_pickle_support(self, minimal_dataset): - """Test that dataset can be pickled and unpickled.""" - dataset, _ = minimal_dataset - - # Pickle the dataset - pickled = pickle.dumps(dataset) - - # Unpickle the dataset - unpickled = pickle.loads(pickled) - - # Verify properties are preserved - assert unpickled.raw_path == dataset.raw_path - assert unpickled.target_path == dataset.target_path - assert unpickled.classes == dataset.classes - assert unpickled.input_arrays == dataset.input_arrays - assert unpickled.target_arrays == dataset.target_arrays - - def test_del_method_cleanup(self, minimal_dataset): - """Test that __del__ properly cleans up the executor.""" - dataset, _ = minimal_dataset - - # Access executor to force initialization - _ = dataset.executor - - # Verify executor exists - assert dataset._executor is not None - - # Delete dataset should trigger cleanup - del dataset - - # No exception should be raised - assert True - - def test_executor_property_lazy_init(self, minimal_dataset): - """Test that executor is lazily initialized.""" - dataset, _ = minimal_dataset - - # Initially, executor should not be initialized - assert dataset._executor is None - - # Access executor property - executor = dataset.executor - - # Now it should be initialized - assert executor is not None - assert dataset._executor is not None - - # Accessing again should return same instance - executor2 = dataset.executor - assert executor is executor2 - - def test_executor_handles_fork(self, minimal_dataset): - """Test that executor is recreated after process fork.""" - dataset, _ = minimal_dataset - - # Access executor - _ = dataset.executor - original_pid = dataset._executor_pid - - # Simulate a fork by changing the PID tracking - import os - - dataset._executor_pid = os.getpid() + 1 - - # Access executor again - should create new one - _ = dataset.executor - - # PID should be updated - assert dataset._executor_pid == os.getpid() - - def test_center_property(self, minimal_dataset): - """Test the center property calculation.""" - dataset, _ = minimal_dataset - - center = dataset.center - - # Center should be a dict with axis keys - assert isinstance(center, dict) - for axis in dataset.axis_order: - assert axis in center - assert isinstance(center[axis], (int, float)) - - def test_largest_voxel_sizes_property(self, minimal_dataset): - """Test the largest_voxel_sizes property.""" - dataset, _ = minimal_dataset - - voxel_sizes = dataset.largest_voxel_sizes - - # Should be a dict with axis keys - assert isinstance(voxel_sizes, dict) - for axis in dataset.axis_order: - assert axis in voxel_sizes - assert voxel_sizes[axis] > 0 - - def test_bounding_box_property(self, minimal_dataset): - """Test the bounding_box property.""" - dataset, _ = minimal_dataset - - bbox = dataset.bounding_box - - # Should be a dict mapping axes to [min, max] pairs - assert isinstance(bbox, dict) - for axis in dataset.axis_order: - assert axis in bbox - assert len(bbox[axis]) == 2 - assert bbox[axis][0] <= bbox[axis][1] - - def test_sampling_box_property(self, minimal_dataset): - """Test the sampling_box property.""" - dataset, _ = minimal_dataset - - sbox = dataset.sampling_box - - # Should be a dict mapping axes to [min, max] pairs - assert isinstance(sbox, dict) - for axis in dataset.axis_order: - assert axis in sbox - assert len(sbox[axis]) == 2 - - def test_sampling_box_shape_property(self, minimal_dataset): - """Test the sampling_box_shape property.""" - dataset, _ = minimal_dataset - - shape = dataset.sampling_box_shape - - # Should be a dict mapping axes to integer sizes - assert isinstance(shape, dict) - for axis in dataset.axis_order: - assert axis in shape - assert isinstance(shape[axis], int) - assert shape[axis] > 0 - - def test_device_property_auto_selection(self, minimal_dataset): - """Test device property auto-selects appropriate device.""" - dataset, _ = minimal_dataset - - device = dataset.device - - # Should be a torch device - assert isinstance(device, torch.device) - # Should be one of the expected types - assert device.type in ["cpu", "cuda", "mps"] - - def test_negative_index_handling(self, minimal_dataset): - """Test that negative indices are handled correctly.""" - dataset, _ = minimal_dataset - - # Try to get item with negative index - item = dataset[-1] - - # Should return a valid item - assert isinstance(item, dict) - assert "raw" in item - - def test_out_of_bounds_index_handling(self, minimal_dataset): - """Test that out of bounds indices are handled gracefully.""" - dataset, _ = minimal_dataset - - # Try an index way out of bounds - large_idx = len(dataset) * 10 - - # Should not raise, but may log warning - item = dataset[large_idx] - - # Should still return a valid item (clamped to bounds) - assert isinstance(item, dict) - - def test_class_counts_property(self, minimal_dataset): - """Test the class_counts property.""" - dataset, _ = minimal_dataset - - counts = dataset.class_counts - - # Should be a dict - assert isinstance(counts, dict) - # class_counts structure has changed - it's now nested with 'totals' - # Check that the totals key exists and has class entries - if "totals" in counts: - for cls in dataset.classes: - # Class names might have _bg suffix - assert any(cls in key for key in counts["totals"].keys()) - else: - # Old structure - direct class keys - for cls in dataset.classes: - assert cls in counts - - def test_class_weights_property(self, minimal_dataset): - """Test the class_weights property.""" - dataset, _ = minimal_dataset - - weights = dataset.class_weights - - # Should be a dict - assert isinstance(weights, dict) - # Should have entries for each class - for cls in dataset.classes: - assert cls in weights - assert isinstance(weights[cls], (int, float)) - assert 0 <= weights[cls] <= 1 - - def test_validation_indices_property(self, minimal_dataset): - """Test the validation_indices property.""" - dataset, _ = minimal_dataset - - indices = dataset.validation_indices - - # Should be a sequence - assert hasattr(indices, "__iter__") - - def test_2d_array_creates_multidataset(self, tmp_path): - """Test that 2D array without slicing axis triggers special handling.""" - config = create_minimal_test_dataset(tmp_path) - - # Create 2D array configuration (shape has a 1 in it) - # Note: The actual behavior may depend on how is_array_2D is implemented - input_arrays = { - "raw": { - "shape": (1, 8, 8), # 2D array - "scale": (4.0, 4.0, 4.0), - } - } - - target_arrays = { - "gt": { - "shape": (1, 8, 8), # 2D array - "scale": (4.0, 4.0, 4.0), - } - } - - # Creating dataset with 2D arrays may create multidataset or regular dataset - # depending on implementation details - dataset = CellMapDataset( - raw_path=str(config["raw_path"]), - target_path=str(config["gt_path"]), - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - force_has_data=True, - ) - - # Should create some kind of dataset (either regular or multi) - # The key is that it doesn't raise an error - assert dataset is not None - assert hasattr(dataset, "__getitem__") - - def test_set_raw_value_transforms(self, minimal_dataset): - """Test setting raw value transforms.""" - dataset, _ = minimal_dataset - - transform = lambda x: x * 2 - dataset.set_raw_value_transforms(transform) - - # Should not raise - assert True - - def test_set_target_value_transforms(self, minimal_dataset): - """Test setting target value transforms.""" - dataset, _ = minimal_dataset - - transform = lambda x: x * 0.5 - dataset.set_target_value_transforms(transform) - - # Should not raise - assert True - - def test_to_device_method(self, minimal_dataset): - """Test moving dataset to device.""" - dataset, _ = minimal_dataset - - # Move to CPU explicitly - result = dataset.to("cpu") - - # Should return self - assert result is dataset - assert dataset.device.type == "cpu" - - def test_get_random_subset_indices(self, minimal_dataset): - """Test getting random subset indices.""" - dataset, _ = minimal_dataset - - num_samples = 5 - indices = dataset.get_random_subset_indices(num_samples) - - # Should return list of indices - assert len(indices) == num_samples - for idx in indices: - assert 0 <= idx < len(dataset) - - def test_get_subset_random_sampler(self, minimal_dataset): - """Test creating a subset random sampler.""" - dataset, _ = minimal_dataset - - num_samples = 5 - sampler = dataset.get_subset_random_sampler(num_samples) - - # Should create a sampler - assert sampler is not None - # Should be iterable - indices = list(sampler) - assert len(indices) == num_samples - - -class TestProcessExecutorSingleton: - """Tests for the per-process shared ThreadPoolExecutor. - - Before the fix, each CellMapDataset created its own ThreadPoolExecutor, - causing thread explosion when many datasets exist inside DataLoader workers. - After the fix, all datasets in a process share one pool. - """ - - @pytest.fixture - def two_datasets(self, tmp_path): - """Create two independent datasets in the same process.""" - configs = [] - datasets = [] - for i in range(2): - cfg = create_minimal_test_dataset(tmp_path / f"ds{i}") - configs.append(cfg) - datasets.append( - CellMapDataset( - raw_path=str(cfg["raw_path"]), - target_path=str(cfg["gt_path"]), - classes=cfg["classes"], - input_arrays={ - "raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)} - }, - target_arrays={ - "gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)} - }, - force_has_data=True, - ) - ) - return datasets - - def test_executor_is_shared_across_datasets(self, two_datasets): - """Two datasets in the same process must return the exact same executor object.""" - ds0, ds1 = two_datasets - assert ds0.executor is ds1.executor - - def test_executor_is_module_level_singleton(self, two_datasets): - """The executor must live in the module-level _PROCESS_EXECUTORS dict.""" - import os - - from cellmap_data.dataset import _PROCESS_EXECUTORS - - ds0, _ = two_datasets - pid = os.getpid() - assert pid in _PROCESS_EXECUTORS - assert _PROCESS_EXECUTORS[pid] is ds0.executor - - def test_close_does_not_shut_down_shared_pool(self, two_datasets): - """close() on one dataset must not prevent other datasets from using the pool.""" - ds0, ds1 = two_datasets - - # Trigger executor creation on both - _ = ds0.executor - _ = ds1.executor - - # Close the first dataset - ds0.close() - assert ds0._executor is None - - # The second dataset must still be able to submit work - future = ds1.executor.submit(lambda: 42) - assert future.result() == 42 - - def test_del_does_not_shut_down_shared_pool(self, two_datasets): - """__del__ on one dataset must not prevent other datasets from using the pool.""" - ds0, ds1 = two_datasets - _ = ds0.executor - executor_ref = ds1.executor - - ds0.__del__() - assert ds0._executor is None - - # Pool is still operational - future = executor_ref.submit(lambda: "alive") - assert future.result() == "alive" - - def test_executor_lazy_init(self, tmp_path): - """Executor must not be created until first access via the property.""" - from cellmap_data.dataset import _PROCESS_EXECUTORS - - import os - - cfg = create_minimal_test_dataset(tmp_path / "lazy") - # Clear any existing entry so the laziness is observable - _PROCESS_EXECUTORS.pop(os.getpid(), None) - - ds = CellMapDataset( - raw_path=str(cfg["raw_path"]), - target_path=str(cfg["gt_path"]), - classes=cfg["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - ) - - assert ds._executor is None # not yet created - _ = ds.executor # trigger lazy init - assert os.getpid() in _PROCESS_EXECUTORS - - def test_pid_change_triggers_new_executor(self, two_datasets): - """Simulating a PID change (post-fork child) causes a fresh executor lookup.""" - import os - from unittest.mock import patch - - from cellmap_data.dataset import _PROCESS_EXECUTORS - - ds0, _ = two_datasets - original_executor = ds0.executor # ensure cached - fake_pid = os.getpid() + 99999 # a PID that isn't in the dict - - with patch("cellmap_data.dataset.os.getpid", return_value=fake_pid): - # Force re-evaluation by clearing the cached pid - ds0._executor_pid = None - new_executor = ds0.executor - - # A new entry was created for the fake PID - assert fake_pid in _PROCESS_EXECUTORS - # The new executor is different from the original process's executor - assert new_executor is not original_executor - - # Cleanup: remove the fake PID entry - _PROCESS_EXECUTORS.pop(fake_pid, None) diff --git a/tests/test_dataset_writer.py b/tests/test_dataset_writer.py deleted file mode 100644 index 2bd8dcb..0000000 --- a/tests/test_dataset_writer.py +++ /dev/null @@ -1,580 +0,0 @@ -""" -Tests for CellMapDatasetWriter class. - -Tests writing predictions and outputs using real data. -""" - -import pytest -import torchvision.transforms.v2 as T -import torch - -from cellmap_data import CellMapDatasetWriter - -from .test_helpers import create_test_dataset - - -class TestCellMapDatasetWriter: - """Test suite for CellMapDatasetWriter class.""" - - @pytest.fixture - def writer_config(self, tmp_path): - """Create configuration for writer tests.""" - # Create input data - input_config = create_test_dataset( - tmp_path / "input", - raw_shape=(64, 64, 64), - num_classes=2, - raw_scale=(8.0, 8.0, 8.0), - ) - - # Output path - output_path = tmp_path / "output" / "predictions.zarr" - - return { - "input_config": input_config, - "output_path": str(output_path), - } - - def test_initialization_basic(self, writer_config): - """Test basic DatasetWriter initialization.""" - config = writer_config["input_config"] - - input_arrays = {"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}} - target_arrays = { - "predictions": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)} - } - - target_bounds = { - "predictions": { - "x": [0, 256], - "y": [0, 256], - "z": [0, 256], - } - } - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=writer_config["output_path"], - classes=["class_0", "class_1"], - input_arrays=input_arrays, - target_arrays=target_arrays, - target_bounds=target_bounds, - ) - - assert writer is not None - assert writer.raw_path == config["raw_path"] - assert writer.target_path == writer_config["output_path"] - - def test_classes_parameter(self, writer_config): - """Test classes parameter.""" - config = writer_config["input_config"] - - classes = ["class_0", "class_1", "class_2"] - - target_bounds = { - "pred": { - "x": [0, 128], - "y": [0, 128], - "z": [0, 128], - } - } - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=writer_config["output_path"], - classes=classes, - input_arrays={"raw": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"pred": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_bounds=target_bounds, - ) - - assert writer.classes == classes - - def test_input_arrays_configuration(self, writer_config): - """Test input arrays configuration.""" - config = writer_config["input_config"] - - input_arrays = { - "raw_4nm": {"shape": (32, 32, 32), "scale": (4.0, 4.0, 4.0)}, - "raw_8nm": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}, - } - - target_bounds = { - "pred": { - "x": [0, 128], - "y": [0, 128], - "z": [0, 128], - } - } - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=writer_config["output_path"], - classes=["class_0"], - input_arrays=input_arrays, - target_arrays={"pred": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_bounds=target_bounds, - ) - - assert "raw_4nm" in writer.input_arrays - assert "raw_8nm" in writer.input_arrays - - def test_target_arrays_configuration(self, writer_config): - """Test target arrays configuration.""" - config = writer_config["input_config"] - - target_arrays = { - "predictions": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}, - "confidences": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}, - } - - target_bounds = { - "predictions": { - "x": [0, 256], - "y": [0, 256], - "z": [0, 256], - }, - "confidences": { - "x": [0, 256], - "y": [0, 256], - "z": [0, 256], - }, - } - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=writer_config["output_path"], - classes=["class_0"], - input_arrays={"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, - target_arrays=target_arrays, - target_bounds=target_bounds, - ) - - assert "predictions" in writer.target_arrays - assert "confidences" in writer.target_arrays - - def test_target_bounds_parameter(self, writer_config): - """Test target bounds parameter.""" - config = writer_config["input_config"] - - target_bounds = { - "pred": { - "x": [0, 512], - "y": [0, 512], - "z": [0, 64], - } - } - - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=writer_config["output_path"], - classes=["class_0"], - input_arrays={"raw": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"pred": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_bounds=target_bounds, - ) - - assert writer is not None - - def test_axis_order_parameter(self, writer_config): - """Test axis order parameter.""" - config = writer_config["input_config"] - - target_bounds = { - "pred": { - "x": [0, 128], - "y": [0, 128], - "z": [0, 128], - } - } - for axis_order in ["zyx", "xyz", "yxz"]: - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=writer_config["output_path"], - classes=["class_0"], - input_arrays={"raw": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={ - "pred": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)} - }, - axis_order=axis_order, - target_bounds=target_bounds, - ) - assert writer.axis_order == axis_order - - def test_pad_parameter(self, writer_config): - """Test pad parameter.""" - config = writer_config["input_config"] - - target_bounds = { - "pred": { - "x": [0, 128], - "y": [0, 128], - "z": [0, 128], - } - } - writer_pad = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=writer_config["output_path"], - classes=["class_0"], - input_arrays={"raw": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"pred": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_bounds=target_bounds, - ) - assert writer_pad.input_sources["raw"].pad is True - - writer_no_pad = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=writer_config["output_path"], - classes=["class_0"], - input_arrays={"raw": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"pred": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_bounds=target_bounds, - ) - assert writer_no_pad.input_sources["raw"].pad is True - - def test_device_parameter(self, writer_config): - """Test device parameter.""" - config = writer_config["input_config"] - - target_bounds = { - "pred": { - "x": [0, 128], - "y": [0, 128], - "z": [0, 128], - } - } - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=writer_config["output_path"], - classes=["class_0"], - input_arrays={"raw": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"pred": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - device="cpu", - target_bounds=target_bounds, - ) - - assert writer is not None - - def test_context_parameter(self, writer_config): - """Test TensorStore context parameter.""" - import tensorstore as ts - - config = writer_config["input_config"] - context = ts.Context() - - target_bounds = { - "pred": { - "x": [0, 128], - "y": [0, 128], - "z": [0, 128], - } - } - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=writer_config["output_path"], - classes=["class_0"], - input_arrays={"raw": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"pred": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - context=context, - target_bounds=target_bounds, - ) - - assert writer.context is context - - def test_n_channels_in_target_arrays(self, writer_config): - """Test n_channels parameter in target arrays configuration.""" - config = writer_config["input_config"] - - # Test with n_channels to create multi-channel output - target_arrays = { - "predictions": { - "shape": (32, 32, 32), - "scale": (8.0, 8.0, 8.0), - "n_channels": 3, - } - } - - target_bounds = { - "predictions": { - "x": [0, 256], - "y": [0, 256], - "z": [0, 256], - } - } - - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=writer_config["output_path"], - classes=["class_0"], - input_arrays={"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, - target_arrays=target_arrays, - target_bounds=target_bounds, - ) - - # Verify that the channel axis was added - assert "c" in writer.axis_order - assert writer.axis_order.startswith("c") - - def test_n_channels_with_existing_channel_axis(self, writer_config): - """Test n_channels parameter when channel axis already exists.""" - config = writer_config["input_config"] - - target_arrays = { - "predictions": { - "shape": (32, 32, 32), - "scale": (8.0, 8.0, 8.0), - "n_channels": 4, - } - } - - target_bounds = { - "predictions": { - "x": [0, 256], - "y": [0, 256], - "z": [0, 256], - } - } - - # Test with explicit channel axis in axis_order - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=writer_config["output_path"], - classes=["class_0"], - input_arrays={"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, - target_arrays=target_arrays, - axis_order="czyx", - target_bounds=target_bounds, - ) - - # Verify channel axis is present and not duplicated - assert writer.axis_order == "czyx" - assert writer.axis_order.count("c") == 1 - - -class TestWriterOperations: - """Test writer operations and functionality.""" - - def test_writer_with_value_transforms(self, tmp_path): - """Test writer with value transforms.""" - config = create_test_dataset( - tmp_path / "input", - raw_shape=(32, 32, 32), - num_classes=2, - ) - - output_path = tmp_path / "output.zarr" - - raw_transform = T.ToDtype(torch.float, scale=True) - - target_bounds = { - "pred": { - "x": [0, 128], - "y": [0, 128], - "z": [0, 128], - } - } - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=str(output_path), - classes=["class_0"], - input_arrays={"raw": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"pred": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - raw_value_transforms=raw_transform, - target_bounds=target_bounds, - ) - - assert writer.raw_value_transforms is not None - - def test_writer_different_input_output_shapes(self, tmp_path): - """Test writer with different input and output shapes.""" - config = create_test_dataset( - tmp_path / "input", - raw_shape=(64, 64, 64), - num_classes=2, - ) - - output_path = tmp_path / "output.zarr" - - # Input larger than output - target_bounds = { - "pred": { - "x": [0, 128], - "y": [0, 128], - "z": [0, 128], - } - } - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=str(output_path), - classes=["class_0"], - input_arrays={"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"pred": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_bounds=target_bounds, - ) - - assert writer.input_arrays["raw"]["shape"] == (32, 32, 32) - assert writer.target_arrays["pred"]["shape"] == (16, 16, 16) - - def test_writer_anisotropic_resolution(self, tmp_path): - """Test writer with anisotropic voxel sizes.""" - config = create_test_dataset( - tmp_path / "input", - raw_shape=(32, 64, 64), - raw_scale=(16.0, 4.0, 4.0), - num_classes=2, - ) - - output_path = tmp_path / "output.zarr" - - target_bounds = { - "pred": { - "x": [0, 128], - "y": [0, 256], - "z": [0, 512], - } - } - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=str(output_path), - classes=["class_0"], - input_arrays={"raw": {"shape": (16, 32, 32), "scale": (16.0, 4.0, 4.0)}}, - target_arrays={"pred": {"shape": (16, 32, 32), "scale": (16.0, 4.0, 4.0)}}, - target_bounds=target_bounds, - ) - - assert writer.input_arrays["raw"]["scale"] == (16.0, 4.0, 4.0) - - -class TestWriterIntegration: - """Integration tests for writer functionality.""" - - def test_writer_prediction_workflow(self, tmp_path): - """Test complete prediction writing workflow.""" - # Create input data - config = create_test_dataset( - tmp_path / "input", - raw_shape=(64, 64, 64), - num_classes=2, - ) - - output_path = tmp_path / "predictions.zarr" - - # Create writer - target_bounds = { - "pred": { - "x": [0, 512], - "y": [0, 512], - "z": [0, 512], - } - } - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=str(output_path), - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"pred": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, - target_bounds=target_bounds, - ) - - # Writer should be ready - assert writer is not None - - def test_writer_with_bounds(self, tmp_path): - """Test writer with specific spatial bounds.""" - config = create_test_dataset( - tmp_path / "input", - raw_shape=(128, 128, 128), - num_classes=2, - ) - - output_path = tmp_path / "predictions.zarr" - - # Only write to specific region - target_bounds = { - "pred": { - "x": [32, 96], - "y": [32, 96], - "z": [0, 64], - } - } - - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=str(output_path), - classes=["class_0"], - input_arrays={"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"pred": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, - target_bounds=target_bounds, - ) - - assert writer is not None - - def test_multi_output_writer(self, tmp_path): - """Test writer with multiple output arrays.""" - config = create_test_dataset( - tmp_path / "input", - raw_shape=(64, 64, 64), - num_classes=3, - ) - - output_path = tmp_path / "predictions.zarr" - - # Multiple outputs - target_arrays = { - "predictions": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}, - "uncertainties": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}, - "embeddings": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}, - } - - target_bounds = { - "predictions": { - "x": [0, 512], - "y": [0, 512], - "z": [0, 512], - }, - "uncertainties": { - "x": [0, 512], - "y": [0, 512], - "z": [0, 512], - }, - "embeddings": { - "x": [0, 512], - "y": [0, 512], - "z": [0, 512], - }, - } - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=str(output_path), - classes=["class_0", "class_1", "class_2"], - input_arrays={"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, - target_arrays=target_arrays, - target_bounds=target_bounds, - ) - - assert len(writer.target_arrays) == 3 - - def test_writer_2d_output(self, tmp_path): - """Test writer for 2D outputs.""" - # Create 2D input data - from .test_helpers import create_test_image_data, create_test_zarr_array - - input_path = tmp_path / "input_2d.zarr" - data_2d = create_test_image_data((128, 128), pattern="gradient") - create_test_zarr_array(input_path, data_2d, axes=("y", "x"), scale=(4.0, 4.0)) - - output_path = tmp_path / "output_2d.zarr" - - target_bounds = { - "pred": { - "x": [0, 512], - "y": [0, 512], - } - } - writer = CellMapDatasetWriter( - raw_path=str(input_path), - target_path=str(output_path), - classes=["class_0"], - input_arrays={"raw": {"shape": (64, 64), "scale": (4.0, 4.0)}}, - target_arrays={"pred": {"shape": (64, 64), "scale": (4.0, 4.0)}}, - axis_order="yx", - target_bounds=target_bounds, - ) - - assert writer.axis_order == "yx" diff --git a/tests/test_dataset_writer_batch.py b/tests/test_dataset_writer_batch.py deleted file mode 100644 index fc36e44..0000000 --- a/tests/test_dataset_writer_batch.py +++ /dev/null @@ -1,209 +0,0 @@ -""" -Tests for CellMapDatasetWriter batch operations. - -Tests that the writer correctly handles batched write operations. -""" - -import numpy as np -import pytest -import torch - -from cellmap_data import CellMapDatasetWriter - -from .test_helpers import create_test_dataset - - -class TestDatasetWriterBatchOperations: - """Test suite for batch write operations in DatasetWriter.""" - - @pytest.fixture - def writer_setup(self, tmp_path): - """Create writer and config for batch write tests. - - Returns a tuple of (writer, config) where writer is a CellMapDatasetWriter - configured for testing batch operations. - """ - # Create input data - config = create_test_dataset( - tmp_path / "input", - raw_shape=(64, 64, 64), - num_classes=2, - raw_scale=(8.0, 8.0, 8.0), - ) - - # Output path - output_path = tmp_path / "output" / "predictions.zarr" - - target_bounds = { - "pred": { - "x": [0, 512], - "y": [0, 512], - "z": [0, 512], - } - } - - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=str(output_path), - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"pred": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, - target_bounds=target_bounds, - ) - - return writer, config - - def test_batch_write_with_tensor_indices(self, writer_setup): - """Test writing with a batch of tensor indices.""" - writer, config = writer_setup - - # Simulate batch predictions - batch_size = 8 - indices = torch.tensor(list(range(batch_size))) - - # Create predictions with shape (batch_size, num_classes, *spatial_dims) - predictions = torch.randn(batch_size, 2, 32, 32, 32) - - # This should not raise an error - writer[indices] = {"pred": predictions} - - def test_batch_write_with_numpy_indices(self, writer_setup): - """Test writing with a batch of numpy indices.""" - writer, config = writer_setup - - # Simulate batch predictions - batch_size = 4 - indices = np.array(list(range(batch_size))) - - # Create predictions - predictions = np.random.randn(batch_size, 2, 32, 32, 32).astype(np.float32) - - # This should not raise an error - writer[indices] = {"pred": predictions} - - def test_batch_write_with_list_indices(self, writer_setup): - """Test writing with a batch of list indices.""" - writer, config = writer_setup - - # Simulate batch predictions - batch_size = 4 - indices = [0, 1, 2, 3] - - # Create predictions - predictions = torch.randn(batch_size, 2, 32, 32, 32) - - # This should not raise an error - writer[indices] = {"pred": predictions} - - def test_batch_write_large_batch(self, writer_setup): - """Test writing with a large batch size (as in the error case).""" - writer, config = writer_setup - - # Simulate the error case: batch_size=32 - batch_size = 32 - indices = torch.tensor(list(range(batch_size))) - - # Create predictions with shape (32, 2, 32, 32, 32) - predictions = torch.randn(batch_size, 2, 32, 32, 32) - - # This should not raise ValueError about shape mismatch - writer[indices] = {"pred": predictions} - - def test_batch_write_with_dict_arrays(self, writer_setup): - """Test writing with dictionary of arrays per class.""" - writer, config = writer_setup - - batch_size = 4 - indices = torch.tensor(list(range(batch_size))) - - # Create predictions as dictionary - predictions_dict = { - "class_0": torch.randn(batch_size, 32, 32, 32), - "class_1": torch.randn(batch_size, 32, 32, 32), - } - - # This should not raise an error - writer[indices] = {"pred": predictions_dict} - - def test_batch_write_2d_data(self, tmp_path): - """Test batch writing for 2D data (3D with singleton z dimension).""" - # Import kept at module level; reuse create_test_dataset here - - # Create test dataset with thin Z dimension to simulate 2D - config = create_test_dataset( - tmp_path / "input", - raw_shape=(1, 128, 128), # Thin z dimension - num_classes=1, - raw_scale=(8.0, 4.0, 4.0), - ) - - output_path = tmp_path / "output_2d.zarr" - - target_bounds = { - "pred": { - "z": [0, 8], - "y": [0, 512], - "x": [0, 512], - } - } - - writer = CellMapDatasetWriter( - raw_path=config["raw_path"], - target_path=str(output_path), - classes=["class_0"], - input_arrays={"raw": {"shape": (1, 64, 64), "scale": (8.0, 4.0, 4.0)}}, - target_arrays={"pred": {"shape": (1, 64, 64), "scale": (8.0, 4.0, 4.0)}}, - axis_order="zyx", - target_bounds=target_bounds, - ) - - # Test batch write with thin-z 3D data - batch_size = 4 - indices = torch.tensor(list(range(batch_size))) - predictions = torch.randn(batch_size, 1, 1, 64, 64) - - # This should not raise an error - writer[indices] = {"pred": predictions} - - def test_single_item_write_still_works(self, writer_setup): - """Test that single item writes still work correctly.""" - writer, config = writer_setup - - # Single item write - idx = 0 - predictions = torch.randn(2, 32, 32, 32) - - # This should work as before - writer[idx] = {"pred": predictions} - - def test_batch_write_with_scalar_values(self, writer_setup): - """Test batch writing with scalar values fills all spatial dims.""" - writer, config = writer_setup - - batch_size = 4 - indices = torch.tensor(list(range(batch_size))) - - # Scalar values should be broadcast to full arrays - # Create proper shaped arrays filled with the scalar value - scalar_val = 0.5 - predictions = torch.full((batch_size, 2, 32, 32, 32), scalar_val) - writer[indices] = {"pred": predictions} - - def test_batch_write_mixed_data_types(self, writer_setup): - """Test batch writing preserves data types.""" - writer, config = writer_setup - - batch_size = 4 - indices = torch.tensor(list(range(batch_size))) - - # Test with different dtypes - predictions_float32 = torch.randn( - batch_size, 2, 32, 32, 32, dtype=torch.float32 - ) - writer[indices] = {"pred": predictions_float32} - - predictions_float64 = torch.randn( - batch_size, 2, 32, 32, 32, dtype=torch.float64 - ) - indices2 = torch.tensor(list(range(batch_size, batch_size * 2))) - writer[indices2] = {"pred": predictions_float64} diff --git a/tests/test_datasplit.py b/tests/test_datasplit.py new file mode 100644 index 0000000..2113710 --- /dev/null +++ b/tests/test_datasplit.py @@ -0,0 +1,139 @@ +"""Tests for CellMapDataSplit.""" + +from __future__ import annotations + +import csv +import os + +import torch + +from cellmap_data import CellMapDataSplit +from cellmap_data.multidataset import CellMapMultiDataset + +from .test_helpers import create_test_dataset + +INPUT_ARRAYS = {"raw": {"shape": (4, 4, 4), "scale": (8.0, 8.0, 8.0)}} +TARGET_ARRAYS = {"labels": {"shape": (4, 4, 4), "scale": (8.0, 8.0, 8.0)}} +CLASSES = ["mito", "er"] + + +def _make_split_from_dict(tmp_path): + train_info = create_test_dataset(tmp_path / "train", classes=CLASSES) + val_info = create_test_dataset(tmp_path / "val", classes=CLASSES) + dataset_dict = { + "train": [{"raw": train_info["raw_path"], "gt": train_info["gt_path"]}], + "validate": [{"raw": val_info["raw_path"], "gt": val_info["gt_path"]}], + } + return CellMapDataSplit( + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + classes=CLASSES, + dataset_dict=dataset_dict, + force_has_data=True, + ) + + +class TestCellMapDataSplit: + def test_init_from_dict(self, tmp_path): + split = _make_split_from_dict(tmp_path) + assert len(split.train_datasets) == 1 + assert len(split._validation_datasets) == 1 + + def test_train_datasets_combined_type(self, tmp_path): + split = _make_split_from_dict(tmp_path) + combined = split.train_datasets_combined + assert isinstance(combined, CellMapMultiDataset) + + def test_validation_datasets_combined_type(self, tmp_path): + split = _make_split_from_dict(tmp_path) + combined = split.validation_datasets_combined + assert isinstance(combined, CellMapMultiDataset) + + def test_validation_datasets_property(self, tmp_path): + split = _make_split_from_dict(tmp_path) + assert len(split.validation_datasets) == 1 + + def test_validation_blocks(self, tmp_path): + from torch.utils.data import Subset + + split = _make_split_from_dict(tmp_path) + blocks = split.validation_blocks + assert isinstance(blocks, Subset) + assert len(blocks) > 0 + + def test_class_counts_keys(self, tmp_path): + split = _make_split_from_dict(tmp_path) + counts = split.class_counts + assert "train" in counts + assert "validate" in counts + + def test_init_from_csv(self, tmp_path): + train_info = create_test_dataset(tmp_path / "train", classes=CLASSES) + val_info = create_test_dataset(tmp_path / "val", classes=CLASSES) + csv_path = str(tmp_path / "split.csv") + with open(csv_path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(["train", train_info["raw_path"], train_info["gt_path"]]) + w.writerow(["validate", val_info["raw_path"], val_info["gt_path"]]) + split = CellMapDataSplit( + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + classes=CLASSES, + csv_path=csv_path, + force_has_data=True, + ) + assert len(split.train_datasets) == 1 + assert len(split._validation_datasets) == 1 + + def test_set_raw_value_transforms(self, tmp_path): + split = _make_split_from_dict(tmp_path) + import torchvision.transforms.v2 as T + + tx = T.ToDtype(torch.float, scale=True) + split.set_raw_value_transforms(train_transforms=tx, val_transforms=tx) + # Check that train datasets have the new transform + for ds in split.train_datasets: + assert ds.raw_value_transforms is tx + + def test_invalidate_clears_combined(self, tmp_path): + split = _make_split_from_dict(tmp_path) + combined1 = split.train_datasets_combined + split._invalidate() + combined2 = split.train_datasets_combined + # After invalidation, a new CellMapMultiDataset is created + assert combined1 is not combined2 + + def test_repr(self, tmp_path): + split = _make_split_from_dict(tmp_path) + r = repr(split) + assert "CellMapDataSplit" in r + + def test_init_empty(self): + split = CellMapDataSplit( + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + classes=CLASSES, + ) + assert len(split.train_datasets) == 0 + assert len(split._validation_datasets) == 0 + + def test_init_from_datasets(self, tmp_path): + from cellmap_data import CellMapDataset + + train_info = create_test_dataset(tmp_path / "d1", classes=CLASSES) + ds = CellMapDataset( + raw_path=train_info["raw_path"], + target_path=train_info["gt_path"], + classes=CLASSES, + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + ) + split = CellMapDataSplit( + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + classes=CLASSES, + datasets={"train": [ds], "validate": []}, + ) + assert len(split.train_datasets) == 1 + assert len(split._validation_datasets) == 0 diff --git a/tests/test_empty_image.py b/tests/test_empty_image.py new file mode 100644 index 0000000..b299406 --- /dev/null +++ b/tests/test_empty_image.py @@ -0,0 +1,48 @@ +"""Tests for EmptyImage.""" + +from __future__ import annotations + +import torch + +from cellmap_data.empty_image import EmptyImage + + +def test_empty_image_returns_nan(): + img = EmptyImage("fake/path", "mito", [8.0, 8.0, 8.0], [4, 4, 4]) + patch = img[{"z": 0.0, "y": 0.0, "x": 0.0}] + assert isinstance(patch, torch.Tensor) + assert patch.shape == torch.Size([4, 4, 4]) + assert torch.isnan(patch).all() + + +def test_empty_image_bounding_box_none(): + img = EmptyImage("fake/path", "er", [8.0, 8.0, 8.0], [4, 4, 4]) + assert img.bounding_box is None + assert img.sampling_box is None + + +def test_empty_image_class_counts_zero(): + img = EmptyImage("fake/path", "nucleus", [8.0, 8.0, 8.0], [4, 4, 4]) + assert img.class_counts == {"nucleus": 0} + + +def test_empty_image_set_spatial_transforms_noop(): + img = EmptyImage("fake/path", "mito", [8.0, 8.0, 8.0], [4, 4, 4]) + img.set_spatial_transforms({"mirror": {"z": True}}) # should not raise + patch = img[{"z": 0.0, "y": 0.0, "x": 0.0}] + assert torch.isnan(patch).all() + + +def test_empty_image_repr(): + img = EmptyImage("fake/path", "mito", [8.0, 8.0, 8.0], [4, 4, 4]) + r = repr(img) + assert "EmptyImage" in r + assert "mito" in r + + +def test_empty_image_clone(): + """Each call returns a fresh clone (not the same tensor).""" + img = EmptyImage("fake/path", "mito", [8.0, 8.0, 8.0], [4, 4, 4]) + p1 = img[{"z": 0.0, "y": 0.0, "x": 0.0}] + p2 = img[{"z": 0.0, "y": 0.0, "x": 0.0}] + assert p1 is not p2 diff --git a/tests/test_empty_image_writer.py b/tests/test_empty_image_writer.py deleted file mode 100644 index 2d733bf..0000000 --- a/tests/test_empty_image_writer.py +++ /dev/null @@ -1,393 +0,0 @@ -""" -Tests for EmptyImage and ImageWriter classes. - -Tests empty image handling and image writing functionality. -""" - -import os -from pathlib import Path - -import pytest -from upath import UPath - -from cellmap_data import EmptyImage, ImageWriter - -from .test_helpers import create_test_image_data, create_test_zarr_array - - -@pytest.fixture -def tmp_upath(tmp_path: Path): - """Return a temporary directory (as :class:`upathlib.UPath` object) - which is unique to each test function invocation. - The temporary directory is created as a subdirectory - of the base temporary directory, with configurable retention, - as discussed in :ref:`temporary directory location and retention`. - """ - return UPath(tmp_path) - - -class TestEmptyImage: - """Test suite for EmptyImage class.""" - - def test_initialization_basic(self): - """Test basic EmptyImage initialization.""" - empty_image = EmptyImage( - label_class="test_class", - scale=(8.0, 8.0, 8.0), - voxel_shape=(16, 16, 16), - axis_order="zyx", - ) - - assert empty_image.label_class == "test_class" - assert empty_image.scale == {"z": 8.0, "y": 8.0, "x": 8.0} - assert empty_image.output_shape == {"z": 16, "y": 16, "x": 16} - - def test_empty_image_shape(self): - """Test that EmptyImage has correct shape.""" - shape = (32, 32, 32) - empty_image = EmptyImage( - label_class="empty", - scale=(4.0, 4.0, 4.0), - voxel_shape=shape, - axis_order="zyx", - ) - - assert empty_image.output_shape == {"z": 32, "y": 32, "x": 32} - - def test_empty_image_2d(self): - """Test EmptyImage with 2D shape.""" - empty_image = EmptyImage( - label_class="empty_2d", - scale=(4.0, 4.0), - voxel_shape=(64, 64), - axis_order="yx", - ) - - assert empty_image.axes == "yx" - assert len(empty_image.output_shape) == 2 - - def test_empty_image_different_scales(self): - """Test EmptyImage with different scales per axis.""" - empty_image = EmptyImage( - label_class="anisotropic", - scale=(16.0, 4.0, 4.0), - voxel_shape=(16, 32, 32), - axis_order="zyx", - ) - - assert empty_image.scale == {"z": 16.0, "y": 4.0, "x": 4.0} - assert empty_image.output_size == {"z": 256.0, "y": 128.0, "x": 128.0} - - def test_empty_image_value_transform(self): - """Test EmptyImage with value transform.""" - - def dummy_transform(x): - return x * 2 - - empty_image = EmptyImage( - label_class="test", - scale=(4.0, 4.0, 4.0), - voxel_shape=(8, 8, 8), - ) - empty_image.value_transform = dummy_transform - - assert empty_image.value_transform is not None - - def test_empty_image_device(self): - """Test EmptyImage device assignment.""" - empty_image = EmptyImage( - label_class="test", - scale=(4.0, 4.0, 4.0), - voxel_shape=(8, 8, 8), - ) - empty_image.to("cpu") - - assert empty_image.store.device.type == "cpu" - - def test_empty_image_pad_parameter(self): - """Test EmptyImage with pad parameter.""" - empty_image = EmptyImage( - label_class="test", - scale=(4.0, 4.0, 4.0), - voxel_shape=(8, 8, 8), - ) - empty_image.pad = True - empty_image.pad_value = 0.0 - - assert empty_image.pad is True - assert empty_image.pad_value == 0.0 - - -class TestImageWriter: - """Test suite for ImageWriter class.""" - - @pytest.fixture - def output_path(self, tmp_upath): - """Create output path for writing.""" - return tmp_upath / "output.zarr" - - def test_image_writer_initialization(self, output_path): - """Test ImageWriter initialization.""" - writer = ImageWriter( - path=output_path.path, - target_class="output_class", - scale=(8.0, 8.0, 8.0), - write_voxel_shape=(32, 32, 32), - axis_order="zyx", - bounding_box={"z": [0, 256], "y": [0, 256], "x": [0, 256]}, - ) - - assert os.path.normpath(writer.path).endswith( - os.path.normpath(output_path.path + os.path.sep + "s0") - ) - assert writer.target_class == "output_class" - - def test_image_writer_with_existing_data(self, tmp_upath): - """Test ImageWriter with pre-existing data.""" - # Create existing zarr array - data = create_test_image_data((32, 32, 32), pattern="gradient") - path = tmp_upath / "existing.zarr" - create_test_zarr_array(path, data) - - # Create writer for same path - writer = ImageWriter( - path=path.path, - target_class="test", - scale=(4.0, 4.0, 4.0), - write_voxel_shape=(16, 16, 16), - bounding_box={"z": [0, 128], "y": [0, 128], "x": [0, 128]}, - ) - - assert os.path.normpath(writer.path).endswith( - os.path.normpath(path.path + os.path.sep + "s0") - ) - - def test_image_writer_different_shapes(self, tmp_upath): - """Test ImageWriter with different output shapes.""" - shapes = [(16, 16, 16), (32, 32, 32), (64, 32, 16)] - - for i, shape in enumerate(shapes): - path = tmp_upath / f"output_{i}.zarr" - writer = ImageWriter( - path=str(path), - target_class="test", - scale=(4.0, 4.0, 4.0), - write_voxel_shape=shape, - bounding_box={"z": [0, 256], "y": [0, 128], "x": [0, 64]}, - ) - - assert writer.write_voxel_shape == { - "z": shape[0], - "y": shape[1], - "x": shape[2], - } - - def test_image_writer_2d(self, tmp_upath): - """Test ImageWriter for 2D images.""" - path = tmp_upath / "output_2d.zarr" - writer = ImageWriter( - path=str(path), - target_class="test_2d", - scale=(4.0, 4.0), - write_voxel_shape=(64, 64), - axis_order="yx", - bounding_box={"y": [0, 256], "x": [0, 256]}, - ) - - assert writer.axes == "yx" - assert len(writer.write_voxel_shape) == 2 - - def test_image_writer_value_transform(self, tmp_upath): - """Test ImageWriter with value transform.""" - - def normalize(x): - return x / 255.0 - - path = tmp_upath / "output.zarr" - writer = ImageWriter( - path=str(path), - target_class="test", - scale=(4.0, 4.0, 4.0), - write_voxel_shape=(16, 16, 16), - bounding_box={"z": [0, 64], "y": [0, 64], "x": [0, 64]}, - ) - writer.value_transform = normalize - - assert writer.value_transform is not None - - def test_image_writer_interpolation(self, tmp_upath): - """Test ImageWriter with different interpolation modes.""" - for interp in ["nearest", "linear"]: - path = tmp_upath / f"output_{interp}.zarr" - writer = ImageWriter( - path=str(path), - target_class="test", - scale=(4.0, 4.0, 4.0), - write_voxel_shape=(16, 16, 16), - bounding_box={"z": [0, 64], "y": [0, 64], "x": [0, 64]}, - ) - writer.interpolation = interp - - assert writer.interpolation == interp - - def test_image_writer_anisotropic_scale(self, tmp_upath): - """Test ImageWriter with anisotropic voxel sizes.""" - path = tmp_upath / "anisotropic.zarr" - writer = ImageWriter( - path=str(path), - target_class="test", - scale=(16.0, 4.0, 4.0), # Anisotropic - write_voxel_shape=(16, 32, 32), - axis_order="zyx", - bounding_box={"z": [0, 256], "y": [0, 128], "x": [0, 128]}, - ) - - assert writer.scale == {"z": 16.0, "y": 4.0, "x": 4.0} - # Output size should account for scale - assert writer.write_world_shape == {"z": 256.0, "y": 128.0, "x": 128.0} - - def test_image_writer_context(self, tmp_upath): - """Test ImageWriter with TensorStore context.""" - import tensorstore as ts - - path = tmp_upath / "output.zarr" - context = ts.Context() - - writer = ImageWriter( - path=str(path), - target_class="test", - scale=(4.0, 4.0, 4.0), - write_voxel_shape=(16, 16, 16), - context=context, - bounding_box={"z": [0, 64], "y": [0, 64], "x": [0, 64]}, - ) - - assert writer.context is context - - -class TestEmptyImageIntegration: - """Integration tests for EmptyImage with dataset operations.""" - - def test_empty_image_as_placeholder(self): - """Test using EmptyImage as placeholder in dataset.""" - # EmptyImage can be used when data is missing - empty = EmptyImage( - label_class="missing_class", - scale=(8.0, 8.0, 8.0), - voxel_shape=(32, 32, 32), - ) - - # Should have proper attributes - assert empty.label_class == "missing_class" - assert empty.output_shape is not None - - def test_empty_image_collection(self): - """Test collection of EmptyImages.""" - # Create multiple empty images for different classes - empty_images = [] - for i in range(3): - empty = EmptyImage( - label_class=f"class_{i}", - scale=(4.0, 4.0, 4.0), - voxel_shape=(16, 16, 16), - ) - empty_images.append(empty) - - assert len(empty_images) == 3 - assert all(img.label_class.startswith("class_") for img in empty_images) - - -class TestImageWriterIntegration: - """Integration tests for ImageWriter functionality.""" - - def test_writer_output_preparation(self, tmp_upath): - """Test preparing outputs for writing.""" - path = tmp_upath / "predictions.zarr" - - writer = ImageWriter( - path=path.path, - target_class="predictions", - scale=(8.0, 8.0, 8.0), - write_voxel_shape=(32, 32, 32), - bounding_box={"z": [0, 256], "y": [0, 256], "x": [0, 256]}, - ) - - # Writer should be ready to write - assert os.path.normpath(writer.path).endswith( - os.path.normpath(path.path + os.path.sep + "s0") - ) - assert writer.write_voxel_shape is not None - - def test_multiple_writers_different_classes(self, tmp_upath): - """Test multiple writers for different classes.""" - classes = ["class_0", "class_1", "class_2"] - writers = [] - - for class_name in classes: - path = tmp_upath / f"{class_name}.zarr" - writer = ImageWriter( - path=str(path), - target_class=class_name, - scale=(4.0, 4.0, 4.0), - write_voxel_shape=(16, 16, 16), - bounding_box={"z": [0, 64], "y": [0, 64], "x": [0, 64]}, - ) - writers.append(writer) - - assert len(writers) == 3 - assert all(w.target_class in classes for w in writers) - - def test_image_writer_channel_axis_detection(self, tmp_upath): - """Test automatic channel axis detection when write_voxel_shape has extra dimension.""" - path = tmp_upath / "output_channels.zarr" - - # Test with 4D shape but 3D axis_order (should add channel axis) - writer = ImageWriter( - path=str(path), - target_class="multichannel", - scale=(4.0, 4.0, 4.0), - write_voxel_shape=(3, 32, 32, 32), # 4D shape with channels - axis_order="zyx", # 3D axis order - bounding_box={"z": [0, 256], "y": [0, 256], "x": [0, 256]}, - ) - - # Verify channel axis was added - assert "c" in writer.axes - assert writer.axes.startswith("c") - assert len(writer.axes) == 4 - - def test_image_writer_with_explicit_channel_axis(self, tmp_upath): - """Test ImageWriter with explicit channel axis in axis_order.""" - path = tmp_upath / "output_explicit_channels.zarr" - - # Test with explicit channel axis - writer = ImageWriter( - path=str(path), - target_class="multichannel", - scale=(4.0, 4.0, 4.0), - write_voxel_shape=(5, 32, 32, 32), # 4D shape with 5 channels - axis_order="czyx", # Explicit channel axis - bounding_box={"z": [0, 256], "y": [0, 256], "x": [0, 256]}, - ) - - # Verify channel axis is present - assert writer.axes == "czyx" - assert writer.write_voxel_shape["c"] == 5 - - def test_image_writer_no_channel_detection_when_not_needed(self, tmp_upath): - """Test that channel axis is not added when dimensions match.""" - path = tmp_upath / "output_no_channels.zarr" - - # Test with matching dimensions (no channel detection needed) - writer = ImageWriter( - path=str(path), - target_class="test", - scale=(4.0, 4.0, 4.0), - write_voxel_shape=(32, 32, 32), # 3D shape matching 3D axis order - axis_order="zyx", - bounding_box={"z": [0, 256], "y": [0, 256], "x": [0, 256]}, - ) - - # Verify no channel axis was added - assert "c" not in writer.axes - assert writer.axes == "zyx" diff --git a/tests/test_geometry.py b/tests/test_geometry.py new file mode 100644 index 0000000..ff1d840 --- /dev/null +++ b/tests/test_geometry.py @@ -0,0 +1,83 @@ +"""Tests for geometry utilities.""" + +from __future__ import annotations + +import pytest + +from cellmap_data.utils.geometry import box_intersection, box_shape, box_union + + +class TestBoxIntersection: + def test_overlap(self): + a = {"z": (0.0, 100.0), "y": (0.0, 100.0)} + b = {"z": (50.0, 150.0), "y": (50.0, 150.0)} + result = box_intersection(a, b) + assert result == {"z": (50.0, 100.0), "y": (50.0, 100.0)} + + def test_no_overlap_returns_none(self): + a = {"z": (0.0, 50.0)} + b = {"z": (60.0, 100.0)} + assert box_intersection(a, b) is None + + def test_touching_returns_none(self): + a = {"z": (0.0, 50.0)} + b = {"z": (50.0, 100.0)} + assert box_intersection(a, b) is None # lo >= hi + + def test_one_contains_other(self): + a = {"z": (0.0, 200.0)} + b = {"z": (50.0, 150.0)} + result = box_intersection(a, b) + assert result == {"z": (50.0, 150.0)} + + def test_missing_axis_skipped(self): + a = {"z": (0.0, 100.0), "y": (0.0, 100.0)} + b = {"z": (10.0, 90.0)} # no y key + result = box_intersection(a, b) + assert result == {"z": (10.0, 90.0)} + + def test_empty_result_returns_none(self): + # No shared axes at all + a = {"z": (0.0, 100.0)} + b = {"y": (0.0, 100.0)} + assert box_intersection(a, b) is None + + +class TestBoxUnion: + def test_same_boxes(self): + a = {"z": (0.0, 100.0)} + result = box_union(a, a) + assert result == a + + def test_disjoint(self): + a = {"z": (0.0, 50.0)} + b = {"z": (70.0, 120.0)} + result = box_union(a, b) + assert result == {"z": (0.0, 120.0)} + + def test_missing_axis_in_one(self): + a = {"z": (0.0, 100.0)} + b = {"y": (5.0, 50.0)} + result = box_union(a, b) + assert result["z"] == (0.0, 100.0) + assert result["y"] == (5.0, 50.0) + + +class TestBoxShape: + def test_basic(self): + box = {"z": (0.0, 160.0), "y": (0.0, 160.0), "x": (0.0, 160.0)} + scale = {"z": 8.0, "y": 8.0, "x": 8.0} + result = box_shape(box, scale) + assert result == {"z": 20, "y": 20, "x": 20} + + def test_min_one(self): + box = {"z": (0.0, 4.0)} + scale = {"z": 8.0} + # 4/8 = 0.5 → rounds to 1 (at least 1) + assert box_shape(box, scale)["z"] == 1 + + def test_non_integer_rounds(self): + box = {"z": (0.0, 12.0)} + scale = {"z": 8.0} + # 12/8 = 1.5 → rounds to 2 + assert box_shape(box, scale)["z"] == 2 diff --git a/tests/test_helpers.py b/tests/test_helpers.py index fb2ec74..89500dc 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,313 +1,133 @@ -""" -Test helpers for creating real test data without mocks. +"""Helpers for creating minimal zarr test fixtures.""" -This module provides utilities to create real Zarr/OME-NGFF datasets -for testing purposes. -""" +from __future__ import annotations +import json +import os from pathlib import Path -from typing import Any, Dict, Optional, Sequence import numpy as np import zarr -from pydantic_ome_ngff.v04.axis import Axis -from pydantic_ome_ngff.v04.multiscale import Dataset as MultiscaleDataset -from pydantic_ome_ngff.v04.multiscale import MultiscaleMetadata -from pydantic_ome_ngff.v04.transform import VectorScale -def create_test_zarr_array( - path: Path, +def _write_ome_ngff( + path: str, data: np.ndarray, - axes: Sequence[str] = ("z", "y", "x"), - scale: Sequence[float] = (1.0, 1.0, 1.0), - chunks: Optional[Sequence[int]] = None, - multiscale: bool = True, - absent: int = 0, -) -> zarr.Array: - """ - Create a test Zarr array with OME-NGFF metadata. - - Args: - path: Path to create the Zarr array - data: Numpy array data - axes: Axis names - scale: Scale for each axis in physical units - chunks: Chunk size for Zarr array - multiscale: Whether to create multiscale metadata - - Returns: - Created zarr.Array - """ - path.mkdir(parents=True, exist_ok=True) - - if chunks is None: - chunks = tuple(min(32, s) for s in data.shape) - - # Create zarr group - store = zarr.DirectoryStore(str(path)) - root = zarr.group(store=store, overwrite=True) - - if multiscale: - # Create multiscale group with s0 level - s0 = root.create_dataset( - "s0", - data=data, - chunks=chunks, - dtype=data.dtype, - overwrite=True, - ) - - # Create OME-NGFF multiscale metadata - axis_list = tuple( - Axis( - name=name, - type="space" if name in ["x", "y", "z"] else "channel", - unit="nanometer" if name in ["x", "y", "z"] else None, - ) - for name in axes - ) - - datasets = ( - MultiscaleDataset( - path="s0", - coordinateTransformations=( - VectorScale(type="scale", scale=tuple(scale)), - ), - ), - ) - - multiscale_metadata = MultiscaleMetadata( - version="0.4", - name="test_data", - axes=axis_list, - datasets=datasets, - ) - - root.attrs["multiscales"] = [ - multiscale_metadata.model_dump(mode="json", exclude_none=True) + voxel_size: list[float], + *, + axes: list[str] | None = None, + origin: list[float] | None = None, + level: str = "s0", +) -> None: + """Write a single-level OME-NGFF zarr group.""" + if axes is None: + axes = ["z", "y", "x"][-data.ndim :] + if origin is None: + origin = [0.0] * len(axes) + + os.makedirs(path, exist_ok=True) + z_attrs = { + "multiscales": [ + { + "axes": [ + {"name": ax, "type": "space", "unit": "nanometer"} for ax in axes + ], + "datasets": [ + { + "path": level, + "coordinateTransformations": [ + {"type": "scale", "scale": voxel_size}, + {"type": "translation", "translation": origin}, + ], + } + ], + "version": "0.4", + } ] - - s0.attrs["cellmap"] = {"annotation": {"complement_counts": {"absent": absent}}} - - return s0 - else: - # Create simple array without multiscale - arr = root.create_dataset( - name="data", - data=data, - chunks=chunks, - dtype=data.dtype, - overwrite=True, - ) - return arr - - -def create_test_image_data( - shape: Sequence[int], - dtype: np.dtype = np.float32, - pattern: str = "gradient", - seed: int = 42, -) -> np.ndarray: - """ - Create test image data with various patterns. - - Args: - shape: Shape of the array - dtype: Data type - pattern: Type of pattern ("gradient", "checkerboard", "random", "constant", "sphere") - seed: Random seed - - Returns: - Generated numpy array - """ - rng = np.random.default_rng(seed) - - if pattern == "gradient": - # Create a gradient along the last axis - data = np.zeros(shape, dtype=dtype) - for i in range(shape[-1]): - data[..., i] = i / shape[-1] - elif pattern == "checkerboard": - # Create checkerboard pattern - indices = np.indices(shape) - data = np.sum(indices, axis=0) % 2 - data = data.astype(dtype) - elif pattern == "random": - # Random values between 0 and 1 - data = rng.random(shape, dtype=np.float32).astype(dtype) - elif pattern == "constant": - # Constant value - data = np.ones(shape, dtype=dtype) - elif pattern == "sphere": - # Create a sphere in the center - data = np.zeros(shape, dtype=dtype) - center = tuple(s // 2 for s in shape) - radius = min(shape) // 4 - - indices = np.indices(shape) - distances = np.sqrt( - sum((indices[i] - center[i]) ** 2 for i in range(len(shape))) - ) - data[distances <= radius] = 1.0 - else: - raise ValueError(f"Unknown pattern: {pattern}") - - return data - - -def create_test_label_data( - shape: Sequence[int], - num_classes: int = 3, - pattern: str = "regions", - seed: int = 42, -) -> Dict[str, np.ndarray]: - """ - Create test label data for multiple classes. - - Args: - shape: Shape of the arrays - num_classes: Number of classes to generate - pattern: Type of pattern ("regions", "random", "stripes") - seed: Random seed - - Returns: - Dictionary mapping class names to label arrays + } + with open(os.path.join(path, ".zattrs"), "w") as f: + json.dump(z_attrs, f) + with open(os.path.join(path, ".zgroup"), "w") as f: + f.write('{"zarr_format": 2}') + + arr_path = os.path.join(path, level) + zarr.open_array( + arr_path, + mode="w", + shape=data.shape, + dtype=data.dtype, + chunks=data.shape, + )[:] = data + + +def create_test_zarr( + tmp_path: Path, + name: str = "test", + shape: tuple[int, ...] = (20, 20, 20), + voxel_size: list[float] | None = None, + origin: list[float] | None = None, + data: np.ndarray | None = None, + axes: list[str] | None = None, +) -> str: + """Create a minimal OME-NGFF zarr group under *tmp_path*. + + Returns the path to the zarr group (the directory). """ - rng = np.random.default_rng(seed) - labels = {} - - if pattern == "regions": - # Divide the volume into regions for different classes - for i in range(num_classes): - class_label = np.zeros(shape, dtype=np.uint8) - # Create regions along first axis - start = (i * shape[0]) // num_classes - end = ((i + 1) * shape[0]) // num_classes - class_label[start:end] = 1 - labels[f"class_{i}"] = class_label - elif pattern == "random": - # Random labels - for i in range(num_classes): - labels[f"class_{i}"] = (rng.random(shape) > 0.5).astype(np.uint8) - elif pattern == "stripes": - # Create stripes along last axis - for i in range(num_classes): - class_label = np.zeros(shape, dtype=np.uint8) - # Create stripes - for j in range(shape[-1]): - if j % num_classes == i: - class_label[..., j] = 1 - if np.sum(class_label) == 0 and shape[-1] > 0: - class_label[..., 0] = 1 # Ensure at least one pixel - labels[f"class_{i}"] = class_label - else: - raise ValueError(f"Unknown pattern: {pattern}") - - return labels + ndim = len(shape) + if axes is None: + axes = ["z", "y", "x"][-ndim:] + if voxel_size is None: + voxel_size = [8.0] * ndim + if origin is None: + origin = [0.0] * ndim + if data is None: + rng = np.random.default_rng(0) + data = (rng.random(shape) * 255).astype(np.uint8) + + path = str(tmp_path / f"{name}.zarr") + _write_ome_ngff(path, data, voxel_size, axes=axes, origin=origin) + return path def create_test_dataset( tmp_path: Path, - raw_shape: Sequence[int] = (64, 64, 64), - gt_shape: Optional[Sequence[int]] = None, - num_classes: int = 3, - raw_scale: Sequence[float] = (4.0, 4.0, 4.0), - gt_scale: Optional[Sequence[float]] = None, - seed: int = 0, - raw_pattern: str = "random", - label_pattern: str = "regions", -) -> Dict[str, Any]: - """ - Create a test dataset with raw and ground truth Zarr arrays. + classes: list[str] | None = None, + shape: tuple[int, ...] = (32, 32, 32), + voxel_size: list[float] | None = None, +) -> dict: + """Create a minimal raw + label zarr dataset for testing. - Args: - tmp_path: Path to create the dataset - raw_shape: Shape of the raw data - gt_shape: Shape of the ground truth data - num_classes: Number of classes in ground truth - raw_scale: Scale of the raw data - gt_scale: Scale of the ground truth data - seed: Random seed for data generation - raw_pattern: Pattern for raw data - label_pattern: Pattern for label data - - Returns: - Dictionary with paths and parameters of the created dataset + Returns a dict with keys ``raw_path``, ``gt_path``, ``classes``. """ - dataset_path = tmp_path / "dataset.zarr" - raw_data = create_test_image_data( - raw_shape, dtype=np.dtype(np.uint8), pattern=raw_pattern, seed=seed - ) - create_test_zarr_array(dataset_path / "raw", raw_data, scale=raw_scale) - - classes = [f"class_{i}" for i in range(num_classes)] - if gt_shape is None: - gt_shape = raw_shape - if gt_scale is None: - gt_scale = raw_scale - - label_data = create_test_label_data( - gt_shape, num_classes, pattern=label_pattern, seed=seed - ) - - for class_name, gt_data in label_data.items(): - class_path = dataset_path / class_name - create_test_zarr_array( - class_path, - gt_data, - scale=gt_scale, - absent=np.count_nonzero(gt_data == 0), - ) + if classes is None: + classes = ["mito", "er"] + ndim = len(shape) + if voxel_size is None: + voxel_size = [8.0] * ndim + axes = ["z", "y", "x"][-ndim:] + + rng = np.random.default_rng(42) + raw_data = (rng.random(shape) * 255).astype(np.uint8) + raw_path = str(tmp_path / "raw.zarr") + _write_ome_ngff(raw_path, raw_data, voxel_size, axes=axes) + + gt_base = str(tmp_path / "gt.zarr") + os.makedirs(gt_base, exist_ok=True) + with open(os.path.join(gt_base, ".zgroup"), "w") as f: + f.write('{"zarr_format": 2}') + + for cls in classes: + label_data = rng.integers(0, 2, size=shape).astype(np.uint8) + cls_path = os.path.join(gt_base, cls) + _write_ome_ngff(cls_path, label_data, voxel_size, axes=axes) + + class_str = ",".join(classes) + gt_path = f"{gt_base}/[{class_str}]" return { - "raw_path": str(dataset_path / "raw"), - "gt_path": str(dataset_path / f"[{','.join(classes)}]"), + "raw_path": raw_path, + "gt_path": gt_path, "classes": classes, - "raw_shape": raw_shape, - "gt_shape": gt_shape, - "raw_scale": raw_scale, - "gt_scale": gt_scale, + "shape": shape, + "voxel_size": voxel_size, } - - -def create_minimal_test_dataset(tmp_path: Path) -> Dict[str, Any]: - """ - Create a minimal test dataset for quick tests. - - Args: - tmp_path: Temporary directory path - - Returns: - Dictionary with paths and metadata - """ - return create_test_dataset( - tmp_path, - raw_shape=(16, 16, 16), - num_classes=2, - raw_scale=(4.0, 4.0, 4.0), - ) - - -def check_device_transfer(loader, device): - """ - Check if data transfer between CPU and GPU works as expected. - - Args: - loader: Data loader providing the data - device: Device to transfer the data to (e.g., "cuda" or "cpu") - - Returns: - None - """ - # Iterate through the data loader - for batch in loader: - # Transfer the batch to the specified device - batch = {k: v.to(device) for k, v in batch.items()} - - # Check if the transfer was successful - for k, v in batch.items(): - assert v.device == device - - # Break after the first batch to avoid transferring all data - break diff --git a/tests/test_image.py b/tests/test_image.py new file mode 100644 index 0000000..6ced03d --- /dev/null +++ b/tests/test_image.py @@ -0,0 +1,190 @@ +"""Tests for CellMapImage.""" + +from __future__ import annotations + +import numpy as np +import pytest +import torch + +from cellmap_data import CellMapImage + +from .test_helpers import create_test_zarr + + +class TestCellMapImageBasics: + def test_init(self, tmp_path): + path = create_test_zarr(tmp_path, shape=(20, 20, 20)) + img = CellMapImage( + path=path, + target_class="raw", + target_scale=[8.0, 8.0, 8.0], + target_voxel_shape=[4, 4, 4], + ) + assert img.label_class == "raw" + assert img.axes == ["z", "y", "x"] + + def test_bounding_box(self, tmp_path): + path = create_test_zarr( + tmp_path, shape=(20, 20, 20), voxel_size=[8.0, 8.0, 8.0] + ) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [4, 4, 4]) + bb = img.bounding_box + assert set(bb.keys()) == {"z", "y", "x"} + # 20 voxels * 8 nm = 160 nm + assert bb["z"] == pytest.approx((0.0, 160.0)) + assert bb["x"] == pytest.approx((0.0, 160.0)) + + def test_sampling_box(self, tmp_path): + path = create_test_zarr(tmp_path, shape=(20, 20, 20)) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [4, 4, 4]) + sb = img.sampling_box + # output is 4*8=32 nm → half=16 nm shrink on each side + assert sb["z"][0] == pytest.approx(16.0) + assert sb["z"][1] == pytest.approx(144.0) + + def test_sampling_box_none_when_too_small(self, tmp_path): + # Patch (100 voxels) larger than array (10 voxels * 8nm = 80nm) + path = create_test_zarr(tmp_path, shape=(10, 10, 10)) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [100, 100, 100]) + assert img.sampling_box is None + + def test_scale_level_best_match(self, tmp_path): + path = create_test_zarr(tmp_path, voxel_size=[8.0, 8.0, 8.0]) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [4, 4, 4]) + assert img.scale_level == 0 + + def test_getitem_returns_tensor(self, tmp_path): + path = create_test_zarr(tmp_path, shape=(20, 20, 20)) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [4, 4, 4]) + center = {"z": 80.0, "y": 80.0, "x": 80.0} + patch = img[center] + assert isinstance(patch, torch.Tensor) + assert patch.shape == torch.Size([4, 4, 4]) + + def test_getitem_shape_correct(self, tmp_path): + path = create_test_zarr(tmp_path, shape=(40, 40, 40)) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [8, 8, 8]) + center = {"z": 160.0, "y": 160.0, "x": 160.0} + patch = img[center] + assert patch.shape == torch.Size([8, 8, 8]) + + def test_padding_with_nan(self, tmp_path): + """Reading near edge with pad=True → NaN in OOB regions.""" + path = create_test_zarr(tmp_path, shape=(8, 8, 8)) + img = CellMapImage( + path, "raw", [8.0, 8.0, 8.0], [4, 4, 4], pad=True, pad_value=float("nan") + ) + # Center near corner: some region will be outside bounds + center = {"z": 4.0, "y": 4.0, "x": 4.0} # origin + 0.5 voxel + patch = img[center] + assert patch.shape == torch.Size([4, 4, 4]) + # Should have some NaN in the padded region + assert torch.isnan(patch).any() + + def test_no_padding_clamps(self, tmp_path): + """Reading near edge with pad=False → no NaN, just smaller or clamped data.""" + path = create_test_zarr(tmp_path, shape=(8, 8, 8)) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [4, 4, 4], pad=False) + center = {"z": 4.0, "y": 4.0, "x": 4.0} + patch = img[center] + # No NaN expected (clamped read, may be smaller shape) + assert not torch.isnan(patch).any() + + def test_get_center(self, tmp_path): + path = create_test_zarr(tmp_path, shape=(20, 20, 20)) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [4, 4, 4]) + center = img.get_center(0) + assert set(center.keys()) == {"z", "y", "x"} + # First centre should be at sampling_box lower bound + 0.5*scale + sb = img.sampling_box + assert center["z"] == pytest.approx(sb["z"][0] + 0.5 * 8.0) + + def test_value_transform_applied(self, tmp_path): + """A value_transform that negates values should change the output.""" + path = create_test_zarr(tmp_path, shape=(20, 20, 20)) + img = CellMapImage( + path, + "raw", + [8.0, 8.0, 8.0], + [4, 4, 4], + value_transform=lambda x: x * -1.0, + ) + img_plain = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [4, 4, 4]) + center = {"z": 80.0, "y": 80.0, "x": 80.0} + assert torch.allclose(img[center], -img_plain[center]) + + def test_repr(self, tmp_path): + path = create_test_zarr(tmp_path) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [4, 4, 4]) + r = repr(img) + assert "CellMapImage" in r + assert "raw" in r + + +class TestCellMapImageSpatialTransforms: + def test_mirror_z(self, tmp_path): + data = np.arange(8 * 8 * 8, dtype=np.float32).reshape(8, 8, 8) + path = create_test_zarr(tmp_path, shape=(8, 8, 8), data=data) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [4, 4, 4], pad=True) + center = {"z": 32.0, "y": 32.0, "x": 32.0} + img.set_spatial_transforms(None) + patch_orig = img[center].clone() + img.set_spatial_transforms({"mirror": {"z": True, "y": False, "x": False}}) + patch_mirrored = img[center].clone() + img.set_spatial_transforms(None) + assert not torch.allclose(patch_orig, patch_mirrored) + # Mirroring z twice should give back original + assert torch.allclose(patch_orig, patch_mirrored.flip(0)) + + def test_set_spatial_transforms_none(self, tmp_path): + path = create_test_zarr(tmp_path, shape=(20, 20, 20)) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [4, 4, 4]) + img.set_spatial_transforms(None) + assert img._current_spatial_transforms is None + + def test_rotation_read_shape_larger(self, tmp_path): + """With rotation, read_shape should be larger than output_shape.""" + path = create_test_zarr(tmp_path, shape=(40, 40, 40)) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [8, 8, 8], pad=True) + # 45° rotation about z + theta = np.deg2rad(45) + R = np.eye(3) + R[1, 1] = np.cos(theta) + R[1, 2] = -np.sin(theta) + R[2, 1] = np.sin(theta) + R[2, 2] = np.cos(theta) + img.set_spatial_transforms({"rotation_matrix": R}) + read_shape = img._compute_read_shape() + # y and x dims should be larger (by cos+sin ≈ 1.41) + assert read_shape[1] > 8 + assert read_shape[2] > 8 + img.set_spatial_transforms(None) + + def test_rotation_output_shape_preserved(self, tmp_path): + """After rotation+crop, output shape must equal target_voxel_shape.""" + path = create_test_zarr(tmp_path, shape=(60, 60, 60)) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [8, 8, 8], pad=True) + theta = np.deg2rad(30) + R = np.eye(3) + R[1, 1] = np.cos(theta) + R[1, 2] = -np.sin(theta) + R[2, 1] = np.sin(theta) + R[2, 2] = np.cos(theta) + center = {"z": 240.0, "y": 240.0, "x": 240.0} + img.set_spatial_transforms({"rotation_matrix": R}) + patch = img[center] + img.set_spatial_transforms(None) + assert patch.shape == torch.Size([8, 8, 8]) + + +class TestCellMapImageClassCounts: + def test_class_counts_keys(self, tmp_path): + import zarr as z + + data = np.zeros((10, 10, 10), dtype=np.uint8) + data[2:5, 2:5, 2:5] = 1 # some foreground + path = create_test_zarr(tmp_path, shape=(10, 10, 10), data=data) + img = CellMapImage(path, "mito", [8.0, 8.0, 8.0], [4, 4, 4]) + counts = img.class_counts + assert "mito" in counts + assert counts["mito"] >= 0 diff --git a/tests/test_image_edge_cases.py b/tests/test_image_edge_cases.py deleted file mode 100644 index 632d912..0000000 --- a/tests/test_image_edge_cases.py +++ /dev/null @@ -1,744 +0,0 @@ -"""Tests for CellMapImage edge cases and special methods.""" - -import numpy as np -import pytest -import torch - -from cellmap_data import CellMapImage - -from .test_helpers import create_test_image_data, create_test_zarr_array - - -class TestCellMapImageEdgeCases: - """Test edge cases and special methods in CellMapImage.""" - - @pytest.fixture - def test_zarr_image(self, tmp_path): - """Create a test Zarr image.""" - data = create_test_image_data((32, 32, 32), pattern="gradient") - path = tmp_path / "test_image.zarr" - create_test_zarr_array(path, data, scale=(4.0, 4.0, 4.0)) - return str(path), data - - def test_axis_order_longer_than_scale(self, test_zarr_image): - """Test handling when axis_order has more axes than target_scale.""" - path, _ = test_zarr_image - - # Provide fewer scale values than axes - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0), # Only 2 values for 3 axes - target_voxel_shape=(16, 16, 16), - axis_order="zyx", # 3 axes - ) - - # Should pad scale with first value - assert len(image.scale) == 3 - assert image.scale["z"] == 4.0 # Padded value - assert image.scale["y"] == 4.0 - assert image.scale["x"] == 4.0 - - def test_axis_order_longer_than_shape(self, test_zarr_image): - """Test handling when axis_order has more axes than target_voxel_shape.""" - path, _ = test_zarr_image - - # Provide fewer shape values than axes - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16), # Only 2 values for 3 axes - axis_order="zyx", # 3 axes - ) - - # Should pad shape with 1s - assert len(image.output_shape) == 3 - assert image.output_shape["z"] == 1 # Padded value - assert image.output_shape["y"] == 16 - assert image.output_shape["x"] == 16 - - def test_device_auto_selection_cuda(self, test_zarr_image): - """Test device auto-selection when no device specified.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - ) - - # Should select an appropriate device - assert image.device in ["cuda", "mps", "cpu"] - - def test_explicit_device_selection(self, test_zarr_image): - """Test explicit device selection.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - device="cpu", - ) - - assert image.device == "cpu" - - def test_to_device_method(self, test_zarr_image): - """Test moving image to different device.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - ) - - # Move to CPU - image.to("cpu") - assert image.device == "cpu" - - def test_set_spatial_transforms_none(self, test_zarr_image): - """Test setting spatial transforms to None.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - ) - - # Set to None - image.set_spatial_transforms(None) - assert image._current_spatial_transforms is None - - def test_set_spatial_transforms_with_values(self, test_zarr_image): - """Test setting spatial transforms with actual transform dict.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - ) - - # Set transforms - transforms = {"mirror": {"axes": {"x": 0.5}}} - image.set_spatial_transforms(transforms) - assert image._current_spatial_transforms == transforms - - def test_bounding_box_property(self, test_zarr_image): - """Test the bounding_box property.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - ) - - bbox = image.bounding_box - - # Should be a dict with axis keys - assert isinstance(bbox, dict) - for axis in "zyx": - assert axis in bbox - assert len(bbox[axis]) == 2 - assert bbox[axis][0] <= bbox[axis][1] - - def test_sampling_box_property(self, test_zarr_image): - """Test the sampling_box property.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - ) - - sbox = image.sampling_box - - # Should be a dict with axis keys - assert isinstance(sbox, dict) - for axis in "zyx": - assert axis in sbox - assert len(sbox[axis]) == 2 - - def test_class_counts_fast_path_value(self, tmp_path): - """class_counts fast path reads complement_counts from s0/.zattrs and - returns the number of foreground voxels normalized to training resolution. - - With s0_scale == training_scale the ratio is 1, so counts equal the - raw foreground voxel count at s0 resolution. - """ - shape = (8, 8, 8) - total = int(np.prod(shape)) - absent = 100 # background voxels at s0 - expected_fg = total - absent - - data = np.zeros(shape, dtype=np.uint8) - path = tmp_path / "label.zarr" - create_test_zarr_array(path, data, scale=(4.0, 4.0, 4.0), absent=absent) - - image = CellMapImage( - path=str(path), - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), # same as s0 scale → ratio 1 - target_voxel_shape=(4, 4, 4), - ) - - assert image.class_counts == pytest.approx(expected_fg) - assert image.bg_count == pytest.approx(absent) - - def test_class_counts_normalized_by_training_scale(self, tmp_path): - """When training scale differs from s0 scale, counts are expressed in - training-resolution voxels, not physical volume. - - s0=4nm, training=8nm → each s0 voxel is (4/8)^3 = 0.125 training voxels. - """ - shape = (8, 8, 8) - total = int(np.prod(shape)) - absent = 100 - expected_fg_training = (total - absent) * (4.0**3 / 8.0**3) - expected_bg_training = absent * (4.0**3 / 8.0**3) - - data = np.zeros(shape, dtype=np.uint8) - path = tmp_path / "label.zarr" - create_test_zarr_array(path, data, scale=(4.0, 4.0, 4.0), absent=absent) - - image = CellMapImage( - path=str(path), - target_class="test_class", - target_scale=(8.0, 8.0, 8.0), # coarser training scale - target_voxel_shape=(4, 4, 4), - ) - - assert image.class_counts == pytest.approx(expected_fg_training) - assert image.bg_count == pytest.approx(expected_bg_training) - - def _make_zarr_without_cellmap_attrs(self, path, data, scale=(4.0, 4.0, 4.0)): - """Create a zarr group at *path* with OME-NGFF metadata but no cellmap attrs.""" - import zarr as _zarr - from pydantic_ome_ngff.v04.axis import Axis - from pydantic_ome_ngff.v04.multiscale import ( - Dataset as MultiscaleDataset, - MultiscaleMetadata, - ) - from pydantic_ome_ngff.v04.transform import VectorScale as _VectorScale - - path.mkdir(parents=True, exist_ok=True) - store = _zarr.DirectoryStore(str(path)) - root = _zarr.group(store=store, overwrite=True) - root.create_dataset("s0", data=data, chunks=data.shape, overwrite=True) - axes = tuple(Axis(name=n, type="space", unit="nanometer") for n in "zyx") - ms = MultiscaleMetadata( - version="0.4", - name="no_meta", - axes=axes, - datasets=( - MultiscaleDataset( - path="s0", - coordinateTransformations=( - _VectorScale(type="scale", scale=scale), - ), - ), - ), - ) - root.attrs["multiscales"] = [ms.model_dump(mode="json", exclude_none=True)] - # s0 has NO "cellmap" attrs — triggers fallback - - def test_class_counts_fallback_without_metadata(self, tmp_path): - """When complement_counts metadata is absent the fallback reads the - array directly and returns the raw voxel count (no scale factor). - """ - shape = (8, 8, 8) - data = create_test_image_data(shape, pattern="sphere").astype(np.uint8) - path = tmp_path / "label_no_meta.zarr" - self._make_zarr_without_cellmap_attrs(path, data, scale=(4.0, 4.0, 4.0)) - - expected_fg = int(np.count_nonzero(data)) - expected_bg = int(data.size - np.count_nonzero(data)) - - image = CellMapImage( - path=str(path), - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(4, 4, 4), - ) - - assert image.class_counts == pytest.approx(expected_fg) - assert image.bg_count == pytest.approx(expected_bg) - - def test_class_counts_fallback_writes_metadata(self, tmp_path): - """After the fallback fires, complement_counts should be written back to - s0/.zattrs so that subsequent calls use the fast path. - """ - import json - - shape = (8, 8, 8) - data = create_test_image_data(shape, pattern="sphere").astype(np.uint8) - path = tmp_path / "label_writeback.zarr" - self._make_zarr_without_cellmap_attrs(path, data, scale=(4.0, 4.0, 4.0)) - - image = CellMapImage( - path=str(path), - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(4, 4, 4), - ) - _ = image.class_counts # triggers fallback + write-back - - # The s0/.zattrs file should now contain the complement_counts entry. - zattrs_path = path / "s0" / ".zattrs" - assert zattrs_path.exists(), "s0/.zattrs was not written" - with open(zattrs_path) as f: - s0_attrs = json.load(f) - - bg_s0 = s0_attrs["cellmap"]["annotation"]["complement_counts"]["absent"] - expected_bg_s0 = int(data.size - np.count_nonzero(data)) - assert bg_s0 == expected_bg_s0 - - def test_class_counts_fallback_writeback_converts_scale(self, tmp_path): - """When training scale differs from s0 scale the written absent count - must be in s0 voxels, not training voxels. - - s0=4nm, training=8nm → training voxel covers 8x the s0 voxel volume, - so fg_s0 = fg_training * (8^3 / 4^3) = fg_training * 8. - """ - import json - - shape = (8, 8, 8) # s0 shape; 512 voxels total - data = create_test_image_data(shape, pattern="sphere").astype(np.uint8) - path = tmp_path / "label_scale_writeback.zarr" - self._make_zarr_without_cellmap_attrs(path, data, scale=(4.0, 4.0, 4.0)) - - image = CellMapImage( - path=str(path), - target_class="test_class", - target_scale=(8.0, 8.0, 8.0), # coarser training resolution - target_voxel_shape=(4, 4, 4), - ) - _ = image.class_counts # triggers fallback + write-back - - zattrs_path = path / "s0" / ".zattrs" - with open(zattrs_path) as f: - s0_attrs = json.load(f) - - bg_s0 = s0_attrs["cellmap"]["annotation"]["complement_counts"]["absent"] - total_s0 = int(np.prod(shape)) - - # fg_training = count_nonzero at training resolution (same data since - # training scale is just a resampling spec; the actual array loaded is - # the training-resolution crop). For this test the array IS the s0 - # data, so fg_training = count_nonzero(data). - fg_training = int(np.count_nonzero(data)) - # scale_ratio (training→s0) = training_voxel_vol / s0_voxel_vol = 8^3/4^3 = 8 - fg_s0_expected = int(round(fg_training * (8.0**3 / 4.0**3))) - fg_s0_expected = min(fg_s0_expected, total_s0) - bg_s0_expected = total_s0 - fg_s0_expected - - assert bg_s0 == bg_s0_expected - - def test_class_counts_property(self, test_zarr_image): - """class_counts returns a non-negative float.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - ) - - counts = image.class_counts - assert isinstance(counts, float) - assert counts >= 0.0 - - def test_pad_parameter_true(self, test_zarr_image): - """Test padding when pad=True.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - pad=True, - pad_value=0, - ) - - assert image.pad is True - assert image.pad_value == 0 - - def test_pad_parameter_false(self, test_zarr_image): - """Test when pad=False.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - pad=False, - ) - - assert image.pad is False - - def test_interpolation_nearest(self, test_zarr_image): - """Test interpolation mode nearest.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - interpolation="nearest", - ) - - assert image.interpolation == "nearest" - - def test_interpolation_linear(self, test_zarr_image): - """Test interpolation mode linear.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - interpolation="linear", - ) - - assert image.interpolation == "linear" - - def test_value_transform_none(self, test_zarr_image): - """Test when no value transform is provided.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - value_transform=None, - ) - - assert image.value_transform is None - - def test_value_transform_provided(self, test_zarr_image): - """Test when value transform is provided.""" - path, _ = test_zarr_image - - transform = lambda x: x * 2 - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - value_transform=transform, - ) - - assert image.value_transform is transform - - def test_output_size_calculation(self, test_zarr_image): - """Test that output_size is correctly calculated.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 8.0, 2.0), - target_voxel_shape=(10, 20, 30), - axis_order="zyx", - ) - - # output_size = voxel_shape * scale - assert image.output_size["z"] == 10 * 4.0 - assert image.output_size["y"] == 20 * 8.0 - assert image.output_size["x"] == 30 * 2.0 - - def test_axes_property(self, test_zarr_image): - """Test that axes property is correctly set.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - axis_order="zyx", - ) - - assert image.axes == "zyx" - - def test_context_parameter_none(self, test_zarr_image): - """Test when no context is provided.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - context=None, - ) - - assert image.context is None - - def test_path_attribute(self, test_zarr_image): - """Test that path attribute is correctly set.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - ) - - assert image.path == path - - def test_label_class_attribute(self, test_zarr_image): - """Test that label_class attribute is correctly set.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="my_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - ) - - assert image.label_class == "my_class" - - def test_getitem_returns_tensor(self, test_zarr_image): - """Test that __getitem__ returns a PyTorch tensor.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - ) - - center = {"z": 64.0, "y": 64.0, "x": 64.0} - result = image[center] - - assert isinstance(result, torch.Tensor) - assert result.ndim >= 3 - - def test_nan_pad_value(self, test_zarr_image): - """Test using NaN as pad value.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(16, 16, 16), - pad=True, - pad_value=np.nan, - ) - - assert np.isnan(image.pad_value) - - # ----------------------------------------------------------------------- - # coord_offsets caching - # ----------------------------------------------------------------------- - - def test_coord_offsets_is_cached_property(self, test_zarr_image): - """coord_offsets must use @cached_property, not a manual null-check pattern. - - Verifies: (a) the returned dict has the expected axes, (b) successive - accesses return the exact same objects (cached, not recomputed), and - (c) the offsets are symmetric around zero for each axis. - """ - path, _ = test_zarr_image - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - ) - - offsets1 = image.coord_offsets - offsets2 = image.coord_offsets - - # Cached: same object returned on every access - assert offsets1 is offsets2 - - # Stored in __dict__ (cached_property, not regular property) - assert "coord_offsets" in image.__dict__ - - # Correct axes present - for axis in image.axes: - assert axis in offsets1 - arr = offsets1[axis] - assert len(arr) == image.output_shape[axis] - # Symmetric around zero within float tolerance - assert abs(arr[0] + arr[-1]) < 1e-9 - - def test_coord_offsets_not_cleared_by_array_cache_clear(self, test_zarr_image): - """_clear_array_cache must only clear 'array', leaving coord_offsets intact.""" - path, _ = test_zarr_image - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - ) - - offsets_before = image.coord_offsets # populate cache - assert "coord_offsets" in image.__dict__ - - image._clear_array_cache() - - # coord_offsets must still be cached after cache clear - assert "coord_offsets" in image.__dict__ - assert image.coord_offsets is offsets_before - - def test_coord_offsets_values_match_output_size_and_scale(self, test_zarr_image): - """coord_offsets values must span exactly [-output_size/2+scale/2, output_size/2-scale/2].""" - path, _ = test_zarr_image - image = CellMapImage( - path=path, - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - ) - - for axis in image.axes: - arr = image.coord_offsets[axis] - expected_lo = -image.output_size[axis] / 2 + image.scale[axis] / 2 - expected_hi = image.output_size[axis] / 2 - image.scale[axis] / 2 - assert abs(arr[0] - expected_lo) < 1e-9 - assert abs(arr[-1] - expected_hi) < 1e-9 - - -# --------------------------------------------------------------------------- -# full_coords memory fix: property (not cached) + _array_shape -# --------------------------------------------------------------------------- - - -class TestFullCoordsMemoryFix: - """Verify the full_coords / _array_shape memory-reduction change. - - full_coords was changed from @cached_property to @property so that the - large per-axis coordinate arrays are not held in memory between - __getitem__ calls. _array_shape replaces the per-call zarr shape read - with a compact cached tuple. - """ - - @pytest.fixture - def image(self, tmp_path): - data = create_test_image_data((32, 32, 32), pattern="gradient") - path = tmp_path / "test_image.zarr" - create_test_zarr_array(path, data, scale=(4.0, 4.0, 4.0)) - return CellMapImage( - path=str(path), - target_class="test_class", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - ) - - # ------------------------------------------------------------------ - # _array_shape: cached, compact - # ------------------------------------------------------------------ - - def test_array_shape_is_cached(self, image): - """_array_shape must be a @cached_property (stored in __dict__).""" - _ = image._array_shape - assert "_array_shape" in image.__dict__ - - def test_array_shape_is_tuple_of_ints(self, image): - """_array_shape must be a tuple of plain Python ints.""" - shape = image._array_shape - assert isinstance(shape, tuple) - assert all(isinstance(s, int) for s in shape) - - def test_array_shape_matches_source_array(self, image): - """_array_shape must match the underlying zarr array dimensions.""" - shape = image._array_shape - assert shape == (32, 32, 32) - - def test_array_shape_same_object_on_repeated_access(self, image): - """_array_shape returns the same tuple object (cached, not recomputed).""" - s1 = image._array_shape - s2 = image._array_shape - assert s1 is s2 - - # ------------------------------------------------------------------ - # full_coords: NOT cached between calls - # ------------------------------------------------------------------ - - def test_full_coords_not_in_dict_after_bounding_box(self, image): - """After bounding_box initialises, full_coords must NOT be in __dict__. - - bounding_box is the primary consumer of full_coords during setup; - the fix requires that the large coord arrays are freed immediately - after bounding_box is cached. - """ - _ = image.bounding_box # triggers full_coords access internally - assert "full_coords" not in image.__dict__ - - def test_full_coords_not_cached_after_getitem(self, image): - """After a __getitem__ call, full_coords must not be in __dict__.""" - center = {"z": 64.0, "y": 64.0, "x": 64.0} - _ = image[center] - assert "full_coords" not in image.__dict__ - - def test_full_coords_returns_new_object_each_access(self, image): - """full_coords must produce a new tuple on every call (not cached).""" - fc1 = image.full_coords - fc2 = image.full_coords - assert fc1 is not fc2 - - def test_full_coords_values_consistent(self, image): - """Repeated calls to full_coords must return equivalent coordinate values.""" - import numpy as np - - fc1 = image.full_coords - fc2 = image.full_coords - assert len(fc1) == len(fc2) - for da1, da2 in zip(fc1, fc2): - np.testing.assert_array_equal(da1.values, da2.values) - assert da1.dims == da2.dims - - # ------------------------------------------------------------------ - # shape property still correct (now delegates to _array_shape) - # ------------------------------------------------------------------ - - def test_shape_property_correct(self, image): - """shape must still return the correct axis→size mapping.""" - shape = image.shape - assert isinstance(shape, dict) - for axis in image.axes: - assert axis in shape - assert shape[axis] == 32 - - def test_shape_uses_array_shape(self, image): - """shape values must match _array_shape elements.""" - arr_shape = image._array_shape - for s, axis in zip(arr_shape, image.axes): - assert image.shape[axis] == s - - # ------------------------------------------------------------------ - # bounding_box correctness preserved after the refactor - # ------------------------------------------------------------------ - - def test_bounding_box_still_correct_after_refactor(self, image): - """bounding_box must still return valid min/max per axis.""" - bbox = image.bounding_box - assert set(bbox.keys()) == set(image.axes) - for axis in image.axes: - lo, hi = bbox[axis] - assert lo <= hi diff --git a/tests/test_init_optimizations.py b/tests/test_init_optimizations.py deleted file mode 100644 index 1835a6a..0000000 --- a/tests/test_init_optimizations.py +++ /dev/null @@ -1,532 +0,0 @@ -""" -Tests for initialization optimizations added to CellMapDataset and -CellMapMultiDataset. - -Covers: - - force_has_data=True sets has_data immediately (no class_counts read) - - bounding_box / sampling_box parallel computation: correctness and cleanup - - CellMapMultiDataset.class_counts parallel execution: correct aggregation, - exception propagation, CELLMAP_MAX_WORKERS env-var respected - - _ImmediateExecutor: submit/map correctness (Windows+TensorStore drop-in) - - Immediate executor code paths in bounding_box, sampling_box, and - CellMapMultiDataset.class_counts (simulated via monkeypatching) - - Consistency: dataset.py and multidataset.py share the same - _USE_IMMEDIATE_EXECUTOR flag -""" - -from unittest.mock import PropertyMock, patch - -import pytest - -from cellmap_data import CellMapDataset, CellMapMultiDataset -from cellmap_data.image import CellMapImage - -from .test_helpers import create_test_dataset - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def single_dataset_config(tmp_path): - return create_test_dataset( - tmp_path, - raw_shape=(32, 32, 32), - num_classes=2, - raw_scale=(4.0, 4.0, 4.0), - seed=0, - ) - - -@pytest.fixture -def multi_source_dataset(tmp_path): - """Dataset with two input arrays and two target arrays (four CellMapImage - objects), so the parallel bounding_box / sampling_box paths receive more - than one source to map over.""" - config = create_test_dataset( - tmp_path, - raw_shape=(32, 32, 32), - num_classes=2, - raw_scale=(4.0, 4.0, 4.0), - seed=7, - ) - return CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={ - "raw_4nm": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}, - "raw_8nm": {"shape": (8, 8, 8), "scale": (8.0, 8.0, 8.0)}, - }, - target_arrays={ - "gt_4nm": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}, - "gt_8nm": {"shape": (8, 8, 8), "scale": (8.0, 8.0, 8.0)}, - }, - force_has_data=True, - ) - - -@pytest.fixture -def three_datasets(tmp_path): - datasets = [] - for i in range(3): - config = create_test_dataset( - tmp_path / f"ds_{i}", - raw_shape=(32, 32, 32), - num_classes=2, - raw_scale=(4.0, 4.0, 4.0), - seed=i, - ) - ds = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - ) - datasets.append(ds) - return datasets - - -# --------------------------------------------------------------------------- -# force_has_data -# --------------------------------------------------------------------------- - - -class TestForceHasData: - """force_has_data=True should set has_data=True at construction time - without ever accessing CellMapImage.class_counts.""" - - def test_has_data_true_when_force_set(self, single_dataset_config): - """has_data is True immediately after __init__ when force_has_data=True.""" - config = single_dataset_config - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - ) - assert dataset.has_data is True - - def test_class_counts_not_accessed_when_force_has_data(self, single_dataset_config): - """CellMapImage.class_counts must never be accessed during __init__ - when force_has_data=True.""" - config = single_dataset_config - with patch.object( - CellMapImage, "class_counts", new_callable=PropertyMock, return_value=100.0 - ) as mock_counts: - CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - ) - mock_counts.assert_not_called() - - def test_class_counts_accessed_without_force_has_data(self, single_dataset_config): - """Without force_has_data, class_counts IS accessed (inverse check).""" - config = single_dataset_config - with patch.object( - CellMapImage, "class_counts", new_callable=PropertyMock, return_value=100.0 - ) as mock_counts: - CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=False, - ) - mock_counts.assert_called() - - def test_has_data_false_without_force_for_empty_data(self, tmp_path): - """Without force_has_data and with all-zero target data, has_data=False.""" - import numpy as np - from .test_helpers import create_test_zarr_array, create_test_image_data - - # Raw array - raw_data = create_test_image_data((16, 16, 16), pattern="random") - create_test_zarr_array(tmp_path / "dataset.zarr" / "raw", raw_data) - - # All-zero target → class_counts == 0 → has_data stays False - zero_data = np.zeros((16, 16, 16), dtype=np.uint8) - create_test_zarr_array( - tmp_path / "dataset.zarr" / "class_0", - zero_data, - absent=zero_data.size, # all absent - ) - - dataset = CellMapDataset( - raw_path=str(tmp_path / "dataset.zarr" / "raw"), - target_path=str(tmp_path / "dataset.zarr" / "[class_0]"), - classes=["class_0"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (1.0, 1.0, 1.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (1.0, 1.0, 1.0)}}, - force_has_data=False, - ) - assert not dataset.has_data - - -# --------------------------------------------------------------------------- -# bounding_box / sampling_box parallel computation -# --------------------------------------------------------------------------- - - -class TestParallelBoundingBox: - """bounding_box and sampling_box must give correct results when computed - in parallel across multiple CellMapImage sources.""" - - def test_bounding_box_correct_with_multiple_sources(self, multi_source_dataset): - bbox = multi_source_dataset.bounding_box - assert isinstance(bbox, dict) - for axis in multi_source_dataset.axis_order: - assert axis in bbox - lo, hi = bbox[axis] - assert lo <= hi - - def test_sampling_box_correct_with_multiple_sources(self, multi_source_dataset): - sbox = multi_source_dataset.sampling_box - assert isinstance(sbox, dict) - for axis in multi_source_dataset.axis_order: - assert axis in sbox - assert len(sbox[axis]) == 2 - - def test_bounding_box_consistent_with_single_source(self, single_dataset_config): - """Sequential vs. parallel should yield the same bounding box.""" - config = single_dataset_config - - def make_dataset(): - return CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - ) - - ds1 = make_dataset() - ds2 = make_dataset() - - bbox1 = ds1.bounding_box - bbox2 = ds2.bounding_box - - for axis in ds1.axis_order: - assert pytest.approx(bbox1[axis][0]) == bbox2[axis][0] - assert pytest.approx(bbox1[axis][1]) == bbox2[axis][1] - - def test_sampling_box_inside_bounding_box(self, multi_source_dataset): - """The sampling box must be a sub-region of (or equal to) the bounding box.""" - bbox = multi_source_dataset.bounding_box - sbox = multi_source_dataset.sampling_box - for axis in multi_source_dataset.axis_order: - assert sbox[axis][0] >= bbox[axis][0] - 1e-9 - assert sbox[axis][1] <= bbox[axis][1] + 1e-9 - - def test_bounding_box_pool_does_not_leak_threads(self, single_dataset_config): - """Accessing bounding_box twice on fresh datasets should not raise even - if the pool from the first call was already shut down.""" - config = single_dataset_config - - for _ in range(2): - ds = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - ) - bbox = ds.bounding_box - assert bbox is not None - - -# --------------------------------------------------------------------------- -# CellMapMultiDataset.class_counts parallel execution -# --------------------------------------------------------------------------- - - -class TestMultiDatasetClassCountsParallel: - """Parallel class_counts must aggregate correctly and behave robustly.""" - - def test_totals_equal_sum_of_individual_datasets(self, three_datasets): - """Aggregated totals must equal the element-wise sum of each dataset's - class_counts["totals"].""" - classes = ["class_0", "class_1"] - multi = CellMapMultiDataset( - classes=classes, - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets=three_datasets, - ) - - # Compute expected totals by summing across individual datasets. - expected: dict[str, float] = {c: 0.0 for c in classes} - expected.update({c + "_bg": 0.0 for c in classes}) - for ds in three_datasets: - for key in expected: - expected[key] += ds.class_counts["totals"].get(key, 0.0) - - actual = multi.class_counts["totals"] - for key, val in expected.items(): - assert pytest.approx(actual[key], rel=1e-6) == val - - def test_class_counts_has_totals_key(self, three_datasets): - multi = CellMapMultiDataset( - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets=three_datasets, - ) - counts = multi.class_counts - assert "totals" in counts - for c in ["class_0", "class_1", "class_0_bg", "class_1_bg"]: - assert c in counts["totals"] - - def test_exception_from_dataset_propagates(self, three_datasets): - """If any dataset's class_counts raises, the exception must propagate - out of CellMapMultiDataset.class_counts (via future.result()).""" - multi = CellMapMultiDataset( - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets=three_datasets, - ) - - with patch.object( - CellMapDataset, - "class_counts", - new_callable=PropertyMock, - side_effect=RuntimeError("simulated failure"), - ): - with pytest.raises(RuntimeError, match="simulated failure"): - _ = multi.class_counts - - def test_max_workers_env_var_respected(self, three_datasets, monkeypatch): - """CELLMAP_MAX_WORKERS is the cap on the number of worker threads.""" - monkeypatch.setenv("CELLMAP_MAX_WORKERS", "1") - - multi = CellMapMultiDataset( - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets=three_datasets, - ) - # Should still produce correct results with a single worker - counts = multi.class_counts - assert "totals" in counts - - def test_single_dataset_multidataset(self, tmp_path): - """Edge case: a multi-dataset with one child returns that child's counts.""" - config = create_test_dataset( - tmp_path, raw_shape=(32, 32, 32), num_classes=2, raw_scale=(4.0, 4.0, 4.0) - ) - ds = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - ) - multi = CellMapMultiDataset( - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets=[ds], - ) - multi_totals = multi.class_counts["totals"] - ds_totals = ds.class_counts["totals"] - for c in ["class_0", "class_1"]: - assert pytest.approx(multi_totals[c]) == ds_totals.get(c, 0.0) - assert pytest.approx(multi_totals[c + "_bg"]) == ds_totals.get( - c + "_bg", 0.0 - ) - - def test_empty_classes_list(self, three_datasets): - """An empty classes list produces an empty totals dict without error.""" - multi = CellMapMultiDataset( - classes=[], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={}, - datasets=three_datasets, - ) - counts = multi.class_counts - assert counts["totals"] == {} - - -# --------------------------------------------------------------------------- -# _ImmediateExecutor unit tests -# --------------------------------------------------------------------------- - - -class TestImmediateExecutor: - """Unit tests for _ImmediateExecutor. - - _ImmediateExecutor is the Windows+TensorStore drop-in that runs every - submitted callable synchronously in the calling thread. It must satisfy - the same interface as ThreadPoolExecutor so all existing call sites - (submit, map, as_completed, shutdown) work without modification. - """ - - @pytest.fixture - def executor(self): - from cellmap_data.dataset import _ImmediateExecutor - - return _ImmediateExecutor() - - def test_submit_executes_synchronously(self, executor): - """submit() runs the callable before returning; the future is already done.""" - calls = [] - future = executor.submit(calls.append, 99) - assert future.done(), "Future should be resolved immediately" - assert calls == [99], "Callable should have run synchronously" - - def test_submit_returns_correct_result(self, executor): - """submit() stores the return value in the future.""" - future = executor.submit(lambda x, y: x + y, 3, 4) - assert future.result() == 7 - - def test_submit_captures_exception(self, executor): - """Exceptions raised by the callable are stored, not propagated.""" - future = executor.submit(lambda: 1 / 0) - assert future.exception() is not None - assert isinstance(future.exception(), ZeroDivisionError) - - def test_map_returns_results_in_order(self, executor): - """map() returns results in the same order as the input iterable.""" - results = list(executor.map(lambda x: x * 2, [1, 2, 3, 4])) - assert results == [2, 4, 6, 8] - - def test_map_with_lambda(self, executor): - """map() works with lambda functions, matching the bounding_box usage.""" - items = [{"v": i} for i in range(5)] - results = list(executor.map(lambda d: d["v"], items)) - assert results == list(range(5)) - - def test_map_propagates_exception(self, executor): - """Exceptions from map() propagate when the result is consumed.""" - with pytest.raises(ZeroDivisionError): - list(executor.map(lambda x: 1 / x, [1, 0, 2])) - - def test_shutdown_is_noop(self, executor): - """shutdown() must not raise even when called multiple times.""" - executor.shutdown(wait=True) - executor.shutdown(wait=False, cancel_futures=True) - - def test_as_completed_works_with_submit(self, executor): - """Futures from submit() are compatible with as_completed().""" - from concurrent.futures import as_completed - - futures = [executor.submit(lambda i=i: i * 3, i) for i in range(5)] - results = {f.result() for f in as_completed(futures)} - assert results == {0, 3, 6, 9, 12} - - def test_is_executor_subclass(self): - """_ImmediateExecutor must be a subclass of concurrent.futures.Executor - so it satisfies the Executor interface including map().""" - from concurrent.futures import Executor - - from cellmap_data.dataset import _ImmediateExecutor - - assert issubclass(_ImmediateExecutor, Executor) - - -# --------------------------------------------------------------------------- -# Immediate executor code paths (simulated via monkeypatching) -# --------------------------------------------------------------------------- - - -class TestImmediateExecutorPaths: - """Verify that bounding_box, sampling_box, and CellMapMultiDataset.class_counts - work correctly when _USE_IMMEDIATE_EXECUTOR is True. - - These tests simulate the Windows+TensorStore environment on any platform - by monkeypatching the module-level flag and singleton executor. - """ - - @pytest.fixture - def patched_immediate(self, monkeypatch): - """Patch dataset module to act as if running on Windows+TensorStore.""" - import cellmap_data.dataset as ds_module - from cellmap_data.dataset import _ImmediateExecutor - - monkeypatch.setattr(ds_module, "_USE_IMMEDIATE_EXECUTOR", True) - monkeypatch.setattr(ds_module, "_IMMEDIATE_EXECUTOR", _ImmediateExecutor()) - - def test_bounding_box_uses_immediate_executor( - self, single_dataset_config, patched_immediate - ): - """bounding_box must work via executor.map() when using _ImmediateExecutor.""" - config = single_dataset_config - ds = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - ) - from cellmap_data.dataset import _ImmediateExecutor - - assert isinstance(ds.executor, _ImmediateExecutor) - bbox = ds.bounding_box - assert isinstance(bbox, dict) - for axis in ds.axis_order: - assert axis in bbox - lo, hi = bbox[axis] - assert lo <= hi - - def test_sampling_box_uses_immediate_executor( - self, single_dataset_config, patched_immediate - ): - """sampling_box must work via executor.map() when using _ImmediateExecutor.""" - config = single_dataset_config - ds = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - ) - sbox = ds.sampling_box - assert isinstance(sbox, dict) - for axis in ds.axis_order: - assert axis in sbox - assert len(sbox[axis]) == 2 - - def test_getitem_uses_immediate_executor( - self, single_dataset_config, patched_immediate - ): - """__getitem__ must work when _ImmediateExecutor is active.""" - config = single_dataset_config - ds = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - ) - result = ds[0] - assert isinstance(result, dict) - assert "raw" in result - - def test_multidataset_class_counts_is_sequential(self, three_datasets): - """CellMapMultiDataset.class_counts runs sequentially (no thread pool) - and returns well-formed totals.""" - multi = CellMapMultiDataset( - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets=three_datasets, - ) - counts = multi.class_counts - assert "totals" in counts - for c in ["class_0", "class_1", "class_0_bg", "class_1_bg"]: - assert c in counts["totals"] diff --git a/tests/test_integration.py b/tests/test_integration.py deleted file mode 100644 index 9c6354f..0000000 --- a/tests/test_integration.py +++ /dev/null @@ -1,447 +0,0 @@ -""" -Integration tests for complete workflows. - -Tests end-to-end workflows combining multiple components. -""" - -import torch -import torchvision.transforms.v2 as T - -from cellmap_data import ( - CellMapDataLoader, - CellMapDataset, - CellMapDataSplit, - CellMapMultiDataset, -) -from cellmap_data.transforms import Binarize, GaussianNoise - -from .test_helpers import create_test_dataset - - -class TestTrainingWorkflow: - """Integration tests for complete training workflows.""" - - def test_basic_training_setup(self, tmp_path): - """Test basic training pipeline setup.""" - # Create dataset - config = create_test_dataset( - tmp_path, - raw_shape=(64, 64, 64), - num_classes=3, - raw_scale=(8.0, 8.0, 8.0), - ) - - # Configure arrays - input_arrays = {"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}} - target_arrays = {"gt": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}} - - # Configure transforms - spatial_transforms = { - "mirror": {"axes": {"x": 0.5, "y": 0.5}}, - "rotate": {"axes": {"z": [-45, 45]}}, - } - - raw_transforms = T.Compose( - [ - T.ToDtype(torch.float, scale=True), - GaussianNoise(std=0.05), - ] - ) - - target_transforms = T.Compose( - [ - Binarize(threshold=0.5), - ] - ) - - # Create dataset - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - spatial_transforms=spatial_transforms, - raw_value_transforms=raw_transforms, - target_value_transforms=target_transforms, - is_train=True, - force_has_data=True, - ) - - # Create loader - loader = CellMapDataLoader( - dataset, - batch_size=4, - num_workers=0, - weighted_sampler=True, - ) - - assert dataset is not None - assert loader is not None - - def test_train_validation_split_workflow(self, tmp_path): - """Test complete train/validation split workflow.""" - # Create training and validation datasets - train_config = create_test_dataset( - tmp_path / "train", - raw_shape=(64, 64, 64), - num_classes=2, - seed=42, - ) - - val_config = create_test_dataset( - tmp_path / "val", - raw_shape=(64, 64, 64), - num_classes=2, - seed=100, - ) - - # Configure dataset split - dataset_dict = { - "train": [{"raw": train_config["raw_path"], "gt": train_config["gt_path"]}], - "validate": [{"raw": val_config["raw_path"], "gt": val_config["gt_path"]}], - } - - input_arrays = {"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}} - target_arrays = {"gt": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}} - - # Training transforms - spatial_transforms = { - "mirror": {"axes": {"x": 0.5}}, - } - - datasplit = CellMapDataSplit( - dataset_dict=dataset_dict, - classes=["class_0", "class_1"], - input_arrays=input_arrays, - target_arrays=target_arrays, - spatial_transforms=spatial_transforms, - pad=True, - ) - - assert datasplit is not None - - def test_multi_dataset_training(self, tmp_path): - """Test training with multiple datasets.""" - # Create multiple datasets - configs = [] - datasets = [] - - for i in range(3): - config = create_test_dataset( - tmp_path / f"dataset_{i}", - raw_shape=(48, 48, 48), - num_classes=2, - seed=42 + i, - ) - configs.append(config) - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"gt": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - is_train=True, - force_has_data=True, - ) - datasets.append(dataset) - - # Combine into multi-dataset - multi_dataset = CellMapMultiDataset( - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"gt": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - datasets=datasets, - ) - - # Create loader - loader = CellMapDataLoader( - multi_dataset, - batch_size=4, - num_workers=0, - weighted_sampler=True, - ) - - assert len(multi_dataset.datasets) == 3 - assert loader is not None - - def test_multiscale_training_setup(self, tmp_path): - """Test training with multiscale inputs.""" - config = create_test_dataset( - tmp_path, - raw_shape=(64, 64, 64), - num_classes=2, - ) - - # Multiple scales - input_arrays = { - "raw_4nm": {"shape": (32, 32, 32), "scale": (4.0, 4.0, 4.0)}, - "raw_8nm": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}, - } - - target_arrays = {"gt": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}} - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - ) - - loader = CellMapDataLoader(dataset, batch_size=2, num_workers=0) - - assert "raw_4nm" in dataset.input_arrays - assert "raw_8nm" in dataset.input_arrays - assert loader is not None - - -class TestTransformPipeline: - """Integration tests for transform pipelines.""" - - def test_complete_augmentation_pipeline(self, tmp_path): - """Test complete augmentation pipeline.""" - from cellmap_data.transforms import ( - Binarize, - GaussianNoise, - NaNtoNum, - RandomContrast, - RandomGamma, - ) - - config = create_test_dataset( - tmp_path, - raw_shape=(48, 48, 48), - num_classes=2, - ) - - # Complex transform pipeline - raw_transforms = T.Compose( - [ - NaNtoNum({"nan": 0.0}), - T.ToDtype(torch.float, scale=True), - GaussianNoise(std=0.05), - RandomContrast(contrast_range=(0.8, 1.2)), - RandomGamma(gamma_range=(0.8, 1.2)), - ] - ) - - target_transforms = T.Compose( - [ - Binarize(threshold=0.5), - T.ToDtype(torch.float32), - ] - ) - - # Spatial transforms must come first - spatial_transforms = { - "mirror": {"axes": {"x": 0.5, "y": 0.5, "z": 0.2}}, - "rotate": {"axes": {"z": [-180, 180]}}, - "transpose": {"axes": ["x", "y"]}, - } - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"gt": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - spatial_transforms=spatial_transforms, - raw_value_transforms=raw_transforms, - target_value_transforms=target_transforms, - is_train=True, - force_has_data=True, - ) - - loader = CellMapDataLoader(dataset, batch_size=2, num_workers=0) - - assert dataset.spatial_transforms is not None - assert dataset.raw_value_transforms is not None - assert loader is not None - - def test_per_target_transforms(self, tmp_path): - """Test different transforms per target array.""" - config = create_test_dataset( - tmp_path, - raw_shape=(48, 48, 48), - num_classes=2, - ) - - # Different transforms for different targets - target_transforms = { - "labels": T.Compose([Binarize(threshold=0.5)]), - "distances": T.Compose([T.ToDtype(torch.float, scale=True)]), - } - - target_arrays = { - "labels": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}, - "distances": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}, - } - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_arrays=target_arrays, - target_value_transforms=target_transforms, - ) - - assert dataset.target_value_transforms is not None - - -class TestDataLoaderOptimization: - """Integration tests for data loader optimizations.""" - - def test_memory_optimization_settings(self, tmp_path): - """Test memory-optimized loader configuration.""" - config = create_test_dataset( - tmp_path, - raw_shape=(64, 64, 64), - num_classes=2, - ) - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"gt": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, - ) - - # Optimized loader settings - loader = CellMapDataLoader( - dataset, - batch_size=8, - num_workers=2, - pin_memory=True, - persistent_workers=True, - prefetch_factor=4, - ) - - assert loader is not None - - def test_weighted_sampling_integration(self, tmp_path): - """Test weighted sampling for class balance.""" - config = create_test_dataset( - tmp_path, - raw_shape=(64, 64, 64), - num_classes=3, - label_pattern="regions", # Creates imbalanced classes - ) - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"gt": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - is_train=True, - force_has_data=True, - ) - - # Use weighted sampler to balance classes - loader = CellMapDataLoader( - dataset, - batch_size=4, - num_workers=0, - weighted_sampler=True, - ) - - assert loader is not None - - def test_iterations_per_epoch_large_dataset(self, tmp_path): - """Test limited iterations for large datasets.""" - config = create_test_dataset( - tmp_path, - raw_shape=(128, 128, 128), # Larger dataset - num_classes=2, - ) - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"gt": {"shape": (32, 32, 32), "scale": (8.0, 8.0, 8.0)}}, - ) - - # Limit iterations per epoch - loader = CellMapDataLoader( - dataset, - batch_size=4, - num_workers=0, - iterations_per_epoch=50, # Only 50 batches per epoch - ) - - assert loader is not None - - -class TestEdgeCases: - """Integration tests for edge cases and special scenarios.""" - - def test_small_dataset(self, tmp_path): - """Test with very small dataset.""" - config = create_test_dataset( - tmp_path, - raw_shape=(16, 16, 16), # Small - num_classes=2, - ) - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - pad=True, # Need padding for small dataset - ) - - loader = CellMapDataLoader(dataset, batch_size=1, num_workers=0) - - assert dataset.pad is True - assert loader is not None - - def test_single_class(self, tmp_path): - """Test with single class.""" - config = create_test_dataset( - tmp_path, - raw_shape=(32, 32, 32), - num_classes=1, - ) - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - target_arrays={"gt": {"shape": (16, 16, 16), "scale": (8.0, 8.0, 8.0)}}, - ) - - loader = CellMapDataLoader(dataset, batch_size=2, num_workers=0) - - assert len(dataset.classes) == 1 - assert loader is not None - - def test_anisotropic_data(self, tmp_path): - """Test with anisotropic voxel sizes.""" - config = create_test_dataset( - tmp_path, - raw_shape=(32, 64, 64), - raw_scale=(16.0, 4.0, 4.0), # Anisotropic - num_classes=2, - ) - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (16, 32, 32), "scale": (16.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (16, 32, 32), "scale": (16.0, 4.0, 4.0)}}, - ) - - loader = CellMapDataLoader(dataset, batch_size=2, num_workers=0) - - assert dataset.input_arrays["raw"]["scale"] == (16.0, 4.0, 4.0) - assert loader is not None diff --git a/tests/test_memory_management.py b/tests/test_memory_management.py deleted file mode 100644 index 7cf889b..0000000 --- a/tests/test_memory_management.py +++ /dev/null @@ -1,217 +0,0 @@ -""" -Tests for memory management in CellMapImage. - -Specifically tests the array cache clearing mechanism to prevent memory leaks. -""" - -import pytest -from cellmap_data import CellMapImage -from .test_helpers import create_test_image_data, create_test_zarr_array - - -class TestMemoryManagement: - """Test memory management features.""" - - @pytest.fixture - def test_zarr_image(self, tmp_path): - """Create a test Zarr image.""" - data = create_test_image_data((32, 32, 32), pattern="gradient") - path = tmp_path / "test_image.zarr" - create_test_zarr_array(path, data, scale=(4.0, 4.0, 4.0)) - return str(path), data - - def test_array_cache_cleared_after_getitem(self, test_zarr_image): - """Test that array cache is cleared after __getitem__ to prevent memory leaks.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - axis_order="zyx", - ) - - # Access array to populate cache - _ = image.array - assert "array" in image.__dict__, "Array should be cached after first access" - - # Call __getitem__ which should clear the cache - center = {"z": 64.0, "y": 64.0, "x": 64.0} - _ = image[center] - - # Check that cache was cleared - assert ( - "array" not in image.__dict__ - ), "Array cache should be cleared after __getitem__" - - def test_array_cache_repopulates_after_clearing(self, test_zarr_image): - """Test that array cache can be repopulated after being cleared.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - axis_order="zyx", - ) - - # First access - center = {"z": 64.0, "y": 64.0, "x": 64.0} - data1 = image[center] - - # Array cache should be cleared - assert "array" not in image.__dict__ - - # Second access - should work without errors (cache will be repopulated) - data2 = image[center] - - # Both should produce valid tensors - assert data1.shape == data2.shape - assert data1.dtype == data2.dtype - - def test_clear_array_cache_method(self, test_zarr_image): - """Test the _clear_array_cache method directly.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - ) - - # Populate cache - _ = image.array - assert "array" in image.__dict__ - - # Clear cache - image._clear_array_cache() - assert "array" not in image.__dict__ - - # Clearing when not cached should not raise an error - image._clear_array_cache() # Should be a no-op - - def test_multiple_getitem_calls_clear_cache_each_time(self, test_zarr_image): - """Test that cache is cleared on every __getitem__ call.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - ) - - centers = [ - {"z": 48.0, "y": 48.0, "x": 48.0}, - {"z": 64.0, "y": 64.0, "x": 64.0}, - {"z": 80.0, "y": 80.0, "x": 80.0}, - ] - - for center in centers: - _ = image[center] - # Cache should be cleared after each call - assert ( - "array" not in image.__dict__ - ), f"Array cache should be cleared after accessing center {center}" - - def test_cache_clearing_with_spatial_transforms(self, test_zarr_image): - """Test that cache is cleared even with spatial transforms.""" - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - ) - - # Set spatial transforms - image.set_spatial_transforms({"mirror": {"x": True}, "rotate": {"z": 15}}) - - center = {"z": 64.0, "y": 64.0, "x": 64.0} - _ = image[center] - - # Cache should still be cleared - assert "array" not in image.__dict__ - - def test_cache_clearing_with_value_transforms(self, test_zarr_image): - """Test that cache is cleared when value transforms are applied.""" - path, _ = test_zarr_image - - def normalize(x): - return x / 255.0 - - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - value_transform=normalize, - ) - - center = {"z": 64.0, "y": 64.0, "x": 64.0} - _ = image[center] - - # Cache should be cleared - assert "array" not in image.__dict__ - - def test_simulated_training_loop_memory(self, test_zarr_image): - """ - Simulate a training loop to verify cache is cleared on each iteration. - - This test simulates the memory leak scenario described in the issue: - repeated calls to __getitem__ should not accumulate memory from cached arrays. - """ - path, _ = test_zarr_image - - image = CellMapImage( - path=path, - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - ) - - # Simulate multiple training iterations - centers = [ - {"z": 48.0 + i * 4.0, "y": 48.0 + i * 4.0, "x": 48.0 + i * 4.0} - for i in range(10) - ] - - for i, center in enumerate(centers): - _ = image[center] - - # After each iteration, array cache should be cleared - assert ( - "array" not in image.__dict__ - ), f"Iteration {i}: Array cache should be cleared to prevent memory leak" - - def test_cache_clearing_with_interpolation(self, tmp_path): - """ - Test cache clearing when interpolation is used (the main memory leak scenario). - - When coords require interpolation (not simple float/int), the array.interp() - method creates intermediate arrays that could accumulate memory. - """ - data = create_test_image_data((32, 32, 32), pattern="gradient") - path = tmp_path / "test_interp.zarr" - create_test_zarr_array(path, data, scale=(4.0, 4.0, 4.0)) - - image = CellMapImage( - path=str(path), - target_class="test", - target_scale=(4.0, 4.0, 4.0), - target_voxel_shape=(8, 8, 8), - interpolation="linear", # Use linear interpolation to trigger interp() - ) - - # Use spatial transforms to trigger the interpolation code path - image.set_spatial_transforms({"rotate": {"z": 15}}) - - center = {"z": 64.0, "y": 64.0, "x": 64.0} - _ = image[center] - - # Cache should be cleared even after interpolation - assert "array" not in image.__dict__ diff --git a/tests/test_metadata.py b/tests/test_metadata.py deleted file mode 100644 index ddb1f74..0000000 --- a/tests/test_metadata.py +++ /dev/null @@ -1,291 +0,0 @@ -""" -Tests for utils/metadata.py. - -Tests OME-NGFF metadata generation, writing, and scale-level lookup. -""" - -import json -import os - -import numpy as np -import pytest -import zarr - -from cellmap_data.utils.metadata import ( - add_multiscale_metadata_levels, - create_multiscale_metadata, - find_level, - generate_base_multiscales_metadata, - write_metadata, -) - - -class TestGenerateBaseMultiscalesMetadata: - """Tests for generate_base_multiscales_metadata.""" - - def test_basic_structure(self): - z_attrs = generate_base_multiscales_metadata( - ds_name="my_dataset", - scale_level=0, - voxel_size=[4.0, 4.0, 4.0], - translation=[0.0, 0.0, 0.0], - units="nanometer", - axes=["z", "y", "x"], - ) - assert "multiscales" in z_attrs - assert len(z_attrs["multiscales"]) == 1 - ms = z_attrs["multiscales"][0] - assert ms["version"] == "0.4" - - def test_axes_populated(self): - z_attrs = generate_base_multiscales_metadata( - ds_name="test", - scale_level=0, - voxel_size=[8.0, 8.0, 8.0], - translation=[0.0, 0.0, 0.0], - units="nanometer", - axes=["z", "y", "x"], - ) - axes = z_attrs["multiscales"][0]["axes"] - assert len(axes) == 3 - axis_names = [a["name"] for a in axes] - assert axis_names == ["z", "y", "x"] - for a in axes: - assert a["type"] == "space" - assert a["unit"] == "nanometer" - - def test_dataset_path_uses_scale_level(self): - z_attrs = generate_base_multiscales_metadata( - ds_name="test", - scale_level=2, - voxel_size=[16.0, 16.0, 16.0], - translation=[0.0, 0.0, 0.0], - units="nanometer", - axes=["z", "y", "x"], - ) - datasets = z_attrs["multiscales"][0]["datasets"] - assert datasets[0]["path"] == "s2" - - def test_voxel_size_stored(self): - voxel_size = [4.0, 8.0, 16.0] - z_attrs = generate_base_multiscales_metadata( - ds_name="test", - scale_level=0, - voxel_size=voxel_size, - translation=[0.0, 0.0, 0.0], - units="nanometer", - axes=["z", "y", "x"], - ) - datasets = z_attrs["multiscales"][0]["datasets"] - transforms = datasets[0]["coordinateTransformations"] - scale_transform = next(t for t in transforms if t.get("type") == "scale") - assert scale_transform["scale"] == voxel_size - - def test_zarr_suffix_stripped_from_name(self): - z_attrs = generate_base_multiscales_metadata( - ds_name="some_path/dataset.zarr/subgroup", - scale_level=0, - voxel_size=[4.0, 4.0, 4.0], - translation=[0.0, 0.0, 0.0], - units="nanometer", - axes=["z", "y", "x"], - ) - name = z_attrs["multiscales"][0]["name"] - assert ".zarr" not in name - - def test_name_stored(self): - z_attrs = generate_base_multiscales_metadata( - ds_name="my_group", - scale_level=0, - voxel_size=[4.0, 4.0, 4.0], - translation=[0.0, 0.0, 0.0], - units="nanometer", - axes=["z", "y", "x"], - ) - assert z_attrs["multiscales"][0]["name"] == "my_group" - - def test_2d_axes(self): - z_attrs = generate_base_multiscales_metadata( - ds_name="2d_test", - scale_level=0, - voxel_size=[4.0, 4.0], - translation=[0.0, 0.0], - units="nanometer", - axes=["y", "x"], - ) - axes = z_attrs["multiscales"][0]["axes"] - assert len(axes) == 2 - - -class TestAddMultiscaleMetadataLevels: - """Tests for add_multiscale_metadata_levels.""" - - @pytest.fixture - def base_metadata(self): - return generate_base_multiscales_metadata( - ds_name="test", - scale_level=0, - voxel_size=[4.0, 4.0, 4.0], - translation=[0.0, 0.0, 0.0], - units="nanometer", - axes=["z", "y", "x"], - ) - - def test_adds_correct_number_of_levels(self, base_metadata): - result = add_multiscale_metadata_levels(base_metadata, 0, 3) - datasets = result["multiscales"][0]["datasets"] - # Started with 1 level (s0), added 3 more (s1, s2, s3) - assert len(datasets) == 4 - - def test_added_paths_sequential(self, base_metadata): - result = add_multiscale_metadata_levels(base_metadata, 0, 2) - datasets = result["multiscales"][0]["datasets"] - paths = [d["path"] for d in datasets] - assert "s1" in paths - assert "s2" in paths - - def test_scale_formula(self, base_metadata): - # With base_scale_level=1, the added level uses pow(2, 1)=2, so scale doubles - result = add_multiscale_metadata_levels(base_metadata, 1, 1) - datasets = result["multiscales"][0]["datasets"] - s0_scale = datasets[0]["coordinateTransformations"][0]["scale"] - s1_scale = datasets[1]["coordinateTransformations"][0]["scale"] - # Formula: sn = dim * pow(2, level) where level=1 - for i in range(len(s0_scale)): - assert s1_scale[i] == pytest.approx(s0_scale[i] * 2, rel=1e-5) - - def test_zero_levels_adds_nothing(self, base_metadata): - original_count = len(base_metadata["multiscales"][0]["datasets"]) - result = add_multiscale_metadata_levels(base_metadata, 0, 0) - assert len(result["multiscales"][0]["datasets"]) == original_count - - -class TestCreateMultiscaleMetadata: - """Tests for create_multiscale_metadata.""" - - def test_returns_metadata_without_outpath(self): - result = create_multiscale_metadata( - ds_name="test", - voxel_size=[4.0, 4.0, 4.0], - translation=[0.0, 0.0, 0.0], - units="nanometer", - axes=["z", "y", "x"], - ) - assert result is not None - assert "multiscales" in result - - def test_with_extra_levels(self): - result = create_multiscale_metadata( - ds_name="test", - voxel_size=[4.0, 4.0, 4.0], - translation=[0.0, 0.0, 0.0], - units="nanometer", - axes=["z", "y", "x"], - levels_to_add=2, - ) - datasets = result["multiscales"][0]["datasets"] - assert len(datasets) == 3 - - def test_writes_to_file(self, tmp_path): - out_path = str(tmp_path / "zattrs.json") - result = create_multiscale_metadata( - ds_name="test", - voxel_size=[4.0, 4.0, 4.0], - translation=[0.0, 0.0, 0.0], - units="nanometer", - axes=["z", "y", "x"], - out_path=out_path, - ) - # When out_path given, should return None and write file - assert result is None - assert os.path.exists(out_path) - with open(out_path) as f: - data = json.load(f) - assert "multiscales" in data - - -class TestWriteMetadata: - """Tests for write_metadata.""" - - def test_writes_valid_json(self, tmp_path): - z_attrs = {"multiscales": [{"version": "0.4", "name": "test"}]} - out_path = str(tmp_path / "metadata.json") - write_metadata(z_attrs, out_path) - assert os.path.exists(out_path) - with open(out_path) as f: - loaded = json.load(f) - assert loaded == z_attrs - - def test_overwrites_existing_file(self, tmp_path): - out_path = str(tmp_path / "metadata.json") - write_metadata({"version": "old"}, out_path) - write_metadata({"version": "new"}, out_path) - with open(out_path) as f: - loaded = json.load(f) - assert loaded["version"] == "new" - - def test_indented_output(self, tmp_path): - z_attrs = {"multiscales": [{"version": "0.4"}]} - out_path = str(tmp_path / "indented.json") - write_metadata(z_attrs, out_path) - with open(out_path) as f: - content = f.read() - # Should be pretty-printed (indented) - assert "\n" in content - - -class TestFindLevel: - """Tests for find_level.""" - - @pytest.fixture - def multiscale_zarr(self, tmp_path): - """Create a Zarr group with multiple scale levels.""" - store = zarr.DirectoryStore(str(tmp_path / "test.zarr")) - root = zarr.group(store=store, overwrite=True) - - # Create two scale levels - root.create_dataset("s0", data=np.zeros((64, 64, 64), dtype=np.float32)) - root.create_dataset("s1", data=np.zeros((32, 32, 32), dtype=np.float32)) - - root.attrs["multiscales"] = [ - { - "version": "0.4", - "axes": [ - {"name": "z", "type": "space", "unit": "nanometer"}, - {"name": "y", "type": "space", "unit": "nanometer"}, - {"name": "x", "type": "space", "unit": "nanometer"}, - ], - "datasets": [ - { - "path": "s0", - "coordinateTransformations": [ - {"type": "scale", "scale": [4.0, 4.0, 4.0]}, - {"type": "translation", "translation": [0.0, 0.0, 0.0]}, - ], - }, - { - "path": "s1", - "coordinateTransformations": [ - {"type": "scale", "scale": [8.0, 8.0, 8.0]}, - {"type": "translation", "translation": [2.0, 2.0, 2.0]}, - ], - }, - ], - } - ] - return str(tmp_path / "test.zarr") - - def test_find_fine_level(self, multiscale_zarr): - # Target scale smaller than s0 -> should return s0 - level = find_level(multiscale_zarr, {"z": 2.0, "y": 2.0, "x": 2.0}) - assert level == "s0" - - def test_find_coarse_level(self, multiscale_zarr): - # Target scale between s0 and s1 -> should return s0 (last level not exceeding target) - level = find_level(multiscale_zarr, {"z": 6.0, "y": 6.0, "x": 6.0}) - assert level == "s0" - - def test_find_last_level(self, multiscale_zarr): - # Target scale larger than all levels -> should return last level - level = find_level(multiscale_zarr, {"z": 100.0, "y": 100.0, "x": 100.0}) - assert level == "s1" diff --git a/tests/test_multidataset.py b/tests/test_multidataset.py new file mode 100644 index 0000000..bb4b95c --- /dev/null +++ b/tests/test_multidataset.py @@ -0,0 +1,96 @@ +"""Tests for CellMapMultiDataset.""" + +from __future__ import annotations + +import numpy as np +import torch + +from cellmap_data import CellMapDataset, CellMapMultiDataset + +from .test_helpers import create_test_dataset + +INPUT_ARRAYS = {"raw": {"shape": (4, 4, 4), "scale": (8.0, 8.0, 8.0)}} +TARGET_ARRAYS = {"labels": {"shape": (4, 4, 4), "scale": (8.0, 8.0, 8.0)}} +CLASSES = ["mito", "er"] + + +def _make_ds(tmp_path, suffix="", **kwargs): + import tempfile, pathlib + + sub = tmp_path / suffix if suffix else tmp_path / "ds0" + sub.mkdir(parents=True, exist_ok=True) + info = create_test_dataset(sub, classes=CLASSES, **kwargs) + return CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=CLASSES, + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + pad=True, + ) + + +class TestCellMapMultiDataset: + def test_len_sum(self, tmp_path): + ds1 = _make_ds(tmp_path, "d1") + ds2 = _make_ds(tmp_path, "d2") + multi = CellMapMultiDataset([ds1, ds2], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS) + assert len(multi) == len(ds1) + len(ds2) + + def test_getitem_returns_dict(self, tmp_path): + ds1 = _make_ds(tmp_path, "d1") + multi = CellMapMultiDataset([ds1], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS) + item = multi[0] + assert "raw" in item + assert "idx" in item + + def test_getitem_index_mapping(self, tmp_path): + ds1 = _make_ds(tmp_path, "d1") + ds2 = _make_ds(tmp_path, "d2") + n1 = len(ds1) + multi = CellMapMultiDataset([ds1, ds2], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS) + # Index 0 should come from ds1 + item0 = multi[0] + assert item0["idx"].item() == 0 + # Index n1 should come from ds2 with local idx 0 + item_n1 = multi[n1] + assert item_n1["idx"].item() == 0 + + def test_class_counts_keys(self, tmp_path): + ds1 = _make_ds(tmp_path, "d1") + multi = CellMapMultiDataset([ds1], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS) + counts = multi.class_counts + assert "totals" in counts + assert all(c in counts["totals"] for c in CLASSES) + + def test_get_crop_class_matrix_shape(self, tmp_path): + ds1 = _make_ds(tmp_path, "d1") + ds2 = _make_ds(tmp_path, "d2") + multi = CellMapMultiDataset([ds1, ds2], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS) + mat = multi.get_crop_class_matrix() + assert mat.shape == (2, len(CLASSES)) # 1 row per dataset + + def test_validation_indices_non_empty(self, tmp_path): + ds1 = _make_ds(tmp_path, "d1") + multi = CellMapMultiDataset([ds1], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS) + indices = multi.validation_indices + assert len(indices) > 0 + + def test_validation_indices_within_bounds(self, tmp_path): + ds1 = _make_ds(tmp_path, "d1") + ds2 = _make_ds(tmp_path, "d2") + multi = CellMapMultiDataset([ds1, ds2], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS) + for idx in multi.validation_indices: + assert 0 <= idx < len(multi) + + def test_repr(self, tmp_path): + ds1 = _make_ds(tmp_path, "d1") + multi = CellMapMultiDataset([ds1], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS) + r = repr(multi) + assert "CellMapMultiDataset" in r + + def test_verify(self, tmp_path): + ds1 = _make_ds(tmp_path, "d1") + multi = CellMapMultiDataset([ds1], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS) + assert multi.verify() diff --git a/tests/test_multidataset_datasplit.py b/tests/test_multidataset_datasplit.py deleted file mode 100644 index 649a455..0000000 --- a/tests/test_multidataset_datasplit.py +++ /dev/null @@ -1,821 +0,0 @@ -""" -Tests for CellMapMultiDataset and CellMapDataSplit classes. - -Tests combining multiple datasets and train/validation splits. -""" - -import csv -import os - -import pytest -import torch -import torchvision.transforms.v2 as T - -from cellmap_data import CellMapDataset, CellMapDataSplit, CellMapMultiDataset - -from .test_helpers import create_test_dataset - - -class TestCellMapMultiDataset: - """Test suite for CellMapMultiDataset class.""" - - @pytest.fixture - def multiple_datasets(self, tmp_path): - """Create multiple test datasets.""" - datasets = [] - - for i in range(3): - config = create_test_dataset( - tmp_path / f"dataset_{i}", - raw_shape=(32, 32, 32), - num_classes=2, - raw_scale=(4.0, 4.0, 4.0), - seed=42 + i, - ) - - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - ) - datasets.append(dataset) - - return datasets - - def test_initialization_basic(self, multiple_datasets): - """Test basic MultiDataset initialization.""" - multi_dataset = CellMapMultiDataset( - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets=multiple_datasets, - ) - - assert multi_dataset is not None - assert len(multi_dataset.datasets) == 3 - - def test_classes_parameter(self, multiple_datasets): - """Test classes parameter.""" - classes = ["class_0", "class_1", "class_2"] - - multi_dataset = CellMapMultiDataset( - classes=classes, - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets=multiple_datasets, - ) - - assert multi_dataset.classes == classes - - def test_input_arrays_configuration(self, multiple_datasets): - """Test input arrays configuration.""" - input_arrays = { - "raw_4nm": {"shape": (16, 16, 16), "scale": (4.0, 4.0, 4.0)}, - "raw_8nm": {"shape": (8, 8, 8), "scale": (8.0, 8.0, 8.0)}, - } - - multi_dataset = CellMapMultiDataset( - classes=["class_0", "class_1"], - input_arrays=input_arrays, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets=multiple_datasets, - ) - - assert "raw_4nm" in multi_dataset.input_arrays - assert "raw_8nm" in multi_dataset.input_arrays - - def test_target_arrays_configuration(self, multiple_datasets): - """Test target arrays configuration.""" - target_arrays = { - "labels": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}, - "distances": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}, - } - - multi_dataset = CellMapMultiDataset( - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays=target_arrays, - datasets=multiple_datasets, - ) - - assert "labels" in multi_dataset.target_arrays - assert "distances" in multi_dataset.target_arrays - - def test_empty_datasets_list(self): - """Test with empty datasets list.""" - with pytest.raises(ValueError): - CellMapDataSplit( - classes=["class_0"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets={"train": []}, - ) - - def test_single_dataset(self, multiple_datasets): - """Test with single dataset.""" - multi_dataset = CellMapMultiDataset( - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets=[multiple_datasets[0]], - ) - - assert len(multi_dataset.datasets) == 1 - - def test_spatial_transforms(self, multiple_datasets): - """Test spatial transforms configuration.""" - spatial_transforms = { - "mirror": {"axes": {"x": 0.5, "y": 0.5}}, - "rotate": {"axes": {"z": [-45, 45]}}, - } - - datasplit = CellMapDataSplit( - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets={"train": multiple_datasets}, - spatial_transforms=spatial_transforms, - force_has_data=True, - ) - - assert datasplit.spatial_transforms is not None - - -class TestCellMapDataSplit: - """Test suite for CellMapDataSplit class.""" - - @pytest.fixture - def datasplit_paths(self, tmp_path): - """Create paths for train and validation datasets.""" - # Create training datasets - train_configs = [] - for i in range(2): - config = create_test_dataset( - tmp_path / f"train_{i}", - raw_shape=(32, 32, 32), - num_classes=2, - seed=42 + i, - ) - train_configs.append(config) - - # Create validation datasets - val_configs = [] - for i in range(1): - config = create_test_dataset( - tmp_path / f"val_{i}", - raw_shape=(32, 32, 32), - num_classes=2, - seed=100 + i, - ) - val_configs.append(config) - - return train_configs, val_configs - - def test_initialization_with_dict(self, datasplit_paths): - """Test DataSplit initialization with dictionary.""" - train_configs, val_configs = datasplit_paths - - dataset_dict = { - "train": [ - {"raw": tc["raw_path"], "gt": tc["gt_path"]} for tc in train_configs - ], - "validate": [ - {"raw": vc["raw_path"], "gt": vc["gt_path"]} for vc in val_configs - ], - } - - datasplit = CellMapDataSplit( - dataset_dict=dataset_dict, - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - ) - - assert datasplit is not None - - def test_train_validation_split(self, datasplit_paths): - """Test accessing train and validation datasets.""" - train_configs, val_configs = datasplit_paths - - dataset_dict = { - "train": [ - {"raw": tc["raw_path"], "gt": tc["gt_path"]} for tc in train_configs - ], - "validate": [ - {"raw": vc["raw_path"], "gt": vc["gt_path"]} for vc in val_configs - ], - } - - datasplit = CellMapDataSplit( - dataset_dict=dataset_dict, - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - ) - - # Should have train and validation datasets - assert hasattr(datasplit, "train_datasets") or hasattr( - datasplit, "train_datasets_combined" - ) - assert hasattr(datasplit, "validation_datasets") or hasattr( - datasplit, "validation_datasets_combined" - ) - - def test_classes_parameter(self, datasplit_paths): - """Test classes parameter.""" - train_configs, val_configs = datasplit_paths - - dataset_dict = { - "train": [ - {"raw": tc["raw_path"], "gt": tc["gt_path"]} for tc in train_configs - ], - "validate": [ - {"raw": vc["raw_path"], "gt": vc["gt_path"]} for vc in val_configs - ], - } - - classes = ["class_0", "class_1", "class_2"] - - datasplit = CellMapDataSplit( - dataset_dict=dataset_dict, - classes=classes, - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - ) - - assert datasplit.classes == classes - - def test_input_arrays_configuration(self, datasplit_paths): - """Test input arrays configuration.""" - train_configs, val_configs = datasplit_paths - - dataset_dict = { - "train": [ - {"raw": tc["raw_path"], "gt": tc["gt_path"]} for tc in train_configs - ], - "validate": [ - {"raw": vc["raw_path"], "gt": vc["gt_path"]} for vc in val_configs - ], - } - - input_arrays = { - "raw_4nm": {"shape": (16, 16, 16), "scale": (4.0, 4.0, 4.0)}, - "raw_8nm": {"shape": (8, 8, 8), "scale": (8.0, 8.0, 8.0)}, - } - - datasplit = CellMapDataSplit( - dataset_dict=dataset_dict, - classes=["class_0", "class_1"], - input_arrays=input_arrays, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - ) - - assert datasplit.input_arrays is not None - - def test_spatial_transforms_configuration(self, datasplit_paths): - """Test spatial transforms configuration.""" - train_configs, val_configs = datasplit_paths - - dataset_dict = { - "train": [ - {"raw": tc["raw_path"], "gt": tc["gt_path"]} for tc in train_configs - ], - "validate": [ - {"raw": vc["raw_path"], "gt": vc["gt_path"]} for vc in val_configs - ], - } - - spatial_transforms = { - "mirror": {"axes": {"x": 0.5}}, - "rotate": {"axes": {"z": [-30, 30]}}, - } - - datasplit = CellMapDataSplit( - dataset_dict=dataset_dict, - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - spatial_transforms=spatial_transforms, - force_has_data=True, - ) - - assert datasplit is not None - - def test_only_train_split(self, datasplit_paths): - """Test with only training data.""" - train_configs, _ = datasplit_paths - - dataset_dict = { - "train": [ - {"raw": tc["raw_path"], "gt": tc["gt_path"]} for tc in train_configs - ], - } - - datasplit = CellMapDataSplit( - dataset_dict=dataset_dict, - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - ) - - assert datasplit is not None - - def test_only_validation_split(self, datasplit_paths): - """Test with only validation data.""" - _, val_configs = datasplit_paths - - dataset_dict = { - "validate": [ - {"raw": vc["raw_path"], "gt": vc["gt_path"]} for vc in val_configs - ], - } - - datasplit = CellMapDataSplit( - dataset_dict=dataset_dict, - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - ) - - assert datasplit is not None - - -class TestMultiDatasetIntegration: - """Integration tests for multi-dataset scenarios.""" - - def test_multi_dataset_with_loader(self, tmp_path): - """Test MultiDataset with DataLoader.""" - from cellmap_data import CellMapDataLoader - - # Create multiple datasets - datasets = [] - for i in range(2): - config = create_test_dataset( - tmp_path / f"dataset_{i}", - raw_shape=(24, 24, 24), - num_classes=2, - seed=42 + i, - ) - - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - ) - datasets.append(dataset) - - # Create MultiDataset - multi_dataset = CellMapMultiDataset( - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets=datasets, - ) - - # Create loader - loader = CellMapDataLoader(multi_dataset, batch_size=2, num_workers=0) - - assert loader is not None - - def test_datasplit_with_loaders(self, tmp_path): - """Test DataSplit with separate train/val loaders.""" - - # Create datasets - train_config = create_test_dataset( - tmp_path / "train", - raw_shape=(24, 24, 24), - num_classes=2, - ) - val_config = create_test_dataset( - tmp_path / "val", - raw_shape=(24, 24, 24), - num_classes=2, - ) - - dataset_dict = { - "train": [{"raw": train_config["raw_path"], "gt": train_config["gt_path"]}], - "validate": [{"raw": val_config["raw_path"], "gt": val_config["gt_path"]}], - } - - datasplit = CellMapDataSplit( - dataset_dict=dataset_dict, - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - ) - - # DataSplit should be created successfully - assert datasplit is not None - - def test_different_resolution_datasets(self, tmp_path): - """Test combining datasets with different resolutions.""" - # Create datasets with different scales - config1 = create_test_dataset( - tmp_path / "dataset_4nm", - raw_shape=(32, 32, 32), - raw_scale=(4.0, 4.0, 4.0), - num_classes=2, - ) - - config2 = create_test_dataset( - tmp_path / "dataset_8nm", - raw_shape=(32, 32, 32), - raw_scale=(8.0, 8.0, 8.0), - num_classes=2, - ) - - datasets = [] - for config in [config1, config2]: - dataset = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - ) - datasets.append(dataset) - - # Create MultiDataset - multi_dataset = CellMapMultiDataset( - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets=datasets, - ) - - assert len(multi_dataset.datasets) == 2 - - -class TestCellMapMultiDatasetProperties: - """Tests for CellMapMultiDataset properties and methods not yet covered.""" - - @pytest.fixture - def multi_dataset(self, tmp_path): - """Build a CellMapMultiDataset from two real datasets.""" - datasets = [] - for i in range(2): - config = create_test_dataset( - tmp_path / f"ds_{i}", - raw_shape=(32, 32, 32), - num_classes=2, - raw_scale=(4.0, 4.0, 4.0), - seed=i, - ) - ds = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - ) - datasets.append(ds) - - return CellMapMultiDataset( - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets=datasets, - ) - - def test_has_data_true(self, multi_dataset): - assert multi_dataset.has_data is True - - def test_class_counts_structure(self, multi_dataset): - counts = multi_dataset.class_counts - assert "totals" in counts - assert "class_0" in counts["totals"] - assert "class_1" in counts["totals"] - - def test_class_weights_keys(self, multi_dataset): - weights = multi_dataset.class_weights - assert "class_0" in weights - assert "class_1" in weights - for w in weights.values(): - assert w >= 0 - - def test_dataset_weights_keys(self, multi_dataset): - dw = multi_dataset.dataset_weights - # Should have one entry per dataset - assert len(dw) == len(multi_dataset.datasets) - for w in dw.values(): - assert w >= 0 - - def test_sample_weights_length(self, multi_dataset): - sw = multi_dataset.sample_weights - assert len(sw) == len(multi_dataset) - - def test_validation_indices_nonempty(self, multi_dataset): - indices = multi_dataset.validation_indices - assert isinstance(indices, list) - assert len(indices) > 0 - assert all(0 <= i < len(multi_dataset) for i in indices) - - def test_verify_true(self, multi_dataset): - assert multi_dataset.verify() is True - - def test_get_weighted_sampler(self, multi_dataset): - sampler = multi_dataset.get_weighted_sampler(batch_size=4) - assert sampler is not None - - def test_get_random_subset_indices(self, multi_dataset): - indices = multi_dataset.get_random_subset_indices(4, weighted=False) - assert len(indices) == 4 - - def test_get_random_subset_indices_weighted(self, multi_dataset): - indices = multi_dataset.get_random_subset_indices(4, weighted=True) - assert len(indices) == 4 - - def test_get_subset_random_sampler(self, multi_dataset): - sampler = multi_dataset.get_subset_random_sampler(4) - assert sampler is not None - - def test_get_indices(self, multi_dataset): - indices = multi_dataset.get_indices({"x": 8, "y": 8, "z": 8}) - assert isinstance(indices, list) - assert len(indices) > 0 - - def test_set_raw_value_transforms(self, multi_dataset): - new_transforms = T.Compose([T.ToDtype(torch.float, scale=True)]) - multi_dataset.set_raw_value_transforms(new_transforms) - - def test_set_target_value_transforms(self, multi_dataset): - new_transforms = T.Compose([T.ToDtype(torch.float)]) - multi_dataset.set_target_value_transforms(new_transforms) - - def test_set_spatial_transforms(self, multi_dataset): - transforms = {"mirror": {"axes": {"x": 0.5}}} - multi_dataset.set_spatial_transforms(transforms) - - def test_repr(self, multi_dataset): - r = repr(multi_dataset) - assert "CellMapMultiDataset" in r - - def test_empty_class_method(self): - empty = CellMapMultiDataset.empty() - assert empty is not None - assert empty.has_data is False - assert empty.classes == [] - assert empty.validation_indices == [] - - def test_verify_empty_returns_false(self): - empty = CellMapMultiDataset.empty() - assert empty.verify() is False - - def test_no_classes_dataset_weights(self, tmp_path): - """Dataset weights with no classes should give equal weights.""" - config = create_test_dataset(tmp_path / "ds", raw_shape=(32, 32, 32)) - ds = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - ) - multi = CellMapMultiDataset( - classes=[], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={}, - datasets=[ds], - ) - dw = multi.dataset_weights - assert list(dw.values())[0] == 1.0 - - -class TestCellMapDataSplitExtended: - """Extended tests for CellMapDataSplit.""" - - @pytest.fixture - def train_val_configs(self, tmp_path): - train = [] - for i in range(2): - train.append( - create_test_dataset( - tmp_path / f"train_{i}", - raw_shape=(32, 32, 32), - num_classes=2, - seed=i, - ) - ) - val = [ - create_test_dataset( - tmp_path / "val_0", - raw_shape=(32, 32, 32), - num_classes=2, - seed=99, - ) - ] - return train, val - - @pytest.fixture - def datasplit(self, train_val_configs): - train, val = train_val_configs - dataset_dict = { - "train": [{"raw": c["raw_path"], "gt": c["gt_path"]} for c in train], - "validate": [{"raw": c["raw_path"], "gt": c["gt_path"]} for c in val], - } - return CellMapDataSplit( - dataset_dict=dataset_dict, - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - ) - - def test_from_csv(self, tmp_path, train_val_configs): - """Test CellMapDataSplit.from_csv loads the dataset_dict correctly.""" - train, val = train_val_configs - - csv_path = str(tmp_path / "splits.csv") - rows = [] - for c in train: - raw_dir, raw_file = os.path.split(c["raw_path"]) - gt_dir, gt_file = os.path.split(c["gt_path"]) - rows.append(["train", raw_dir, raw_file, gt_dir, gt_file]) - for c in val: - raw_dir, raw_file = os.path.split(c["raw_path"]) - gt_dir, gt_file = os.path.split(c["gt_path"]) - rows.append(["validate", raw_dir, raw_file, gt_dir, gt_file]) - - with open(csv_path, "w", newline="") as f: - writer = csv.writer(f) - writer.writerows(rows) - - # Use from_csv via the constructor - split = CellMapDataSplit( - csv_path=csv_path, - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - ) - assert len(split.train_datasets) == 2 - assert len(split.validation_datasets) == 1 - - def test_from_csv_no_gt(self, tmp_path, train_val_configs): - """Test CSV rows without gt columns.""" - train, _ = train_val_configs - - csv_path = str(tmp_path / "splits_no_gt.csv") - rows = [] - for c in train: - raw_dir, raw_file = os.path.split(c["raw_path"]) - rows.append(["train", raw_dir, raw_file]) - - with open(csv_path, "w", newline="") as f: - writer = csv.writer(f) - writer.writerows(rows) - - # Direct call to from_csv - split = CellMapDataSplit( - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - force_has_data=True, - datasets={"train": []}, - ) - # Read CSV manually using the method - result = split.from_csv(csv_path) - assert "train" in result - assert len(result["train"]) == 2 - for entry in result["train"]: - assert entry["gt"] == "" - - def test_train_datasets_combined_property(self, datasplit): - combined = datasplit.train_datasets_combined - assert combined is not None - assert len(combined) > 0 - - def test_validation_datasets_combined_property(self, datasplit): - combined = datasplit.validation_datasets_combined - assert combined is not None - - def test_class_counts_property(self, datasplit): - counts = datasplit.class_counts - assert "train" in counts - assert "validate" in counts - - def test_repr(self, datasplit): - r = repr(datasplit) - assert "CellMapDataSplit" in r - - def test_no_source_raises(self): - """Providing no data source should raise ValueError.""" - with pytest.raises(ValueError, match="One of"): - CellMapDataSplit( - classes=["class_0"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - ) - - def test_set_raw_value_transforms(self, datasplit): - new_transform = T.Compose([T.ToDtype(torch.float, scale=True)]) - datasplit.set_raw_value_transforms( - train_transforms=new_transform, val_transforms=new_transform - ) - - def test_set_target_value_transforms(self, datasplit): - new_transform = T.Compose([T.ToDtype(torch.float)]) - datasplit.set_target_value_transforms(new_transform) - - def test_set_spatial_transforms(self, datasplit): - train_transforms = {"mirror": {"axes": {"x": 0.5}}} - datasplit.set_spatial_transforms(train_transforms=train_transforms) - - def test_set_raw_value_transforms_after_combined(self, datasplit): - """Test set_raw_value_transforms after train_datasets_combined is cached.""" - _ = datasplit.train_datasets_combined - new_transform = T.Compose([T.ToDtype(torch.float, scale=True)]) - datasplit.set_raw_value_transforms(train_transforms=new_transform) - - def test_set_target_value_transforms_after_combined(self, datasplit): - """Test set_target_value_transforms after combined datasets are cached.""" - _ = datasplit.train_datasets_combined - _ = datasplit.validation_datasets_combined - new_transform = T.Compose([T.ToDtype(torch.float)]) - datasplit.set_target_value_transforms(new_transform) - - def test_set_spatial_transforms_after_combined(self, datasplit): - """Test set_spatial_transforms after train_datasets_combined is cached.""" - _ = datasplit.train_datasets_combined - _ = datasplit.validation_datasets_combined - transforms = {"mirror": {"axes": {"x": 0.5}}} - datasplit.set_spatial_transforms( - train_transforms=transforms, val_transforms=transforms - ) - - def test_to_device(self, datasplit): - datasplit.to("cpu") - assert datasplit.device == "cpu" - - def test_to_device_after_combined(self, datasplit): - _ = datasplit.train_datasets_combined - _ = datasplit.validation_datasets_combined - datasplit.to("cpu") - - def test_pad_string_train(self, train_val_configs): - train, val = train_val_configs - dataset_dict = { - "train": [{"raw": c["raw_path"], "gt": c["gt_path"]} for c in train], - "validate": [{"raw": c["raw_path"], "gt": c["gt_path"]} for c in val], - } - split = CellMapDataSplit( - dataset_dict=dataset_dict, - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - pad="train", - force_has_data=True, - ) - assert split.pad_training is True - assert split.pad_validation is False - - def test_pad_string_validate(self, train_val_configs): - train, val = train_val_configs - dataset_dict = { - "train": [{"raw": c["raw_path"], "gt": c["gt_path"]} for c in train], - "validate": [{"raw": c["raw_path"], "gt": c["gt_path"]} for c in val], - } - split = CellMapDataSplit( - dataset_dict=dataset_dict, - classes=["class_0", "class_1"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - pad="validate", - force_has_data=True, - ) - assert split.pad_training is False - assert split.pad_validation is True - - def test_initialization_with_datasets_no_validate(self, tmp_path): - """Test providing datasets dict without validate key.""" - config = create_test_dataset(tmp_path / "ds", raw_shape=(32, 32, 32)) - ds = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - ) - split = CellMapDataSplit( - classes=config["classes"], - input_arrays={"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - target_arrays={"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}}, - datasets={"train": [ds]}, - force_has_data=True, - ) - assert split.validation_datasets == [] - - def test_validation_blocks_property(self, datasplit): - blocks = datasplit.validation_blocks - assert blocks is not None diff --git a/tests/test_mutable_sampler.py b/tests/test_mutable_sampler.py deleted file mode 100644 index e1220d1..0000000 --- a/tests/test_mutable_sampler.py +++ /dev/null @@ -1,279 +0,0 @@ -""" -Tests for MutableSubsetRandomSampler class. - -Tests weighted sampling and mutable subset functionality. -""" - -import numpy as np -import torch -from torch.utils.data import Dataset - -from cellmap_data import MutableSubsetRandomSampler - - -class DummyDataset(Dataset): - """Simple dummy dataset for testing samplers.""" - - def __init__(self, size=100): - self.size = size - self.data = torch.arange(size) - - def __len__(self): - return self.size - - def __getitem__(self, idx): - return self.data[idx] - - -class TestMutableSubsetRandomSampler: - """Test suite for MutableSubsetRandomSampler.""" - - def test_initialization_basic(self): - """Test basic sampler initialization.""" - indices = list(range(100)) - sampler = MutableSubsetRandomSampler(lambda: indices) - - assert sampler is not None - assert len(list(sampler)) > 0 - - def test_initialization_with_generator(self): - """Test sampler with custom generator.""" - indices = list(range(100)) - generator = torch.Generator() - generator.manual_seed(42) - - sampler = MutableSubsetRandomSampler(lambda: indices, rng=generator) - - assert sampler is not None - # Sample some indices - sample1 = list(sampler) - assert len(sample1) > 0 - - def test_reproducibility_with_seed(self): - """Test that same seed produces same sequence.""" - indices = list(range(100)) - - # First sampler - gen1 = torch.Generator() - gen1.manual_seed(42) - sampler1 = MutableSubsetRandomSampler(lambda: indices, rng=gen1) - samples1 = list(sampler1) - - # Second sampler with same seed - gen2 = torch.Generator() - gen2.manual_seed(42) - sampler2 = MutableSubsetRandomSampler(lambda: indices, rng=gen2) - samples2 = list(sampler2) - - # Should produce same sequence - assert samples1 == samples2 - - def test_different_seeds_produce_different_sequences(self): - """Test that different seeds produce different sequences.""" - indices = list(range(100)) - - # First sampler - gen1 = torch.Generator() - gen1.manual_seed(42) - sampler1 = MutableSubsetRandomSampler(lambda: indices, rng=gen1) - samples1 = list(sampler1) - - # Second sampler with different seed - gen2 = torch.Generator() - gen2.manual_seed(123) - sampler2 = MutableSubsetRandomSampler(lambda: indices, rng=gen2) - samples2 = list(sampler2) - - # Should produce different sequences - assert samples1 != samples2 - - def test_length(self): - """Test sampler length.""" - indices = list(range(50)) - sampler = MutableSubsetRandomSampler(lambda: indices) - - assert len(sampler) == 50 - - def test_iteration(self): - """Test iterating through sampler.""" - indices = list(range(20)) - sampler = MutableSubsetRandomSampler(lambda: indices) - - samples = list(sampler) - - # Should return all indices (in random order) - assert len(samples) == 20 - assert set(samples) == set(indices) - - def test_multiple_iterations(self): - """Test multiple iterations produce different orders.""" - indices = list(range(50)) - generator = torch.Generator() - generator.manual_seed(42) - sampler = MutableSubsetRandomSampler(lambda: indices, rng=generator) - - samples1 = list(sampler) - samples2 = list(sampler) - - # Each iteration should produce results - assert len(samples1) == 50 - assert len(samples2) == 50 - - # Orders may differ between iterations - # (depends on implementation) - - def test_subset_of_indices(self): - """Test sampler with subset of indices.""" - # Only sample from subset - all_indices = list(range(100)) - num_samples = 50 - subset_ind_gen = lambda: np.random.choice( - all_indices, num_samples, replace=False - ) - - sampler = MutableSubsetRandomSampler(subset_ind_gen) - samples = list(sampler) - - # All samples should be from subset - assert all(s in all_indices for s in samples) - assert len(samples) == num_samples - - def test_empty_indices(self): - """Test sampler with empty indices.""" - sampler = MutableSubsetRandomSampler(lambda: []) - samples = list(sampler) - - assert len(samples) == 0 - - def test_single_index(self): - """Test sampler with single index.""" - sampler = MutableSubsetRandomSampler(lambda: [42]) - samples = list(sampler) - - assert len(samples) == 1 - assert samples[0] == 42 - - def test_indices_mutation(self): - """Test that indices can be mutated.""" - indices = list(range(10)) - sampler = MutableSubsetRandomSampler(lambda: indices) - - # Get initial samples - samples1 = list(sampler) - assert len(samples1) == 10 - - # Mutate indices - new_indices = list(range(10, 20)) - sampler.indices_generator = lambda: new_indices - sampler.refresh() - - # New samples should be from new indices - samples2 = list(sampler) - assert all(s in new_indices for s in samples2) - - def test_use_with_dataloader(self): - """Test sampler integration with DataLoader.""" - from torch.utils.data import DataLoader - - dataset = DummyDataset(size=50) - indices = list(range(25)) # Only use first half - sampler = MutableSubsetRandomSampler(lambda: indices) - - loader = DataLoader(dataset, batch_size=5, sampler=sampler) - - # Should be able to iterate - batches = list(loader) - assert len(batches) > 0 - - # Should only see indices from sampler - all_indices = [] - for batch in batches: - all_indices.extend(batch.tolist()) - - assert all(idx in indices for idx in all_indices) - - def test_weighted_sampling_setup(self): - """Test setup for weighted sampling.""" - # Create indices with weights - indices = list(range(100)) - - # Could be used with weights (implementation specific) - sampler = MutableSubsetRandomSampler(lambda: indices) - - # Sampler should work - samples = list(sampler) - assert len(samples) == 100 - - def test_deterministic_ordering_with_seed(self): - """Test that seed makes ordering deterministic.""" - indices = list(range(30)) - - results = [] - for _ in range(3): - gen = torch.Generator() - gen.manual_seed(42) - sampler = MutableSubsetRandomSampler(indices, rng=gen) - results.append(list(sampler)) - - # All should be identical - assert results[0] == results[1] == results[2] - - def test_refresh_capability(self): - """Test that sampler can be refreshed.""" - indices = list(range(50)) - gen = torch.Generator() - sampler = MutableSubsetRandomSampler(indices, rng=gen) - - # Get first sampling - samples1 = list(sampler) - - # Get second sampling (may or may not be different) - samples2 = list(sampler) - - # Both should have correct length - assert len(samples1) == 50 - assert len(samples2) == 50 - - # Both should contain all indices - assert set(samples1) == set(indices) - assert set(samples2) == set(indices) - - -class TestWeightedSampling: - """Test weighted sampling scenarios.""" - - def test_balanced_sampling(self): - """Test balanced sampling across classes.""" - # Simulate class-balanced sampling - class_0_indices = list(range(0, 30)) # 30 samples - class_1_indices = list(range(30, 100)) # 70 samples - - # To balance, we might oversample class_0 - # For simplicity, just test that we can sample from both - all_indices = class_0_indices + class_1_indices - sampler = MutableSubsetRandomSampler(all_indices) - - samples = list(sampler) - - # Should include samples from both classes - assert any(s in class_0_indices for s in samples) - assert any(s in class_1_indices for s in samples) - - def test_stratified_indices(self): - """Test stratified sampling indices.""" - # Create stratified indices - strata = [ - list(range(0, 25)), # Stratum 1 - list(range(25, 50)), # Stratum 2 - list(range(50, 75)), # Stratum 3 - list(range(75, 100)), # Stratum 4 - ] - - # Sample from each stratum - for stratum_indices in strata: - sampler = MutableSubsetRandomSampler(stratum_indices) - samples = list(sampler) - - # All samples should be from this stratum - assert all(s in stratum_indices for s in samples) - assert len(samples) == len(stratum_indices) diff --git a/tests/test_sampler.py b/tests/test_sampler.py new file mode 100644 index 0000000..6b7f4b5 --- /dev/null +++ b/tests/test_sampler.py @@ -0,0 +1,88 @@ +"""Tests for ClassBalancedSampler.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from cellmap_data.sampler import ClassBalancedSampler + + +class FakeDataset: + """Minimal dataset with a known crop-class matrix.""" + + def __init__(self, matrix: np.ndarray): + self._matrix = matrix + + def get_crop_class_matrix(self) -> np.ndarray: + return self._matrix + + def __len__(self) -> int: + return self._matrix.shape[0] + + +class TestClassBalancedSampler: + def _make_sampler(self, matrix, samples_per_epoch=None, seed=42): + ds = FakeDataset(matrix) + return ClassBalancedSampler(ds, samples_per_epoch=samples_per_epoch, seed=seed) + + def test_basic_iteration(self): + matrix = np.array([[True, False], [False, True], [True, True]], dtype=bool) + sampler = self._make_sampler(matrix, samples_per_epoch=10) + indices = list(sampler) + assert len(indices) == 10 + assert all(0 <= i < 3 for i in indices) + + def test_len(self): + matrix = np.eye(4, dtype=bool) + sampler = self._make_sampler(matrix, samples_per_epoch=20) + assert len(sampler) == 20 + + def test_default_samples_per_epoch(self): + matrix = np.eye(3, dtype=bool) + ds = FakeDataset(matrix) + sampler = ClassBalancedSampler(ds) + assert len(sampler) == 3 + + def test_reset_between_epochs(self): + """Each __iter__ call resets counts → different random sequence.""" + matrix = np.array([[True, False], [False, True]], dtype=bool) + sampler = self._make_sampler(matrix, samples_per_epoch=6, seed=123) + epoch1 = list(sampler) + epoch2 = list(sampler) + # With small samples_per_epoch, deterministic resets should produce same result + # (counts reset → same greedy order from same RNG state if RNG is re-seeded each iter) + # We just check both are valid indices + assert all(0 <= i < 2 for i in epoch1) + assert all(0 <= i < 2 for i in epoch2) + + def test_rare_class_sampled(self): + """Class appearing in only 1 of 10 crops must still be sampled.""" + # Class 0 appears in all 10; class 1 appears only in crop 0 + matrix = np.zeros((10, 2), dtype=bool) + matrix[:, 0] = True + matrix[0, 1] = True + sampler = self._make_sampler(matrix, samples_per_epoch=20) + indices = list(sampler) + # Crop 0 must appear (it's the only way to see class 1) + assert 0 in indices + + def test_crop_class_matrix_stored(self): + matrix = np.eye(3, dtype=bool) + ds = FakeDataset(matrix) + sampler = ClassBalancedSampler(ds, samples_per_epoch=5) + assert np.array_equal(sampler.crop_class_matrix, matrix) + + def test_active_classes_only_annotated(self): + """Classes with zero crops must not be in active_classes.""" + # class 2 has no annotated crops + matrix = np.array([[True, False, False], [False, True, False]], dtype=bool) + sampler = self._make_sampler(matrix, samples_per_epoch=4) + assert 2 not in sampler.active_classes + + def test_yields_valid_indices_for_single_class(self): + matrix = np.ones((5, 1), dtype=bool) + sampler = self._make_sampler(matrix, samples_per_epoch=10) + indices = list(sampler) + assert len(indices) == 10 + assert all(0 <= i < 5 for i in indices) diff --git a/tests/test_subdataset.py b/tests/test_subdataset.py deleted file mode 100644 index ba638bb..0000000 --- a/tests/test_subdataset.py +++ /dev/null @@ -1,252 +0,0 @@ -"""Tests for CellMapSubset class.""" - -import pytest -import torch - -from cellmap_data import CellMapDataset, CellMapSubset -from cellmap_data.mutable_sampler import MutableSubsetRandomSampler - -from .test_helpers import create_minimal_test_dataset - - -class TestCellMapSubset: - """Test suite for CellMapSubset class.""" - - @pytest.fixture - def dataset_with_indices(self, tmp_path): - """Create a dataset and indices for subsetting.""" - config = create_minimal_test_dataset(tmp_path) - - input_arrays = { - "raw": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - } - } - - target_arrays = { - "gt": { - "shape": (8, 8, 8), - "scale": (4.0, 4.0, 4.0), - } - } - - dataset = CellMapDataset( - raw_path=str(config["raw_path"]), - target_path=str(config["gt_path"]), - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - force_has_data=True, - ) - - # Create indices for subset - indices = [0, 2, 4, 6, 8] - return dataset, indices - - def test_initialization(self, dataset_with_indices): - """Test basic initialization of CellMapSubset.""" - dataset, indices = dataset_with_indices - - subset = CellMapSubset(dataset, indices) - - assert isinstance(subset, CellMapSubset) - assert subset.dataset is dataset - assert list(subset.indices) == indices - assert len(subset) == len(indices) - - def test_input_arrays_property(self, dataset_with_indices): - """Test that input_arrays property delegates to parent dataset.""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - assert subset.input_arrays == dataset.input_arrays - assert "raw" in subset.input_arrays - - def test_target_arrays_property(self, dataset_with_indices): - """Test that target_arrays property delegates to parent dataset.""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - assert subset.target_arrays == dataset.target_arrays - assert "gt" in subset.target_arrays - - def test_classes_property(self, dataset_with_indices): - """Test that classes property delegates to parent dataset.""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - assert subset.classes == dataset.classes - assert len(subset.classes) > 0 - - def test_class_counts_property(self, dataset_with_indices): - """Test that class_counts property delegates to parent dataset.""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - assert subset.class_counts == dataset.class_counts - assert isinstance(subset.class_counts, dict) - - def test_class_weights_property(self, dataset_with_indices): - """Test that class_weights property delegates to parent dataset.""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - assert subset.class_weights == dataset.class_weights - assert isinstance(subset.class_weights, dict) - - def test_validation_indices_property(self, dataset_with_indices): - """Test that validation_indices property delegates to parent dataset.""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - assert subset.validation_indices == dataset.validation_indices - - def test_to_device(self, dataset_with_indices): - """Test moving subset to different device.""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - # Test moving to CPU - result = subset.to("cpu") - assert result is subset # Should return self - assert dataset.device.type == "cpu" - - def test_set_raw_value_transforms(self, dataset_with_indices): - """Test setting raw value transforms.""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - transform = lambda x: x * 2 - subset.set_raw_value_transforms(transform) - - # Verify it was set on the parent dataset - # We can't directly test if it worked, but we can verify no error was raised - assert True - - def test_set_target_value_transforms(self, dataset_with_indices): - """Test setting target value transforms.""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - transform = lambda x: x * 0.5 - subset.set_target_value_transforms(transform) - - # Verify it was set on the parent dataset - # We can't directly test if it worked, but we can verify no error was raised - assert True - - def test_get_random_subset_indices_without_replacement(self, dataset_with_indices): - """Test getting random subset indices when num_samples <= len(indices).""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - # Request fewer samples than available - num_samples = 3 - result_indices = subset.get_random_subset_indices(num_samples) - - assert len(result_indices) == num_samples - # All returned indices should be from the original subset indices - for idx in result_indices: - assert idx in indices - - def test_get_random_subset_indices_with_replacement(self, dataset_with_indices): - """Test getting random subset indices when num_samples > len(indices).""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - # Request more samples than available (requires replacement) - num_samples = 10 - with pytest.warns(UserWarning, match="Sampling with replacement"): - result_indices = subset.get_random_subset_indices(num_samples) - - assert len(result_indices) == num_samples - # All returned indices should be from the original subset indices - for idx in result_indices: - assert idx in indices - - def test_get_random_subset_indices_with_rng(self, dataset_with_indices): - """Test that get_random_subset_indices respects the RNG for reproducibility.""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - rng1 = torch.Generator().manual_seed(42) - rng2 = torch.Generator().manual_seed(42) - - num_samples = 5 - result1 = subset.get_random_subset_indices(num_samples, rng=rng1) - result2 = subset.get_random_subset_indices(num_samples, rng=rng2) - - assert result1 == result2 # Same seed should give same results - - def test_get_subset_random_sampler(self, dataset_with_indices): - """Test creating a MutableSubsetRandomSampler from subset.""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - num_samples = 5 - sampler = subset.get_subset_random_sampler(num_samples) - - assert isinstance(sampler, MutableSubsetRandomSampler) - # Sample from the sampler - sampled_indices = list(sampler) - assert len(sampled_indices) == num_samples - - def test_get_subset_random_sampler_with_rng(self, dataset_with_indices): - """Test that sampler respects RNG.""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - rng1 = torch.Generator().manual_seed(123) - rng2 = torch.Generator().manual_seed(123) - - num_samples = 5 - sampler1 = subset.get_subset_random_sampler(num_samples, rng=rng1) - sampler2 = subset.get_subset_random_sampler(num_samples, rng=rng2) - - result1 = list(sampler1) - result2 = list(sampler2) - - assert result1 == result2 # Same seed should give same results - - def test_getitem_delegates_to_parent(self, dataset_with_indices): - """Test that __getitem__ properly delegates to parent dataset with mapped indices.""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - # Get first item from subset (should be index 0 from original dataset) - item = subset[0] - - # Should return a dictionary with 'raw' and 'gt' keys - assert isinstance(item, dict) - assert "raw" in item - # The gt might not be present if force_has_data doesn't work as expected, - # but raw should always be there - - def test_subset_length(self, dataset_with_indices): - """Test that len() returns correct subset length.""" - dataset, indices = dataset_with_indices - subset = CellMapSubset(dataset, indices) - - assert len(subset) == len(indices) - assert len(subset) < len(dataset) - - def test_empty_subset(self, dataset_with_indices): - """Test creating a subset with no indices.""" - dataset, _ = dataset_with_indices - empty_indices = [] - - subset = CellMapSubset(dataset, empty_indices) - - assert len(subset) == 0 - assert list(subset.indices) == [] - - def test_single_index_subset(self, dataset_with_indices): - """Test creating a subset with a single index.""" - dataset, _ = dataset_with_indices - single_index = [0] - - subset = CellMapSubset(dataset, single_index) - - assert len(subset) == 1 - assert list(subset.indices) == single_index diff --git a/tests/test_transforms.py b/tests/test_transforms.py index f7d193c..4c0eada 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,13 +1,13 @@ -""" -Tests for augmentation transforms. +"""Tests for cellmap_data.transforms.augment — all transform classes.""" -Tests all augmentation transforms using real tensors without mocks. -""" +from __future__ import annotations + +import math import torch -import torchvision.transforms.v2 as T +import pytest -from cellmap_data.transforms import ( +from cellmap_data.transforms.augment import ( Binarize, GaussianBlur, GaussianNoise, @@ -16,365 +16,335 @@ RandomGamma, ) +# --------------------------------------------------------------------------- +# NaNtoNum +# --------------------------------------------------------------------------- -class TestGaussianNoise: - """Test suite for GaussianNoise transform.""" - - def test_gaussian_noise_basic(self): - """Test basic Gaussian noise addition.""" - torch.manual_seed(42) - transform = GaussianNoise(std=0.1) - x = torch.zeros(100, 100) - result = transform(x) +class TestNaNtoNum: + def test_import_path(self): + from cellmap_data.transforms.augment import NaNtoNum # noqa: F401 + + def test_nan_replaced_by_zero(self): + t = NaNtoNum({"nan": 0, "posinf": None, "neginf": None}) + x = torch.tensor([float("nan"), 1.0, 2.0]) + out = t(x) + assert not torch.isnan(out).any() + assert out[0] == 0.0 + + def test_nan_replaced_by_custom_value(self): + t = NaNtoNum({"nan": -1.0}) + x = torch.tensor([float("nan"), 5.0]) + out = t(x) + assert out[0] == pytest.approx(-1.0) + + def test_posinf_replaced(self): + t = NaNtoNum({"nan": 0, "posinf": 99.0, "neginf": None}) + x = torch.tensor([float("inf"), 1.0]) + out = t(x) + assert out[0] == pytest.approx(99.0) + + def test_neginf_replaced(self): + t = NaNtoNum({"nan": 0, "posinf": None, "neginf": -99.0}) + x = torch.tensor([float("-inf"), 1.0]) + out = t(x) + assert out[0] == pytest.approx(-99.0) + + def test_no_nans_unchanged(self): + t = NaNtoNum({"nan": 0, "posinf": None, "neginf": None}) + x = torch.tensor([1.0, 2.0, 3.0]) + out = t(x) + assert torch.allclose(out, x) + + def test_callable(self): + t = NaNtoNum({"nan": 0}) + x = torch.full((4, 4), float("nan")) + out = t(x) + assert (out == 0.0).all() + + def test_repr(self): + t = NaNtoNum({"nan": 0}) + assert "NaNtoNum" in repr(t) + + def test_transform_method_alias(self): + """Both __call__ and .transform() should work.""" + t = NaNtoNum({"nan": 42.0}) + x = torch.tensor([float("nan")]) + assert t.transform(x)[0] == pytest.approx(42.0) + + def test_used_as_in_api(self): + """Replicates the exact usage from API_TO_PRESERVE.md.""" + t = NaNtoNum({"nan": 0, "posinf": None, "neginf": None}) + x = torch.tensor([float("nan"), float("inf"), float("-inf"), 0.5]) + out = t(x) + assert out[0] == 0.0 + assert not torch.isnan(out).any() + + def test_3d_tensor(self): + t = NaNtoNum({"nan": 0}) + x = torch.full((4, 4, 4), float("nan")) + out = t(x) + assert not torch.isnan(out).any() + assert out.shape == torch.Size([4, 4, 4]) + + +# --------------------------------------------------------------------------- +# Binarize +# --------------------------------------------------------------------------- - # Result should be different from input - assert not torch.allclose(result, x) - # Noise should have approximately the right std - assert result.std() < 0.15 # Allow some tolerance - def test_gaussian_noise_preserves_shape(self): - """Test that Gaussian noise preserves shape.""" - transform = GaussianNoise(std=0.1) +class TestBinarize: + def test_import_path(self): + from cellmap_data.transforms.augment import Binarize # noqa: F401 + + def test_default_threshold_zero(self): + t = Binarize() + x = torch.tensor([-1.0, 0.0, 0.5, 1.0]) + out = t(x) + expected = torch.tensor([0.0, 0.0, 1.0, 1.0]) + assert torch.allclose(out, expected) + + def test_custom_threshold(self): + t = Binarize(0.5) + x = torch.tensor([0.0, 0.49, 0.5, 0.51, 1.0]) + out = t(x) + # > 0.5 → 1 + assert out[0] == 0.0 + assert out[1] == 0.0 + assert out[2] == 0.0 # 0.5 is NOT > 0.5 + assert out[3] == 1.0 + assert out[4] == 1.0 + + def test_nan_preserved_after_binarize(self): + """NaN in input → NaN in output (unknown class, not zero).""" + t = Binarize() + x = torch.tensor([float("nan"), 1.0, 0.0]) + out = t(x) + assert torch.isnan(out[0]) + assert out[1] == 1.0 + assert out[2] == 0.0 + + def test_repr(self): + t = Binarize(0.5) + assert "Binarize" in repr(t) + assert "0.5" in repr(t) + + def test_integer_input_binarize(self): + t = Binarize() + x = torch.tensor([0, 1, 2, -1], dtype=torch.int32) + out = t(x) + # > 0 → 1 + assert out[0] == 0 + assert out[1] == 1 + assert out[2] == 1 + + def test_transform_method_alias(self): + t = Binarize(0.5) + x = torch.tensor([0.0, 1.0]) + out = t.transform(x) + assert torch.allclose(out, torch.tensor([0.0, 1.0])) + + def test_used_as_in_api(self): + """Replicates the exact usage from API_TO_PRESERVE.md / datasplit default.""" + t = Binarize() + x = torch.tensor([0.0, 0.001, 0.5, 0.9, float("nan")]) + out = t(x) + assert out[0] == 0.0 # 0.0 not > 0 + assert out[1] == 1.0 # 0.001 > 0 + assert out[2] == 1.0 + assert out[3] == 1.0 + assert torch.isnan(out[4]) + + def test_shape_preserved(self): + t = Binarize() + x = torch.rand(3, 4, 5) + out = t(x) + assert out.shape == x.shape + + def test_with_import_compose(self): + """Compose(Binarize()) used as target_value_transforms in CellMapDataSplit.""" + import torchvision.transforms.v2 as T - shapes = [(10,), (10, 10), (5, 10, 10), (2, 5, 10, 10)] - for shape in shapes: - x = torch.rand(shape) - result = transform(x) - assert result.shape == x.shape + pipeline = T.Compose([T.ToDtype(torch.float), Binarize()]) + x = torch.tensor([0, 1, 2], dtype=torch.int32) + out = pipeline(x) + assert torch.allclose(out, torch.tensor([0.0, 1.0, 1.0])) - def test_gaussian_noise_zero_std(self): - """Test that zero std produces no change.""" - transform = GaussianNoise(std=0.0) - x = torch.rand(10, 10) - result = transform(x) - assert torch.allclose(result, x) +# --------------------------------------------------------------------------- +# GaussianNoise +# --------------------------------------------------------------------------- - def test_gaussian_noise_different_stds(self): - """Test different standard deviations.""" - torch.manual_seed(42) - x = torch.zeros(1000, 1000) - for std in [0.01, 0.1, 0.5, 1.0]: - transform = GaussianNoise(std=std) - result = transform(x.clone()) - # Empirical std should be close to specified std - assert abs(result.std().item() - std) < std * 0.2 # 20% tolerance +class TestGaussianNoise: + def test_output_shape(self): + t = GaussianNoise(mean=0.0, std=0.1) + x = torch.zeros(4, 4, 4) + out = t(x) + assert out.shape == x.shape + + def test_adds_noise(self): + """Output should differ from input (with high probability for nonzero std).""" + t = GaussianNoise(mean=0.0, std=1.0) + x = torch.zeros(100) + out = t(x) + assert not torch.allclose(out, x) + + def test_zero_std_unchanged(self): + t = GaussianNoise(mean=0.0, std=0.0) + x = torch.ones(10) + out = t(x) + assert torch.allclose(out, x) + + def test_mean_offset(self): + """High mean with large tensor → output mean close to input_mean + noise_mean.""" + t = GaussianNoise(mean=10.0, std=0.0) + x = torch.zeros(1000) + out = t(x) + assert out.mean().item() == pytest.approx(10.0, abs=0.5) + + +# --------------------------------------------------------------------------- +# RandomContrast +# --------------------------------------------------------------------------- class TestRandomContrast: - """Test suite for RandomContrast transform.""" - - def test_random_contrast_basic(self): - """Test basic random contrast adjustment.""" - torch.manual_seed(42) - transform = RandomContrast(contrast_range=(0.5, 1.5)) - - x = torch.linspace(0, 1, 100).reshape(10, 10) - result = transform(x) - - # Result should be different (with high probability) - assert result.shape == x.shape - - def test_random_contrast_preserves_shape(self): - """Test that random contrast preserves shape.""" - transform = RandomContrast(contrast_range=(0.8, 1.2)) - - shapes = [(10,), (10, 10), (5, 10, 10), (2, 5, 10, 10)] - for shape in shapes: - x = torch.rand(shape) - result = transform(x) - assert result.shape == x.shape - - def test_random_contrast_identity(self): - """Test that (1.0, 1.0) range produces identity.""" - transform = RandomContrast(contrast_range=(1.0, 1.0)) - + def test_output_shape(self): + t = RandomContrast((0.8, 1.2)) + x = torch.rand(4, 4, 4) + out = t(x) + assert out.shape == x.shape + + def test_output_dtype_preserved(self): + t = RandomContrast((0.9, 1.1)) + x = torch.rand(8, 8).float() + out = t(x) + assert out.dtype == x.dtype + + def test_no_nan_output(self): + t = RandomContrast((0.5, 1.5)) x = torch.rand(10, 10) - result = transform(x) - # With factor=1.0, output should be close to input - assert torch.allclose(result, x, atol=1e-5) + out = t(x) + assert not torch.isnan(out).any() - def test_random_contrast_range(self): - """Test that contrast is within specified range.""" - torch.manual_seed(42) - transform = RandomContrast(contrast_range=(0.5, 2.0)) + def test_clamped_to_dtype_max(self): + """Output should not exceed the max value for the dtype.""" + from cellmap_data.utils import torch_max_value - x = torch.linspace(0, 1, 100).reshape(10, 10) + t = RandomContrast((1.0, 1.0)) # identity contrast ratio + x = torch.rand(8, 8).float() + out = t(x) + assert (out <= torch_max_value(x.dtype) + 1e-6).all() - # Test multiple times to check randomness - results = [] - for _ in range(10): - result = transform(x.clone()) - results.append(result) - # Results should vary - assert not all(torch.allclose(results[0], r) for r in results[1:]) +# --------------------------------------------------------------------------- +# RandomGamma +# --------------------------------------------------------------------------- class TestRandomGamma: - """Test suite for RandomGamma transform.""" - - def test_random_gamma_basic(self): - """Test basic random gamma adjustment.""" - torch.manual_seed(42) - transform = RandomGamma(gamma_range=(0.5, 1.5)) - - x = torch.linspace(0, 1, 100).reshape(10, 10) - result = transform(x) - - assert result.shape == x.shape - assert result.min() >= 0.0 - assert result.max() <= 1.0 - - def test_random_gamma_preserves_shape(self): - """Test that random gamma preserves shape.""" - transform = RandomGamma(gamma_range=(0.8, 1.2)) - - shapes = [(10,), (10, 10), (5, 10, 10), (2, 5, 10, 10)] - for shape in shapes: - x = torch.rand(shape) - result = transform(x) - assert result.shape == x.shape - - def test_random_gamma_identity(self): - """Test that gamma=1.0 produces identity.""" - transform = RandomGamma(gamma_range=(1.0, 1.0)) - + def test_output_shape(self): + t = RandomGamma((0.5, 1.5)) + x = torch.rand(4, 4, 4) + out = t(x) + assert out.shape == x.shape + + def test_output_in_01(self): + """After gamma, float input in [0,1] → output in [0,1].""" + t = RandomGamma((0.5, 2.0)) + x = torch.rand(64) + out = t(x) + assert (out >= 0.0).all() + assert (out <= 1.0 + 1e-5).all() + + def test_no_nan_output(self): + t = RandomGamma((0.8, 1.2)) x = torch.rand(10, 10) - result = transform(x) - assert torch.allclose(result, x, atol=1e-5) - - def test_random_gamma_values(self): - """Test gamma effect on values.""" - torch.manual_seed(42) - x = torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0]) - - # Gamma < 1 should brighten mid-tones - transform_bright = RandomGamma(gamma_range=(0.5, 0.5)) - result_bright = transform_bright(x.clone()) - assert result_bright[2] > x[2] # Mid-tone should be brighter - - # Gamma > 1 should darken mid-tones - transform_dark = RandomGamma(gamma_range=(2.0, 2.0)) - result_dark = transform_dark(x.clone()) - assert result_dark[2] < x[2] # Mid-tone should be darker - - -class TestNaNtoNum: - """Test suite for NaNtoNum transform.""" - - def test_nan_to_num_basic(self): - """Test basic NaN replacement.""" - transform = NaNtoNum({"nan": 0.0}) - - x = torch.tensor([1.0, float("nan"), 3.0, float("nan"), 5.0]) - result = transform(x) - - expected = torch.tensor([1.0, 0.0, 3.0, 0.0, 5.0]) - assert torch.allclose(result, expected, equal_nan=False) - assert not torch.isnan(result).any() + out = t(x) + assert not torch.isnan(out).any() - def test_nan_to_num_inf(self): - """Test infinity replacement.""" - transform = NaNtoNum({"posinf": 1e6, "neginf": -1e6}) + def test_integer_input_converted(self): + """Integer input should be converted to float without error.""" + t = RandomGamma((0.9, 1.1)) + x = torch.randint(0, 256, (10,), dtype=torch.uint8) + out = t(x) # should not raise + assert out.dtype == torch.float32 - x = torch.tensor([1.0, float("inf"), -float("inf"), 3.0]) - result = transform(x) - expected = torch.tensor([1.0, 1e6, -1e6, 3.0]) - assert torch.allclose(result, expected) +# --------------------------------------------------------------------------- +# GaussianBlur +# --------------------------------------------------------------------------- - def test_nan_to_num_all_replacements(self): - """Test all replacements at once.""" - transform = NaNtoNum({"nan": 0.0, "posinf": 100.0, "neginf": -100.0}) - x = torch.tensor([float("nan"), float("inf"), -float("inf"), 1.0]) - result = transform(x) - - expected = torch.tensor([0.0, 100.0, -100.0, 1.0]) - assert torch.allclose(result, expected) - - def test_nan_to_num_preserves_valid_values(self): - """Test that valid values are preserved.""" - transform = NaNtoNum({"nan": 0.0}) - - x = torch.rand(10, 10) - result = transform(x) - assert torch.allclose(result, x) - - def test_nan_to_num_multidimensional(self): - """Test NaN replacement in multidimensional arrays.""" - transform = NaNtoNum({"nan": -1.0}) - - x = torch.rand(5, 10, 10) - x[2, 5, 5] = float("nan") - x[3, 7, 3] = float("nan") - - result = transform(x) - assert not torch.isnan(result).any() - assert result[2, 5, 5] == -1.0 - assert result[3, 7, 3] == -1.0 - - -class TestBinarize: - """Test suite for Binarize transform.""" - - def test_binarize_basic(self): - """Test basic binarization.""" - transform = Binarize(threshold=0.5) - - x = torch.tensor([0.0, 0.3, 0.5, 0.7, 1.0]) - result = transform(x) - - # Binarize uses > not >=, so 0.5 is NOT included - expected = torch.tensor([0.0, 0.0, 0.0, 1.0, 1.0]) - assert torch.allclose(result, expected) - - def test_binarize_different_thresholds(self): - """Test different threshold values.""" - x = torch.linspace(0, 1, 11) - - for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]: - transform = Binarize(threshold=threshold) - result = transform(x) - - # Check that values below or equal to threshold are 0, above are 1 - assert torch.all(result[x <= threshold] == 0.0) - assert torch.all(result[x > threshold] == 1.0) +class TestGaussianBlur: + def test_2d_output_shape(self): + t = GaussianBlur(kernel_size=3, sigma=1.0, dim=2) + x = torch.rand(8, 8) + out = t(x) + assert out.shape == x.shape - def test_binarize_preserves_shape(self): - """Test that binarize preserves shape.""" - transform = Binarize(threshold=0.5) + def test_3d_output_shape(self): + t = GaussianBlur(kernel_size=3, sigma=1.0, dim=3) + x = torch.rand(8, 8, 8) + out = t(x) + assert out.shape == x.shape - shapes = [(10,), (10, 10), (5, 10, 10), (2, 5, 10, 10)] - for shape in shapes: - x = torch.rand(shape) - result = transform(x) - assert result.shape == x.shape + def test_blurred_differs_from_input(self): + t = GaussianBlur(kernel_size=5, sigma=2.0, dim=2) + x = torch.rand(16, 16) + out = t(x) + assert not torch.allclose(out, x) - def test_binarize_output_values(self): - """Test that output only contains 0 and 1.""" - transform = Binarize(threshold=0.5) + def test_constant_input_unchanged(self): + """Blurring a constant field should return the same constant (approximately).""" + t = GaussianBlur(kernel_size=3, sigma=1.0, dim=2) + x = torch.ones(16, 16) + out = t(x) + assert torch.allclose(out, x, atol=1e-4) - x = torch.rand(100, 100) - result = transform(x) + def test_even_kernel_raises(self): + with pytest.raises(AssertionError): + GaussianBlur(kernel_size=4, dim=2) - unique_values = torch.unique(result) - assert len(unique_values) <= 2 - assert all(v in [0.0, 1.0] for v in unique_values.tolist()) + def test_invalid_dim_raises(self): + with pytest.raises(AssertionError): + GaussianBlur(dim=1) -class TestGaussianBlur: - """Test suite for GaussianBlur transform.""" - - def test_gaussian_blur_basic(self): - """Test basic Gaussian blur.""" - transform = GaussianBlur(sigma=1.0) - - # Create image with a single bright pixel - x = torch.zeros(21, 21) - x[10, 10] = 1.0 - - result = transform(x) - - # Blur should spread the value - assert result[10, 10] < 1.0 # Center should be less bright - assert result[9, 10] > 0.0 # Neighbors should have some value - assert result.sum() > 0.0 - - def test_gaussian_blur_preserves_shape(self): - """Test that Gaussian blur preserves shape.""" - # Test 2D - transform_2d = GaussianBlur(sigma=1.0, dim=2, channels=1) - x_2d = torch.rand(1, 10, 10) # Need channel dimension - result_2d = transform_2d(x_2d) - assert result_2d.shape == x_2d.shape - - # Test 3D - transform_3d = GaussianBlur(sigma=1.0, dim=3, channels=1) - x_3d = torch.rand(1, 5, 10, 10) # Need channel dimension - result_3d = transform_3d(x_3d) - assert result_3d.shape == x_3d.shape - - def test_gaussian_blur_different_sigmas(self): - """Test different sigma values.""" - x = torch.zeros(21, 21) - x[10, 10] = 1.0 - - results = [] - for sigma in [0.5, 1.0, 2.0, 3.0]: - transform = GaussianBlur(sigma=sigma) - result = transform(x.clone()) - results.append(result) - - # Larger sigma should produce more blur (lower peak) - peaks = [r[10, 10].item() for r in results] - assert peaks[0] > peaks[1] > peaks[2] > peaks[3] - - def test_gaussian_blur_smoothing(self): - """Test that blur reduces high frequencies.""" - # Create checkerboard pattern - x = torch.zeros(20, 20) - x[::2, ::2] = 1.0 - x[1::2, 1::2] = 1.0 - - transform = GaussianBlur(sigma=2.0) - result = transform(x) - - # Blurred result should have less variance - assert result.var() < x.var() +# --------------------------------------------------------------------------- +# Integration: transforms compose with torchvision +# --------------------------------------------------------------------------- class TestTransformComposition: - """Test composing multiple transforms together.""" - - def test_sequential_transforms(self): - """Test applying transforms sequentially.""" + def test_nan_to_num_in_compose(self): import torchvision.transforms.v2 as T - transforms = T.Compose( + pipeline = T.Compose( [ - T.ToDtype(torch.float32, scale=True), - GaussianNoise(std=0.01), - RandomContrast(contrast_range=(0.9, 1.1)), + NaNtoNum({"nan": 0, "posinf": None, "neginf": None}), ] ) + x = torch.tensor([float("nan"), 1.0]) + out = pipeline(x) + assert out[0] == 0.0 - x = torch.randint(0, 256, (1, 10, 10), dtype=torch.float32) - result = transforms(x) - - assert result.shape == x.shape - assert result.min() >= -0.5 # Noise might push slightly negative - assert result.max() <= 1.5 # Contrast might push slightly above 1 - - def test_transform_pipeline(self): - """Test a realistic transform pipeline.""" + def test_binarize_after_dtype_conversion(self): import torchvision.transforms.v2 as T - # Realistic preprocessing pipeline - raw_transforms = T.Compose( - [ - T.ToDtype(torch.float32, scale=True), - GaussianNoise(std=0.05), - RandomContrast(contrast_range=(0.8, 1.2)), - ] - ) + pipeline = T.Compose([T.ToDtype(torch.float), Binarize()]) + x = torch.tensor([0, 1, 2], dtype=torch.int32) + out = pipeline(x) + assert torch.allclose(out, torch.tensor([0.0, 1.0, 1.0])) - target_transforms = T.Compose( - [ - Binarize(threshold=0.5), - T.ToDtype(torch.float32), - ] - ) - - raw = torch.randint(0, 256, (1, 32, 32), dtype=torch.float32) - target = torch.rand(32, 32) - - raw_out = raw_transforms(raw) - target_out = target_transforms(target) + def test_nan_preserved_through_binarize(self): + """NaN labels must survive Binarize so loss can ignore them.""" + import torchvision.transforms.v2 as T - assert raw_out.shape == raw.shape - assert target_out.shape == target.shape - assert target_out.unique().numel() <= 2 # Should be binary + pipeline = T.Compose([T.ToDtype(torch.float), Binarize()]) + x = torch.tensor([float("nan"), 1.0, 0.0]) + out = pipeline(x) + assert torch.isnan(out[0]) + assert out[1] == 1.0 + assert out[2] == 0.0 diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 9052276..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,454 +0,0 @@ -""" -Tests for utility functions. - -Tests dtype utilities, sampling utilities, and miscellaneous utilities. -""" - -import numpy as np -import torch - -from cellmap_data.utils.misc import ( - array_has_singleton_dim, - expand_scale, - get_sliced_shape, - longest_common_substring, - permute_singleton_dimension, - split_target_path, - torch_max_value, -) -from cellmap_data.utils.sampling import min_redundant_inds - - -class TestUtilsMisc: - """Test suite for miscellaneous utility functions.""" - - def test_get_sliced_shape_basic(self): - """Test get_sliced_shape with axis parameter.""" - shape = (64, 64) - # Add singleton at axis 0 - sliced_shape = get_sliced_shape(shape, 0) - assert isinstance(sliced_shape, list) - assert 1 in sliced_shape - - def test_get_sliced_shape_different_axes(self): - """Test get_sliced_shape with different axes.""" - shape = (64, 64) - for axis in [0, 1, 2]: - sliced_shape = get_sliced_shape(shape, axis) - assert isinstance(sliced_shape, list) - - def test_torch_max_value_float32(self): - """Test torch_max_value for float32.""" - max_val = torch_max_value(torch.float32) - assert isinstance(max_val, int) - assert max_val > 0 - - def test_torch_max_value_uint8(self): - """Test torch_max_value for uint8.""" - max_val = torch_max_value(torch.uint8) - assert max_val == 255 - - def test_torch_max_value_int16(self): - """Test torch_max_value for int16.""" - max_val = torch_max_value(torch.int16) - assert max_val == 32767 - - def test_torch_max_value_int32(self): - """Test torch_max_value for int32.""" - max_val = torch_max_value(torch.int32) - assert max_val == 2147483647 - - def test_torch_max_value_bool(self): - """Test torch_max_value for bool.""" - max_val = torch_max_value(torch.bool) - assert max_val == 1 - - -class TestSamplingUtils: - """Test suite for sampling utilities.""" - - def test_sampling_weights_basic(self): - """Test basic sampling weight calculation.""" - # Create simple class distributions - class_counts = { - "class_0": 100, - "class_1": 200, - "class_2": 300, - } - - # Weights should be inversely proportional to counts - weights = [] - for count in class_counts.values(): - weight = 1.0 / count if count > 0 else 0.0 - weights.append(weight) - - # Check that smaller classes get higher weights - assert weights[0] > weights[1] > weights[2] - - def test_sampling_with_zero_counts(self): - """Test sampling when some classes have zero counts.""" - class_counts = { - "class_0": 100, - "class_1": 0, # No samples - "class_2": 300, - } - - # Zero-count classes should get zero weight - for name, count in class_counts.items(): - weight = 1.0 / count if count > 0 else 0.0 - if count == 0: - assert weight == 0.0 - else: - assert weight > 0.0 - - def test_normalized_weights(self): - """Test that weights can be normalized.""" - class_counts = [100, 200, 300, 400] - - # Calculate unnormalized weights - weights = [1.0 / count for count in class_counts] - - # Normalize - total = sum(weights) - normalized = [w / total for w in weights] - - # Should sum to 1 - assert abs(sum(normalized) - 1.0) < 1e-6 - - # Should preserve relative ordering - assert normalized[0] > normalized[1] > normalized[2] > normalized[3] - - -class TestArrayOperations: - """Test suite for array operation utilities.""" - - def test_array_2d_detection(self): - """Test detection of 2D arrays.""" - from cellmap_data.utils.misc import is_array_2D - - # is_array_2D takes a mapping of array info, not arrays directly - # Test with dict format - arr_2d_info = {"raw": {"shape": (64, 64)}} - result_2d = is_array_2D(arr_2d_info) - assert isinstance(result_2d, (bool, dict)) - - # 3D array info - arr_3d_info = {"raw": {"shape": (64, 64, 64)}} - result_3d = is_array_2D(arr_3d_info) - assert isinstance(result_3d, (bool, dict)) - - def test_2d_array_with_singleton(self): - """Test 2D detection with singleton dimensions.""" - from cellmap_data.utils.misc import is_array_2D - - # Shape with singleton - arr_info = {"raw": {"shape": (1, 64, 64)}} - result = is_array_2D(arr_info) - assert isinstance(result, (bool, dict)) - - # Tests for min_redundant_inds removed - function doesn't exist in current implementation - - -class TestPathUtilities: - """Test suite for path utility functions.""" - - def test_split_target_path_basic(self): - """Test basic target path splitting.""" - from cellmap_data.utils.misc import split_target_path - - # Path without embedded classes - path = "/path/to/dataset.zarr" - base_path, classes = split_target_path(path) - - assert isinstance(base_path, str) - assert isinstance(classes, list) - - def test_split_target_path_with_classes(self): - """Test target path splitting with embedded classes.""" - from cellmap_data.utils.misc import split_target_path - - # Path with class specification in brackets - path = "/path/to/dataset[class1,class2].zarr" - base_path, classes = split_target_path(path) - - assert isinstance(base_path, str) - assert isinstance(classes, list) - assert "{label}" in base_path # Should have placeholder - - def test_split_target_path_multiple_classes(self): - """Test with multiple classes in path.""" - from cellmap_data.utils.misc import split_target_path - - path = "/path/to/dataset.zarr" - base_path, classes = split_target_path(path) - - # Should handle standard case - assert base_path is not None - assert classes is not None - assert isinstance(classes, list) - - -class TestCoordinateTransforms: - """Test suite for coordinate transformation utilities.""" - - def test_coordinate_scaling(self): - """Test coordinate scaling transformations.""" - # Physical coordinates to voxel coordinates - physical_coord = np.array([80.0, 80.0, 80.0]) # nm - scale = np.array([8.0, 8.0, 8.0]) # nm/voxel - - voxel_coord = physical_coord / scale - - expected = np.array([10.0, 10.0, 10.0]) - assert np.allclose(voxel_coord, expected) - - def test_coordinate_translation(self): - """Test coordinate translation.""" - coord = np.array([10, 10, 10]) - offset = np.array([5, 5, 5]) - - translated = coord + offset - - expected = np.array([15, 15, 15]) - assert np.allclose(translated, expected) - - def test_coordinate_rounding(self): - """Test coordinate rounding to nearest voxel.""" - physical_coord = np.array([83.5, 87.2, 91.9]) - scale = np.array([8.0, 8.0, 8.0]) - - voxel_coord = np.round(physical_coord / scale).astype(int) - - # Should round to nearest integer voxel - assert voxel_coord.dtype == np.int64 or voxel_coord.dtype == np.int32 - assert np.all(voxel_coord >= 0) - - -class TestDtypeUtilities: - """Test suite for dtype utility functions.""" - - def test_torch_to_numpy_dtype(self): - """Test torch to numpy dtype conversion.""" - # Common dtype mappings - torch_dtypes = [ - torch.float32, - torch.float64, - torch.int32, - torch.int64, - torch.uint8, - ] - - for torch_dtype in torch_dtypes: - # Create tensor and convert to numpy - t = torch.tensor([1, 2, 3], dtype=torch_dtype) - arr = t.numpy() - - # Should have compatible numpy dtype - assert arr.dtype is not None - - def test_numpy_to_torch_dtype(self): - """Test numpy to torch dtype conversion.""" - # Common dtype mappings - numpy_dtypes = [ - np.float32, - np.float64, - np.int32, - np.int64, - np.uint8, - ] - - for numpy_dtype in numpy_dtypes: - # Create numpy array and convert to torch - arr = np.array([1, 2, 3], dtype=numpy_dtype) - t = torch.from_numpy(arr) - - # Should have compatible torch dtype - assert t.dtype is not None - - def test_dtype_max_values(self): - """Test max values for different dtypes.""" - # Test a few common dtypes - assert torch_max_value(torch.uint8) == 255 - assert torch_max_value(torch.int16) == 32767 - assert torch_max_value(torch.bool) == 1 - - # Float types return 1 (normalized) - assert torch_max_value(torch.float32) == 1 - assert torch_max_value(torch.float64) == 1 - - -class TestLongestCommonSubstring: - """Tests for longest_common_substring utility.""" - - def test_identical_strings(self): - result = longest_common_substring("abcdef", "abcdef") - assert result == "abcdef" - - def test_partial_overlap(self): - result = longest_common_substring("abcXYZ", "XYZdef") - assert result == "XYZ" - - def test_no_overlap(self): - result = longest_common_substring("abc", "xyz") - assert result == "" - - def test_substring_at_start(self): - result = longest_common_substring("hello world", "hello there") - assert result == "hello " - - def test_single_char_overlap(self): - result = longest_common_substring("abc", "cde") - assert result == "c" - - def test_empty_string(self): - result = longest_common_substring("", "abc") - assert result == "" - - def test_path_like_strings(self): - a = "/data/train/dataset_0/raw" - b = "/data/train/dataset_1/raw" - result = longest_common_substring(a, b) - assert len(result) > 0 - assert result in a and result in b - - -class TestExpandScale: - """Tests for expand_scale utility.""" - - def test_2d_scale_expanded(self): - scale = [4.0, 8.0] - result = expand_scale(scale) - assert len(result) == 3 - assert result[0] == 4.0 # first element duplicated at front - - def test_3d_scale_unchanged(self): - scale = [4.0, 8.0, 16.0] - result = expand_scale(scale) - assert result == [4.0, 8.0, 16.0] - - def test_isotropic_2d(self): - scale = [4.0, 4.0] - result = expand_scale(scale) - assert len(result) == 3 - assert result == [4.0, 4.0, 4.0] - - def test_single_element(self): - scale = [8.0] - result = expand_scale(scale) - assert len(result) == 1 # no change for 1D - - -class TestArrayHasSingletonDim: - """Tests for array_has_singleton_dim utility.""" - - def test_with_singleton(self): - arr_info = {"shape": (1, 64, 64)} - assert array_has_singleton_dim(arr_info) is True - - def test_without_singleton(self): - arr_info = {"shape": (8, 64, 64)} - assert array_has_singleton_dim(arr_info) is False - - def test_none_input(self): - assert array_has_singleton_dim(None) is False - - def test_empty_dict(self): - assert array_has_singleton_dim({}) is False - - def test_nested_dict_any(self): - arr_info = { - "raw": {"shape": (1, 64, 64)}, - "labels": {"shape": (8, 64, 64)}, - } - # summary=True (default) returns True if any has singleton - assert array_has_singleton_dim(arr_info, summary=True) is True - - def test_nested_dict_none_singleton(self): - arr_info = { - "raw": {"shape": (4, 64, 64)}, - "labels": {"shape": (8, 64, 64)}, - } - assert array_has_singleton_dim(arr_info, summary=True) is False - - def test_nested_dict_per_key(self): - arr_info = { - "raw": {"shape": (1, 64, 64)}, - "labels": {"shape": (8, 64, 64)}, - } - result = array_has_singleton_dim(arr_info, summary=False) - assert isinstance(result, dict) - assert result["raw"] is True - assert result["labels"] is False - - -class TestPermutesSingletonDimension: - """Tests for permute_singleton_dimension utility.""" - - def test_single_array_dict(self): - arr_dict = {"shape": (64, 64), "scale": (4.0, 4.0)} - permute_singleton_dimension(arr_dict, axis=0) - assert len(arr_dict["shape"]) == 3 - assert arr_dict["shape"][0] == 1 - assert len(arr_dict["scale"]) == 3 - - def test_nested_array_dict(self): - arr_dict = { - "raw": {"shape": (64, 64), "scale": (4.0, 4.0)}, - "labels": {"shape": (64, 64), "scale": (4.0, 4.0)}, - } - permute_singleton_dimension(arr_dict, axis=1) - assert len(arr_dict["raw"]["shape"]) == 3 - assert len(arr_dict["labels"]["shape"]) == 3 - - def test_axis_placement(self): - arr_dict = {"shape": (64, 64), "scale": (4.0, 8.0)} - permute_singleton_dimension(arr_dict, axis=2) - assert arr_dict["shape"][2] == 1 - - def test_existing_singleton_moved(self): - # shape already has a singleton, but at wrong position - arr_dict = {"shape": (1, 64, 64), "scale": (4.0, 4.0, 4.0)} - permute_singleton_dimension(arr_dict, axis=2) - assert arr_dict["shape"][2] == 1 - - -class TestMinRedundantInds: - """Tests for min_redundant_inds from utils.sampling.""" - - def test_basic_sampling_under_size(self): - result = min_redundant_inds(10, 5) - assert len(result) == 5 - assert result.max() < 10 - - def test_exact_size(self): - result = min_redundant_inds(10, 10) - assert len(result) == 10 - # Should be a permutation - assert set(result.tolist()) == set(range(10)) - - def test_oversample(self): - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - result = min_redundant_inds(5, 12) - assert len(result) == 12 - assert result.max() < 5 - - def test_with_rng(self): - rng = torch.Generator() - rng.manual_seed(42) - result1 = min_redundant_inds(10, 5, rng=rng) - rng.manual_seed(42) - result2 = min_redundant_inds(10, 5, rng=rng) - assert torch.equal(result1, result2) - - def test_invalid_size_raises(self): - import pytest - - with pytest.raises(ValueError): - min_redundant_inds(0, 5) - - def test_returns_tensor(self): - result = min_redundant_inds(10, 5) - assert isinstance(result, torch.Tensor) diff --git a/tests/test_windows_stress.py b/tests/test_windows_stress.py deleted file mode 100644 index 28fdcaf..0000000 --- a/tests/test_windows_stress.py +++ /dev/null @@ -1,415 +0,0 @@ -""" -Stress test for concurrent TensorStore reads and the read limiter. - -Verifies that the TensorStore read limiter (read_limiter.py) prevents crashes -under high concurrency on Windows and does not cause deadlocks or correctness -issues on any other platform. - -On Windows, running many concurrent __getitem__ calls without the limiter -triggers native hard-crashes (abort / SEH) inside TensorStore. This test -catches those regressions as non-zero exit codes in CI. - -On Linux/macOS the limiter is a no-op, but the concurrency tests still -exercise the same code paths and would expose deadlocks introduced in the -limiter itself. -""" - -import os -import platform -import threading -from concurrent.futures import ThreadPoolExecutor, as_completed -from contextlib import contextmanager -from typing import List - -import pytest - -from cellmap_data import CellMapDataset -from cellmap_data.utils.read_limiter import ( - MAX_CONCURRENT_READS, - _read_semaphore, - limit_tensorstore_reads, -) - -from .test_helpers import create_minimal_test_dataset, create_test_dataset - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -@contextmanager -def collect_worker_errors(): - """Context manager for collecting errors from concurrent workers. - - Yields: - List[Exception]: Empty list that workers can append errors to. - - Raises: - AssertionError: If any errors were collected during execution. - """ - errors: List[Exception] = [] - try: - yield errors - finally: - assert not errors, f"{len(errors)} errors occurred: {errors[:3]}" - - -_IS_WINDOWS = platform.system() == "Windows" -_IS_TENSORSTORE = ( - os.environ.get("CELLMAP_DATA_BACKEND", "tensorstore").lower() == "tensorstore" -) - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def stress_dataset_config(tmp_path): - """Dataset with multiple classes, suitable for heavy concurrent access.""" - return create_test_dataset( - tmp_path, - raw_shape=(32, 32, 32), - num_classes=3, - raw_scale=(4.0, 4.0, 4.0), - ) - - -@pytest.fixture -def stress_dataset(stress_dataset_config): - """CellMapDataset configured for concurrent stress testing.""" - config = stress_dataset_config - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - return CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - is_train=True, - force_has_data=True, - ) - - -@pytest.fixture -def raw_only_stress_dataset(tmp_path): - """raw_only CellMapDataset for testing the raw-only target read path.""" - config = create_minimal_test_dataset(tmp_path) - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - return CellMapDataset( - raw_path=config["raw_path"], - target_path=config["raw_path"], # raw as target too → raw_only path - classes=None, - input_arrays=input_arrays, - target_arrays=target_arrays, - is_train=True, - force_has_data=True, - ) - - -# --------------------------------------------------------------------------- -# Unit tests for the read_limiter module itself -# --------------------------------------------------------------------------- - - -class TestReadLimiterUnit: - """Unit tests for cellmap_data.read_limiter.""" - - def test_semaphore_state_matches_platform(self): - """Semaphore is active on Windows+tensorstore; None elsewhere.""" - if _IS_WINDOWS and _IS_TENSORSTORE: - assert ( - _read_semaphore is not None - ), "Expected a semaphore on Windows+TensorStore" - assert isinstance(MAX_CONCURRENT_READS, int) - assert MAX_CONCURRENT_READS >= 1 - else: - assert ( - _read_semaphore is None - ), "Expected no semaphore on non-Windows or non-TensorStore" - assert MAX_CONCURRENT_READS is None - - def test_env_override_respected(self): - """CELLMAP_MAX_CONCURRENT_READS is reflected in MAX_CONCURRENT_READS.""" - if _IS_WINDOWS and _IS_TENSORSTORE: - # The env var was read at import time; just verify the value is sane. - expected = int(os.environ.get("CELLMAP_MAX_CONCURRENT_READS", "1")) - assert MAX_CONCURRENT_READS == expected - - def test_context_manager_completes_without_error(self): - """A single entry/exit of limit_tensorstore_reads() does not raise.""" - with limit_tensorstore_reads(): - pass - - def test_context_manager_reraises_exceptions(self): - """Exceptions inside the context manager propagate and release the lock.""" - with pytest.raises(RuntimeError, match="boom"): - with limit_tensorstore_reads(): - raise RuntimeError("boom") - - # Semaphore must be released: a second entry must not block. - acquired = threading.Event() - - def try_acquire(): - with limit_tensorstore_reads(): - acquired.set() - - t = threading.Thread(target=try_acquire) - t.start() - t.join(timeout=5) - assert acquired.is_set(), "Semaphore was not released after exception" - - def test_concurrent_access_does_not_deadlock(self): - """50 threads entering the context manager concurrently must all finish. - - Uses t.join(timeout=...) as the deadlock detector rather than a Barrier - that requires simultaneous occupancy. The Barrier approach was broken: - with MAX_CONCURRENT_READS=1, only 1 thread can be inside the context at - a time, so a 50-party barrier inside the context can never be satisfied. - """ - errors: list[Exception] = [] - - def task(): - try: - with limit_tensorstore_reads(): - pass # just verify acquire + release works under concurrency - except Exception as exc: - errors.append(exc) - - threads = [threading.Thread(target=task) for _ in range(50)] - for t in threads: - t.start() - for t in threads: - t.join(timeout=60) - - alive = [t for t in threads if t.is_alive()] - assert not alive, f"{len(alive)} threads still alive (possible deadlock)" - assert not errors, f"Errors during concurrent access: {errors}" - - -# --------------------------------------------------------------------------- -# Dataset close() and atexit integration -# --------------------------------------------------------------------------- - - -class TestExecutorLifecycle: - """Tests for the close() method and atexit registration.""" - - def test_close_shuts_down_executor(self, stress_dataset): - """close() shuts down the executor and sets it to None.""" - # Force executor creation - _ = stress_dataset.executor - assert stress_dataset._executor is not None - - stress_dataset.close() - assert stress_dataset._executor is None - - def test_close_is_idempotent(self, stress_dataset): - """Calling close() multiple times does not raise.""" - stress_dataset.close() - stress_dataset.close() # second call must be safe - - def test_getitem_after_close_recreates_executor(self, stress_dataset): - """After close(), __getitem__ can still run (executor is re-created).""" - stress_dataset.close() - assert stress_dataset._executor is None - - # __getitem__ internally accesses .executor which lazily re-creates it - result = stress_dataset[0] - assert result is not None - assert stress_dataset._executor is not None - - -# --------------------------------------------------------------------------- -# __getitem__ stress tests -# --------------------------------------------------------------------------- - - -def _make_stress_dataset(config: dict) -> CellMapDataset: - """Create a fresh CellMapDataset from a config dict (for per-thread use).""" - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - return CellMapDataset( - raw_path=config["raw_path"], - target_path=config["gt_path"], - classes=config["classes"], - input_arrays=input_arrays, - target_arrays=target_arrays, - is_train=True, - force_has_data=True, - ) - - -class TestConcurrentGetitem: - """Stress tests for dataset.__getitem__ under sustained load. - - Design note on concurrency model - --------------------------------- - CellMapDataset uses an *internal* ThreadPoolExecutor to parallelize - per-array and per-label reads within a single ``__getitem__`` call. - The real DataLoader usage pattern is: - - * ``num_workers=0`` → the main process calls ``__getitem__`` sequentially; - the dataset's internal pool handles within-item parallelism. - * ``num_workers>0`` → each worker *process* gets its own pickle-restored - copy of the dataset (and therefore its own executor); calls are still - sequential within each worker. - - Sharing one dataset instance across multiple threads that each call - ``__getitem__`` simultaneously is NOT how DataLoader uses the dataset and - creates a deadlock: ``get_target_array`` (running in a worker slot) blocks - waiting for ``get_label_array`` sub-futures on the same pool, starving - those sub-futures of worker slots. - - The concurrent tests below therefore give each outer thread its own dataset - instance, accurately simulating ``num_workers>0`` DataLoader workers. The - TensorStore read limiter is still exercised because ``limit_tensorstore_reads`` - uses a *process-wide* semaphore, shared across all threads (and datasets) in - the same process. - """ - - NUM_ITERATIONS_PER_THREAD = 50 # iterations each "worker" runs - NUM_OUTER_THREADS = 4 # simulated DataLoader num_workers - - # ------------------------------------------------------------------ - # serial baseline - # ------------------------------------------------------------------ - - def test_serial_getitem_with_classes(self, stress_dataset): - """Sequential __getitem__ calls (multi-class) complete without error.""" - n = min(200, len(stress_dataset)) - for i in range(n): - result = stress_dataset[i % len(stress_dataset)] - assert result is not None - assert "raw" in result - assert "gt" in result - - def test_serial_getitem_raw_only(self, raw_only_stress_dataset): - """Sequential __getitem__ calls (raw-only) complete without error.""" - ds = raw_only_stress_dataset - n = min(200, len(ds)) - for i in range(n): - result = ds[i % len(ds)] - assert result is not None - - # ------------------------------------------------------------------ - # concurrent stress — each thread owns its dataset (mirrors DataLoader workers) - # ------------------------------------------------------------------ - - def test_concurrent_workers_with_classes(self, stress_dataset_config): - """Multiple simulated DataLoader workers (each with its own dataset) run concurrently.""" - with collect_worker_errors() as errors: - - def worker(thread_id: int) -> None: - ds = _make_stress_dataset(stress_dataset_config) - try: - n = min(self.NUM_ITERATIONS_PER_THREAD, len(ds)) - for i in range(n): - result = ds[i % len(ds)] - if result is None: - errors.append( - RuntimeError(f"thread {thread_id}: got None at idx {i}") - ) - except Exception as exc: - errors.append(exc) - finally: - ds.close() - - threads = [ - threading.Thread(target=worker, args=(tid,)) - for tid in range(self.NUM_OUTER_THREADS) - ] - for t in threads: - t.start() - for t in threads: - t.join(timeout=120) - - alive = [t for t in threads if t.is_alive()] - assert ( - not alive - ), f"{len(alive)} threads are still alive (possible deadlock)" - - def test_concurrent_workers_raw_only(self, tmp_path): - """Multiple simulated workers with raw-only datasets run concurrently.""" - with collect_worker_errors() as errors: - # Build a shared config for the raw-only variant - config = create_minimal_test_dataset(tmp_path) - input_arrays = {"raw": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - target_arrays = {"gt": {"shape": (8, 8, 8), "scale": (4.0, 4.0, 4.0)}} - - def worker(thread_id: int) -> None: - ds = CellMapDataset( - raw_path=config["raw_path"], - target_path=config["raw_path"], - classes=None, - input_arrays=input_arrays, - target_arrays=target_arrays, - is_train=True, - force_has_data=True, - ) - try: - n = min(self.NUM_ITERATIONS_PER_THREAD, len(ds)) - for i in range(n): - ds[i % len(ds)] - except Exception as exc: - errors.append(exc) - finally: - ds.close() - - threads = [ - threading.Thread(target=worker, args=(tid,)) - for tid in range(self.NUM_OUTER_THREADS) - ] - for t in threads: - t.start() - for t in threads: - t.join(timeout=120) - - alive = [t for t in threads if t.is_alive()] - assert ( - not alive - ), f"{len(alive)} threads are still alive (possible deadlock)" - - @pytest.mark.skipif( - not _IS_WINDOWS, - reason="Windows-specific crash regression test; skipped on non-Windows", - ) - def test_windows_high_concurrency_no_crash(self, stress_dataset_config): - """ - Windows-specific: many simulated workers must not hard-crash the process. - - Each worker has its own dataset (matching DataLoader num_workers behavior). - A native TensorStore abort appears as a non-zero pytest exit code, which - CI catches even without a Python exception being raised. - """ - num_workers = 8 - iters = 100 - - with collect_worker_errors() as errors: - - def worker(thread_id: int) -> None: - ds = _make_stress_dataset(stress_dataset_config) - try: - for i in range(iters): - ds[i % len(ds)] - except Exception as exc: - errors.append(exc) - finally: - ds.close() - - threads = [ - threading.Thread(target=worker, args=(tid,)) - for tid in range(num_workers) - ] - for t in threads: - t.start() - for t in threads: - t.join(timeout=300) - - alive = [t for t in threads if t.is_alive()] - assert not alive, f"{len(alive)} threads still alive" diff --git a/tests/test_writer.py b/tests/test_writer.py new file mode 100644 index 0000000..9395197 --- /dev/null +++ b/tests/test_writer.py @@ -0,0 +1,124 @@ +"""Tests for CellMapDatasetWriter and ImageWriter.""" + +from __future__ import annotations + +import numpy as np +import torch + +from cellmap_data import CellMapDatasetWriter +from cellmap_data.image_writer import ImageWriter + +from .test_helpers import create_test_zarr + +INPUT_ARRAYS = {"raw": {"shape": (4, 4, 4), "scale": (8.0, 8.0, 8.0)}} +TARGET_ARRAYS = {"pred": {"shape": (4, 4, 4), "scale": (8.0, 8.0, 8.0)}} + + +class TestImageWriter: + def test_write_and_read_back(self, tmp_path): + out_path = str(tmp_path / "out.zarr" / "mito") + bounding_box = {"z": (0.0, 128.0), "y": (0.0, 128.0), "x": (0.0, 128.0)} + writer = ImageWriter( + path=out_path, + target_class="mito", + scale={"z": 8.0, "y": 8.0, "x": 8.0}, + bounding_box=bounding_box, + write_voxel_shape={"z": 4, "y": 4, "x": 4}, + overwrite=True, + ) + data = torch.ones(4, 4, 4) * 0.5 + center = {"z": 16.0, "y": 16.0, "x": 16.0} + writer[center] = data + # Read back + readback = writer[center] + assert torch.allclose(readback, torch.ones(4, 4, 4) * 0.5, atol=1e-4) + + def test_shape_property(self, tmp_path): + out_path = str(tmp_path / "out.zarr" / "mito") + writer = ImageWriter( + path=out_path, + target_class="mito", + scale={"z": 8.0, "y": 8.0, "x": 8.0}, + bounding_box={"z": (0.0, 128.0), "y": (0.0, 128.0), "x": (0.0, 128.0)}, + write_voxel_shape={"z": 4, "y": 4, "x": 4}, + overwrite=True, + ) + # 128 nm / 8 nm/voxel = 16 voxels per axis + assert writer.shape == {"z": 16, "y": 16, "x": 16} + + def test_repr(self, tmp_path): + writer = ImageWriter( + path=str(tmp_path / "out.zarr" / "mito"), + target_class="mito", + scale={"z": 8.0, "y": 8.0, "x": 8.0}, + bounding_box={"z": (0.0, 64.0), "y": (0.0, 64.0), "x": (0.0, 64.0)}, + write_voxel_shape={"z": 4, "y": 4, "x": 4}, + ) + assert "ImageWriter" in repr(writer) + + +class TestCellMapDatasetWriter: + def _make_writer(self, tmp_path): + raw_path = create_test_zarr( + tmp_path, name="raw", shape=(32, 32, 32), voxel_size=[8.0, 8.0, 8.0] + ) + out_path = str(tmp_path / "predictions.zarr") + bounds = {"pred": {"z": (0.0, 256.0), "y": (0.0, 256.0), "x": (0.0, 256.0)}} + writer = CellMapDatasetWriter( + raw_path=raw_path, + target_path=out_path, + classes=["mito"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + target_bounds=bounds, + overwrite=True, + ) + return writer + + def test_len_positive(self, tmp_path): + writer = self._make_writer(tmp_path) + assert len(writer) > 0 + + def test_bounding_box(self, tmp_path): + writer = self._make_writer(tmp_path) + bb = writer.bounding_box + assert bb is not None + assert "z" in bb + + def test_getitem_returns_dict_with_idx(self, tmp_path): + writer = self._make_writer(tmp_path) + item = writer[0] + assert "idx" in item + assert isinstance(item["raw"], torch.Tensor) + + def test_writer_indices_non_empty(self, tmp_path): + writer = self._make_writer(tmp_path) + assert len(writer.writer_indices) > 0 + + def test_setitem_scalar(self, tmp_path): + """Writing a single prediction should not raise.""" + writer = self._make_writer(tmp_path) + idx = writer.writer_indices[0] + output = {"mito": torch.zeros(4, 4, 4)} + writer[idx] = output # should not raise + + def test_setitem_batch(self, tmp_path): + """Writing a batch (tensor of indices) should not raise.""" + writer = self._make_writer(tmp_path) + indices = writer.writer_indices[:2] + idx_tensor = torch.tensor(indices) + # Batch of predictions: [batch, *spatial] + output = {"mito": torch.zeros(2, 4, 4, 4)} + writer[idx_tensor] = output # should not raise + + def test_loader_iterable(self, tmp_path): + writer = self._make_writer(tmp_path) + loader = writer.loader(batch_size=2) + batches = list(loader) + assert len(batches) > 0 + assert "idx" in batches[0] + + def test_repr(self, tmp_path): + writer = self._make_writer(tmp_path) + r = repr(writer) + assert "CellMapDatasetWriter" in r From 754dddb4b07ff702bb05ff2239ae1dd2f40f9b5e Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Thu, 5 Mar 2026 10:52:35 -0500 Subject: [PATCH 02/33] feat: support 5-column CSV format in CellMapDataSplit for challenge datasets --- src/cellmap_data/datasplit.py | 18 ++++++++++--- tests/test_datasplit.py | 51 +++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/src/cellmap_data/datasplit.py b/src/cellmap_data/datasplit.py index e1902ae..c94f57c 100644 --- a/src/cellmap_data/datasplit.py +++ b/src/cellmap_data/datasplit.py @@ -129,8 +129,12 @@ def _parse_csv( ) -> dict[str, list[dict[str, str]]]: """Parse the dataset CSV into a ``dataset_dict``. - Expected CSV columns: ``split, raw_path, gt_path`` (and optionally - ``raw_name``, ``gt_name`` which are ignored). + Supports two formats: + + * **3-column**: ``split, raw_path, gt_path`` + * **5-column** (cellmap-segmentation-challenge): ``split, zarr_path, + raw_ds_name, zarr_path, gt_ds_name`` — the full raw/gt paths are + constructed by joining columns 1+2 and 3+4 respectively. """ result: dict[str, list[dict[str, str]]] = { "train": [], @@ -145,8 +149,14 @@ def _parse_csv( logger.warning("Skipping malformed CSV row: %s", row) continue split = row[0].strip() - raw_path = row[1].strip() - gt_path = row[2].strip() + if len(row) >= 5: + # 5-column challenge format: + # split, zarr_path, raw_ds_name, zarr_path, gt_ds_name + raw_path = os.path.join(row[1].strip(), row[2].strip()) + gt_path = os.path.join(row[3].strip(), row[4].strip()) + else: + raw_path = row[1].strip() + gt_path = row[2].strip() if split not in result: result[split] = [] result[split].append({"raw": raw_path, "gt": gt_path}) diff --git a/tests/test_datasplit.py b/tests/test_datasplit.py index 2113710..a1dbe60 100644 --- a/tests/test_datasplit.py +++ b/tests/test_datasplit.py @@ -117,6 +117,57 @@ def test_init_empty(self): assert len(split.train_datasets) == 0 assert len(split._validation_datasets) == 0 + def test_init_from_csv_5col(self, tmp_path): + """5-column challenge CSV format: split, zarr_path, raw_ds, zarr_path, gt_ds.""" + train_info = create_test_dataset(tmp_path / "train5", classes=CLASSES) + val_info = create_test_dataset(tmp_path / "val5", classes=CLASSES) + + # Simulate challenge CSV: split zarr_path and sub-path across columns 1+2 and 3+4 + train_zarr = str(tmp_path / "train5") + train_raw_ds = os.path.relpath(train_info["raw_path"], train_zarr) + train_gt_ds = os.path.relpath( + train_info["gt_path"].split("[")[0].rstrip(os.sep), train_zarr + ) + train_classes = ",".join(CLASSES) + + val_zarr = str(tmp_path / "val5") + val_raw_ds = os.path.relpath(val_info["raw_path"], val_zarr) + val_gt_ds = os.path.relpath( + val_info["gt_path"].split("[")[0].rstrip(os.sep), val_zarr + ) + + csv_path = str(tmp_path / "split5.csv") + with open(csv_path, "w", newline="") as f: + w = csv.writer(f) + w.writerow( + [ + "train", + train_zarr, + train_raw_ds, + train_zarr, + f"{train_gt_ds}[{train_classes}]", + ] + ) + w.writerow( + [ + "validate", + val_zarr, + val_raw_ds, + val_zarr, + f"{val_gt_ds}[{','.join(CLASSES)}]", + ] + ) + + split = CellMapDataSplit( + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + classes=CLASSES, + csv_path=csv_path, + force_has_data=True, + ) + assert len(split.train_datasets) == 1 + assert len(split._validation_datasets) == 1 + def test_init_from_datasets(self, tmp_path): from cellmap_data import CellMapDataset From b567f51761dedac5bf057357fce912b8758d2bfe Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Thu, 5 Mar 2026 14:07:51 -0500 Subject: [PATCH 03/33] feat: enhance CellMapDataset and CellMapImage for better handling of small crops and padding scenarios - Update sampling_box logic in CellMapDataset to skip EmptyImage sources and handle cases where crops are smaller than output patches. - Modify CellMapImage to return appropriate sampling boxes based on the size of the array relative to the output patch, with support for padding. - Adjust tests to validate new behavior for small crops with padding options, ensuring correct tensor shapes and handling of NaN values. --- src/cellmap_data/dataset.py | 80 ++++++++++++++++++++------- src/cellmap_data/image.py | 107 ++++++++++++++++++++++++++++++------ tests/test_dataloader.py | 2 +- tests/test_dataset.py | 95 ++++++++++++++++++++++++++++++-- tests/test_image.py | 87 ++++++++++++++++++++++++++--- 5 files changed, 320 insertions(+), 51 deletions(-) diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index 91a6aff..a0de11a 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -191,14 +191,23 @@ def bounding_box(self) -> dict[str, tuple[float, float]] | None: @cached_property def sampling_box(self) -> dict[str, tuple[float, float]] | None: - """Intersection of all source sampling boxes.""" + """Intersection of all source sampling boxes. + + ``EmptyImage`` sources (``bounding_box is None``) are skipped. + A ``CellMapImage`` with a crop smaller than the output patch returns + a single-centre ``sampling_box`` when ``pad=True`` (so + ``len(dataset)`` becomes 1), or ``None`` when ``pad=False`` (which + causes this method to return ``None`` and exclude the dataset). + """ box = None for src in list(self.input_sources.values()) + list( self.target_sources.values() ): sb = src.sampling_box if sb is None: - continue + if src.bounding_box is None: + continue # EmptyImage — no spatial constraint + return None # pad=False and crop too small → exclude box = sb if box is None else box_intersection(box, sb) if box is None: return None @@ -240,10 +249,29 @@ def __getitem__(self, idx: int) -> dict[str, Any]: result: dict[str, Any] = {"idx": torch.tensor(idx)} for name, src in self.input_sources.items(): - result[name] = src[center] - - for cls, src in self.target_sources.items(): - result[cls] = src[center] + tensor = src[center] + # Drop any singleton spatial dims (e.g. Z=1 for flat-3D inputs), + # then prepend C=1 so the batch has shape [N, C, *spatial] as + # expected by PyTorch convolutions. + if 1 in tensor.shape: + tensor = tensor.squeeze() + result[name] = tensor.unsqueeze(0) # [C=1, *spatial] + + # Stack per-class tensors under each target array name. + # The challenge (and train.py) accesses targets via target_arrays keys + # (e.g. batch["output"]), not individual class names. + if self.target_arrays: + class_tensors = [] + for cls in self.classes: + t = self.target_sources[cls][center] + # Match the same singleton-dim squeeze applied to inputs so + # spatial dims are consistent between inputs and targets. + if 1 in t.shape: + t = t.squeeze() + class_tensors.append(t) + stacked = torch.stack(class_tensors, dim=0) # [n_classes, *spatial] + for arr_name in self.target_arrays: + result[arr_name] = stacked # Reset spatial transforms for src in list(self.input_sources.values()) + list( @@ -279,10 +307,18 @@ def _generate_spatial_transforms(self) -> dict | None: mirror_cfg = cfg.get("mirror") if mirror_cfg: if isinstance(mirror_cfg, dict): - result["mirror"] = { - ax: bool(self._rng.random() < 0.5) if enabled else False - for ax, enabled in mirror_cfg.items() - } + # Support {"axes": {"x": 0.5, ...}} wrapper or flat {"x": 0.5, ...} + axis_probs = mirror_cfg.get("axes", mirror_cfg) + if isinstance(axis_probs, dict): + result["mirror"] = { + ax: bool(self._rng.random() < prob) + for ax, prob in axis_probs.items() + } + else: + axes = next(iter(self.input_sources.values())).axes + result["mirror"] = { + ax: bool(self._rng.random() < 0.5) for ax in axes + } else: axes = next(iter(self.input_sources.values())).axes result["mirror"] = {ax: bool(self._rng.random() < 0.5) for ax in axes} @@ -299,17 +335,19 @@ def _generate_spatial_transforms(self) -> dict | None: if rotate_cfg: axes = next(iter(self.input_sources.values())).axes if isinstance(rotate_cfg, dict): - # e.g. {"z": 45} → random angle in [-45, 45] degrees - angle_dict: dict[str, float] = {} - for ax, max_angle in rotate_cfg.items(): - if isinstance(max_angle, (list, tuple)): - lo, hi = max_angle - else: - lo, hi = -float(max_angle), float(max_angle) - angle_dict[ax] = float(self._rng.uniform(lo, hi)) - R = _make_rotation_matrix(axes, angle_dict) - if R is not None: - result["rotation_matrix"] = R + # Support {"axes": {"x": 45, ...}} wrapper or flat {"x": 45, ...} + axis_angles = rotate_cfg.get("axes", rotate_cfg) + if isinstance(axis_angles, dict): + angle_dict: dict[str, float] = {} + for ax, max_angle in axis_angles.items(): + if isinstance(max_angle, (list, tuple)): + lo, hi = max_angle + else: + lo, hi = -float(max_angle), float(max_angle) + angle_dict[ax] = float(self._rng.uniform(lo, hi)) + R = _make_rotation_matrix(axes, angle_dict) + if R is not None: + result["rotation_matrix"] = R return result if result else None diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 43f44ff..24e1af3 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -199,9 +199,17 @@ def bounding_box(self) -> dict[str, tuple[float, float]]: @cached_property def sampling_box(self) -> dict[str, tuple[float, float]] | None: - """Shrunk bounding box where patch centres can be drawn without going OOB. + """Bounding box where patch centres can be drawn. - Returns ``None`` if the array is smaller than the requested patch. + When the array is large enough to contain the full output patch, the + box is the bounding box shrunk by ``output_size / 2`` on each side so + that reads never extend outside the array. + + When the array is *smaller* than the output patch: + - If ``pad=True``: returns a single-centre box positioned at the + midpoint of the bounding box. ``len(dataset)`` will be 1 and + out-of-bounds voxels are filled with ``pad_value``. + - If ``pad=False``: returns ``None`` (dataset excluded). """ bb = self.bounding_box result: dict[str, tuple[float, float]] = {} @@ -210,8 +218,16 @@ def sampling_box(self) -> dict[str, tuple[float, float]] | None: lo = bb[ax][0] + half hi = bb[ax][1] - half if lo >= hi: - return None - result[ax] = (lo, hi) + if not self.pad: + return None + # Crop smaller than output: expose a single sample at the centre + bb_center = (bb[ax][0] + bb[ax][1]) / 2.0 + result[ax] = ( + bb_center - self.scale[ax] / 2.0, + bb_center + self.scale[ax] / 2.0, + ) + else: + result[ax] = (lo, hi) return result def get_center(self, idx: int) -> dict[str, float]: @@ -248,8 +264,25 @@ def set_spatial_transforms(self, transforms: dict | None) -> None: # ------------------------------------------------------------------ def _compute_read_shape(self) -> list[int]: - """Read shape large enough to accommodate the current rotation.""" - base = [self.output_shape[ax] for ax in self.axes] + """Number of zarr voxels to read to cover ``output_shape * target_scale`` + nm of world space. + + When the selected zarr level has a different voxel size from the target + scale, reading ``output_shape`` zarr voxels covers the wrong world + extent. The correct count is:: + + zarr_voxels = round(output_voxels * target_scale / zarr_voxel_size) + + The rotation case further enlarges the read to accommodate the + oversized pre-rotation patch. + """ + base = [ + max(1, int(round( + self.output_shape[ax] * self.scale[ax] + / self._voxel_size.get(ax, self.scale[ax]) + ))) + for ax in self.axes + ] if self._current_spatial_transforms is None: return base R = self._current_spatial_transforms.get("rotation_matrix") @@ -280,14 +313,32 @@ def __getitem__(self, center: Mapping[str, float]) -> torch.Tensor: pad_widths: list[tuple[int, int]] = [] for i, ax in enumerate(self.axes): - vs = self._voxel_size.get(ax, 1.0) + vs = self._voxel_size.get(ax, self.scale[ax]) vox_center = (center[ax] - self._origin.get(ax, 0.0)) / vs start = int(np.floor(vox_center - read_shape[i] / 2.0)) - end = start + read_shape[i] - pad_lo = max(0, -start) - pad_hi = max(0, end - arr_shape[i]) - slices.append(slice(max(0, start), min(arr_shape[i], end))) + # Compute pad and slice such that pad_lo + valid_len + pad_hi == read_shape[i] + # even when the read window is partially or fully outside the array. + pad_lo = max(0, min(read_shape[i], -start)) + remaining = read_shape[i] - pad_lo + arr_start = max(0, start) + valid_len = max(0, min(arr_shape[i], arr_start + remaining) - arr_start) + pad_hi = remaining - valid_len + + if valid_len == 0: + logger.warning( + "Fully out-of-bounds read for %r axis %r: centre=%.1f is " + "outside array extent [%.1f, %.1f] nm. This should not " + "happen in normal usage — check that the centre was drawn " + "from within sampling_box.", + self.path, + ax, + center[ax], + self._origin.get(ax, 0.0), + self._origin.get(ax, 0.0) + arr_shape[i] * vs, + ) + + slices.append(slice(arr_start, arr_start + valid_len)) pad_widths.append((pad_lo, pad_hi)) # Prepend slices for leading non-spatial dims (e.g. channel) @@ -310,16 +361,13 @@ def __getitem__(self, center: Mapping[str, float]) -> torch.Tensor: # Resample if voxel size differs from target scale needs_resample = any( - abs(self._voxel_size.get(ax, 1.0) - self.scale[ax]) + abs(self._voxel_size.get(ax, self.scale[ax]) - self.scale[ax]) / max(self.scale[ax], 1e-9) > 0.01 for ax in self.axes ) if needs_resample: - zoom = [self._voxel_size.get(ax, 1.0) / self.scale[ax] for ax in self.axes] - out_spatial = [ - max(1, int(round(read_shape[i] * zoom[i]))) for i in range(spatial_ndim) - ] + out_spatial = [self.output_shape[ax] for ax in self.axes] # Bring data to [N, C, *spatial] for interpolate orig_ndim = data.ndim while data.ndim < spatial_ndim + 2: @@ -373,6 +421,31 @@ def __getitem__(self, center: Mapping[str, float]) -> torch.Tensor: full_perm = list(range(n_lead)) + [n_lead + p for p in perm] data = data.permute(*full_perm).contiguous() + # Enforce exact output_shape: centre-crop oversized dims, pad undersized + # dims. Undersized dims arise from near-boundary reads when pad=False; + # padding them here makes the output shape predictable regardless of the + # pad setting. + target_shape = [self.output_shape[ax] for ax in self.axes] + actual = [data.shape[data.ndim - spatial_ndim + i] for i in range(spatial_ndim)] + if any(actual[i] != target_shape[i] for i in range(spatial_ndim)): + # Centre-crop any oversized spatial dims + crop_sl: list[Any] = [slice(None)] * (data.ndim - spatial_ndim) + for i in range(spatial_ndim): + curr, tgt = actual[i], target_shape[i] + lo = max(0, (curr - tgt) // 2) + crop_sl.append(slice(lo, lo + min(curr, tgt))) + data = data[tuple(crop_sl)] + + # Pad any undersized spatial dims (symmetric, pad_value fill) + actual = [data.shape[data.ndim - spatial_ndim + i] for i in range(spatial_ndim)] + pad_needed = [target_shape[i] - actual[i] for i in range(spatial_ndim)] + if any(p > 0 for p in pad_needed): + size_pad: list[int] = [] + for p in reversed(pad_needed): + size_pad += [p // 2, p - p // 2] + size_pad += [0, 0] * (data.ndim - spatial_ndim) + data = F.pad(data, size_pad, mode="constant", value=self.pad_value) + if self.value_transform is not None: data = self.value_transform(data) @@ -417,7 +490,7 @@ def _apply_rotation(self, data: torch.Tensor, R: np.ndarray) -> torch.Tensor: ) # Replace zero-padded corners with pad_value (only matters for continuous data) - if not np.isnan(self.pad_value) and self.pad_value != 0.0: + if self.pad_value != 0.0: # grid_sample fills OOB with 0; patch with pad_value oob_mask = (grid[..., 0].abs() > 1) | (grid[..., 1].abs() > 1) if spatial_ndim == 3: diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index e56c8a1..5edcd89 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -47,7 +47,7 @@ def test_batch_raw_shape(self, tmp_path): raw = batch["raw"] assert isinstance(raw, torch.Tensor) # batch_size is 2 but last batch may be smaller - assert raw.shape[1:] == torch.Size([4, 4, 4]) + assert raw.shape[1:] == torch.Size([1, 4, 4, 4]) break def test_len(self, tmp_path): diff --git a/tests/test_dataset.py b/tests/test_dataset.py index da1e8a9..3ca1042 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -84,7 +84,7 @@ def test_getitem_raw_is_tensor(self, tmp_path): ) item = ds[0] assert isinstance(item["raw"], torch.Tensor) - assert item["raw"].shape == torch.Size([4, 4, 4]) + assert item["raw"].shape == torch.Size([1, 4, 4, 4]) def test_getitem_missing_class_nan(self, tmp_path): info = create_test_dataset(tmp_path, classes=["mito"]) @@ -98,10 +98,13 @@ def test_getitem_missing_class_nan(self, tmp_path): pad=True, ) item = ds[0] - # unannotated class → NaN - assert torch.isnan(item["er"]).all() - # annotated class → not all NaN - assert not torch.isnan(item["mito"]).all() + # Target classes are stacked under the target array key ("labels"). + # classes=["mito", "er"] → index 0=mito, 1=er + target = item["labels"] # shape [2, z, y, x] + # unannotated class (er, index 1) → NaN + assert torch.isnan(target[1]).all() + # annotated class (mito, index 0) → not all NaN + assert not torch.isnan(target[0]).all() def test_get_crop_class_matrix_shape(self, tmp_path): info = create_test_dataset(tmp_path, classes=["mito"]) @@ -174,6 +177,88 @@ def test_repr(self, tmp_path): r = repr(ds) assert "CellMapDataset" in r + def test_small_crop_pad_true_len_one(self, tmp_path): + """Label crop smaller than output patch with pad=True → len=1, valid sample.""" + # raw: 100³ at 8nm = 800nm (large); output: 4³ at 8nm = 32nm + # raw sampling_box: [16, 784] nm in each axis + # label: 2³ at 8nm = 16nm, origin at 50nm → bb=[50,66], centre=58nm (inside raw sb) + from .test_helpers import _write_ome_ngff + import numpy as np, os + + large_raw = (np.random.default_rng(0).random((100, 100, 100)) * 255).astype(np.uint8) + raw_path = str(tmp_path / "raw.zarr") + _write_ome_ngff(raw_path, large_raw, [8.0, 8.0, 8.0]) + + gt_base = str(tmp_path / "gt.zarr") + os.makedirs(gt_base, exist_ok=True) + import json + with open(os.path.join(gt_base, ".zgroup"), "w") as f: + f.write('{"zarr_format": 2}') + + small_data = np.ones((2, 2, 2), dtype=np.uint8) + classes = ["mito", "er"] + for cls in classes: + _write_ome_ngff( + os.path.join(gt_base, cls), + small_data, + [8.0, 8.0, 8.0], + origin=[50.0, 50.0, 50.0], + ) + + gt_path = f"{gt_base}/[mito,er]" + ds = CellMapDataset( + raw_path=raw_path, + target_path=gt_path, + classes=classes, + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + pad=True, + ) + assert len(ds) == 1 + sample = ds[0] + assert sample["raw"].shape == torch.Size([1, 4, 4, 4]) + assert sample["labels"].shape == torch.Size([2, 4, 4, 4]) + # 2³=8 valid voxels per class inside a 4³=64-voxel patch → NaN outside + nan_count = torch.isnan(sample["labels"]).sum().item() + assert nan_count > 0 + + def test_small_crop_pad_false_excluded(self, tmp_path): + """Label crop smaller than output patch with pad=False → dataset excluded (len=0).""" + from .test_helpers import _write_ome_ngff + import numpy as np, os + + large_raw = (np.random.default_rng(0).random((100, 100, 100)) * 255).astype(np.uint8) + raw_path = str(tmp_path / "raw.zarr") + _write_ome_ngff(raw_path, large_raw, [8.0, 8.0, 8.0]) + + gt_base = str(tmp_path / "gt.zarr") + os.makedirs(gt_base, exist_ok=True) + import json + with open(os.path.join(gt_base, ".zgroup"), "w") as f: + f.write('{"zarr_format": 2}') + + small_data = np.ones((2, 2, 2), dtype=np.uint8) + classes = ["mito", "er"] + for cls in classes: + _write_ome_ngff( + os.path.join(gt_base, cls), + small_data, + [8.0, 8.0, 8.0], + origin=[50.0, 50.0, 50.0], + ) + + gt_path = f"{gt_base}/[mito,er]" + ds = CellMapDataset( + raw_path=raw_path, + target_path=gt_path, + classes=classes, + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + pad=False, + ) + assert ds.sampling_box is None + assert len(ds) == 0 + def test_class_counts(self, tmp_path): info = create_test_dataset(tmp_path) ds = CellMapDataset( diff --git a/tests/test_image.py b/tests/test_image.py index 6ced03d..da7be0a 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -42,12 +42,47 @@ def test_sampling_box(self, tmp_path): assert sb["z"][0] == pytest.approx(16.0) assert sb["z"][1] == pytest.approx(144.0) - def test_sampling_box_none_when_too_small(self, tmp_path): - # Patch (100 voxels) larger than array (10 voxels * 8nm = 80nm) + def test_sampling_box_none_when_too_small_no_pad(self, tmp_path): + """Array smaller than output patch with pad=False → sampling_box is None.""" path = create_test_zarr(tmp_path, shape=(10, 10, 10)) - img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [100, 100, 100]) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [100, 100, 100], pad=False) assert img.sampling_box is None + def test_sampling_box_single_centre_when_too_small_with_pad(self, tmp_path): + """Array smaller than output patch with pad=True → single-centre sampling_box.""" + # 10 voxels * 8nm = 80nm array, output 100 voxels * 8nm = 800nm patch + path = create_test_zarr(tmp_path, shape=(10, 10, 10)) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [100, 100, 100], pad=True) + sb = img.sampling_box + assert sb is not None + # Single centre: box width == scale == 8nm + for ax in img.axes: + assert sb[ax][1] - sb[ax][0] == pytest.approx(8.0) + # Centre of bounding box is midpoint of [0, 80] = 40nm + # → sampling_box centre = 40nm → lo = 40 - 4 = 36, hi = 40 + 4 = 44 + assert sb["z"][0] == pytest.approx(36.0) + assert sb["z"][1] == pytest.approx(44.0) + + def test_sampling_box_single_centre_yields_len_one(self, tmp_path): + """get_center(0) for a single-centre image returns the bounding box midpoint.""" + path = create_test_zarr(tmp_path, shape=(10, 10, 10)) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [100, 100, 100], pad=True) + center = img.get_center(0) + # bounding_box midpoint is 40nm in each axis + for ax in img.axes: + assert center[ax] == pytest.approx(40.0) + + def test_small_crop_read_shape_and_nan(self, tmp_path): + """Reading a small crop (pad=True) returns the full output shape with NaN padding.""" + path = create_test_zarr(tmp_path, shape=(10, 10, 10)) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [100, 100, 100], pad=True) + center = img.get_center(0) + patch = img[center] + assert patch.shape == torch.Size([100, 100, 100]) + # 10*10*10 = 1000 valid voxels; rest are NaN + valid = (~torch.isnan(patch)).sum().item() + assert valid == 1000 + def test_scale_level_best_match(self, tmp_path): path = create_test_zarr(tmp_path, voxel_size=[8.0, 8.0, 8.0]) img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [4, 4, 4]) @@ -81,13 +116,51 @@ def test_padding_with_nan(self, tmp_path): # Should have some NaN in the padded region assert torch.isnan(patch).any() - def test_no_padding_clamps(self, tmp_path): - """Reading near edge with pad=False → no NaN, just smaller or clamped data.""" + def test_partial_oob_left_correct_shape(self, tmp_path): + """Partial OOB on the left: output shape must equal target, left region is NaN.""" + path = create_test_zarr(tmp_path, shape=(8, 8, 8)) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [4, 4, 4], pad=True) + # Centre near edge: 4nm = 0.5 voxel, so half the patch extends before origin + center = {"z": 4.0, "y": 32.0, "x": 32.0} + patch = img[center] + assert patch.shape == torch.Size([4, 4, 4]) + # z-slices before origin should be NaN; interior slices should not be all-NaN + assert torch.isnan(patch[0]).all() # first z slice is OOB + assert not torch.isnan(patch[-1]).all() # last z slice is in-bounds + + def test_partial_oob_right_correct_shape(self, tmp_path): + """Partial OOB on the right: output shape must equal target, right region is NaN.""" + path = create_test_zarr(tmp_path, shape=(8, 8, 8)) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [4, 4, 4], pad=True) + # 8 voxels * 8nm = 64nm; centre at 60nm = 7.5th voxel + center = {"z": 60.0, "y": 32.0, "x": 32.0} + patch = img[center] + assert patch.shape == torch.Size([4, 4, 4]) + assert not torch.isnan(patch[0]).all() # first z slice is in-bounds + assert torch.isnan(patch[-1]).all() # last z slice is OOB + + def test_fully_oob_returns_all_nan_with_warning(self, tmp_path, caplog): + """Fully OOB read returns all-pad_value tensor and emits a logger warning.""" + import logging + + path = create_test_zarr(tmp_path, shape=(8, 8, 8)) + img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [4, 4, 4], pad=True) + # Centre far outside bounding_box [0, 64] nm + center = {"z": 10000.0, "y": 32.0, "x": 32.0} + with caplog.at_level(logging.WARNING, logger="cellmap_data.image"): + patch = img[center] + assert patch.shape == torch.Size([4, 4, 4]) + assert torch.isnan(patch).all() + assert any("out-of-bounds" in msg for msg in caplog.messages) + + def test_no_padding_within_sampling_box(self, tmp_path): + """Reading a centre within sampling_box with pad=False → no NaN, shape == output.""" path = create_test_zarr(tmp_path, shape=(8, 8, 8)) img = CellMapImage(path, "raw", [8.0, 8.0, 8.0], [4, 4, 4], pad=False) - center = {"z": 4.0, "y": 4.0, "x": 4.0} + # sampling_box is [16, 48] nm; use centre of the array (32 nm) + center = {"z": 32.0, "y": 32.0, "x": 32.0} patch = img[center] - # No NaN expected (clamped read, may be smaller shape) + assert patch.shape == torch.Size([4, 4, 4]) assert not torch.isnan(patch).any() def test_get_center(self, tmp_path): From bf9de7ef45172ee438f3e3e71df91f7537aa9a8a Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Thu, 5 Mar 2026 14:48:13 -0500 Subject: [PATCH 04/33] feat: improve handling of 2D array specs in CellMapDataset and EmptyImage tests --- src/cellmap_data/dataset.py | 5 +---- src/cellmap_data/empty_image.py | 4 ++++ tests/test_dataset.py | 17 +++++++++++++++++ tests/test_empty_image.py | 8 ++++++++ 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index a0de11a..1543018 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -216,11 +216,8 @@ def sampling_box(self) -> dict[str, tuple[float, float]] | None: @cached_property def _target_scale(self) -> dict[str, float]: """Scale of the first target array spec.""" - first = next(iter(self.target_arrays.values())) - scale_seq = first["scale"] first_target_src = next(iter(self.target_sources.values())) - axes = first_target_src.axes - return {c: float(s) for c, s in zip(axes, scale_seq)} + return dict(first_target_src.scale) # ------------------------------------------------------------------ # Dataset interface diff --git a/src/cellmap_data/empty_image.py b/src/cellmap_data/empty_image.py index 32ea702..3869efc 100644 --- a/src/cellmap_data/empty_image.py +++ b/src/cellmap_data/empty_image.py @@ -34,6 +34,10 @@ def __init__( self.path = path self.label_class = target_class axis_order = list(axis_order) + if len(axis_order) > len(target_scale): + target_scale = [target_scale[0]] * ( + len(axis_order) - len(target_scale) + ) + list(target_scale) if len(axis_order) > len(target_voxel_shape): ndim_fix = len(axis_order) - len(target_voxel_shape) target_voxel_shape = [1] * ndim_fix + list(target_voxel_shape) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 3ca1042..2f0c464 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -56,6 +56,23 @@ def test_len_positive(self, tmp_path): ) assert len(ds) > 0 + def test_len_2d_arrays_no_keyerror(self, tmp_path): + """Regression: 2D array specs (scale/shape with 2 values) on 3D zarr data + must not raise KeyError when computing __len__ via _target_scale.""" + info = create_test_dataset(tmp_path, shape=(32, 32, 32)) + input_2d = {"raw": {"shape": (4, 4), "scale": (8.0, 8.0)}} + target_2d = {"labels": {"shape": (4, 4), "scale": (8.0, 8.0)}} + ds = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=input_2d, + target_arrays=target_2d, + force_has_data=True, + pad=True, + ) + assert len(ds) > 0 + def test_getitem_returns_dict_with_idx(self, tmp_path): info = create_test_dataset(tmp_path) ds = CellMapDataset( diff --git a/tests/test_empty_image.py b/tests/test_empty_image.py index b299406..87773b8 100644 --- a/tests/test_empty_image.py +++ b/tests/test_empty_image.py @@ -46,3 +46,11 @@ def test_empty_image_clone(): p1 = img[{"z": 0.0, "y": 0.0, "x": 0.0}] p2 = img[{"z": 0.0, "y": 0.0, "x": 0.0}] assert p1 is not p2 + + +def test_empty_image_2d_scale_has_all_axes(): + """Regression: 2D scale/shape with default axis_order='zyx' must produce a + scale dict covering all three axes (z, y, x), not just two.""" + img = EmptyImage("fake/path", "mito", [8.0, 8.0], [4, 4]) + assert set(img.scale.keys()) == {"z", "y", "x"} + assert set(img.axes) == {"z", "y", "x"} From d0c0480de0b871f15cd7413a080aace9eca3ec62 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Thu, 5 Mar 2026 17:15:25 -0500 Subject: [PATCH 05/33] feat: implement seeding for reproducible data augmentation in CellMapDataLoader and CellMapDataset --- src/cellmap_data/dataloader.py | 53 ++++++++++++++++- src/cellmap_data/dataset.py | 3 +- tests/test_dataloader.py | 105 +++++++++++++++++++++++++++++++++ 3 files changed, 158 insertions(+), 3 deletions(-) diff --git a/src/cellmap_data/dataloader.py b/src/cellmap_data/dataloader.py index 25313b4..1c604ed 100644 --- a/src/cellmap_data/dataloader.py +++ b/src/cellmap_data/dataloader.py @@ -5,6 +5,7 @@ import logging from typing import Any, Callable, Iterator, Optional, Sequence, Union +import numpy as np import torch import torch.utils.data from torch.utils.data import DataLoader, Dataset, Subset @@ -14,6 +15,37 @@ logger = logging.getLogger(__name__) +def _collect_datasets(dataset) -> list: + """Recursively collect all leaf datasets that own an ``_rng``.""" + if hasattr(dataset, "_rng"): + return [dataset] + result = [] + for attr in ("datasets", "dataset"): + child = getattr(dataset, attr, None) + if child is None: + continue + if isinstance(child, (list, tuple)): + for ds in child: + result.extend(_collect_datasets(ds)) + else: + result.extend(_collect_datasets(child)) + return result + + +def _worker_init_fn(worker_id: int) -> None: + """Seed each dataset's numpy RNG from the per-worker torch seed. + + PyTorch derives a unique seed per worker from the DataLoader's base seed + (which respects ``torch.manual_seed``). This function propagates that + seed to every constituent ``CellMapDataset._rng`` so that spatial + augmentation transforms are reproducible given the same global seed. + """ + worker_info = torch.utils.data.get_worker_info() + seed = worker_info.seed % (2**32) + for i, ds in enumerate(_collect_datasets(worker_info.dataset)): + ds._rng = np.random.default_rng(seed + i) + + class CellMapDataLoader: """PyTorch-compatible DataLoader for CellMap datasets. @@ -100,8 +132,24 @@ def __init__( else: self._sampler = None - # pin_memory: use on CUDA, skip otherwise to avoid issues - pin = kwargs.pop("pin_memory", str(device).startswith("cuda")) + # Seed numpy RNGs so augmentation is reproducible when a torch seed is set. + # Derive a base seed from the provided generator or the global torch seed. + base_seed = ( + rng.initial_seed() if rng is not None else torch.initial_seed() + ) % (2**32) + if num_workers == 0: + # Single-process: seed directly now. + for i, ds in enumerate(_collect_datasets(dataset)): + ds._rng = np.random.default_rng(base_seed + i) + # Multi-process workers each get a unique seed via worker_init_fn. + # Respect any caller-supplied worker_init_fn by not overwriting it. + if num_workers > 0 and "worker_init_fn" not in kwargs: + kwargs["worker_init_fn"] = _worker_init_fn + + # pin_memory: opt-in only — auto-enabling it based on CUDA availability + # causes OOM failures on memory-constrained GPUs. Pass pin_memory=True + # explicitly if you want the performance benefit. + pin = kwargs.pop("pin_memory", False) self.loader = DataLoader( dataset, @@ -111,6 +159,7 @@ def __init__( num_workers=num_workers, collate_fn=self.collate_fn, pin_memory=pin, + generator=rng, **self._kwargs, ) diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index 1543018..c0b3630 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -104,6 +104,7 @@ def __init__( class_relation_dict: Optional[Mapping[str, Sequence[str]]] = None, force_has_data: bool = False, device: Optional[str | torch.device] = None, + seed: Optional[int] = None, ) -> None: self.raw_path = raw_path self.target_path = target_path @@ -116,7 +117,7 @@ def __init__( self.target_value_transforms = target_value_transforms self.class_relation_dict = class_relation_dict self.force_has_data = force_has_data - self._rng = np.random.default_rng() + self._rng = np.random.default_rng(seed) # Parse target path to get template and annotated classes gt_path_template, annotated_classes = split_target_path(target_path) diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 5edcd89..bc4970e 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -2,9 +2,11 @@ from __future__ import annotations +import numpy as np import torch from cellmap_data import CellMapDataLoader, CellMapDataset +from cellmap_data.dataloader import _collect_datasets from .test_helpers import create_test_dataset @@ -90,3 +92,106 @@ def test_repr(self, tmp_path): loader = CellMapDataLoader(ds, classes=CLASSES, batch_size=1, is_train=False) r = repr(loader) assert "CellMapDataLoader" in r + + # ------------------------------------------------------------------ + # Seeding / determinism + # ------------------------------------------------------------------ + + def test_loader_seeds_dataset_rng(self, tmp_path): + """CellMapDataLoader must seed _rng from torch.initial_seed().""" + ds = _make_ds(tmp_path) + torch.manual_seed(7) + CellMapDataLoader(ds, classes=CLASSES, batch_size=1, is_train=False, + device="cpu") + rng_state_7 = ds._rng.random() + + ds2 = _make_ds(tmp_path) + torch.manual_seed(7) + CellMapDataLoader(ds2, classes=CLASSES, batch_size=1, is_train=False, + device="cpu") + rng_state_7b = ds2._rng.random() + + ds3 = _make_ds(tmp_path) + torch.manual_seed(99) + CellMapDataLoader(ds3, classes=CLASSES, batch_size=1, is_train=False, + device="cpu") + rng_state_99 = ds3._rng.random() + + assert rng_state_7 == rng_state_7b, "same seed must yield same rng state" + assert rng_state_7 != rng_state_99, "different seeds must yield different rng state" + + def test_augmentation_reproducible_same_seed(self, tmp_path): + """Same torch seed → identical augmented batches across two loader runs.""" + SPATIAL = {"mirror": {"z": 0.5, "y": 0.5, "x": 0.5}} + info = create_test_dataset(tmp_path, classes=["mito"]) + + def get_first_batch(seed): + ds = CellMapDataset( + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + spatial_transforms=SPATIAL, + force_has_data=True, + pad=True, + ) + torch.manual_seed(seed) + loader = CellMapDataLoader(ds, classes=info["classes"], batch_size=1, + is_train=False, device="cpu") + return next(iter(loader))["raw"] + + b1 = get_first_batch(42) + b2 = get_first_batch(42) + b3 = get_first_batch(99) + assert torch.allclose(b1, b2), "same seed must produce identical augmentation" + assert not torch.allclose(b1, b3), "different seeds must produce different augmentation" + + def test_dataset_seed_param(self, tmp_path): + """CellMapDataset(seed=N) seeds _rng at construction.""" + info = create_test_dataset(tmp_path) + ds_a = CellMapDataset( + raw_path=info["raw_path"], target_path=info["gt_path"], + classes=info["classes"], input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, force_has_data=True, seed=123, + ) + ds_b = CellMapDataset( + raw_path=info["raw_path"], target_path=info["gt_path"], + classes=info["classes"], input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, force_has_data=True, seed=123, + ) + ds_c = CellMapDataset( + raw_path=info["raw_path"], target_path=info["gt_path"], + classes=info["classes"], input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, force_has_data=True, seed=456, + ) + v_a = ds_a._rng.random() + v_b = ds_b._rng.random() + v_c = ds_c._rng.random() + assert v_a == v_b, "same seed must give same first draw" + assert v_a != v_c, "different seeds must give different first draw" + + def test_collect_datasets_flat(self, tmp_path): + """_collect_datasets on a single CellMapDataset returns that dataset.""" + ds = _make_ds(tmp_path) + collected = _collect_datasets(ds) + assert collected == [ds] + + def test_collect_datasets_multidataset(self, tmp_path): + """_collect_datasets traverses CellMapMultiDataset.""" + from cellmap_data import CellMapMultiDataset + + info = create_test_dataset(tmp_path) + ds1 = CellMapDataset( + raw_path=info["raw_path"], target_path=info["gt_path"], + classes=info["classes"], input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, force_has_data=True, + ) + ds2 = CellMapDataset( + raw_path=info["raw_path"], target_path=info["gt_path"], + classes=info["classes"], input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, force_has_data=True, + ) + multi = CellMapMultiDataset([ds1, ds2], info["classes"], INPUT_ARRAYS, TARGET_ARRAYS) + collected = _collect_datasets(multi) + assert set(id(d) for d in collected) == {id(ds1), id(ds2)} From ab59d78ad00d9dc9d13b46973c61fa4ab23ba6bf Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 13 Mar 2026 14:06:02 -0400 Subject: [PATCH 06/33] fix: remove unused import of Subset from DataLoader in dataloader.py & black format --- src/cellmap_data/dataloader.py | 2 +- src/cellmap_data/image.py | 18 +++++--- tests/test_dataloader.py | 80 +++++++++++++++++++++++----------- tests/test_dataset.py | 10 ++++- 4 files changed, 76 insertions(+), 34 deletions(-) diff --git a/src/cellmap_data/dataloader.py b/src/cellmap_data/dataloader.py index 1c604ed..20a58e0 100644 --- a/src/cellmap_data/dataloader.py +++ b/src/cellmap_data/dataloader.py @@ -8,7 +8,7 @@ import numpy as np import torch import torch.utils.data -from torch.utils.data import DataLoader, Dataset, Subset +from torch.utils.data import DataLoader, Dataset from .sampler import ClassBalancedSampler diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 24e1af3..a2ae5bc 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -277,10 +277,16 @@ def _compute_read_shape(self) -> list[int]: oversized pre-rotation patch. """ base = [ - max(1, int(round( - self.output_shape[ax] * self.scale[ax] - / self._voxel_size.get(ax, self.scale[ax]) - ))) + max( + 1, + int( + round( + self.output_shape[ax] + * self.scale[ax] + / self._voxel_size.get(ax, self.scale[ax]) + ) + ), + ) for ax in self.axes ] if self._current_spatial_transforms is None: @@ -437,7 +443,9 @@ def __getitem__(self, center: Mapping[str, float]) -> torch.Tensor: data = data[tuple(crop_sl)] # Pad any undersized spatial dims (symmetric, pad_value fill) - actual = [data.shape[data.ndim - spatial_ndim + i] for i in range(spatial_ndim)] + actual = [ + data.shape[data.ndim - spatial_ndim + i] for i in range(spatial_ndim) + ] pad_needed = [target_shape[i] - actual[i] for i in range(spatial_ndim)] if any(p > 0 for p in pad_needed): size_pad: list[int] = [] diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index bc4970e..9f091dd 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -101,24 +101,29 @@ def test_loader_seeds_dataset_rng(self, tmp_path): """CellMapDataLoader must seed _rng from torch.initial_seed().""" ds = _make_ds(tmp_path) torch.manual_seed(7) - CellMapDataLoader(ds, classes=CLASSES, batch_size=1, is_train=False, - device="cpu") + CellMapDataLoader( + ds, classes=CLASSES, batch_size=1, is_train=False, device="cpu" + ) rng_state_7 = ds._rng.random() ds2 = _make_ds(tmp_path) torch.manual_seed(7) - CellMapDataLoader(ds2, classes=CLASSES, batch_size=1, is_train=False, - device="cpu") + CellMapDataLoader( + ds2, classes=CLASSES, batch_size=1, is_train=False, device="cpu" + ) rng_state_7b = ds2._rng.random() ds3 = _make_ds(tmp_path) torch.manual_seed(99) - CellMapDataLoader(ds3, classes=CLASSES, batch_size=1, is_train=False, - device="cpu") + CellMapDataLoader( + ds3, classes=CLASSES, batch_size=1, is_train=False, device="cpu" + ) rng_state_99 = ds3._rng.random() assert rng_state_7 == rng_state_7b, "same seed must yield same rng state" - assert rng_state_7 != rng_state_99, "different seeds must yield different rng state" + assert ( + rng_state_7 != rng_state_99 + ), "different seeds must yield different rng state" def test_augmentation_reproducible_same_seed(self, tmp_path): """Same torch seed → identical augmented batches across two loader runs.""" @@ -137,33 +142,48 @@ def get_first_batch(seed): pad=True, ) torch.manual_seed(seed) - loader = CellMapDataLoader(ds, classes=info["classes"], batch_size=1, - is_train=False, device="cpu") + loader = CellMapDataLoader( + ds, classes=info["classes"], batch_size=1, is_train=False, device="cpu" + ) return next(iter(loader))["raw"] b1 = get_first_batch(42) b2 = get_first_batch(42) b3 = get_first_batch(99) assert torch.allclose(b1, b2), "same seed must produce identical augmentation" - assert not torch.allclose(b1, b3), "different seeds must produce different augmentation" + assert not torch.allclose( + b1, b3 + ), "different seeds must produce different augmentation" def test_dataset_seed_param(self, tmp_path): """CellMapDataset(seed=N) seeds _rng at construction.""" info = create_test_dataset(tmp_path) ds_a = CellMapDataset( - raw_path=info["raw_path"], target_path=info["gt_path"], - classes=info["classes"], input_arrays=INPUT_ARRAYS, - target_arrays=TARGET_ARRAYS, force_has_data=True, seed=123, + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + seed=123, ) ds_b = CellMapDataset( - raw_path=info["raw_path"], target_path=info["gt_path"], - classes=info["classes"], input_arrays=INPUT_ARRAYS, - target_arrays=TARGET_ARRAYS, force_has_data=True, seed=123, + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + seed=123, ) ds_c = CellMapDataset( - raw_path=info["raw_path"], target_path=info["gt_path"], - classes=info["classes"], input_arrays=INPUT_ARRAYS, - target_arrays=TARGET_ARRAYS, force_has_data=True, seed=456, + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + seed=456, ) v_a = ds_a._rng.random() v_b = ds_b._rng.random() @@ -183,15 +203,23 @@ def test_collect_datasets_multidataset(self, tmp_path): info = create_test_dataset(tmp_path) ds1 = CellMapDataset( - raw_path=info["raw_path"], target_path=info["gt_path"], - classes=info["classes"], input_arrays=INPUT_ARRAYS, - target_arrays=TARGET_ARRAYS, force_has_data=True, + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, ) ds2 = CellMapDataset( - raw_path=info["raw_path"], target_path=info["gt_path"], - classes=info["classes"], input_arrays=INPUT_ARRAYS, - target_arrays=TARGET_ARRAYS, force_has_data=True, + raw_path=info["raw_path"], + target_path=info["gt_path"], + classes=info["classes"], + input_arrays=INPUT_ARRAYS, + target_arrays=TARGET_ARRAYS, + force_has_data=True, + ) + multi = CellMapMultiDataset( + [ds1, ds2], info["classes"], INPUT_ARRAYS, TARGET_ARRAYS ) - multi = CellMapMultiDataset([ds1, ds2], info["classes"], INPUT_ARRAYS, TARGET_ARRAYS) collected = _collect_datasets(multi) assert set(id(d) for d in collected) == {id(ds1), id(ds2)} diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 2f0c464..cada76b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -202,13 +202,16 @@ def test_small_crop_pad_true_len_one(self, tmp_path): from .test_helpers import _write_ome_ngff import numpy as np, os - large_raw = (np.random.default_rng(0).random((100, 100, 100)) * 255).astype(np.uint8) + large_raw = (np.random.default_rng(0).random((100, 100, 100)) * 255).astype( + np.uint8 + ) raw_path = str(tmp_path / "raw.zarr") _write_ome_ngff(raw_path, large_raw, [8.0, 8.0, 8.0]) gt_base = str(tmp_path / "gt.zarr") os.makedirs(gt_base, exist_ok=True) import json + with open(os.path.join(gt_base, ".zgroup"), "w") as f: f.write('{"zarr_format": 2}') @@ -244,13 +247,16 @@ def test_small_crop_pad_false_excluded(self, tmp_path): from .test_helpers import _write_ome_ngff import numpy as np, os - large_raw = (np.random.default_rng(0).random((100, 100, 100)) * 255).astype(np.uint8) + large_raw = (np.random.default_rng(0).random((100, 100, 100)) * 255).astype( + np.uint8 + ) raw_path = str(tmp_path / "raw.zarr") _write_ome_ngff(raw_path, large_raw, [8.0, 8.0, 8.0]) gt_base = str(tmp_path / "gt.zarr") os.makedirs(gt_base, exist_ok=True) import json + with open(os.path.join(gt_base, ".zgroup"), "w") as f: f.write('{"zarr_format": 2}') From 675b614f3a045a01ce222d9d169403c080099a77 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Fri, 13 Mar 2026 14:13:15 -0400 Subject: [PATCH 07/33] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/cellmap_data/image_writer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/cellmap_data/image_writer.py b/src/cellmap_data/image_writer.py index 9145954..f699724 100644 --- a/src/cellmap_data/image_writer.py +++ b/src/cellmap_data/image_writer.py @@ -193,13 +193,17 @@ def _write_single( arr_shape = [self.shape[c] for c in self.spatial_axes] slices: list[slice] = [] + src_starts: list[int] = [] for i, c in enumerate(self.spatial_axes): start_nm = center[c] - self.write_world_shape[c] / 2.0 start_vox = int(round((start_nm - self.offset[c]) / self.scale[c])) end_vox = start_vox + self.write_voxel_shape[c] clamp_start = max(0, start_vox) clamp_end = min(arr_shape[i], end_vox) + # Where the visible region starts inside the source patch along this axis + src_start = clamp_start - start_vox slices.append(slice(clamp_start, clamp_end)) + src_starts.append(src_start) if isinstance(data, torch.Tensor): data_np = data.detach().cpu().numpy() @@ -214,7 +218,13 @@ def _write_single( # Crop data to clamped region (near array edges) actual = tuple(s.stop - s.start for s in slices) if data_np.shape != actual: - data_np = data_np[tuple(slice(0, e) for e in actual)] + # Use per-axis offsets so that when start_vox < 0, we skip the out-of-bounds prefix + data_np = data_np[ + tuple( + slice(src_starts[i], src_starts[i] + actual[i]) + for i in range(len(self.spatial_axes)) + ) + ] arr[tuple(slices)] = data_np From 057f91cf0d32511b5e5649b460654c75dbac85c2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 13 Mar 2026 18:22:43 +0000 Subject: [PATCH 08/33] Initial plan From b5dd7ae64e7fe8bd6a4e5fca7b06c5cd024c70a2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 13 Mar 2026 18:23:44 +0000 Subject: [PATCH 09/33] Initial plan From 911d9a6aa5c11db1ff57baf5afcb01899f54713d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 13 Mar 2026 18:34:33 +0000 Subject: [PATCH 10/33] Initial plan From a1883ec9e772f17490f14dd808a47d7be782b980 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 13 Mar 2026 18:40:53 +0000 Subject: [PATCH 11/33] fix: remove scalar write support from ImageWriter and CellMapDatasetWriter Co-authored-by: rhoadesScholar <37990507+rhoadesScholar@users.noreply.github.com> --- src/cellmap_data/dataset_writer.py | 4 +--- src/cellmap_data/image_writer.py | 7 ++++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index 780dbfa..263bce5 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -252,9 +252,7 @@ def __setitem__( for key, val in arrays.items(): if key in _SKIP_KEYS: continue - if isinstance(val, (int, float)): - item[key] = val - elif isinstance(val, dict): + if isinstance(val, dict): item[key] = {k: v[batch_i] for k, v in val.items()} else: item[key] = val[batch_i] diff --git a/src/cellmap_data/image_writer.py b/src/cellmap_data/image_writer.py index f699724..e8e5b43 100644 --- a/src/cellmap_data/image_writer.py +++ b/src/cellmap_data/image_writer.py @@ -170,7 +170,7 @@ def _zarr_array(self) -> zarr.Array: def __setitem__( self, coords: Mapping[str, float] | Mapping[str, Sequence], - data: torch.Tensor | ArrayLike | float | int, + data: torch.Tensor | ArrayLike, ) -> None: """Write *data* at the location given by *coords*. @@ -178,6 +178,11 @@ def __setitem__( - ``{axis: float}`` centre coordinates — single patch. - ``{axis: Sequence[float]}`` centres — batch. """ + if np.isscalar(data): + raise TypeError( + "Scalar writes are not supported. " + "Pass an array or tensor with shape matching the patch." + ) first = next(iter(coords.values())) if isinstance(first, (int, float)): self._write_single(coords, data) # type: ignore[arg-type] From 78bef7a57895993f5b5093090eacac2eb1e3304f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 13 Mar 2026 18:47:39 +0000 Subject: [PATCH 12/33] Remove scalar write support from ImageWriter Co-authored-by: rhoadesScholar <37990507+rhoadesScholar@users.noreply.github.com> --- src/cellmap_data/image_writer.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/cellmap_data/image_writer.py b/src/cellmap_data/image_writer.py index f699724..f5fd726 100644 --- a/src/cellmap_data/image_writer.py +++ b/src/cellmap_data/image_writer.py @@ -170,14 +170,24 @@ def _zarr_array(self) -> zarr.Array: def __setitem__( self, coords: Mapping[str, float] | Mapping[str, Sequence], - data: torch.Tensor | ArrayLike | float | int, + data: torch.Tensor | ArrayLike, ) -> None: """Write *data* at the location given by *coords*. *coords* can be: - ``{axis: float}`` centre coordinates — single patch. - ``{axis: Sequence[float]}`` centres — batch. + + Raises + ------ + TypeError + If *data* is a scalar (int or float). Use an array or tensor instead. """ + if isinstance(data, (int, float)): + raise TypeError( + "Scalar writes are not supported. " + "Provide an array or tensor with the patch shape instead." + ) first = next(iter(coords.values())) if isinstance(first, (int, float)): self._write_single(coords, data) # type: ignore[arg-type] @@ -209,6 +219,13 @@ def _write_single( data_np = data.detach().cpu().numpy() else: data_np = np.asarray(data) + + if data_np.ndim == 0: + raise TypeError( + "Scalar writes are not supported. " + "Provide an array or tensor with the patch shape instead." + ) + data_np = data_np.astype(self.dtype) # Strip batch / channel leading dims of size 1 @@ -236,8 +253,7 @@ def _write_batch( n = len(next(iter(batch_coords.values()))) for i in range(n): center = {ax: float(batch_coords[ax][i]) for ax in self.spatial_axes} - item = data[i] if hasattr(data, "__getitem__") else data # type: ignore[index] - self._write_single(center, item) + self._write_single(center, data[i]) # type: ignore[index] def __getitem__(self, coords: Mapping[str, float]) -> torch.Tensor: """Read the patch centred at *coords*.""" From ab1bf11151f8998c808283f83b6fc2f5b97381ac Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 13 Mar 2026 18:58:48 +0000 Subject: [PATCH 13/33] Fix total_voxels to use actual data volume spatial shape Co-authored-by: rhoadesScholar <37990507+rhoadesScholar@users.noreply.github.com> --- src/cellmap_data/dataset.py | 13 +++++++++++-- src/cellmap_data/empty_image.py | 4 ++++ src/cellmap_data/image.py | 18 +++++++++++++++++ src/cellmap_data/multidataset.py | 29 ++++++++++++++++++++-------- tests/test_image.py | 8 ++++++++ tests/test_multidataset.py | 33 ++++++++++++++++++++++++++++++++ 6 files changed, 95 insertions(+), 10 deletions(-) diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index c0b3630..e31423d 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -408,16 +408,25 @@ def get_crop_class_matrix(self) -> np.ndarray: @property def class_counts(self) -> dict[str, Any]: - """Aggregate per-class foreground voxel counts from all target sources.""" + """Aggregate per-class foreground voxel counts from all target sources. + + Returns a dict with: + - ``"totals"``: per-class foreground voxel counts at training resolution. + - ``"totals_total"``: per-class total voxel counts (full array size) at + training resolution. + """ totals: dict[str, int] = {} + totals_total: dict[str, int] = {} for cls in self.classes: src = self.target_sources.get(cls) if src is not None: counts = src.class_counts totals[cls] = counts.get(cls, 0) + totals_total[cls] = src.total_voxels else: totals[cls] = 0 - return {"totals": totals} + totals_total[cls] = 0 + return {"totals": totals, "totals_total": totals_total} # ------------------------------------------------------------------ # Misc diff --git a/src/cellmap_data/empty_image.py b/src/cellmap_data/empty_image.py index 3869efc..062f268 100644 --- a/src/cellmap_data/empty_image.py +++ b/src/cellmap_data/empty_image.py @@ -62,6 +62,10 @@ def bounding_box(self) -> None: def sampling_box(self) -> None: return None + @property + def total_voxels(self) -> int: + return 0 + @property def class_counts(self) -> dict[str, int]: return {self.label_class: 0} diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index a2ae5bc..99c180a 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -553,6 +553,24 @@ def class_counts(self) -> dict[str, int]: logger.warning("class_counts failed for %s: %s", self.path, exc) return {self.label_class: 0} + @property + def total_voxels(self) -> int: + """Total number of voxels in the data volume at training resolution. + + Computed as the product of the s0 array's spatial dimensions scaled + to the training-resolution voxel size via :meth:`_scale_count`. + """ + try: + s0_path = self._level_info[0][0] + s0_arr = zarr.open_array(f"{self.path}/{s0_path}", mode="r") + n_spatial = len(self.axes) + spatial_shape = s0_arr.shape[-n_spatial:] + total_s0 = int(np.prod(spatial_shape)) + return self._scale_count(total_s0, s0_idx=0) + except Exception as exc: + logger.warning("total_voxels failed for %s: %s", self.path, exc) + return 0 + def _scale_count(self, s0_count: int, s0_idx: int = 0) -> int: """Scale a voxel count from s0 resolution to training resolution.""" try: diff --git a/src/cellmap_data/multidataset.py b/src/cellmap_data/multidataset.py index a9cc49c..41170da 100644 --- a/src/cellmap_data/multidataset.py +++ b/src/cellmap_data/multidataset.py @@ -58,23 +58,36 @@ def class_counts(self) -> dict[str, Any]: Sequential scan (parallelism offers no benefit over NFS; see project MEMORY.md notes on ``CellMapMultiDataset.class_counts``). + + Returns a dict with: + - ``"totals"``: per-class foreground voxel counts. + - ``"totals_total"``: per-class total voxel counts (full array sizes). """ totals: dict[str, int] = {cls: 0 for cls in self.classes} + totals_total: dict[str, int] = {cls: 0 for cls in self.classes} for ds in tqdm(self.datasets, desc="Counting class voxels", leave=False): - ds_counts = ds.class_counts.get("totals", {}) + ds_counts = ds.class_counts for cls in self.classes: - totals[cls] += ds_counts.get(cls, 0) - return {"totals": totals} + totals[cls] += ds_counts.get("totals", {}).get(cls, 0) + totals_total[cls] += ds_counts.get("totals_total", {}).get(cls, 0) + return {"totals": totals, "totals_total": totals_total} @property def class_weights(self) -> dict[str, float]: - """Per-class sampling weight: ``bg_voxels / fg_voxels``.""" - counts = self.class_counts["totals"] - total_voxels = sum(counts.values()) + """Per-class sampling weight: ``bg_voxels / fg_voxels``. + + Background voxels for each class are derived from the actual data + volume size (``totals_total``) minus foreground voxels, so the ratio + correctly reflects the class imbalance within each volume. + """ + counts = self.class_counts + fg_counts = counts["totals"] + total_counts = counts["totals_total"] weights: dict[str, float] = {} for cls in self.classes: - fg = counts.get(cls, 0) - bg = total_voxels - fg + fg = fg_counts.get(cls, 0) + total = total_counts.get(cls, 0) + bg = max(total - fg, 0) weights[cls] = float(bg) / float(max(fg, 1)) return weights diff --git a/tests/test_image.py b/tests/test_image.py index da7be0a..92bd769 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -261,3 +261,11 @@ def test_class_counts_keys(self, tmp_path): counts = img.class_counts assert "mito" in counts assert counts["mito"] >= 0 + + def test_total_voxels_equals_array_size(self, tmp_path): + shape = (10, 10, 10) + data = np.zeros(shape, dtype=np.uint8) + data[2:5, 2:5, 2:5] = 1 + path = create_test_zarr(tmp_path, shape=shape, data=data) + img = CellMapImage(path, "mito", [8.0, 8.0, 8.0], [4, 4, 4]) + assert img.total_voxels == int(np.prod(shape)) diff --git a/tests/test_multidataset.py b/tests/test_multidataset.py index bb4b95c..acaf98a 100644 --- a/tests/test_multidataset.py +++ b/tests/test_multidataset.py @@ -62,7 +62,40 @@ def test_class_counts_keys(self, tmp_path): multi = CellMapMultiDataset([ds1], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS) counts = multi.class_counts assert "totals" in counts + assert "totals_total" in counts assert all(c in counts["totals"] for c in CLASSES) + assert all(c in counts["totals_total"] for c in CLASSES) + + def test_class_counts_total_equals_volume_size(self, tmp_path): + """totals_total should reflect the actual array volume, not sum of fg counts.""" + shape = (8, 8, 8) + voxel_size = [8.0, 8.0, 8.0] + ds1 = _make_ds(tmp_path, "d1", shape=shape, voxel_size=voxel_size) + multi = CellMapMultiDataset([ds1], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS) + counts = multi.class_counts + expected_total = int(np.prod(shape)) + for cls in CLASSES: + assert counts["totals_total"][cls] == expected_total, ( + f"totals_total[{cls!r}] should equal array volume {expected_total}, " + f"got {counts['totals_total'][cls]}" + ) + + def test_class_weights_bg_uses_volume_size(self, tmp_path): + """bg in class_weights should be total_voxels - fg, not sum(fg) - fg.""" + shape = (8, 8, 8) + voxel_size = [8.0, 8.0, 8.0] + ds1 = _make_ds(tmp_path, "d1", shape=shape, voxel_size=voxel_size) + multi = CellMapMultiDataset([ds1], CLASSES, INPUT_ARRAYS, TARGET_ARRAYS) + counts = multi.class_counts + weights = multi.class_weights + total = int(np.prod(shape)) + for cls in CLASSES: + fg = counts["totals"][cls] + expected_bg = total - fg + expected_weight = float(max(expected_bg, 0)) / float(max(fg, 1)) + assert abs(weights[cls] - expected_weight) < 1e-6, ( + f"class_weights[{cls!r}] mismatch: expected {expected_weight}, got {weights[cls]}" + ) def test_get_crop_class_matrix_shape(self, tmp_path): ds1 = _make_ds(tmp_path, "d1") From b2db1b6a90602e4c7f707d6a54d5e37503b519e1 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:10:03 -0400 Subject: [PATCH 14/33] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/cellmap_data/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 99c180a..27e71b6 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -553,7 +553,7 @@ def class_counts(self) -> dict[str, int]: logger.warning("class_counts failed for %s: %s", self.path, exc) return {self.label_class: 0} - @property + @cached_property def total_voxels(self) -> int: """Total number of voxels in the data volume at training resolution. From 357f2e017ca5226bc5402172df03989ba7ce2404 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:11:38 -0400 Subject: [PATCH 15/33] Update src/cellmap_data/dataset.py --- src/cellmap_data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index e31423d..6b19a6d 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -406,7 +406,7 @@ def get_crop_class_matrix(self) -> np.ndarray: # Class counts # ------------------------------------------------------------------ - @property + @cached_property def class_counts(self) -> dict[str, Any]: """Aggregate per-class foreground voxel counts from all target sources. From 9fa89531cbd269fa29abfbf4d945837e6412df96 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 13 Mar 2026 15:13:21 -0400 Subject: [PATCH 16/33] black format --- tests/test_multidataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_multidataset.py b/tests/test_multidataset.py index acaf98a..29f02ee 100644 --- a/tests/test_multidataset.py +++ b/tests/test_multidataset.py @@ -93,9 +93,9 @@ def test_class_weights_bg_uses_volume_size(self, tmp_path): fg = counts["totals"][cls] expected_bg = total - fg expected_weight = float(max(expected_bg, 0)) / float(max(fg, 1)) - assert abs(weights[cls] - expected_weight) < 1e-6, ( - f"class_weights[{cls!r}] mismatch: expected {expected_weight}, got {weights[cls]}" - ) + assert ( + abs(weights[cls] - expected_weight) < 1e-6 + ), f"class_weights[{cls!r}] mismatch: expected {expected_weight}, got {weights[cls]}" def test_get_crop_class_matrix_shape(self, tmp_path): ds1 = _make_ds(tmp_path, "d1") From 3ec2dab23f5894bd1adff8d8cd98503521b707cd Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:14:41 -0400 Subject: [PATCH 17/33] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/cellmap_data/image_writer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/cellmap_data/image_writer.py b/src/cellmap_data/image_writer.py index a5a6b8d..e59f593 100644 --- a/src/cellmap_data/image_writer.py +++ b/src/cellmap_data/image_writer.py @@ -181,7 +181,10 @@ def __setitem__( Raises ------ TypeError - If *data* is a scalar (int or float). Use an array or tensor instead. + If *data* is a scalar (i.e. ``np.isscalar(data)`` is ``True``, including + Python and NumPy scalar types). Use a non-scalar array or tensor with + shape matching the patch instead. Zero-dimensional arrays/tensors are + also not supported for writes. """ if np.isscalar(data): raise TypeError( From 99e0dd7d8b510777c1504de75aa992a6f775b325 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:15:39 -0400 Subject: [PATCH 18/33] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/cellmap_data/image_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cellmap_data/image_writer.py b/src/cellmap_data/image_writer.py index e59f593..82e7afb 100644 --- a/src/cellmap_data/image_writer.py +++ b/src/cellmap_data/image_writer.py @@ -226,7 +226,7 @@ def _write_single( if data_np.ndim == 0: raise TypeError( "Scalar writes are not supported. " - "Provide an array or tensor with the patch shape instead." + "Pass an array or tensor with shape matching the patch." ) data_np = data_np.astype(self.dtype) From 0d4bf5a7bcbb3c888e6c5fba39ed2da02f29a81c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 13 Mar 2026 19:19:51 +0000 Subject: [PATCH 19/33] Improve total_voxels to use cached bounding_box instead of reopening s0 Co-authored-by: rhoadesScholar <37990507+rhoadesScholar@users.noreply.github.com> --- src/cellmap_data/image.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 27e71b6..cf7914b 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -557,16 +557,15 @@ def class_counts(self) -> dict[str, int]: def total_voxels(self) -> int: """Total number of voxels in the data volume at training resolution. - Computed as the product of the s0 array's spatial dimensions scaled - to the training-resolution voxel size via :meth:`_scale_count`. + Derived from the cached :attr:`bounding_box` (world-space extent of the + dataset) divided by the training-resolution voxel size, so no additional + zarr I/O is needed beyond what is already cached. """ try: - s0_path = self._level_info[0][0] - s0_arr = zarr.open_array(f"{self.path}/{s0_path}", mode="r") - n_spatial = len(self.axes) - spatial_shape = s0_arr.shape[-n_spatial:] - total_s0 = int(np.prod(spatial_shape)) - return self._scale_count(total_s0, s0_idx=0) + total = 1 + for ax, (start, end) in self.bounding_box.items(): + total *= int(round((end - start) / self.scale[ax])) + return total except Exception as exc: logger.warning("total_voxels failed for %s: %s", self.path, exc) return 0 From 7554d4c19932d28f946454731c1958a25f696279 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 13 Mar 2026 19:22:50 +0000 Subject: [PATCH 20/33] test: add test for scalar-in-batch TypeError and explicit guard in dataset_writer Co-authored-by: rhoadesScholar <37990507+rhoadesScholar@users.noreply.github.com> --- src/cellmap_data/dataset_writer.py | 5 +++++ tests/test_writer.py | 12 ++++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index 263bce5..271c493 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -252,6 +252,11 @@ def __setitem__( for key, val in arrays.items(): if key in _SKIP_KEYS: continue + if np.isscalar(val): + raise TypeError( + f"Scalar writes are not supported (key={key!r}). " + "Pass an array or tensor with a leading batch dimension." + ) if isinstance(val, dict): item[key] = {k: v[batch_i] for k, v in val.items()} else: diff --git a/tests/test_writer.py b/tests/test_writer.py index 9395197..27d9568 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -111,6 +111,18 @@ def test_setitem_batch(self, tmp_path): output = {"mito": torch.zeros(2, 4, 4, 4)} writer[idx_tensor] = output # should not raise + def test_setitem_batch_scalar_raises(self, tmp_path): + """Passing a scalar value in a batch write must raise TypeError.""" + import pytest + + writer = self._make_writer(tmp_path) + indices = writer.writer_indices[:2] + idx_tensor = torch.tensor(indices) + # Scalar instead of a batched array — should raise + output = {"mito": 1.0} + with pytest.raises(TypeError, match="Scalar writes are not supported"): + writer[idx_tensor] = output + def test_loader_iterable(self, tmp_path): writer = self._make_writer(tmp_path) loader = writer.loader(batch_size=2) From d26ace84411742f172254b44f22d4b34b235d798 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 13 Mar 2026 15:27:38 -0400 Subject: [PATCH 21/33] fix: add type check for written values in CellMapDatasetWriter --- src/cellmap_data/dataset_writer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index 271c493..48129c7 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -259,8 +259,12 @@ def __setitem__( ) if isinstance(val, dict): item[key] = {k: v[batch_i] for k, v in val.items()} - else: + elif hasattr(val, "__getitem__") and not isinstance(val, str): item[key] = val[batch_i] + else: + raise TypeError( + "Written values should be ArrayLike, or dictionaries of ArrayLike objects." + ) self.__setitem__(int(i), item) return From 8d446ccf56559040a610122071f212d658650cff Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Mar 2026 17:42:08 +0000 Subject: [PATCH 22/33] refactor: apply rhoadesScholar suggestion - drop else branch, keep np.isscalar + hasattr guards Co-authored-by: rhoadesScholar <37990507+rhoadesScholar@users.noreply.github.com> --- src/cellmap_data/dataset_writer.py | 4 ---- tests/test_writer.py | 3 +-- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index 48129c7..f1d11cb 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -261,10 +261,6 @@ def __setitem__( item[key] = {k: v[batch_i] for k, v in val.items()} elif hasattr(val, "__getitem__") and not isinstance(val, str): item[key] = val[batch_i] - else: - raise TypeError( - "Written values should be ArrayLike, or dictionaries of ArrayLike objects." - ) self.__setitem__(int(i), item) return diff --git a/tests/test_writer.py b/tests/test_writer.py index 27d9568..de96ed3 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -3,6 +3,7 @@ from __future__ import annotations import numpy as np +import pytest import torch from cellmap_data import CellMapDatasetWriter @@ -113,8 +114,6 @@ def test_setitem_batch(self, tmp_path): def test_setitem_batch_scalar_raises(self, tmp_path): """Passing a scalar value in a batch write must raise TypeError.""" - import pytest - writer = self._make_writer(tmp_path) indices = writer.writer_indices[:2] idx_tensor = torch.tensor(indices) From ef331536c547bbcdc4b285a3eed26c4cc81d7707 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Tue, 17 Mar 2026 14:23:48 -0400 Subject: [PATCH 23/33] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/cellmap_data/image.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index cf7914b..8fb113d 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -564,7 +564,10 @@ def total_voxels(self) -> int: try: total = 1 for ax, (start, end) in self.bounding_box.items(): - total *= int(round((end - start) / self.scale[ax])) + axis_voxels = int(round((end - start) / self.scale[ax])) + if axis_voxels < 1: + axis_voxels = 1 + total *= axis_voxels return total except Exception as exc: logger.warning("total_voxels failed for %s: %s", self.path, exc) From 02f8b6babc73ede286c88ae7eb06f184837158b3 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Wed, 18 Mar 2026 13:44:59 -0400 Subject: [PATCH 24/33] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/cellmap_data/dataset_writer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index f1d11cb..6a35f60 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -261,6 +261,13 @@ def __setitem__( item[key] = {k: v[batch_i] for k, v in val.items()} elif hasattr(val, "__getitem__") and not isinstance(val, str): item[key] = val[batch_i] + else: + raise TypeError( + "Unsupported batched value type for key " + f"{key!r}: {type(val).__name__}. Expected a dict of " + "batch-indexable values or a batch-indexable " + "array/tensor/sequence." + ) self.__setitem__(int(i), item) return From c6d1338cdb636303e4ef1c31f61618e68535c106 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Wed, 18 Mar 2026 13:45:43 -0400 Subject: [PATCH 25/33] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/cellmap_data/image_writer.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/cellmap_data/image_writer.py b/src/cellmap_data/image_writer.py index 82e7afb..7efb0d2 100644 --- a/src/cellmap_data/image_writer.py +++ b/src/cellmap_data/image_writer.py @@ -191,6 +191,18 @@ def __setitem__( "Scalar writes are not supported. " "Pass an array or tensor with shape matching the patch." ) + # Explicitly reject zero-dimensional arrays/tensors, which are not caught + # by np.isscalar and are documented as unsupported for writes. + if isinstance(data, np.ndarray) and data.ndim == 0: + raise TypeError( + "Zero-dimensional NumPy arrays are not supported for writes. " + "Pass a non-scalar array or tensor with shape matching the patch." + ) + if torch.is_tensor(data) and data.dim() == 0: + raise TypeError( + "Zero-dimensional torch.Tensors are not supported for writes. " + "Pass a non-scalar tensor or array with shape matching the patch." + ) first = next(iter(coords.values())) if isinstance(first, (int, float)): self._write_single(coords, data) # type: ignore[arg-type] From c7141ced9e7e3bb1ee2ac0f47a77b5524c69ba4c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 18 Mar 2026 18:05:34 +0000 Subject: [PATCH 26/33] test: add coverage for 0-D array/tensor rejection and unsupported batch type Co-authored-by: rhoadesScholar <37990507+rhoadesScholar@users.noreply.github.com> --- tests/test_writer.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_writer.py b/tests/test_writer.py index de96ed3..5fca73b 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -57,6 +57,30 @@ def test_repr(self, tmp_path): ) assert "ImageWriter" in repr(writer) + def _make_image_writer(self, tmp_path): + return ImageWriter( + path=str(tmp_path / "out.zarr" / "mito"), + target_class="mito", + scale={"z": 8.0, "y": 8.0, "x": 8.0}, + bounding_box={"z": (0.0, 128.0), "y": (0.0, 128.0), "x": (0.0, 128.0)}, + write_voxel_shape={"z": 4, "y": 4, "x": 4}, + overwrite=True, + ) + + def test_setitem_zero_dim_ndarray_raises(self, tmp_path): + """A 0-D NumPy array must raise TypeError with a clear message.""" + writer = self._make_image_writer(tmp_path) + center = {"z": 16.0, "y": 16.0, "x": 16.0} + with pytest.raises(TypeError, match="Zero-dimensional NumPy arrays"): + writer[center] = np.array(1.0) + + def test_setitem_zero_dim_tensor_raises(self, tmp_path): + """A 0-D torch.Tensor must raise TypeError with a clear message.""" + writer = self._make_image_writer(tmp_path) + center = {"z": 16.0, "y": 16.0, "x": 16.0} + with pytest.raises(TypeError, match="Zero-dimensional torch.Tensors"): + writer[center] = torch.tensor(1.0) + class TestCellMapDatasetWriter: def _make_writer(self, tmp_path): @@ -122,6 +146,14 @@ def test_setitem_batch_scalar_raises(self, tmp_path): with pytest.raises(TypeError, match="Scalar writes are not supported"): writer[idx_tensor] = output + def test_setitem_batch_unsupported_type_raises(self, tmp_path): + """A non-dict, non-indexable value in a batch write must raise TypeError.""" + writer = self._make_writer(tmp_path) + indices = writer.writer_indices[:2] + idx_tensor = torch.tensor(indices) + with pytest.raises(TypeError, match="Unsupported batched value type"): + writer[idx_tensor] = {"mito": object()} + def test_loader_iterable(self, tmp_path): writer = self._make_writer(tmp_path) loader = writer.loader(batch_size=2) From bc8792c809d455eff5001347e95b049885df6266 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 20 Mar 2026 13:17:13 -0400 Subject: [PATCH 27/33] fix: adjust bounding box handling and improve target array writing logic --- src/cellmap_data/dataset_writer.py | 42 +++++++++++++++--------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index 6a35f60..7682bb7 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -149,7 +149,9 @@ def sampling_box(self) -> dict[str, tuple[float, float]] | None: lo = bb[ax][0] + h hi = bb[ax][1] - h if lo >= hi: - return None + # Bounding box is smaller than one write patch; center a single tile + center = (bb[ax][0] + bb[ax][1]) / 2.0 + lo = hi = center result[ax] = (lo, hi) return result @@ -276,28 +278,26 @@ def __setitem__( for key, val in arrays.items(): if key in _SKIP_KEYS: continue - # Find which target array and class this key maps to - for arr_name, writers in self.target_array_writers.items(): - if key in writers: - writers[key][center] = val - elif key in self.classes: - # Flat class key — write to first matching target array + if key in self.target_array_writers: + # key is an array name — val is either a class dict or a multi-class tensor + writers = self.target_array_writers[key] + if isinstance(val, dict): + for cls, tensor in val.items(): + if cls in writers: + writers[cls][center] = tensor + else: + # tensor shape (C, ...) — split channels by class order + for i, cls in enumerate(self.model_classes): + if cls in writers: + # Use slice i:i+1 to preserve the leading dim so that + # _write_single can strip singleton dims correctly + writers[cls][center] = val[i:i+1] if val.ndim > 0 and val.shape[0] > i else val + elif key in self.classes: + # key is a class name — write to matching writer in any target array + for writers in self.target_array_writers.values(): if key in writers: writers[key][center] = val - else: - # Write per channel if val is multi-channel - cls_idx = ( - self.model_classes.index(key) - if key in self.model_classes - else None - ) - if key in writers: - writers[key][center] = ( - val[cls_idx] - if cls_idx is not None and val.ndim > 0 - else val - ) - break + break # ------------------------------------------------------------------ # DataLoader helper From c0cad73556fc2e344ebb9aad5a2ade5e43096990 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 20 Mar 2026 13:18:46 -0400 Subject: [PATCH 28/33] black format --- src/cellmap_data/dataset_writer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index 7682bb7..f3aebd0 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -291,7 +291,11 @@ def __setitem__( if cls in writers: # Use slice i:i+1 to preserve the leading dim so that # _write_single can strip singleton dims correctly - writers[cls][center] = val[i:i+1] if val.ndim > 0 and val.shape[0] > i else val + writers[cls][center] = ( + val[i : i + 1] + if val.ndim > 0 and val.shape[0] > i + else val + ) elif key in self.classes: # key is a class name — write to matching writer in any target array for writers in self.target_array_writers.values(): From 08fd903690c1cc14be784eb7cbd5d9c694cabd77 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Fri, 20 Mar 2026 13:25:59 -0400 Subject: [PATCH 29/33] Update tests/test_dataset.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tests/test_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index cada76b..c89f0e5 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -210,7 +210,6 @@ def test_small_crop_pad_true_len_one(self, tmp_path): gt_base = str(tmp_path / "gt.zarr") os.makedirs(gt_base, exist_ok=True) - import json with open(os.path.join(gt_base, ".zgroup"), "w") as f: f.write('{"zarr_format": 2}') From 295b90da2e054dc0c7b5a7d6119b4d49570c843a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 17:30:47 +0000 Subject: [PATCH 30/33] Initial plan From 6c895fa81a24c11c6ec261b6669e8f2401ddee1a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 17:38:39 +0000 Subject: [PATCH 31/33] Fix review comments: unused imports, min_redundant_inds, ClassBalancedSampler Co-authored-by: rhoadesScholar <37990507+rhoadesScholar@users.noreply.github.com> Agent-Logs-Url: https://github.com/janelia-cellmap/cellmap-data/sessions/1e323ffb-2a14-4aae-99e1-8e70c1d393d3 --- src/cellmap_data/sampler.py | 45 +++++++++++++++++++++++++++++++++- src/cellmap_data/utils/misc.py | 9 +++++-- tests/test_dataset.py | 1 - tests/test_geometry.py | 2 -- tests/test_image.py | 2 -- tests/test_multidataset.py | 3 --- tests/test_sampler.py | 1 - 7 files changed, 51 insertions(+), 12 deletions(-) diff --git a/src/cellmap_data/sampler.py b/src/cellmap_data/sampler.py index 426d829..4de777a 100644 --- a/src/cellmap_data/sampler.py +++ b/src/cellmap_data/sampler.py @@ -64,6 +64,12 @@ def __init__( self.class_to_crops[c] = indices self.active_classes: list[int] = sorted(self.class_to_crops.keys()) + if not self.active_classes: + raise ValueError( + "ClassBalancedSampler: no active classes found in crop-class " + "matrix. This can occur when all requested classes are only " + "represented by empty crops (e.g., EmptyImage)." + ) def __iter__(self) -> Iterator[int]: class_counts = np.zeros(self.n_classes, dtype=np.float64) @@ -86,7 +92,44 @@ def __iter__(self) -> Iterator[int]: annotated = np.where(self.crop_class_matrix[crop_idx])[0] class_counts[annotated] += 1.0 - yield crop_idx + # Map crop index (dataset-level row) to an actual sample index. + # If n_crops equals len(dataset), the crop index IS the sample index. + if self.n_crops == len(self.dataset): + sample_idx = crop_idx + elif hasattr(self.dataset, "datasets") and hasattr( + self.dataset, "cumulative_sizes" + ): + # ConcatDataset / CellMapMultiDataset: each crop row corresponds + # to one sub-dataset; pick a random sample within that sub-dataset. + cumulative_sizes = self.dataset.cumulative_sizes + if crop_idx < len(cumulative_sizes): + start = int(cumulative_sizes[crop_idx - 1]) if crop_idx > 0 else 0 + end = int(cumulative_sizes[crop_idx]) + else: + start, end = 0, len(self.dataset) + if start >= end or end > len(self.dataset): + start, end = 0, len(self.dataset) + sample_idx = int(self.rng.integers(start, end)) + else: + # Generic fallback: partition [0, len(dataset)) into n_crops + # contiguous segments and sample within this crop's segment. + total = len(self.dataset) + if self.n_crops <= 1 or total <= 0: + start, end = 0, max(total, 1) + else: + base = total // self.n_crops + remainder = total % self.n_crops + if crop_idx < remainder: + start = crop_idx * (base + 1) + end = start + (base + 1) + else: + start = remainder * (base + 1) + (crop_idx - remainder) * base + end = start + base + if start >= end or end > total: + start, end = 0, total + sample_idx = int(self.rng.integers(start, end)) + + yield sample_idx def __len__(self) -> int: return self.samples_per_epoch diff --git a/src/cellmap_data/utils/misc.py b/src/cellmap_data/utils/misc.py index 42a523b..8341a71 100644 --- a/src/cellmap_data/utils/misc.py +++ b/src/cellmap_data/utils/misc.py @@ -140,7 +140,12 @@ def min_redundant_inds( return torch.randint(n, (k,), generator=rng) else: if k > n: - # Repeat the unique indices until we have k indices - return torch.cat([torch.randperm(n, generator=rng) for _ in range(k // n)]) + # Repeat unique indices until we have k indices (handle remainder) + full_perms = k // n + remainder = k % n + parts = [torch.randperm(n, generator=rng) for _ in range(full_perms)] + if remainder > 0: + parts.append(torch.randperm(n, generator=rng)[:remainder]) + return torch.cat(parts) else: return torch.randperm(n, generator=rng)[:k] diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c89f0e5..2ee1166 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -254,7 +254,6 @@ def test_small_crop_pad_false_excluded(self, tmp_path): gt_base = str(tmp_path / "gt.zarr") os.makedirs(gt_base, exist_ok=True) - import json with open(os.path.join(gt_base, ".zgroup"), "w") as f: f.write('{"zarr_format": 2}') diff --git a/tests/test_geometry.py b/tests/test_geometry.py index ff1d840..dd1110f 100644 --- a/tests/test_geometry.py +++ b/tests/test_geometry.py @@ -2,8 +2,6 @@ from __future__ import annotations -import pytest - from cellmap_data.utils.geometry import box_intersection, box_shape, box_union diff --git a/tests/test_image.py b/tests/test_image.py index 92bd769..9b555c6 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -252,8 +252,6 @@ def test_rotation_output_shape_preserved(self, tmp_path): class TestCellMapImageClassCounts: def test_class_counts_keys(self, tmp_path): - import zarr as z - data = np.zeros((10, 10, 10), dtype=np.uint8) data[2:5, 2:5, 2:5] = 1 # some foreground path = create_test_zarr(tmp_path, shape=(10, 10, 10), data=data) diff --git a/tests/test_multidataset.py b/tests/test_multidataset.py index 29f02ee..44ca43b 100644 --- a/tests/test_multidataset.py +++ b/tests/test_multidataset.py @@ -3,7 +3,6 @@ from __future__ import annotations import numpy as np -import torch from cellmap_data import CellMapDataset, CellMapMultiDataset @@ -15,8 +14,6 @@ def _make_ds(tmp_path, suffix="", **kwargs): - import tempfile, pathlib - sub = tmp_path / suffix if suffix else tmp_path / "ds0" sub.mkdir(parents=True, exist_ok=True) info = create_test_dataset(sub, classes=CLASSES, **kwargs) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 6b7f4b5..d23b294 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -3,7 +3,6 @@ from __future__ import annotations import numpy as np -import pytest from cellmap_data.sampler import ClassBalancedSampler From ddb6246890d877a5317b776def5dcb381d80a039 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Fri, 20 Mar 2026 14:36:21 -0400 Subject: [PATCH 32/33] Update src/cellmap_data/sampler.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/cellmap_data/sampler.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/cellmap_data/sampler.py b/src/cellmap_data/sampler.py index 4de777a..7ed2e57 100644 --- a/src/cellmap_data/sampler.py +++ b/src/cellmap_data/sampler.py @@ -106,9 +106,18 @@ def __iter__(self) -> Iterator[int]: start = int(cumulative_sizes[crop_idx - 1]) if crop_idx > 0 else 0 end = int(cumulative_sizes[crop_idx]) else: - start, end = 0, len(self.dataset) + raise ValueError( + "ClassBalancedSampler: crop index out of range for " + "ConcatDataset/CellMapMultiDataset mapping. " + f"crop_idx={crop_idx}, n_subdatasets={len(cumulative_sizes)}" + ) if start >= end or end > len(self.dataset): - start, end = 0, len(self.dataset) + raise ValueError( + "ClassBalancedSampler: invalid sub-dataset slice computed " + "from cumulative_sizes for crop index " + f"{crop_idx}: start={start}, end={end}, " + f"len(dataset)={len(self.dataset)}" + ) sample_idx = int(self.rng.integers(start, end)) else: # Generic fallback: partition [0, len(dataset)) into n_crops From 4448a5642b6af723ddbe914df49ccaf002fcf33e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 18:43:54 +0000 Subject: [PATCH 33/33] Add tests for ClassBalancedSampler edge cases and min_redundant_inds; fix naming and docstrings Co-authored-by: rhoadesScholar <37990507+rhoadesScholar@users.noreply.github.com> Agent-Logs-Url: https://github.com/janelia-cellmap/cellmap-data/sessions/b1ce1f6e-fdc4-47b9-aa82-562c86c4800d --- src/cellmap_data/sampler.py | 41 ++++++++-------- src/cellmap_data/utils/misc.py | 6 ++- tests/test_sampler.py | 88 ++++++++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 22 deletions(-) diff --git a/src/cellmap_data/sampler.py b/src/cellmap_data/sampler.py index 7ed2e57..477fa73 100644 --- a/src/cellmap_data/sampler.py +++ b/src/cellmap_data/sampler.py @@ -24,8 +24,9 @@ class ClassBalancedSampler(Sampler): ``dataset.get_crop_class_matrix()`` → ``bool[n_crops, n_classes]``. 2. Maintain running counts of how many times each class has been seen. 3. At each step: pick the class with the lowest count (ties broken - randomly), sample a crop annotating it, yield that crop index, then - increment counts for *all* classes that crop annotates. + randomly), sample a matrix row (crop) annotating it, map that row to + an actual dataset sample index, and yield the sample index. Then + increment counts for *all* classes that row annotates. This guarantees rare classes get sampled as often as common ones. @@ -85,54 +86,54 @@ def __iter__(self) -> Iterator[int]: ] target_class = int(self.rng.choice(tied)) - # Sample a crop that annotates this class - crop_idx = int(self.rng.choice(self.class_to_crops[target_class])) + # Sample a matrix row (crop) that annotates this class + row_idx = int(self.rng.choice(self.class_to_crops[target_class])) - # Increment counts for all classes this crop annotates - annotated = np.where(self.crop_class_matrix[crop_idx])[0] + # Increment counts for all classes this row annotates + annotated = np.where(self.crop_class_matrix[row_idx])[0] class_counts[annotated] += 1.0 - # Map crop index (dataset-level row) to an actual sample index. - # If n_crops equals len(dataset), the crop index IS the sample index. + # Map matrix row (dataset-level row) to an actual sample index. + # If n_crops equals len(dataset), the row index IS the sample index. if self.n_crops == len(self.dataset): - sample_idx = crop_idx + sample_idx = row_idx elif hasattr(self.dataset, "datasets") and hasattr( self.dataset, "cumulative_sizes" ): - # ConcatDataset / CellMapMultiDataset: each crop row corresponds + # ConcatDataset / CellMapMultiDataset: each row corresponds # to one sub-dataset; pick a random sample within that sub-dataset. cumulative_sizes = self.dataset.cumulative_sizes - if crop_idx < len(cumulative_sizes): - start = int(cumulative_sizes[crop_idx - 1]) if crop_idx > 0 else 0 - end = int(cumulative_sizes[crop_idx]) + if row_idx < len(cumulative_sizes): + start = int(cumulative_sizes[row_idx - 1]) if row_idx > 0 else 0 + end = int(cumulative_sizes[row_idx]) else: raise ValueError( "ClassBalancedSampler: crop index out of range for " "ConcatDataset/CellMapMultiDataset mapping. " - f"crop_idx={crop_idx}, n_subdatasets={len(cumulative_sizes)}" + f"row_idx={row_idx}, n_subdatasets={len(cumulative_sizes)}" ) if start >= end or end > len(self.dataset): raise ValueError( "ClassBalancedSampler: invalid sub-dataset slice computed " - "from cumulative_sizes for crop index " - f"{crop_idx}: start={start}, end={end}, " + "from cumulative_sizes for row index " + f"{row_idx}: start={start}, end={end}, " f"len(dataset)={len(self.dataset)}" ) sample_idx = int(self.rng.integers(start, end)) else: # Generic fallback: partition [0, len(dataset)) into n_crops - # contiguous segments and sample within this crop's segment. + # contiguous segments and sample within this row's segment. total = len(self.dataset) if self.n_crops <= 1 or total <= 0: start, end = 0, max(total, 1) else: base = total // self.n_crops remainder = total % self.n_crops - if crop_idx < remainder: - start = crop_idx * (base + 1) + if row_idx < remainder: + start = row_idx * (base + 1) end = start + (base + 1) else: - start = remainder * (base + 1) + (crop_idx - remainder) * base + start = remainder * (base + 1) + (row_idx - remainder) * base end = start + base if start >= end or end > total: start, end = 0, total diff --git a/src/cellmap_data/utils/misc.py b/src/cellmap_data/utils/misc.py index 8341a71..4ac7782 100644 --- a/src/cellmap_data/utils/misc.py +++ b/src/cellmap_data/utils/misc.py @@ -124,7 +124,9 @@ def min_redundant_inds( ) -> torch.Tensor: """Returns k indices from 0 to n-1 with minimum redundancy. - If replacement is False, the indices are unique. + If replacement is False and k <= n, the indices are unique. + If replacement is False and k > n, duplicates are unavoidable; indices + are unique within each block of size n (minimum redundancy overall). If replacement is True, the indices can have duplicates. Args: @@ -134,7 +136,7 @@ def min_redundant_inds( rng (torch.Generator, optional): The random number generator. Defaults to None. Returns: - torch.Tensor: A tensor of k indices. + torch.Tensor: A tensor of exactly k indices. """ if replacement: return torch.randint(n, (k,), generator=rng) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index d23b294..e4d168d 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -3,8 +3,10 @@ from __future__ import annotations import numpy as np +import pytest from cellmap_data.sampler import ClassBalancedSampler +from cellmap_data.utils.misc import min_redundant_inds class FakeDataset: @@ -20,6 +22,27 @@ def __len__(self) -> int: return self._matrix.shape[0] +class FakeConcatDataset: + """Minimal ConcatDataset-like dataset with datasets + cumulative_sizes.""" + + def __init__(self, sub_lengths: list[int], matrix: np.ndarray): + self._matrix = matrix + self._sub_lengths = sub_lengths + # Cumulative sizes mirrors torch.utils.data.ConcatDataset behaviour + self.cumulative_sizes: list[int] = [] + total = 0 + for length in sub_lengths: + total += length + self.cumulative_sizes.append(total) + self.datasets = [None] * len(sub_lengths) # placeholders + + def get_crop_class_matrix(self) -> np.ndarray: + return self._matrix + + def __len__(self) -> int: + return sum(self._sub_lengths) + + class TestClassBalancedSampler: def _make_sampler(self, matrix, samples_per_epoch=None, seed=42): ds = FakeDataset(matrix) @@ -85,3 +108,68 @@ def test_yields_valid_indices_for_single_class(self): indices = list(sampler) assert len(indices) == 10 assert all(0 <= i < 5 for i in indices) + + def test_raises_when_no_active_classes(self): + """All-False crop-class matrix must raise ValueError immediately.""" + matrix = np.zeros((4, 3), dtype=bool) + with pytest.raises(ValueError, match="no active classes"): + self._make_sampler(matrix, samples_per_epoch=5) + + def test_concat_dataset_indices_in_correct_subdataset(self): + """ConcatDataset path: each yielded index falls in the expected sub-dataset range.""" + # Two sub-datasets: first has 10 samples, second has 20 samples + sub_lengths = [10, 20] + # Row 0 → only class 0 annotated; Row 1 → only class 1 annotated + matrix = np.array([[True, False], [False, True]], dtype=bool) + ds = FakeConcatDataset(sub_lengths, matrix) + sampler = ClassBalancedSampler(ds, samples_per_epoch=40, seed=0) + indices = list(sampler) + assert len(indices) == 40 + # All indices must be valid dataset indices + assert all(0 <= i < len(ds) for i in indices) + # Indices from class-0 crops (row 0 → sub-dataset 0) must be in [0, 10) + # Indices from class-1 crops (row 1 → sub-dataset 1) must be in [10, 30) + # Because the sampler alternates classes, roughly half go to each sub-dataset + indices_set = set(indices) + assert any(i < 10 for i in indices_set), "No index from sub-dataset 0" + assert any(10 <= i < 30 for i in indices_set), "No index from sub-dataset 1" + + def test_concat_dataset_all_indices_in_range(self): + """ConcatDataset path: all yielded indices are within [0, len(dataset)).""" + sub_lengths = [5, 5, 5] + matrix = np.eye(3, dtype=bool) + ds = FakeConcatDataset(sub_lengths, matrix) + sampler = ClassBalancedSampler(ds, samples_per_epoch=30, seed=7) + indices = list(sampler) + assert all(0 <= i < len(ds) for i in indices) + + +class TestMinRedundantInds: + def test_replacement_returns_k(self): + result = min_redundant_inds(5, 12, replacement=True) + assert len(result) == 12 + + def test_no_replacement_k_leq_n(self): + result = min_redundant_inds(10, 4, replacement=False) + assert len(result) == 4 + assert len(set(result.tolist())) == 4 # all unique + + def test_no_replacement_k_equals_n(self): + result = min_redundant_inds(5, 5, replacement=False) + assert len(result) == 5 + assert sorted(result.tolist()) == list(range(5)) + + def test_no_replacement_k_gt_n_exact_multiple(self): + """k=6, n=3: two full permutations, exactly 6 indices returned.""" + result = min_redundant_inds(3, 6, replacement=False) + assert len(result) == 6 + + def test_no_replacement_k_gt_n_with_remainder(self): + """k=7, n=3: must return exactly 7 indices, not 6.""" + result = min_redundant_inds(3, 7, replacement=False) + assert len(result) == 7 + + def test_no_replacement_all_values_in_range(self): + result = min_redundant_inds(4, 11, replacement=False) + assert len(result) == 11 + assert all(0 <= v < 4 for v in result.tolist())