diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7e8a554..51bad8b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -44,7 +44,7 @@ jobs: allow-prereleases: true - name: Install package - run: python -m pip install .[test] + run: python -m pip install -e .[test] - name: Test package run: >- diff --git a/.gitignore b/.gitignore index 25cf9a4..80265b8 100644 --- a/.gitignore +++ b/.gitignore @@ -156,3 +156,7 @@ Thumbs.db # Common editor files *~ *.swp + + +# IDE specific files +.vscode/ diff --git a/pyproject.toml b/pyproject.toml index 56eb38d..9e45763 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ "Typing :: Typed", ] -dependencies = ["torch", "numpy", "pandas", "mrcfile", "torchvision", "scipy", "pyarrow"] +dependencies = ["torch", "numpy", "pandas", "mrcfile", "torchvision", "scipy ~= 1.9.3", "pyarrow", "ccpem-utils", "h5py", "psutil", "pillow ~= 9.3"] [project.optional-dependencies] test = [ @@ -60,7 +60,7 @@ minversion = "6.0" addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] xfail_strict = true filterwarnings = [ - "error", + "ignore::DeprecationWarning", ] log_cli_level = "INFO" testpaths = [ diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..969267e --- /dev/null +++ b/setup.cfg @@ -0,0 +1,53 @@ +# Setup configuration for the package +[metadata] +name = caked + + +# Options for the package + +[options] + +packages = find: +python_requires = >=3.8 +package_dir = + = src + + +# where to add pip dependencies + +install_requires = + torch + numpy + pandas + mrcfile + torchvision + scipy + pyarrow + ccpem-utils + h5py + psutil + + +[options.packages.find] +where = + src + src/Transforms + src/Wrappers + +exclude = + tests + .github + .gitignore + .gitattributes + .pytest_cache + .git + .vscode + .history + *.egg + *.egg-info + docs + site + mkdocs.yml + *.ipynb + .mypy_cache + .ruff_cache diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..a03590f --- /dev/null +++ b/setup.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from setuptools import setup + +setup() diff --git a/src/caked/Transforms/augments.py b/src/caked/Transforms/augments.py new file mode 100644 index 0000000..4dace35 --- /dev/null +++ b/src/caked/Transforms/augments.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import random +from enum import Enum + +import numpy as np +from ccpem_utils.map.array_utils import rotate_array +from ccpem_utils.map.parse_mrcmapobj import MapObjHandle + +from .base import AugmentBase + + +class Augments(Enum): + """ """ + + RANDOMROT = "randrot" + ROT90 = "rot90" + + +def get_augment(augment: str, random_seed) -> AugmentBase: + """ """ + + if augment == Augments.RANDOMROT.value: + return RandomRotationAugment(random_seed=random_seed) + if augment == Augments.ROT90.value: + return Rotation90Augment(random_seed=random_seed) + + msg = f"Unknown Augmentation: {augment}, please choose from {Augments.__members__}" + raise ValueError(msg) + + +class ComposeAugment: + """ + Compose multiple Augments together. + + :param augments: (list) list of augments to compose + + :return: (np.ndarrry) transformed array + """ + + def __init__(self, augments: list[str], random_seed: int = 42): + self.random_seed = random_seed + self.augments = augments + + def __call__(self, data: np.ndarray, **kwargs) -> MapObjHandle: + for augment in self.augments: + data, augment_kwargs = get_augment(augment, random_seed=self.random_seed)( + data, **kwargs + ) + + kwargs.update(augment_kwargs) + + return data, kwargs + + +class RandomRotationAugment(AugmentBase): + """ + Random or controlled rotation (if ax and an kwargs provided). + + :param data: (np.ndarray) 3d volume + :param return_all: (bool) if True, will parameters of the rotation (ax, an) + :param interp: (bool) if True, will interpolate the rotation + :param ax: (int) 0 for yaw, 1 for pitch, 2 for roll + :param an: (int) number of times to rotate, between <1 and 3> + + :return: (np.ndarray) rotated volume or (np.ndarray, int, int) rotated volume and rotation parameters + """ + + def __init__(self, random_seed: int = 42): + super().__init__(random_seed) + + def __call__( + self, + data: np.ndarray, + **kwargs, + ) -> np.ndarray | tuple[np.ndarray, int, int]: + ax = kwargs.get("ax", None) + an = kwargs.get("an", None) + interp = kwargs.get("interp", True) + + if (ax is not None and an is None) or (ax is None and an is not None): + msg = "When specifying rotation, please use both arguments to specify the axis and angle." + raise RuntimeError(msg) + rotations = [(0, 1), (0, 2), (1, 2)] # yaw, pitch, roll + if ax is None and an is None: + axes = random.randint(0, 2) + set_angles = [30, 60, 90] + angler = random.randint(0, 2) + angle = set_angles[angler] + else: + axes = ax + angle = an + + r = rotations[axes] + data = rotate_array(data, angle, axes=r, interpolate=interp, reshape=False) + + return data, {"ax": axes, "an": angle} + + +class Rotation90Augment(AugmentBase): + """ + Rotate the volume by 90 degrees. + + :param data: (np.ndarray) 3d volume + :param return_all: (bool) if True, will parameters of the rotation (ax, an) + :param interp: (bool) if True, will interpolate the rotation + :param ax: (int) 0 for yaw, 1 for pitch, 2 for roll + :param an: (int) number of times to rotate, between <1 and 3> + + :return: (np.ndarray) rotated volume or (np.ndarray, int, int) rotated volume and rotation parameters + """ + + def __init__(self, random_seed: int = 42): + super().__init__(random_seed) + + def __call__( + self, + data: np.ndarray, + **kwargs, + ) -> np.ndarray: + _ = data + _ = kwargs + msg = "Rotation90Augment not implemented yet." + raise NotImplementedError(msg) diff --git a/src/caked/Transforms/base.py b/src/caked/Transforms/base.py new file mode 100644 index 0000000..5e91863 --- /dev/null +++ b/src/caked/Transforms/base.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +import numpy as np + + +class TransformBase(ABC): + """ + Base class for transformations. + + """ + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def __call__(self, mapobj, **kwargs): + msg = "The __call__ method must be implemented in the subclass" + raise NotImplementedError(msg) + + +class AugmentBase(ABC): + """ + Base class for augmentations. + """ + + # This will need to take the hyper parameters for the augmentations + + @abstractmethod + def __init__(self, random_seed: int = 42): + self.random_state = np.random.RandomState(random_seed) + + @abstractmethod + def __call__(self, data, **kwargs): + msg = "The __call__ method must be implemented in the subclass" + raise NotImplementedError(msg) diff --git a/src/caked/Transforms/transforms.py b/src/caked/Transforms/transforms.py new file mode 100644 index 0000000..f91293e --- /dev/null +++ b/src/caked/Transforms/transforms.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +from enum import Enum + +import numpy as np +from ccpem_utils.map.mrc_map_utils import ( + interpolate_to_grid, + normalise_mapobj, + pad_map_grid_split_distribution, +) +from ccpem_utils.map.parse_mrcmapobj import MapObjHandle + +from .base import TransformBase +from .utils import divx, mask_from_labelobj + + +class Transforms(Enum): + """ + Enum class for transformations. + + """ + + VOXNORM = "voxnorm" + NORM = "norm" + MASKCROP = "maskcrop" + PADDING = "padding" + + +def get_transform(transform: str) -> TransformBase: + """ + Get the transformation object. + + :param transform: (str) transformation to apply + + :return: (MapObjHandle) transformed MapObjHandle + """ + + if transform == Transforms.VOXNORM.value: + return MapObjectVoxelNormalisation() + if transform == Transforms.NORM.value: + return MapObjectNormalisation() + if transform == Transforms.MASKCROP.value: + return MapObjectMaskCrop() + if transform == Transforms.PADDING.value: + return MapObjectPadding() + msg = f"Unknown transform: {transform}, please choose from {Transforms.__members__}" + raise ValueError(msg) + + +class ComposeTransform: + """ + Compose multiple transformations together. + + :param transforms: (list) list of transformations to compose + + :return: (dict) transformed MapObjHandle kwargs + """ + + def __init__(self, transforms: list[str]): + self.transforms = transforms + + def __call__(self, *args: list[MapObjHandle | None], **kwargs) -> dict: + for transform in self.transforms: + for mapobj in args: + if mapobj is None: + continue # type: ignore[unreachable] + + _, kwargs = get_transform(transform)(mapobj, **kwargs) + + return kwargs + + +class DecomposeToSlices: + """ """ + + def __init__(self, map_shape: tuple, **kwargs): + step = kwargs.get("step", 1) + cshape = kwargs.get("cshape", 1) + slices, slice_indicies = [], [] + + for i in range(0, map_shape[0], step): + for j in range(0, map_shape[1], step): + for k in range(0, map_shape[2], step): + if ( + i + cshape > map_shape[0] + or j + cshape > map_shape[1] + or k + cshape > map_shape[2] + ): + continue + slices.append( + ( + slice(i, i + cshape), + slice(j, j + cshape), + slice(k, k + cshape), + ) + ) + slice_indicies.append((i, j, k)) + + if len(slice_indicies) == 0: + msg = "No slices were generated, please check the step and cshape values." + raise ValueError(msg) + self.slices = slices + self.slice_indicies = slice_indicies + + +class MapObjectVoxelNormalisation(TransformBase): + """ + Resamples a map object to a desired voxel size if outside of vox_sh_min and + vox_sh_max. + + """ + + def __init__(self): + super().__init__() + + def __call__( + self, + mapobj: MapObjHandle, + **kwargs, + ) -> tuple[MapObjHandle, dict]: + # This is needed to do the normalisation but I need to check if label obj is affected by this + + vox = kwargs.get("vox", 1.0) + vox_min = kwargs.get("vox_min", 0.95) + vox_max = kwargs.get("vox_max", 1.05) + + if not vox_min < vox < vox_max: + msg = f"Voxel size must be within the range of {vox_min} and {vox_max}." + raise ValueError(msg) + + voxx, voxy, voxz = mapobj.apix + sample = np.array(mapobj.shape) + if voxx > vox_max or voxx < vox_min: + sample[2] = int(mapobj.dim[0] / vox) + if voxy > vox_max or voxy < vox_min: + sample[1] = int(mapobj.dim[1] / vox) + if voxz > vox_max or voxz < vox_min: + sample[0] = int(mapobj.dim[2] / vox) + sample = tuple(sample) + interpolate_to_grid( + mapobj, + sample, + (vox, vox, vox), + mapobj.origin, + inplace=True, + prefilter_input=mapobj.all_transforms, + ) + + mapobj.update_header_by_data() + + return mapobj, kwargs + + +class MapObjectNormalisation(TransformBase): + """ + Normalise the voxel values of a Map Object. + + """ + + def __init__(self): + super().__init__() + + def __call__( + self, + mapobj: MapObjHandle, + **kwargs, + ) -> tuple[MapObjHandle, dict]: + if not mapobj.all_transforms: + return mapobj, kwargs + normalise_mapobj( + mapobj, + inplace=True, + ) + + return mapobj, kwargs + + +class MapObjectMaskCrop(TransformBase): + """ + Crop a Map Object using a mask. + """ + + def __init__(self): + super().__init__() + + def __call__( + self, + mapobj: MapObjHandle, + **kwargs, + ) -> tuple[MapObjHandle, dict]: + mask = kwargs.get("mask", None) + if mask is None: + msg = "Please provide a mask to crop the map object." + raise ValueError(msg) + + mask = mask_from_labelobj(mask) + + return mapobj, kwargs + + +class MapObjectPadding(TransformBase): + """ + Pad a Map Object. + """ + + def __init__(self): + super().__init__() + + def __call__( + self, + mapobj: MapObjHandle, + **kwargs, + ) -> tuple[MapObjHandle, dict]: + ext_dim = [divx(d, kwargs.get("step", 1)) - d for d in mapobj.shape] + + left = kwargs.get("left", True) + pad_map_grid_split_distribution( + mapobj, + ext_dim=ext_dim, + fill_padding=0.0, + left=left, + inplace=True, + ) + return mapobj, kwargs + + +# def data_scale(mapobj: MapObjHandle, desired_shape: tuple, inplace=True): +# """ +# Resamples image to desired shape. + +# :param mapobj: (MapObjHandle) map object +# :param desired_shape: (tuple(int, int, int)) desired shape +# :param inplace: (bool) perform operation in place +# :return: mapobj: (MapObjHandle) updated map object +# """ +# interpolate_to_grid(mapobj, desired_shape, mapobj.apix, mapobj.origin, inplace=True) +# if not inplace: +# return mapobj + +# mapobj.update_header_by_data() diff --git a/src/caked/Transforms/utils.py b/src/caked/Transforms/utils.py new file mode 100644 index 0000000..78aae11 --- /dev/null +++ b/src/caked/Transforms/utils.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import math + +from ccpem_utils.map.parse_mrcmapobj import MapObjHandle + + +def mask_from_labelobj(label_mapobj: MapObjHandle): + """ + Create a mask from a label object, where the mask is a boolean array + where the values are 1 for the labels and 0 for the background. + """ + mask_obj = label_mapobj.copy(deep=True) + arr = mask_obj.data + arr[arr > 1] = 1 + arr[arr < 0] = 0 + mask_obj.data = arr + return mask_obj + + +def divx(x, d=8): + """Ensure the number is divisible (to an integer) by x (to ensure it can pool + and concatenate max 3 times (2^3)).""" + if x % d != 0: + y = math.ceil(x / d) + x = y * d + return x diff --git a/src/caked/Wrappers/__init__.py b/src/caked/Wrappers/__init__.py new file mode 100644 index 0000000..7245127 --- /dev/null +++ b/src/caked/Wrappers/__init__.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from functools import wraps + + +def none_return_none(func): + @wraps(func) + def wrapper(*args, **kwargs): + if args[0] is None: + return None + return func(*args, **kwargs) + + return wrapper diff --git a/src/caked/base.py b/src/caked/base.py index 66e9a72..3c2cb8a 100644 --- a/src/caked/base.py +++ b/src/caked/base.py @@ -1,9 +1,12 @@ from __future__ import annotations from abc import ABC, abstractmethod +from pathlib import Path from torch.utils.data import Dataset +from caked.hdf5 import HDF5DataStore + class AbstractDataLoader(ABC): """ @@ -88,3 +91,27 @@ class AbstractDataset(ABC, Dataset): @abstractmethod def augment(self, augment: bool, aug_type: str): pass + + +class DatasetConfig: + datatype: str = "mrc" + label_path: str | Path | None = None + weight_path: str | Path | None = None + dataset_size: int | None = None + save_to_disk: bool = False + training: bool = True + classes: list[str] | None = None + pipeline: str = "disk" + transforms: list[str] | None = None + augments: list[str] | None = None + decompose: bool = True + decompose_kwargs: dict[str, int] | None = None + transform_kwargs: dict | None = None + augment_kwargs: dict | None = None + map_hdf5_store: HDF5DataStore | None = None + label_hdf5_store: HDF5DataStore | None = None + weight_hdf5_store: HDF5DataStore | None = None + slices: list[tuple[int, int, int]] | None = None + slice_indicies = None + slices_count: int = 0 + diff --git a/src/caked/dataloader.py b/src/caked/dataloader.py index 968ed1f..1392acb 100644 --- a/src/caked/dataloader.py +++ b/src/caked/dataloader.py @@ -14,11 +14,29 @@ import mrcfile import numpy as np import torch +from ccpem_utils.map.parse_mrcmapobj import MapObjHandle, get_mapobjhandle from scipy.ndimage import zoom -from torch.utils.data import DataLoader, Subset +from torch.utils.data import ConcatDataset, DataLoader, Subset from torchvision import transforms -from .base import AbstractDataLoader, AbstractDataset +from caked.base import AbstractDataLoader, AbstractDataset, DatasetConfig +from caked.hdf5 import HDF5DataStore, LRUCache +from caked.Transforms.augments import ComposeAugment +from caked.Transforms.transforms import ComposeTransform, DecomposeToSlices, Transforms +from caked.utils import ( + get_max_memory, + get_sorted_paths, + process_datasets, + find_background_slices_to_skip, +) + +try: + from ccpem_utils.other.utils import set_gpu +except ImportError: + + def set_gpu(): + pass + np.random.seed(42) TRANSFORM_OPTIONS = ["normalise", "gaussianblur", "shiftmin"] @@ -102,9 +120,7 @@ def load(self, datapath, datatype) -> None: else: class_check = np.isin(self.classes, ids) if not np.all(class_check): - msg = "Not all classes in the list are present in the directory. Missing classes: {}".format( - np.asarray(self.classes)[~class_check] - ) + msg = f"Not all classes in the list are present in the directory. Missing classes: {np.asarray(self.classes)[~class_check]}" raise RuntimeError(msg) class_check = np.isin(ids, self.classes) if not np.all(class_check): @@ -215,9 +231,227 @@ def get_loader( s = int(np.ceil(len(self.dataset) * int(split_size) / 100)) if s < 2: - msg = "Train and validation sets must be larger than 1 sample, train: {}, val: {}.".format( - len(idx[:-s]), len(idx[-s:]) - ) + msg = f"Train and validation sets must be larger than 1 sample, train: {len(idx[:-s])}, val: {len(idx[-s:])}." + raise RuntimeError(msg) + train_data = Subset(self.dataset, indices=idx[:-s]) + val_data = Subset(self.dataset, indices=idx[-s:]) + + loader_train = DataLoader( + train_data, + batch_size=batch_size, + num_workers=0, + shuffle=True, + drop_last=True, + ) + loader_val = DataLoader( + val_data, + batch_size=batch_size, + num_workers=0, + shuffle=True, + drop_last=(not no_val_drop), + ) + return loader_train, loader_val + + return DataLoader( + self.dataset, + batch_size=batch_size, + num_workers=0, + shuffle=True, + ) + + +class MapDataLoader(AbstractDataLoader): + def __init__( + self, + dataset_size: int | None = None, + save_to_disk: bool = False, + training: bool = True, + classes: list[str] | None = None, + pipeline: str = "disk", + transformations: list[str] | None = None, + augmentations: list[str] | None = None, + decompose: bool = True, + ) -> None: + """ + DataLoader implementation for loading map data from disk and saving them to a internal HDF5 store. + + + """ + self.dataset_size = dataset_size + self.save_to_disk = save_to_disk + self.training = training + self.pipeline = pipeline + self.transformations = transformations + self.augmentations = augmentations + self.decompose = decompose + self.debug = False + self.classes = classes + + if self.classes is None: + self.classes = [] + if self.transformations is None: + self.transformations = [] + if self.augmentations is None: + self.augmentations = [] + + def load( + self, + datapath: str | Path, + datatype: str, + cache_size: int | None = None, + label_path: str | Path | None = None, + weight_path: str | Path | None = None, + use_gpu: bool = False, + num_workers: int = 1, + background_filter: float | None = None, + **kwargs, + ) -> None: + """ + Load the data from the specified path and data type. + + Args: + datapath (str | Path): The path to the directory containing the data. + datatype (str): The type of data to load. + label_path (str | Path, optional): The path to the directory containing the labels. Defaults to None. + weight_path (str | Path, optional): The path to the directory containing the weights. Defaults to None. + multi_process (bool, optional): Whether to use multi-processing. Defaults to False. + use_gpu (bool, optional): Whether to use the GPU. Defaults to False. + kwargs: Additional keyword arguments used for MapDataSet + + Returns: + None + """ + datasets = [] + + if use_gpu and num_workers > 1: + msg = "Cannot use GPU and multi-process at the same time." + raise ValueError(msg) + if use_gpu: + set_gpu() + + datapath = Path(datapath) + label_path = Path(label_path) if label_path is not None else None + weight_path = Path(weight_path) if weight_path is not None else None + + cache_size = get_max_memory() if cache_size is None else cache_size + cache = LRUCache(cache_size) + + map_hdf5_store = HDF5DataStore( + datapath.joinpath("raw_map_data.h5"), + cache=cache, + ) + + label_hdf5_store = ( + HDF5DataStore(label_path.joinpath("label_data.h5"), cache=cache) + if label_path is not None + else None + ) + + paths = get_sorted_paths(datapath, datatype, self.dataset_size) + label_paths = get_sorted_paths(label_path, datatype, self.dataset_size) + weight_paths = get_sorted_paths(weight_path, datatype, self.dataset_size) + + if self.dataset_size is not None: + paths = paths[: self.dataset_size] + label_paths = ( + label_paths[: self.dataset_size] if label_paths is not None else None + ) + weight_paths = ( + weight_paths[: self.dataset_size] if weight_paths is not None else None + ) + + if label_paths is not None and len(label_paths) != len(paths): + msg = "Label paths and data paths do not match." + raise RuntimeError(msg) + + if weight_paths is not None and len(weight_paths) != len(paths): + msg = "Weight paths and data paths do not match." + raise RuntimeError(msg) + + label_paths = label_paths or [None] * len(paths) + weight_paths = weight_paths or [None] * len(paths) + + # HDF5 store assumes the data is all in one location + + datasets = process_datasets( + num_workers, + paths, + label_paths, + weight_paths, + self.transformations, + self.augmentations, + self.decompose, + map_hdf5_store, + label_hdf5_store, + **kwargs, + ) + + self.dataset = ConcatDataset(datasets) + + # TODO: I think this should be removed in favour of user input for classes + if not self.classes and label_hdf5_store is not None: + unique_labels = [ + np.unique(label_data) for label_data in label_hdf5_store.values() + ] + self.classes = np.unique(np.concatenate(unique_labels).flatten()).tolist() + + if background_filter is not None: + self.filter_slices_under_background_limit( + self.classes, background_limit=background_filter + ) + + def process(self): + """ """ + raise NotImplementedError() + + def get_hdf5_store( + self, + ) -> tuple[HDF5DataStore, HDF5DataStore | None]: + if self.dataset is None: + msg = "The dataset has not been loaded yet." + raise RuntimeError(msg) + return ( + self.dataset.datasets[0].map_hdf5_store, + self.dataset.datasets[0].label_hdf5_store, + ) + + def get_loader( + self, + batch_size: int, + split_size: float | None = None, + no_val_drop: bool = False, + split: bool = True, + ): + """ + Retrieve the data loader. + + Args: + batch_size (int): The batch size for the data loader. + split_size (float | None, optional): The percentage of data to be used for validation set. + If None, the entire dataset will be used for training. Defaults to None. + no_val_drop (bool, optional): If True, the last batch of validation data will not be dropped if it is smaller than batch size. Defaults to False. + + Returns: + DataLoader or Tuple[DataLoader, DataLoader]: The data loader(s) for testing or training/validation, according to whether training is True or False. + + Raises: + RuntimeError: If split_size is None and the method is called for training. + RuntimeError: If the train and validation sets are smaller than 2 samples. + + """ + if self.training and split: + if split_size is None: + msg = "Split size must be provided for training. " + raise RuntimeError(msg) + # split into train / val sets + idx = np.random.permutation(len(self.dataset)) + + if split_size < 1: + split_size = split_size * 100 + + s = int(np.ceil(len(self.dataset) * int(split_size) / 100)) + if s < 2: + msg = f"Train and validation sets must be larger than 1 sample, train: {len(idx[:-s])}, val: {len(idx[-s:])}." raise RuntimeError(msg) train_data = Subset(self.dataset, indices=idx[:-s]) val_data = Subset(self.dataset, indices=idx[-s:]) @@ -245,6 +479,41 @@ def get_loader( shuffle=True, ) + def filter_slices_under_background_limit( + self, + class_labels, + background_limit: float = 0.3, + ): + """ + Find the slices in the dataloader that contain only background and remove them. + + :param class_label_handler: Class label handler + + :return: Empty tile + """ + to_skip = find_background_slices_to_skip( + self, + class_labels, + background_limit=background_limit, + ) + + for dataset in self.dataset.datasets: + if dataset.id in to_skip: + dataset.slice_indicies = [ + tile + for i, tile in enumerate(dataset.slice_indicies) + if i not in to_skip[dataset.id] + ] + + dataset.slices = [ + slice_ + for i, slice_ in enumerate(dataset.slices) + if i not in to_skip[dataset.id] + ] + dataset.slices_count = len(dataset.slice_indicies) + + self.dataset.cumulative_sizes = self.dataset.cumsum(self.dataset.datasets) + class DiskDataset(AbstractDataset): """ @@ -360,4 +629,509 @@ def transformation(self, x): return x def augment(self, augment): - raise NotImplementedError + raise NotImplementedError() + + +class MapDataset(AbstractDataset): + def __init__( + self, + path: str | Path, + **kwargs, + ) -> None: + """ + A dataset class for loading map data, alongside the corresponding class labels and weights. + The map data is loaded from the disk and is decomposed into a set of slice_indicies. These slice_indicies are + then returned when indexing the dataset. + + Args: + path (Union[str, Path]): The path to the map data. + label_path (Optional[Union[str, Path]]): The path to the label data. Defaults to None. + weight_path (Optional[Union[str, Path]]): The path to the weight data. Defaults to None. + map_hdf5_store (Optional[HDF5DataStore]): The HDF5 store for the map data. Defaults to None. + label_hdf5_store (Optional[HDF5DataStore]): The HDF5 store for the label data. Defaults to None. + transforms (Optional[List[str]]): The transformations to apply to the data. + augments (Optional[List[str]]): The augmentations to apply to the data. + decompose (bool): Whether to decompose the data into slices and slice_indicies. Defaults to True. + decompose_kwargs (Optional[Dict[str, int]]): The decomposition parameters. Defaults to None. + transform_kwargs (Optional[Dict]): The transformation parameters. Defaults to None. + + + Attributes: + data_shape (Optional[Tuple]): The shape of the map data. Defaults to None. + mapobj (Optional[MapObjHandle]): The map object handle for the map data. Defaults to None. + label_mapobj (Optional[MapObjHandle]): The map object handle for the label data. Defaults to None. + weight_mapobj (Optional[MapObjHandle]): The map object handle for the weight data. Defaults to None. + slices (Optional[List[Tuple]]): The slices of the data. Defaults to None. + slice_indicies (Optional): The slice_indicies of the data. Defaults to None. + slices_count (int): The number of slice_indicies. Defaults to 0. + + """ + config = DatasetConfig() + + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + + self.path = Path(path) + self.id = self.path.stem + self.label_path = ( + Path(config.label_path) if config.label_path is not None else None + ) + self.weight_path = ( + Path(config.weight_path) if config.weight_path is not None else None + ) + + self.map_hdf5_store: HDF5DataStore = kwargs.get( + "map_hdf5_store", config.map_hdf5_store + ) + self.label_hdf5_store: HDF5DataStore | None = kwargs.get( + "label_hdf5_store", config.label_hdf5_store + ) + self.slices: list = kwargs.get("slices", []) + self.slice_indicies: list = kwargs.get("slice_indicies", []) + self.slices_count = kwargs.get("slices_count", config.slices_count) + self.transforms = kwargs.get("transforms", config.transforms) + if not config.train: + self.augments = None + self.augments = kwargs.get("augments", config.augments) + self.decompose_kwargs = kwargs.get("decompose_kwargs", config.decompose_kwargs) + self.transform_kwargs = kwargs.get("transform_kwargs", config.transform_kwargs) + self.decompose = kwargs.get("decompose", config.decompose) + self.data_shape: tuple | None = None + + self.mapobj: MapObjHandle | None = None + self.label_mapobj: MapObjHandle | None = None + self.weight_mapobj: MapObjHandle | None = None + + cshape = kwargs.get("cshape", 32) + margin = kwargs.get("margin", 8) + if self.decompose_kwargs is None: + self.decompose_kwargs = {"cshape": cshape, "margin": margin} + + if self.transform_kwargs is None: + self.transform_kwargs = {} + + self.augments = [] if self.augments is None else self.augments + + self.transforms = [] if self.transforms is None else self.transforms + + if not self.decompose_kwargs.get("step", False): + step = self.decompose_kwargs.get("cshape", 1) - ( + 2 * self.decompose_kwargs.get("margin") + ) + self.decompose_kwargs["step"] = step if step != 0 else 1 + + def __len__(self): + if self.slices_count == 0 and self.decompose: + self.generate_tile_indicies() + elif self.slices_count == 0: + self.slices_count = 1 + + return self.slices_count + + def __getitem__( + self, idx + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + if (not self.slices or not self.slice_indicies) and self.decompose: + self.generate_tile_indicies() + elif (not self.slices or not self.slice_indicies) and not self.decompose: + self.slices = [(slice(None), slice(None), slice(None))] + else: + self.slices = self.slices + self.slice_indicies = self.slice_indicies + + map_array = self.map_hdf5_store.get(f"{self.id}_map", to_torch=True) + + if map_array.ndim == 4: + x_slice, y_slice, z_slice = self.slices[idx] + map_slice = map_array[:, x_slice, y_slice, z_slice] + else: + map_slice = map_array[self.slices[idx]] + + label_slice = ( + self.label_hdf5_store.get(f"{self.id}_label", to_torch=True)[ + self.slices[idx] + ] + if self.label_hdf5_store is not None + else None + ) + + if not isinstance(map_slice, torch.Tensor): + map_tensor = torch.tensor(map_slice) + else: + map_tensor = map_slice + + if not isinstance(label_slice, torch.Tensor): + label_tensor = ( + torch.tensor(label_slice) if label_slice is not None else None + ) + else: + label_tensor = label_slice + + return tuple( + tensor for tensor in (map_tensor, label_tensor) if tensor is not None + ) + + def load_map_objects( + self, + ) -> None: + """ + Load the map objects from the specified paths. + """ + self.mapobj = get_mapobjhandle(self.path) + self.mapobj.all_transforms = True + if self.label_path is not None: + if not self.label_path.exists(): + msg = f"Label file {self.label_path} not found." + raise FileNotFoundError(msg) + self.label_mapobj = get_mapobjhandle(self.label_path) + self.label_mapobj.all_transforms = False + if self.weight_path is not None: + if not self.weight_path.exists(): + msg = f"Weight file {self.weight_path} not found." + raise FileNotFoundError(msg) + self.weight_mapobj = get_mapobjhandle(self.weight_path) + self.weight_mapobj.all_transforms = False + + def close_map_objects(self, *args): + """ + Close the map objects. + + Args: + *args: The map objects to close. + + + """ + for arg in args: + if arg is not None: + arg.close() + + def augment(self, close_map_objects) -> dict: + """ + Apply augmentations to the map data. + + Args: + close_map_objects (bool): Whether to close the map objects after transformation. + + Returns: + dict: The augmentation keywords + """ + augment_kwargs = self._augment_keywords_builder() + if len(self.augments) == 0: + return {} + + self.mapobj, extra_kwargs = ComposeAugment(self.augments)( + self.mapobj, **augment_kwargs + ) + augment_kwargs.update(extra_kwargs) + + self.label_mapobj = ComposeAugment(self.augments)( + self.label_mapobj, **augment_kwargs + ) + self.weight_mapobj = ComposeAugment(self.augments)( + self.weight_mapobj, **augment_kwargs + ) + + if close_map_objects: + self.close_map_objects(self.mapobj, self.label_mapobj, self.weight_mapobj) + + return augment_kwargs + + def transform(self, close_map_objects: bool = True): + """ + Perform the transformations on the map data. + + Note: The final map shape is calculated here, + + Args: + close_map_objects (bool, optional): Whether to close the map objects after transformation. Defaults to True. + + """ + if self.mapobj is None: + self.load_map_objects() + transform_kwargs = self._transform_keywords_builder() + if len(self.transforms) == 0: + self.transform_kwargs = transform_kwargs + + self.transform_kwargs = ComposeTransform(self.transforms)( + self.mapobj, self.label_mapobj, self.weight_mapobj, **transform_kwargs + ) + self.get_data_shape(close_map_objects=False) + + if close_map_objects: + self.close_map_objects(self.mapobj, self.label_mapobj, self.weight_mapobj) + + def get_data_shape(self, close_map_objects: bool = True): + """ + Get the shape of the map data, label data, and weight data. + + + Args: + close_map_objects (bool, optional): Whether to close the map objects after transformation. Defaults to True. + + """ + if self.data_shape is not None: + return + + if (self.mapobj is None) or (self.mapobj.data) is None: + self.load_map_objects() + if self.mapobj is not None and self.mapobj.data is not None: + # MyPy shenanigans + self.data_shape = self.mapobj.data.shape + if self.label_mapobj is not None: + assert ( + self.label_mapobj.data.shape == self.data_shape + ), f"Map and label shapes do not match for {self.id}." + if self.weight_mapobj is not None: + assert ( + self.weight_mapobj.data.shape == self.data_shape + ), f"Map and weight shapes do not match for {self.id}." + + if close_map_objects: + self.close_map_objects(self.mapobj, self.label_mapobj, self.weight_mapobj) + + def generate_tile_indicies(self): + """ + Generate the tile indices for the map data using the decomposition parameters. + + """ + if self.data_shape is None: + self.get_data_shape() + + decompose = DecomposeToSlices( + self.data_shape, + step=self.decompose_kwargs.get("step"), + cshape=self.decompose_kwargs.get("cshape"), + margin=self.decompose_kwargs.get("margin"), + ) + + self.slices = decompose.slices + self.slice_indicies = decompose.slice_indicies + self.slices_count = len(self.slice_indicies) + + def _transform_keywords_builder(self): + keywords = {} + keywords.update(self.decompose_kwargs) + + for transform in self.transforms: + if transform == Transforms.MASKCROP.value: + keywords["mask"] = self.label_mapobj + + if transform == Transforms.NORM.value: + keywords["ext_dim"] = (0, 0, 0) + keywords["fill_padding"] = (0, 0, 0) + + if transform == Transforms.VOXNORM.value: + keywords["vox"] = self.decompose_kwargs.get("vox", 1.0) + keywords["vox_lim"] = self.decompose_kwargs.get("vox_lim", (0.95, 1.05)) + + return keywords + + def _augment_keywords_builder(self): + keywords = {} + for augment in self.augments: + if augment.__class__.__name__ == "RandomRotationAugment": + keywords["ax"] = self.ax + keywords["an"] = self.an + + return keywords + + +class ArrayDataset(AbstractDataset): + """Class to handle loading of data from hdf5 files, to be handled by a DataLoader + + Args: + dataset_id (str): The dataset ID. + data_array (np.ndarray): The data array. + label_array (np.ndarray, optional): The label array. Defaults to None. + weight_array (np.ndarray, optional): The weight array. Defaults to None. + + + """ + + def __init__( + self, + dataset_id: str, + data_array: np.ndarray, + label_array: np.ndarray | None = None, + weight_array: np.ndarray | None = None, + **kwargs, + ) -> None: + config = DatasetConfig() + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + self.id = dataset_id + self.data_array = data_array + self.label_array = label_array + self.weight_array = weight_array + + self.slices = kwargs.get("slices", config.slices) + self.slice_indicies = kwargs.get("slice_indicies", config.slice_indicies) + self.slices_count = kwargs.get("slices_count", config.slices_count) + self.augments = kwargs.get("augments", config.augments) + self.decompose = kwargs.get("decompose", config.decompose) + self.data_shape: tuple | None = None + self.decompose_kwargs = kwargs.get("decompose_kwargs", config.decompose_kwargs) + self.map_hdf5_store = kwargs.get("map_hdf5_store", config.map_hdf5_store) + self.label_hdf5_store = kwargs.get("label_hdf5_store", config.label_hdf5_store) + if self.decompose_kwargs is None: + self.decompose_kwargs = {"cshape": 64, "margin": 8} + + if not self.decompose_kwargs.get("step", False): + self.decompose_kwargs["step"] = self.decompose_kwargs.get("cshape", 1) - ( + 2 * self.decompose_kwargs.get("margin") + ) + + if self.augments is None: + self.augments = [] + + self.__mapdataset = MapDataset( + path=self.id, + **config.__dict__, + ) + + def __len__(self): + if self.slices_count == 0 and self.decompose: + self.generate_tile_indicies() + elif self.slices_count == 0: + self.slices_count = 1 + + return self.slices_count + + def __getitem__(self, idx): + if (not self.slices or not self.slice_indicies) and self.decompose: + self.generate_tile_indicies() + elif (not self.slices or not self.slice_indicies) and not self.decompose: + self.slices = [(slice(None), slice(None), slice(None))] + else: + self.slices = self.slices + self.slice_indicies = self.slice_indicies + + if self.data_array is None: + self.get_data() + + # MyPy shenanigans + if self.data_array is not None and self.data_array.ndim == 4: + x_slice, y_slice, z_slice = self.slices[idx] + map_slice = self.data_array[:, x_slice, y_slice, z_slice] + elif self.data_array is not None: + map_slice = self.data_array[self.slices[idx]] + else: + map_slice = None + + label_slice = ( + self.label_array[self.slices[idx]] if self.label_array is not None else None + ) + + if not isinstance(map_slice, torch.Tensor): + map_tensor = torch.tensor(map_slice) + else: + map_tensor = map_slice + + if not isinstance(label_slice, torch.Tensor): + label_tensor = ( + torch.tensor(label_slice) if label_slice is not None else None + ) + else: + label_tensor = label_slice + + self.close_data() + + return tuple( + tensor for tensor in (map_tensor, label_tensor) if tensor is not None + ) + + def get_data(self): + """ + Retrieve the array data from the HDF5 store. + """ + self.data_array = self.map_hdf5_store.get(self.id + "_map", to_torch=True) + if self.label_hdf5_store is not None: + self.label_array = self.label_hdf5_store.get( + self.id + "_label", to_torch=True + ) + + def close_data(self): + """ + Close the data arrays. + """ + self.data_array = None + self.label_array = None + self.weight_array = None + + def _augment_keywords_builder(self): + return self.__mapdataset._augment_keywords_builder() + + def _transform_keywords_builder(self): + return self.__mapdataset._transform_keywords_builder() + + def transform(self) -> None: + msg = "Transforms are not supported for ArrayDataset." + raise NotImplementedError(msg) + + def augment(self) -> dict: + """ + Apply augmentations to the array data. + """ + augment_kwargs = self._augment_keywords_builder() + if len(self.augments) == 0: + return {} + + self.data_array, extra_kwargs = ComposeAugment(self.augments)( + self.data_array, **augment_kwargs + ) + + augment_kwargs.update(extra_kwargs) + if self.label_array is not None: + self.label_array, _ = ComposeAugment(self.augments)( + self.label_array, **augment_kwargs + ) + if self.weight_array is not None: + self.weight_array, _ = ComposeAugment(self.augments)( + self.weight_array, **augment_kwargs + ) + + return augment_kwargs + + def get_data_shape(self, close_data: bool = True): + """ + Get the shape of the array data. + """ + if self.data_shape is not None: + return + + if self.data_array is None: + self.get_data() + + # MyPy shenanigans + self.data_shape = self.data_array.shape if self.data_array is not None else None + if self.label_array is not None: + assert ( + self.label_array.shape == self.data_shape + ), "Map and label shapes do not match." + if self.weight_array is not None: + assert ( + self.weight_array.shape == self.data_shape + ), "Map and weight shapes do not match." + + if close_data: + self.close_data() + + def generate_tile_indicies(self): + """ + Generate the tile indices for the array data using the decomposition parameters. + """ + if self.data_shape is None: + self.get_data_shape() + + decompose = DecomposeToSlices( + self.data_shape, + step=self.decompose_kwargs.get("step"), + cshape=self.decompose_kwargs.get("cshape"), + margin=self.decompose_kwargs.get("margin"), + ) + + self.slices = decompose.slices + self.slice_indicies = decompose.slice_indicies + self.slices_count = len(self.slice_indicies) diff --git a/src/caked/hdf5.py b/src/caked/hdf5.py new file mode 100644 index 0000000..f4c395e --- /dev/null +++ b/src/caked/hdf5.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import tempfile +from collections import OrderedDict +from pathlib import Path + +import h5py +import numpy as np +import torch + + +class HDF5DataStore: + def __init__( + self, + save_path: str | Path, + use_temp_dir: bool = True, + cache: LRUCache | None = None, + batch_size: int = 10, + cache_size: int = 5, + ): + """ + Object to store data in HDF5 format. If use_temp_dir is True, the file is saved + in a temporary directory and deleted when the object is deleted. This is useful + for temporary storage of data. If use_temp_dir is False, the file is + saved in the save_path provided. The file is not deleted when the object is deleted. + + :param save_path: (str) path to save the file + :param cache: (LRUCache) cache object to store data + :param use_temp_dir: (bool) whether to use a temporary directory + :param batch_size: (int) number of items to write to the file before closing + :param cache_size: (int) size of the cache in GB + + + """ + save_path = Path(save_path) + if use_temp_dir: + self.temp_dir_obj = tempfile.TemporaryDirectory() + self.temp_dir: Path | None = Path(self.temp_dir_obj.name) + self.save_path = self.temp_dir.joinpath(save_path.name) + else: + self.save_path = Path(save_path) + self.temp_dir = None + + self.batch_size = batch_size + self.counter = 0 + self.file = None + if cache is None: + self.cache = LRUCache(cache_size) + else: + self.cache = cache + + def open(self, mode: str = "a"): + if self.file is None: + self.file = h5py.File(self.save_path, mode) + + def close(self): + if self.file is not None: + self.file.close() # type: ignore[unreachable] + self.file = None + + def __del__(self): + self.close() + if self.temp_dir is not None: + self.temp_dir_obj.cleanup() + + def __getitem__(self, key: str): + with h5py.File(self.save_path, "r") as f: + return np.array(f[key]) + + def __iter__(self): + with h5py.File(self.save_path, "r") as f: + yield from f + + def get(self, key: str, default=None, to_torch: bool = False): + try: + if key in self.cache: + return self.cache.get(key) + with h5py.File(self.save_path, "r") as f: + if to_torch: + arr = torch.from_numpy(np.array(f[key])).clone().detach() + + else: + arr = np.array(f[key]) + self.cache.put(key, arr) + return arr + except KeyError: + return default + + def __len__(self): + return len(self.keys()) + + def add_array( + self, array: np.ndarray, dataset_name: str, compression: str = "gzip" + ) -> str: + if self.check_name_in_store(dataset_name): + dataset_name = self._add_number_to_dataset_name(dataset_name) + with h5py.File(self.save_path, "a") as f: + f.create_dataset( + dataset_name, data=array, compression=compression, chunks=True + ) + + return dataset_name + + def save(self, array_list: list[np.ndarray]): + for i, array in enumerate(array_list): + self.add_array(array, f"array_{i}") + + def check_name_in_store(self, dataset_name: str): + if not self.save_path.exists(): + return False + with h5py.File(self.save_path, "r") as f: + return dataset_name in f + + def _add_number_to_dataset_name(self, dataset_name: str, delimiter: str = "--"): + # add a number to the end of the dataset name, take the last number and increment it + existing_names = [name for name in self.keys() if dataset_name in name] + last_number = ( + max( + [ + int(name.split(delimiter)[0]) + for name in existing_names + if delimiter in name + ] + ) + if len(existing_names) > 1 + else 0 + ) + + # dataset_name = dataset_name.split(delimiter)[0:-1] + return f"{last_number+1}{delimiter}{dataset_name}" + + def keys(self): + with h5py.File(self.save_path, "r") as f: + return list(f.keys()) + + def values(self, to_torch: bool = False): + with h5py.File(self.save_path, "r") as f: + for key in f: + if to_torch: + yield torch.from_numpy(np.array(f[key])) + else: + yield np.array(f[key]) + + +class LRUCache: + def __init__(self, max_memory_gb: int): + self.max_memory_bytes = max_memory_gb * 1024**3 + self.cache: OrderedDict = OrderedDict() + self.current_memory_usage = 0 + + def get_memory_usage(self, obj) -> int: + if isinstance(obj, np.ndarray): + return obj.nbytes + if isinstance(obj, torch.Tensor): + return obj.element_size() * obj.nelement() + return 0 + + def get(self, key: str): + if key not in self.cache: + return None + self.cache.move_to_end(key) + return self.cache[key] + + def put(self, key: str, value): + if key in self.cache: + self.current_memory_usage -= self.get_memory_usage(self.cache[key]) + self.cache.move_to_end(key) + self.cache[key] = value + self.current_memory_usage += self.get_memory_usage(value) + self.evict_if_needed() + + def evict_if_needed(self): + while self.current_memory_usage > self.max_memory_bytes: + _, evicted_value = self.cache.popitem(last=False) + self.current_memory_usage -= self.get_memory_usage(evicted_value) + + def __contains__(self, key): + return key in self.cache diff --git a/src/caked/utils.py b/src/caked/utils.py new file mode 100644 index 0000000..17adb1a --- /dev/null +++ b/src/caked/utils.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +import numpy as np +import psutil +import torch +from torch.utils.data import ConcatDataset + +from caked.hdf5 import HDF5DataStore +from caked.Wrappers import none_return_none + + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def process_datasets( + num_workers: int, + paths: list[str], + label_paths: list[str], + weight_paths: list[str], + transformations, + augmentations, + decompose: bool, + raw_map_HDF5: HDF5DataStore, + label_HDF5: HDF5DataStore | None = None, + **kwargs, +): + """ + Process multiple datasets in parallel. + + Args: + num_workers: Number of workers to use. + paths: List of paths to the map files. + label_paths: List of paths to the label files. + weight_paths: List of paths to the weight files. + raw_map_HDF5: Instance of HDF5DataStore to store map data. + label_HDF5: Instance of HDF5DataStore to store label data. + + Returns: + None + + """ + datasets = [] + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [ + executor.submit( + process_map_dataset, + path, + label_path, + weight_path, + transformations, + augmentations, + decompose, + raw_map_HDF5, + label_HDF5, + **kwargs, + ) + for path, label_path, weight_path in zip(paths, label_paths, weight_paths) + ] + + for future in as_completed(futures): + result, dataset = future.result() + map_data, label_data, weight_data = result.values() + add_dataset_to_HDF5( + map_data, + label_data, + weight_data, + dataset.id, + raw_map_HDF5, + label_HDF5=label_HDF5, + ) + datasets.append(dataset) # Collect processed datasets + + return datasets + + +def process_map_dataset( + path: str | Path, + label_path: str | Path | None, + weight_path: str | Path | None, + transformations: list[str], + augmentations: list[str], + decompose: bool, + map_hdf5: HDF5DataStore, + label_hdf5: HDF5DataStore | None, + **kwargs, +): + """ + Process a single map dataset, applying transformations and augmentations, closes the map objects. + + Args: + path: (str| Path) path to the map file. + label_path: (str | Path | None) path to the label file. + weight_path: (str | Path | None) path to the weight file. + transformations: (list[str]) list of transformations to apply. + augmentations: (list[str]) list of augmentations to apply. + + Returns: + tuple[dict, MapDataset]: dictionary containing map, label, and weight data, + and the processed MapDataset object. + + + """ + from caked.dataloader import MapDataset # Avoid circular import + + map_dataset = MapDataset( + path, + label_path=label_path, + weight_path=weight_path, + transforms=transformations, + augments=augmentations, + decompose=decompose, + map_hdf5_store=map_hdf5, + label_hdf5_store=label_hdf5, + **kwargs, + ) + map_dataset.transform(close_map_objects=False) + map_dataset.augment(close_map_objects=False) + result = { + "map_data": map_dataset.mapobj.data, + "label_data": map_dataset.label_mapobj.data if label_path is not None else None, + "weight_data": ( + map_dataset.weight_mapobj.data if weight_path is not None else None + ), + } + + map_dataset.close_map_objects() + + return result, map_dataset + + +def add_dataset_to_HDF5( + map_data: np.ndarray, + label_data: np.ndarray | None, + weight_data: np.ndarray | None, + name: str, + raw_map_HDF5: HDF5DataStore, + label_HDF5: HDF5DataStore | None = None, +) -> tuple[str, str]: + """ + Add a map data to HDF5 files. + + Args: + + map_data: (np.ndarray) map data + raw_map_HDF5: (HDF5DataStore) instance of HDF5DataStore to store map data + name: (str) name of the dataset + label_data: (np.ndarray | None) label data + weight_data: (np.ndarray | None) weight data + label_HDF5: (HDF5DataStore | None) instance of HDF5DataStore to store label data + + Returns: + tuple[str, str, str]: map_id, label_id, weight_id + """ + map_id = f"{name}_map" + label_id = f"{name}_label" + + map_data = torch.tensor(map_data, dtype=torch.float32) + label_data = ( + torch.tensor(label_data, dtype=torch.float32) + if label_data is not None + else None + ) + + if weight_data is None and label_data is not None: + weight_data = torch.where( + label_data != 0, + torch.ones_like(label_data), + torch.zeros_like(label_data), + ) + + else: + weight_data = torch.ones_like(map_data) + + if weight_data is not None and weight_data.shape == map_data.shape: + # Add weight values to the first dimension of the map tensor + map_data = torch.cat((map_data.unsqueeze(0), weight_data.unsqueeze(0)), dim=0) + + map_id = raw_map_HDF5.add_array(map_data, map_id) + if label_HDF5 is not None: + label_id = label_HDF5.add_array(label_data, label_id) + + return map_id, label_id + + +@none_return_none +def filter_and_construct_paths(base_path, paths, classes): + return [ + base_path / p.name for p in paths for c in classes if c in p.name.split("_")[0] + ] + + +def duplicate_and_augment_from_hdf5( + map_data_loader, + ids: list[str], + augmentations: list[str] | None = None, +): + """ + Add data from a list of paths to the HDF5 store.k + + Args: + pathnames (list[str]): List of path names accessed from the HDF5 store, typically the stem of the original file. + + Returns: + None + """ + from caked.dataloader import ArrayDataset, MapDataLoader + + datasets = map_data_loader.dataset.datasets + + if not isinstance(map_data_loader, MapDataLoader): + msg = "map_data_loader must be an instance of MapDataLoader." + raise TypeError(msg) + + if len(map_data_loader.dataset.datasets) == 0: + msg = "No datasets have been loaded yet." + raise RuntimeError(msg) + + map_hdf5_store, label_hdf5_store = ( + map_data_loader.dataset.datasets[0].map_hdf5_store, + map_data_loader.dataset.datasets[0].label_hdf5_store, + ) + + for dataset_id in ids: + array_weight = map_hdf5_store[dataset_id + "_map"] + label_array = ( + label_hdf5_store.get(dataset_id + "_label") + if label_hdf5_store is not None + else None + ) + weight_array = array_weight[0] + array = array_weight[1] + + dataset = ArrayDataset( + dataset_id=dataset_id, + data_array=array, + label_array=label_array, + weight_array=weight_array, + augments=augmentations, + map_hdf5_store=map_hdf5_store, + label_hdf5_store=label_hdf5_store, + decompose=map_data_loader.dataset.datasets[0].decompose, + decompose_kwargs=map_data_loader.dataset.datasets[0].decompose_kwargs, + ) + + dataset.augment() # Augment, flagged off when prediction mode selected + + add_dataset_to_HDF5( + dataset.data_array, + dataset.label_array, + dataset.weight_array, + dataset.id, + map_hdf5_store, + label_hdf5_store, + ) + + datasets.append(dataset) + + map_data_loader.dataset = ConcatDataset(datasets) + + +@none_return_none +def get_sorted_paths( + path: Path, + datatype: str, + dataset_size: int | None = None, +): + """ + Sort paths by the stem of the file name. + """ + paths = sorted(path.rglob(f"*.{datatype}"), key=lambda x: x.stem.split("_")[0]) + return paths[:dataset_size] if dataset_size is not None else paths + + +def get_max_memory() -> int: + """ + Detect the maximum memory available on the machine. + + Returns: + int: The maximum memory available in GB, rounded down to the nearest integer. + """ + mem_info = psutil.virtual_memory() + max_memory_gb = mem_info.total / (1024**3) # Convert bytes to GB + return int(max_memory_gb // 1) + + + + +def find_background_slices_to_skip( + dataloader, + class_labels, + background_limit: float = 0.3, +) -> None: + to_skip = {} + + for dataset in dataloader.dataset.datasets: + counts_tensor = torch.zeros(len(class_labels), dtype=torch.int32, device=DEVICE) + for index in range(len(dataset)): + _, label_tensor = dataset[index] + label_tensor = label_tensor.to(DEVICE) + + label_tensor = label_tensor.flatten().type(torch.int64) + + if label_tensor.numel() == 0: + continue + counts_tensor.zero_() + counts_tensor.scatter_add_( + 0, label_tensor, torch.ones_like(label_tensor, dtype=torch.int32) + ) + + total = label_tensor.size(0) + background_counts = (counts_tensor[0] / total).item() + + if background_counts > background_limit: + if dataset.id not in to_skip: + to_skip[dataset.id] = [] + to_skip[dataset.id].append(index) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ca86b2d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import shutil +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest + + +@pytest.fixture() +def test_data_mrc_dir(): + """Fixture to provide the MRC test data directory.""" + return Path(Path(__file__).parent.joinpath("testdata_mrc")) + + +@pytest.fixture() +def test_data_npy_dir(): + """Fixture to provide the NPY test data directory.""" + return Path(Path(__file__).parent.joinpath("testdata_npy")) + + +@pytest.fixture() +def test_corrupt_file(): + """Fixture to provide the path to a corrupt file for testing.""" + return Path(__file__).parent / "corrupt.mrc" + + +@pytest.fixture() +def test_data_single_mrc_dir(): + """Fixture to provide a single MRC file for testing.""" + return Path(Path(__file__).parent.joinpath("testdata_mrc", "mrc")) + + +@pytest.fixture() +def test_data_single_mrc_temp_dir(): + with TemporaryDirectory() as temp_dir: + temp_dir_path = Path(temp_dir) + test_data_single_mrc_dir = Path( + Path(__file__).parent.joinpath("testdata_mrc", "mrc") + ) + for file in test_data_single_mrc_dir.glob("*"): + shutil.copy(file, temp_dir_path) + yield temp_dir_path diff --git a/tests/test_disk_io.py b/tests/test_disk_io.py index f9a4a7e..b5729d9 100644 --- a/tests/test_disk_io.py +++ b/tests/test_disk_io.py @@ -5,13 +5,11 @@ import numpy as np import pytest import torch -from tests import testdata_mrc, testdata_npy from caked.dataloader import DiskDataLoader, DiskDataset ORIG_DIR = Path.cwd() -TEST_DATA_MRC = Path(testdata_mrc.__file__).parent -TEST_DATA_NPY = Path(testdata_npy.__file__).parent + TEST_CORRUPT = Path(__file__).parent / "corrupt.mrc" DISK_PIPELINE = "disk" DATASET_SIZE_ALL = None @@ -47,36 +45,36 @@ def test_class_instantiation(): assert test_loader.pipeline == DISK_PIPELINE -def test_dataset_instantiation_mrc(): +def test_dataset_instantiation_mrc(test_data_mrc_dir): """ Test case for instantiating a DiskDataset with MRC data. """ - test_dataset = DiskDataset(paths=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_dataset = DiskDataset(paths=test_data_mrc_dir, datatype=DATATYPE_MRC) assert isinstance(test_dataset, DiskDataset) -def test_dataset_instantiation_npy(): +def test_dataset_instantiation_npy(test_data_npy_dir): """ Test case for instantiating a DiskDataset with npy datatype. """ - test_dataset = DiskDataset(paths=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_dataset = DiskDataset(paths=test_data_npy_dir, datatype=DATATYPE_MRC) assert isinstance(test_dataset, DiskDataset) -def test_load_dataset_no_classes(): +def test_load_dataset_no_classes(test_data_mrc_dir): """ Test case for loading dataset without specifying classes. """ test_loader = DiskDataLoader( pipeline=DISK_PIPELINE, classes=DISK_CLASSES_NONE, dataset_size=DATASET_SIZE_ALL ) - test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_loader.load(datapath=test_data_mrc_dir, datatype=DATATYPE_MRC) assert isinstance(test_loader.dataset, DiskDataset) assert len(test_loader.classes) == len(DISK_CLASSES_FULL_MRC) assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL_MRC)) -def test_load_dataset_all_classes_mrc(): +def test_load_dataset_all_classes_mrc(test_data_mrc_dir): """ Test case for loading a dataset with all classes using DiskDataLoader. """ @@ -85,13 +83,13 @@ def test_load_dataset_all_classes_mrc(): classes=DISK_CLASSES_FULL_MRC, dataset_size=DATASET_SIZE_ALL, ) - test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_loader.load(datapath=test_data_mrc_dir, datatype=DATATYPE_MRC) assert isinstance(test_loader.dataset, DiskDataset) assert len(test_loader.classes) == len(DISK_CLASSES_FULL_MRC) assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL_MRC)) -def test_load_dataset_all_classes_npy(): +def test_load_dataset_all_classes_npy(test_data_npy_dir): """ Test case for loading a dataset with all classes using npy files. @@ -106,13 +104,13 @@ def test_load_dataset_all_classes_npy(): classes=DISK_CLASSES_FULL_NPY, dataset_size=DATASET_SIZE_ALL, ) - test_loader.load(datapath=TEST_DATA_NPY, datatype=DATATYPE_NPY) + test_loader.load(datapath=test_data_npy_dir, datatype=DATATYPE_NPY) assert isinstance(test_loader.dataset, DiskDataset) assert len(test_loader.classes) == len(DISK_CLASSES_FULL_NPY) assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_FULL_NPY)) -def test_load_dataset_some_classes(): +def test_load_dataset_some_classes(test_data_mrc_dir): """ Test case for loading a dataset with some specific classes using DiskDataLoader. """ @@ -121,13 +119,13 @@ def test_load_dataset_some_classes(): classes=DISK_CLASSES_SOME_MRC, dataset_size=DATASET_SIZE_ALL, ) - test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_loader.load(datapath=test_data_mrc_dir, datatype=DATATYPE_MRC) assert isinstance(test_loader.dataset, DiskDataset) assert len(test_loader.classes) == len(DISK_CLASSES_SOME_MRC) assert all(a == b for a, b in zip(test_loader.classes, DISK_CLASSES_SOME_MRC)) -def test_load_dataset_missing_class(): +def test_load_dataset_missing_class(test_data_mrc_dir): """ Test case for loading dataset with missing classes. """ @@ -137,10 +135,10 @@ def test_load_dataset_missing_class(): dataset_size=DATASET_SIZE_ALL, ) with pytest.raises(Exception, match=r".*Missing classes: .*"): - test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_loader.load(datapath=test_data_mrc_dir, datatype=DATATYPE_MRC) -def test_one_image(): +def test_one_image(test_data_mrc_dir): """ Test case for loading one image using DiskDataLoader. @@ -149,14 +147,14 @@ def test_one_image(): test_loader = DiskDataLoader( pipeline=DISK_PIPELINE, classes=DISK_CLASSES_NONE, dataset_size=DATASET_SIZE_ALL ) - test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_loader.load(datapath=test_data_mrc_dir, datatype=DATATYPE_MRC) test_dataset = test_loader.dataset test_item_image, test_item_name = test_dataset.__getitem__(1) assert test_item_name in DISK_CLASSES_FULL_MRC assert isinstance(test_item_image, torch.Tensor) -def test_get_loader_training_false(): +def test_get_loader_training_false(test_data_mrc_dir): """ Test case for the `get_loader` method of the `DiskDataLoader` class when `training` is set to False. """ @@ -166,12 +164,12 @@ def test_get_loader_training_false(): dataset_size=DATASET_SIZE_ALL, training=False, ) - test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_loader.load(datapath=test_data_mrc_dir, datatype=DATATYPE_MRC) torch_loader = test_loader.get_loader(batch_size=64) assert isinstance(torch_loader, torch.utils.data.DataLoader) -def test_get_loader_training_true(): +def test_get_loader_training_true(test_data_mrc_dir): """ Test case for the `get_loader` method of the `DiskDataLoader` class when training is set to True. """ @@ -181,7 +179,7 @@ def test_get_loader_training_true(): dataset_size=DATASET_SIZE_ALL, training=True, ) - test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_loader.load(datapath=test_data_mrc_dir, datatype=DATATYPE_MRC) torch_loader_train, torch_loader_val = test_loader.get_loader( split_size=0.8, batch_size=64 ) @@ -189,7 +187,7 @@ def test_get_loader_training_true(): assert isinstance(torch_loader_val, torch.utils.data.DataLoader) -def test_get_loader_training_fail(): +def test_get_loader_training_fail(test_data_mrc_dir): """ Test case for the `get_loader` method of the `DiskDataLoader` class when training fails. @@ -201,14 +199,14 @@ def test_get_loader_training_fail(): dataset_size=DATASET_SIZE_ALL, training=True, ) - test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_loader.load(datapath=test_data_mrc_dir, datatype=DATATYPE_MRC) with pytest.raises(Exception, match=r".* sets must be larger than .*"): torch_loader_train, torch_loader_val = test_loader.get_loader( split_size=1, batch_size=64 ) -def test_processing_data_all_transforms(): +def test_processing_data_all_transforms(test_data_mrc_dir): """ Test the processing of data with all transforms applied. @@ -225,7 +223,7 @@ def test_processing_data_all_transforms(): training=True, transformations=TRANSFORM_ALL, ) - test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_loader.load(datapath=test_data_mrc_dir, datatype=DATATYPE_MRC) assert test_loader.dataset.normalise assert test_loader.dataset.shiftmin assert test_loader.dataset.gaussianblur @@ -235,7 +233,7 @@ def test_processing_data_all_transforms(): assert label in DISK_CLASSES_FULL_MRC -def test_processing_data_some_transforms_npy(): +def test_processing_data_some_transforms_npy(test_data_npy_dir): """ Test case for processing data with some transformations using the DiskDataLoader class. @@ -256,8 +254,8 @@ def test_processing_data_some_transforms_npy(): dataset_size=DATASET_SIZE_ALL, training=True, ) - test_loader_none.load(datapath=TEST_DATA_NPY, datatype=DATATYPE_NPY) - test_loader_transf.load(datapath=TEST_DATA_NPY, datatype=DATATYPE_NPY) + test_loader_none.load(datapath=test_data_npy_dir, datatype=DATATYPE_NPY) + test_loader_transf.load(datapath=test_data_npy_dir, datatype=DATATYPE_NPY) assert test_loader_transf.dataset.normalise assert not test_loader_transf.dataset.shiftmin assert test_loader_transf.dataset.gaussianblur @@ -273,7 +271,7 @@ def test_processing_data_some_transforms_npy(): assert len(image_none[1]) == len(image_transf[1]) -def test_processing_data_rescale(): +def test_processing_data_rescale(test_data_mrc_dir): """ Test the processing of data with rescaling. @@ -288,7 +286,7 @@ def test_processing_data_rescale(): training=True, transformations=TRANSFORM_ALL_RESCALE, ) - test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_loader.load(datapath=test_data_mrc_dir, datatype=DATATYPE_MRC) assert test_loader.dataset.normalise assert test_loader.dataset.shiftmin assert test_loader.dataset.gaussianblur @@ -305,7 +303,7 @@ def test_processing_data_rescale(): training=True, transformations=TRANSFORM_RESCALE, ) - test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_loader.load(datapath=test_data_mrc_dir, datatype=DATATYPE_MRC) assert not test_loader.dataset.normalise assert not test_loader.dataset.shiftmin assert not test_loader.dataset.gaussianblur @@ -316,7 +314,7 @@ def test_processing_data_rescale(): assert label in DISK_CLASSES_FULL_MRC -def test_processing_after_load(): +def test_processing_after_load(test_data_mrc_dir): """ Test the processing steps after loading data using DiskDataLoader. """ @@ -327,14 +325,14 @@ def test_processing_after_load(): training=False, ) test_loader.debug = True - test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_loader.load(datapath=test_data_mrc_dir, datatype=DATATYPE_MRC) assert test_loader.transformations is None assert not test_loader.dataset.normalise assert not test_loader.dataset.shiftmin assert not test_loader.dataset.gaussianblur test_loader.transformations = TRANSFORM_ALL_RESCALE pre_dataset = test_loader.dataset - test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_loader.load(datapath=test_data_mrc_dir, datatype=DATATYPE_MRC) post_dataset = test_loader.dataset assert test_loader.dataset.normalise assert test_loader.dataset.shiftmin @@ -346,7 +344,7 @@ def test_processing_after_load(): assert not torch.equal(pre_image, post_image) -def test_drop_last(): +def test_drop_last(test_data_mrc_dir): """ Test the drop_last parameter in the get_loader method of the DiskDataLoader class. """ @@ -356,7 +354,7 @@ def test_drop_last(): dataset_size=DATASET_SIZE_ALL, training=True, ) - test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_loader.load(datapath=test_data_mrc_dir, datatype=DATATYPE_MRC) loader_train_true, loader_val_true = test_loader.get_loader( split_size=0.7, batch_size=64, no_val_drop=True ) @@ -369,11 +367,11 @@ def test_drop_last(): assert loader_val_false.drop_last -def test_corrupt_mrcfile(): +def test_corrupt_mrcfile(test_data_mrc_dir): """ Test that corrupt mrcfiles are not loaded and throw an exception. """ - test_dataset = DiskDataset(paths=TEST_DATA_MRC, datatype=DATATYPE_MRC) + test_dataset = DiskDataset(paths=test_data_mrc_dir, datatype=DATATYPE_MRC) assert isinstance(test_dataset, DiskDataset) with pytest.raises(Exception, match=r".* corrupted."): test_dataset.read(TEST_CORRUPT) diff --git a/tests/test_map_io.py b/tests/test_map_io.py new file mode 100644 index 0000000..37230fa --- /dev/null +++ b/tests/test_map_io.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from pathlib import Path + +import torch + +from caked.dataloader import MapDataLoader, MapDataset +from caked.hdf5 import HDF5DataStore, LRUCache +from caked.utils import add_dataset_to_HDF5, duplicate_and_augment_from_hdf5 + +ORIG_DIR = Path.cwd() + + +DISK_CLASSES_NONE = None +DATATYPE_MRC = "mrc" +VOXNORM = "voxnorm" +NORM = "norm" +MASKCROP = "maskcrop" +PADDING = "padding" +ROTATION = "randrot" +TRANSFORM_ALL = [VOXNORM, NORM, PADDING] +AUGMENT_ALL = [ROTATION] + + +def test_map_dataloader(): + test_loader = MapDataLoader() + + assert test_loader is not None + assert isinstance(test_loader, MapDataLoader) + + +def test_map_dataset(test_data_single_mrc_dir): + test_map_dataset = MapDataset( + path=next(test_data_single_mrc_dir.glob(f"*{DATATYPE_MRC}")) + ) + assert test_map_dataset is not None + assert isinstance(test_map_dataset, MapDataset) + + +def test_slices(test_data_single_mrc_dir): + hdf5_store = HDF5DataStore("test.hdf5", cache_size=1) + + test_map_dataset = MapDataset( + path=next(test_data_single_mrc_dir.glob(f"*{DATATYPE_MRC}")), + transforms=[], + augments=[], + map_hdf5_store=hdf5_store, + ) + test_map_dataset.load_map_objects() + + add_dataset_to_HDF5( + test_map_dataset.mapobj.data, + None, + None, + "realmap", + hdf5_store, + ) + slice_ = test_map_dataset.__getitem__(0)[0] + + assert isinstance(slice_, torch.Tensor) + assert len(test_map_dataset) == 4 + assert slice_.shape == (2, 32, 32, 32) + + +def test_transforms(test_data_single_mrc_dir): + hdf5_store = HDF5DataStore("test.hdf5", cache_size=1) + test_map_dataset = MapDataset( + path=next(test_data_single_mrc_dir.glob(f"*{DATATYPE_MRC}")), + map_hdf5_store=hdf5_store, + transforms=TRANSFORM_ALL, + augments=[], + ) + test_map_dataset.load_map_objects() + test_map_dataset.transform() + add_dataset_to_HDF5( + test_map_dataset.mapobj.data, + None, + None, + "realmap", + hdf5_store, + ) + slice_ = test_map_dataset.__getitem__(0)[0] + + assert len(test_map_dataset) == 64 + assert slice_.shape == (2, 32, 32, 32) + + +def test_dataloader_load_to_HDF5_file(test_data_single_mrc_temp_dir): + test_map_dataloader = MapDataLoader() + test_map_dataloader.load( + datapath=test_data_single_mrc_temp_dir, + datatype=DATATYPE_MRC, + ) + + assert test_map_dataloader is not None + assert isinstance(test_map_dataloader, MapDataLoader) + assert test_map_dataloader.dataset is not None + assert test_map_dataloader.dataset.datasets[0].map_hdf5_store.save_path.exists() + + +def test_dataloader_load_and_decompose(test_data_single_mrc_temp_dir): + test_map_dataloader = MapDataLoader() + test_map_dataloader.load( + datapath=test_data_single_mrc_temp_dir, + datatype=DATATYPE_MRC, + cshape=16, + ) + + assert test_map_dataloader is not None + assert isinstance(test_map_dataloader, MapDataLoader) + test_map_dataset = test_map_dataloader.dataset.datasets[0] + slice_ = test_map_dataset.__getitem__(0)[0] + + assert slice_.shape == (2, 16, 16, 16) + + +def test_dataloader_load_to_HDF5_file_with_transforms(test_data_single_mrc_temp_dir): + test_map_dataloader = MapDataLoader( + transformations=TRANSFORM_ALL, + ) + test_map_dataloader.load( + datapath=test_data_single_mrc_temp_dir, + datatype=DATATYPE_MRC, + ) + + assert test_map_dataloader is not None + assert isinstance(test_map_dataloader, MapDataLoader) + assert test_map_dataloader.dataset is not None + assert test_map_dataloader.dataset.datasets[0].map_hdf5_store.save_path.exists() + + +def test_add_duplicate_dataset_to_dataloader(test_data_single_mrc_temp_dir): + test_map_dataloader = MapDataLoader( + transformations=TRANSFORM_ALL, + ) + test_map_dataloader.load( + datapath=test_data_single_mrc_temp_dir, + datatype=DATATYPE_MRC, + ) + + duplicate_and_augment_from_hdf5( + test_map_dataloader, + ids=[ + next(test_data_single_mrc_temp_dir.glob(f"*{DATATYPE_MRC}")).stem, + next(test_data_single_mrc_temp_dir.glob(f"*{DATATYPE_MRC}")).stem, + ], + ) + hdf5_store = test_map_dataloader.dataset.datasets[0].map_hdf5_store + + assert len(hdf5_store.keys()) == 3 + assert "realmap_map" in hdf5_store + assert "1--realmap_map" in hdf5_store + assert "2--realmap_map" in hdf5_store + + +def test_add_duplicate_dataset_to_dataloader_with_augments( + test_data_single_mrc_temp_dir, +): + test_map_dataloader = MapDataLoader( + transformations=TRANSFORM_ALL, + ) + test_map_dataloader.load( + datapath=test_data_single_mrc_temp_dir, + datatype=DATATYPE_MRC, + ) + duplicate_and_augment_from_hdf5( + ids=[next(test_data_single_mrc_temp_dir.glob(f"*{DATATYPE_MRC}")).stem], + map_data_loader=test_map_dataloader, + augmentations=AUGMENT_ALL, + ) + hdf5_store = test_map_dataloader.dataset.datasets[0].map_hdf5_store + assert len(hdf5_store.keys()) == 2 + assert "realmap_map" in hdf5_store + assert "1--realmap_map" in hdf5_store + + assert len(test_map_dataloader.dataset.datasets[0]) == 64 + assert len(test_map_dataloader.dataset.datasets[1]) == 64 + + assert len(test_map_dataloader.dataset) == 128 + + +def test_dataloader_load_multi_process(test_data_single_mrc_temp_dir): + test_map_dataloader = MapDataLoader() + test_map_dataloader.load( + datapath=test_data_single_mrc_temp_dir, + datatype=DATATYPE_MRC, + num_workers=2, + ) + + assert test_map_dataloader is not None + assert isinstance(test_map_dataloader, MapDataLoader) + assert test_map_dataloader.dataset is not None + assert test_map_dataloader.dataset.datasets[0].map_hdf5_store.save_path.exists() + + # test_map_dataloader. + + +def test_lru_cache(test_data_single_mrc_dir): + cache = LRUCache(1) + hdf5_store = HDF5DataStore("test.hdf5", cache=cache) + + test_map_dataset = MapDataset( + path=next(test_data_single_mrc_dir.glob(f"*{DATATYPE_MRC}")), + transforms=[], + augments=[], + map_hdf5_store=hdf5_store, + ) + test_map_dataset.load_map_objects() + + add_dataset_to_HDF5( + test_map_dataset.mapobj.data, + None, + None, + "realmap", + hdf5_store, + ) + + assert "realmap_map" not in hdf5_store.cache + + _ = test_map_dataset.__getitem__(0)[0] + + assert "realmap_map" in hdf5_store.cache diff --git a/tests/testdata_mrc/mrc/realmap.mrc b/tests/testdata_mrc/mrc/realmap.mrc new file mode 100644 index 0000000..f1f9266 Binary files /dev/null and b/tests/testdata_mrc/mrc/realmap.mrc differ