Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions brainles_preprocessing/brain_extraction/brain_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import Optional, Union
from enum import Enum
import numpy as np

from auxiliary.io import read_image, write_image
from brainles_hd_bet import run_hd_bet
Expand All @@ -14,7 +15,23 @@ class Mode(Enum):
ACCURATE = "accurate"


class BrainExtractor:
class BrainExtractor(ABC):
def __init__(
self,
masking_value: Optional[Union[int, float]] = None,
):
"""
Base class for skull stripping medical images using brain masks.

Subclasses should implement the `extract` method to generate a skull stripped image
based on the provided input image and mask.
"""
# Just as in the defacer, masking value is a global value defined across all images and modalities
# If no value is passed, the minimum of a given input image is chosen
# TODO: Consider extending this to modality-specific masking values in the future, this should
# probably be implemented as a property of the specific modality
self.masking_value = masking_value

@abstractmethod
def extract(
self,
Expand Down Expand Up @@ -63,8 +80,17 @@ def apply_mask(
if input_data.shape != mask_data.shape:
raise ValueError("Input image and mask must have the same dimensions.")

# Mask and save it
masked_data = input_data * mask_data
# check whether a global masking value was passed, otherwise choose minimum
if self.masking_value is None:
current_masking_value = np.min(input_data)
else:
current_masking_value = (
np.array(self.masking_value).astype(input_data.dtype).item()
)
# Apply mask (element-wise either input or masking value)
masked_data = np.where(
mask_data.astype(bool), input_data, current_masking_value
)
Comment on lines +83 to +93
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply_mask() now supports a global masking_value and falls back to per-image minimum when unset, but the existing unit test only asserts that an output file is created. Please extend tests to assert that voxels outside the mask are set to the expected value for both cases (default min-per-image and a custom masking_value).

Copilot uses AI. Check for mistakes.

try:
write_image(
Expand All @@ -78,6 +104,15 @@ def apply_mask(


class HDBetExtractor(BrainExtractor):
def __init__(self, masking_value: Optional[Union[int, float]] = None):
"""
Brain extraction HDBet implementation.

Args:
masking_value (Optional[Union[int, float]], optional): global value to be inserted in the masked areas. Default is None which leads to the minimum of each respective image.
"""
super().__init__(masking_value=masking_value)

def extract(
self,
input_image_path: Union[str, Path],
Expand Down
9 changes: 6 additions & 3 deletions brainles_preprocessing/brain_extraction/synthstrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

class SynthStripExtractor(BrainExtractor):

def __init__(self, border: int = 1):
def __init__(
self, border: int = 1, masking_value: Optional[Union[int, float]] = None
):
Comment on lines +24 to +26
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

masking_value is added to the constructor and forwarded to BrainExtractor, but SynthStripExtractor.extract() never uses self.masking_value when writing the masked output (it still hard-codes bg = np.min([0, img_data.min()])). This makes the new parameter ineffective and the docstring misleading. Update extract() to fill background using the same masking_value/per-image-min logic as BrainExtractor.apply_mask (or drop the parameter if not supported).

Copilot uses AI. Check for mistakes.
"""
Brain extraction using SynthStrip with preprocessing conforming to model requirements.

Expand All @@ -31,9 +33,10 @@ def __init__(self, border: int = 1):

Args:
border (int): Mask border threshold in mm. Defaults to 1.
"""
masking_value (Optional[Union[int, float]], optional): global value to be inserted in the masked areas. Default is None which leads to the minimum of each respective image.

super().__init__()
"""
super().__init__(masking_value=masking_value)
self.border = border

def _setup_model(self, device: torch.device) -> StripModel:
Expand Down