From d0edf25d7b5904efa9220128e0eea5fff01aa513 Mon Sep 17 00:00:00 2001 From: Sathiesh Date: Wed, 18 Feb 2026 12:19:13 +0100 Subject: [PATCH] feat: add preprocessed mode to PatchConfig, preserve original columns in preprocess_dataset; bump to 0.8.2 --- fastMONAI/__init__.py | 2 +- fastMONAI/dataset_info.py | 13 +- fastMONAI/utils.py | 1 + fastMONAI/vision_patch.py | 28 ++- nbs/07_utils.ipynb | 2 +- nbs/08_dataset_info.ipynb | 4 +- nbs/10_vision_patch.ipynb | 437 +------------------------------------- settings.ini | 2 +- 8 files changed, 35 insertions(+), 454 deletions(-) diff --git a/fastMONAI/__init__.py b/fastMONAI/__init__.py index 8088f75..deded32 100644 --- a/fastMONAI/__init__.py +++ b/fastMONAI/__init__.py @@ -1 +1 @@ -__version__ = "0.8.1" +__version__ = "0.8.2" diff --git a/fastMONAI/dataset_info.py b/fastMONAI/dataset_info.py index 6414935..5d3e401 100644 --- a/fastMONAI/dataset_info.py +++ b/fastMONAI/dataset_info.py @@ -554,11 +554,11 @@ def round_to_divisor(val, div): def preprocess_dataset(df, img_col, mask_col=None, output_dir='preprocessed', target_spacing=None, apply_reorder=True, transforms=None, max_workers=4, skip_existing=True): - """Preprocess dataset to disk and update DataFrame path columns in-place. + """Preprocess dataset to disk, creating new columns for preprocessed paths. Processes images (and optionally masks) through a transform pipeline, - saves to output_dir, then updates df[img_col] and df[mask_col] in-place - to point to the preprocessed files. + saves to output_dir, then creates new '{col}_preprocessed' columns in + the DataFrame. Original columns are preserved unchanged. Transform pipeline order: CopyAffine (if masks) -> ToCanonical (if apply_reorder) @@ -678,10 +678,11 @@ def _process_case(item): failed_cases.append(Path(item['img_path']).name) warnings.warn(f"Failed to process {item['img_path']}: {e}") - # Update DataFrame in-place - df[img_col] = [str(img_dir / Path(p).name) for p in df[img_col]] + # Create new columns for preprocessed paths (preserve originals) + df[f'{img_col}_preprocessed'] = [str(img_dir / Path(p).name) for p in df[img_col]] + if mask_col is not None: - df[mask_col] = [str(mask_dir / Path(p).name) for p in df[mask_col]] + df[f'{mask_col}_preprocessed'] = [str(mask_dir / Path(p).name) for p in df[mask_col]] print(f"Preprocessing complete: {processed} processed, {skipped} skipped, {failed} failed") if failed_cases: diff --git a/fastMONAI/utils.py b/fastMONAI/utils.py index b9ef97e..252618f 100644 --- a/fastMONAI/utils.py +++ b/fastMONAI/utils.py @@ -237,6 +237,7 @@ def _extract_patch_config(learn) -> dict: 'aggregation_mode': patch_config.aggregation_mode, 'padding_mode': patch_config.padding_mode, 'keep_largest_component': patch_config.keep_largest_component, + 'preprocessed': patch_config.preprocessed, } else: config['patch_config'] = None diff --git a/fastMONAI/vision_patch.py b/fastMONAI/vision_patch.py index 55d0d61..4e17780 100644 --- a/fastMONAI/vision_patch.py +++ b/fastMONAI/vision_patch.py @@ -116,6 +116,11 @@ class PatchConfig: training and inference. Defaults to True (the common case). target_spacing: Target voxel spacing [x, y, z] for resampling. Must match between training and inference. + preprocessed: If True, data has been preprocessed externally (e.g., via + preprocess_dataset()). Training will skip reorder, resample, AND + pre_patch_tfms (e.g., normalization) since they were already applied. + Inference is unaffected and always applies pre_inference_tfms to raw + images. Defaults to False. padding_mode: Padding mode for CropOrPad when image < patch_size. Default is 0 (zero padding) to align with nnU-Net's approach. Can be int, float, or string (e.g., 'minimum', 'mean'). keep_largest_component: If True, keep only the largest connected component @@ -142,6 +147,7 @@ class PatchConfig: # Preprocessing parameters - must match between training and inference apply_reorder: bool = True # Defaults to True (the common case) target_spacing: list = None + preprocessed: bool = False # True = data already preprocessed, skip all preprocessing during training padding_mode: int | float | str = 0 # Zero padding (nnU-Net standard) # Post-processing (binary segmentation only) keep_largest_component: bool = False @@ -653,6 +659,8 @@ def from_df( pre_patch_tfms: TorchIO transforms applied before patch extraction (after reorder/resample). Example: [tio.ZNormalization()]. Accepts both fastMONAI wrappers and raw TorchIO transforms. + Skipped when preprocessed=True (include in preprocess_dataset() + transforms instead). Still needed for inference via pre_inference_tfms. patch_tfms: TorchIO transforms applied to extracted patches (training only). Mutually exclusive with gpu_augmentation. gpu_augmentation: GpuPatchAugmentation instance for GPU-batched augmentation @@ -725,17 +733,19 @@ def from_df( # Build preprocessing transforms all_pre_tfms = [] - # Add reorder transform (reorder to RAS+ orientation) - if _apply_reorder: - all_pre_tfms.append(tio.ToCanonical()) + # Skip all preprocessing if data was already preprocessed externally + if not patch_config.preprocessed: + # Add reorder transform (reorder to RAS+ orientation) + if _apply_reorder: + all_pre_tfms.append(tio.ToCanonical()) - # Add resample transform - if _target_spacing is not None: - all_pre_tfms.append(tio.Resample(_target_spacing)) + # Add resample transform + if _target_spacing is not None: + all_pre_tfms.append(tio.Resample(_target_spacing)) - # Add user-provided transforms (normalize to raw TorchIO transforms) - if pre_patch_tfms: - all_pre_tfms.extend(normalize_patch_transforms(pre_patch_tfms)) + # Add user-provided transforms (normalize to raw TorchIO transforms) + if pre_patch_tfms: + all_pre_tfms.extend(normalize_patch_transforms(pre_patch_tfms)) # Create subjects datasets with lazy loading (paths only, ~0 MB) train_subjects = create_subjects_dataset( diff --git a/nbs/07_utils.ipynb b/nbs/07_utils.ipynb index 9604693..84ffc05 100644 --- a/nbs/07_utils.ipynb +++ b/nbs/07_utils.ipynb @@ -199,7 +199,7 @@ "id": "czquspt567w", "metadata": {}, "outputs": [], - "source": "#| export\ndef _detect_patch_workflow(dls) -> bool:\n \"\"\"Detect if DataLoaders are patch-based (MedPatchDataLoaders).\n \n Args:\n dls: DataLoaders instance\n \n Returns:\n True if dls is a MedPatchDataLoaders instance\n \"\"\"\n return hasattr(dls, 'patch_config') or hasattr(dls, '_patch_config')\n\n\ndef _extract_size_from_transforms(tfms) -> list | None:\n \"\"\"Extract target size from PadOrCrop transform if present.\n \n Args:\n tfms: List of transforms\n \n Returns:\n Target size as list, or None if not found\n \"\"\"\n if tfms is None:\n return None\n for tfm in tfms:\n if hasattr(tfm, 'pad_or_crop') and hasattr(tfm.pad_or_crop, 'target_shape'):\n return list(tfm.pad_or_crop.target_shape)\n return None\n\n\ndef _extract_standard_config(learn) -> dict:\n \"\"\"Extract config from standard MedDataBlock workflow.\n \n Args:\n learn: fastai Learner instance\n \n Returns:\n Dictionary with extracted configuration\n \"\"\"\n from fastMONAI.vision_core import MedBase\n dls = learn.dls\n\n # Get preprocessing from MedBase class attributes\n apply_reorder = MedBase.apply_reorder\n target_spacing = MedBase.target_spacing\n\n # Extract item_tfms from DataLoaders pipeline\n item_tfms = []\n if hasattr(dls, 'after_item') and dls.after_item:\n item_tfms = list(dls.after_item.fs)\n\n # Extract size from PadOrCrop transform\n size = _extract_size_from_transforms(item_tfms)\n\n return {\n 'apply_reorder': apply_reorder,\n 'target_spacing': target_spacing,\n 'size': size,\n 'item_tfms': item_tfms,\n 'batch_size': dls.bs,\n 'patch_config': None,\n }\n\n\ndef _extract_patch_config(learn) -> dict:\n \"\"\"Extract config from MedPatchDataLoaders workflow.\n \n Args:\n learn: fastai Learner instance\n \n Returns:\n Dictionary with extracted configuration including patch-specific params\n \"\"\"\n dls = learn.dls\n patch_config = getattr(dls, '_patch_config', None) or getattr(dls, 'patch_config', None)\n\n config = {\n 'apply_reorder': getattr(dls, '_apply_reorder', patch_config.apply_reorder if patch_config else False),\n 'target_spacing': getattr(dls, '_target_spacing', patch_config.target_spacing if patch_config else None),\n 'size': patch_config.patch_size if patch_config else None,\n 'item_tfms': getattr(dls, '_pre_patch_tfms', []) or [],\n 'batch_size': dls.bs,\n }\n\n # Add patch-specific params for logging\n if patch_config:\n config['patch_config'] = {\n 'patch_size': patch_config.patch_size,\n 'patch_overlap': patch_config.patch_overlap,\n 'samples_per_volume': patch_config.samples_per_volume,\n 'sampler_type': patch_config.sampler_type,\n 'label_probabilities': str(patch_config.label_probabilities) if patch_config.label_probabilities else None,\n 'queue_length': patch_config.queue_length,\n 'aggregation_mode': patch_config.aggregation_mode,\n 'padding_mode': patch_config.padding_mode,\n 'keep_largest_component': patch_config.keep_largest_component,\n }\n else:\n config['patch_config'] = None\n\n return config\n\n\ndef _extract_loss_name(learn) -> str:\n \"\"\"Extract loss function name from Learner.\n \n Args:\n learn: fastai Learner instance\n \n Returns:\n Name of the loss function\n \"\"\"\n loss_func = learn.loss_func\n # Handle CustomLoss wrapper\n if hasattr(loss_func, 'loss_func'):\n inner = loss_func.loss_func\n return inner._get_name() if hasattr(inner, '_get_name') else inner.__class__.__name__\n return loss_func._get_name() if hasattr(loss_func, '_get_name') else loss_func.__class__.__name__\n\n\ndef _extract_model_name(learn) -> str:\n \"\"\"Extract model architecture name from Learner.\n \n Args:\n learn: fastai Learner instance\n \n Returns:\n Name of the model architecture\n \"\"\"\n model = learn.model\n return model._get_name() if hasattr(model, '_get_name') else model.__class__.__name__" + "source": "#| export\ndef _detect_patch_workflow(dls) -> bool:\n \"\"\"Detect if DataLoaders are patch-based (MedPatchDataLoaders).\n \n Args:\n dls: DataLoaders instance\n \n Returns:\n True if dls is a MedPatchDataLoaders instance\n \"\"\"\n return hasattr(dls, 'patch_config') or hasattr(dls, '_patch_config')\n\n\ndef _extract_size_from_transforms(tfms) -> list | None:\n \"\"\"Extract target size from PadOrCrop transform if present.\n \n Args:\n tfms: List of transforms\n \n Returns:\n Target size as list, or None if not found\n \"\"\"\n if tfms is None:\n return None\n for tfm in tfms:\n if hasattr(tfm, 'pad_or_crop') and hasattr(tfm.pad_or_crop, 'target_shape'):\n return list(tfm.pad_or_crop.target_shape)\n return None\n\n\ndef _extract_standard_config(learn) -> dict:\n \"\"\"Extract config from standard MedDataBlock workflow.\n \n Args:\n learn: fastai Learner instance\n \n Returns:\n Dictionary with extracted configuration\n \"\"\"\n from fastMONAI.vision_core import MedBase\n dls = learn.dls\n\n # Get preprocessing from MedBase class attributes\n apply_reorder = MedBase.apply_reorder\n target_spacing = MedBase.target_spacing\n\n # Extract item_tfms from DataLoaders pipeline\n item_tfms = []\n if hasattr(dls, 'after_item') and dls.after_item:\n item_tfms = list(dls.after_item.fs)\n\n # Extract size from PadOrCrop transform\n size = _extract_size_from_transforms(item_tfms)\n\n return {\n 'apply_reorder': apply_reorder,\n 'target_spacing': target_spacing,\n 'size': size,\n 'item_tfms': item_tfms,\n 'batch_size': dls.bs,\n 'patch_config': None,\n }\n\n\ndef _extract_patch_config(learn) -> dict:\n \"\"\"Extract config from MedPatchDataLoaders workflow.\n \n Args:\n learn: fastai Learner instance\n \n Returns:\n Dictionary with extracted configuration including patch-specific params\n \"\"\"\n dls = learn.dls\n patch_config = getattr(dls, '_patch_config', None) or getattr(dls, 'patch_config', None)\n\n config = {\n 'apply_reorder': getattr(dls, '_apply_reorder', patch_config.apply_reorder if patch_config else False),\n 'target_spacing': getattr(dls, '_target_spacing', patch_config.target_spacing if patch_config else None),\n 'size': patch_config.patch_size if patch_config else None,\n 'item_tfms': getattr(dls, '_pre_patch_tfms', []) or [],\n 'batch_size': dls.bs,\n }\n\n # Add patch-specific params for logging\n if patch_config:\n config['patch_config'] = {\n 'patch_size': patch_config.patch_size,\n 'patch_overlap': patch_config.patch_overlap,\n 'samples_per_volume': patch_config.samples_per_volume,\n 'sampler_type': patch_config.sampler_type,\n 'label_probabilities': str(patch_config.label_probabilities) if patch_config.label_probabilities else None,\n 'queue_length': patch_config.queue_length,\n 'aggregation_mode': patch_config.aggregation_mode,\n 'padding_mode': patch_config.padding_mode,\n 'keep_largest_component': patch_config.keep_largest_component,\n 'preprocessed': patch_config.preprocessed,\n }\n else:\n config['patch_config'] = None\n\n return config\n\n\ndef _extract_loss_name(learn) -> str:\n \"\"\"Extract loss function name from Learner.\n \n Args:\n learn: fastai Learner instance\n \n Returns:\n Name of the loss function\n \"\"\"\n loss_func = learn.loss_func\n # Handle CustomLoss wrapper\n if hasattr(loss_func, 'loss_func'):\n inner = loss_func.loss_func\n return inner._get_name() if hasattr(inner, '_get_name') else inner.__class__.__name__\n return loss_func._get_name() if hasattr(loss_func, '_get_name') else loss_func.__class__.__name__\n\n\ndef _extract_model_name(learn) -> str:\n \"\"\"Extract model architecture name from Learner.\n \n Args:\n learn: fastai Learner instance\n \n Returns:\n Name of the model architecture\n \"\"\"\n model = learn.model\n return model._get_name() if hasattr(model, '_get_name') else model.__class__.__name__" }, { "cell_type": "code", diff --git a/nbs/08_dataset_info.ipynb b/nbs/08_dataset_info.ipynb index 1f27b71..7539a9b 100644 --- a/nbs/08_dataset_info.ipynb +++ b/nbs/08_dataset_info.ipynb @@ -146,7 +146,7 @@ "id": "mbn5svtmzkh", "metadata": {}, "outputs": [], - "source": "#| export\ndef preprocess_dataset(df, img_col, mask_col=None, output_dir='preprocessed',\n target_spacing=None, apply_reorder=True, transforms=None,\n max_workers=4, skip_existing=True):\n \"\"\"Preprocess dataset to disk and update DataFrame path columns in-place.\n\n Processes images (and optionally masks) through a transform pipeline,\n saves to output_dir, then updates df[img_col] and df[mask_col] in-place\n to point to the preprocessed files.\n\n Transform pipeline order:\n CopyAffine (if masks) -> ToCanonical (if apply_reorder)\n -> Resample (if target_spacing) -> user transforms\n\n Args:\n df: DataFrame with file paths.\n img_col: Column name for image paths.\n mask_col: Optional column name for mask paths.\n output_dir: Output directory. Creates images/ and masks/ subdirectories.\n target_spacing: Target voxel spacing for resampling (e.g., [1.0, 1.0, 1.0]).\n apply_reorder: Whether to reorder to RAS+ canonical orientation.\n transforms: Additional TorchIO or fastMONAI transforms to apply after\n reordering and resampling.\n max_workers: Number of parallel workers. Each worker loads a full 3D\n volume into memory, so reduce for large volumes.\n skip_existing: Skip files that already exist on disk (with size > 0).\n \"\"\"\n # Input validation\n if len(df) == 0:\n raise ValueError(\"DataFrame is empty\")\n if img_col not in df.columns:\n raise ValueError(f\"Column '{img_col}' not found in DataFrame\")\n if mask_col is not None and mask_col not in df.columns:\n raise ValueError(f\"Column '{mask_col}' not found in DataFrame\")\n\n img_names = [Path(p).name for p in df[img_col]]\n if len(set(img_names)) != len(img_names):\n dupes = set(n for n in img_names if img_names.count(n) > 1)\n raise ValueError(f\"Duplicate image file names: {dupes}\")\n\n if mask_col is not None:\n mask_names = [Path(p).name for p in df[mask_col]]\n if len(set(mask_names)) != len(mask_names):\n dupes = set(n for n in mask_names if mask_names.count(n) > 1)\n raise ValueError(f\"Duplicate mask file names: {dupes}\")\n\n # Build transform pipeline (canonical order)\n all_tfms = []\n if mask_col is not None:\n all_tfms.append(tio.CopyAffine(target='image'))\n if apply_reorder:\n all_tfms.append(tio.ToCanonical())\n if target_spacing is not None:\n all_tfms.append(tio.Resample(target_spacing))\n if transforms:\n all_tfms.extend([getattr(t, 'tio_transform', t) for t in transforms])\n pipeline = tio.Compose(all_tfms) if all_tfms else None\n\n # Create output directories\n output_dir = Path(output_dir)\n img_dir = output_dir / 'images'\n img_dir.mkdir(parents=True, exist_ok=True)\n if mask_col is not None:\n mask_dir = output_dir / 'masks'\n mask_dir.mkdir(parents=True, exist_ok=True)\n\n # Build work items, filtering skip_existing\n work_items = []\n skipped = 0\n for idx in range(len(df)):\n img_path = df[img_col].iloc[idx]\n out_img = img_dir / Path(img_path).name\n\n mask_path = df[mask_col].iloc[idx] if mask_col is not None else None\n out_mask = (mask_dir / Path(mask_path).name) if mask_col is not None else None\n\n if skip_existing:\n img_ok = out_img.exists() and out_img.stat().st_size > 0\n mask_ok = out_mask is None or (out_mask.exists() and out_mask.stat().st_size > 0)\n if img_ok and mask_ok:\n skipped += 1\n continue\n\n work_items.append({\n 'idx': idx, 'img_path': img_path, 'mask_path': mask_path,\n 'out_img': out_img, 'out_mask': out_mask,\n })\n\n # Process cases\n processed = 0\n failed = 0\n failed_cases = []\n\n def _process_case(item):\n subject_dict = {'image': tio.ScalarImage(item['img_path'])}\n if item['mask_path'] is not None:\n subject_dict['mask'] = tio.LabelMap(item['mask_path'])\n\n subject = tio.Subject(**subject_dict)\n if pipeline is not None:\n subject = pipeline(subject)\n\n # Atomic write: save to temp file (with valid NIfTI extension), then rename\n out_img = item['out_img']\n tmp_img = out_img.parent / f'.tmp_{out_img.name}'\n subject['image'].save(str(tmp_img))\n os.rename(str(tmp_img), str(out_img))\n\n if item['out_mask'] is not None:\n out_mask = item['out_mask']\n tmp_mask = out_mask.parent / f'.tmp_{out_mask.name}'\n subject['mask'].save(str(tmp_mask))\n os.rename(str(tmp_mask), str(out_mask))\n\n if work_items:\n with ThreadPoolExecutor(max_workers=max_workers) as executor:\n futures = {executor.submit(_process_case, item): item for item in work_items}\n for future in tqdm(as_completed(futures), total=len(futures),\n desc='Preprocessing'):\n item = futures[future]\n try:\n future.result()\n processed += 1\n except Exception as e:\n failed += 1\n failed_cases.append(Path(item['img_path']).name)\n warnings.warn(f\"Failed to process {item['img_path']}: {e}\")\n\n # Update DataFrame in-place\n df[img_col] = [str(img_dir / Path(p).name) for p in df[img_col]]\n if mask_col is not None:\n df[mask_col] = [str(mask_dir / Path(p).name) for p in df[mask_col]]\n\n print(f\"Preprocessing complete: {processed} processed, {skipped} skipped, {failed} failed\")\n if failed_cases:\n print(f\"Failed cases: {failed_cases}\")" + "source": "#| export\ndef preprocess_dataset(df, img_col, mask_col=None, output_dir='preprocessed',\n target_spacing=None, apply_reorder=True, transforms=None,\n max_workers=4, skip_existing=True):\n \"\"\"Preprocess dataset to disk, creating new columns for preprocessed paths.\n\n Processes images (and optionally masks) through a transform pipeline,\n saves to output_dir, then creates new '{col}_preprocessed' columns in\n the DataFrame. Original columns are preserved unchanged.\n\n Transform pipeline order:\n CopyAffine (if masks) -> ToCanonical (if apply_reorder)\n -> Resample (if target_spacing) -> user transforms\n\n Args:\n df: DataFrame with file paths.\n img_col: Column name for image paths.\n mask_col: Optional column name for mask paths.\n output_dir: Output directory. Creates images/ and masks/ subdirectories.\n target_spacing: Target voxel spacing for resampling (e.g., [1.0, 1.0, 1.0]).\n apply_reorder: Whether to reorder to RAS+ canonical orientation.\n transforms: Additional TorchIO or fastMONAI transforms to apply after\n reordering and resampling.\n max_workers: Number of parallel workers. Each worker loads a full 3D\n volume into memory, so reduce for large volumes.\n skip_existing: Skip files that already exist on disk (with size > 0).\n \"\"\"\n # Input validation\n if len(df) == 0:\n raise ValueError(\"DataFrame is empty\")\n if img_col not in df.columns:\n raise ValueError(f\"Column '{img_col}' not found in DataFrame\")\n if mask_col is not None and mask_col not in df.columns:\n raise ValueError(f\"Column '{mask_col}' not found in DataFrame\")\n\n img_names = [Path(p).name for p in df[img_col]]\n if len(set(img_names)) != len(img_names):\n dupes = set(n for n in img_names if img_names.count(n) > 1)\n raise ValueError(f\"Duplicate image file names: {dupes}\")\n\n if mask_col is not None:\n mask_names = [Path(p).name for p in df[mask_col]]\n if len(set(mask_names)) != len(mask_names):\n dupes = set(n for n in mask_names if mask_names.count(n) > 1)\n raise ValueError(f\"Duplicate mask file names: {dupes}\")\n\n # Build transform pipeline (canonical order)\n all_tfms = []\n if mask_col is not None:\n all_tfms.append(tio.CopyAffine(target='image'))\n if apply_reorder:\n all_tfms.append(tio.ToCanonical())\n if target_spacing is not None:\n all_tfms.append(tio.Resample(target_spacing))\n if transforms:\n all_tfms.extend([getattr(t, 'tio_transform', t) for t in transforms])\n pipeline = tio.Compose(all_tfms) if all_tfms else None\n\n # Create output directories\n output_dir = Path(output_dir)\n img_dir = output_dir / 'images'\n img_dir.mkdir(parents=True, exist_ok=True)\n if mask_col is not None:\n mask_dir = output_dir / 'masks'\n mask_dir.mkdir(parents=True, exist_ok=True)\n\n # Build work items, filtering skip_existing\n work_items = []\n skipped = 0\n for idx in range(len(df)):\n img_path = df[img_col].iloc[idx]\n out_img = img_dir / Path(img_path).name\n\n mask_path = df[mask_col].iloc[idx] if mask_col is not None else None\n out_mask = (mask_dir / Path(mask_path).name) if mask_col is not None else None\n\n if skip_existing:\n img_ok = out_img.exists() and out_img.stat().st_size > 0\n mask_ok = out_mask is None or (out_mask.exists() and out_mask.stat().st_size > 0)\n if img_ok and mask_ok:\n skipped += 1\n continue\n\n work_items.append({\n 'idx': idx, 'img_path': img_path, 'mask_path': mask_path,\n 'out_img': out_img, 'out_mask': out_mask,\n })\n\n # Process cases\n processed = 0\n failed = 0\n failed_cases = []\n\n def _process_case(item):\n subject_dict = {'image': tio.ScalarImage(item['img_path'])}\n if item['mask_path'] is not None:\n subject_dict['mask'] = tio.LabelMap(item['mask_path'])\n\n subject = tio.Subject(**subject_dict)\n if pipeline is not None:\n subject = pipeline(subject)\n\n # Atomic write: save to temp file (with valid NIfTI extension), then rename\n out_img = item['out_img']\n tmp_img = out_img.parent / f'.tmp_{out_img.name}'\n subject['image'].save(str(tmp_img))\n os.rename(str(tmp_img), str(out_img))\n\n if item['out_mask'] is not None:\n out_mask = item['out_mask']\n tmp_mask = out_mask.parent / f'.tmp_{out_mask.name}'\n subject['mask'].save(str(tmp_mask))\n os.rename(str(tmp_mask), str(out_mask))\n\n if work_items:\n with ThreadPoolExecutor(max_workers=max_workers) as executor:\n futures = {executor.submit(_process_case, item): item for item in work_items}\n for future in tqdm(as_completed(futures), total=len(futures),\n desc='Preprocessing'):\n item = futures[future]\n try:\n future.result()\n processed += 1\n except Exception as e:\n failed += 1\n failed_cases.append(Path(item['img_path']).name)\n warnings.warn(f\"Failed to process {item['img_path']}: {e}\")\n\n # Create new columns for preprocessed paths (preserve originals)\n df[f'{img_col}_preprocessed'] = [str(img_dir / Path(p).name) for p in df[img_col]]\n\n if mask_col is not None:\n df[f'{mask_col}_preprocessed'] = [str(mask_dir / Path(p).name) for p in df[mask_col]]\n\n print(f\"Preprocessing complete: {processed} processed, {skipped} skipped, {failed} failed\")\n if failed_cases:\n print(f\"Failed cases: {failed_cases}\")" }, { "cell_type": "code", @@ -154,7 +154,7 @@ "id": "rkoedtvhegm", "metadata": {}, "outputs": [], - "source": "import tempfile, shutil\nfrom fastcore.test import test_eq, test_fail\n\n_tmp = tempfile.mkdtemp()\n\n# Create synthetic NIfTI files\nfor i in range(3):\n tio.ScalarImage(tensor=torch.randn(1, 10, 10, 10)).save(f'{_tmp}/img_{i}.nii.gz')\n tio.LabelMap(tensor=torch.randint(0, 2, (1, 10, 10, 10))).save(f'{_tmp}/mask_{i}.nii.gz')\n\n# Test 1: Image-only preprocessing\n_df1 = pd.DataFrame({'img': [f'{_tmp}/img_{i}.nii.gz' for i in range(3)]})\n_out1 = f'{_tmp}/out1'\npreprocess_dataset(_df1, img_col='img', output_dir=_out1, apply_reorder=False)\ntest_eq(all(Path(p).exists() for p in _df1['img']), True)\ntest_eq(all('out1/images/' in p for p in _df1['img']), True)\n\n# Test 2: Skip-existing (rerun with original paths pointing to same filenames)\n_df2 = pd.DataFrame({'img': [f'{_tmp}/img_{i}.nii.gz' for i in range(3)]})\npreprocess_dataset(_df2, img_col='img', output_dir=_out1, apply_reorder=False)\n# Should print \"0 processed, 3 skipped\"\n\n# Test 3: With masks\n_df3 = pd.DataFrame({\n 'img': [f'{_tmp}/img_{i}.nii.gz' for i in range(3)],\n 'mask': [f'{_tmp}/mask_{i}.nii.gz' for i in range(3)],\n})\n_out3 = f'{_tmp}/out3'\npreprocess_dataset(_df3, img_col='img', mask_col='mask', output_dir=_out3, apply_reorder=False)\ntest_eq(all(Path(p).exists() for p in _df3['img']), True)\ntest_eq(all(Path(p).exists() for p in _df3['mask']), True)\ntest_eq(all('out3/masks/' in p for p in _df3['mask']), True)\n\n# Test 4: Input validation\ntest_fail(lambda: preprocess_dataset(pd.DataFrame(), img_col='img'), contains='empty')\ntest_fail(lambda: preprocess_dataset(pd.DataFrame({'x': [1]}), img_col='img'), contains='not found')\n_df_dup = pd.DataFrame({'img': [f'{_tmp}/img_0.nii.gz', f'{_tmp}/img_0.nii.gz']})\ntest_fail(lambda: preprocess_dataset(_df_dup, img_col='img'), contains='Duplicate')\n\nshutil.rmtree(_tmp)" + "source": "import tempfile, shutil\nfrom fastcore.test import test_eq, test_fail\n\n_tmp = tempfile.mkdtemp()\n\n# Create synthetic NIfTI files\nfor i in range(3):\n tio.ScalarImage(tensor=torch.randn(1, 10, 10, 10)).save(f'{_tmp}/img_{i}.nii.gz')\n tio.LabelMap(tensor=torch.randint(0, 2, (1, 10, 10, 10))).save(f'{_tmp}/mask_{i}.nii.gz')\n\n# Test 1: Image-only preprocessing (new columns, originals preserved)\n_df1 = pd.DataFrame({'img': [f'{_tmp}/img_{i}.nii.gz' for i in range(3)]})\n_orig_paths1 = _df1['img'].tolist()\n_out1 = f'{_tmp}/out1'\npreprocess_dataset(_df1, img_col='img', output_dir=_out1, apply_reorder=False)\n# Original column preserved\ntest_eq(_df1['img'].tolist(), _orig_paths1)\n# New preprocessed column created\ntest_eq('img_preprocessed' in _df1.columns, True)\ntest_eq(all(Path(p).exists() for p in _df1['img_preprocessed']), True)\ntest_eq(all('out1/images/' in p for p in _df1['img_preprocessed']), True)\n\n# Test 2: Skip-existing (rerun with original paths pointing to same filenames)\n_df2 = pd.DataFrame({'img': [f'{_tmp}/img_{i}.nii.gz' for i in range(3)]})\npreprocess_dataset(_df2, img_col='img', output_dir=_out1, apply_reorder=False)\n# Should print \"0 processed, 3 skipped\"\n\n# Test 3: With masks (both columns preserved, new columns created)\n_df3 = pd.DataFrame({\n 'img': [f'{_tmp}/img_{i}.nii.gz' for i in range(3)],\n 'mask': [f'{_tmp}/mask_{i}.nii.gz' for i in range(3)],\n})\n_orig_img3 = _df3['img'].tolist()\n_orig_mask3 = _df3['mask'].tolist()\n_out3 = f'{_tmp}/out3'\npreprocess_dataset(_df3, img_col='img', mask_col='mask', output_dir=_out3, apply_reorder=False)\n# Original columns preserved\ntest_eq(_df3['img'].tolist(), _orig_img3)\ntest_eq(_df3['mask'].tolist(), _orig_mask3)\n# New preprocessed columns created\ntest_eq(all(Path(p).exists() for p in _df3['img_preprocessed']), True)\ntest_eq(all(Path(p).exists() for p in _df3['mask_preprocessed']), True)\ntest_eq(all('out3/masks/' in p for p in _df3['mask_preprocessed']), True)\n\n# Test 4: Input validation\ntest_fail(lambda: preprocess_dataset(pd.DataFrame(), img_col='img'), contains='empty')\ntest_fail(lambda: preprocess_dataset(pd.DataFrame({'x': [1]}), img_col='img'), contains='not found')\n_df_dup = pd.DataFrame({'img': [f'{_tmp}/img_0.nii.gz', f'{_tmp}/img_0.nii.gz']})\ntest_fail(lambda: preprocess_dataset(_df_dup, img_col='img'), contains='Duplicate')\n\nshutil.rmtree(_tmp)" }, { "cell_type": "markdown", diff --git a/nbs/10_vision_patch.ipynb b/nbs/10_vision_patch.ipynb index 9eab4de..f7dee84 100644 --- a/nbs/10_vision_patch.ipynb +++ b/nbs/10_vision_patch.ipynb @@ -178,7 +178,7 @@ "id": "cell-5", "metadata": {}, "outputs": [], - "source": "#| export\n@dataclass\nclass PatchConfig:\n \"\"\"Configuration for patch-based training and inference.\n \n Args:\n patch_size: Size of patches [x, y, z].\n patch_overlap: Overlap for inference GridSampler (int, float 0-1, or list).\n - Float 0-1: fraction of patch_size (e.g., 0.5 = 50% overlap)\n - Int >= 1: pixel overlap (e.g., 48 = 48 pixel overlap)\n - List: per-dimension overlap in pixels\n samples_per_volume: Number of patches to extract per volume during training.\n sampler_type: Type of sampler ('uniform', 'label', 'weighted').\n label_probabilities: For LabelSampler, dict mapping label values to probabilities.\n queue_length: Maximum number of patches to store in queue.\n queue_num_workers: Number of workers for parallel patch extraction.\n aggregation_mode: For inference, how to combine overlapping patches ('crop', 'average', 'hann').\n apply_reorder: Whether to reorder to RAS+ canonical orientation. Must match between\n training and inference. Defaults to True (the common case).\n target_spacing: Target voxel spacing [x, y, z] for resampling. Must match between\n training and inference.\n padding_mode: Padding mode for CropOrPad when image < patch_size. Default is 0 (zero padding)\n to align with nnU-Net's approach. Can be int, float, or string (e.g., 'minimum', 'mean').\n keep_largest_component: If True, keep only the largest connected component\n in binary segmentation predictions. Only applies during inference when\n return_probabilities=False. Defaults to False.\n \n Example:\n >>> config = PatchConfig(\n ... patch_size=[96, 96, 96],\n ... samples_per_volume=16,\n ... sampler_type='label',\n ... label_probabilities={0: 0.1, 1: 0.9},\n ... target_spacing=[0.5, 0.5, 0.5]\n ... )\n \"\"\"\n patch_size: list = field(default_factory=lambda: [96, 96, 96])\n patch_overlap: int | float | list = 0\n samples_per_volume: int = 8\n sampler_type: str = 'uniform'\n label_probabilities: dict = None\n queue_length: int = 300\n queue_num_workers: int = 4\n aggregation_mode: str = 'hann'\n # Preprocessing parameters - must match between training and inference\n apply_reorder: bool = True # Defaults to True (the common case)\n target_spacing: list = None\n padding_mode: int | float | str = 0 # Zero padding (nnU-Net standard)\n # Post-processing (binary segmentation only)\n keep_largest_component: bool = False\n \n def __post_init__(self):\n \"\"\"Validate configuration.\"\"\"\n valid_samplers = ['uniform', 'label', 'weighted']\n if self.sampler_type not in valid_samplers:\n raise ValueError(f\"sampler_type must be one of {valid_samplers}\")\n \n valid_aggregation = ['crop', 'average', 'hann']\n if self.aggregation_mode not in valid_aggregation:\n raise ValueError(f\"aggregation_mode must be one of {valid_aggregation}\")\n \n # Validate patch_overlap\n # Negative overlap doesn't make sense\n if isinstance(self.patch_overlap, (int, float)):\n if self.patch_overlap < 0:\n raise ValueError(\"patch_overlap cannot be negative\")\n # Check if overlap as pixels would exceed patch_size (causes step_size=0)\n if self.patch_overlap >= 1: # Pixel value, not fraction\n for ps in self.patch_size:\n if self.patch_overlap >= ps:\n raise ValueError(\n f\"patch_overlap ({self.patch_overlap}) must be less than patch_size ({ps}). \"\n f\"Overlap >= patch_size creates step_size <= 0 (infinite patches).\"\n )\n elif isinstance(self.patch_overlap, (list, tuple)):\n for i, (overlap, ps) in enumerate(zip(self.patch_overlap, self.patch_size)):\n if overlap < 0:\n raise ValueError(f\"patch_overlap[{i}] cannot be negative\")\n if overlap >= ps:\n raise ValueError(\n f\"patch_overlap[{i}] ({overlap}) must be less than patch_size[{i}] ({ps}). \"\n f\"Overlap >= patch_size creates step_size <= 0 (infinite patches).\"\n )\n\n # Warn if patch_size dimensions are not divisible by 16\n non_div = [s for s in self.patch_size if s % 16 != 0]\n if non_div:\n warnings.warn(\n f\"patch_size {self.patch_size} has dimensions not divisible by 16. \"\n f\"Most encoder-decoder architectures (e.g., U-Net) require patch sizes \"\n f\"divisible by 16 (2^4 for 4 downsampling levels).\"\n )\n\n @classmethod\n def from_dataset(\n cls,\n dataset: 'MedDataset',\n target_spacing: list = None,\n min_patch_size: list = None,\n max_patch_size: list = None,\n divisor: int = 16,\n **kwargs\n ) -> 'PatchConfig':\n \"\"\"Create PatchConfig with automatic patch_size from dataset analysis.\n\n Combines dataset preprocessing suggestions with patch size calculation\n for a complete, DRY configuration.\n\n Args:\n dataset: MedDataset instance with analyzed images.\n target_spacing: Target voxel spacing [x, y, z]. If None, uses\n dataset.get_suggestion()['target_spacing'].\n min_patch_size: Minimum per dimension [32, 32, 32].\n max_patch_size: Maximum per dimension [256, 256, 256].\n divisor: Divisibility constraint (default 16 for UNet compatibility).\n **kwargs: Additional PatchConfig parameters (samples_per_volume,\n sampler_type, label_probabilities, etc.).\n\n Returns:\n PatchConfig with suggested patch_size, apply_reorder, target_spacing.\n\n Example:\n >>> from fastMONAI.dataset_info import MedDataset\n >>> dataset = MedDataset(dataframe=df, mask_col='mask_path', dtype=MedMask)\n >>> \n >>> # Use recommended spacing\n >>> config = PatchConfig.from_dataset(dataset, samples_per_volume=16)\n >>> \n >>> # Use custom spacing\n >>> config = PatchConfig.from_dataset(\n ... dataset,\n ... target_spacing=[1.0, 1.0, 2.0],\n ... samples_per_volume=16\n ... )\n \"\"\"\n # Get preprocessing suggestion from dataset\n suggestion = dataset.get_suggestion()\n\n # Use explicit spacing or dataset suggestion\n _target_spacing = target_spacing if target_spacing is not None else suggestion['target_spacing']\n\n # Calculate patch size for the target spacing\n patch_size = suggest_patch_size(\n dataset,\n target_spacing=_target_spacing,\n min_patch_size=min_patch_size,\n max_patch_size=max_patch_size,\n divisor=divisor\n )\n\n # Merge with explicit kwargs (kwargs override defaults)\n # Use dataset.apply_reorder directly (not from get_suggestion() since it's not data-derived)\n config_kwargs = {\n 'patch_size': patch_size,\n 'apply_reorder': dataset.apply_reorder,\n 'target_spacing': _target_spacing,\n }\n config_kwargs.update(kwargs)\n\n return cls(**config_kwargs)" + "source": "#| export\n@dataclass\nclass PatchConfig:\n \"\"\"Configuration for patch-based training and inference.\n \n Args:\n patch_size: Size of patches [x, y, z].\n patch_overlap: Overlap for inference GridSampler (int, float 0-1, or list).\n - Float 0-1: fraction of patch_size (e.g., 0.5 = 50% overlap)\n - Int >= 1: pixel overlap (e.g., 48 = 48 pixel overlap)\n - List: per-dimension overlap in pixels\n samples_per_volume: Number of patches to extract per volume during training.\n sampler_type: Type of sampler ('uniform', 'label', 'weighted').\n label_probabilities: For LabelSampler, dict mapping label values to probabilities.\n queue_length: Maximum number of patches to store in queue.\n queue_num_workers: Number of workers for parallel patch extraction.\n aggregation_mode: For inference, how to combine overlapping patches ('crop', 'average', 'hann').\n apply_reorder: Whether to reorder to RAS+ canonical orientation. Must match between\n training and inference. Defaults to True (the common case).\n target_spacing: Target voxel spacing [x, y, z] for resampling. Must match between\n training and inference.\n preprocessed: If True, data has been preprocessed externally (e.g., via\n preprocess_dataset()). Training will skip reorder, resample, AND\n pre_patch_tfms (e.g., normalization) since they were already applied.\n Inference is unaffected and always applies pre_inference_tfms to raw\n images. Defaults to False.\n padding_mode: Padding mode for CropOrPad when image < patch_size. Default is 0 (zero padding)\n to align with nnU-Net's approach. Can be int, float, or string (e.g., 'minimum', 'mean').\n keep_largest_component: If True, keep only the largest connected component\n in binary segmentation predictions. Only applies during inference when\n return_probabilities=False. Defaults to False.\n \n Example:\n >>> config = PatchConfig(\n ... patch_size=[96, 96, 96],\n ... samples_per_volume=16,\n ... sampler_type='label',\n ... label_probabilities={0: 0.1, 1: 0.9},\n ... target_spacing=[0.5, 0.5, 0.5]\n ... )\n \"\"\"\n patch_size: list = field(default_factory=lambda: [96, 96, 96])\n patch_overlap: int | float | list = 0\n samples_per_volume: int = 8\n sampler_type: str = 'uniform'\n label_probabilities: dict = None\n queue_length: int = 300\n queue_num_workers: int = 4\n aggregation_mode: str = 'hann'\n # Preprocessing parameters - must match between training and inference\n apply_reorder: bool = True # Defaults to True (the common case)\n target_spacing: list = None\n preprocessed: bool = False # True = data already preprocessed, skip all preprocessing during training\n padding_mode: int | float | str = 0 # Zero padding (nnU-Net standard)\n # Post-processing (binary segmentation only)\n keep_largest_component: bool = False\n \n def __post_init__(self):\n \"\"\"Validate configuration.\"\"\"\n valid_samplers = ['uniform', 'label', 'weighted']\n if self.sampler_type not in valid_samplers:\n raise ValueError(f\"sampler_type must be one of {valid_samplers}\")\n \n valid_aggregation = ['crop', 'average', 'hann']\n if self.aggregation_mode not in valid_aggregation:\n raise ValueError(f\"aggregation_mode must be one of {valid_aggregation}\")\n \n # Validate patch_overlap\n # Negative overlap doesn't make sense\n if isinstance(self.patch_overlap, (int, float)):\n if self.patch_overlap < 0:\n raise ValueError(\"patch_overlap cannot be negative\")\n # Check if overlap as pixels would exceed patch_size (causes step_size=0)\n if self.patch_overlap >= 1: # Pixel value, not fraction\n for ps in self.patch_size:\n if self.patch_overlap >= ps:\n raise ValueError(\n f\"patch_overlap ({self.patch_overlap}) must be less than patch_size ({ps}). \"\n f\"Overlap >= patch_size creates step_size <= 0 (infinite patches).\"\n )\n elif isinstance(self.patch_overlap, (list, tuple)):\n for i, (overlap, ps) in enumerate(zip(self.patch_overlap, self.patch_size)):\n if overlap < 0:\n raise ValueError(f\"patch_overlap[{i}] cannot be negative\")\n if overlap >= ps:\n raise ValueError(\n f\"patch_overlap[{i}] ({overlap}) must be less than patch_size[{i}] ({ps}). \"\n f\"Overlap >= patch_size creates step_size <= 0 (infinite patches).\"\n )\n\n # Warn if patch_size dimensions are not divisible by 16\n non_div = [s for s in self.patch_size if s % 16 != 0]\n if non_div:\n warnings.warn(\n f\"patch_size {self.patch_size} has dimensions not divisible by 16. \"\n f\"Most encoder-decoder architectures (e.g., U-Net) require patch sizes \"\n f\"divisible by 16 (2^4 for 4 downsampling levels).\"\n )\n\n @classmethod\n def from_dataset(\n cls,\n dataset: 'MedDataset',\n target_spacing: list = None,\n min_patch_size: list = None,\n max_patch_size: list = None,\n divisor: int = 16,\n **kwargs\n ) -> 'PatchConfig':\n \"\"\"Create PatchConfig with automatic patch_size from dataset analysis.\n\n Combines dataset preprocessing suggestions with patch size calculation\n for a complete, DRY configuration.\n\n Args:\n dataset: MedDataset instance with analyzed images.\n target_spacing: Target voxel spacing [x, y, z]. If None, uses\n dataset.get_suggestion()['target_spacing'].\n min_patch_size: Minimum per dimension [32, 32, 32].\n max_patch_size: Maximum per dimension [256, 256, 256].\n divisor: Divisibility constraint (default 16 for UNet compatibility).\n **kwargs: Additional PatchConfig parameters (samples_per_volume,\n sampler_type, label_probabilities, etc.).\n\n Returns:\n PatchConfig with suggested patch_size, apply_reorder, target_spacing.\n\n Example:\n >>> from fastMONAI.dataset_info import MedDataset\n >>> dataset = MedDataset(dataframe=df, mask_col='mask_path', dtype=MedMask)\n >>> \n >>> # Use recommended spacing\n >>> config = PatchConfig.from_dataset(dataset, samples_per_volume=16)\n >>> \n >>> # Use custom spacing\n >>> config = PatchConfig.from_dataset(\n ... dataset,\n ... target_spacing=[1.0, 1.0, 2.0],\n ... samples_per_volume=16\n ... )\n \"\"\"\n # Get preprocessing suggestion from dataset\n suggestion = dataset.get_suggestion()\n\n # Use explicit spacing or dataset suggestion\n _target_spacing = target_spacing if target_spacing is not None else suggestion['target_spacing']\n\n # Calculate patch size for the target spacing\n patch_size = suggest_patch_size(\n dataset,\n target_spacing=_target_spacing,\n min_patch_size=min_patch_size,\n max_patch_size=max_patch_size,\n divisor=divisor\n )\n\n # Merge with explicit kwargs (kwargs override defaults)\n # Use dataset.apply_reorder directly (not from get_suggestion() since it's not data-derived)\n config_kwargs = {\n 'patch_size': patch_size,\n 'apply_reorder': dataset.apply_reorder,\n 'target_spacing': _target_spacing,\n }\n config_kwargs.update(kwargs)\n\n return cls(**config_kwargs)" }, { "cell_type": "code", @@ -186,26 +186,7 @@ "id": "cell-6", "metadata": {}, "outputs": [], - "source": [ - "# Test PatchConfig\n", - "config = PatchConfig(patch_size=[96, 96, 96], samples_per_volume=16)\n", - "test_eq(config.patch_size, [96, 96, 96])\n", - "test_eq(config.samples_per_volume, 16)\n", - "test_eq(config.sampler_type, 'uniform')\n", - "test_eq(config.apply_reorder, True) # Default is now True (the common case)\n", - "test_eq(config.target_spacing, None)\n", - "test_eq(config.padding_mode, 0)\n", - "\n", - "# Test with preprocessing params\n", - "config2 = PatchConfig(\n", - " patch_size=[64, 64, 64],\n", - " apply_reorder=True,\n", - " target_spacing=[0.5, 0.5, 0.5],\n", - " padding_mode=0\n", - ")\n", - "test_eq(config2.apply_reorder, True)\n", - "test_eq(config2.target_spacing, [0.5, 0.5, 0.5])" - ] + "source": "# Test PatchConfig\nconfig = PatchConfig(patch_size=[96, 96, 96], samples_per_volume=16)\ntest_eq(config.patch_size, [96, 96, 96])\ntest_eq(config.samples_per_volume, 16)\ntest_eq(config.sampler_type, 'uniform')\ntest_eq(config.apply_reorder, True) # Default is now True (the common case)\ntest_eq(config.target_spacing, None)\ntest_eq(config.preprocessed, False) # Default is False\ntest_eq(config.padding_mode, 0)\n\n# Test with preprocessing params\nconfig2 = PatchConfig(\n patch_size=[64, 64, 64],\n apply_reorder=True,\n target_spacing=[0.5, 0.5, 0.5],\n padding_mode=0\n)\ntest_eq(config2.apply_reorder, True)\ntest_eq(config2.target_spacing, [0.5, 0.5, 0.5])\n\n# Test preprocessed=True with actual preprocessing params (no warning)\nconfig3 = PatchConfig(\n patch_size=[96, 96, 96],\n apply_reorder=True,\n target_spacing=[0.5, 0.5, 0.5],\n preprocessed=True\n)\ntest_eq(config3.preprocessed, True)\ntest_eq(config3.apply_reorder, True)\ntest_eq(config3.target_spacing, [0.5, 0.5, 0.5])\n\n# Test preprocessed=True without preprocessing params does NOT warn\n# (preprocessed=True still has effect: skips pre_patch_tfms during training)\nwith warnings.catch_warnings(record=True) as w:\n warnings.simplefilter(\"always\")\n config4 = PatchConfig(\n patch_size=[96, 96, 96],\n apply_reorder=False,\n target_spacing=None,\n preprocessed=True\n )\n preprocessed_warns = [x for x in w if 'preprocessed' in str(x.message).lower()]\n test_eq(len(preprocessed_warns), 0)" }, { "cell_type": "markdown", @@ -624,419 +605,7 @@ "id": "cell-15", "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "class MedPatchDataLoaders:\n", - " \"\"\"fastai-compatible DataLoaders for patch-based training with LAZY loading.\n", - "\n", - " This class provides train and validation DataLoaders that work with\n", - " fastai's Learner for patch-based training on 3D medical images.\n", - "\n", - " Memory-efficient: Volumes are loaded on-demand by Queue workers,\n", - " keeping memory usage constant (~150 MB) regardless of dataset size.\n", - "\n", - " Note: Validation uses the same sampling as training (pseudo Dice).\n", - " For true validation metrics, use PatchInferenceEngine with GridSampler\n", - " for full-volume sliding window inference.\n", - "\n", - " Example:\n", - " >>> import torchio as tio\n", - " >>>\n", - " >>> # New pattern: preprocessing params in config (DRY)\n", - " >>> config = PatchConfig(\n", - " ... patch_size=[96, 96, 96],\n", - " ... apply_reorder=True,\n", - " ... target_spacing=[0.5, 0.5, 0.5]\n", - " ... )\n", - " >>> dls = MedPatchDataLoaders.from_df(\n", - " ... df, img_col='image', mask_col='label',\n", - " ... valid_pct=0.2,\n", - " ... patch_config=config,\n", - " ... pre_patch_tfms=[tio.ZNormalization()],\n", - " ... bs=4\n", - " ... )\n", - " >>> learn = Learner(dls, model, loss_func=DiceLoss())\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " train_dl: MedPatchDataLoader,\n", - " valid_dl: MedPatchDataLoader,\n", - " device: torch.device = None\n", - " ):\n", - " self._train_dl = train_dl\n", - " self._valid_dl = valid_dl\n", - " self._device = device or _get_default_device()\n", - "\n", - " # Move to device\n", - " self._train_dl.to(self._device)\n", - " self._valid_dl.to(self._device)\n", - "\n", - " # Track cleanup state\n", - " self._closed = False\n", - "\n", - " @classmethod\n", - " def from_df(\n", - " cls,\n", - " df: pd.DataFrame,\n", - " img_col: str,\n", - " mask_col: str = None,\n", - " valid_pct: float = 0.2,\n", - " valid_col: str = None,\n", - " patch_config: PatchConfig = None,\n", - " pre_patch_tfms: list = None,\n", - " patch_tfms: list = None,\n", - " gpu_augmentation=None,\n", - " apply_reorder: bool = None,\n", - " target_spacing: list = None,\n", - " bs: int = 4,\n", - " seed: int = None,\n", - " device: torch.device = None,\n", - " ensure_affine_consistency: bool = True\n", - " ) -> 'MedPatchDataLoaders':\n", - " \"\"\"Create train/valid DataLoaders from DataFrame with LAZY loading.\n", - "\n", - " Memory-efficient: Only file paths are stored at creation time.\n", - " Volumes are loaded on-demand by Queue workers during training.\n", - "\n", - " Note: Both train and valid use the same sampling strategy from patch_config.\n", - " This gives pseudo Dice during training. For true validation metrics,\n", - " use PatchInferenceEngine with full-volume sliding window inference.\n", - "\n", - " Args:\n", - " df: DataFrame with image paths.\n", - " img_col: Column name for image paths.\n", - " mask_col: Column name for mask paths.\n", - " valid_pct: Fraction of data for validation.\n", - " valid_col: Column name for train/valid split (if pre-defined).\n", - " patch_config: PatchConfig instance. Preprocessing params (apply_reorder,\n", - " target_spacing) can be set here for DRY usage with PatchInferenceEngine.\n", - " pre_patch_tfms: TorchIO transforms applied before patch extraction\n", - " (after reorder/resample). Example: [tio.ZNormalization()].\n", - " Accepts both fastMONAI wrappers and raw TorchIO transforms.\n", - " patch_tfms: TorchIO transforms applied to extracted patches (training only).\n", - " Mutually exclusive with gpu_augmentation.\n", - " gpu_augmentation: GpuPatchAugmentation instance for GPU-batched augmentation\n", - " (training only). Mutually exclusive with patch_tfms.\n", - " apply_reorder: If True, reorder to RAS+ orientation. If None, uses\n", - " patch_config.apply_reorder. Explicit value overrides config.\n", - " target_spacing: Target voxel spacing [x, y, z]. If None, uses\n", - " patch_config.target_spacing. Explicit value overrides config.\n", - " bs: Batch size.\n", - " seed: Random seed for splitting.\n", - " device: Device to use.\n", - " ensure_affine_consistency: If True and mask_col is provided, automatically\n", - " adds tio.CopyAffine(target='image') as the first transform to prevent\n", - " spatial metadata mismatch errors. Defaults to True.\n", - "\n", - " Returns:\n", - " MedPatchDataLoaders instance.\n", - "\n", - " Example:\n", - " >>> # CPU augmentation path (existing)\n", - " >>> dls = MedPatchDataLoaders.from_df(\n", - " ... df, img_col='image', mask_col='label',\n", - " ... patch_config=config,\n", - " ... patch_tfms=[tio.RandomAffine(degrees=10), tio.RandomFlip()],\n", - " ... bs=4\n", - " ... )\n", - " >>>\n", - " >>> # GPU augmentation path (new, faster for long training runs)\n", - " >>> from fastMONAI.vision_augmentation import gpu_patch_augmentations\n", - " >>> gpu_aug = gpu_patch_augmentations(config.patch_size, config.target_spacing)\n", - " >>> dls = MedPatchDataLoaders.from_df(\n", - " ... df, img_col='image', mask_col='label',\n", - " ... patch_config=config,\n", - " ... gpu_augmentation=gpu_aug,\n", - " ... bs=4\n", - " ... )\n", - " \"\"\"\n", - " # Validate mutual exclusivity\n", - " if gpu_augmentation is not None and patch_tfms is not None:\n", - " raise ValueError(\n", - " \"Cannot use both gpu_augmentation and patch_tfms. \"\n", - " \"gpu_augmentation operates on GPU tensors batch-wise, while \"\n", - " \"patch_tfms uses per-sample CPU TorchIO transforms. Choose one.\"\n", - " )\n", - "\n", - " if patch_config is None:\n", - " patch_config = PatchConfig()\n", - "\n", - " # Use config values, allow explicit overrides for backward compatibility\n", - " _apply_reorder = apply_reorder if apply_reorder is not None else patch_config.apply_reorder\n", - " _target_spacing = target_spacing if target_spacing is not None else patch_config.target_spacing\n", - "\n", - " # Warn if both config and explicit args provided with different values\n", - " _warn_config_override('apply_reorder', patch_config.apply_reorder, apply_reorder)\n", - " _warn_config_override('target_spacing', patch_config.target_spacing, target_spacing)\n", - "\n", - " # Split data\n", - " if valid_col is not None:\n", - " train_df = df[df[valid_col] == False].reset_index(drop=True)\n", - " valid_df = df[df[valid_col] == True].reset_index(drop=True)\n", - " else:\n", - " if seed is not None:\n", - " np.random.seed(seed)\n", - " n = len(df)\n", - " valid_idx = np.random.choice(n, size=int(n * valid_pct), replace=False)\n", - " train_idx = np.setdiff1d(np.arange(n), valid_idx)\n", - " train_df = df.iloc[train_idx].reset_index(drop=True)\n", - " valid_df = df.iloc[valid_idx].reset_index(drop=True)\n", - "\n", - " # Build preprocessing transforms\n", - " all_pre_tfms = []\n", - "\n", - " # Add reorder transform (reorder to RAS+ orientation)\n", - " if _apply_reorder:\n", - " all_pre_tfms.append(tio.ToCanonical())\n", - "\n", - " # Add resample transform\n", - " if _target_spacing is not None:\n", - " all_pre_tfms.append(tio.Resample(_target_spacing))\n", - "\n", - " # Add user-provided transforms (normalize to raw TorchIO transforms)\n", - " if pre_patch_tfms:\n", - " all_pre_tfms.extend(normalize_patch_transforms(pre_patch_tfms))\n", - "\n", - " # Create subjects datasets with lazy loading (paths only, ~0 MB)\n", - " train_subjects = create_subjects_dataset(\n", - " train_df, img_col, mask_col,\n", - " pre_tfms=all_pre_tfms if all_pre_tfms else None,\n", - " ensure_affine_consistency=ensure_affine_consistency\n", - " )\n", - " valid_subjects = create_subjects_dataset(\n", - " valid_df, img_col, mask_col,\n", - " pre_tfms=all_pre_tfms if all_pre_tfms else None,\n", - " ensure_affine_consistency=ensure_affine_consistency\n", - " )\n", - "\n", - " # Create DataLoaders (both use same patch_config for consistent sampling)\n", - " train_dl = MedPatchDataLoader(\n", - " train_subjects, patch_config, bs,\n", - " patch_tfms=patch_tfms,\n", - " gpu_augmentation=gpu_augmentation,\n", - " shuffle=True, drop_last=True\n", - " )\n", - " valid_dl = MedPatchDataLoader(\n", - " valid_subjects, patch_config, bs,\n", - " patch_tfms=None, # No augmentation for validation\n", - " gpu_augmentation=None, # No augmentation for validation\n", - " shuffle=False, drop_last=False\n", - " )\n", - "\n", - " # Create instance and store metadata\n", - " instance = cls(train_dl, valid_dl, device)\n", - " instance._img_col = img_col\n", - " instance._mask_col = mask_col\n", - " instance._pre_patch_tfms = pre_patch_tfms\n", - " instance._apply_reorder = _apply_reorder\n", - " instance._target_spacing = _target_spacing\n", - " instance._ensure_affine_consistency = ensure_affine_consistency\n", - " instance._patch_config = patch_config\n", - " instance._train_source_df = train_df\n", - " instance._valid_source_df = valid_df\n", - " return instance\n", - "\n", - " @property\n", - " def train(self):\n", - " \"\"\"Training DataLoader.\"\"\"\n", - " return self._train_dl\n", - "\n", - " @property\n", - " def valid(self):\n", - " \"\"\"Validation DataLoader.\"\"\"\n", - " return self._valid_dl\n", - "\n", - " @property\n", - " def train_ds(self):\n", - " \"\"\"Training subjects dataset.\"\"\"\n", - " return self._train_dl.subjects_dataset\n", - "\n", - " @property\n", - " def valid_ds(self):\n", - " \"\"\"Validation subjects dataset.\"\"\"\n", - " return self._valid_dl.subjects_dataset\n", - "\n", - " @property\n", - " def device(self):\n", - " \"\"\"Current device.\"\"\"\n", - " return self._device\n", - "\n", - " @property\n", - " def bs(self):\n", - " \"\"\"Batch size.\"\"\"\n", - " return self._train_dl.bs\n", - "\n", - " @property\n", - " def apply_reorder(self):\n", - " \"\"\"Whether reordering to RAS+ is enabled.\"\"\"\n", - " return getattr(self, '_apply_reorder', False)\n", - "\n", - " @property\n", - " def target_spacing(self):\n", - " \"\"\"Target voxel spacing for resampling.\"\"\"\n", - " return getattr(self, '_target_spacing', None)\n", - "\n", - " @property\n", - " def patch_config(self):\n", - " \"\"\"The PatchConfig used for this DataLoaders.\"\"\"\n", - " return getattr(self, '_patch_config', None)\n", - "\n", - " @property\n", - " def split_df(self):\n", - " \"\"\"DataFrame recording train/valid split for reproducibility logging.\"\"\"\n", - " train = self._train_source_df.assign(is_valid=False)\n", - " valid = self._valid_source_df.assign(is_valid=True)\n", - " return pd.concat([train, valid], ignore_index=True)\n", - "\n", - " def to(self, device):\n", - " \"\"\"Move DataLoaders to device.\"\"\"\n", - " self._device = device\n", - " self._train_dl.to(device)\n", - " self._valid_dl.to(device)\n", - " return self\n", - "\n", - " def __iter__(self):\n", - " \"\"\"Iterate over training DataLoader.\"\"\"\n", - " return iter(self._train_dl)\n", - "\n", - " def one_batch(self):\n", - " \"\"\"Return one batch from the training DataLoader.\n", - "\n", - " Required for fastai Learner compatibility - used for device\n", - " detection and batch shape validation.\n", - " \"\"\"\n", - " return self._train_dl.one_batch()\n", - "\n", - " def __len__(self):\n", - " \"\"\"Return number of batches in training DataLoader.\"\"\"\n", - " return len(self._train_dl)\n", - "\n", - " def __getitem__(self, idx):\n", - " \"\"\"Get DataLoader by index. Required for fastai Learner compatibility.\n", - "\n", - " Args:\n", - " idx: 0 for training DataLoader, 1 for validation DataLoader.\n", - "\n", - " Returns:\n", - " MedPatchDataLoader instance.\n", - " \"\"\"\n", - " if idx == 0:\n", - " return self._train_dl\n", - " elif idx == 1:\n", - " return self._valid_dl\n", - " else:\n", - " raise IndexError(f\"Index {idx} out of range. Use 0 (train) or 1 (valid).\")\n", - "\n", - " def cuda(self):\n", - " \"\"\"Move DataLoaders to CUDA device.\"\"\"\n", - " return self.to(torch.device('cuda'))\n", - "\n", - " def cpu(self):\n", - " \"\"\"Move DataLoaders to CPU.\"\"\"\n", - " return self.to(torch.device('cpu'))\n", - "\n", - " def show_batch(self, dl_idx=0, max_n=6, figsize=None, channel=0,\n", - " slice_index=None, anatomical_plane=0, overlay=False,\n", - " voxel_size=None, **kwargs):\n", - " \"\"\"Show a batch of patch samples for visualization.\"\"\"\n", - "\n", - " dl = self[dl_idx]\n", - " x, y = dl.one_batch()\n", - " x = x.cpu()\n", - " if y is not None: y = y.cpu()\n", - "\n", - " nrows = min(x.shape[0], max_n)\n", - " has_mask = y is not None\n", - "\n", - " if overlay and has_mask:\n", - " ncols = x.shape[1]\n", - " else:\n", - " ncols = x.shape[1] + (1 if has_mask else 0)\n", - "\n", - " if figsize is None:\n", - " figsize = (ncols * 3, nrows * 3)\n", - " fig, axs = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)\n", - " flat_axs = axs.flatten()\n", - "\n", - " imgs, masks_for_overlay, slice_idxs = [], [], []\n", - " for i in range(nrows):\n", - " img = x[i]\n", - " im_channels = [MedImage(c_img[None]) for c_img in img]\n", - "\n", - " if has_mask:\n", - " mask = y[i]\n", - " idx = find_max_slice(mask[0].numpy(), anatomical_plane) if slice_index is None else slice_index\n", - " if overlay:\n", - " masks_for_overlay.extend([MedMask(mask)] * len(im_channels))\n", - " else:\n", - " im_channels.append(MedMask(mask))\n", - " else:\n", - " idx = slice_index\n", - "\n", - " imgs.extend(im_channels)\n", - " slice_idxs.extend([idx] * len(im_channels))\n", - "\n", - " _voxel_size = voxel_size if voxel_size is not None else self.target_spacing\n", - " ctxs = [im.show(ax=ax, slice_index=idx, anatomical_plane=anatomical_plane,\n", - " voxel_size=_voxel_size)\n", - " for im, ax, idx in zip(imgs, flat_axs, slice_idxs)]\n", - "\n", - " if overlay and has_mask:\n", - " for mask, ax, idx in zip(masks_for_overlay, flat_axs, slice_idxs):\n", - " mask.show(ax=ax, slice_index=idx, anatomical_plane=anatomical_plane,\n", - " voxel_size=_voxel_size)\n", - "\n", - " plt.tight_layout()\n", - " plt.show()\n", - "\n", - " def new_empty(self):\n", - " \"\"\"Create a new empty version of self for learner export.\n", - "\n", - " Required for fastai Learner.export() compatibility - creates a\n", - " lightweight placeholder that can be pickled without the full dataset.\n", - "\n", - " Returns:\n", - " A minimal MedPatchDataLoaders-like object with no data.\n", - " \"\"\"\n", - " class EmptyMedPatchDataLoaders:\n", - " \"\"\"Minimal placeholder for exported learner.\"\"\"\n", - " def __init__(self, device):\n", - " self._device = device\n", - " @property\n", - " def device(self): return self._device\n", - " def to(self, device):\n", - " self._device = device\n", - " return self\n", - " def cpu(self):\n", - " \"\"\"Move to CPU. Required for load_learner compatibility.\"\"\"\n", - " return self.to(torch.device('cpu'))\n", - "\n", - " return EmptyMedPatchDataLoaders(self._device)\n", - "\n", - " def close(self):\n", - " \"\"\"Shut down all DataLoader workers. Safe to call multiple times.\"\"\"\n", - " if self._closed:\n", - " return\n", - " self._closed = True\n", - " if hasattr(self, '_train_dl') and self._train_dl is not None:\n", - " self._train_dl.close()\n", - " if hasattr(self, '_valid_dl') and self._valid_dl is not None:\n", - " self._valid_dl.close()\n", - "\n", - " def __enter__(self):\n", - " return self\n", - "\n", - " def __exit__(self, exc_type, exc_val, exc_tb):\n", - " self.close()\n", - " return False\n", - "\n", - " def __del__(self):\n", - " try:\n", - " self.close()\n", - " except Exception:\n", - " pass" - ] + "source": "#| export\nclass MedPatchDataLoaders:\n \"\"\"fastai-compatible DataLoaders for patch-based training with LAZY loading.\n\n This class provides train and validation DataLoaders that work with\n fastai's Learner for patch-based training on 3D medical images.\n\n Memory-efficient: Volumes are loaded on-demand by Queue workers,\n keeping memory usage constant (~150 MB) regardless of dataset size.\n\n Note: Validation uses the same sampling as training (pseudo Dice).\n For true validation metrics, use PatchInferenceEngine with GridSampler\n for full-volume sliding window inference.\n\n Example:\n >>> import torchio as tio\n >>>\n >>> # New pattern: preprocessing params in config (DRY)\n >>> config = PatchConfig(\n ... patch_size=[96, 96, 96],\n ... apply_reorder=True,\n ... target_spacing=[0.5, 0.5, 0.5]\n ... )\n >>> dls = MedPatchDataLoaders.from_df(\n ... df, img_col='image', mask_col='label',\n ... valid_pct=0.2,\n ... patch_config=config,\n ... pre_patch_tfms=[tio.ZNormalization()],\n ... bs=4\n ... )\n >>> learn = Learner(dls, model, loss_func=DiceLoss())\n \"\"\"\n\n def __init__(\n self,\n train_dl: MedPatchDataLoader,\n valid_dl: MedPatchDataLoader,\n device: torch.device = None\n ):\n self._train_dl = train_dl\n self._valid_dl = valid_dl\n self._device = device or _get_default_device()\n\n # Move to device\n self._train_dl.to(self._device)\n self._valid_dl.to(self._device)\n\n # Track cleanup state\n self._closed = False\n\n @classmethod\n def from_df(\n cls,\n df: pd.DataFrame,\n img_col: str,\n mask_col: str = None,\n valid_pct: float = 0.2,\n valid_col: str = None,\n patch_config: PatchConfig = None,\n pre_patch_tfms: list = None,\n patch_tfms: list = None,\n gpu_augmentation=None,\n apply_reorder: bool = None,\n target_spacing: list = None,\n bs: int = 4,\n seed: int = None,\n device: torch.device = None,\n ensure_affine_consistency: bool = True\n ) -> 'MedPatchDataLoaders':\n \"\"\"Create train/valid DataLoaders from DataFrame with LAZY loading.\n\n Memory-efficient: Only file paths are stored at creation time.\n Volumes are loaded on-demand by Queue workers during training.\n\n Note: Both train and valid use the same sampling strategy from patch_config.\n This gives pseudo Dice during training. For true validation metrics,\n use PatchInferenceEngine with full-volume sliding window inference.\n\n Args:\n df: DataFrame with image paths.\n img_col: Column name for image paths.\n mask_col: Column name for mask paths.\n valid_pct: Fraction of data for validation.\n valid_col: Column name for train/valid split (if pre-defined).\n patch_config: PatchConfig instance. Preprocessing params (apply_reorder,\n target_spacing) can be set here for DRY usage with PatchInferenceEngine.\n pre_patch_tfms: TorchIO transforms applied before patch extraction\n (after reorder/resample). Example: [tio.ZNormalization()].\n Accepts both fastMONAI wrappers and raw TorchIO transforms.\n Skipped when preprocessed=True (include in preprocess_dataset()\n transforms instead). Still needed for inference via pre_inference_tfms.\n patch_tfms: TorchIO transforms applied to extracted patches (training only).\n Mutually exclusive with gpu_augmentation.\n gpu_augmentation: GpuPatchAugmentation instance for GPU-batched augmentation\n (training only). Mutually exclusive with patch_tfms.\n apply_reorder: If True, reorder to RAS+ orientation. If None, uses\n patch_config.apply_reorder. Explicit value overrides config.\n target_spacing: Target voxel spacing [x, y, z]. If None, uses\n patch_config.target_spacing. Explicit value overrides config.\n bs: Batch size.\n seed: Random seed for splitting.\n device: Device to use.\n ensure_affine_consistency: If True and mask_col is provided, automatically\n adds tio.CopyAffine(target='image') as the first transform to prevent\n spatial metadata mismatch errors. Defaults to True.\n\n Returns:\n MedPatchDataLoaders instance.\n\n Example:\n >>> # CPU augmentation path (existing)\n >>> dls = MedPatchDataLoaders.from_df(\n ... df, img_col='image', mask_col='label',\n ... patch_config=config,\n ... patch_tfms=[tio.RandomAffine(degrees=10), tio.RandomFlip()],\n ... bs=4\n ... )\n >>>\n >>> # GPU augmentation path (new, faster for long training runs)\n >>> from fastMONAI.vision_augmentation import gpu_patch_augmentations\n >>> gpu_aug = gpu_patch_augmentations(config.patch_size, config.target_spacing)\n >>> dls = MedPatchDataLoaders.from_df(\n ... df, img_col='image', mask_col='label',\n ... patch_config=config,\n ... gpu_augmentation=gpu_aug,\n ... bs=4\n ... )\n \"\"\"\n # Validate mutual exclusivity\n if gpu_augmentation is not None and patch_tfms is not None:\n raise ValueError(\n \"Cannot use both gpu_augmentation and patch_tfms. \"\n \"gpu_augmentation operates on GPU tensors batch-wise, while \"\n \"patch_tfms uses per-sample CPU TorchIO transforms. Choose one.\"\n )\n\n if patch_config is None:\n patch_config = PatchConfig()\n\n # Use config values, allow explicit overrides for backward compatibility\n _apply_reorder = apply_reorder if apply_reorder is not None else patch_config.apply_reorder\n _target_spacing = target_spacing if target_spacing is not None else patch_config.target_spacing\n\n # Warn if both config and explicit args provided with different values\n _warn_config_override('apply_reorder', patch_config.apply_reorder, apply_reorder)\n _warn_config_override('target_spacing', patch_config.target_spacing, target_spacing)\n\n # Split data\n if valid_col is not None:\n train_df = df[df[valid_col] == False].reset_index(drop=True)\n valid_df = df[df[valid_col] == True].reset_index(drop=True)\n else:\n if seed is not None:\n np.random.seed(seed)\n n = len(df)\n valid_idx = np.random.choice(n, size=int(n * valid_pct), replace=False)\n train_idx = np.setdiff1d(np.arange(n), valid_idx)\n train_df = df.iloc[train_idx].reset_index(drop=True)\n valid_df = df.iloc[valid_idx].reset_index(drop=True)\n\n # Build preprocessing transforms\n all_pre_tfms = []\n\n # Skip all preprocessing if data was already preprocessed externally\n if not patch_config.preprocessed:\n # Add reorder transform (reorder to RAS+ orientation)\n if _apply_reorder:\n all_pre_tfms.append(tio.ToCanonical())\n\n # Add resample transform\n if _target_spacing is not None:\n all_pre_tfms.append(tio.Resample(_target_spacing))\n\n # Add user-provided transforms (normalize to raw TorchIO transforms)\n if pre_patch_tfms:\n all_pre_tfms.extend(normalize_patch_transforms(pre_patch_tfms))\n\n # Create subjects datasets with lazy loading (paths only, ~0 MB)\n train_subjects = create_subjects_dataset(\n train_df, img_col, mask_col,\n pre_tfms=all_pre_tfms if all_pre_tfms else None,\n ensure_affine_consistency=ensure_affine_consistency\n )\n valid_subjects = create_subjects_dataset(\n valid_df, img_col, mask_col,\n pre_tfms=all_pre_tfms if all_pre_tfms else None,\n ensure_affine_consistency=ensure_affine_consistency\n )\n\n # Create DataLoaders (both use same patch_config for consistent sampling)\n train_dl = MedPatchDataLoader(\n train_subjects, patch_config, bs,\n patch_tfms=patch_tfms,\n gpu_augmentation=gpu_augmentation,\n shuffle=True, drop_last=True\n )\n valid_dl = MedPatchDataLoader(\n valid_subjects, patch_config, bs,\n patch_tfms=None, # No augmentation for validation\n gpu_augmentation=None, # No augmentation for validation\n shuffle=False, drop_last=False\n )\n\n # Create instance and store metadata\n instance = cls(train_dl, valid_dl, device)\n instance._img_col = img_col\n instance._mask_col = mask_col\n instance._pre_patch_tfms = pre_patch_tfms\n instance._apply_reorder = _apply_reorder\n instance._target_spacing = _target_spacing\n instance._ensure_affine_consistency = ensure_affine_consistency\n instance._patch_config = patch_config\n instance._train_source_df = train_df\n instance._valid_source_df = valid_df\n return instance\n\n @property\n def train(self):\n \"\"\"Training DataLoader.\"\"\"\n return self._train_dl\n\n @property\n def valid(self):\n \"\"\"Validation DataLoader.\"\"\"\n return self._valid_dl\n\n @property\n def train_ds(self):\n \"\"\"Training subjects dataset.\"\"\"\n return self._train_dl.subjects_dataset\n\n @property\n def valid_ds(self):\n \"\"\"Validation subjects dataset.\"\"\"\n return self._valid_dl.subjects_dataset\n\n @property\n def device(self):\n \"\"\"Current device.\"\"\"\n return self._device\n\n @property\n def bs(self):\n \"\"\"Batch size.\"\"\"\n return self._train_dl.bs\n\n @property\n def apply_reorder(self):\n \"\"\"Whether reordering to RAS+ is enabled.\"\"\"\n return getattr(self, '_apply_reorder', False)\n\n @property\n def target_spacing(self):\n \"\"\"Target voxel spacing for resampling.\"\"\"\n return getattr(self, '_target_spacing', None)\n\n @property\n def patch_config(self):\n \"\"\"The PatchConfig used for this DataLoaders.\"\"\"\n return getattr(self, '_patch_config', None)\n\n @property\n def split_df(self):\n \"\"\"DataFrame recording train/valid split for reproducibility logging.\"\"\"\n train = self._train_source_df.assign(is_valid=False)\n valid = self._valid_source_df.assign(is_valid=True)\n return pd.concat([train, valid], ignore_index=True)\n\n def to(self, device):\n \"\"\"Move DataLoaders to device.\"\"\"\n self._device = device\n self._train_dl.to(device)\n self._valid_dl.to(device)\n return self\n\n def __iter__(self):\n \"\"\"Iterate over training DataLoader.\"\"\"\n return iter(self._train_dl)\n\n def one_batch(self):\n \"\"\"Return one batch from the training DataLoader.\n\n Required for fastai Learner compatibility - used for device\n detection and batch shape validation.\n \"\"\"\n return self._train_dl.one_batch()\n\n def __len__(self):\n \"\"\"Return number of batches in training DataLoader.\"\"\"\n return len(self._train_dl)\n\n def __getitem__(self, idx):\n \"\"\"Get DataLoader by index. Required for fastai Learner compatibility.\n\n Args:\n idx: 0 for training DataLoader, 1 for validation DataLoader.\n\n Returns:\n MedPatchDataLoader instance.\n \"\"\"\n if idx == 0:\n return self._train_dl\n elif idx == 1:\n return self._valid_dl\n else:\n raise IndexError(f\"Index {idx} out of range. Use 0 (train) or 1 (valid).\")\n\n def cuda(self):\n \"\"\"Move DataLoaders to CUDA device.\"\"\"\n return self.to(torch.device('cuda'))\n\n def cpu(self):\n \"\"\"Move DataLoaders to CPU.\"\"\"\n return self.to(torch.device('cpu'))\n\n def show_batch(self, dl_idx=0, max_n=6, figsize=None, channel=0,\n slice_index=None, anatomical_plane=0, overlay=False,\n voxel_size=None, **kwargs):\n \"\"\"Show a batch of patch samples for visualization.\"\"\"\n\n dl = self[dl_idx]\n x, y = dl.one_batch()\n x = x.cpu()\n if y is not None: y = y.cpu()\n\n nrows = min(x.shape[0], max_n)\n has_mask = y is not None\n\n if overlay and has_mask:\n ncols = x.shape[1]\n else:\n ncols = x.shape[1] + (1 if has_mask else 0)\n\n if figsize is None:\n figsize = (ncols * 3, nrows * 3)\n fig, axs = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)\n flat_axs = axs.flatten()\n\n imgs, masks_for_overlay, slice_idxs = [], [], []\n for i in range(nrows):\n img = x[i]\n im_channels = [MedImage(c_img[None]) for c_img in img]\n\n if has_mask:\n mask = y[i]\n idx = find_max_slice(mask[0].numpy(), anatomical_plane) if slice_index is None else slice_index\n if overlay:\n masks_for_overlay.extend([MedMask(mask)] * len(im_channels))\n else:\n im_channels.append(MedMask(mask))\n else:\n idx = slice_index\n\n imgs.extend(im_channels)\n slice_idxs.extend([idx] * len(im_channels))\n\n _voxel_size = voxel_size if voxel_size is not None else self.target_spacing\n ctxs = [im.show(ax=ax, slice_index=idx, anatomical_plane=anatomical_plane,\n voxel_size=_voxel_size)\n for im, ax, idx in zip(imgs, flat_axs, slice_idxs)]\n\n if overlay and has_mask:\n for mask, ax, idx in zip(masks_for_overlay, flat_axs, slice_idxs):\n mask.show(ax=ax, slice_index=idx, anatomical_plane=anatomical_plane,\n voxel_size=_voxel_size)\n\n plt.tight_layout()\n plt.show()\n\n def new_empty(self):\n \"\"\"Create a new empty version of self for learner export.\n\n Required for fastai Learner.export() compatibility - creates a\n lightweight placeholder that can be pickled without the full dataset.\n\n Returns:\n A minimal MedPatchDataLoaders-like object with no data.\n \"\"\"\n class EmptyMedPatchDataLoaders:\n \"\"\"Minimal placeholder for exported learner.\"\"\"\n def __init__(self, device):\n self._device = device\n @property\n def device(self): return self._device\n def to(self, device):\n self._device = device\n return self\n def cpu(self):\n \"\"\"Move to CPU. Required for load_learner compatibility.\"\"\"\n return self.to(torch.device('cpu'))\n\n return EmptyMedPatchDataLoaders(self._device)\n\n def close(self):\n \"\"\"Shut down all DataLoader workers. Safe to call multiple times.\"\"\"\n if self._closed:\n return\n self._closed = True\n if hasattr(self, '_train_dl') and self._train_dl is not None:\n self._train_dl.close()\n if hasattr(self, '_valid_dl') and self._valid_dl is not None:\n self._valid_dl.close()\n\n def __enter__(self):\n return self\n\n def __exit__(self, exc_type, exc_val, exc_tb):\n self.close()\n return False\n\n def __del__(self):\n try:\n self.close()\n except Exception:\n pass" }, { "cell_type": "code", diff --git a/settings.ini b/settings.ini index 736e347..a481ccf 100644 --- a/settings.ini +++ b/settings.ini @@ -5,7 +5,7 @@ ### Python Library ### lib_name = fastMONAI min_python = 3.10 -version = 0.8.1 +version = 0.8.2 ### OPTIONAL ### requirements = fastai==2.8.6 monai==1.5.2 torchio==0.21.2 xlrd>=1.2.0 scikit-image==0.26.0 imagedata==3.8.14 mlflow==3.9.0 huggingface-hub gdown gradio opencv-python plum-dispatch