From f179a227906e30df35d9d345c5ffd8493f906df3 Mon Sep 17 00:00:00 2001 From: Sathiesh Date: Tue, 17 Feb 2026 14:59:40 +0100 Subject: [PATCH] feat: add GPU patch augmentation, dataset preprocessing, and augmentation refactoring - Add GpuPatchAugmentation for GPU-batched augmentation on patch tensors - Add preprocess_dataset utility for offline dataset preprocessing with parallel workers - Refactor suggest_patch_augmentations to share logic via _compute_patch_aug_params - Add gpu_augmentation parameter to MedPatchDataLoader (mutually exclusive with patch_tfms) - Add tutorial documentation for GPU augmentation workflow --- fastMONAI/_modidx.py | 30 ++ fastMONAI/dataset_info.py | 143 ++++++- fastMONAI/vision_augmentation.py | 516 ++++++++++++++++++++++++-- fastMONAI/vision_patch.py | 74 ++-- nbs/03_vision_augment.ipynb | 58 ++- nbs/08_dataset_info.ipynb | 18 +- nbs/10_vision_patch.ipynb | 470 ++++++++++++++++++++++- nbs/12a_tutorial_patch_training.ipynb | 31 ++ 8 files changed, 1282 insertions(+), 58 deletions(-) diff --git a/fastMONAI/_modidx.py b/fastMONAI/_modidx.py index 9d150fd..67305a8 100644 --- a/fastMONAI/_modidx.py +++ b/fastMONAI/_modidx.py @@ -38,6 +38,8 @@ 'fastMONAI/dataset_info.py'), 'fastMONAI.dataset_info.get_class_weights': ( 'dataset_info.html#get_class_weights', 'fastMONAI/dataset_info.py'), + 'fastMONAI.dataset_info.preprocess_dataset': ( 'dataset_info.html#preprocess_dataset', + 'fastMONAI/dataset_info.py'), 'fastMONAI.dataset_info.suggest_patch_size': ( 'dataset_info.html#suggest_patch_size', 'fastMONAI/dataset_info.py')}, 'fastMONAI.external_data': { 'fastMONAI.external_data.MURLs': ('external_data.html#murls', 'fastMONAI/external_data.py'), @@ -139,6 +141,28 @@ 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation.CustomDictTransform.tio_transform': ( 'vision_augment.html#customdicttransform.tio_transform', 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.GpuPatchAugmentation': ( 'vision_augment.html#gpupatchaugmentation', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.GpuPatchAugmentation.__call__': ( 'vision_augment.html#gpupatchaugmentation.__call__', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.GpuPatchAugmentation.__init__': ( 'vision_augment.html#gpupatchaugmentation.__init__', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.GpuPatchAugmentation.__repr__': ( 'vision_augment.html#gpupatchaugmentation.__repr__', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.GpuPatchAugmentation._apply_affine': ( 'vision_augment.html#gpupatchaugmentation._apply_affine', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.GpuPatchAugmentation._apply_anisotropy': ( 'vision_augment.html#gpupatchaugmentation._apply_anisotropy', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.GpuPatchAugmentation._apply_blur': ( 'vision_augment.html#gpupatchaugmentation._apply_blur', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.GpuPatchAugmentation._apply_flip': ( 'vision_augment.html#gpupatchaugmentation._apply_flip', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.GpuPatchAugmentation._apply_gamma': ( 'vision_augment.html#gpupatchaugmentation._apply_gamma', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.GpuPatchAugmentation._apply_intensity_scale': ( 'vision_augment.html#gpupatchaugmentation._apply_intensity_scale', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.GpuPatchAugmentation._apply_noise': ( 'vision_augment.html#gpupatchaugmentation._apply_noise', + 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation.NormalizeIntensity': ( 'vision_augment.html#normalizeintensity', 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation.NormalizeIntensity.__init__': ( 'vision_augment.html#normalizeintensity.__init__', @@ -289,12 +313,18 @@ 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation._TioRandomIntensityScale.apply_transform': ( 'vision_augment.html#_tiorandomintensityscale.apply_transform', 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation._build_rotation_matrix_3d': ( 'vision_augment.html#_build_rotation_matrix_3d', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation._compute_patch_aug_params': ( 'vision_augment.html#_compute_patch_aug_params', + 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation._create_ellipsoid_mask': ( 'vision_augment.html#_create_ellipsoid_mask', 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation._foreground_masking': ( 'vision_augment.html#_foreground_masking', 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation.do_pad_or_crop': ( 'vision_augment.html#do_pad_or_crop', 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.gpu_patch_augmentations': ( 'vision_augment.html#gpu_patch_augmentations', + 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation.suggest_patch_augmentations': ( 'vision_augment.html#suggest_patch_augmentations', 'fastMONAI/vision_augmentation.py')}, 'fastMONAI.vision_core': { 'fastMONAI.vision_core.MedBase': ('vision_core.html#medbase', 'fastMONAI/vision_core.py'), diff --git a/fastMONAI/dataset_info.py b/fastMONAI/dataset_info.py index fad03e5..6414935 100644 --- a/fastMONAI/dataset_info.py +++ b/fastMONAI/dataset_info.py @@ -1,15 +1,17 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/08_dataset_info.ipynb. # %% auto #0 -__all__ = ['MedDataset', 'suggest_patch_size', 'get_class_weights'] +__all__ = ['MedDataset', 'suggest_patch_size', 'preprocess_dataset', 'get_class_weights'] # %% ../nbs/08_dataset_info.ipynb #027f016a-a80c-4842-b9dc-0bddb358a00c from .vision_core import * from .vision_plot import find_max_slice from sklearn.utils.class_weight import compute_class_weight -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed +from tqdm.auto import tqdm from pathlib import Path +import torchio as tio import pandas as pd import numpy as np import torch @@ -548,6 +550,143 @@ def round_to_divisor(val, div): return patch_size +# %% ../nbs/08_dataset_info.ipynb #mbn5svtmzkh +def preprocess_dataset(df, img_col, mask_col=None, output_dir='preprocessed', + target_spacing=None, apply_reorder=True, transforms=None, + max_workers=4, skip_existing=True): + """Preprocess dataset to disk and update DataFrame path columns in-place. + + Processes images (and optionally masks) through a transform pipeline, + saves to output_dir, then updates df[img_col] and df[mask_col] in-place + to point to the preprocessed files. + + Transform pipeline order: + CopyAffine (if masks) -> ToCanonical (if apply_reorder) + -> Resample (if target_spacing) -> user transforms + + Args: + df: DataFrame with file paths. + img_col: Column name for image paths. + mask_col: Optional column name for mask paths. + output_dir: Output directory. Creates images/ and masks/ subdirectories. + target_spacing: Target voxel spacing for resampling (e.g., [1.0, 1.0, 1.0]). + apply_reorder: Whether to reorder to RAS+ canonical orientation. + transforms: Additional TorchIO or fastMONAI transforms to apply after + reordering and resampling. + max_workers: Number of parallel workers. Each worker loads a full 3D + volume into memory, so reduce for large volumes. + skip_existing: Skip files that already exist on disk (with size > 0). + """ + # Input validation + if len(df) == 0: + raise ValueError("DataFrame is empty") + if img_col not in df.columns: + raise ValueError(f"Column '{img_col}' not found in DataFrame") + if mask_col is not None and mask_col not in df.columns: + raise ValueError(f"Column '{mask_col}' not found in DataFrame") + + img_names = [Path(p).name for p in df[img_col]] + if len(set(img_names)) != len(img_names): + dupes = set(n for n in img_names if img_names.count(n) > 1) + raise ValueError(f"Duplicate image file names: {dupes}") + + if mask_col is not None: + mask_names = [Path(p).name for p in df[mask_col]] + if len(set(mask_names)) != len(mask_names): + dupes = set(n for n in mask_names if mask_names.count(n) > 1) + raise ValueError(f"Duplicate mask file names: {dupes}") + + # Build transform pipeline (canonical order) + all_tfms = [] + if mask_col is not None: + all_tfms.append(tio.CopyAffine(target='image')) + if apply_reorder: + all_tfms.append(tio.ToCanonical()) + if target_spacing is not None: + all_tfms.append(tio.Resample(target_spacing)) + if transforms: + all_tfms.extend([getattr(t, 'tio_transform', t) for t in transforms]) + pipeline = tio.Compose(all_tfms) if all_tfms else None + + # Create output directories + output_dir = Path(output_dir) + img_dir = output_dir / 'images' + img_dir.mkdir(parents=True, exist_ok=True) + if mask_col is not None: + mask_dir = output_dir / 'masks' + mask_dir.mkdir(parents=True, exist_ok=True) + + # Build work items, filtering skip_existing + work_items = [] + skipped = 0 + for idx in range(len(df)): + img_path = df[img_col].iloc[idx] + out_img = img_dir / Path(img_path).name + + mask_path = df[mask_col].iloc[idx] if mask_col is not None else None + out_mask = (mask_dir / Path(mask_path).name) if mask_col is not None else None + + if skip_existing: + img_ok = out_img.exists() and out_img.stat().st_size > 0 + mask_ok = out_mask is None or (out_mask.exists() and out_mask.stat().st_size > 0) + if img_ok and mask_ok: + skipped += 1 + continue + + work_items.append({ + 'idx': idx, 'img_path': img_path, 'mask_path': mask_path, + 'out_img': out_img, 'out_mask': out_mask, + }) + + # Process cases + processed = 0 + failed = 0 + failed_cases = [] + + def _process_case(item): + subject_dict = {'image': tio.ScalarImage(item['img_path'])} + if item['mask_path'] is not None: + subject_dict['mask'] = tio.LabelMap(item['mask_path']) + + subject = tio.Subject(**subject_dict) + if pipeline is not None: + subject = pipeline(subject) + + # Atomic write: save to temp file (with valid NIfTI extension), then rename + out_img = item['out_img'] + tmp_img = out_img.parent / f'.tmp_{out_img.name}' + subject['image'].save(str(tmp_img)) + os.rename(str(tmp_img), str(out_img)) + + if item['out_mask'] is not None: + out_mask = item['out_mask'] + tmp_mask = out_mask.parent / f'.tmp_{out_mask.name}' + subject['mask'].save(str(tmp_mask)) + os.rename(str(tmp_mask), str(out_mask)) + + if work_items: + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(_process_case, item): item for item in work_items} + for future in tqdm(as_completed(futures), total=len(futures), + desc='Preprocessing'): + item = futures[future] + try: + future.result() + processed += 1 + except Exception as e: + failed += 1 + failed_cases.append(Path(item['img_path']).name) + warnings.warn(f"Failed to process {item['img_path']}: {e}") + + # Update DataFrame in-place + df[img_col] = [str(img_dir / Path(p).name) for p in df[img_col]] + if mask_col is not None: + df[mask_col] = [str(mask_dir / Path(p).name) for p in df[mask_col]] + + print(f"Preprocessing complete: {processed} processed, {skipped} skipped, {failed} failed") + if failed_cases: + print(f"Failed cases: {failed_cases}") + # %% ../nbs/08_dataset_info.ipynb #9b81f6e8-abd7-4bf6-be4c-4118986c308a def get_class_weights(labels: (np.array, list), class_weight: str = 'balanced') -> torch.Tensor: """Calculates and returns the class weights. diff --git a/fastMONAI/vision_augmentation.py b/fastMONAI/vision_augmentation.py index 1193322..54e5fcd 100644 --- a/fastMONAI/vision_augmentation.py +++ b/fastMONAI/vision_augmentation.py @@ -4,11 +4,14 @@ __all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', 'ZNormalization', 'RescaleIntensity', 'NormalizeIntensity', 'BraTSMaskConverter', 'BinaryConverter', 'RandomGhosting', 'RandomSpike', 'RandomNoise', 'RandomBiasField', 'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion', 'RandomAnisotropy', 'RandomCutout', - 'RandomElasticDeformation', 'RandomAffine', 'RandomFlip', 'OneOf', 'suggest_patch_augmentations'] + 'RandomElasticDeformation', 'RandomAffine', 'RandomFlip', 'OneOf', 'GpuPatchAugmentation', + 'gpu_patch_augmentations', 'suggest_patch_augmentations'] # %% ../nbs/03_vision_augment.ipynb #2d6694aa from fastai.data.all import * from .vision_core import * +import torch.nn.functional as F +import math import torchio as tio from monai.transforms import NormalizeIntensity as MonaiNormalizeIntensity @@ -791,19 +794,19 @@ class OneOf(CustomDictTransform): def __init__(self, transform_dict, p=1): super().__init__(tio.OneOf(transform_dict, p=p)) -# %% ../nbs/03_vision_augment.ipynb #t6hak044rc -def suggest_patch_augmentations(patch_size, target_spacing, - anisotropy_threshold=3.0, - translation_fraction=0.15): - """Suggest patch-based augmentations with nnU-Net-inspired defaults. +# %% ../nbs/03_vision_augment.ipynb #lqet5pabzy +def _compute_patch_aug_params(patch_size, target_spacing, + anisotropy_threshold=3.0, + translation_fraction=0.15): + """Compute geometry-aware augmentation parameters from patch/spacing metadata. - Derives rotation degrees, translation, and RandomAnisotropy axes from - patch geometry and voxel spacing. Returns a list of fastMONAI transform - instances ready for the ``patch_tfms`` parameter in MedPatchDataLoaders. + Shared logic used by both suggest_patch_augmentations (CPU/TorchIO) and + gpu_patch_augmentations (GPU-batched). Extracts rotation degrees, translation + offsets, and RandomAnisotropy axes from the spatial configuration. Anisotropy detection: if max(spacing)/min(spacing) >= threshold, rotation is restricted to 5 deg out-of-plane and 30 deg in-plane. Otherwise 30 deg - symmetric. Translation is patch_size * fraction per axis. + symmetric. Args: patch_size: List/tuple of 3 ints -- patch dimensions. @@ -812,38 +815,507 @@ def suggest_patch_augmentations(patch_size, target_spacing, translation_fraction: Fraction of patch_size for translation (default 0.15). Returns: - list: fastMONAI transform instances (7 normally, 6 if RandomAnisotropy omitted). - - Example:: - - >>> patch_tfms = suggest_patch_augmentations([128, 128, 32], [0.5, 0.5, 1.5]) - >>> dls = MedPatchDataLoaders.from_config(..., patch_tfms=patch_tfms) + dict with keys: + 'degrees': tuple of 3 ints (per-axis max rotation in degrees) + 'translation': tuple of 3 ints (per-axis translation in voxels) + 'aniso_axes': tuple of ints (axes where patch_size > 1) + 'is_aniso': bool (whether spacing is anisotropic) """ if len(patch_size) != 3: raise ValueError(f"patch_size must have 3 elements, got {len(patch_size)}") if len(target_spacing) != 3: raise ValueError(f"target_spacing must have 3 elements, got {len(target_spacing)}") - # Determine anisotropy spacing = list(target_spacing) ratio = max(spacing) / min(spacing) is_aniso = ratio >= anisotropy_threshold aniso_axis = spacing.index(max(spacing)) if is_aniso else None - # Rotation degrees if is_aniso: degrees = [5, 5, 5] degrees[aniso_axis] = 30 degrees = tuple(degrees) else: - degrees = 30 + degrees = (30, 30, 30) - # Translation translation = tuple(round(p * translation_fraction) for p in patch_size) - - # RandomAnisotropy axes: all axes where patch_size > 1 aniso_axes = tuple(i for i in range(3) if patch_size[i] > 1) + return { + 'degrees': degrees, + 'translation': translation, + 'aniso_axes': aniso_axes, + 'is_aniso': is_aniso, + } + +# %% ../nbs/03_vision_augment.ipynb #oef9rtzvbw +def _build_rotation_matrix_3d(angles_rad): + """Build [N, 3, 3] rotation matrices from [N, 3] Euler angles (XYZ extrinsic). + + Computes R = Rz @ Ry @ Rx for each sample in the batch. + + Args: + angles_rad: Tensor of shape [N, 3] with rotation angles in radians + for each axis (x, y, z). + + Returns: + Tensor of shape [N, 3, 3] -- rotation matrices. + """ + cos = torch.cos(angles_rad) + sin = torch.sin(angles_rad) + cx, cy, cz = cos[:, 0], cos[:, 1], cos[:, 2] + sx, sy, sz = sin[:, 0], sin[:, 1], sin[:, 2] + + # R = Rz @ Ry @ Rx (combined formula) + r00 = cy * cz; r01 = sx * sy * cz - cx * sz; r02 = cx * sy * cz + sx * sz + r10 = cy * sz; r11 = sx * sy * sz + cx * cz; r12 = cx * sy * sz - sx * cz + r20 = -sy; r21 = sx * cy; r22 = cx * cy + + R = torch.stack([ + torch.stack([r00, r01, r02], dim=-1), + torch.stack([r10, r11, r12], dim=-1), + torch.stack([r20, r21, r22], dim=-1), + ], dim=-2) + return R + +# %% ../nbs/03_vision_augment.ipynb #ts2kolv83z +class GpuPatchAugmentation: + """GPU-batched augmentation for patch-based training. + + Operates on [B, C, D, H, W] tensors already on GPU. All operations run + under torch.no_grad() since augmentation does not need gradient tracking. + + Transform order: spatial (affine, anisotropy, flip) then intensity + (gamma, intensity_scale, noise, blur). Spatial transforms apply the + same parameters to both image and mask. Intensity transforms skip the mask. + + Each transform is controlled by a parameter dict with at minimum a 'p' key + for per-sample probability. Pass None to disable a transform. + + Args: + affine: dict with keys 'scales', 'degrees', 'translation', + 'default_pad_value', 'p'. None to disable. + anisotropy: dict with keys 'axes', 'downsampling', 'p'. None to disable. + flip: dict with keys 'axes', 'p'. None to disable. + gamma: dict with keys 'log_gamma', 'p'. None to disable. + intensity_scale: dict with keys 'scale_range', 'p'. None to disable. + noise: dict with keys 'std', 'p'. None to disable. + blur: dict with keys 'std', 'p'. None to disable. + + Example:: + + >>> gpu_aug = GpuPatchAugmentation( + ... affine={'scales': (0.7, 1.4), 'degrees': (30, 30, 30), + ... 'translation': (25, 25, 10), 'default_pad_value': 0., 'p': 0.2}, + ... gamma={'log_gamma': (-0.3, 0.3), 'p': 0.3}, + ... flip={'axes': (0, 1, 2), 'p': 0.5}, + ... ) + >>> img_aug, mask_aug = gpu_aug(img_gpu, mask_gpu) + """ + + def __init__(self, affine=None, anisotropy=None, flip=None, + gamma=None, intensity_scale=None, noise=None, blur=None): + self.affine = affine + self.anisotropy = anisotropy + self.flip = flip + self.gamma = gamma + self.intensity_scale = intensity_scale + self.noise = noise + self.blur = blur + + def __call__(self, img, mask=None): + """Apply GPU augmentations to a batch. + + Args: + img: Tensor [B, C, D, H, W] (float). + mask: Tensor [B, C, D, H, W] (float), or None. + + Returns: + Tuple (img, mask). mask is None if input was None. + """ + with torch.no_grad(): + # Spatial transforms (same params for img and mask) + if self.affine is not None: + img, mask = self._apply_affine(img, mask) + if self.anisotropy is not None: + img, mask = self._apply_anisotropy(img, mask) + if self.flip is not None: + img, mask = self._apply_flip(img, mask) + # Intensity transforms (img only) + if self.gamma is not None: + img = self._apply_gamma(img) + if self.intensity_scale is not None: + img = self._apply_intensity_scale(img) + if self.noise is not None: + img = self._apply_noise(img) + if self.blur is not None: + img = self._apply_blur(img) + return img, mask + + def _apply_affine(self, img, mask): + """Batched random affine via F.affine_grid + F.grid_sample.""" + cfg = self.affine + B = img.shape[0] + device = img.device + dtype = img.dtype + + # Per-sample probability + do_tfm = torch.rand(B, device=device) < cfg['p'] + if not do_tfm.any(): + return img, mask + + # Start with identity theta [B, 3, 4] + theta = torch.zeros(B, 3, 4, device=device, dtype=dtype) + theta[:, 0, 0] = 1.0 + theta[:, 1, 1] = 1.0 + theta[:, 2, 2] = 1.0 + + idx = do_tfm.nonzero(as_tuple=True)[0] + n = idx.shape[0] + + # Random scales per axis + s_lo, s_hi = cfg['scales'] + scales = torch.empty(n, 3, device=device, dtype=dtype).uniform_(s_lo, s_hi) + + # Random rotation angles (degrees -> radians) + degrees = cfg['degrees'] + if not isinstance(degrees, (list, tuple)): + degrees = (degrees, degrees, degrees) + angles_deg = torch.stack([ + torch.empty(n, device=device, dtype=dtype).uniform_(-degrees[0], degrees[0]), + torch.empty(n, device=device, dtype=dtype).uniform_(-degrees[1], degrees[1]), + torch.empty(n, device=device, dtype=dtype).uniform_(-degrees[2], degrees[2]), + ], dim=1) # [n, 3] + angles_rad = angles_deg * (math.pi / 180.0) + + # Build rotation matrices [n, 3, 3] + R = _build_rotation_matrix_3d(angles_rad) + + # Scale matrix: S @ R -> [n, 3, 3] + S = torch.diag_embed(scales) # [n, 3, 3] + SR = S @ R # [n, 3, 3] + + # Random translation (voxels -> normalized [-1, 1] coords) + translation = cfg['translation'] + if not isinstance(translation, (list, tuple)): + translation = (translation, translation, translation) + spatial_size = img.shape[2:] # (D, H, W) + t_norm = torch.stack([ + torch.empty(n, device=device, dtype=dtype).uniform_( + -translation[i], translation[i] + ) * 2.0 / spatial_size[i] + for i in range(3) + ], dim=1) # [n, 3] + + # Assemble theta for active samples + theta[idx, :3, :3] = SR + theta[idx, :3, 3] = t_norm + + # Apply grid_sample + grid = F.affine_grid(theta, img.shape, align_corners=False) + pad_val = cfg.get('default_pad_value', 0.) + # For non-zero padding, shift values, sample, shift back + if pad_val != 0.: + img = img - pad_val + img = F.grid_sample(img, grid, mode='bilinear', + padding_mode='zeros', align_corners=False) + if pad_val != 0.: + img = img + pad_val + + if mask is not None: + mask = F.grid_sample(mask.float(), grid, mode='nearest', + padding_mode='zeros', align_corners=False) + return img, mask + + def _apply_anisotropy(self, img, mask): + """Per-sample anisotropy simulation via F.interpolate. + + Downsample along a random axis with nearest interpolation, + then upsample back with trilinear (matches TorchIO behavior). + Only affects img, not mask (anisotropy is intensity degradation). + """ + cfg = self.anisotropy + B = img.shape[0] + device = img.device + ds_lo, ds_hi = cfg['downsampling'] + axes = cfg['axes'] + + for i in range(B): + if torch.rand(1, device=device).item() >= cfg['p']: + continue + # Pick random axis and downsampling factor + axis_idx = torch.randint(len(axes), (1,), device=device).item() + axis = axes[axis_idx] + factor = torch.empty(1, device=device).uniform_(ds_lo, ds_hi).item() + + sample = img[i:i+1] # [1, C, D, H, W] + orig_size = list(sample.shape[2:]) + down_size = list(orig_size) + down_size[axis] = max(1, round(orig_size[axis] / factor)) + + # Downsample with nearest, upsample with trilinear + down = F.interpolate(sample, size=down_size, mode='nearest') + up = F.interpolate(down, size=orig_size, mode='trilinear', + align_corners=False) + img[i:i+1] = up + + return img, mask + + def _apply_flip(self, img, mask): + """Per-sample random flip along configured axes.""" + cfg = self.flip + B = img.shape[0] + device = img.device + axes = cfg['axes'] + p = cfg['p'] + + for i in range(B): + # Each axis is independently flipped with probability p + flip_dims = [] + for axis in axes: + if torch.rand(1, device=device).item() < p: + # axis 0 -> dim 2 (D), axis 1 -> dim 3 (H), axis 2 -> dim 4 (W) + # but img[i] is [C, D, H, W], so axis 0 -> dim 1, etc. + flip_dims.append(axis + 2) # +2 for batch and channel dims + if flip_dims: + img[i] = torch.flip(img[i], dims=[d - 1 for d in flip_dims]) # -1 since no batch dim + if mask is not None: + mask[i] = torch.flip(mask[i], dims=[d - 1 for d in flip_dims]) + + return img, mask + + def _apply_gamma(self, img): + """Batched gamma correction with per-sample random gamma.""" + cfg = self.gamma + B = img.shape[0] + device = img.device + dtype = img.dtype + log_lo, log_hi = cfg['log_gamma'] + + active = torch.rand(B, device=device) < cfg['p'] + if not active.any(): + return img + + # Only apply clamp + pow to active samples (clamp destroys negatives) + active_idx = active.nonzero(as_tuple=True)[0] + log_gamma = torch.empty(active_idx.shape[0], device=device, dtype=dtype).uniform_(log_lo, log_hi) + gamma = torch.exp(log_gamma).view(-1, 1, 1, 1, 1) + img[active_idx] = img[active_idx].clamp(min=0).pow(gamma) + return img + + def _apply_intensity_scale(self, img): + """Batched intensity scaling with per-sample random factors.""" + cfg = self.intensity_scale + B = img.shape[0] + device = img.device + dtype = img.dtype + s_lo, s_hi = cfg['scale_range'] + + # Per-sample scale (inactive get scale=1) + scale = torch.empty(B, device=device, dtype=dtype).uniform_(s_lo, s_hi) + active = torch.rand(B, device=device) < cfg['p'] + scale = torch.where(active, scale, torch.ones_like(scale)) + + img = img * scale.view(B, 1, 1, 1, 1) + return img + + def _apply_noise(self, img): + """Batched additive Gaussian noise with per-sample random std.""" + cfg = self.noise + B = img.shape[0] + device = img.device + dtype = img.dtype + + std_val = cfg['std'] + if isinstance(std_val, (list, tuple)): + std_lo, std_hi = std_val + per_std = torch.empty(B, device=device, dtype=dtype).uniform_(std_lo, std_hi) + else: + per_std = torch.full((B,), std_val, device=device, dtype=dtype) + + # Zero std for inactive samples + active = torch.rand(B, device=device) < cfg['p'] + per_std = torch.where(active, per_std, torch.zeros_like(per_std)) + + noise = torch.randn_like(img) * per_std.view(B, 1, 1, 1, 1) + img = img + noise + return img + + def _apply_blur(self, img): + """Batched separable 3D Gaussian blur via F.conv3d with groups trick.""" + cfg = self.blur + B, C, D, H, W = img.shape + device = img.device + dtype = img.dtype + + std_val = cfg['std'] + if isinstance(std_val, (list, tuple)): + std_lo, std_hi = std_val + else: + std_lo, std_hi = 0.0, std_val + + # Per-sample sigma + sigma = torch.empty(B, device=device, dtype=dtype).uniform_(std_lo, std_hi) + active = torch.rand(B, device=device) < cfg['p'] + if not active.any(): + return img + + # Fixed kernel size from max sigma + max_sigma = max(std_hi, 0.01) + kernel_radius = int(math.ceil(3 * max_sigma)) + kernel_size = 2 * kernel_radius + 1 + + # Build per-sample 1D Gaussian kernels [B, kernel_size] + x = torch.arange(-kernel_radius, kernel_radius + 1, + device=device, dtype=dtype) + # Avoid division by zero for sigma=0 + safe_sigma = torch.where(active, sigma, torch.ones_like(sigma)) + kernels = torch.exp(-x.unsqueeze(0)**2 / (2 * safe_sigma.unsqueeze(1)**2)) + kernels = kernels / kernels.sum(dim=1, keepdim=True) + + # For inactive samples, use delta kernel + delta = torch.zeros(B, kernel_size, device=device, dtype=dtype) + delta[:, kernel_radius] = 1.0 + kernels = torch.where(active.unsqueeze(1), kernels, delta) + + # Expand kernels for all channels: [B*C, kernel_size] + kernels_bc = kernels.unsqueeze(1).expand(B, C, kernel_size).reshape(B * C, kernel_size) + + # Reshape img for grouped convolution: [1, B*C, D, H, W] + img_grouped = img.reshape(1, B * C, D, H, W) + + # Separable 3D convolution: D-axis, H-axis, W-axis + pad = kernel_radius + + # D-axis: kernel shape [B*C, 1, K, 1, 1] + k_d = kernels_bc.reshape(B * C, 1, kernel_size, 1, 1) + img_grouped = F.pad(img_grouped, (0, 0, 0, 0, pad, pad), mode='replicate') + img_grouped = F.conv3d(img_grouped, k_d, groups=B * C) + + # H-axis: kernel shape [B*C, 1, 1, K, 1] + k_h = kernels_bc.reshape(B * C, 1, 1, kernel_size, 1) + img_grouped = F.pad(img_grouped, (0, 0, pad, pad, 0, 0), mode='replicate') + img_grouped = F.conv3d(img_grouped, k_h, groups=B * C) + + # W-axis: kernel shape [B*C, 1, 1, 1, K] + k_w = kernels_bc.reshape(B * C, 1, 1, 1, kernel_size) + img_grouped = F.pad(img_grouped, (pad, pad, 0, 0, 0, 0), mode='replicate') + img_grouped = F.conv3d(img_grouped, k_w, groups=B * C) + + return img_grouped.reshape(B, C, D, H, W) + + def __repr__(self): + parts = [] + for name in ['affine', 'anisotropy', 'flip', 'gamma', + 'intensity_scale', 'noise', 'blur']: + cfg = getattr(self, name) + if cfg is not None: + parts.append(f"{name}(p={cfg['p']})") + return f"GpuPatchAugmentation({', '.join(parts)})" + +# %% ../nbs/03_vision_augment.ipynb #pdbh1nqo0j7 +def gpu_patch_augmentations(patch_size, target_spacing, + anisotropy_threshold=3.0, + translation_fraction=0.15, + affine_p=0.2, anisotropy_p=0.25, + gamma_p=0.3, intensity_scale_p=0.1, + noise_p=0.1, blur_p=0.2, flip_p=0.5): + """Create GpuPatchAugmentation with nnU-Net-inspired defaults. + + Factory function that mirrors suggest_patch_augmentations but returns + a GpuPatchAugmentation for GPU-batched operation. Uses the same shared + parameter logic via _compute_patch_aug_params. + + Args: + patch_size: List/tuple of 3 ints -- patch dimensions. + target_spacing: List/tuple of 3 floats -- voxel spacing. + anisotropy_threshold: Ratio threshold for anisotropy detection (default 3.0). + translation_fraction: Fraction of patch_size for translation (default 0.15). + affine_p: Probability for RandomAffine (default 0.2). + anisotropy_p: Probability for RandomAnisotropy (default 0.25). + gamma_p: Probability for RandomGamma (default 0.3). + intensity_scale_p: Probability for RandomIntensityScale (default 0.1). + noise_p: Probability for RandomNoise (default 0.1). + blur_p: Probability for RandomBlur (default 0.2). + flip_p: Probability for RandomFlip per axis (default 0.5). + + Returns: + GpuPatchAugmentation instance. + + Example:: + + >>> gpu_aug = gpu_patch_augmentations([128, 128, 32], [0.5, 0.5, 1.5]) + >>> dls = MedPatchDataLoaders.from_df(..., gpu_augmentation=gpu_aug) + """ + params = _compute_patch_aug_params( + patch_size, target_spacing, anisotropy_threshold, translation_fraction + ) + + affine_cfg = { + 'scales': (0.7, 1.4), + 'degrees': params['degrees'], + 'translation': params['translation'], + 'default_pad_value': 0., + 'p': affine_p, + } + + aniso_cfg = None + if len(params['aniso_axes']) > 0: + aniso_cfg = { + 'axes': params['aniso_axes'], + 'downsampling': (1.5, 4), + 'p': anisotropy_p, + } + + return GpuPatchAugmentation( + affine=affine_cfg, + anisotropy=aniso_cfg, + flip={'axes': (0, 1, 2), 'p': flip_p}, + gamma={'log_gamma': (-0.3, 0.3), 'p': gamma_p}, + intensity_scale={'scale_range': (0.75, 1.25), 'p': intensity_scale_p}, + noise={'std': 0.1, 'p': noise_p}, + blur={'std': (0.5, 1.0), 'p': blur_p}, + ) + +# %% ../nbs/03_vision_augment.ipynb #t6hak044rc +def suggest_patch_augmentations(patch_size, target_spacing, + anisotropy_threshold=3.0, + translation_fraction=0.15): + """Suggest patch-based augmentations with nnU-Net-inspired defaults. + + Derives rotation degrees, translation, and RandomAnisotropy axes from + patch geometry and voxel spacing. Returns a list of fastMONAI transform + instances ready for the ``patch_tfms`` parameter in MedPatchDataLoaders. + + Anisotropy detection: if max(spacing)/min(spacing) >= threshold, rotation + is restricted to 5 deg out-of-plane and 30 deg in-plane. Otherwise 30 deg + symmetric. Translation is patch_size * fraction per axis. + + Args: + patch_size: List/tuple of 3 ints -- patch dimensions. + target_spacing: List/tuple of 3 floats -- voxel spacing. + anisotropy_threshold: Ratio threshold for anisotropy detection (default 3.0). + translation_fraction: Fraction of patch_size for translation (default 0.15). + + Returns: + list: fastMONAI transform instances (7 normally, 6 if RandomAnisotropy omitted). + + Example:: + + >>> patch_tfms = suggest_patch_augmentations([128, 128, 32], [0.5, 0.5, 1.5]) + >>> dls = MedPatchDataLoaders.from_df(..., patch_tfms=patch_tfms) + """ + params = _compute_patch_aug_params( + patch_size, target_spacing, anisotropy_threshold, translation_fraction + ) + degrees = params['degrees'] + translation = params['translation'] + aniso_axes = params['aniso_axes'] + + # For TorchIO: pass scalar 30 when isotropic (TorchIO expands to symmetric) + if not params['is_aniso']: + degrees = 30 + transforms = [ RandomAffine(scales=(0.7, 1.4), degrees=degrees, translation=translation, default_pad_value=0., p=0.2), diff --git a/fastMONAI/vision_patch.py b/fastMONAI/vision_patch.py index 40031c7..55d0d61 100644 --- a/fastMONAI/vision_patch.py +++ b/fastMONAI/vision_patch.py @@ -403,7 +403,10 @@ class MedPatchDataLoader: patch_tfms: Transforms to apply to extracted patches (training only). Accepts both fastMONAI wrappers (e.g., RandomAffine, RandomGamma) and raw TorchIO transforms. fastMONAI wrappers are automatically normalized - to raw TorchIO for internal use. + to raw TorchIO for internal use. Mutually exclusive with gpu_augmentation. + gpu_augmentation: GpuPatchAugmentation instance for GPU-batched augmentation. + Operates on [B,C,D,H,W] tensors already on GPU, avoiding per-sample CPU + overhead. Mutually exclusive with patch_tfms. Training only. shuffle: Whether to shuffle subjects and patches. drop_last: Whether to drop last incomplete batch. """ @@ -414,18 +417,20 @@ def __init__( config: PatchConfig, batch_size: int = 4, patch_tfms: list = None, + gpu_augmentation=None, shuffle: bool = True, drop_last: bool = False ): if batch_size <= 0: raise ValueError(f"batch_size must be positive, got {batch_size}") - + self.subjects_dataset = subjects_dataset self.config = config self.bs = batch_size self.shuffle = shuffle self.drop_last = drop_last self._device = _get_default_device() + self.gpu_augmentation = gpu_augmentation # Create sampler self.sampler = create_patch_sampler(config) @@ -464,14 +469,12 @@ def __iter__(self): img = batch['image'][tio.DATA] # [B, C, H, W, D] has_mask = 'mask' in batch - # Apply patch transforms if provided + # Apply CPU patch transforms if provided (per-sample TorchIO loop) if self.patch_tfms is not None: - # Apply transforms to each sample in batch transformed_imgs = [] transformed_masks = [] if has_mask else None for i in range(img.shape[0]): - # Build subject dict with image, and mask if available subject_dict = {'image': tio.ScalarImage(tensor=batch['image'][tio.DATA][i])} if has_mask: subject_dict['mask'] = tio.LabelMap(tensor=batch['mask'][tio.DATA][i]) @@ -487,10 +490,19 @@ def __iter__(self): else: mask = batch['mask'][tio.DATA] if has_mask else None - # Convert to MedImage/MedMask and move to device - img = MedImage(img).to(self._device) + # Move to device + img = img.to(self._device) if mask is not None: - mask = MedMask(mask).to(self._device) + mask = mask.to(self._device) + + # Apply GPU augmentation if provided (batched, on-device) + if self.gpu_augmentation is not None: + img, mask = self.gpu_augmentation(img, mask) + + # Wrap as MedImage/MedMask + img = MedImage(img) + if mask is not None: + mask = MedMask(mask) yield img, mask @@ -613,6 +625,7 @@ def from_df( patch_config: PatchConfig = None, pre_patch_tfms: list = None, patch_tfms: list = None, + gpu_augmentation=None, apply_reorder: bool = None, target_spacing: list = None, bs: int = 4, @@ -641,6 +654,9 @@ def from_df( (after reorder/resample). Example: [tio.ZNormalization()]. Accepts both fastMONAI wrappers and raw TorchIO transforms. patch_tfms: TorchIO transforms applied to extracted patches (training only). + Mutually exclusive with gpu_augmentation. + gpu_augmentation: GpuPatchAugmentation instance for GPU-batched augmentation + (training only). Mutually exclusive with patch_tfms. apply_reorder: If True, reorder to RAS+ orientation. If None, uses patch_config.apply_reorder. Explicit value overrides config. target_spacing: Target voxel spacing [x, y, z]. If None, uses @@ -656,22 +672,32 @@ def from_df( MedPatchDataLoaders instance. Example: - >>> # New pattern: config contains preprocessing params - >>> config = PatchConfig( - ... patch_size=[96, 96, 96], - ... apply_reorder=True, - ... target_spacing=[0.5, 0.5, 0.5], - ... label_probabilities={0: 0.1, 1: 0.9} - ... ) + >>> # CPU augmentation path (existing) >>> dls = MedPatchDataLoaders.from_df( ... df, img_col='image', mask_col='label', ... patch_config=config, - ... pre_patch_tfms=[tio.ZNormalization()], ... patch_tfms=[tio.RandomAffine(degrees=10), tio.RandomFlip()], ... bs=4 ... ) - >>> # Memory: ~150 MB (queue buffer only) + >>> + >>> # GPU augmentation path (new, faster for long training runs) + >>> from fastMONAI.vision_augmentation import gpu_patch_augmentations + >>> gpu_aug = gpu_patch_augmentations(config.patch_size, config.target_spacing) + >>> dls = MedPatchDataLoaders.from_df( + ... df, img_col='image', mask_col='label', + ... patch_config=config, + ... gpu_augmentation=gpu_aug, + ... bs=4 + ... ) """ + # Validate mutual exclusivity + if gpu_augmentation is not None and patch_tfms is not None: + raise ValueError( + "Cannot use both gpu_augmentation and patch_tfms. " + "gpu_augmentation operates on GPU tensors batch-wise, while " + "patch_tfms uses per-sample CPU TorchIO transforms. Choose one." + ) + if patch_config is None: patch_config = PatchConfig() @@ -726,11 +752,14 @@ def from_df( # Create DataLoaders (both use same patch_config for consistent sampling) train_dl = MedPatchDataLoader( train_subjects, patch_config, bs, - patch_tfms=patch_tfms, shuffle=True, drop_last=True + patch_tfms=patch_tfms, + gpu_augmentation=gpu_augmentation, + shuffle=True, drop_last=True ) valid_dl = MedPatchDataLoader( valid_subjects, patch_config, bs, patch_tfms=None, # No augmentation for validation + gpu_augmentation=None, # No augmentation for validation shuffle=False, drop_last=False ) @@ -847,7 +876,8 @@ def cpu(self): return self.to(torch.device('cpu')) def show_batch(self, dl_idx=0, max_n=6, figsize=None, channel=0, - slice_index=None, anatomical_plane=0, overlay=False, **kwargs): + slice_index=None, anatomical_plane=0, overlay=False, + voxel_size=None, **kwargs): """Show a batch of patch samples for visualization.""" dl = self[dl_idx] @@ -886,15 +916,15 @@ def show_batch(self, dl_idx=0, max_n=6, figsize=None, channel=0, imgs.extend(im_channels) slice_idxs.extend([idx] * len(im_channels)) - voxel_size = self.target_spacing + _voxel_size = voxel_size if voxel_size is not None else self.target_spacing ctxs = [im.show(ax=ax, slice_index=idx, anatomical_plane=anatomical_plane, - voxel_size=voxel_size) + voxel_size=_voxel_size) for im, ax, idx in zip(imgs, flat_axs, slice_idxs)] if overlay and has_mask: for mask, ax, idx in zip(masks_for_overlay, flat_axs, slice_idxs): mask.show(ax=ax, slice_index=idx, anatomical_plane=anatomical_plane, - voxel_size=voxel_size) + voxel_size=_voxel_size) plt.tight_layout() plt.show() diff --git a/nbs/03_vision_augment.ipynb b/nbs/03_vision_augment.ipynb index f1f7c94..463e763 100644 --- a/nbs/03_vision_augment.ipynb +++ b/nbs/03_vision_augment.ipynb @@ -27,7 +27,7 @@ "id": "2d6694aa", "metadata": {}, "outputs": [], - "source": "#| export\nfrom fastai.data.all import *\nfrom fastMONAI.vision_core import *\nimport torchio as tio\nfrom monai.transforms import NormalizeIntensity as MonaiNormalizeIntensity" + "source": "#| export\nfrom fastai.data.all import *\nfrom fastMONAI.vision_core import *\nimport torch.nn.functional as F\nimport math\nimport torchio as tio\nfrom monai.transforms import NormalizeIntensity as MonaiNormalizeIntensity" }, { "cell_type": "markdown", @@ -754,13 +754,51 @@ "metadata": {}, "source": "## Augmentation suggestion" }, + { + "cell_type": "code", + "execution_count": null, + "id": "lqet5pabzy", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef _compute_patch_aug_params(patch_size, target_spacing,\n anisotropy_threshold=3.0,\n translation_fraction=0.15):\n \"\"\"Compute geometry-aware augmentation parameters from patch/spacing metadata.\n\n Shared logic used by both suggest_patch_augmentations (CPU/TorchIO) and\n gpu_patch_augmentations (GPU-batched). Extracts rotation degrees, translation\n offsets, and RandomAnisotropy axes from the spatial configuration.\n\n Anisotropy detection: if max(spacing)/min(spacing) >= threshold, rotation\n is restricted to 5 deg out-of-plane and 30 deg in-plane. Otherwise 30 deg\n symmetric.\n\n Args:\n patch_size: List/tuple of 3 ints -- patch dimensions.\n target_spacing: List/tuple of 3 floats -- voxel spacing.\n anisotropy_threshold: Ratio threshold for anisotropy detection (default 3.0).\n translation_fraction: Fraction of patch_size for translation (default 0.15).\n\n Returns:\n dict with keys:\n 'degrees': tuple of 3 ints (per-axis max rotation in degrees)\n 'translation': tuple of 3 ints (per-axis translation in voxels)\n 'aniso_axes': tuple of ints (axes where patch_size > 1)\n 'is_aniso': bool (whether spacing is anisotropic)\n \"\"\"\n if len(patch_size) != 3:\n raise ValueError(f\"patch_size must have 3 elements, got {len(patch_size)}\")\n if len(target_spacing) != 3:\n raise ValueError(f\"target_spacing must have 3 elements, got {len(target_spacing)}\")\n\n spacing = list(target_spacing)\n ratio = max(spacing) / min(spacing)\n is_aniso = ratio >= anisotropy_threshold\n aniso_axis = spacing.index(max(spacing)) if is_aniso else None\n\n if is_aniso:\n degrees = [5, 5, 5]\n degrees[aniso_axis] = 30\n degrees = tuple(degrees)\n else:\n degrees = (30, 30, 30)\n\n translation = tuple(round(p * translation_fraction) for p in patch_size)\n aniso_axes = tuple(i for i in range(3) if patch_size[i] > 1)\n\n return {\n 'degrees': degrees,\n 'translation': translation,\n 'aniso_axes': aniso_axes,\n 'is_aniso': is_aniso,\n }" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "oef9rtzvbw", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef _build_rotation_matrix_3d(angles_rad):\n \"\"\"Build [N, 3, 3] rotation matrices from [N, 3] Euler angles (XYZ extrinsic).\n\n Computes R = Rz @ Ry @ Rx for each sample in the batch.\n\n Args:\n angles_rad: Tensor of shape [N, 3] with rotation angles in radians\n for each axis (x, y, z).\n\n Returns:\n Tensor of shape [N, 3, 3] -- rotation matrices.\n \"\"\"\n cos = torch.cos(angles_rad)\n sin = torch.sin(angles_rad)\n cx, cy, cz = cos[:, 0], cos[:, 1], cos[:, 2]\n sx, sy, sz = sin[:, 0], sin[:, 1], sin[:, 2]\n\n # R = Rz @ Ry @ Rx (combined formula)\n r00 = cy * cz; r01 = sx * sy * cz - cx * sz; r02 = cx * sy * cz + sx * sz\n r10 = cy * sz; r11 = sx * sy * sz + cx * cz; r12 = cx * sy * sz - sx * cz\n r20 = -sy; r21 = sx * cy; r22 = cx * cy\n\n R = torch.stack([\n torch.stack([r00, r01, r02], dim=-1),\n torch.stack([r10, r11, r12], dim=-1),\n torch.stack([r20, r21, r22], dim=-1),\n ], dim=-2)\n return R" + }, + { + "cell_type": "markdown", + "id": "qjc9baqro1", + "metadata": {}, + "source": "## GPU patch augmentation" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ts2kolv83z", + "metadata": {}, + "outputs": [], + "source": "#| export\nclass GpuPatchAugmentation:\n \"\"\"GPU-batched augmentation for patch-based training.\n\n Operates on [B, C, D, H, W] tensors already on GPU. All operations run\n under torch.no_grad() since augmentation does not need gradient tracking.\n\n Transform order: spatial (affine, anisotropy, flip) then intensity\n (gamma, intensity_scale, noise, blur). Spatial transforms apply the\n same parameters to both image and mask. Intensity transforms skip the mask.\n\n Each transform is controlled by a parameter dict with at minimum a 'p' key\n for per-sample probability. Pass None to disable a transform.\n\n Args:\n affine: dict with keys 'scales', 'degrees', 'translation',\n 'default_pad_value', 'p'. None to disable.\n anisotropy: dict with keys 'axes', 'downsampling', 'p'. None to disable.\n flip: dict with keys 'axes', 'p'. None to disable.\n gamma: dict with keys 'log_gamma', 'p'. None to disable.\n intensity_scale: dict with keys 'scale_range', 'p'. None to disable.\n noise: dict with keys 'std', 'p'. None to disable.\n blur: dict with keys 'std', 'p'. None to disable.\n\n Example::\n\n >>> gpu_aug = GpuPatchAugmentation(\n ... affine={'scales': (0.7, 1.4), 'degrees': (30, 30, 30),\n ... 'translation': (25, 25, 10), 'default_pad_value': 0., 'p': 0.2},\n ... gamma={'log_gamma': (-0.3, 0.3), 'p': 0.3},\n ... flip={'axes': (0, 1, 2), 'p': 0.5},\n ... )\n >>> img_aug, mask_aug = gpu_aug(img_gpu, mask_gpu)\n \"\"\"\n\n def __init__(self, affine=None, anisotropy=None, flip=None,\n gamma=None, intensity_scale=None, noise=None, blur=None):\n self.affine = affine\n self.anisotropy = anisotropy\n self.flip = flip\n self.gamma = gamma\n self.intensity_scale = intensity_scale\n self.noise = noise\n self.blur = blur\n\n def __call__(self, img, mask=None):\n \"\"\"Apply GPU augmentations to a batch.\n\n Args:\n img: Tensor [B, C, D, H, W] (float).\n mask: Tensor [B, C, D, H, W] (float), or None.\n\n Returns:\n Tuple (img, mask). mask is None if input was None.\n \"\"\"\n with torch.no_grad():\n # Spatial transforms (same params for img and mask)\n if self.affine is not None:\n img, mask = self._apply_affine(img, mask)\n if self.anisotropy is not None:\n img, mask = self._apply_anisotropy(img, mask)\n if self.flip is not None:\n img, mask = self._apply_flip(img, mask)\n # Intensity transforms (img only)\n if self.gamma is not None:\n img = self._apply_gamma(img)\n if self.intensity_scale is not None:\n img = self._apply_intensity_scale(img)\n if self.noise is not None:\n img = self._apply_noise(img)\n if self.blur is not None:\n img = self._apply_blur(img)\n return img, mask\n\n def _apply_affine(self, img, mask):\n \"\"\"Batched random affine via F.affine_grid + F.grid_sample.\"\"\"\n cfg = self.affine\n B = img.shape[0]\n device = img.device\n dtype = img.dtype\n\n # Per-sample probability\n do_tfm = torch.rand(B, device=device) < cfg['p']\n if not do_tfm.any():\n return img, mask\n\n # Start with identity theta [B, 3, 4]\n theta = torch.zeros(B, 3, 4, device=device, dtype=dtype)\n theta[:, 0, 0] = 1.0\n theta[:, 1, 1] = 1.0\n theta[:, 2, 2] = 1.0\n\n idx = do_tfm.nonzero(as_tuple=True)[0]\n n = idx.shape[0]\n\n # Random scales per axis\n s_lo, s_hi = cfg['scales']\n scales = torch.empty(n, 3, device=device, dtype=dtype).uniform_(s_lo, s_hi)\n\n # Random rotation angles (degrees -> radians)\n degrees = cfg['degrees']\n if not isinstance(degrees, (list, tuple)):\n degrees = (degrees, degrees, degrees)\n angles_deg = torch.stack([\n torch.empty(n, device=device, dtype=dtype).uniform_(-degrees[0], degrees[0]),\n torch.empty(n, device=device, dtype=dtype).uniform_(-degrees[1], degrees[1]),\n torch.empty(n, device=device, dtype=dtype).uniform_(-degrees[2], degrees[2]),\n ], dim=1) # [n, 3]\n angles_rad = angles_deg * (math.pi / 180.0)\n\n # Build rotation matrices [n, 3, 3]\n R = _build_rotation_matrix_3d(angles_rad)\n\n # Scale matrix: S @ R -> [n, 3, 3]\n S = torch.diag_embed(scales) # [n, 3, 3]\n SR = S @ R # [n, 3, 3]\n\n # Random translation (voxels -> normalized [-1, 1] coords)\n translation = cfg['translation']\n if not isinstance(translation, (list, tuple)):\n translation = (translation, translation, translation)\n spatial_size = img.shape[2:] # (D, H, W)\n t_norm = torch.stack([\n torch.empty(n, device=device, dtype=dtype).uniform_(\n -translation[i], translation[i]\n ) * 2.0 / spatial_size[i]\n for i in range(3)\n ], dim=1) # [n, 3]\n\n # Assemble theta for active samples\n theta[idx, :3, :3] = SR\n theta[idx, :3, 3] = t_norm\n\n # Apply grid_sample\n grid = F.affine_grid(theta, img.shape, align_corners=False)\n pad_val = cfg.get('default_pad_value', 0.)\n # For non-zero padding, shift values, sample, shift back\n if pad_val != 0.:\n img = img - pad_val\n img = F.grid_sample(img, grid, mode='bilinear',\n padding_mode='zeros', align_corners=False)\n if pad_val != 0.:\n img = img + pad_val\n\n if mask is not None:\n mask = F.grid_sample(mask.float(), grid, mode='nearest',\n padding_mode='zeros', align_corners=False)\n return img, mask\n\n def _apply_anisotropy(self, img, mask):\n \"\"\"Per-sample anisotropy simulation via F.interpolate.\n\n Downsample along a random axis with nearest interpolation,\n then upsample back with trilinear (matches TorchIO behavior).\n Only affects img, not mask (anisotropy is intensity degradation).\n \"\"\"\n cfg = self.anisotropy\n B = img.shape[0]\n device = img.device\n ds_lo, ds_hi = cfg['downsampling']\n axes = cfg['axes']\n\n for i in range(B):\n if torch.rand(1, device=device).item() >= cfg['p']:\n continue\n # Pick random axis and downsampling factor\n axis_idx = torch.randint(len(axes), (1,), device=device).item()\n axis = axes[axis_idx]\n factor = torch.empty(1, device=device).uniform_(ds_lo, ds_hi).item()\n\n sample = img[i:i+1] # [1, C, D, H, W]\n orig_size = list(sample.shape[2:])\n down_size = list(orig_size)\n down_size[axis] = max(1, round(orig_size[axis] / factor))\n\n # Downsample with nearest, upsample with trilinear\n down = F.interpolate(sample, size=down_size, mode='nearest')\n up = F.interpolate(down, size=orig_size, mode='trilinear',\n align_corners=False)\n img[i:i+1] = up\n\n return img, mask\n\n def _apply_flip(self, img, mask):\n \"\"\"Per-sample random flip along configured axes.\"\"\"\n cfg = self.flip\n B = img.shape[0]\n device = img.device\n axes = cfg['axes']\n p = cfg['p']\n\n for i in range(B):\n # Each axis is independently flipped with probability p\n flip_dims = []\n for axis in axes:\n if torch.rand(1, device=device).item() < p:\n # axis 0 -> dim 2 (D), axis 1 -> dim 3 (H), axis 2 -> dim 4 (W)\n # but img[i] is [C, D, H, W], so axis 0 -> dim 1, etc.\n flip_dims.append(axis + 2) # +2 for batch and channel dims\n if flip_dims:\n img[i] = torch.flip(img[i], dims=[d - 1 for d in flip_dims]) # -1 since no batch dim\n if mask is not None:\n mask[i] = torch.flip(mask[i], dims=[d - 1 for d in flip_dims])\n\n return img, mask\n\n def _apply_gamma(self, img):\n \"\"\"Batched gamma correction with per-sample random gamma.\"\"\"\n cfg = self.gamma\n B = img.shape[0]\n device = img.device\n dtype = img.dtype\n log_lo, log_hi = cfg['log_gamma']\n\n active = torch.rand(B, device=device) < cfg['p']\n if not active.any():\n return img\n\n # Only apply clamp + pow to active samples (clamp destroys negatives)\n active_idx = active.nonzero(as_tuple=True)[0]\n log_gamma = torch.empty(active_idx.shape[0], device=device, dtype=dtype).uniform_(log_lo, log_hi)\n gamma = torch.exp(log_gamma).view(-1, 1, 1, 1, 1)\n img[active_idx] = img[active_idx].clamp(min=0).pow(gamma)\n return img\n\n def _apply_intensity_scale(self, img):\n \"\"\"Batched intensity scaling with per-sample random factors.\"\"\"\n cfg = self.intensity_scale\n B = img.shape[0]\n device = img.device\n dtype = img.dtype\n s_lo, s_hi = cfg['scale_range']\n\n # Per-sample scale (inactive get scale=1)\n scale = torch.empty(B, device=device, dtype=dtype).uniform_(s_lo, s_hi)\n active = torch.rand(B, device=device) < cfg['p']\n scale = torch.where(active, scale, torch.ones_like(scale))\n\n img = img * scale.view(B, 1, 1, 1, 1)\n return img\n\n def _apply_noise(self, img):\n \"\"\"Batched additive Gaussian noise with per-sample random std.\"\"\"\n cfg = self.noise\n B = img.shape[0]\n device = img.device\n dtype = img.dtype\n\n std_val = cfg['std']\n if isinstance(std_val, (list, tuple)):\n std_lo, std_hi = std_val\n per_std = torch.empty(B, device=device, dtype=dtype).uniform_(std_lo, std_hi)\n else:\n per_std = torch.full((B,), std_val, device=device, dtype=dtype)\n\n # Zero std for inactive samples\n active = torch.rand(B, device=device) < cfg['p']\n per_std = torch.where(active, per_std, torch.zeros_like(per_std))\n\n noise = torch.randn_like(img) * per_std.view(B, 1, 1, 1, 1)\n img = img + noise\n return img\n\n def _apply_blur(self, img):\n \"\"\"Batched separable 3D Gaussian blur via F.conv3d with groups trick.\"\"\"\n cfg = self.blur\n B, C, D, H, W = img.shape\n device = img.device\n dtype = img.dtype\n\n std_val = cfg['std']\n if isinstance(std_val, (list, tuple)):\n std_lo, std_hi = std_val\n else:\n std_lo, std_hi = 0.0, std_val\n\n # Per-sample sigma\n sigma = torch.empty(B, device=device, dtype=dtype).uniform_(std_lo, std_hi)\n active = torch.rand(B, device=device) < cfg['p']\n if not active.any():\n return img\n\n # Fixed kernel size from max sigma\n max_sigma = max(std_hi, 0.01)\n kernel_radius = int(math.ceil(3 * max_sigma))\n kernel_size = 2 * kernel_radius + 1\n\n # Build per-sample 1D Gaussian kernels [B, kernel_size]\n x = torch.arange(-kernel_radius, kernel_radius + 1,\n device=device, dtype=dtype)\n # Avoid division by zero for sigma=0\n safe_sigma = torch.where(active, sigma, torch.ones_like(sigma))\n kernels = torch.exp(-x.unsqueeze(0)**2 / (2 * safe_sigma.unsqueeze(1)**2))\n kernels = kernels / kernels.sum(dim=1, keepdim=True)\n\n # For inactive samples, use delta kernel\n delta = torch.zeros(B, kernel_size, device=device, dtype=dtype)\n delta[:, kernel_radius] = 1.0\n kernels = torch.where(active.unsqueeze(1), kernels, delta)\n\n # Expand kernels for all channels: [B*C, kernel_size]\n kernels_bc = kernels.unsqueeze(1).expand(B, C, kernel_size).reshape(B * C, kernel_size)\n\n # Reshape img for grouped convolution: [1, B*C, D, H, W]\n img_grouped = img.reshape(1, B * C, D, H, W)\n\n # Separable 3D convolution: D-axis, H-axis, W-axis\n pad = kernel_radius\n\n # D-axis: kernel shape [B*C, 1, K, 1, 1]\n k_d = kernels_bc.reshape(B * C, 1, kernel_size, 1, 1)\n img_grouped = F.pad(img_grouped, (0, 0, 0, 0, pad, pad), mode='replicate')\n img_grouped = F.conv3d(img_grouped, k_d, groups=B * C)\n\n # H-axis: kernel shape [B*C, 1, 1, K, 1]\n k_h = kernels_bc.reshape(B * C, 1, 1, kernel_size, 1)\n img_grouped = F.pad(img_grouped, (0, 0, pad, pad, 0, 0), mode='replicate')\n img_grouped = F.conv3d(img_grouped, k_h, groups=B * C)\n\n # W-axis: kernel shape [B*C, 1, 1, 1, K]\n k_w = kernels_bc.reshape(B * C, 1, 1, 1, kernel_size)\n img_grouped = F.pad(img_grouped, (pad, pad, 0, 0, 0, 0), mode='replicate')\n img_grouped = F.conv3d(img_grouped, k_w, groups=B * C)\n\n return img_grouped.reshape(B, C, D, H, W)\n\n def __repr__(self):\n parts = []\n for name in ['affine', 'anisotropy', 'flip', 'gamma',\n 'intensity_scale', 'noise', 'blur']:\n cfg = getattr(self, name)\n if cfg is not None:\n parts.append(f\"{name}(p={cfg['p']})\")\n return f\"GpuPatchAugmentation({', '.join(parts)})\"" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "pdbh1nqo0j7", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef gpu_patch_augmentations(patch_size, target_spacing,\n anisotropy_threshold=3.0,\n translation_fraction=0.15,\n affine_p=0.2, anisotropy_p=0.25,\n gamma_p=0.3, intensity_scale_p=0.1,\n noise_p=0.1, blur_p=0.2, flip_p=0.5):\n \"\"\"Create GpuPatchAugmentation with nnU-Net-inspired defaults.\n\n Factory function that mirrors suggest_patch_augmentations but returns\n a GpuPatchAugmentation for GPU-batched operation. Uses the same shared\n parameter logic via _compute_patch_aug_params.\n\n Args:\n patch_size: List/tuple of 3 ints -- patch dimensions.\n target_spacing: List/tuple of 3 floats -- voxel spacing.\n anisotropy_threshold: Ratio threshold for anisotropy detection (default 3.0).\n translation_fraction: Fraction of patch_size for translation (default 0.15).\n affine_p: Probability for RandomAffine (default 0.2).\n anisotropy_p: Probability for RandomAnisotropy (default 0.25).\n gamma_p: Probability for RandomGamma (default 0.3).\n intensity_scale_p: Probability for RandomIntensityScale (default 0.1).\n noise_p: Probability for RandomNoise (default 0.1).\n blur_p: Probability for RandomBlur (default 0.2).\n flip_p: Probability for RandomFlip per axis (default 0.5).\n\n Returns:\n GpuPatchAugmentation instance.\n\n Example::\n\n >>> gpu_aug = gpu_patch_augmentations([128, 128, 32], [0.5, 0.5, 1.5])\n >>> dls = MedPatchDataLoaders.from_df(..., gpu_augmentation=gpu_aug)\n \"\"\"\n params = _compute_patch_aug_params(\n patch_size, target_spacing, anisotropy_threshold, translation_fraction\n )\n\n affine_cfg = {\n 'scales': (0.7, 1.4),\n 'degrees': params['degrees'],\n 'translation': params['translation'],\n 'default_pad_value': 0.,\n 'p': affine_p,\n }\n\n aniso_cfg = None\n if len(params['aniso_axes']) > 0:\n aniso_cfg = {\n 'axes': params['aniso_axes'],\n 'downsampling': (1.5, 4),\n 'p': anisotropy_p,\n }\n\n return GpuPatchAugmentation(\n affine=affine_cfg,\n anisotropy=aniso_cfg,\n flip={'axes': (0, 1, 2), 'p': flip_p},\n gamma={'log_gamma': (-0.3, 0.3), 'p': gamma_p},\n intensity_scale={'scale_range': (0.75, 1.25), 'p': intensity_scale_p},\n noise={'std': 0.1, 'p': noise_p},\n blur={'std': (0.5, 1.0), 'p': blur_p},\n )" + }, { "cell_type": "code", "execution_count": null, "id": "t6hak044rc", "metadata": {}, "outputs": [], - "source": "#| export\ndef suggest_patch_augmentations(patch_size, target_spacing,\n anisotropy_threshold=3.0,\n translation_fraction=0.15):\n \"\"\"Suggest patch-based augmentations with nnU-Net-inspired defaults.\n\n Derives rotation degrees, translation, and RandomAnisotropy axes from\n patch geometry and voxel spacing. Returns a list of fastMONAI transform\n instances ready for the ``patch_tfms`` parameter in MedPatchDataLoaders.\n\n Anisotropy detection: if max(spacing)/min(spacing) >= threshold, rotation\n is restricted to 5 deg out-of-plane and 30 deg in-plane. Otherwise 30 deg\n symmetric. Translation is patch_size * fraction per axis.\n\n Args:\n patch_size: List/tuple of 3 ints -- patch dimensions.\n target_spacing: List/tuple of 3 floats -- voxel spacing.\n anisotropy_threshold: Ratio threshold for anisotropy detection (default 3.0).\n translation_fraction: Fraction of patch_size for translation (default 0.15).\n\n Returns:\n list: fastMONAI transform instances (7 normally, 6 if RandomAnisotropy omitted).\n\n Example::\n\n >>> patch_tfms = suggest_patch_augmentations([128, 128, 32], [0.5, 0.5, 1.5])\n >>> dls = MedPatchDataLoaders.from_config(..., patch_tfms=patch_tfms)\n \"\"\"\n if len(patch_size) != 3:\n raise ValueError(f\"patch_size must have 3 elements, got {len(patch_size)}\")\n if len(target_spacing) != 3:\n raise ValueError(f\"target_spacing must have 3 elements, got {len(target_spacing)}\")\n\n # Determine anisotropy\n spacing = list(target_spacing)\n ratio = max(spacing) / min(spacing)\n is_aniso = ratio >= anisotropy_threshold\n aniso_axis = spacing.index(max(spacing)) if is_aniso else None\n\n # Rotation degrees\n if is_aniso:\n degrees = [5, 5, 5]\n degrees[aniso_axis] = 30\n degrees = tuple(degrees)\n else:\n degrees = 30\n\n # Translation\n translation = tuple(round(p * translation_fraction) for p in patch_size)\n\n # RandomAnisotropy axes: all axes where patch_size > 1\n aniso_axes = tuple(i for i in range(3) if patch_size[i] > 1)\n\n transforms = [\n RandomAffine(scales=(0.7, 1.4), degrees=degrees, translation=translation,\n default_pad_value=0., p=0.2),\n ]\n\n if len(aniso_axes) > 0:\n transforms.append(RandomAnisotropy(axes=aniso_axes, downsampling=(1.5, 4), p=0.25))\n\n transforms.extend([\n RandomGamma(log_gamma=(-0.3, 0.3), p=0.3),\n RandomIntensityScale(scale_range=(0.75, 1.25), p=0.1),\n RandomNoise(std=0.1, p=0.1),\n RandomBlur(std=(0.5, 1.0), p=0.2),\n RandomFlip(p=0.5),\n ])\n\n return transforms" + "source": "#| export\ndef suggest_patch_augmentations(patch_size, target_spacing,\n anisotropy_threshold=3.0,\n translation_fraction=0.15):\n \"\"\"Suggest patch-based augmentations with nnU-Net-inspired defaults.\n\n Derives rotation degrees, translation, and RandomAnisotropy axes from\n patch geometry and voxel spacing. Returns a list of fastMONAI transform\n instances ready for the ``patch_tfms`` parameter in MedPatchDataLoaders.\n\n Anisotropy detection: if max(spacing)/min(spacing) >= threshold, rotation\n is restricted to 5 deg out-of-plane and 30 deg in-plane. Otherwise 30 deg\n symmetric. Translation is patch_size * fraction per axis.\n\n Args:\n patch_size: List/tuple of 3 ints -- patch dimensions.\n target_spacing: List/tuple of 3 floats -- voxel spacing.\n anisotropy_threshold: Ratio threshold for anisotropy detection (default 3.0).\n translation_fraction: Fraction of patch_size for translation (default 0.15).\n\n Returns:\n list: fastMONAI transform instances (7 normally, 6 if RandomAnisotropy omitted).\n\n Example::\n\n >>> patch_tfms = suggest_patch_augmentations([128, 128, 32], [0.5, 0.5, 1.5])\n >>> dls = MedPatchDataLoaders.from_df(..., patch_tfms=patch_tfms)\n \"\"\"\n params = _compute_patch_aug_params(\n patch_size, target_spacing, anisotropy_threshold, translation_fraction\n )\n degrees = params['degrees']\n translation = params['translation']\n aniso_axes = params['aniso_axes']\n\n # For TorchIO: pass scalar 30 when isotropic (TorchIO expands to symmetric)\n if not params['is_aniso']:\n degrees = 30\n\n transforms = [\n RandomAffine(scales=(0.7, 1.4), degrees=degrees, translation=translation,\n default_pad_value=0., p=0.2),\n ]\n\n if len(aniso_axes) > 0:\n transforms.append(RandomAnisotropy(axes=aniso_axes, downsampling=(1.5, 4), p=0.25))\n\n transforms.extend([\n RandomGamma(log_gamma=(-0.3, 0.3), p=0.3),\n RandomIntensityScale(scale_range=(0.75, 1.25), p=0.1),\n RandomNoise(std=0.1, p=0.1),\n RandomBlur(std=(0.5, 1.0), p=0.2),\n RandomFlip(p=0.5),\n ])\n\n return transforms" }, { "cell_type": "code", @@ -770,6 +808,22 @@ "outputs": [], "source": "from fastcore.test import test_eq, test_fail\n\n# Isotropic case\ntfms = suggest_patch_augmentations([128, 128, 128], [1.0, 1.0, 1.0])\ntest_eq(len(tfms), 7)\ntest_eq(type(tfms[0]), RandomAffine)\ntest_eq(type(tfms[-1]), RandomFlip)\n\n# Anisotropic case (axis 2 thick): degrees=(5, 5, 30) -> (-5, 5, -5, 5, -30, 30)\ntfms = suggest_patch_augmentations([128, 128, 32], target_spacing=[0.5, 0.5, 1.5])\ntest_eq(len(tfms), 7)\naff = tfms[0].tio_transform\ntest_eq(aff.degrees, (-5, 5, -5, 5, -30, 30))\n\n# Anisotropic case (axis 0 thick): degrees=(30, 5, 5) -> (-30, 30, -5, 5, -5, 5)\ntfms = suggest_patch_augmentations([32, 128, 128], target_spacing=[3.0, 0.5, 0.5])\naff = tfms[0].tio_transform\ntest_eq(aff.degrees, (-30, 30, -5, 5, -5, 5))\n\n# Isotropic spacing -> symmetric degrees: 30 -> (-30, 30, -30, 30, -30, 30)\ntfms = suggest_patch_augmentations([64, 64, 64], [1.0, 1.0, 1.0])\naff = tfms[0].tio_transform\ntest_eq(aff.degrees, (-30, 30, -30, 30, -30, 30))\n\n# 2D-like patch [128, 128, 1]\ntfms = suggest_patch_augmentations([128, 128, 1], [1.0, 1.0, 1.0])\naniso_tfm = tfms[1]\ntest_eq(type(aniso_tfm), RandomAnisotropy)\ntest_eq(aniso_tfm.add_anisotropy.axes, (0, 1))\n\n# All dims 1 -> RandomAnisotropy omitted\ntfms = suggest_patch_augmentations([1, 1, 1], [1.0, 1.0, 1.0])\ntest_eq(len(tfms), 6)\ntest_eq(all(not isinstance(t, RandomAnisotropy) for t in tfms), True)\n\n# Wrong input lengths\ntest_fail(lambda: suggest_patch_augmentations([128, 128], [1.0, 1.0, 1.0]))\ntest_fail(lambda: suggest_patch_augmentations([128, 128, 128], [1.0, 1.0]))\n\n# All returned transforms have .tio_transform\ntfms = suggest_patch_augmentations([128, 128, 64], [1.0, 1.0, 1.0])\nfor t in tfms:\n assert hasattr(t, 'tio_transform'), f\"{type(t).__name__} missing .tio_transform\"" }, + { + "cell_type": "code", + "execution_count": null, + "id": "bcc6ucwhfa", + "metadata": {}, + "outputs": [], + "source": "# Tests for _compute_patch_aug_params\nparams = _compute_patch_aug_params([128, 128, 128], [1.0, 1.0, 1.0])\ntest_eq(params['degrees'], (30, 30, 30))\ntest_eq(params['is_aniso'], False)\ntest_eq(params['aniso_axes'], (0, 1, 2))\n\nparams = _compute_patch_aug_params([128, 128, 32], [0.5, 0.5, 1.5])\ntest_eq(params['degrees'], (5, 5, 30))\ntest_eq(params['is_aniso'], True)\n\nparams = _compute_patch_aug_params([128, 128, 1], [1.0, 1.0, 1.0])\ntest_eq(params['aniso_axes'], (0, 1))\n\ntest_fail(lambda: _compute_patch_aug_params([128, 128], [1.0, 1.0, 1.0]))\ntest_fail(lambda: _compute_patch_aug_params([128, 128, 128], [1.0, 1.0]))\n\n# Tests for _build_rotation_matrix_3d\nzero_angles = torch.zeros(3, 3)\nR = _build_rotation_matrix_3d(zero_angles)\ntest_eq(R.shape, (3, 3, 3))\nfor i in range(3):\n assert torch.allclose(R[i], torch.eye(3), atol=1e-6), \"Zero angles should give identity\"\n\n# Tests for GpuPatchAugmentation: p=0 -> identity\ngpu_aug_noop = GpuPatchAugmentation(\n affine={'scales': (0.7, 1.4), 'degrees': (30, 30, 30),\n 'translation': (10, 10, 10), 'default_pad_value': 0., 'p': 0.0},\n anisotropy={'axes': (0, 1, 2), 'downsampling': (1.5, 4), 'p': 0.0},\n flip={'axes': (0, 1, 2), 'p': 0.0},\n gamma={'log_gamma': (-0.3, 0.3), 'p': 0.0},\n intensity_scale={'scale_range': (0.75, 1.25), 'p': 0.0},\n noise={'std': 0.1, 'p': 0.0},\n blur={'std': (0.5, 1.0), 'p': 0.0},\n)\ntest_img = torch.randn(2, 1, 16, 16, 16)\ntest_mask = torch.zeros(2, 1, 16, 16, 16)\ntest_mask[:, :, 4:12, 4:12, 4:12] = 1.0\nout_img, out_mask = gpu_aug_noop(test_img, test_mask)\ntest_eq(torch.equal(out_img, test_img), True)\ntest_eq(torch.equal(out_mask, test_mask), True)\n\n# Tests for GpuPatchAugmentation: p=1 -> shapes preserved\ngpu_aug_all = GpuPatchAugmentation(\n affine={'scales': (0.7, 1.4), 'degrees': (30, 30, 30),\n 'translation': (2, 2, 2), 'default_pad_value': 0., 'p': 1.0},\n anisotropy={'axes': (0, 1, 2), 'downsampling': (1.5, 4), 'p': 1.0},\n flip={'axes': (0, 1, 2), 'p': 1.0},\n gamma={'log_gamma': (-0.3, 0.3), 'p': 1.0},\n intensity_scale={'scale_range': (0.75, 1.25), 'p': 1.0},\n noise={'std': 0.1, 'p': 1.0},\n blur={'std': (0.5, 1.0), 'p': 1.0},\n)\nout_img, out_mask = gpu_aug_all(test_img.clone(), test_mask.clone())\ntest_eq(out_img.shape, test_img.shape)\ntest_eq(out_mask.shape, test_mask.shape)\n\n# Intensity-only augmentation: mask unchanged\ngpu_aug_intensity = GpuPatchAugmentation(\n gamma={'log_gamma': (-0.3, 0.3), 'p': 1.0},\n intensity_scale={'scale_range': (0.75, 1.25), 'p': 1.0},\n noise={'std': 0.1, 'p': 1.0},\n blur={'std': (0.5, 1.0), 'p': 1.0},\n)\nmask_copy = test_mask.clone()\n_, out_mask = gpu_aug_intensity(test_img.clone(), mask_copy)\ntest_eq(torch.equal(out_mask, test_mask), True)\n\n# None mask handling\nout_img, out_mask = gpu_aug_all(test_img.clone(), None)\ntest_eq(out_mask, None)\ntest_eq(out_img.shape, test_img.shape)" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89ongoyx3hp", + "metadata": {}, + "outputs": [], + "source": "# Tests for gpu_patch_augmentations factory\ngpu_aug = gpu_patch_augmentations([128, 128, 128], [1.0, 1.0, 1.0])\ntest_eq(type(gpu_aug), GpuPatchAugmentation)\ntest_eq(gpu_aug.affine['degrees'], (30, 30, 30))\ntest_eq(gpu_aug.anisotropy['axes'], (0, 1, 2))\n\n# Anisotropic case\ngpu_aug_aniso = gpu_patch_augmentations([128, 128, 32], [0.5, 0.5, 1.5])\ntest_eq(gpu_aug_aniso.affine['degrees'], (5, 5, 30))\n\n# All dims 1 -> no anisotropy\ngpu_aug_nodim = gpu_patch_augmentations([1, 1, 1], [1.0, 1.0, 1.0])\ntest_eq(gpu_aug_nodim.anisotropy, None)\n\n# repr contains expected parts\nrepr_str = repr(gpu_aug)\nassert 'GpuPatchAugmentation' in repr_str\nassert 'affine' in repr_str\nassert 'flip' in repr_str\n\n# Flip with p=1 modifies both img and mask consistently\ngpu_aug_flip = GpuPatchAugmentation(flip={'axes': (0,), 'p': 1.0})\n# Create asymmetric tensor to verify flip happened\nasym_img = torch.zeros(1, 1, 8, 8, 8)\nasym_img[0, 0, 0, :, :] = 1.0 # First slice lit up\nasym_mask = asym_img.clone()\nout_img, out_mask = gpu_aug_flip(asym_img.clone(), asym_mask.clone())\ntest_eq(out_img.shape, asym_img.shape)\n# After flip on axis 0 (D-axis), first slice -> last slice\ntest_eq(out_img[0, 0, -1, :, :].sum() > 0, True)\ntest_eq(out_mask[0, 0, -1, :, :].sum() > 0, True)" + }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/08_dataset_info.ipynb b/nbs/08_dataset_info.ipynb index 5be7e82..1f27b71 100644 --- a/nbs/08_dataset_info.ipynb +++ b/nbs/08_dataset_info.ipynb @@ -27,7 +27,7 @@ "id": "027f016a-a80c-4842-b9dc-0bddb358a00c", "metadata": {}, "outputs": [], - "source": "#| export\nfrom fastMONAI.vision_core import *\nfrom fastMONAI.vision_plot import find_max_slice\n\nfrom sklearn.utils.class_weight import compute_class_weight\nfrom concurrent.futures import ThreadPoolExecutor\nfrom pathlib import Path\nimport pandas as pd\nimport numpy as np\nimport torch\nimport glob\nimport hashlib\nimport os\nimport pickle\nimport matplotlib.pyplot as plt" + "source": "#| export\nfrom fastMONAI.vision_core import *\nfrom fastMONAI.vision_plot import find_max_slice\n\nfrom sklearn.utils.class_weight import compute_class_weight\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\nfrom tqdm.auto import tqdm\nfrom pathlib import Path\nimport torchio as tio\nimport pandas as pd\nimport numpy as np\nimport torch\nimport glob\nimport hashlib\nimport os\nimport pickle\nimport matplotlib.pyplot as plt" }, { "cell_type": "code", @@ -140,6 +140,22 @@ " return patch_size" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "mbn5svtmzkh", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef preprocess_dataset(df, img_col, mask_col=None, output_dir='preprocessed',\n target_spacing=None, apply_reorder=True, transforms=None,\n max_workers=4, skip_existing=True):\n \"\"\"Preprocess dataset to disk and update DataFrame path columns in-place.\n\n Processes images (and optionally masks) through a transform pipeline,\n saves to output_dir, then updates df[img_col] and df[mask_col] in-place\n to point to the preprocessed files.\n\n Transform pipeline order:\n CopyAffine (if masks) -> ToCanonical (if apply_reorder)\n -> Resample (if target_spacing) -> user transforms\n\n Args:\n df: DataFrame with file paths.\n img_col: Column name for image paths.\n mask_col: Optional column name for mask paths.\n output_dir: Output directory. Creates images/ and masks/ subdirectories.\n target_spacing: Target voxel spacing for resampling (e.g., [1.0, 1.0, 1.0]).\n apply_reorder: Whether to reorder to RAS+ canonical orientation.\n transforms: Additional TorchIO or fastMONAI transforms to apply after\n reordering and resampling.\n max_workers: Number of parallel workers. Each worker loads a full 3D\n volume into memory, so reduce for large volumes.\n skip_existing: Skip files that already exist on disk (with size > 0).\n \"\"\"\n # Input validation\n if len(df) == 0:\n raise ValueError(\"DataFrame is empty\")\n if img_col not in df.columns:\n raise ValueError(f\"Column '{img_col}' not found in DataFrame\")\n if mask_col is not None and mask_col not in df.columns:\n raise ValueError(f\"Column '{mask_col}' not found in DataFrame\")\n\n img_names = [Path(p).name for p in df[img_col]]\n if len(set(img_names)) != len(img_names):\n dupes = set(n for n in img_names if img_names.count(n) > 1)\n raise ValueError(f\"Duplicate image file names: {dupes}\")\n\n if mask_col is not None:\n mask_names = [Path(p).name for p in df[mask_col]]\n if len(set(mask_names)) != len(mask_names):\n dupes = set(n for n in mask_names if mask_names.count(n) > 1)\n raise ValueError(f\"Duplicate mask file names: {dupes}\")\n\n # Build transform pipeline (canonical order)\n all_tfms = []\n if mask_col is not None:\n all_tfms.append(tio.CopyAffine(target='image'))\n if apply_reorder:\n all_tfms.append(tio.ToCanonical())\n if target_spacing is not None:\n all_tfms.append(tio.Resample(target_spacing))\n if transforms:\n all_tfms.extend([getattr(t, 'tio_transform', t) for t in transforms])\n pipeline = tio.Compose(all_tfms) if all_tfms else None\n\n # Create output directories\n output_dir = Path(output_dir)\n img_dir = output_dir / 'images'\n img_dir.mkdir(parents=True, exist_ok=True)\n if mask_col is not None:\n mask_dir = output_dir / 'masks'\n mask_dir.mkdir(parents=True, exist_ok=True)\n\n # Build work items, filtering skip_existing\n work_items = []\n skipped = 0\n for idx in range(len(df)):\n img_path = df[img_col].iloc[idx]\n out_img = img_dir / Path(img_path).name\n\n mask_path = df[mask_col].iloc[idx] if mask_col is not None else None\n out_mask = (mask_dir / Path(mask_path).name) if mask_col is not None else None\n\n if skip_existing:\n img_ok = out_img.exists() and out_img.stat().st_size > 0\n mask_ok = out_mask is None or (out_mask.exists() and out_mask.stat().st_size > 0)\n if img_ok and mask_ok:\n skipped += 1\n continue\n\n work_items.append({\n 'idx': idx, 'img_path': img_path, 'mask_path': mask_path,\n 'out_img': out_img, 'out_mask': out_mask,\n })\n\n # Process cases\n processed = 0\n failed = 0\n failed_cases = []\n\n def _process_case(item):\n subject_dict = {'image': tio.ScalarImage(item['img_path'])}\n if item['mask_path'] is not None:\n subject_dict['mask'] = tio.LabelMap(item['mask_path'])\n\n subject = tio.Subject(**subject_dict)\n if pipeline is not None:\n subject = pipeline(subject)\n\n # Atomic write: save to temp file (with valid NIfTI extension), then rename\n out_img = item['out_img']\n tmp_img = out_img.parent / f'.tmp_{out_img.name}'\n subject['image'].save(str(tmp_img))\n os.rename(str(tmp_img), str(out_img))\n\n if item['out_mask'] is not None:\n out_mask = item['out_mask']\n tmp_mask = out_mask.parent / f'.tmp_{out_mask.name}'\n subject['mask'].save(str(tmp_mask))\n os.rename(str(tmp_mask), str(out_mask))\n\n if work_items:\n with ThreadPoolExecutor(max_workers=max_workers) as executor:\n futures = {executor.submit(_process_case, item): item for item in work_items}\n for future in tqdm(as_completed(futures), total=len(futures),\n desc='Preprocessing'):\n item = futures[future]\n try:\n future.result()\n processed += 1\n except Exception as e:\n failed += 1\n failed_cases.append(Path(item['img_path']).name)\n warnings.warn(f\"Failed to process {item['img_path']}: {e}\")\n\n # Update DataFrame in-place\n df[img_col] = [str(img_dir / Path(p).name) for p in df[img_col]]\n if mask_col is not None:\n df[mask_col] = [str(mask_dir / Path(p).name) for p in df[mask_col]]\n\n print(f\"Preprocessing complete: {processed} processed, {skipped} skipped, {failed} failed\")\n if failed_cases:\n print(f\"Failed cases: {failed_cases}\")" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "rkoedtvhegm", + "metadata": {}, + "outputs": [], + "source": "import tempfile, shutil\nfrom fastcore.test import test_eq, test_fail\n\n_tmp = tempfile.mkdtemp()\n\n# Create synthetic NIfTI files\nfor i in range(3):\n tio.ScalarImage(tensor=torch.randn(1, 10, 10, 10)).save(f'{_tmp}/img_{i}.nii.gz')\n tio.LabelMap(tensor=torch.randint(0, 2, (1, 10, 10, 10))).save(f'{_tmp}/mask_{i}.nii.gz')\n\n# Test 1: Image-only preprocessing\n_df1 = pd.DataFrame({'img': [f'{_tmp}/img_{i}.nii.gz' for i in range(3)]})\n_out1 = f'{_tmp}/out1'\npreprocess_dataset(_df1, img_col='img', output_dir=_out1, apply_reorder=False)\ntest_eq(all(Path(p).exists() for p in _df1['img']), True)\ntest_eq(all('out1/images/' in p for p in _df1['img']), True)\n\n# Test 2: Skip-existing (rerun with original paths pointing to same filenames)\n_df2 = pd.DataFrame({'img': [f'{_tmp}/img_{i}.nii.gz' for i in range(3)]})\npreprocess_dataset(_df2, img_col='img', output_dir=_out1, apply_reorder=False)\n# Should print \"0 processed, 3 skipped\"\n\n# Test 3: With masks\n_df3 = pd.DataFrame({\n 'img': [f'{_tmp}/img_{i}.nii.gz' for i in range(3)],\n 'mask': [f'{_tmp}/mask_{i}.nii.gz' for i in range(3)],\n})\n_out3 = f'{_tmp}/out3'\npreprocess_dataset(_df3, img_col='img', mask_col='mask', output_dir=_out3, apply_reorder=False)\ntest_eq(all(Path(p).exists() for p in _df3['img']), True)\ntest_eq(all(Path(p).exists() for p in _df3['mask']), True)\ntest_eq(all('out3/masks/' in p for p in _df3['mask']), True)\n\n# Test 4: Input validation\ntest_fail(lambda: preprocess_dataset(pd.DataFrame(), img_col='img'), contains='empty')\ntest_fail(lambda: preprocess_dataset(pd.DataFrame({'x': [1]}), img_col='img'), contains='not found')\n_df_dup = pd.DataFrame({'img': [f'{_tmp}/img_0.nii.gz', f'{_tmp}/img_0.nii.gz']})\ntest_fail(lambda: preprocess_dataset(_df_dup, img_col='img'), contains='Duplicate')\n\nshutil.rmtree(_tmp)" + }, { "cell_type": "markdown", "id": "74812108-f3eb-4a8d-9f2d-b93132619008", diff --git a/nbs/10_vision_patch.ipynb b/nbs/10_vision_patch.ipynb index eb14276..9eab4de 100644 --- a/nbs/10_vision_patch.ipynb +++ b/nbs/10_vision_patch.ipynb @@ -456,7 +456,10 @@ " patch_tfms: Transforms to apply to extracted patches (training only).\n", " Accepts both fastMONAI wrappers (e.g., RandomAffine, RandomGamma) and\n", " raw TorchIO transforms. fastMONAI wrappers are automatically normalized\n", - " to raw TorchIO for internal use.\n", + " to raw TorchIO for internal use. Mutually exclusive with gpu_augmentation.\n", + " gpu_augmentation: GpuPatchAugmentation instance for GPU-batched augmentation.\n", + " Operates on [B,C,D,H,W] tensors already on GPU, avoiding per-sample CPU\n", + " overhead. Mutually exclusive with patch_tfms. Training only.\n", " shuffle: Whether to shuffle subjects and patches.\n", " drop_last: Whether to drop last incomplete batch.\n", " \"\"\"\n", @@ -467,18 +470,20 @@ " config: PatchConfig,\n", " batch_size: int = 4,\n", " patch_tfms: list = None,\n", + " gpu_augmentation=None,\n", " shuffle: bool = True,\n", " drop_last: bool = False\n", " ):\n", " if batch_size <= 0:\n", " raise ValueError(f\"batch_size must be positive, got {batch_size}\")\n", - " \n", + "\n", " self.subjects_dataset = subjects_dataset\n", " self.config = config\n", " self.bs = batch_size\n", " self.shuffle = shuffle\n", " self.drop_last = drop_last\n", " self._device = _get_default_device()\n", + " self.gpu_augmentation = gpu_augmentation\n", "\n", " # Create sampler\n", " self.sampler = create_patch_sampler(config)\n", @@ -517,14 +522,12 @@ " img = batch['image'][tio.DATA] # [B, C, H, W, D]\n", " has_mask = 'mask' in batch\n", "\n", - " # Apply patch transforms if provided\n", + " # Apply CPU patch transforms if provided (per-sample TorchIO loop)\n", " if self.patch_tfms is not None:\n", - " # Apply transforms to each sample in batch\n", " transformed_imgs = []\n", " transformed_masks = [] if has_mask else None\n", "\n", " for i in range(img.shape[0]):\n", - " # Build subject dict with image, and mask if available\n", " subject_dict = {'image': tio.ScalarImage(tensor=batch['image'][tio.DATA][i])}\n", " if has_mask:\n", " subject_dict['mask'] = tio.LabelMap(tensor=batch['mask'][tio.DATA][i])\n", @@ -540,10 +543,19 @@ " else:\n", " mask = batch['mask'][tio.DATA] if has_mask else None\n", "\n", - " # Convert to MedImage/MedMask and move to device\n", - " img = MedImage(img).to(self._device)\n", + " # Move to device\n", + " img = img.to(self._device)\n", + " if mask is not None:\n", + " mask = mask.to(self._device)\n", + "\n", + " # Apply GPU augmentation if provided (batched, on-device)\n", + " if self.gpu_augmentation is not None:\n", + " img, mask = self.gpu_augmentation(img, mask)\n", + "\n", + " # Wrap as MedImage/MedMask\n", + " img = MedImage(img)\n", " if mask is not None:\n", - " mask = MedMask(mask).to(self._device)\n", + " mask = MedMask(mask)\n", "\n", " yield img, mask\n", "\n", @@ -612,7 +624,447 @@ "id": "cell-15", "metadata": {}, "outputs": [], - "source": "#| export\nclass MedPatchDataLoaders:\n \"\"\"fastai-compatible DataLoaders for patch-based training with LAZY loading.\n\n This class provides train and validation DataLoaders that work with\n fastai's Learner for patch-based training on 3D medical images.\n\n Memory-efficient: Volumes are loaded on-demand by Queue workers,\n keeping memory usage constant (~150 MB) regardless of dataset size.\n\n Note: Validation uses the same sampling as training (pseudo Dice).\n For true validation metrics, use PatchInferenceEngine with GridSampler\n for full-volume sliding window inference.\n\n Example:\n >>> import torchio as tio\n >>>\n >>> # New pattern: preprocessing params in config (DRY)\n >>> config = PatchConfig(\n ... patch_size=[96, 96, 96],\n ... apply_reorder=True,\n ... target_spacing=[0.5, 0.5, 0.5]\n ... )\n >>> dls = MedPatchDataLoaders.from_df(\n ... df, img_col='image', mask_col='label',\n ... valid_pct=0.2,\n ... patch_config=config,\n ... pre_patch_tfms=[tio.ZNormalization()],\n ... bs=4\n ... )\n >>> learn = Learner(dls, model, loss_func=DiceLoss())\n \"\"\"\n\n def __init__(\n self,\n train_dl: MedPatchDataLoader,\n valid_dl: MedPatchDataLoader,\n device: torch.device = None\n ):\n self._train_dl = train_dl\n self._valid_dl = valid_dl\n self._device = device or _get_default_device()\n\n # Move to device\n self._train_dl.to(self._device)\n self._valid_dl.to(self._device)\n\n # Track cleanup state\n self._closed = False\n\n @classmethod\n def from_df(\n cls,\n df: pd.DataFrame,\n img_col: str,\n mask_col: str = None,\n valid_pct: float = 0.2,\n valid_col: str = None,\n patch_config: PatchConfig = None,\n pre_patch_tfms: list = None,\n patch_tfms: list = None,\n apply_reorder: bool = None,\n target_spacing: list = None,\n bs: int = 4,\n seed: int = None,\n device: torch.device = None,\n ensure_affine_consistency: bool = True\n ) -> 'MedPatchDataLoaders':\n \"\"\"Create train/valid DataLoaders from DataFrame with LAZY loading.\n\n Memory-efficient: Only file paths are stored at creation time.\n Volumes are loaded on-demand by Queue workers during training.\n\n Note: Both train and valid use the same sampling strategy from patch_config.\n This gives pseudo Dice during training. For true validation metrics,\n use PatchInferenceEngine with full-volume sliding window inference.\n\n Args:\n df: DataFrame with image paths.\n img_col: Column name for image paths.\n mask_col: Column name for mask paths.\n valid_pct: Fraction of data for validation.\n valid_col: Column name for train/valid split (if pre-defined).\n patch_config: PatchConfig instance. Preprocessing params (apply_reorder,\n target_spacing) can be set here for DRY usage with PatchInferenceEngine.\n pre_patch_tfms: TorchIO transforms applied before patch extraction\n (after reorder/resample). Example: [tio.ZNormalization()].\n Accepts both fastMONAI wrappers and raw TorchIO transforms.\n patch_tfms: TorchIO transforms applied to extracted patches (training only).\n apply_reorder: If True, reorder to RAS+ orientation. If None, uses\n patch_config.apply_reorder. Explicit value overrides config.\n target_spacing: Target voxel spacing [x, y, z]. If None, uses\n patch_config.target_spacing. Explicit value overrides config.\n bs: Batch size.\n seed: Random seed for splitting.\n device: Device to use.\n ensure_affine_consistency: If True and mask_col is provided, automatically\n adds tio.CopyAffine(target='image') as the first transform to prevent\n spatial metadata mismatch errors. Defaults to True.\n\n Returns:\n MedPatchDataLoaders instance.\n\n Example:\n >>> # New pattern: config contains preprocessing params\n >>> config = PatchConfig(\n ... patch_size=[96, 96, 96],\n ... apply_reorder=True,\n ... target_spacing=[0.5, 0.5, 0.5],\n ... label_probabilities={0: 0.1, 1: 0.9}\n ... )\n >>> dls = MedPatchDataLoaders.from_df(\n ... df, img_col='image', mask_col='label',\n ... patch_config=config,\n ... pre_patch_tfms=[tio.ZNormalization()],\n ... patch_tfms=[tio.RandomAffine(degrees=10), tio.RandomFlip()],\n ... bs=4\n ... )\n >>> # Memory: ~150 MB (queue buffer only)\n \"\"\"\n if patch_config is None:\n patch_config = PatchConfig()\n\n # Use config values, allow explicit overrides for backward compatibility\n _apply_reorder = apply_reorder if apply_reorder is not None else patch_config.apply_reorder\n _target_spacing = target_spacing if target_spacing is not None else patch_config.target_spacing\n\n # Warn if both config and explicit args provided with different values\n _warn_config_override('apply_reorder', patch_config.apply_reorder, apply_reorder)\n _warn_config_override('target_spacing', patch_config.target_spacing, target_spacing)\n\n # Split data\n if valid_col is not None:\n train_df = df[df[valid_col] == False].reset_index(drop=True)\n valid_df = df[df[valid_col] == True].reset_index(drop=True)\n else:\n if seed is not None:\n np.random.seed(seed)\n n = len(df)\n valid_idx = np.random.choice(n, size=int(n * valid_pct), replace=False)\n train_idx = np.setdiff1d(np.arange(n), valid_idx)\n train_df = df.iloc[train_idx].reset_index(drop=True)\n valid_df = df.iloc[valid_idx].reset_index(drop=True)\n\n # Build preprocessing transforms\n all_pre_tfms = []\n\n # Add reorder transform (reorder to RAS+ orientation)\n if _apply_reorder:\n all_pre_tfms.append(tio.ToCanonical())\n\n # Add resample transform\n if _target_spacing is not None:\n all_pre_tfms.append(tio.Resample(_target_spacing))\n\n # Add user-provided transforms (normalize to raw TorchIO transforms)\n if pre_patch_tfms:\n all_pre_tfms.extend(normalize_patch_transforms(pre_patch_tfms))\n\n # Create subjects datasets with lazy loading (paths only, ~0 MB)\n train_subjects = create_subjects_dataset(\n train_df, img_col, mask_col,\n pre_tfms=all_pre_tfms if all_pre_tfms else None,\n ensure_affine_consistency=ensure_affine_consistency\n )\n valid_subjects = create_subjects_dataset(\n valid_df, img_col, mask_col,\n pre_tfms=all_pre_tfms if all_pre_tfms else None,\n ensure_affine_consistency=ensure_affine_consistency\n )\n\n # Create DataLoaders (both use same patch_config for consistent sampling)\n train_dl = MedPatchDataLoader(\n train_subjects, patch_config, bs,\n patch_tfms=patch_tfms, shuffle=True, drop_last=True\n )\n valid_dl = MedPatchDataLoader(\n valid_subjects, patch_config, bs,\n patch_tfms=None, # No augmentation for validation\n shuffle=False, drop_last=False\n )\n\n # Create instance and store metadata\n instance = cls(train_dl, valid_dl, device)\n instance._img_col = img_col\n instance._mask_col = mask_col\n instance._pre_patch_tfms = pre_patch_tfms\n instance._apply_reorder = _apply_reorder\n instance._target_spacing = _target_spacing\n instance._ensure_affine_consistency = ensure_affine_consistency\n instance._patch_config = patch_config\n instance._train_source_df = train_df\n instance._valid_source_df = valid_df\n return instance\n\n @property\n def train(self):\n \"\"\"Training DataLoader.\"\"\"\n return self._train_dl\n\n @property\n def valid(self):\n \"\"\"Validation DataLoader.\"\"\"\n return self._valid_dl\n\n @property\n def train_ds(self):\n \"\"\"Training subjects dataset.\"\"\"\n return self._train_dl.subjects_dataset\n\n @property\n def valid_ds(self):\n \"\"\"Validation subjects dataset.\"\"\"\n return self._valid_dl.subjects_dataset\n\n @property\n def device(self):\n \"\"\"Current device.\"\"\"\n return self._device\n\n @property\n def bs(self):\n \"\"\"Batch size.\"\"\"\n return self._train_dl.bs\n\n @property\n def apply_reorder(self):\n \"\"\"Whether reordering to RAS+ is enabled.\"\"\"\n return getattr(self, '_apply_reorder', False)\n\n @property\n def target_spacing(self):\n \"\"\"Target voxel spacing for resampling.\"\"\"\n return getattr(self, '_target_spacing', None)\n\n @property\n def patch_config(self):\n \"\"\"The PatchConfig used for this DataLoaders.\"\"\"\n return getattr(self, '_patch_config', None)\n\n @property\n def split_df(self):\n \"\"\"DataFrame recording train/valid split for reproducibility logging.\"\"\"\n train = self._train_source_df.assign(is_valid=False)\n valid = self._valid_source_df.assign(is_valid=True)\n return pd.concat([train, valid], ignore_index=True)\n\n def to(self, device):\n \"\"\"Move DataLoaders to device.\"\"\"\n self._device = device\n self._train_dl.to(device)\n self._valid_dl.to(device)\n return self\n\n def __iter__(self):\n \"\"\"Iterate over training DataLoader.\"\"\"\n return iter(self._train_dl)\n\n def one_batch(self):\n \"\"\"Return one batch from the training DataLoader.\n\n Required for fastai Learner compatibility - used for device\n detection and batch shape validation.\n \"\"\"\n return self._train_dl.one_batch()\n\n def __len__(self):\n \"\"\"Return number of batches in training DataLoader.\"\"\"\n return len(self._train_dl)\n\n def __getitem__(self, idx):\n \"\"\"Get DataLoader by index. Required for fastai Learner compatibility.\n\n Args:\n idx: 0 for training DataLoader, 1 for validation DataLoader.\n\n Returns:\n MedPatchDataLoader instance.\n \"\"\"\n if idx == 0:\n return self._train_dl\n elif idx == 1:\n return self._valid_dl\n else:\n raise IndexError(f\"Index {idx} out of range. Use 0 (train) or 1 (valid).\")\n\n def cuda(self):\n \"\"\"Move DataLoaders to CUDA device.\"\"\"\n return self.to(torch.device('cuda'))\n\n def cpu(self):\n \"\"\"Move DataLoaders to CPU.\"\"\"\n return self.to(torch.device('cpu'))\n\n def show_batch(self, dl_idx=0, max_n=6, figsize=None, channel=0,\n slice_index=None, anatomical_plane=0, overlay=False, **kwargs):\n \"\"\"Show a batch of patch samples for visualization.\"\"\"\n\n dl = self[dl_idx]\n x, y = dl.one_batch()\n x = x.cpu()\n if y is not None: y = y.cpu()\n\n nrows = min(x.shape[0], max_n)\n has_mask = y is not None\n\n if overlay and has_mask:\n ncols = x.shape[1]\n else:\n ncols = x.shape[1] + (1 if has_mask else 0)\n\n if figsize is None:\n figsize = (ncols * 3, nrows * 3)\n fig, axs = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)\n flat_axs = axs.flatten()\n\n imgs, masks_for_overlay, slice_idxs = [], [], []\n for i in range(nrows):\n img = x[i]\n im_channels = [MedImage(c_img[None]) for c_img in img]\n\n if has_mask:\n mask = y[i]\n idx = find_max_slice(mask[0].numpy(), anatomical_plane) if slice_index is None else slice_index\n if overlay:\n masks_for_overlay.extend([MedMask(mask)] * len(im_channels))\n else:\n im_channels.append(MedMask(mask))\n else:\n idx = slice_index\n\n imgs.extend(im_channels)\n slice_idxs.extend([idx] * len(im_channels))\n\n voxel_size = self.target_spacing\n ctxs = [im.show(ax=ax, slice_index=idx, anatomical_plane=anatomical_plane,\n voxel_size=voxel_size)\n for im, ax, idx in zip(imgs, flat_axs, slice_idxs)]\n\n if overlay and has_mask:\n for mask, ax, idx in zip(masks_for_overlay, flat_axs, slice_idxs):\n mask.show(ax=ax, slice_index=idx, anatomical_plane=anatomical_plane,\n voxel_size=voxel_size)\n\n plt.tight_layout()\n plt.show()\n\n def new_empty(self):\n \"\"\"Create a new empty version of self for learner export.\n\n Required for fastai Learner.export() compatibility - creates a\n lightweight placeholder that can be pickled without the full dataset.\n\n Returns:\n A minimal MedPatchDataLoaders-like object with no data.\n \"\"\"\n class EmptyMedPatchDataLoaders:\n \"\"\"Minimal placeholder for exported learner.\"\"\"\n def __init__(self, device):\n self._device = device\n @property\n def device(self): return self._device\n def to(self, device):\n self._device = device\n return self\n def cpu(self):\n \"\"\"Move to CPU. Required for load_learner compatibility.\"\"\"\n return self.to(torch.device('cpu'))\n\n return EmptyMedPatchDataLoaders(self._device)\n\n def close(self):\n \"\"\"Shut down all DataLoader workers. Safe to call multiple times.\"\"\"\n if self._closed:\n return\n self._closed = True\n if hasattr(self, '_train_dl') and self._train_dl is not None:\n self._train_dl.close()\n if hasattr(self, '_valid_dl') and self._valid_dl is not None:\n self._valid_dl.close()\n\n def __enter__(self):\n return self\n\n def __exit__(self, exc_type, exc_val, exc_tb):\n self.close()\n return False\n\n def __del__(self):\n try:\n self.close()\n except Exception:\n pass" + "source": [ + "#| export\n", + "class MedPatchDataLoaders:\n", + " \"\"\"fastai-compatible DataLoaders for patch-based training with LAZY loading.\n", + "\n", + " This class provides train and validation DataLoaders that work with\n", + " fastai's Learner for patch-based training on 3D medical images.\n", + "\n", + " Memory-efficient: Volumes are loaded on-demand by Queue workers,\n", + " keeping memory usage constant (~150 MB) regardless of dataset size.\n", + "\n", + " Note: Validation uses the same sampling as training (pseudo Dice).\n", + " For true validation metrics, use PatchInferenceEngine with GridSampler\n", + " for full-volume sliding window inference.\n", + "\n", + " Example:\n", + " >>> import torchio as tio\n", + " >>>\n", + " >>> # New pattern: preprocessing params in config (DRY)\n", + " >>> config = PatchConfig(\n", + " ... patch_size=[96, 96, 96],\n", + " ... apply_reorder=True,\n", + " ... target_spacing=[0.5, 0.5, 0.5]\n", + " ... )\n", + " >>> dls = MedPatchDataLoaders.from_df(\n", + " ... df, img_col='image', mask_col='label',\n", + " ... valid_pct=0.2,\n", + " ... patch_config=config,\n", + " ... pre_patch_tfms=[tio.ZNormalization()],\n", + " ... bs=4\n", + " ... )\n", + " >>> learn = Learner(dls, model, loss_func=DiceLoss())\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " train_dl: MedPatchDataLoader,\n", + " valid_dl: MedPatchDataLoader,\n", + " device: torch.device = None\n", + " ):\n", + " self._train_dl = train_dl\n", + " self._valid_dl = valid_dl\n", + " self._device = device or _get_default_device()\n", + "\n", + " # Move to device\n", + " self._train_dl.to(self._device)\n", + " self._valid_dl.to(self._device)\n", + "\n", + " # Track cleanup state\n", + " self._closed = False\n", + "\n", + " @classmethod\n", + " def from_df(\n", + " cls,\n", + " df: pd.DataFrame,\n", + " img_col: str,\n", + " mask_col: str = None,\n", + " valid_pct: float = 0.2,\n", + " valid_col: str = None,\n", + " patch_config: PatchConfig = None,\n", + " pre_patch_tfms: list = None,\n", + " patch_tfms: list = None,\n", + " gpu_augmentation=None,\n", + " apply_reorder: bool = None,\n", + " target_spacing: list = None,\n", + " bs: int = 4,\n", + " seed: int = None,\n", + " device: torch.device = None,\n", + " ensure_affine_consistency: bool = True\n", + " ) -> 'MedPatchDataLoaders':\n", + " \"\"\"Create train/valid DataLoaders from DataFrame with LAZY loading.\n", + "\n", + " Memory-efficient: Only file paths are stored at creation time.\n", + " Volumes are loaded on-demand by Queue workers during training.\n", + "\n", + " Note: Both train and valid use the same sampling strategy from patch_config.\n", + " This gives pseudo Dice during training. For true validation metrics,\n", + " use PatchInferenceEngine with full-volume sliding window inference.\n", + "\n", + " Args:\n", + " df: DataFrame with image paths.\n", + " img_col: Column name for image paths.\n", + " mask_col: Column name for mask paths.\n", + " valid_pct: Fraction of data for validation.\n", + " valid_col: Column name for train/valid split (if pre-defined).\n", + " patch_config: PatchConfig instance. Preprocessing params (apply_reorder,\n", + " target_spacing) can be set here for DRY usage with PatchInferenceEngine.\n", + " pre_patch_tfms: TorchIO transforms applied before patch extraction\n", + " (after reorder/resample). Example: [tio.ZNormalization()].\n", + " Accepts both fastMONAI wrappers and raw TorchIO transforms.\n", + " patch_tfms: TorchIO transforms applied to extracted patches (training only).\n", + " Mutually exclusive with gpu_augmentation.\n", + " gpu_augmentation: GpuPatchAugmentation instance for GPU-batched augmentation\n", + " (training only). Mutually exclusive with patch_tfms.\n", + " apply_reorder: If True, reorder to RAS+ orientation. If None, uses\n", + " patch_config.apply_reorder. Explicit value overrides config.\n", + " target_spacing: Target voxel spacing [x, y, z]. If None, uses\n", + " patch_config.target_spacing. Explicit value overrides config.\n", + " bs: Batch size.\n", + " seed: Random seed for splitting.\n", + " device: Device to use.\n", + " ensure_affine_consistency: If True and mask_col is provided, automatically\n", + " adds tio.CopyAffine(target='image') as the first transform to prevent\n", + " spatial metadata mismatch errors. Defaults to True.\n", + "\n", + " Returns:\n", + " MedPatchDataLoaders instance.\n", + "\n", + " Example:\n", + " >>> # CPU augmentation path (existing)\n", + " >>> dls = MedPatchDataLoaders.from_df(\n", + " ... df, img_col='image', mask_col='label',\n", + " ... patch_config=config,\n", + " ... patch_tfms=[tio.RandomAffine(degrees=10), tio.RandomFlip()],\n", + " ... bs=4\n", + " ... )\n", + " >>>\n", + " >>> # GPU augmentation path (new, faster for long training runs)\n", + " >>> from fastMONAI.vision_augmentation import gpu_patch_augmentations\n", + " >>> gpu_aug = gpu_patch_augmentations(config.patch_size, config.target_spacing)\n", + " >>> dls = MedPatchDataLoaders.from_df(\n", + " ... df, img_col='image', mask_col='label',\n", + " ... patch_config=config,\n", + " ... gpu_augmentation=gpu_aug,\n", + " ... bs=4\n", + " ... )\n", + " \"\"\"\n", + " # Validate mutual exclusivity\n", + " if gpu_augmentation is not None and patch_tfms is not None:\n", + " raise ValueError(\n", + " \"Cannot use both gpu_augmentation and patch_tfms. \"\n", + " \"gpu_augmentation operates on GPU tensors batch-wise, while \"\n", + " \"patch_tfms uses per-sample CPU TorchIO transforms. Choose one.\"\n", + " )\n", + "\n", + " if patch_config is None:\n", + " patch_config = PatchConfig()\n", + "\n", + " # Use config values, allow explicit overrides for backward compatibility\n", + " _apply_reorder = apply_reorder if apply_reorder is not None else patch_config.apply_reorder\n", + " _target_spacing = target_spacing if target_spacing is not None else patch_config.target_spacing\n", + "\n", + " # Warn if both config and explicit args provided with different values\n", + " _warn_config_override('apply_reorder', patch_config.apply_reorder, apply_reorder)\n", + " _warn_config_override('target_spacing', patch_config.target_spacing, target_spacing)\n", + "\n", + " # Split data\n", + " if valid_col is not None:\n", + " train_df = df[df[valid_col] == False].reset_index(drop=True)\n", + " valid_df = df[df[valid_col] == True].reset_index(drop=True)\n", + " else:\n", + " if seed is not None:\n", + " np.random.seed(seed)\n", + " n = len(df)\n", + " valid_idx = np.random.choice(n, size=int(n * valid_pct), replace=False)\n", + " train_idx = np.setdiff1d(np.arange(n), valid_idx)\n", + " train_df = df.iloc[train_idx].reset_index(drop=True)\n", + " valid_df = df.iloc[valid_idx].reset_index(drop=True)\n", + "\n", + " # Build preprocessing transforms\n", + " all_pre_tfms = []\n", + "\n", + " # Add reorder transform (reorder to RAS+ orientation)\n", + " if _apply_reorder:\n", + " all_pre_tfms.append(tio.ToCanonical())\n", + "\n", + " # Add resample transform\n", + " if _target_spacing is not None:\n", + " all_pre_tfms.append(tio.Resample(_target_spacing))\n", + "\n", + " # Add user-provided transforms (normalize to raw TorchIO transforms)\n", + " if pre_patch_tfms:\n", + " all_pre_tfms.extend(normalize_patch_transforms(pre_patch_tfms))\n", + "\n", + " # Create subjects datasets with lazy loading (paths only, ~0 MB)\n", + " train_subjects = create_subjects_dataset(\n", + " train_df, img_col, mask_col,\n", + " pre_tfms=all_pre_tfms if all_pre_tfms else None,\n", + " ensure_affine_consistency=ensure_affine_consistency\n", + " )\n", + " valid_subjects = create_subjects_dataset(\n", + " valid_df, img_col, mask_col,\n", + " pre_tfms=all_pre_tfms if all_pre_tfms else None,\n", + " ensure_affine_consistency=ensure_affine_consistency\n", + " )\n", + "\n", + " # Create DataLoaders (both use same patch_config for consistent sampling)\n", + " train_dl = MedPatchDataLoader(\n", + " train_subjects, patch_config, bs,\n", + " patch_tfms=patch_tfms,\n", + " gpu_augmentation=gpu_augmentation,\n", + " shuffle=True, drop_last=True\n", + " )\n", + " valid_dl = MedPatchDataLoader(\n", + " valid_subjects, patch_config, bs,\n", + " patch_tfms=None, # No augmentation for validation\n", + " gpu_augmentation=None, # No augmentation for validation\n", + " shuffle=False, drop_last=False\n", + " )\n", + "\n", + " # Create instance and store metadata\n", + " instance = cls(train_dl, valid_dl, device)\n", + " instance._img_col = img_col\n", + " instance._mask_col = mask_col\n", + " instance._pre_patch_tfms = pre_patch_tfms\n", + " instance._apply_reorder = _apply_reorder\n", + " instance._target_spacing = _target_spacing\n", + " instance._ensure_affine_consistency = ensure_affine_consistency\n", + " instance._patch_config = patch_config\n", + " instance._train_source_df = train_df\n", + " instance._valid_source_df = valid_df\n", + " return instance\n", + "\n", + " @property\n", + " def train(self):\n", + " \"\"\"Training DataLoader.\"\"\"\n", + " return self._train_dl\n", + "\n", + " @property\n", + " def valid(self):\n", + " \"\"\"Validation DataLoader.\"\"\"\n", + " return self._valid_dl\n", + "\n", + " @property\n", + " def train_ds(self):\n", + " \"\"\"Training subjects dataset.\"\"\"\n", + " return self._train_dl.subjects_dataset\n", + "\n", + " @property\n", + " def valid_ds(self):\n", + " \"\"\"Validation subjects dataset.\"\"\"\n", + " return self._valid_dl.subjects_dataset\n", + "\n", + " @property\n", + " def device(self):\n", + " \"\"\"Current device.\"\"\"\n", + " return self._device\n", + "\n", + " @property\n", + " def bs(self):\n", + " \"\"\"Batch size.\"\"\"\n", + " return self._train_dl.bs\n", + "\n", + " @property\n", + " def apply_reorder(self):\n", + " \"\"\"Whether reordering to RAS+ is enabled.\"\"\"\n", + " return getattr(self, '_apply_reorder', False)\n", + "\n", + " @property\n", + " def target_spacing(self):\n", + " \"\"\"Target voxel spacing for resampling.\"\"\"\n", + " return getattr(self, '_target_spacing', None)\n", + "\n", + " @property\n", + " def patch_config(self):\n", + " \"\"\"The PatchConfig used for this DataLoaders.\"\"\"\n", + " return getattr(self, '_patch_config', None)\n", + "\n", + " @property\n", + " def split_df(self):\n", + " \"\"\"DataFrame recording train/valid split for reproducibility logging.\"\"\"\n", + " train = self._train_source_df.assign(is_valid=False)\n", + " valid = self._valid_source_df.assign(is_valid=True)\n", + " return pd.concat([train, valid], ignore_index=True)\n", + "\n", + " def to(self, device):\n", + " \"\"\"Move DataLoaders to device.\"\"\"\n", + " self._device = device\n", + " self._train_dl.to(device)\n", + " self._valid_dl.to(device)\n", + " return self\n", + "\n", + " def __iter__(self):\n", + " \"\"\"Iterate over training DataLoader.\"\"\"\n", + " return iter(self._train_dl)\n", + "\n", + " def one_batch(self):\n", + " \"\"\"Return one batch from the training DataLoader.\n", + "\n", + " Required for fastai Learner compatibility - used for device\n", + " detection and batch shape validation.\n", + " \"\"\"\n", + " return self._train_dl.one_batch()\n", + "\n", + " def __len__(self):\n", + " \"\"\"Return number of batches in training DataLoader.\"\"\"\n", + " return len(self._train_dl)\n", + "\n", + " def __getitem__(self, idx):\n", + " \"\"\"Get DataLoader by index. Required for fastai Learner compatibility.\n", + "\n", + " Args:\n", + " idx: 0 for training DataLoader, 1 for validation DataLoader.\n", + "\n", + " Returns:\n", + " MedPatchDataLoader instance.\n", + " \"\"\"\n", + " if idx == 0:\n", + " return self._train_dl\n", + " elif idx == 1:\n", + " return self._valid_dl\n", + " else:\n", + " raise IndexError(f\"Index {idx} out of range. Use 0 (train) or 1 (valid).\")\n", + "\n", + " def cuda(self):\n", + " \"\"\"Move DataLoaders to CUDA device.\"\"\"\n", + " return self.to(torch.device('cuda'))\n", + "\n", + " def cpu(self):\n", + " \"\"\"Move DataLoaders to CPU.\"\"\"\n", + " return self.to(torch.device('cpu'))\n", + "\n", + " def show_batch(self, dl_idx=0, max_n=6, figsize=None, channel=0,\n", + " slice_index=None, anatomical_plane=0, overlay=False,\n", + " voxel_size=None, **kwargs):\n", + " \"\"\"Show a batch of patch samples for visualization.\"\"\"\n", + "\n", + " dl = self[dl_idx]\n", + " x, y = dl.one_batch()\n", + " x = x.cpu()\n", + " if y is not None: y = y.cpu()\n", + "\n", + " nrows = min(x.shape[0], max_n)\n", + " has_mask = y is not None\n", + "\n", + " if overlay and has_mask:\n", + " ncols = x.shape[1]\n", + " else:\n", + " ncols = x.shape[1] + (1 if has_mask else 0)\n", + "\n", + " if figsize is None:\n", + " figsize = (ncols * 3, nrows * 3)\n", + " fig, axs = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)\n", + " flat_axs = axs.flatten()\n", + "\n", + " imgs, masks_for_overlay, slice_idxs = [], [], []\n", + " for i in range(nrows):\n", + " img = x[i]\n", + " im_channels = [MedImage(c_img[None]) for c_img in img]\n", + "\n", + " if has_mask:\n", + " mask = y[i]\n", + " idx = find_max_slice(mask[0].numpy(), anatomical_plane) if slice_index is None else slice_index\n", + " if overlay:\n", + " masks_for_overlay.extend([MedMask(mask)] * len(im_channels))\n", + " else:\n", + " im_channels.append(MedMask(mask))\n", + " else:\n", + " idx = slice_index\n", + "\n", + " imgs.extend(im_channels)\n", + " slice_idxs.extend([idx] * len(im_channels))\n", + "\n", + " _voxel_size = voxel_size if voxel_size is not None else self.target_spacing\n", + " ctxs = [im.show(ax=ax, slice_index=idx, anatomical_plane=anatomical_plane,\n", + " voxel_size=_voxel_size)\n", + " for im, ax, idx in zip(imgs, flat_axs, slice_idxs)]\n", + "\n", + " if overlay and has_mask:\n", + " for mask, ax, idx in zip(masks_for_overlay, flat_axs, slice_idxs):\n", + " mask.show(ax=ax, slice_index=idx, anatomical_plane=anatomical_plane,\n", + " voxel_size=_voxel_size)\n", + "\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + " def new_empty(self):\n", + " \"\"\"Create a new empty version of self for learner export.\n", + "\n", + " Required for fastai Learner.export() compatibility - creates a\n", + " lightweight placeholder that can be pickled without the full dataset.\n", + "\n", + " Returns:\n", + " A minimal MedPatchDataLoaders-like object with no data.\n", + " \"\"\"\n", + " class EmptyMedPatchDataLoaders:\n", + " \"\"\"Minimal placeholder for exported learner.\"\"\"\n", + " def __init__(self, device):\n", + " self._device = device\n", + " @property\n", + " def device(self): return self._device\n", + " def to(self, device):\n", + " self._device = device\n", + " return self\n", + " def cpu(self):\n", + " \"\"\"Move to CPU. Required for load_learner compatibility.\"\"\"\n", + " return self.to(torch.device('cpu'))\n", + "\n", + " return EmptyMedPatchDataLoaders(self._device)\n", + "\n", + " def close(self):\n", + " \"\"\"Shut down all DataLoader workers. Safe to call multiple times.\"\"\"\n", + " if self._closed:\n", + " return\n", + " self._closed = True\n", + " if hasattr(self, '_train_dl') and self._train_dl is not None:\n", + " self._train_dl.close()\n", + " if hasattr(self, '_valid_dl') and self._valid_dl is not None:\n", + " self._valid_dl.close()\n", + "\n", + " def __enter__(self):\n", + " return self\n", + "\n", + " def __exit__(self, exc_type, exc_val, exc_tb):\n", + " self.close()\n", + " return False\n", + "\n", + " def __del__(self):\n", + " try:\n", + " self.close()\n", + " except Exception:\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "nltdqti8c0t", + "metadata": {}, + "outputs": [], + "source": [ + "# Test mutual exclusivity of gpu_augmentation and patch_tfms\n", + "from fastMONAI.vision_augmentation import GpuPatchAugmentation\n", + "\n", + "# Should raise ValueError when both gpu_augmentation and patch_tfms are provided\n", + "test_fail(\n", + " lambda: MedPatchDataLoaders.from_df(\n", + " pd.DataFrame({'img': ['fake.nii'], 'mask': ['fake.nii']}),\n", + " img_col='img', mask_col='mask',\n", + " patch_tfms=[tio.RandomFlip()],\n", + " gpu_augmentation=GpuPatchAugmentation(flip={'axes': (0,), 'p': 0.5}),\n", + " ),\n", + " contains='Cannot use both'\n", + ")\n", + "\n", + "# Verify gpu_augmentation is stored on train_dl but not valid_dl\n", + "# (We can't fully instantiate from_df without real files, so test MedPatchDataLoader directly)\n", + "test_eq(MedPatchDataLoader.__init__.__code__.co_varnames[:8],\n", + " ('self', 'subjects_dataset', 'config', 'batch_size',\n", + " 'patch_tfms', 'gpu_augmentation', 'shuffle', 'drop_last'))" + ] }, { "cell_type": "markdown", diff --git a/nbs/12a_tutorial_patch_training.ipynb b/nbs/12a_tutorial_patch_training.ipynb index 03fbce9..937bf7f 100644 --- a/nbs/12a_tutorial_patch_training.ipynb +++ b/nbs/12a_tutorial_patch_training.ipynb @@ -506,6 +506,37 @@ "]" ] }, + { + "cell_type": "markdown", + "id": "cxa906d8xp", + "metadata": {}, + "source": [ + "#### Alternative: GPU-batched augmentation\n", + "\n", + "For long training runs (e.g., hundreds of epochs), GPU-batched augmentation can significantly\n", + "reduce training time by moving transforms from CPU to GPU. Instead of per-sample TorchIO\n", + "transforms (`patch_tfms`), `gpu_patch_augmentations` creates a batched augmentation pipeline\n", + "that operates on GPU tensors directly.\n", + "\n", + "Use `gpu_augmentation` **instead of** `patch_tfms` (they are mutually exclusive)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ymnk031qcys", + "metadata": {}, + "outputs": [], + "source": [ + "# # GPU augmentation alternative (uncomment to use instead of patch_tfms above)\n", + "# gpu_aug = gpu_patch_augmentations(patch_config.patch_size, patch_config.target_spacing)\n", + "#\n", + "# # Then pass gpu_augmentation instead of patch_tfms to from_df:\n", + "# # dls = MedPatchDataLoaders.from_df(\n", + "# # ..., gpu_augmentation=gpu_aug, ... # replaces patch_tfms=patch_tfms\n", + "# # )" + ] + }, { "cell_type": "markdown", "id": "cell-dls-header",