diff --git a/hi-ml-cpath/.vscode/launch.json b/hi-ml-cpath/.vscode/launch.json index 8946691de..3cee000e3 100644 --- a/hi-ml-cpath/.vscode/launch.json +++ b/hi-ml-cpath/.vscode/launch.json @@ -41,16 +41,18 @@ "name": "Python: Run SlidesPandaImageNetMIL locally", "type": "python", "request": "launch", + "justMyCode": false, "program": "${workspaceFolder}/../hi-ml/src/health_ml/runner.py", "args": [ "--model=health_cpath.SlidesPandaImageNetMIL", "--pl_fast_dev_run=10", "--crossval_count=0", "--batch_size=2", - "--max_bag_size=4", - "--max_bag_size_inf=4", + "--max_bag_size=5", + "--max_bag_size_inf=5", "--num_top_slides=2", - "--num_top_tiles=2" + "--num_top_tiles=2", + "--max_num_gpus=1", ], "console": "integratedTerminal", }, diff --git a/hi-ml-cpath/primary_deps.yml b/hi-ml-cpath/primary_deps.yml index e1e84889e..8eb3193fa 100644 --- a/hi-ml-cpath/primary_deps.yml +++ b/hi-ml-cpath/primary_deps.yml @@ -33,6 +33,47 @@ dependencies: - ruamel.yaml==0.16.12 - tensorboard==2.6.0 # Histopathology requirements + - coloredlogs==15.0.1 + - cucim==22.04.00 + - flake8==4.0.1 + - girder-client==3.1.14 + - joblib==0.16.0 + - jupyter==1.0.0 + - jupyter-client==7.3.4 + - lightning-bolts==0.4.0 + - mlflow==1.17.0 + # commit of dev branch containing transform with coordinates + # - git+https://github.com/Project-MONAI/MONAI.git@df4a7d72e1d231b898f88d92cf981721c49ceaeb + # commit of dev branch including latest fixed to GridPatch 22/06 + # - git+https://github.com/Project-MONAI/MONAI.git@669bddf581201f994d1bcc0cb780854901605d9b + # commit of dev branch that includes latest fix to GridPatch 06/07 + - git+https://github.com/Project-MONAI/MONAI.git@4ddd2bc3870a86fb0a300c20e680de48886bbfc1 + - more-itertools==8.10.0 + - mypy==0.961 + - mypy-extensions==0.4.3 + - numba==0.51.2 + - numpy==1.19.1 + - pillow==9.0.0 + - psutil==5.7.2 + - pydicom==2.0.0 + - pyflakes==2.4.0 + - PyJWT==1.7.1 + - rich==12.4.4 + - runstats==1.8.0 + - scikit-image==0.17.2 + - scipy==1.5.2 + - seaborn==0.10.1 + - simpleitk==1.2.4 + - six==1.15.0 + - stopit==1.1.2 + - tabulate==0.8.7 + - tifffile==2021.11.2 + - torch==1.10.0 + - torchmetrics==0.6.0 + - torchvision==0.11.1 + - types-python-dateutil==2.8.9 + - umap-learn==0.5.2 + - yacs==0.1.8 - -r requirements_run.txt # Test requirements - -r requirements_test.txt diff --git a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py index f624c47e5..3a83dd0d6 100644 --- a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py +++ b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py @@ -151,10 +151,8 @@ def get_data_module(self) -> PandaSlidesDataModule: max_bag_size=self.max_bag_size, max_bag_size_inf=self.max_bag_size_inf, tile_size=self.tile_size, - step=self.step, random_offset=self.random_offset, seed=self.get_effective_random_seed(), - pad_full=self.pad_full, background_val=self.background_val, filter_mode=self.filter_mode, transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN), @@ -166,6 +164,11 @@ def get_data_module(self) -> PandaSlidesDataModule: def get_slides_dataset(self) -> PandaDataset: return PandaDataset(root=self.local_datasets[0]) # type: ignore + def get_test_plot_options(self) -> Set[PlotOption]: + plot_options = super().get_test_plot_options() + plot_options.add(PlotOption.SLIDE_THUMBNAIL_HEATMAP) + return plot_options + class SlidesPandaImageNetMIL(DeepSMILESlidesPanda): def __init__(self, **kwargs: Any) -> None: diff --git a/hi-ml-cpath/src/health_cpath/datamodules/base_module.py b/hi-ml-cpath/src/health_cpath/datamodules/base_module.py index 4eff4b79f..9fb397709 100644 --- a/hi-ml-cpath/src/health_cpath/datamodules/base_module.py +++ b/hi-ml-cpath/src/health_cpath/datamodules/base_module.py @@ -15,14 +15,14 @@ from health_ml.utils.bag_utils import BagDataset, multibag_collate from health_ml.utils.common_utils import _create_generator -from health_cpath.utils.wsi_utils import image_collate +from health_cpath.utils.wsi_utils import array_collate from health_cpath.models.transforms import LoadTilesBatchd from health_cpath.datasets.base_dataset import SlidesDataset, TilesDataset from health_cpath.utils.naming import ModelKey from monai.transforms.compose import Compose from monai.transforms.io.dictionary import LoadImaged -from monai.apps.pathology.transforms import TileOnGridd +from monai.transforms import RandGridPatchd, GridPatchd from monai.data.image_reader import WSIReader _SlidesOrTilesDataset = TypeVar('_SlidesOrTilesDataset', SlidesDataset, TilesDataset) @@ -245,11 +245,12 @@ def __init__( self, level: Optional[int] = 1, tile_size: Optional[int] = 224, - step: Optional[int] = None, random_offset: Optional[bool] = True, - pad_full: Optional[bool] = False, background_val: Optional[int] = 255, - filter_mode: Optional[str] = "min", + filter_mode: Optional[str] = "max", + overlap: Optional[float] = 0, + intensity_threshold: Optional[float] = 0, + pad_mode: Optional[str] = "constant", **kwargs: Any, ) -> None: """ @@ -257,60 +258,72 @@ def __init__( this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend :param tile_size: size of the square tile, defaults to 224 this param is passed to TileOnGridd monai transform for tiling on the fly. - :param step: step size to create overlapping tiles, defaults to None (same as tile_size) - Use a step < tile_size to create overlapping tiles, analogousely a step > tile_size will skip some chunks in - the wsi. This param is passed to TileOnGridd monai transform for tiling on the fly. :param random_offset: randomize position of the grid, instead of starting from the top-left corner, defaults to True. This param is passed to TileOnGridd monai transform for tiling on the fly. - :param pad_full: pad image to the size evenly divisible by tile_size, defaults to False - This param is passed to TileOnGridd monai transform for tiling on the fly. :param background_val: the background constant to ignore background tiles (e.g. 255 for white background), defaults to 255. This param is passed to TileOnGridd monai transform for tiling on the fly. - :param filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is greater than - tile_count, then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for - random) subset, defaults to "min" (which assumes background is high value). This param is passed to TileOnGridd - monai transform for tiling on the fly. + :param filter_mode: when `num_patches` is provided, it determines if keep patches with highest values + (`"max"`), lowest values (`"min"`), or in their default order (`None`). Default to None. + :param overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). + If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. + :param intensity_threshold: a value to keep only the patches whose sum of intensities are less than the + threshold. Defaults to no filtering. + :pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. """ super().__init__(**kwargs) self.level = level self.tile_size = tile_size - self.step = step self.random_offset = random_offset - self.pad_full = pad_full self.background_val = background_val self.filter_mode = filter_mode - # TileOnGridd transform expects None to select all foreground tile so we hardcode max_bag_size and + # Tiling transform expects None to select all foreground tile so we hardcode max_bag_size and # max_bag_size_inf to None if set to 0 self.max_bag_size = None if self.max_bag_size == 0 else self.max_bag_size # type: ignore self.max_bag_size_inf = None if self.max_bag_size_inf == 0 else self.max_bag_size_inf # type: ignore + self.overlap = overlap + self.intensity_threshold = intensity_threshold + self.pad_mode = pad_mode def _load_dataset(self, slides_dataset: SlidesDataset, stage: ModelKey) -> Dataset: - base_transform = Compose( - [ - LoadImaged( - keys=slides_dataset.IMAGE_COLUMN, - reader=WSIReader, - backend="cuCIM", - dtype=np.uint8, - level=self.level, - image_only=True, - ), - TileOnGridd( - keys=slides_dataset.IMAGE_COLUMN, - tile_count=self.max_bag_size if stage == ModelKey.TRAIN else self.max_bag_size_inf, - tile_size=self.tile_size, - step=self.step, - random_offset=self.random_offset if stage == ModelKey.TRAIN else False, - pad_full=self.pad_full, - background_val=self.background_val, - filter_mode=self.filter_mode, - return_list_of_dicts=True, - ), - ] + load_image_transform = LoadImaged( + keys=slides_dataset.IMAGE_COLUMN, + reader=WSIReader, # type: ignore + backend="cuCIM", + dtype=np.uint8, + level=self.level, + image_only=True, ) - if self.transforms_dict and self.transforms_dict[stage]: + max_offset = None if (self.random_offset and stage == ModelKey.TRAIN) else 0 + + if stage != ModelKey.TRAIN: + grid_transform = RandGridPatchd( + keys=[slides_dataset.IMAGE_COLUMN], + patch_size=[self.tile_size, self.tile_size], # type: ignore + num_patches=self.max_bag_size, + sort_fn=self.filter_mode, + pad_mode=self.pad_mode, # type: ignore + constant_values=self.background_val, + overlap=self.overlap, # type: ignore + threshold=self.intensity_threshold, + max_offset=max_offset, + ) + else: + grid_transform = GridPatchd( + keys=[slides_dataset.IMAGE_COLUMN], + patch_size=[self.tile_size, self.tile_size], # type: ignore + num_patches=self.max_bag_size_inf, + sort_fn=self.filter_mode, + pad_mode=self.pad_mode, # type: ignore + constant_values=self.background_val, + overlap=self.overlap, # type: ignore + threshold=self.intensity_threshold, + offset=max_offset, + ) + + base_transform = Compose([load_image_transform, grid_transform]) - transforms = Compose([base_transform, self.transforms_dict[stage]]).flatten() + if self.transforms_dict and self.transforms_dict[stage]: + transforms = Compose([base_transform, self.transforms_dict[stage]]).flatten() # type: ignore else: transforms = base_transform # The tiling transform is randomized. Make them deterministic. This call needs to be @@ -325,7 +338,7 @@ def _get_dataloader(self, dataset: SlidesDataset, stage: ModelKey, shuffle: bool return DataLoader( transformed_slides_dataset, batch_size=self.batch_size, - collate_fn=image_collate, + collate_fn=array_collate, shuffle=shuffle, generator=generator, **dataloader_kwargs, diff --git a/hi-ml-cpath/src/health_cpath/models/deepmil.py b/hi-ml-cpath/src/health_cpath/models/deepmil.py index a4138660c..de4c847c1 100644 --- a/hi-ml-cpath/src/health_cpath/models/deepmil.py +++ b/hi-ml-cpath/src/health_cpath/models/deepmil.py @@ -3,7 +3,7 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ import torch -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from pytorch_lightning.utilities.warnings import rank_zero_warn from pathlib import Path @@ -353,20 +353,82 @@ def get_bag_label(labels: Tensor) -> Tensor: # SlidesDataModule attributes a single label to a bag of tiles already no need to do majority voting return labels + @staticmethod + def get_empty_lists(shape: int, n: int) -> List: + ll = [] + for _ in range(n): + ll.append([None] * shape) + return ll + + @staticmethod + def get_patch_coordinate(slide_offset: List, patch_location: List[int], patch_size: List[int] + ) -> Tuple[int, int, int, int]: + """ computing absolute patch coordinate """ + # PATCH_LOCATION is expected to have shape [y, x] + top = slide_offset[0] + patch_location[0] + bottom = slide_offset[0] + patch_location[0] + patch_size[0] + left = slide_offset[1] + patch_location[1] + right = slide_offset[1] + patch_location[1] + patch_size[1] + return top, bottom, left, right + + @staticmethod + def expand_slide_constant_metadata(id: str, path: str, n_patches: int, top: List[int], + bottom: List[int], left: List[int], right: List[int]) -> Tuple[List, List, List]: + """Duplicate metadata that is patch invariant to match the shape of other arrays""" + slide_id = [id] * n_patches + image_paths = [path] * n_patches + tile_id = [f"{id}_left_{left[i]}_top_{top[i]}_right_{right[i]}_bottom_{bottom[i]}" for i in range(n_patches)] + return slide_id, image_paths, tile_id + + def get_slide_patch_coordinates(self, slide_offset: List, patches_location: List, patch_size: List + ) -> Tuple[List, List, List, List]: + """ computing absolute coordinates for all patches in a slide""" + top, bottom, left, right = self.get_empty_lists(len(patches_location), 4) + for i, location in enumerate(patches_location): + top[i], bottom[i], left[i], right[i] = self.get_patch_coordinate(slide_offset, location, patch_size) + return top, bottom, left, right + + def compute_slide_metadata(self, batch: Dict, index: int, metadata_dict: Dict) -> Dict: + """compute patch-dependent and patch-invariante metadata for a single slide """ + offset = batch[SlideKey.OFFSET.value][index] + patches_location = batch[SlideKey.TILE_LOCATION.value][index] + patch_size = batch[SlideKey.TILE_SIZE.value][index] + n_patches = len(patches_location) + id = batch[SlideKey.SLIDE_ID][index] + path = batch[SlideKey.IMAGE_PATH][index] + + top, bottom, left, right = self.get_slide_patch_coordinates(offset, patches_location, patch_size) + slide_id, image_paths, tile_id = self.expand_slide_constant_metadata( + id, path, n_patches, top, bottom, left, right + ) + + metadata_dict[ResultsKey.TILE_TOP] = top + metadata_dict[ResultsKey.TILE_BOTTOM] = bottom + metadata_dict[ResultsKey.TILE_LEFT] = left + metadata_dict[ResultsKey.TILE_RIGHT] = right + metadata_dict[ResultsKey.SLIDE_ID] = slide_id + metadata_dict[ResultsKey.TILE_ID] = tile_id + metadata_dict[ResultsKey.IMAGE_PATH] = image_paths + return metadata_dict + def update_results_with_data_specific_info(self, batch: Dict, results: Dict) -> None: - # WARNING: This is a dummy input until we figure out tiles coordinates retrieval in the next iteration. - bag_sizes = [tiles.shape[0] for tiles in batch[SlideKey.IMAGE]] - results.update( - { - ResultsKey.SLIDE_ID: [ - [slide_id] * bag_sizes[i] for i, slide_id in enumerate(batch[SlideKey.SLIDE_ID]) - ], - ResultsKey.TILE_ID: [ - [f"{slide_id}_{tile_id}" for tile_id in range(bag_sizes[i])] - for i, slide_id in enumerate(batch[SlideKey.SLIDE_ID]) - ], - ResultsKey.IMAGE_PATH: [ - [img_path] * bag_sizes[i] for i, img_path in enumerate(batch[SlideKey.IMAGE_PATH]) - ], + if all(key.value in batch.keys() for key in [SlideKey.OFFSET, SlideKey.TILE_LOCATION, SlideKey.TILE_SIZE]): + n_slides = len(batch[SlideKey.SLIDE_ID]) + metadata_dict: Dict[str, List[Union[int, str]]] = { + ResultsKey.TILE_TOP: [], + ResultsKey.TILE_BOTTOM: [], + ResultsKey.TILE_LEFT: [], + ResultsKey.TILE_RIGHT: [], + ResultsKey.SLIDE_ID: [], + ResultsKey.TILE_ID: [], + ResultsKey.IMAGE_PATH: [], } - ) + results.update(metadata_dict) + # each slide can have a different number of patches + for i in range(n_slides): + updated_metadata_dict = self.compute_slide_metadata(batch, i, metadata_dict) + for key in metadata_dict.keys(): + results[key].append(updated_metadata_dict[key]) + else: + rank_zero_warn(message="Offset, patch location or patch size are not found in the batch" + "make sure to use RandGridPatch.") diff --git a/hi-ml-cpath/src/health_cpath/utils/naming.py b/hi-ml-cpath/src/health_cpath/utils/naming.py index d6bae1e03..ac511370e 100644 --- a/hi-ml-cpath/src/health_cpath/utils/naming.py +++ b/hi-ml-cpath/src/health_cpath/utils/naming.py @@ -4,6 +4,7 @@ # ------------------------------------------------------------------------------------------ from enum import Enum +from monai.utils import WSIPatchKeys class SlideKey(str, Enum): @@ -19,6 +20,10 @@ class SlideKey(str, Enum): FOREGROUND_THRESHOLD = 'foreground_threshold' METADATA = 'metadata' LOCATION = 'location' + TILE_SIZE = WSIPatchKeys.SIZE.value # 'patch_size' + TILE_LOCATION = WSIPatchKeys.LOCATION.value # 'patch_location' + OFFSET = 'offset' + SHAPE = 'original_spatial_shape' class TileKey(str, Enum): diff --git a/hi-ml-cpath/src/health_cpath/utils/output_utils.py b/hi-ml-cpath/src/health_cpath/utils/output_utils.py index af210508b..0c387c70e 100644 --- a/hi-ml-cpath/src/health_cpath/utils/output_utils.py +++ b/hi-ml-cpath/src/health_cpath/utils/output_utils.py @@ -72,6 +72,8 @@ def normalize_dict_for_df(dict_old: Dict[ResultsKey, Any]) -> Dict[str, Any]: value = value.squeeze(0).cpu().numpy() if value.ndim == 0: value = np.full(bag_size, fill_value=value) + if isinstance(value, List) and isinstance(value[0], torch.Tensor): + value = [value[i].item() for i in range(len(value))] dict_new[key] = value elif key == ResultsKey.CLASS_PROBS: if isinstance(value, torch.Tensor): @@ -134,11 +136,17 @@ def save_outputs_csv(results: ResultsType, outputs_dir: Path) -> None: # Collect the list of dictionaries in a list of pandas dataframe and save df_list = [] + skipped_slides = 0 for slide_dict in list_slide_dicts: slide_dict = normalize_dict_for_df(slide_dict) # type: ignore - df_list.append(pd.DataFrame.from_dict(slide_dict)) + try: + df_list.append(pd.DataFrame.from_dict(slide_dict)) + except ValueError: + skipped_slides += 1 + logging.warning(f"something wrong in the dimension of slide {slide_dict[ResultsKey.SLIDE_ID][0]}") df = pd.concat(df_list, ignore_index=True) df.to_csv(csv_filename, mode='w+', header=True) + logging.warning(f"{skipped_slides} slides have not been included in the ouputs because of issues with the outputs") def save_features(results: ResultsType, outputs_dir: Path) -> None: diff --git a/hi-ml-cpath/src/health_cpath/utils/wsi_utils.py b/hi-ml-cpath/src/health_cpath/utils/wsi_utils.py index aea7bf4f8..823acad1a 100644 --- a/hi-ml-cpath/src/health_cpath/utils/wsi_utils.py +++ b/hi-ml-cpath/src/health_cpath/utils/wsi_utils.py @@ -1,22 +1,58 @@ import torch import numpy as np +import logging from typing import Any, List from health_cpath.utils.naming import SlideKey from health_ml.utils.bag_utils import multibag_collate +from monai.utils import WSIPatchKeys +from monai.data import MetaTensor +slide_metadata_keys = [ + SlideKey.IMAGE_PATH, + SlideKey.LABEL, + SlideKey.MASK, + SlideKey.METADATA, + SlideKey.SLIDE_ID, + SlideKey.MASK_PATH, + WSIPatchKeys.COUNT, + SlideKey.TILE_SIZE, # TODO: remove in case we want to allow patches of different sizes from the same slide + SlideKey.SHAPE, + SlideKey.OFFSET +] -def image_collate(batch: List) -> Any: + +def array_collate(batch: List) -> Any: """ - Combine instances from a list of dicts into a single dict, by stacking them along first dim + Combine instances from a list of dicts into a single dict, by stacking arrays along first dim [{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW} - followed by the default collate which will form a batch BxNx3xHxW. - The list of dicts refers to the the list of tiles produced by the TileOnGridd transform applied on a WSI. + followed by the default collate which will form a batch BxNx3xHxW. It also convert some values to tensors. + The list of dicts refers to the list of tiles produced by GridPatch transform applied on a WSI. """ + collate_keys = [] + constant_keys = slide_metadata_keys + for key in batch[0][0].keys(): + if key not in slide_metadata_keys: + if isinstance(batch[0][0][key], np.ndarray) or isinstance(batch[0][0][key], MetaTensor): + collate_keys.append(key) + else: + logging.warning("Only np.ndarray and MetaTensors are collated -" + f"{key} value will be taken from first patch") + constant_keys.append(key) + tensor_keys = collate_keys + [SlideKey.LABEL] - for i, item in enumerate(batch): - data = item[0] - data[SlideKey.IMAGE] = torch.tensor(np.array([ix[SlideKey.IMAGE] for ix in item])) - data[SlideKey.LABEL] = torch.tensor(data[SlideKey.LABEL]) - batch[i] = data + new_batch: List[dict] = [] + for patch_data in batch: + # we assume all patches are dictionaries with the same keys + data = patch_data[0] + for key in collate_keys: + if isinstance(data[key], np.ndarray): + data[key] = np.array([ix[key] for ix in patch_data]) + elif isinstance(data[key], MetaTensor): + # TODO change how this collation happens if we have list of tensors + data[key] = np.array([ix[key].as_tensor().numpy() for ix in patch_data]) + for key in tensor_keys: + data[key] = torch.tensor(data[key]) + new_batch.append(data) + batch = new_batch return multibag_collate(batch) diff --git a/hi-ml-cpath/testhisto/testhisto/utils/test_wsi_utils.py b/hi-ml-cpath/testhisto/testhisto/utils/test_wsi_utils.py index 8f20a6f4e..81db4639b 100644 --- a/hi-ml-cpath/testhisto/testhisto/utils/test_wsi_utils.py +++ b/hi-ml-cpath/testhisto/testhisto/utils/test_wsi_utils.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List from typing import Sequence from health_cpath.utils.naming import SlideKey -from health_cpath.utils.wsi_utils import image_collate +from health_cpath.utils.wsi_utils import array_collate from torch.utils.data import Dataset @@ -39,7 +39,7 @@ def __getitem__(self, index: int) -> List[Dict[SlideKey, Any]]: @pytest.mark.parametrize("random_n_tiles", [False, True]) -def test_image_collate(random_n_tiles: bool) -> None: +def test_array_collate(random_n_tiles: bool) -> None: # random_n_tiles accounts for both train and inference settings where the number of tiles is fixed (during # training) and None during inference (validation and test) dataset = MockTiledWSIDataset(n_tiles=20, @@ -51,7 +51,7 @@ def test_image_collate(random_n_tiles: bool) -> None: batch_size = 5 samples_list = [dataset[idx] for idx in range(batch_size)] - batch: dict = image_collate(samples_list) + batch: dict = array_collate(samples_list) assert isinstance(batch, Dict) assert batch.keys() == samples_list[0].keys() # type: ignore