diff --git a/brainles_preprocessing/brain_extraction/brain_extractor.py b/brainles_preprocessing/brain_extraction/brain_extractor.py index 03a716b..fb662d6 100644 --- a/brainles_preprocessing/brain_extraction/brain_extractor.py +++ b/brainles_preprocessing/brain_extraction/brain_extractor.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Optional, Union from enum import Enum +import numpy as np from auxiliary.io import read_image, write_image from brainles_hd_bet import run_hd_bet @@ -14,7 +15,23 @@ class Mode(Enum): ACCURATE = "accurate" -class BrainExtractor: +class BrainExtractor(ABC): + def __init__( + self, + masking_value: Optional[Union[int, float]] = None, + ): + """ + Base class for skull stripping medical images using brain masks. + + Subclasses should implement the `extract` method to generate a skull stripped image + based on the provided input image and mask. + """ + # Just as in the defacer, masking value is a global value defined across all images and modalities + # If no value is passed, the minimum of a given input image is chosen + # TODO: Consider extending this to modality-specific masking values in the future, this should + # probably be implemented as a property of the specific modality + self.masking_value = masking_value + @abstractmethod def extract( self, @@ -63,8 +80,17 @@ def apply_mask( if input_data.shape != mask_data.shape: raise ValueError("Input image and mask must have the same dimensions.") - # Mask and save it - masked_data = input_data * mask_data + # check whether a global masking value was passed, otherwise choose minimum + if self.masking_value is None: + current_masking_value = np.min(input_data) + else: + current_masking_value = ( + np.array(self.masking_value).astype(input_data.dtype).item() + ) + # Apply mask (element-wise either input or masking value) + masked_data = np.where( + mask_data.astype(bool), input_data, current_masking_value + ) try: write_image( @@ -78,6 +104,15 @@ def apply_mask( class HDBetExtractor(BrainExtractor): + def __init__(self, masking_value: Optional[Union[int, float]] = None): + """ + Brain extraction HDBet implementation. + + Args: + masking_value (Optional[Union[int, float]], optional): global value to be inserted in the masked areas. Default is None which leads to the minimum of each respective image. + """ + super().__init__(masking_value=masking_value) + def extract( self, input_image_path: Union[str, Path], diff --git a/brainles_preprocessing/brain_extraction/synthstrip.py b/brainles_preprocessing/brain_extraction/synthstrip.py index 0ea3135..eed97d6 100644 --- a/brainles_preprocessing/brain_extraction/synthstrip.py +++ b/brainles_preprocessing/brain_extraction/synthstrip.py @@ -21,7 +21,9 @@ class SynthStripExtractor(BrainExtractor): - def __init__(self, border: int = 1): + def __init__( + self, border: int = 1, masking_value: Optional[Union[int, float]] = None + ): """ Brain extraction using SynthStrip with preprocessing conforming to model requirements. @@ -31,9 +33,10 @@ def __init__(self, border: int = 1): Args: border (int): Mask border threshold in mm. Defaults to 1. - """ + masking_value (Optional[Union[int, float]], optional): global value to be inserted in the masked areas. Default is None which leads to the minimum of each respective image. - super().__init__() + """ + super().__init__(masking_value=masking_value) self.border = border def _setup_model(self, device: torch.device) -> StripModel: