diff --git a/fastMONAI/_modidx.py b/fastMONAI/_modidx.py index 67305a8..e9823a0 100644 --- a/fastMONAI/_modidx.py +++ b/fastMONAI/_modidx.py @@ -553,10 +553,18 @@ 'fastMONAI/vision_patch.py'), 'fastMONAI.vision_patch.PatchInferenceEngine.__init__': ( 'vision_patch.html#patchinferenceengine.__init__', 'fastMONAI/vision_patch.py'), + 'fastMONAI.vision_patch.PatchInferenceEngine._postprocess': ( 'vision_patch.html#patchinferenceengine._postprocess', + 'fastMONAI/vision_patch.py'), + 'fastMONAI.vision_patch.PatchInferenceEngine._prepare_subject': ( 'vision_patch.html#patchinferenceengine._prepare_subject', + 'fastMONAI/vision_patch.py'), + 'fastMONAI.vision_patch.PatchInferenceEngine._run_inference': ( 'vision_patch.html#patchinferenceengine._run_inference', + 'fastMONAI/vision_patch.py'), 'fastMONAI.vision_patch.PatchInferenceEngine.predict': ( 'vision_patch.html#patchinferenceengine.predict', 'fastMONAI/vision_patch.py'), 'fastMONAI.vision_patch.PatchInferenceEngine.to': ( 'vision_patch.html#patchinferenceengine.to', 'fastMONAI/vision_patch.py'), + 'fastMONAI.vision_patch._PreparedSubject': ( 'vision_patch.html#_preparedsubject', + 'fastMONAI/vision_patch.py'), 'fastMONAI.vision_patch._extract_tio_transform': ( 'vision_patch.html#_extract_tio_transform', 'fastMONAI/vision_patch.py'), 'fastMONAI.vision_patch._get_default_device': ( 'vision_patch.html#_get_default_device', @@ -565,6 +573,8 @@ 'fastMONAI/vision_patch.py'), 'fastMONAI.vision_patch._predict_patch_tta': ( 'vision_patch.html#_predict_patch_tta', 'fastMONAI/vision_patch.py'), + 'fastMONAI.vision_patch._save_prediction': ( 'vision_patch.html#_save_prediction', + 'fastMONAI/vision_patch.py'), 'fastMONAI.vision_patch._warn_config_override': ( 'vision_patch.html#_warn_config_override', 'fastMONAI/vision_patch.py'), 'fastMONAI.vision_patch.create_patch_sampler': ( 'vision_patch.html#create_patch_sampler', diff --git a/fastMONAI/vision_patch.py b/fastMONAI/vision_patch.py index 4e17780..e6d4544 100644 --- a/fastMONAI/vision_patch.py +++ b/fastMONAI/vision_patch.py @@ -1078,6 +1078,18 @@ def _predict_patch_tta(model, patch_input): return (summed_probs / len(_TTA_FLIP_AXES)).cpu() +@dataclass +class _PreparedSubject: + """Intermediate state from image preparation, used for pipelined inference.""" + subject: tio.Subject + org_img: tio.Image + input_img: tio.Image + org_size: tuple + grid_sampler: tio.GridSampler + aggregator: tio.GridAggregator + patch_loader: DataLoader + + class PatchInferenceEngine: """Patch-based inference with automatic volume reconstruction. @@ -1150,30 +1162,19 @@ def __init__( self._device = next(self.model.parameters()).device except StopIteration: self._device = _get_default_device() - - def predict( - self, - img_path: Path | str, - return_probabilities: bool = False, - return_affine: bool = False, - tta: bool = False - ) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]: - """Predict on a single volume using patch-based inference. + + def _prepare_subject(self, img_path: Path | str) -> _PreparedSubject: + """Load and preprocess image, create GridSampler/Aggregator/DataLoader. + + Thread-safe: creates only new local objects, reads only immutable self config. Args: img_path: Path to input image. - return_probabilities: If True, return probability map instead of argmax. - return_affine: If True, return (prediction, affine) tuple instead of just prediction. - tta: If True, apply nnU-Net-style mirror test-time augmentation - (8 flip combinations, averaged probabilities). Requires ~8x inference - time but improves prediction quality. Works best when training used - RandomFlip(axes='LRAPIS', p=0.5). Defaults to False. Returns: - Predicted segmentation mask tensor, or tuple (prediction, affine) if return_affine=True. + _PreparedSubject with all intermediate objects needed for inference. """ # Load image - keep org_img and org_size for post-processing - # Note: med_img_reader handles reorder/resample internally, no global state needed org_img, input_img, org_size = med_img_reader( img_path, apply_reorder=self.apply_reorder, target_spacing=self.target_spacing, only_tensor=False ) @@ -1188,11 +1189,10 @@ def predict( subject = self.pre_inference_tfms(subject) # Pad dimensions smaller than patch_size, keep larger dimensions intact - # GridSampler handles large images via overlapping patches img_shape = subject['image'].shape[1:] # Exclude channel dim target_size = [max(s, p) for s, p in zip(img_shape, self.config.patch_size)] - - # Warn if volume needed padding (may cause artifacts if training didn't cover similar sizes) + + # Warn if volume needed padding if any(s < p for s, p in zip(img_shape, self.config.patch_size)): padded_dims = [f"dim{i}: {s}<{p}" for i, (s, p) in enumerate(zip(img_shape, self.config.patch_size)) if s < p] warnings.warn( @@ -1200,66 +1200,82 @@ def predict( f"in {padded_dims}. Padding with mode={self.config.padding_mode}. " "Ensure training data covered similar sizes to avoid artifacts." ) - - # Use padding_mode from config (default: 0 for zero padding, nnU-Net standard) + subject = tio.CropOrPad(target_size, padding_mode=self.config.padding_mode)(subject) # Convert patch_overlap to integer pixel values for TorchIO compatibility patch_overlap = _normalize_patch_overlap(self.config.patch_overlap, self.config.patch_size) - # Create GridSampler grid_sampler = tio.GridSampler( - subject, - patch_size=self.config.patch_size, - patch_overlap=patch_overlap + subject, patch_size=self.config.patch_size, patch_overlap=patch_overlap ) - - # Create GridAggregator aggregator = tio.GridAggregator( - grid_sampler, - overlap_mode=self.config.aggregation_mode + grid_sampler, overlap_mode=self.config.aggregation_mode ) + patch_loader = DataLoader(grid_sampler, batch_size=self.batch_size, num_workers=0) - # Create patch loader - patch_loader = DataLoader( - grid_sampler, - batch_size=self.batch_size, - num_workers=0 + return _PreparedSubject( + subject=subject, org_img=org_img, input_img=input_img, + org_size=org_size, grid_sampler=grid_sampler, + aggregator=aggregator, patch_loader=patch_loader ) - # Predict patches + def _run_inference(self, prepared: _PreparedSubject, tta: bool = False) -> torch.Tensor: + """Run model inference on all patches and aggregate. + + Must run on the main thread (model forward pass). + + Args: + prepared: _PreparedSubject from _prepare_subject(). + tta: If True, apply mirror test-time augmentation. + + Returns: + Raw output tensor from aggregator (probabilities). + """ self.model.eval() - with torch.no_grad(): - for patches_batch in patch_loader: + # inference_mode is slightly faster than no_grad (disables autograd tracking + # and view tracking). Safe here since we don't do in-place ops on outputs. + with torch.inference_mode(): + for patches_batch in prepared.patch_loader: patch_input = patches_batch['image'][tio.DATA].to(self._device) locations = patches_batch[tio.LOCATION] if tta: probs = _predict_patch_tta(self.model, patch_input) else: - # Forward pass - get logits logits = self.model(patch_input) - - # Convert logits to probabilities BEFORE aggregation - # This is critical: softmax is non-linear, so we must aggregate - # probabilities, not logits, to get correct boundary handling n_classes = logits.shape[1] if n_classes == 1: probs = torch.sigmoid(logits) else: - probs = torch.softmax(logits, dim=1) # dim=1 for batch [B, C, H, W, D] - + probs = torch.softmax(logits, dim=1) probs = probs.cpu() - # Add probabilities to aggregator - aggregator.add_batch(probs, locations) + prepared.aggregator.add_batch(probs, locations) + + return prepared.aggregator.get_output_tensor() + + def _postprocess( + self, + output: torch.Tensor, + prepared: _PreparedSubject, + return_probabilities: bool = False + ) -> tuple[torch.Tensor, np.ndarray]: + """Post-process aggregated output: threshold, resize, reorient. + + Always returns (result, affine) tuple. - # Get reconstructed output (now contains probabilities, not logits) - output = aggregator.get_output_tensor() + Args: + output: Raw output tensor from _run_inference(). + prepared: _PreparedSubject with original image metadata. + return_probabilities: If True, keep probability map instead of argmax. + Returns: + Tuple of (prediction tensor, affine matrix). + """ # Convert to prediction mask (only if not returning probabilities) if return_probabilities: - result = output # Keep as float probabilities + result = output else: n_classes = output.shape[0] if n_classes == 1: @@ -1272,43 +1288,62 @@ def predict( from fastMONAI.vision_inference import keep_largest result = keep_largest(result.squeeze(0)).unsqueeze(0) - # Post-processing: resize back to original size and reorient - # This matches the workflow in vision_inference.py - # Wrap result in TorchIO Image for resizing - # Use ScalarImage for probabilities, LabelMap for masks if return_probabilities: - pred_img = tio.ScalarImage(tensor=result.float(), affine=input_img.affine) + pred_img = tio.ScalarImage(tensor=result.float(), affine=prepared.input_img.affine) else: - pred_img = tio.LabelMap(tensor=result.float(), affine=input_img.affine) - + pred_img = tio.LabelMap(tensor=result.float(), affine=prepared.input_img.affine) + # Resize back to original size (before resampling) - pred_img = _do_resize(pred_img, org_size, image_interpolation='nearest') - + pred_img = _do_resize(pred_img, prepared.org_size, image_interpolation='nearest') + # Reorient to original orientation (if reorder was applied) - # Use explicit .cpu() for consistent device handling if self.apply_reorder: reoriented_array = _to_original_orientation( pred_img.as_sitk(), - ('').join(org_img.orientation) + ('').join(prepared.org_img.orientation) ) result = torch.from_numpy(reoriented_array).cpu() - # Only convert to long for masks, not probabilities if not return_probabilities: result = result.long() else: result = pred_img.data.cpu() - # Only convert to long for masks, not probabilities if not return_probabilities: result = result.long() # Use original affine matrix for correct spatial alignment - # org_img.affine is always available from med_img_reader - if not (hasattr(org_img, 'affine') and org_img.affine is not None): + if not (hasattr(prepared.org_img, 'affine') and prepared.org_img.affine is not None): raise RuntimeError( "org_img.affine not available. This should never happen - please report this bug." ) - affine = org_img.affine.copy() + affine = prepared.org_img.affine.copy() + + return result, affine + + def predict( + self, + img_path: Path | str, + return_probabilities: bool = False, + return_affine: bool = False, + tta: bool = False + ) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]: + """Predict on a single volume using patch-based inference. + + Args: + img_path: Path to input image. + return_probabilities: If True, return probability map instead of argmax. + return_affine: If True, return (prediction, affine) tuple instead of just prediction. + tta: If True, apply nnU-Net-style mirror test-time augmentation + (8 flip combinations, averaged probabilities). Requires ~8x inference + time but improves prediction quality. Works best when training used + RandomFlip(axes='LRAPIS', p=0.5). Defaults to False. + + Returns: + Predicted segmentation mask tensor, or tuple (prediction, affine) if return_affine=True. + """ + prepared = self._prepare_subject(img_path) + output = self._run_inference(prepared, tta=tta) + result, affine = self._postprocess(output, prepared, return_probabilities) if return_affine: return result, affine @@ -1321,6 +1356,39 @@ def to(self, device): return self # %% ../nbs/10_vision_patch.ipynb #cell-18 +from concurrent.futures import ThreadPoolExecutor + + +def _save_prediction(pred, affine, input_path, save_path, return_probabilities): + """Save a single prediction as NIfTI file. + + Module-level helper (no closure captures) for thread-safe background saving. + + Args: + pred: Prediction tensor. + affine: Affine matrix for spatial alignment. + input_path: Original input file path (for deriving output filename). + save_path: Directory Path to save into. + return_probabilities: If True, save as ScalarImage; else LabelMap. + """ + input_path = Path(input_path) + stem = input_path.stem + if input_path.suffix == '.gz' and stem.endswith('.nii'): + stem = stem[:-4] + out_name = f"{stem}_pred.nii.gz" + elif input_path.suffix == '.nii': + out_name = f"{stem}_pred.nii" + else: + out_name = f"{stem}_pred.nii.gz" + out_path = save_path / out_name + + if return_probabilities: + pred_img = tio.ScalarImage(tensor=pred, affine=affine) + else: + pred_img = tio.LabelMap(tensor=pred, affine=affine) + pred_img.save(out_path) + + def patch_inference( learner, config: PatchConfig, @@ -1332,10 +1400,17 @@ def patch_inference( progress: bool = True, save_dir: str = None, pre_inference_tfms: list = None, - tta: bool = False + tta: bool = False, + prefetch: bool = True ) -> list: """Batch patch-based inference on multiple volumes. - + + When prefetch=True (default), overlaps I/O with compute: while the current + image is being inferred, the next image is loaded and preprocessed in a + background thread, and the previous result is saved in the background. + This eliminates most I/O idle time, especially on GPU where CPU prep and + GPU compute use different hardware. + Args: learner: PyTorch model or fastai Learner. config: PatchConfig with inference settings. Preprocessing params (apply_reorder, @@ -1350,10 +1425,14 @@ def patch_inference( pre_inference_tfms: List of TorchIO transforms to apply before patch extraction. IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]). tta: If True, apply nnU-Net-style mirror TTA (8 flip combinations). - + prefetch: If True (default), overlap I/O with compute using a background + thread for preparation and saving. Holds two subjects in memory + simultaneously (current + next). Set to False for memory-constrained + environments processing very large volumes. + Returns: List of predicted tensors. - + Example: >>> config = PatchConfig( ... patch_size=[96, 96, 96], @@ -1371,53 +1450,74 @@ def patch_inference( # Use config values if not explicitly provided _apply_reorder = apply_reorder if apply_reorder is not None else config.apply_reorder _target_spacing = target_spacing if target_spacing is not None else config.target_spacing - + engine = PatchInferenceEngine( learner, config, _apply_reorder, _target_spacing, batch_size, pre_inference_tfms ) - + # Create save directory if specified + save_path = None if save_dir is not None: save_path = Path(save_dir) save_path.mkdir(parents=True, exist_ok=True) - + predictions = [] desc = 'Patch inference (TTA)' if tta else 'Patch inference' - iterator = tqdm(file_paths, desc=desc) if progress else file_paths - - for path in iterator: - # Get prediction and affine when saving is needed - if save_dir is not None: - pred, affine = engine.predict(path, return_probabilities, return_affine=True, tta=tta) - else: - pred = engine.predict(path, return_probabilities, tta=tta) - predictions.append(pred) - - # Save prediction if save_dir specified - if save_dir is not None: - input_path = Path(path) - # Create output filename based on input using suffix-based approach - # This handles .nii.gz correctly without corrupting filenames with .nii elsewhere - stem = input_path.stem - if input_path.suffix == '.gz' and stem.endswith('.nii'): - # Handle .nii.gz files: stem is "filename.nii", strip the .nii - stem = stem[:-4] - out_name = f"{stem}_pred.nii.gz" - elif input_path.suffix == '.nii': - # Handle .nii files - out_name = f"{stem}_pred.nii" - else: - # Fallback for other formats - out_name = f"{stem}_pred.nii.gz" - out_path = save_path / out_name - - # affine is guaranteed to be valid from engine.predict() with return_affine=True - # Save as NIfTI using TorchIO with correct type - # Use ScalarImage for probabilities (float), LabelMap for masks (int) - if return_probabilities: - pred_img = tio.ScalarImage(tensor=pred, affine=affine) + n_files = len(file_paths) + + # Pipelined path: overlap I/O with compute + if prefetch and n_files > 1: + pbar = tqdm(total=n_files, desc=desc) if progress else None + with ThreadPoolExecutor(max_workers=1) as pool: + # Kick off preparation of the first image + prefetch_future = pool.submit(engine._prepare_subject, file_paths[0]) + save_future = None + + for i in range(n_files): + # Wait for the prefetched subject + prepared = prefetch_future.result() + + # Start prefetching the next image (if any) + if i + 1 < n_files: + prefetch_future = pool.submit(engine._prepare_subject, file_paths[i + 1]) + + # Run inference on the main thread + output = engine._run_inference(prepared, tta=tta) + result, affine = engine._postprocess(output, prepared, return_probabilities) + predictions.append(result) + + # Wait for previous save to complete before submitting a new one + if save_future is not None: + save_future.result() + + # Submit current save in background + if save_path is not None: + save_future = pool.submit( + _save_prediction, result, affine, file_paths[i], + save_path, return_probabilities + ) + + if pbar is not None: + pbar.update(1) + + # Wait for final save + if save_future is not None: + save_future.result() + + if pbar is not None: + pbar.close() + + # Sequential fallback: single image or prefetch disabled + else: + iterator = tqdm(file_paths, desc=desc) if progress else file_paths + for path in iterator: + if save_dir is not None: + pred, affine = engine.predict(path, return_probabilities, return_affine=True, tta=tta) else: - pred_img = tio.LabelMap(tensor=pred, affine=affine) - pred_img.save(out_path) - + pred = engine.predict(path, return_probabilities, tta=tta) + predictions.append(pred) + + if save_dir is not None: + _save_prediction(pred, affine, path, save_path, return_probabilities) + return predictions diff --git a/nbs/10_vision_patch.ipynb b/nbs/10_vision_patch.ipynb index f7dee84..1111e8b 100644 --- a/nbs/10_vision_patch.ipynb +++ b/nbs/10_vision_patch.ipynb @@ -649,7 +649,7 @@ "id": "cell-17", "metadata": {}, "outputs": [], - "source": "#| export\nimport numbers\n\ndef _normalize_patch_overlap(patch_overlap, patch_size):\n \"\"\"Convert patch_overlap to integer pixel values for TorchIO compatibility.\n\n TorchIO's GridSampler expects patch_overlap as a tuple of even integers.\n This function handles:\n - Fractional overlap (0-1): converted to pixel values based on patch_size\n - Numpy scalar types: converted to native Python types\n - Sequences: converted to tuple of integers\n\n Note: Input validation (negative values, overlap >= patch_size) is handled\n by PatchConfig.__post_init__(). This function focuses on format conversion.\n\n Args:\n patch_overlap: int, float (0-1 for fraction), or sequence\n patch_size: list/tuple of patch dimensions [x, y, z]\n\n Returns:\n Tuple of even integers suitable for TorchIO GridSampler\n \"\"\"\n # Handle scalar fractional overlap (0 < x < 1)\n # Note: excludes 1.0 as 100% overlap creates step_size=0 (infinite patches)\n if isinstance(patch_overlap, (int, float, numbers.Number)) and 0 < float(patch_overlap) < 1:\n # Convert fraction to pixel values, ensure even\n result = []\n for ps in patch_size:\n pixels = int(int(ps) * float(patch_overlap))\n # Ensure even (required by TorchIO)\n if pixels % 2 != 0:\n pixels = pixels - 1 if pixels > 0 else 0\n result.append(pixels)\n return tuple(result)\n\n # Handle scalar integer (including numpy scalars) - values > 1 are pixel counts\n if isinstance(patch_overlap, (int, float, numbers.Number)):\n val = int(patch_overlap)\n # Ensure even\n if val % 2 != 0:\n val = val - 1 if val > 0 else 0\n return tuple(val for _ in patch_size)\n\n # Handle sequences (list, tuple, ndarray)\n result = []\n for val in patch_overlap:\n pixels = int(val)\n if pixels % 2 != 0:\n pixels = pixels - 1 if pixels > 0 else 0\n result.append(pixels)\n return tuple(result)\n\n\n# nnU-Net-style mirror TTA: all 2^3 = 8 flip combinations for 3D.\n# Batch tensor shape: [B, C, D, H, W], spatial dims are 2, 3, 4.\n_TTA_FLIP_AXES = (\n (), # original\n (4,), # flip LR (W)\n (3,), # flip AP (H)\n (2,), # flip IS (D)\n (3, 4), # flip LR+AP\n (2, 4), # flip LR+IS\n (2, 3), # flip AP+IS\n (2, 3, 4), # flip all\n)\n\n\ndef _predict_patch_tta(model, patch_input):\n \"\"\"nnU-Net-style mirror TTA: average probabilities over 8 flip combinations.\n\n Runs 8 forward passes with a running sum for memory efficiency (2x memory,\n not 9x). Each pass: flip input -> forward -> activate -> flip back -> accumulate.\n\n Args:\n model: PyTorch model in eval mode (already on device).\n patch_input: Batch tensor [B, C, D, H, W] already on device.\n\n Returns:\n Averaged probability tensor [B, C, D, H, W] on CPU.\n \"\"\"\n summed_probs = None\n for axes in _TTA_FLIP_AXES:\n flipped = torch.flip(patch_input, list(axes)) if axes else patch_input\n logits = model(flipped)\n n_classes = logits.shape[1]\n probs = torch.sigmoid(logits) if n_classes == 1 else torch.softmax(logits, dim=1)\n if axes:\n probs = torch.flip(probs, list(axes))\n summed_probs = probs if summed_probs is None else summed_probs + probs\n return (summed_probs / len(_TTA_FLIP_AXES)).cpu()\n\n\nclass PatchInferenceEngine:\n \"\"\"Patch-based inference with automatic volume reconstruction.\n \n Uses TorchIO's GridSampler to extract overlapping patches and\n GridAggregator to reconstruct the full volume from predictions.\n \n Args:\n learner: fastai Learner or PyTorch model (nn.Module). When passing a raw\n PyTorch model, load weights first with model.load_state_dict().\n config: PatchConfig with inference settings. Preprocessing params (apply_reorder,\n target_spacing, padding_mode) can be set here for DRY usage.\n apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value.\n target_spacing: Target voxel spacing. If None, uses config value.\n batch_size: Number of patches to predict at once. Must be positive.\n pre_inference_tfms: List of TorchIO transforms to apply before patch extraction.\n IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]).\n This ensures preprocessing consistency between training and inference.\n Accepts both fastMONAI wrappers and raw TorchIO transforms.\n \n Example:\n >>> # Option 1: From fastai Learner\n >>> engine = PatchInferenceEngine(learn, config, pre_inference_tfms=[ZNormalization()])\n >>> pred = engine.predict('image.nii.gz')\n \n >>> # Option 2: From raw PyTorch model (recommended for deployment)\n >>> model = UNet(spatial_dims=3, in_channels=1, out_channels=2, ...)\n >>> model.load_state_dict(torch.load('final_weights.pth'))\n >>> model.cuda().eval()\n >>> engine = PatchInferenceEngine(model, config, pre_inference_tfms=[ZNormalization()])\n >>> pred = engine.predict('image.nii.gz')\n \"\"\"\n \n def __init__(\n self,\n learner,\n config: PatchConfig,\n apply_reorder: bool = None,\n target_spacing: list = None,\n batch_size: int = 4,\n pre_inference_tfms: list = None\n ):\n if batch_size <= 0:\n raise ValueError(f\"batch_size must be positive, got {batch_size}\")\n \n # Extract model from Learner if needed (use isinstance for robust detection)\n # Note: We check for Learner explicitly because some models (e.g., MONAI UNet)\n # have a .model attribute that is NOT the full model but an internal Sequential.\n if isinstance(learner, Learner):\n self.model = learner.model\n else:\n self.model = learner # Assume it's already a PyTorch model\n \n self.config = config\n self.batch_size = batch_size\n \n # Normalize transforms to raw TorchIO (accepts both fastMONAI wrappers and raw TorchIO)\n normalized_tfms = normalize_patch_transforms(pre_inference_tfms)\n self.pre_inference_tfms = tio.Compose(normalized_tfms) if normalized_tfms else None\n \n # Use config values, allow explicit overrides for backward compatibility\n self.apply_reorder = apply_reorder if apply_reorder is not None else config.apply_reorder\n self.target_spacing = target_spacing if target_spacing is not None else config.target_spacing\n \n # Warn if explicit args provided but differ from config (potential mistake)\n _warn_config_override('apply_reorder', config.apply_reorder, apply_reorder)\n _warn_config_override('target_spacing', config.target_spacing, target_spacing)\n \n # Get device from model parameters, with fallback for parameter-less models\n try:\n self._device = next(self.model.parameters()).device\n except StopIteration:\n self._device = _get_default_device()\n \n def predict(\n self,\n img_path: Path | str,\n return_probabilities: bool = False,\n return_affine: bool = False,\n tta: bool = False\n ) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:\n \"\"\"Predict on a single volume using patch-based inference.\n\n Args:\n img_path: Path to input image.\n return_probabilities: If True, return probability map instead of argmax.\n return_affine: If True, return (prediction, affine) tuple instead of just prediction.\n tta: If True, apply nnU-Net-style mirror test-time augmentation\n (8 flip combinations, averaged probabilities). Requires ~8x inference\n time but improves prediction quality. Works best when training used\n RandomFlip(axes='LRAPIS', p=0.5). Defaults to False.\n\n Returns:\n Predicted segmentation mask tensor, or tuple (prediction, affine) if return_affine=True.\n \"\"\"\n # Load image - keep org_img and org_size for post-processing\n # Note: med_img_reader handles reorder/resample internally, no global state needed\n org_img, input_img, org_size = med_img_reader(\n img_path, apply_reorder=self.apply_reorder, target_spacing=self.target_spacing, only_tensor=False\n )\n\n # Create TorchIO Subject from preprocessed image\n subject = tio.Subject(\n image=tio.ScalarImage(tensor=input_img.data.float(), affine=input_img.affine)\n )\n\n # Apply pre-inference transforms (e.g., ZNormalization) to match training\n if self.pre_inference_tfms is not None:\n subject = self.pre_inference_tfms(subject)\n\n # Pad dimensions smaller than patch_size, keep larger dimensions intact\n # GridSampler handles large images via overlapping patches\n img_shape = subject['image'].shape[1:] # Exclude channel dim\n target_size = [max(s, p) for s, p in zip(img_shape, self.config.patch_size)]\n \n # Warn if volume needed padding (may cause artifacts if training didn't cover similar sizes)\n if any(s < p for s, p in zip(img_shape, self.config.patch_size)):\n padded_dims = [f\"dim{i}: {s}<{p}\" for i, (s, p) in enumerate(zip(img_shape, self.config.patch_size)) if s < p]\n warnings.warn(\n f\"Image size {list(img_shape)} smaller than patch_size {self.config.patch_size} \"\n f\"in {padded_dims}. Padding with mode={self.config.padding_mode}. \"\n \"Ensure training data covered similar sizes to avoid artifacts.\"\n )\n \n # Use padding_mode from config (default: 0 for zero padding, nnU-Net standard)\n subject = tio.CropOrPad(target_size, padding_mode=self.config.padding_mode)(subject)\n\n # Convert patch_overlap to integer pixel values for TorchIO compatibility\n patch_overlap = _normalize_patch_overlap(self.config.patch_overlap, self.config.patch_size)\n\n # Create GridSampler\n grid_sampler = tio.GridSampler(\n subject,\n patch_size=self.config.patch_size,\n patch_overlap=patch_overlap\n )\n\n # Create GridAggregator\n aggregator = tio.GridAggregator(\n grid_sampler,\n overlap_mode=self.config.aggregation_mode\n )\n\n # Create patch loader\n patch_loader = DataLoader(\n grid_sampler,\n batch_size=self.batch_size,\n num_workers=0\n )\n\n # Predict patches\n self.model.eval()\n with torch.no_grad():\n for patches_batch in patch_loader:\n patch_input = patches_batch['image'][tio.DATA].to(self._device)\n locations = patches_batch[tio.LOCATION]\n\n if tta:\n probs = _predict_patch_tta(self.model, patch_input)\n else:\n # Forward pass - get logits\n logits = self.model(patch_input)\n\n # Convert logits to probabilities BEFORE aggregation\n # This is critical: softmax is non-linear, so we must aggregate\n # probabilities, not logits, to get correct boundary handling\n n_classes = logits.shape[1]\n if n_classes == 1:\n probs = torch.sigmoid(logits)\n else:\n probs = torch.softmax(logits, dim=1) # dim=1 for batch [B, C, H, W, D]\n\n probs = probs.cpu()\n\n # Add probabilities to aggregator\n aggregator.add_batch(probs, locations)\n\n # Get reconstructed output (now contains probabilities, not logits)\n output = aggregator.get_output_tensor()\n\n # Convert to prediction mask (only if not returning probabilities)\n if return_probabilities:\n result = output # Keep as float probabilities\n else:\n n_classes = output.shape[0]\n if n_classes == 1:\n result = (output > 0.5).float()\n else:\n result = output.argmax(dim=0, keepdim=True).float()\n\n # Apply keep_largest post-processing for binary segmentation\n if not return_probabilities and self.config.keep_largest_component:\n from fastMONAI.vision_inference import keep_largest\n result = keep_largest(result.squeeze(0)).unsqueeze(0)\n\n # Post-processing: resize back to original size and reorient\n # This matches the workflow in vision_inference.py\n \n # Wrap result in TorchIO Image for resizing\n # Use ScalarImage for probabilities, LabelMap for masks\n if return_probabilities:\n pred_img = tio.ScalarImage(tensor=result.float(), affine=input_img.affine)\n else:\n pred_img = tio.LabelMap(tensor=result.float(), affine=input_img.affine)\n \n # Resize back to original size (before resampling)\n pred_img = _do_resize(pred_img, org_size, image_interpolation='nearest')\n \n # Reorient to original orientation (if reorder was applied)\n # Use explicit .cpu() for consistent device handling\n if self.apply_reorder:\n reoriented_array = _to_original_orientation(\n pred_img.as_sitk(),\n ('').join(org_img.orientation)\n )\n result = torch.from_numpy(reoriented_array).cpu()\n # Only convert to long for masks, not probabilities\n if not return_probabilities:\n result = result.long()\n else:\n result = pred_img.data.cpu()\n # Only convert to long for masks, not probabilities\n if not return_probabilities:\n result = result.long()\n\n # Use original affine matrix for correct spatial alignment\n # org_img.affine is always available from med_img_reader\n if not (hasattr(org_img, 'affine') and org_img.affine is not None):\n raise RuntimeError(\n \"org_img.affine not available. This should never happen - please report this bug.\"\n )\n affine = org_img.affine.copy()\n\n if return_affine:\n return result, affine\n return result\n \n def to(self, device):\n \"\"\"Move engine to device.\"\"\"\n self._device = device\n self.model.to(device)\n return self" + "source": "#| export\nimport numbers\n\ndef _normalize_patch_overlap(patch_overlap, patch_size):\n \"\"\"Convert patch_overlap to integer pixel values for TorchIO compatibility.\n\n TorchIO's GridSampler expects patch_overlap as a tuple of even integers.\n This function handles:\n - Fractional overlap (0-1): converted to pixel values based on patch_size\n - Numpy scalar types: converted to native Python types\n - Sequences: converted to tuple of integers\n\n Note: Input validation (negative values, overlap >= patch_size) is handled\n by PatchConfig.__post_init__(). This function focuses on format conversion.\n\n Args:\n patch_overlap: int, float (0-1 for fraction), or sequence\n patch_size: list/tuple of patch dimensions [x, y, z]\n\n Returns:\n Tuple of even integers suitable for TorchIO GridSampler\n \"\"\"\n # Handle scalar fractional overlap (0 < x < 1)\n # Note: excludes 1.0 as 100% overlap creates step_size=0 (infinite patches)\n if isinstance(patch_overlap, (int, float, numbers.Number)) and 0 < float(patch_overlap) < 1:\n # Convert fraction to pixel values, ensure even\n result = []\n for ps in patch_size:\n pixels = int(int(ps) * float(patch_overlap))\n # Ensure even (required by TorchIO)\n if pixels % 2 != 0:\n pixels = pixels - 1 if pixels > 0 else 0\n result.append(pixels)\n return tuple(result)\n\n # Handle scalar integer (including numpy scalars) - values > 1 are pixel counts\n if isinstance(patch_overlap, (int, float, numbers.Number)):\n val = int(patch_overlap)\n # Ensure even\n if val % 2 != 0:\n val = val - 1 if val > 0 else 0\n return tuple(val for _ in patch_size)\n\n # Handle sequences (list, tuple, ndarray)\n result = []\n for val in patch_overlap:\n pixels = int(val)\n if pixels % 2 != 0:\n pixels = pixels - 1 if pixels > 0 else 0\n result.append(pixels)\n return tuple(result)\n\n\n# nnU-Net-style mirror TTA: all 2^3 = 8 flip combinations for 3D.\n# Batch tensor shape: [B, C, D, H, W], spatial dims are 2, 3, 4.\n_TTA_FLIP_AXES = (\n (), # original\n (4,), # flip LR (W)\n (3,), # flip AP (H)\n (2,), # flip IS (D)\n (3, 4), # flip LR+AP\n (2, 4), # flip LR+IS\n (2, 3), # flip AP+IS\n (2, 3, 4), # flip all\n)\n\n\ndef _predict_patch_tta(model, patch_input):\n \"\"\"nnU-Net-style mirror TTA: average probabilities over 8 flip combinations.\n\n Runs 8 forward passes with a running sum for memory efficiency (2x memory,\n not 9x). Each pass: flip input -> forward -> activate -> flip back -> accumulate.\n\n Args:\n model: PyTorch model in eval mode (already on device).\n patch_input: Batch tensor [B, C, D, H, W] already on device.\n\n Returns:\n Averaged probability tensor [B, C, D, H, W] on CPU.\n \"\"\"\n summed_probs = None\n for axes in _TTA_FLIP_AXES:\n flipped = torch.flip(patch_input, list(axes)) if axes else patch_input\n logits = model(flipped)\n n_classes = logits.shape[1]\n probs = torch.sigmoid(logits) if n_classes == 1 else torch.softmax(logits, dim=1)\n if axes:\n probs = torch.flip(probs, list(axes))\n summed_probs = probs if summed_probs is None else summed_probs + probs\n return (summed_probs / len(_TTA_FLIP_AXES)).cpu()\n\n\n@dataclass\nclass _PreparedSubject:\n \"\"\"Intermediate state from image preparation, used for pipelined inference.\"\"\"\n subject: tio.Subject\n org_img: tio.Image\n input_img: tio.Image\n org_size: tuple\n grid_sampler: tio.GridSampler\n aggregator: tio.GridAggregator\n patch_loader: DataLoader\n\n\nclass PatchInferenceEngine:\n \"\"\"Patch-based inference with automatic volume reconstruction.\n \n Uses TorchIO's GridSampler to extract overlapping patches and\n GridAggregator to reconstruct the full volume from predictions.\n \n Args:\n learner: fastai Learner or PyTorch model (nn.Module). When passing a raw\n PyTorch model, load weights first with model.load_state_dict().\n config: PatchConfig with inference settings. Preprocessing params (apply_reorder,\n target_spacing, padding_mode) can be set here for DRY usage.\n apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value.\n target_spacing: Target voxel spacing. If None, uses config value.\n batch_size: Number of patches to predict at once. Must be positive.\n pre_inference_tfms: List of TorchIO transforms to apply before patch extraction.\n IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]).\n This ensures preprocessing consistency between training and inference.\n Accepts both fastMONAI wrappers and raw TorchIO transforms.\n \n Example:\n >>> # Option 1: From fastai Learner\n >>> engine = PatchInferenceEngine(learn, config, pre_inference_tfms=[ZNormalization()])\n >>> pred = engine.predict('image.nii.gz')\n \n >>> # Option 2: From raw PyTorch model (recommended for deployment)\n >>> model = UNet(spatial_dims=3, in_channels=1, out_channels=2, ...)\n >>> model.load_state_dict(torch.load('final_weights.pth'))\n >>> model.cuda().eval()\n >>> engine = PatchInferenceEngine(model, config, pre_inference_tfms=[ZNormalization()])\n >>> pred = engine.predict('image.nii.gz')\n \"\"\"\n \n def __init__(\n self,\n learner,\n config: PatchConfig,\n apply_reorder: bool = None,\n target_spacing: list = None,\n batch_size: int = 4,\n pre_inference_tfms: list = None\n ):\n if batch_size <= 0:\n raise ValueError(f\"batch_size must be positive, got {batch_size}\")\n \n # Extract model from Learner if needed (use isinstance for robust detection)\n # Note: We check for Learner explicitly because some models (e.g., MONAI UNet)\n # have a .model attribute that is NOT the full model but an internal Sequential.\n if isinstance(learner, Learner):\n self.model = learner.model\n else:\n self.model = learner # Assume it's already a PyTorch model\n \n self.config = config\n self.batch_size = batch_size\n \n # Normalize transforms to raw TorchIO (accepts both fastMONAI wrappers and raw TorchIO)\n normalized_tfms = normalize_patch_transforms(pre_inference_tfms)\n self.pre_inference_tfms = tio.Compose(normalized_tfms) if normalized_tfms else None\n \n # Use config values, allow explicit overrides for backward compatibility\n self.apply_reorder = apply_reorder if apply_reorder is not None else config.apply_reorder\n self.target_spacing = target_spacing if target_spacing is not None else config.target_spacing\n \n # Warn if explicit args provided but differ from config (potential mistake)\n _warn_config_override('apply_reorder', config.apply_reorder, apply_reorder)\n _warn_config_override('target_spacing', config.target_spacing, target_spacing)\n \n # Get device from model parameters, with fallback for parameter-less models\n try:\n self._device = next(self.model.parameters()).device\n except StopIteration:\n self._device = _get_default_device()\n\n def _prepare_subject(self, img_path: Path | str) -> _PreparedSubject:\n \"\"\"Load and preprocess image, create GridSampler/Aggregator/DataLoader.\n\n Thread-safe: creates only new local objects, reads only immutable self config.\n\n Args:\n img_path: Path to input image.\n\n Returns:\n _PreparedSubject with all intermediate objects needed for inference.\n \"\"\"\n # Load image - keep org_img and org_size for post-processing\n org_img, input_img, org_size = med_img_reader(\n img_path, apply_reorder=self.apply_reorder, target_spacing=self.target_spacing, only_tensor=False\n )\n\n # Create TorchIO Subject from preprocessed image\n subject = tio.Subject(\n image=tio.ScalarImage(tensor=input_img.data.float(), affine=input_img.affine)\n )\n\n # Apply pre-inference transforms (e.g., ZNormalization) to match training\n if self.pre_inference_tfms is not None:\n subject = self.pre_inference_tfms(subject)\n\n # Pad dimensions smaller than patch_size, keep larger dimensions intact\n img_shape = subject['image'].shape[1:] # Exclude channel dim\n target_size = [max(s, p) for s, p in zip(img_shape, self.config.patch_size)]\n\n # Warn if volume needed padding\n if any(s < p for s, p in zip(img_shape, self.config.patch_size)):\n padded_dims = [f\"dim{i}: {s}<{p}\" for i, (s, p) in enumerate(zip(img_shape, self.config.patch_size)) if s < p]\n warnings.warn(\n f\"Image size {list(img_shape)} smaller than patch_size {self.config.patch_size} \"\n f\"in {padded_dims}. Padding with mode={self.config.padding_mode}. \"\n \"Ensure training data covered similar sizes to avoid artifacts.\"\n )\n\n subject = tio.CropOrPad(target_size, padding_mode=self.config.padding_mode)(subject)\n\n # Convert patch_overlap to integer pixel values for TorchIO compatibility\n patch_overlap = _normalize_patch_overlap(self.config.patch_overlap, self.config.patch_size)\n\n grid_sampler = tio.GridSampler(\n subject, patch_size=self.config.patch_size, patch_overlap=patch_overlap\n )\n aggregator = tio.GridAggregator(\n grid_sampler, overlap_mode=self.config.aggregation_mode\n )\n patch_loader = DataLoader(grid_sampler, batch_size=self.batch_size, num_workers=0)\n\n return _PreparedSubject(\n subject=subject, org_img=org_img, input_img=input_img,\n org_size=org_size, grid_sampler=grid_sampler,\n aggregator=aggregator, patch_loader=patch_loader\n )\n\n def _run_inference(self, prepared: _PreparedSubject, tta: bool = False) -> torch.Tensor:\n \"\"\"Run model inference on all patches and aggregate.\n\n Must run on the main thread (model forward pass).\n\n Args:\n prepared: _PreparedSubject from _prepare_subject().\n tta: If True, apply mirror test-time augmentation.\n\n Returns:\n Raw output tensor from aggregator (probabilities).\n \"\"\"\n self.model.eval()\n # inference_mode is slightly faster than no_grad (disables autograd tracking\n # and view tracking). Safe here since we don't do in-place ops on outputs.\n with torch.inference_mode():\n for patches_batch in prepared.patch_loader:\n patch_input = patches_batch['image'][tio.DATA].to(self._device)\n locations = patches_batch[tio.LOCATION]\n\n if tta:\n probs = _predict_patch_tta(self.model, patch_input)\n else:\n logits = self.model(patch_input)\n n_classes = logits.shape[1]\n if n_classes == 1:\n probs = torch.sigmoid(logits)\n else:\n probs = torch.softmax(logits, dim=1)\n probs = probs.cpu()\n\n prepared.aggregator.add_batch(probs, locations)\n\n return prepared.aggregator.get_output_tensor()\n\n def _postprocess(\n self,\n output: torch.Tensor,\n prepared: _PreparedSubject,\n return_probabilities: bool = False\n ) -> tuple[torch.Tensor, np.ndarray]:\n \"\"\"Post-process aggregated output: threshold, resize, reorient.\n\n Always returns (result, affine) tuple.\n\n Args:\n output: Raw output tensor from _run_inference().\n prepared: _PreparedSubject with original image metadata.\n return_probabilities: If True, keep probability map instead of argmax.\n\n Returns:\n Tuple of (prediction tensor, affine matrix).\n \"\"\"\n # Convert to prediction mask (only if not returning probabilities)\n if return_probabilities:\n result = output\n else:\n n_classes = output.shape[0]\n if n_classes == 1:\n result = (output > 0.5).float()\n else:\n result = output.argmax(dim=0, keepdim=True).float()\n\n # Apply keep_largest post-processing for binary segmentation\n if not return_probabilities and self.config.keep_largest_component:\n from fastMONAI.vision_inference import keep_largest\n result = keep_largest(result.squeeze(0)).unsqueeze(0)\n\n # Wrap result in TorchIO Image for resizing\n if return_probabilities:\n pred_img = tio.ScalarImage(tensor=result.float(), affine=prepared.input_img.affine)\n else:\n pred_img = tio.LabelMap(tensor=result.float(), affine=prepared.input_img.affine)\n\n # Resize back to original size (before resampling)\n pred_img = _do_resize(pred_img, prepared.org_size, image_interpolation='nearest')\n\n # Reorient to original orientation (if reorder was applied)\n if self.apply_reorder:\n reoriented_array = _to_original_orientation(\n pred_img.as_sitk(),\n ('').join(prepared.org_img.orientation)\n )\n result = torch.from_numpy(reoriented_array).cpu()\n if not return_probabilities:\n result = result.long()\n else:\n result = pred_img.data.cpu()\n if not return_probabilities:\n result = result.long()\n\n # Use original affine matrix for correct spatial alignment\n if not (hasattr(prepared.org_img, 'affine') and prepared.org_img.affine is not None):\n raise RuntimeError(\n \"org_img.affine not available. This should never happen - please report this bug.\"\n )\n affine = prepared.org_img.affine.copy()\n\n return result, affine\n\n def predict(\n self,\n img_path: Path | str,\n return_probabilities: bool = False,\n return_affine: bool = False,\n tta: bool = False\n ) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:\n \"\"\"Predict on a single volume using patch-based inference.\n\n Args:\n img_path: Path to input image.\n return_probabilities: If True, return probability map instead of argmax.\n return_affine: If True, return (prediction, affine) tuple instead of just prediction.\n tta: If True, apply nnU-Net-style mirror test-time augmentation\n (8 flip combinations, averaged probabilities). Requires ~8x inference\n time but improves prediction quality. Works best when training used\n RandomFlip(axes='LRAPIS', p=0.5). Defaults to False.\n\n Returns:\n Predicted segmentation mask tensor, or tuple (prediction, affine) if return_affine=True.\n \"\"\"\n prepared = self._prepare_subject(img_path)\n output = self._run_inference(prepared, tta=tta)\n result, affine = self._postprocess(output, prepared, return_probabilities)\n\n if return_affine:\n return result, affine\n return result\n \n def to(self, device):\n \"\"\"Move engine to device.\"\"\"\n self._device = device\n self.model.to(device)\n return self" }, { "cell_type": "code", @@ -688,7 +688,7 @@ "id": "cell-18", "metadata": {}, "outputs": [], - "source": "#| export\ndef patch_inference(\n learner,\n config: PatchConfig,\n file_paths: list,\n apply_reorder: bool = None,\n target_spacing: list = None,\n batch_size: int = 4,\n return_probabilities: bool = False,\n progress: bool = True,\n save_dir: str = None,\n pre_inference_tfms: list = None,\n tta: bool = False\n) -> list:\n \"\"\"Batch patch-based inference on multiple volumes.\n \n Args:\n learner: PyTorch model or fastai Learner.\n config: PatchConfig with inference settings. Preprocessing params (apply_reorder,\n target_spacing) can be set here for DRY usage.\n file_paths: List of image paths.\n apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value.\n target_spacing: Target voxel spacing. If None, uses config value.\n batch_size: Patches per batch.\n return_probabilities: Return probability maps.\n progress: Show progress bar.\n save_dir: Directory to save predictions as NIfTI files. If None, predictions are not saved.\n pre_inference_tfms: List of TorchIO transforms to apply before patch extraction.\n IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]).\n tta: If True, apply nnU-Net-style mirror TTA (8 flip combinations).\n \n Returns:\n List of predicted tensors.\n \n Example:\n >>> config = PatchConfig(\n ... patch_size=[96, 96, 96],\n ... apply_reorder=True,\n ... target_spacing=[0.4102, 0.4102, 1.5]\n ... )\n >>> predictions = patch_inference(\n ... learner=learn,\n ... config=config, # apply_reorder and target_spacing from config\n ... file_paths=val_paths,\n ... pre_inference_tfms=[tio.ZNormalization()],\n ... save_dir='predictions/patch_based'\n ... )\n \"\"\"\n # Use config values if not explicitly provided\n _apply_reorder = apply_reorder if apply_reorder is not None else config.apply_reorder\n _target_spacing = target_spacing if target_spacing is not None else config.target_spacing\n \n engine = PatchInferenceEngine(\n learner, config, _apply_reorder, _target_spacing, batch_size, pre_inference_tfms\n )\n \n # Create save directory if specified\n if save_dir is not None:\n save_path = Path(save_dir)\n save_path.mkdir(parents=True, exist_ok=True)\n \n predictions = []\n desc = 'Patch inference (TTA)' if tta else 'Patch inference'\n iterator = tqdm(file_paths, desc=desc) if progress else file_paths\n \n for path in iterator:\n # Get prediction and affine when saving is needed\n if save_dir is not None:\n pred, affine = engine.predict(path, return_probabilities, return_affine=True, tta=tta)\n else:\n pred = engine.predict(path, return_probabilities, tta=tta)\n predictions.append(pred)\n \n # Save prediction if save_dir specified\n if save_dir is not None:\n input_path = Path(path)\n # Create output filename based on input using suffix-based approach\n # This handles .nii.gz correctly without corrupting filenames with .nii elsewhere\n stem = input_path.stem\n if input_path.suffix == '.gz' and stem.endswith('.nii'):\n # Handle .nii.gz files: stem is \"filename.nii\", strip the .nii\n stem = stem[:-4]\n out_name = f\"{stem}_pred.nii.gz\"\n elif input_path.suffix == '.nii':\n # Handle .nii files\n out_name = f\"{stem}_pred.nii\"\n else:\n # Fallback for other formats\n out_name = f\"{stem}_pred.nii.gz\"\n out_path = save_path / out_name\n \n # affine is guaranteed to be valid from engine.predict() with return_affine=True\n # Save as NIfTI using TorchIO with correct type\n # Use ScalarImage for probabilities (float), LabelMap for masks (int)\n if return_probabilities:\n pred_img = tio.ScalarImage(tensor=pred, affine=affine)\n else:\n pred_img = tio.LabelMap(tensor=pred, affine=affine)\n pred_img.save(out_path)\n \n return predictions" + "source": "#| export\nfrom concurrent.futures import ThreadPoolExecutor\n\n\ndef _save_prediction(pred, affine, input_path, save_path, return_probabilities):\n \"\"\"Save a single prediction as NIfTI file.\n\n Module-level helper (no closure captures) for thread-safe background saving.\n\n Args:\n pred: Prediction tensor.\n affine: Affine matrix for spatial alignment.\n input_path: Original input file path (for deriving output filename).\n save_path: Directory Path to save into.\n return_probabilities: If True, save as ScalarImage; else LabelMap.\n \"\"\"\n input_path = Path(input_path)\n stem = input_path.stem\n if input_path.suffix == '.gz' and stem.endswith('.nii'):\n stem = stem[:-4]\n out_name = f\"{stem}_pred.nii.gz\"\n elif input_path.suffix == '.nii':\n out_name = f\"{stem}_pred.nii\"\n else:\n out_name = f\"{stem}_pred.nii.gz\"\n out_path = save_path / out_name\n\n if return_probabilities:\n pred_img = tio.ScalarImage(tensor=pred, affine=affine)\n else:\n pred_img = tio.LabelMap(tensor=pred, affine=affine)\n pred_img.save(out_path)\n\n\ndef patch_inference(\n learner,\n config: PatchConfig,\n file_paths: list,\n apply_reorder: bool = None,\n target_spacing: list = None,\n batch_size: int = 4,\n return_probabilities: bool = False,\n progress: bool = True,\n save_dir: str = None,\n pre_inference_tfms: list = None,\n tta: bool = False,\n prefetch: bool = True\n) -> list:\n \"\"\"Batch patch-based inference on multiple volumes.\n\n When prefetch=True (default), overlaps I/O with compute: while the current\n image is being inferred, the next image is loaded and preprocessed in a\n background thread, and the previous result is saved in the background.\n This eliminates most I/O idle time, especially on GPU where CPU prep and\n GPU compute use different hardware.\n\n Args:\n learner: PyTorch model or fastai Learner.\n config: PatchConfig with inference settings. Preprocessing params (apply_reorder,\n target_spacing) can be set here for DRY usage.\n file_paths: List of image paths.\n apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value.\n target_spacing: Target voxel spacing. If None, uses config value.\n batch_size: Patches per batch.\n return_probabilities: Return probability maps.\n progress: Show progress bar.\n save_dir: Directory to save predictions as NIfTI files. If None, predictions are not saved.\n pre_inference_tfms: List of TorchIO transforms to apply before patch extraction.\n IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]).\n tta: If True, apply nnU-Net-style mirror TTA (8 flip combinations).\n prefetch: If True (default), overlap I/O with compute using a background\n thread for preparation and saving. Holds two subjects in memory\n simultaneously (current + next). Set to False for memory-constrained\n environments processing very large volumes.\n\n Returns:\n List of predicted tensors.\n\n Example:\n >>> config = PatchConfig(\n ... patch_size=[96, 96, 96],\n ... apply_reorder=True,\n ... target_spacing=[0.4102, 0.4102, 1.5]\n ... )\n >>> predictions = patch_inference(\n ... learner=learn,\n ... config=config, # apply_reorder and target_spacing from config\n ... file_paths=val_paths,\n ... pre_inference_tfms=[tio.ZNormalization()],\n ... save_dir='predictions/patch_based'\n ... )\n \"\"\"\n # Use config values if not explicitly provided\n _apply_reorder = apply_reorder if apply_reorder is not None else config.apply_reorder\n _target_spacing = target_spacing if target_spacing is not None else config.target_spacing\n\n engine = PatchInferenceEngine(\n learner, config, _apply_reorder, _target_spacing, batch_size, pre_inference_tfms\n )\n\n # Create save directory if specified\n save_path = None\n if save_dir is not None:\n save_path = Path(save_dir)\n save_path.mkdir(parents=True, exist_ok=True)\n\n predictions = []\n desc = 'Patch inference (TTA)' if tta else 'Patch inference'\n n_files = len(file_paths)\n\n # Pipelined path: overlap I/O with compute\n if prefetch and n_files > 1:\n pbar = tqdm(total=n_files, desc=desc) if progress else None\n with ThreadPoolExecutor(max_workers=1) as pool:\n # Kick off preparation of the first image\n prefetch_future = pool.submit(engine._prepare_subject, file_paths[0])\n save_future = None\n\n for i in range(n_files):\n # Wait for the prefetched subject\n prepared = prefetch_future.result()\n\n # Start prefetching the next image (if any)\n if i + 1 < n_files:\n prefetch_future = pool.submit(engine._prepare_subject, file_paths[i + 1])\n\n # Run inference on the main thread\n output = engine._run_inference(prepared, tta=tta)\n result, affine = engine._postprocess(output, prepared, return_probabilities)\n predictions.append(result)\n\n # Wait for previous save to complete before submitting a new one\n if save_future is not None:\n save_future.result()\n\n # Submit current save in background\n if save_path is not None:\n save_future = pool.submit(\n _save_prediction, result, affine, file_paths[i],\n save_path, return_probabilities\n )\n\n if pbar is not None:\n pbar.update(1)\n\n # Wait for final save\n if save_future is not None:\n save_future.result()\n\n if pbar is not None:\n pbar.close()\n\n # Sequential fallback: single image or prefetch disabled\n else:\n iterator = tqdm(file_paths, desc=desc) if progress else file_paths\n for path in iterator:\n if save_dir is not None:\n pred, affine = engine.predict(path, return_probabilities, return_affine=True, tta=tta)\n else:\n pred = engine.predict(path, return_probabilities, tta=tta)\n predictions.append(pred)\n\n if save_dir is not None:\n _save_prediction(pred, affine, path, save_path, return_probabilities)\n\n return predictions" }, { "cell_type": "code", @@ -698,6 +698,14 @@ "outputs": [], "source": "# Test _TTA_FLIP_AXES and _predict_patch_tta\nfrom itertools import combinations\n\n# Test 1: _TTA_FLIP_AXES has exactly 8 entries (2^3 combinations for 3 axes)\ntest_eq(len(_TTA_FLIP_AXES), 8)\n\n# Verify all 2^3 combinations are present (each axis in {2,3,4} independently on/off)\nexpected_combos = set()\naxes = [2, 3, 4]\nfor r in range(len(axes) + 1):\n for combo in combinations(axes, r):\n expected_combos.add(combo)\nactual_combos = set(tuple(sorted(a)) for a in _TTA_FLIP_AXES)\ntest_eq(actual_combos, expected_combos)\n\n# Test 2: _predict_patch_tta output shape and probability range\nimport torch.nn as nn\n\nclass _SimpleConv(nn.Module):\n \"\"\"Minimal model for TTA testing.\"\"\"\n def __init__(self, out_channels):\n super().__init__()\n self.conv = nn.Conv3d(1, out_channels, 1)\n def forward(self, x):\n return self.conv(x)\n\n# Binary case (1 output channel -> sigmoid)\nmodel_bin = _SimpleConv(1).eval()\ndummy_input = torch.randn(2, 1, 8, 8, 8) # [B=2, C=1, D, H, W]\nwith torch.no_grad():\n tta_out = _predict_patch_tta(model_bin, dummy_input)\ntest_eq(tta_out.shape, torch.Size([2, 1, 8, 8, 8]))\nassert tta_out.min() >= 0.0 and tta_out.max() <= 1.0, f\"Probabilities out of range: [{tta_out.min()}, {tta_out.max()}]\"\n\n# Multi-class case (3 output channels -> softmax)\nmodel_mc = _SimpleConv(3).eval()\nwith torch.no_grad():\n tta_out_mc = _predict_patch_tta(model_mc, dummy_input)\ntest_eq(tta_out_mc.shape, torch.Size([2, 3, 8, 8, 8]))\nassert tta_out_mc.min() >= 0.0 and tta_out_mc.max() <= 1.0\n\n# Test 3: TTA on constant input matches single forward pass\n# A constant tensor is invariant to flipping, so TTA should equal single pass\nconst_input = torch.ones(1, 1, 8, 8, 8) * 0.5\nwith torch.no_grad():\n single_logits = model_bin(const_input)\n single_probs = torch.sigmoid(single_logits).cpu()\n tta_probs = _predict_patch_tta(model_bin, const_input)\nassert torch.allclose(single_probs, tta_probs, atol=1e-6), \"TTA on constant input should match single forward pass\"\n\nprint(\"TTA tests passed!\")" }, + { + "cell_type": "code", + "execution_count": null, + "id": "9jwnjpeb5qa", + "metadata": {}, + "outputs": [], + "source": "# Test _PreparedSubject and decomposed predict path\nimport tempfile, os, nibabel as nib\nfrom monai.networks.nets import UNet\nfrom monai.networks.layers import Norm\n\n# Create a small synthetic NIfTI file for testing\n_test_data = np.random.randn(32, 32, 32).astype(np.float32)\n_test_affine = np.eye(4)\n_test_nii = nib.Nifti1Image(_test_data, _test_affine)\n\nwith tempfile.TemporaryDirectory() as tmpdir:\n img_path = os.path.join(tmpdir, 'test_img.nii.gz')\n nib.save(_test_nii, img_path)\n\n _model = UNet(\n spatial_dims=3, in_channels=1, out_channels=2,\n channels=(16, 32), strides=(2,), num_res_units=1,\n norm=Norm.INSTANCE\n ).eval()\n _config = PatchConfig(patch_size=[32, 32, 32])\n _engine = PatchInferenceEngine(_model, _config, apply_reorder=False)\n\n # Test 1: _prepare_subject returns _PreparedSubject with expected attributes\n prepared = _engine._prepare_subject(img_path)\n assert isinstance(prepared, _PreparedSubject), \"Should return _PreparedSubject\"\n assert isinstance(prepared.subject, tio.Subject)\n assert isinstance(prepared.grid_sampler, tio.GridSampler)\n assert isinstance(prepared.aggregator, tio.GridAggregator)\n assert isinstance(prepared.patch_loader, DataLoader)\n assert prepared.org_size is not None\n\n # Test 2: Decomposed path equals predict() output\n pred_decomposed, affine_decomposed = _engine._postprocess(\n _engine._run_inference(\n _engine._prepare_subject(img_path)\n ),\n _engine._prepare_subject(img_path)\n )\n pred_predict, affine_predict = _engine.predict(img_path, return_affine=True)\n\n assert torch.equal(pred_decomposed, pred_predict), \"Decomposed path should match predict()\"\n assert np.array_equal(affine_decomposed, affine_predict), \"Affine should match\"\n\n # Test 3: prefetch=True produces identical results to prefetch=False\n paths = [img_path, img_path] # Two copies to trigger pipeline path\n preds_prefetch = patch_inference(\n _model, _config, paths, apply_reorder=False, progress=False, prefetch=True\n )\n preds_sequential = patch_inference(\n _model, _config, paths, apply_reorder=False, progress=False, prefetch=False\n )\n assert len(preds_prefetch) == len(preds_sequential) == 2\n for p1, p2 in zip(preds_prefetch, preds_sequential):\n assert torch.equal(p1, p2), \"prefetch=True should produce identical results\"\n\n # Test 4: Error propagation -- file-not-found raises (not silently swallowed)\n test_fail(\n lambda: patch_inference(\n _model, _config, ['/nonexistent/file.nii.gz'],\n apply_reorder=False, progress=False, prefetch=True\n )\n )\n test_fail(\n lambda: patch_inference(\n _model, _config, [img_path, '/nonexistent/file.nii.gz'],\n apply_reorder=False, progress=False, prefetch=True\n )\n )\n\n # Test 5: Save pipeline works correctly with prefetch=True\n save_dir = os.path.join(tmpdir, 'preds')\n preds_saved = patch_inference(\n _model, _config, paths, apply_reorder=False, progress=False,\n save_dir=save_dir, prefetch=True\n )\n assert len(preds_saved) == 2\n saved_files = list(Path(save_dir).glob('*.nii.gz'))\n assert len(saved_files) == 1, f\"Expected 1 unique file (same input), got {len(saved_files)}\"\n # Verify the saved file is valid NIfTI\n saved_nii = nib.load(str(saved_files[0]))\n assert saved_nii.shape is not None\n\nprint(\"Pipeline inference tests passed!\")" + }, { "cell_type": "markdown", "id": "cell-19",