From 8b7c0d5608f878e58f4fdac1f2796ab9d68b40ee Mon Sep 17 00:00:00 2001 From: Anthony Gagnon Date: Mon, 16 Mar 2026 14:05:35 -0400 Subject: [PATCH] add support for generating mesh from a brain mask --- yabplot/__init__.py | 2 +- yabplot/plotting.py | 35 +++++++++++++-- yabplot/utils.py | 106 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 138 insertions(+), 5 deletions(-) diff --git a/yabplot/__init__.py b/yabplot/__init__.py index 1f89705..c7b4ae7 100644 --- a/yabplot/__init__.py +++ b/yabplot/__init__.py @@ -3,7 +3,7 @@ from .plotting import plot_cortical, plot_subcortical, plot_tracts, clear_tract_cache, plot_vertexwise from .data import get_available_resources, get_atlas_regions from .atlas_builder import build_cortical_atlas, build_subcortical_atlas -from .utils import load_vertexwise_mesh, project_vol2surf +from .utils import load_vertexwise_mesh, project_vol2surf, load_nifti_mask_as_surface try: __version__ = version("yabplot") diff --git a/yabplot/plotting.py b/yabplot/plotting.py index 38f7acd..d4907a5 100644 --- a/yabplot/plotting.py +++ b/yabplot/plotting.py @@ -15,7 +15,8 @@ load_gii, load_gii2pv, prep_data, generate_distinct_colors, parse_lut, map_values_to_surface, get_puzzle_pieces, apply_internal_blur, apply_dilation, - get_smooth_mask, lines_from_streamlines, make_cortical_mesh + get_smooth_mask, lines_from_streamlines, make_cortical_mesh, + load_nifti_mask_as_surface ) from .scene import ( @@ -480,8 +481,8 @@ def plot_tracts(data=None, atlas=None, custom_atlas_path=None, views=None, layou figsize=(1000, 800), cmap='coolwarm', alpha=1.0, vminmax=[None, None], nan_color='#BDBDBD', nan_alpha=1.0, style='default', bmesh_type='midthickness', bmesh_alpha=0.2, bmesh_color='lightgray', - zoom=1.2, orientation_coloring=False, display_type='static', - tract_kwargs=dict(render_lines_as_tubes=True, line_width=1.2), + nifti_mask=None, nifti_mask_blur=1.5, nifti_smooth=10, zoom=1.2, orientation_coloring=False, + display_type='static', tract_kwargs=dict(render_lines_as_tubes=True, line_width=1.2), export_path=None): """ Visualize data on the white matter tractography bundles using a specified atlas. @@ -529,6 +530,19 @@ def plot_tracts(data=None, atlas=None, custom_atlas_path=None, views=None, layou Opacity of the context brain mesh. Default is 0.2. bmesh_color : str, optional Color of the context brain mesh. + nifti_mask : str, optional + Path to a NIfTI file containing a binary mask. Mesh will be generated + from this mask and used as a background context instead of the standard brain mesh. + `bmesh_type` need to be set to None for this to work. + Default is None. + nifti_mask_blur : float, optional + Standard deviation (in voxels) of Gaussian blur applied to the mask before + surface extraction. Higher values produce smoother surfaces. + Typical range: 1.0-3.0. Default is 1.5. + nifti_smooth : int, optional + Number of smoothing iterations for the extracted NIfTI surface. + Higher values produce smoother surfaces but may lose detail. + Set to 0 to disable smoothing. Default is 10. zoom : float, optional Camera zoom level. >1.0 zooms in, <1.0 zooms out. Default is 1.2. orientation_coloring : bool, optional @@ -577,9 +591,24 @@ def plot_tracts(data=None, atlas=None, custom_atlas_path=None, views=None, layou # load context brain mesh (if requested) bmesh = {} if bmesh_type: + if nifti_mask is not None: + print("Warning: To use a NIfTI mask as background, set `bmesh_type` to None") b_lh_path, b_rh_path = get_surface_paths(bmesh_type, 'bmesh') bmesh['L'] = load_gii2pv(b_lh_path) bmesh['R'] = load_gii2pv(b_rh_path) + elif nifti_mask is not None: + # generate mesh from nifti mask + try: + nifti_mesh = load_nifti_mask_as_surface( + nifti_path=nifti_mask, + mask_blur_sigma=nifti_mask_blur, + smooth_iterations=nifti_smooth, + smooth_factor=0.5 + ) + bmesh['both'] = nifti_mesh + except Exception as e: + print(f"Warning: Failed to load NIfTI background from '{nifti_mask}': {e}") + print("Continuing without background mesh.") # setup plotter sel_views = get_view_configs(views) diff --git a/yabplot/utils.py b/yabplot/utils.py index 0b1c3d9..2e26f56 100644 --- a/yabplot/utils.py +++ b/yabplot/utils.py @@ -5,9 +5,10 @@ import nibabel as nib import pyvista as pv import scipy.sparse as sp -from scipy.ndimage import map_coordinates +from scipy.ndimage import map_coordinates, gaussian_filter import matplotlib.pyplot as plt from importlib.resources import files +from skimage import measure def load_gii(gii_path): """Load GIfTI geometry (vertices, faces).""" @@ -96,6 +97,109 @@ def make_cortical_mesh(verts, faces, scalars, scalar_name='Data'): mesh[scalar_name] = scalars return mesh +def load_nifti_mask_as_surface(nifti_path, mask_blur_sigma=1.5, + smooth_iterations=10, smooth_factor=0.5): + """ + Extract a 3D brain surface mesh from a NIfTI volume using isosurface extraction. + + Converts any non-zero voxels in a NIfTI volume to a binary mask, applies Gaussian + smoothing to create clean boundaries, and extracts a triangulated surface mesh using + marching cubes. The result is a PyVista mesh in native space coordinates suitable + for visualization with volumetric or tractography data. + + Parameters + ---------- + nifti_path : str + Path to the NIfTI file (.nii or .nii.gz). Typically a binary brain mask + or a brain-extract T1w image. + mask_blur_sigma : float, optional + Standard deviation (in voxels) of Gaussian blur applied to the binary mask + before surface extraction. Controls surface smoothness. Higher values (1.5-3.0) + produce smoother surfaces. Default is 1.5. + smooth_iterations : int, optional + Number of Laplacian smoothing iterations applied to the mesh after extraction. + Higher values create smoother meshes but may lose anatomical detail. Set to 0 + to disable. Default is 10. + smooth_factor : float, optional + Relaxation factor for mesh smoothing (range: 0.0 to 1.0). Higher values apply + more aggressive smoothing. Default is 0.5. + + Returns + ------- + pyvista.PolyData + Triangulated surface mesh with vertices in native space (voxel-to-world + coordinates applied via the NIfTI affine matrix). + """ + + # Validate input file exists + if not os.path.exists(nifti_path): + raise FileNotFoundError(f"NIfTI file not found: {nifti_path}") + + # Load NIfTI volume and affine transformation + try: + img = nib.load(nifti_path) + data = img.get_fdata() + affine = img.affine + except Exception as e: + raise RuntimeError(f"Failed to load NIfTI file '{nifti_path}': {str(e)}") + + # Convert to binary mask (any non-zero voxel = 1) + mask_binary = (data > 0).astype(float) + + # Apply Gaussian blur to create smooth boundaries + mask_smoothed = gaussian_filter(mask_binary, sigma=mask_blur_sigma) + + # Extract isosurface at 0.5 (the smooth boundary between 0 and 1) + voxel_spacing = np.abs(np.diag(affine[:3, :3])) + verts, faces, _, _ = measure.marching_cubes( + mask_smoothed, + level=0.5, + spacing=tuple(voxel_spacing) + ) + + # Validate extraction succeeded + if len(verts) == 0 or len(faces) == 0: + raise RuntimeError( + "Surface extraction produced an empty mesh. " + "Verify the NIfTI file contains non-zero voxels." + ) + + # Fill topological holes in extracted mesh + try: + faces_vtk = np.hstack([np.full((faces.shape[0], 1), 3), faces]).flatten().astype(int) + mesh = pv.PolyData(verts, faces_vtk) + mesh = mesh.fill_holes(1000) + verts = mesh.points + faces = mesh.faces.reshape(-1, 4)[:, 1:4] + except Exception as e: + warnings.warn(f"Hole filling failed: {str(e)}. Continuing with original mesh.") + + # Transform vertices from voxel indices to world coordinates + verts_homogeneous = np.c_[verts, np.ones(len(verts))] + verts_world = verts_homogeneous @ affine.T + verts_final = verts_world[:, :3] + + # Convert faces to VTK format + faces_vtk = np.hstack([ + np.full((faces.shape[0], 1), 3, dtype=np.int64), + faces + ]).flatten() + mesh = pv.PolyData(verts_final, faces_vtk) + + # Apply optional Laplacian smoothing + if smooth_iterations > 0: + if not (0.0 <= smooth_factor <= 1.0): + raise ValueError( + f"smooth_factor must be in range [0.0, 1.0], got {smooth_factor}" + ) + try: + mesh = mesh.smooth(n_iter=smooth_iterations, relaxation_factor=smooth_factor) + except Exception as e: + warnings.warn(f"Mesh smoothing failed: {str(e)}. Using unsmoothed mesh.") + + return mesh + + def prep_data(data, regions, atlas, category): """Standardize input data to dictionary.""" if isinstance(data, pd.DataFrame):