"
+ if len(raw_lines) < 3:
+ raise ValueError(f"VTK file too short: {path!r}")
+
+ fmt_line = raw_lines[2].strip().upper()
+ if "BINARY" in fmt_line:
+ raise ValueError(
+ f"Only ASCII legacy VTK POLYDATA is supported; "
+ f"file appears to be BINARY: {path!r}. "
+ f"Convert with: vtk-convert or meshio-convert."
+ )
+ if "ASCII" not in fmt_line:
+ raise ValueError(
+ f"Could not determine VTK format from line 3 (expected 'ASCII' or "
+ f"'BINARY'): {raw_lines[2]!r} in {path!r}."
+ )
+
+ # Scan for DATASET POLYDATA
+ dataset_found = False
+ for line in raw_lines:
+ if line.strip().upper().startswith("DATASET"):
+ if "POLYDATA" not in line.upper():
+ raise ValueError(
+ f"Only POLYDATA VTK datasets are supported, "
+ f"got: {line.strip()!r} in {path!r}."
+ )
+ dataset_found = True
+ break
+ if not dataset_found:
+ raise ValueError(f"No DATASET line found in VTK file {path!r}.")
+
+ # Tokenise everything into a flat list for easy sectioned parsing
+ lines = [raw_ln.strip() for raw_ln in raw_lines if raw_ln.strip() and not raw_ln.strip().startswith("#")]
+
+ vertices = None
+ faces = None
+ i = 0
+ while i < len(lines):
+ upper = lines[i].upper()
+
+ if upper.startswith("POINTS"):
+ parts = lines[i].split()
+ n_pts = int(parts[1])
+ # Collect 3*n_pts floats; they may span multiple lines
+ floats = []
+ i += 1
+ while len(floats) < 3 * n_pts and i < len(lines):
+ # Stop at next keyword section
+ if lines[i].upper().split()[0] in (
+ "POLYGONS", "LINES", "STRIPS", "VERTICES",
+ "POINT_DATA", "CELL_DATA", "FIELD", "NORMALS",
+ "TEXTURE_COORDINATES", "SCALARS", "LOOKUP_TABLE",
+ ):
+ break
+ floats.extend(float(x) for x in lines[i].split())
+ i += 1
+ if len(floats) < 3 * n_pts:
+ raise ValueError(
+ f"Expected {3 * n_pts} floats for POINTS but got "
+ f"{len(floats)} in {path!r}."
+ )
+ vertices = np.array(floats[: 3 * n_pts], dtype=np.float32).reshape(n_pts, 3)
+ continue # i already advanced
+
+ elif upper.startswith("POLYGONS"):
+ parts = lines[i].split()
+ n_polys = int(parts[1])
+ face_list = []
+ i += 1
+ while len(face_list) < n_polys and i < len(lines):
+ if lines[i].upper().split()[0] in (
+ "POINTS", "LINES", "STRIPS", "VERTICES",
+ "POINT_DATA", "CELL_DATA", "FIELD", "NORMALS",
+ "TEXTURE_COORDINATES", "SCALARS", "LOOKUP_TABLE",
+ ):
+ break
+ tokens = lines[i].split()
+ n_poly = int(tokens[0])
+ if n_poly != 3:
+ raise ValueError(
+ f"VTK polygon {len(face_list)} has {n_poly} vertices; "
+ f"only triangles (3) are supported in {path!r}. "
+ f"Triangulate the mesh before loading."
+ )
+ face_list.append([int(tokens[1]), int(tokens[2]), int(tokens[3])])
+ i += 1
+ if len(face_list) < n_polys:
+ raise ValueError(
+ f"Expected {n_polys} polygons but only parsed "
+ f"{len(face_list)} in {path!r}."
+ )
+ faces = np.array(face_list, dtype=np.uint32)
+ continue # i already advanced
+
+ else:
+ i += 1
+
+ if vertices is None:
+ raise ValueError(f"No POINTS section found in VTK file {path!r}.")
+ if faces is None:
+ raise ValueError(f"No POLYGONS section found in VTK file {path!r}.")
+
+ # Bounds check
+ n_verts = vertices.shape[0]
+ if faces.size > 0 and (int(faces.max()) >= n_verts or int(faces.min()) < 0):
+ raise ValueError(
+ f"VTK face indices out of range [0, {n_verts}) in {path!r}."
+ )
+
+ return vertices, faces
+
+
+# ---------------------------------------------------------------------------
+# PLY ASCII reader
+# ---------------------------------------------------------------------------
+
+def read_ply_ascii(path):
+ """Read an ASCII PLY triangle mesh.
+
+ Only ASCII PLY files are supported. Binary PLY files are explicitly
+ rejected with a helpful error message.
+
+ Parameters
+ ----------
+ path : str
+ Path to the ``.ply`` file.
+
+ Returns
+ -------
+ vertices : numpy.ndarray, shape (N, 3), dtype float32
+ faces : numpy.ndarray, shape (M, 3), dtype uint32
+
+ Raises
+ ------
+ ValueError
+ If the file is binary PLY, if faces are not triangles, or if the
+ header is malformed.
+ IOError
+ If the file cannot be opened.
+ """
+ with open(path, encoding="utf-8", errors="replace") as fh:
+ raw_lines = fh.readlines()
+
+ if not raw_lines or raw_lines[0].strip() != "ply":
+ raise ValueError(
+ f"File does not start with 'ply' magic; not a PLY file: {path!r}."
+ )
+
+ # Check encoding
+ for line in raw_lines[1:4]:
+ stripped = line.strip().lower()
+ if stripped.startswith("format"):
+ if "ascii" not in stripped:
+ raise ValueError(
+ f"PLY binary format not supported; only ASCII PLY is "
+ f"accepted: {path!r}. "
+ f"Convert with: plyconvert or meshio-convert."
+ )
+ break
+
+ # Parse header
+ n_verts = None
+ n_faces = None
+ vertex_props = [] # ordered list of property names for the vertex element
+ in_vertex = False
+ header_end = 0
+
+ for idx, line in enumerate(raw_lines):
+ stripped = line.strip()
+ lower = stripped.lower()
+
+ if lower == "end_header":
+ header_end = idx + 1
+ break
+ if lower.startswith("element vertex"):
+ n_verts = int(stripped.split()[2])
+ in_vertex = True
+ elif lower.startswith("element face"):
+ n_faces = int(stripped.split()[2])
+ in_vertex = False
+ elif lower.startswith("property") and in_vertex:
+ # e.g. "property float x"
+ parts = stripped.split()
+ if len(parts) >= 3:
+ vertex_props.append(parts[-1])
+ elif (lower.startswith("element")
+ and not lower.startswith("element vertex")
+ and not lower.startswith("element face")):
+ in_vertex = False
+
+ if n_verts is None:
+ raise ValueError(f"No 'element vertex' found in PLY header: {path!r}.")
+ if n_faces is None:
+ raise ValueError(f"No 'element face' found in PLY header: {path!r}.")
+
+ # Determine column indices for x, y, z
+ try:
+ xi = vertex_props.index("x")
+ yi = vertex_props.index("y")
+ zi = vertex_props.index("z")
+ except ValueError as exc:
+ raise ValueError(
+ f"PLY vertex element missing x/y/z properties in {path!r}; "
+ f"found: {vertex_props!r}."
+ ) from exc
+
+ data_lines = [raw_line.strip() for raw_line in raw_lines[header_end:] if raw_line.strip()]
+
+ if len(data_lines) < n_verts + n_faces:
+ raise ValueError(
+ f"PLY file has {len(data_lines)} data lines but expects "
+ f"{n_verts} vertices + {n_faces} faces in {path!r}."
+ )
+
+ # Parse vertices
+ vertices = np.empty((n_verts, 3), dtype=np.float32)
+ for i in range(n_verts):
+ try:
+ tokens = data_lines[i].split()
+ vertices[i] = [float(tokens[xi]), float(tokens[yi]), float(tokens[zi])]
+ except (ValueError, IndexError) as exc:
+ raise ValueError(
+ f"Could not parse PLY vertex {i} in {path!r}: "
+ f"{data_lines[i]!r}"
+ ) from exc
+
+ # Parse faces
+ faces = np.empty((n_faces, 3), dtype=np.uint32)
+ for j in range(n_faces):
+ try:
+ tokens = data_lines[n_verts + j].split()
+ count = int(tokens[0])
+ except (ValueError, IndexError) as exc:
+ raise ValueError(
+ f"Could not parse PLY face {j} in {path!r}: "
+ f"{data_lines[n_verts + j]!r}"
+ ) from exc
+ if count != 3:
+ raise ValueError(
+ f"PLY face {j} has {count} vertices; only triangles (3) are "
+ f"supported in {path!r}. Triangulate the mesh first."
+ )
+ try:
+ faces[j] = [int(tokens[1]), int(tokens[2]), int(tokens[3])]
+ except (ValueError, IndexError) as exc:
+ raise ValueError(
+ f"Could not parse PLY face indices at face {j} in {path!r}: "
+ f"{data_lines[n_verts + j]!r}"
+ ) from exc
+
+ # Bounds check
+ if n_faces > 0 and (int(faces.max()) >= n_verts or int(faces.min()) < 0):
+ raise ValueError(
+ f"PLY face indices out of range [0, {n_verts}) in {path!r}."
+ )
+
+ return vertices, faces
+
+
+# ---------------------------------------------------------------------------
+# GIfTI surface reader
+# ---------------------------------------------------------------------------
+
+# NIFTI intent codes relevant to GIfTI surface files
+_GIFTI_INTENT_POINTSET = 1008 # NIFTI_INTENT_POINTSET — vertex coordinates
+_GIFTI_INTENT_TRIANGLE = 1009 # NIFTI_INTENT_TRIANGLE — face indices
+
+
+def read_gifti_surface(path):
+ """Read a GIfTI surface file (``.surf.gii`` or ``.gii``).
+
+ A standard GIfTI surface file contains exactly two data arrays:
+
+ * One with intent ``NIFTI_INTENT_POINTSET`` (1008) — vertex coordinates,
+ shape ``(N, 3)``, dtype ``float32``.
+ * One with intent ``NIFTI_INTENT_TRIANGLE`` (1009) — face indices,
+ shape ``(M, 3)``, dtype ``int32``.
+
+ This is the format produced by FreeSurfer's ``mris_convert``,
+ Connectome Workbench, fMRIPrep, and most HCP pipelines.
+
+ Parameters
+ ----------
+ path : str
+ Path to a GIfTI surface file.
+
+ Returns
+ -------
+ vertices : numpy.ndarray, shape (N, 3), dtype float32
+ faces : numpy.ndarray, shape (M, 3), dtype uint32
+
+ Raises
+ ------
+ ImportError
+ If ``nibabel`` is not installed.
+ ValueError
+ If the file does not contain the expected POINTSET and TRIANGLE
+ data arrays, or if the arrays have unexpected shapes.
+ Also raised if the file appears to be a functional/label GIfTI
+ rather than a surface GIfTI (i.e. no POINTSET array found) — in
+ that case, use ``resolve_overlay()`` instead.
+ IOError
+ If the file cannot be opened or is not a valid GIfTI file.
+
+ Examples
+ --------
+ >>> v, f = read_gifti_surface('lh.white.surf.gii')
+ >>> v.shape # (N, 3)
+ >>> f.shape # (M, 3)
+ """
+ try:
+ import nibabel as nib # noqa: PLC0415
+ except ImportError as exc:
+ raise ImportError(
+ "Reading GIfTI surface files requires nibabel. "
+ "Install with: pip install nibabel"
+ ) from exc
+
+ img = nib.load(path)
+ if not hasattr(img, "darrays") or not img.darrays:
+ raise ValueError(
+ f"GIfTI file {path!r} contains no data arrays. "
+ f"Expected a surface file with POINTSET and TRIANGLE arrays."
+ )
+
+ coords_arr = None
+ faces_arr = None
+ for da in img.darrays:
+ if da.intent == _GIFTI_INTENT_POINTSET and coords_arr is None:
+ coords_arr = da.data
+ elif da.intent == _GIFTI_INTENT_TRIANGLE and faces_arr is None:
+ faces_arr = da.data
+
+ if coords_arr is None:
+ raise ValueError(
+ f"GIfTI file {path!r} has no POINTSET (intent 1008) array. "
+ f"If this is a functional or label file, use resolve_overlay() instead."
+ )
+ if faces_arr is None:
+ raise ValueError(
+ f"GIfTI file {path!r} has no TRIANGLE (intent 1009) array. "
+ f"A valid surface GIfTI must contain both POINTSET and TRIANGLE arrays."
+ )
+
+ vertices = np.asarray(coords_arr, dtype=np.float32)
+ faces = np.asarray(faces_arr, dtype=np.uint32)
+
+ if vertices.ndim != 2 or vertices.shape[1] != 3:
+ raise ValueError(
+ f"GIfTI POINTSET array in {path!r} has unexpected shape "
+ f"{vertices.shape}; expected (N, 3)."
+ )
+ if faces.ndim != 2 or faces.shape[1] != 3:
+ raise ValueError(
+ f"GIfTI TRIANGLE array in {path!r} has unexpected shape "
+ f"{faces.shape}; expected (M, 3)."
+ )
+ if faces.size > 0:
+ n_verts = vertices.shape[0]
+ if int(faces.max()) >= n_verts or int(faces.min()) < 0:
+ raise ValueError(
+ f"GIfTI face indices out of range [0, {n_verts}) in {path!r}."
+ )
+
+ return vertices, faces
+
+
+# ---------------------------------------------------------------------------
+# Dispatcher
+# ---------------------------------------------------------------------------
+
+_READERS = {
+ ".off": read_off,
+ ".vtk": read_vtk_ascii_polydata,
+ ".ply": read_ply_ascii,
+ ".gii": read_gifti_surface,
+ ".surf.gii": read_gifti_surface, # compound extension — checked first
+}
+
+_SUPPORTED = ", ".join(sorted(_READERS))
+
+
+def read_mesh(path):
+ """Read a triangle mesh from an OFF, VTK, PLY, or GIfTI surface file.
+
+ Dispatches to the appropriate reader based on the file extension.
+ For FreeSurfer binary surfaces (which typically have no standard
+ extension, e.g. ``lh.white``) use
+ :func:`whippersnappy.geometry.read_geometry` directly, or pass the
+ path through :func:`whippersnappy.geometry.inputs.resolve_mesh` which
+ handles the routing automatically.
+
+ Parameters
+ ----------
+ path : str
+ Path to a mesh file. Extension must be one of:
+ ``.off``, ``.vtk``, ``.ply``, ``.surf.gii``, ``.gii``
+ (case-insensitive).
+
+ Returns
+ -------
+ vertices : numpy.ndarray, shape (N, 3), dtype float32
+ faces : numpy.ndarray, shape (M, 3), dtype uint32
+
+ Raises
+ ------
+ ValueError
+ If the extension is not recognised.
+ """
+ import os as _os
+ lower = path.lower()
+ # Check compound extension first (.surf.gii before .gii)
+ if lower.endswith(".surf.gii"):
+ return read_gifti_surface(path)
+ ext = _os.path.splitext(lower)[1]
+ reader = _READERS.get(ext)
+ if reader is None:
+ raise ValueError(
+ f"Unsupported mesh file extension {ext!r} for {path!r}. "
+ f"Supported formats: {_SUPPORTED}. "
+ f"For FreeSurfer surfaces (no extension) use resolve_mesh() "
+ f"or read_geometry() directly."
+ )
+ return reader(path)
+
+
+
+
+
diff --git a/whippersnappy/geometry/overlay_io.py b/whippersnappy/geometry/overlay_io.py
new file mode 100644
index 0000000..04ea510
--- /dev/null
+++ b/whippersnappy/geometry/overlay_io.py
@@ -0,0 +1,345 @@
+"""Lightweight per-vertex scalar and label readers for common open formats.
+
+This module implements pure-Python (stdlib + numpy only) readers for simple
+per-vertex data files, plus a GIfTI reader that reuses the nibabel dependency
+already present in the project.
+
+Supported formats
+-----------------
+* **ASCII text** (``.txt``, ``.csv``) — one numeric value per line; optional
+ single non-numeric header line (skipped automatically); whitespace or
+ comma-separated. Integer values are loaded as ``int32``; all others as
+ ``float32``.
+
+* **NumPy array** (``.npy``) — single 1-D array saved with
+ ``numpy.save``. Any numeric dtype is accepted and kept as-is; callers
+ cast to the required dtype.
+
+* **NumPy archive** (``.npz``) — multi-array archive saved with
+ ``numpy.savez``. The array named ``"data"`` is used if present,
+ otherwise the first array in the archive is used.
+
+* **GIfTI functional / label** (``.func.gii``, ``.label.gii``, ``.gii``) —
+ loaded via ``nibabel``; the first data array is returned. Covers HCP,
+ fMRIPrep, and Connectome Workbench outputs.
+
+The public dispatcher :func:`read_overlay` routes by file extension.
+FreeSurfer binary morph files and MGH/MGZ files are *not* handled here —
+they are loaded by :mod:`whippersnappy.geometry.freesurfer_io` and
+dispatched from :func:`whippersnappy.geometry.inputs._load_overlay_from_file`.
+
+All readers return a flat ``numpy.ndarray`` of shape ``(N,)``. The caller
+is responsible for casting to the desired dtype (``float32`` for overlays and
+background maps, ``bool`` for ROI masks, ``int32`` for label/parcellation
+maps).
+"""
+
+import os
+
+import numpy as np
+
+# ---------------------------------------------------------------------------
+# ASCII text / CSV reader
+# ---------------------------------------------------------------------------
+
+def read_txt(path):
+ """Read a per-vertex scalar file in plain ASCII format.
+
+ The file must contain exactly one numeric value per line. An optional
+ single-line text header (non-numeric first line) is silently skipped.
+ Whitespace and comma separators are both accepted; only the *first*
+ value on each line is used (allowing simple CSV files with a single
+ data column).
+
+ Integer-valued files (every value equal to its ``int`` cast) are
+ returned as ``int32``; all others as ``float32``.
+
+ Parameters
+ ----------
+ path : str
+ Path to the ``.txt`` or ``.csv`` file.
+
+ Returns
+ -------
+ numpy.ndarray, shape (N,), dtype float32 or int32
+
+ Raises
+ ------
+ ValueError
+ If no numeric values can be parsed from the file.
+ IOError
+ If the file cannot be opened.
+
+ Examples
+ --------
+ A valid ``overlay.txt``::
+
+ # optional comment line (skipped)
+ 0.123
+ -1.456
+ 2.0
+
+ A valid ``labels.csv`` (first column used, header skipped)::
+
+ label
+ 3
+ 0
+ 1
+ 3
+ """
+ values = []
+ with open(path, encoding="utf-8", errors="replace") as fh:
+ for lineno, raw in enumerate(fh, start=1):
+ line = raw.strip()
+ if not line or line.startswith("#"):
+ continue
+ # Take only the first token (handles CSV with a single data column)
+ token = line.split(",")[0].split()[0]
+ try:
+ values.append(float(token))
+ except ValueError as exc:
+ if lineno == 1:
+ # Treat the very first non-numeric line as a header and skip it
+ continue
+ raise ValueError(
+ f"Could not parse numeric value on line {lineno} of {path!r}: "
+ f"{raw.strip()!r}"
+ ) from exc
+
+ if not values:
+ raise ValueError(f"No numeric values found in {path!r}.")
+
+ arr = np.array(values, dtype=np.float32)
+
+ # Promote to int32 if all values are integers (label / parcellation file)
+ if np.all(arr == arr.astype(np.int32)):
+ return arr.astype(np.int32)
+ return arr
+
+
+# ---------------------------------------------------------------------------
+# NumPy readers
+# ---------------------------------------------------------------------------
+
+def read_npy(path):
+ """Read a per-vertex scalar array from a NumPy ``.npy`` file.
+
+ Parameters
+ ----------
+ path : str
+ Path to the ``.npy`` file.
+
+ Returns
+ -------
+ numpy.ndarray, shape (N,)
+ The stored array, squeezed to 1-D.
+
+ Raises
+ ------
+ ValueError
+ If the stored array is not 1-D after squeezing, or is empty.
+ IOError
+ If the file cannot be opened.
+ """
+ arr = np.load(path)
+ arr = np.squeeze(arr)
+ if arr.ndim != 1:
+ raise ValueError(
+ f"NumPy file {path!r} contains an array of shape {arr.shape}; "
+ f"expected a 1-D per-vertex array."
+ )
+ if arr.size == 0:
+ raise ValueError(f"NumPy file {path!r} contains an empty array.")
+ return arr
+
+
+def read_npz(path):
+ """Read a per-vertex scalar array from a NumPy ``.npz`` archive.
+
+ The array named ``"data"`` is returned if it exists; otherwise the
+ first array in the archive is used.
+
+ Parameters
+ ----------
+ path : str
+ Path to the ``.npz`` file.
+
+ Returns
+ -------
+ numpy.ndarray, shape (N,)
+ The selected array, squeezed to 1-D.
+
+ Raises
+ ------
+ ValueError
+ If no arrays are found, or the selected array is not 1-D after
+ squeezing.
+ IOError
+ If the file cannot be opened.
+ """
+ archive = np.load(path)
+ keys = list(archive.keys())
+ if not keys:
+ raise ValueError(f"NumPy archive {path!r} contains no arrays.")
+
+ key = "data" if "data" in keys else keys[0]
+ arr = np.squeeze(archive[key])
+ if arr.ndim != 1:
+ raise ValueError(
+ f"NumPy archive {path!r}, array {key!r} has shape {arr.shape}; "
+ f"expected a 1-D per-vertex array."
+ )
+ if arr.size == 0:
+ raise ValueError(f"NumPy archive {path!r}, array {key!r} is empty.")
+ return arr
+
+
+# ---------------------------------------------------------------------------
+# GIfTI reader
+# ---------------------------------------------------------------------------
+
+def read_gifti(path):
+ """Read a per-vertex scalar array from a GIfTI functional or label file.
+
+ Supports ``.func.gii`` (continuous scalars, e.g. HCP thickness) and
+ ``.label.gii`` (integer parcellation labels, e.g. HCP parcellation).
+ Plain ``.gii`` files are also accepted provided they contain a scalar
+ data array — **not** a surface geometry file. For surface GIfTI files
+ (``.surf.gii`` or ``.gii`` files with POINTSET+TRIANGLE arrays) use
+ :func:`whippersnappy.geometry.mesh_io.read_gifti_surface` or pass the
+ path to :func:`whippersnappy.geometry.inputs.resolve_mesh`.
+
+ The first non-POINTSET, non-TRIANGLE data array in the file is returned.
+
+ Parameters
+ ----------
+ path : str
+ Path to a GIfTI file.
+
+ Returns
+ -------
+ numpy.ndarray, shape (N,)
+ The first scalar data array, squeezed to 1-D.
+
+ Raises
+ ------
+ ImportError
+ If ``nibabel`` is not installed.
+ ValueError
+ If the file is a surface GIfTI (POINTSET+TRIANGLE only), contains
+ no usable scalar arrays, or the first scalar array is not 1-D.
+ IOError
+ If the file cannot be opened or is not a valid GIfTI file.
+ """
+ try:
+ import nibabel as nib # noqa: PLC0415
+ except ImportError as exc:
+ raise ImportError(
+ "Reading GIfTI files requires nibabel. "
+ "Install with: pip install nibabel"
+ ) from exc
+
+ img = nib.load(path)
+ if not hasattr(img, "darrays") or not img.darrays:
+ raise ValueError(
+ f"GIfTI file {path!r} contains no data arrays."
+ )
+
+ # Intent codes for surface geometry — skip these
+ _SURFACE_INTENTS = {1008, 1009} # POINTSET, TRIANGLE
+
+ scalar_da = None
+ has_surface_arrays = False
+ for da in img.darrays:
+ if da.intent in _SURFACE_INTENTS:
+ has_surface_arrays = True
+ elif scalar_da is None:
+ scalar_da = da
+
+ if scalar_da is None:
+ if has_surface_arrays:
+ raise ValueError(
+ f"GIfTI file {path!r} appears to be a surface geometry file "
+ f"(contains only POINTSET/TRIANGLE arrays). "
+ f"Use resolve_mesh() or read_gifti_surface() to load it as a mesh."
+ )
+ raise ValueError(
+ f"GIfTI file {path!r} contains no scalar data arrays."
+ )
+
+ arr = np.squeeze(scalar_da.data)
+ if arr.ndim != 1:
+ raise ValueError(
+ f"GIfTI file {path!r}: first scalar data array has shape "
+ f"{scalar_da.data.shape}; expected a 1-D per-vertex array."
+ )
+ if arr.size == 0:
+ raise ValueError(f"GIfTI file {path!r}: first scalar data array is empty.")
+ return arr
+
+
+# ---------------------------------------------------------------------------
+# Dispatcher
+# ---------------------------------------------------------------------------
+
+# Map from lower-case file extension to reader function.
+# Note: ".func.gii" and ".label.gii" have a compound extension; we handle
+# them by matching the last *two* dot-separated components as well.
+_READERS = {
+ ".txt": read_txt,
+ ".csv": read_txt,
+ ".npy": read_npy,
+ ".npz": read_npz,
+ ".gii": read_gifti,
+ ".func.gii": read_gifti,
+ ".label.gii": read_gifti,
+}
+
+_SUPPORTED = ", ".join(sorted(_READERS))
+
+
+def read_overlay(path):
+ """Read a per-vertex scalar or label array from a file.
+
+ Dispatches to the appropriate reader based on the file extension.
+ FreeSurfer binary morph files (e.g. ``lh.curv``, ``lh.thickness``) and
+ MGH/MGZ files are **not** handled here — pass them through
+ :func:`whippersnappy.geometry.inputs._load_overlay_from_file` which
+ already routes those formats via :mod:`~whippersnappy.geometry.freesurfer_io`.
+
+ Parameters
+ ----------
+ path : str
+ Path to an overlay/label file. Recognised extensions:
+
+ * ``.txt``, ``.csv`` — plain ASCII, one value per line
+ * ``.npy`` — NumPy binary array
+ * ``.npz`` — NumPy archive (key ``"data"`` or first array)
+ * ``.gii``, ``.func.gii``, ``.label.gii`` — GIfTI
+
+ Returns
+ -------
+ numpy.ndarray, shape (N,)
+
+ Raises
+ ------
+ ValueError
+ If the file extension is not recognised.
+ """
+ # Check compound extensions first (.func.gii, .label.gii)
+ lower = path.lower()
+ for compound in (".func.gii", ".label.gii"):
+ if lower.endswith(compound):
+ return _READERS[compound](path)
+
+ ext = os.path.splitext(path)[1].lower()
+ reader = _READERS.get(ext)
+ if reader is None:
+ raise ValueError(
+ f"Unsupported overlay file extension {ext!r} for {path!r}. "
+ f"Supported formats: {_SUPPORTED}. "
+ f"For FreeSurfer morph files (no extension) or .mgh/.mgz files "
+ f"the routing is handled automatically by resolve_overlay()."
+ )
+ return reader(path)
+
diff --git a/whippersnappy/geometry/prepare.py b/whippersnappy/geometry/prepare.py
new file mode 100644
index 0000000..d5df718
--- /dev/null
+++ b/whippersnappy/geometry/prepare.py
@@ -0,0 +1,454 @@
+"""Geometry helpers for mesh processing and GPU preparation (prepare.py).
+
+This module contains the primary geometry-preparation pipeline. The
+low-level workhorse is :func:`prepare_geometry_from_arrays` which operates
+entirely on numpy arrays. :func:`prepare_geometry` is a thin file-loading
+wrapper that delegates to the resolver functions in
+:mod:`whippersnappy.geometry.inputs` before calling
+:func:`prepare_geometry_from_arrays`.
+"""
+
+import warnings
+
+import numpy as np
+
+from ..utils.colormap import binary_color, heat_color, mask_sign, rescale_overlay
+from ..utils.types import ColorSelection
+from .inputs import resolve_annot, resolve_bg_map, resolve_mesh, resolve_overlay, resolve_roi
+
+
+def normalize_mesh(v, scale=1.0):
+ """Center and scale mesh vertex coordinates to a unit cube.
+
+ The function recenters the vertices around the origin and scales them so
+ that the maximum extent fits into a unit cube, optionally applying an
+ additional scale factor.
+
+ Parameters
+ ----------
+ v : numpy.ndarray
+ Vertex coordinate array of shape (n_vertices, 3).
+ scale : float, optional
+ Additional multiplicative scale applied after normalization.
+
+ Returns
+ -------
+ numpy.ndarray
+ Normalized vertex coordinates with same shape as ``v``.
+ """
+ bbmax = np.max(v, axis=0)
+ bbmin = np.min(v, axis=0)
+ v = v - 0.5 * (bbmax + bbmin)
+ v = scale * v / np.max(bbmax - bbmin)
+ return v
+
+
+def vertex_normals(v, t):
+ """Compute per-vertex normals from triangle connectivity.
+
+ Parameters
+ ----------
+ v : numpy.ndarray
+ Vertex coordinates (n_vertices, 3).
+ t : numpy.ndarray
+ Triangle indices (n_faces, 3).
+
+ Returns
+ -------
+ numpy.ndarray
+ Per-vertex unit normals (n_vertices, 3).
+ """
+ v0 = v[t[:, 0], :]
+ v1 = v[t[:, 1], :]
+ v2 = v[t[:, 2], :]
+ v1mv0 = v1 - v0
+ v2mv1 = v2 - v1
+ v0mv2 = v0 - v2
+ cr0 = np.cross(v1mv0, -v0mv2)
+ cr1 = np.cross(v2mv1, -v1mv0)
+ cr2 = np.cross(v0mv2, -v2mv1)
+ # Vectorized accumulation using bincount
+ idx = np.concatenate([t[:, 0], t[:, 1], t[:, 2]])
+ contribs = np.vstack([cr0, cr1, cr2])
+ n = np.empty((v.shape[0], 3), dtype=np.float64)
+ for j in range(3):
+ n[:, j] = np.bincount(idx, weights=contribs[:, j], minlength=v.shape[0])
+ ln = np.sqrt(np.sum(n * n, axis=1))
+ ln[ln < np.finfo(float).eps] = 1
+ n = n / ln.reshape(-1, 1)
+ return n
+
+
+def _estimate_thresholds_from_array(mapdata, minval=None, maxval=None):
+ """Estimate threshold and saturation values from an already-loaded array.
+
+ Parameters
+ ----------
+ mapdata : numpy.ndarray
+ Per-vertex overlay values.
+ minval : float or None, optional
+ If provided, used as-is; otherwise estimated as the minimum absolute
+ value in the data.
+ maxval : float or None, optional
+ If provided, used as-is; otherwise estimated as the maximum absolute
+ value in the data.
+
+ Returns
+ -------
+ minval : float
+ Threshold value (lower bound of the color scale).
+ maxval : float
+ Saturation value (upper bound of the color scale).
+ """
+ valabs = np.abs(mapdata)
+ if maxval is None:
+ maxval = float(np.max(valabs)) if np.any(valabs) else 0.0
+ if minval is None:
+ minval = float(max(0.0, np.min(valabs) if np.any(valabs) else 0.0))
+ return minval, maxval
+
+
+def estimate_overlay_thresholds(overlay, minval=None, maxval=None):
+ """Estimate threshold and saturation values from an overlay file or array.
+
+ Reads the overlay data and derives ``fmin`` / ``fmax`` from the absolute
+ values without performing any geometry or color work. Both values are
+ returned unchanged when they are already provided by the caller, making
+ the function safe to call unconditionally.
+
+ Parameters
+ ----------
+ overlay : str or array-like
+ Path to the overlay file (.mgh or FreeSurfer morph format), or a
+ numpy array / array-like of per-vertex scalar values.
+ minval : float or None, optional
+ If provided, used as-is for the threshold; otherwise estimated as
+ the minimum absolute value in the overlay.
+ maxval : float or None, optional
+ If provided, used as-is for the saturation; otherwise estimated as
+ the maximum absolute value in the overlay.
+
+ Returns
+ -------
+ minval : float
+ Threshold value (lower bound of the color scale).
+ maxval : float
+ Saturation value (upper bound of the color scale).
+ """
+ if isinstance(overlay, str):
+ # Use resolve_overlay with n_vertices=None to skip shape validation
+ overlay_arr = resolve_overlay(overlay, n_vertices=None)
+ else:
+ overlay_arr = np.asarray(overlay)
+ return _estimate_thresholds_from_array(overlay_arr, minval, maxval)
+
+
+def prepare_geometry_from_arrays(
+ vertices,
+ faces,
+ overlay=None,
+ annot=None,
+ ctab=None,
+ bg_map=None,
+ roi=None,
+ minval=None,
+ maxval=None,
+ invert=False,
+ scale=1.85,
+ color_mode=ColorSelection.BOTH,
+):
+ """Prepare vertex and color arrays for GPU upload from numpy arrays.
+
+ This is the core geometry preparation function. All inputs must already
+ be resolved numpy arrays; for file-path support use the thin wrapper
+ :func:`prepare_geometry`.
+
+ Parameters
+ ----------
+ vertices : numpy.ndarray
+ Vertex coordinate array of shape (N, 3), dtype float32.
+ faces : numpy.ndarray
+ Triangle index array of shape (M, 3), dtype uint32.
+ overlay : numpy.ndarray or None, optional
+ Per-vertex scalar values of shape (N,) float32 used for coloring.
+ annot : numpy.ndarray or None, optional
+ Per-vertex integer label indices of shape (N,) int32.
+ ctab : numpy.ndarray or None, optional
+ Color table array (n_labels, ≥3) associated with *annot*.
+ bg_map : numpy.ndarray or None, optional
+ Per-vertex scalar values of shape (N,) float32 whose sign determines
+ background shading (binary light/dark). When ``None`` a flat gray
+ background is used.
+ roi : numpy.ndarray of bool or None, optional
+ Boolean mask of shape (N,). ``True`` = vertex is inside the region
+ of interest and receives overlay coloring; ``False`` = vertex falls
+ back to background shading. When ``None`` all vertices are in-ROI.
+ minval, maxval : float or None, optional
+ Threshold and saturation values for overlay scaling.
+ invert : bool, optional, default False
+ Invert color mapping.
+ scale : float, optional, default 1.85
+ Geometry scaling factor applied by :func:`normalize_mesh`.
+ color_mode : ColorSelection, optional, default ColorSelection.BOTH
+ Which sign(s) of overlay values to use for coloring.
+
+ Returns
+ -------
+ vertexdata : numpy.ndarray
+ Nx9 array (position x3, normal x3, color x3) ready for GPU upload.
+ triangles : numpy.ndarray
+ Mx3 uint32 triangle index array.
+ fmin, fmax : float or None
+ Final threshold and saturation values used for color mapping.
+ pos, neg : bool or None
+ Flags indicating whether positive/negative overlay values are present.
+
+ Raises
+ ------
+ ValueError
+ If overlay or annotation arrays do not match the surface vertex count.
+ """
+ vertices = normalize_mesh(np.array(vertices, dtype=np.float32), scale)
+ triangles = np.array(faces, dtype=np.uint32)
+ vnormals = np.array(vertex_normals(vertices, triangles), dtype=np.float32)
+ num_vertices = vertices.shape[0]
+
+ # Build background (sulcal) colormap
+ if bg_map is not None:
+ if bg_map.shape[0] != num_vertices:
+ warnings.warn(
+ f"bg_map has {bg_map.shape[0]} values but mesh has {num_vertices}.",
+ stacklevel=2,
+ )
+ sulcmap = 0.5 * np.ones(vertices.shape, dtype=np.float32)
+ else:
+ sulcmap = binary_color(bg_map, 0.0, color_low=0.5, color_high=0.33)
+ else:
+ sulcmap = 0.5 * np.ones(vertices.shape, dtype=np.float32)
+
+ # Initialize defaults for overlay outputs
+ fmin = None
+ fmax = None
+ pos = None
+ neg = None
+ colors = sulcmap # use as default
+
+ # Apply overlay coloring
+ if overlay is not None:
+ if overlay.shape[0] != num_vertices:
+ raise ValueError(
+ f"overlay has {overlay.shape[0]} values but mesh has {num_vertices}.\n"
+ "This usually means the overlay does not match the provided surface "
+ "(e.g. RH overlay used with LH surface). Provide the correct overlay."
+ )
+ mapdata = overlay.copy().astype(np.float64)
+ minval, maxval = _estimate_thresholds_from_array(mapdata, minval, maxval)
+ mapdata = mask_sign(mapdata, color_mode)
+ mapdata, fmin, fmax, pos, neg = rescale_overlay(mapdata, minval, maxval)
+ colors = heat_color(mapdata, invert)
+ # Some mapdata values could be nan (below min threshold) — fall back to bg
+ missing = np.isnan(mapdata)
+ if np.any(missing):
+ colors[missing, :] = sulcmap[missing, :]
+
+ elif annot is not None and ctab is not None:
+ # Per-vertex annotation coloring
+ if annot.shape[0] != num_vertices:
+ raise ValueError(
+ f"annot has {annot.shape[0]} values but mesh has {num_vertices}.\n"
+ "This usually means the .annot does not match the provided surface "
+ "(e.g. RH annot used with LH surface). Provide the correct annot file."
+ )
+ annot = annot.astype(np.int32)
+ colors = np.array(sulcmap, dtype=np.float32)
+ ctab_rgb = np.asarray(ctab[:, 0:3], dtype=np.float32)
+ denom = 255.0 if np.max(ctab_rgb) > 1 else 1.0
+ valid = (annot >= 0) & (annot < ctab.shape[0])
+ if np.any(valid):
+ colors[valid, :] = ctab_rgb[annot[valid], :] / denom
+
+ # Ensure colors dtype matches vertices/normals
+ colors = np.asarray(colors, dtype=np.float32)
+
+ # Apply ROI mask: vertices where roi == False fall back to sulcmap
+ if roi is not None:
+ outside = ~roi
+ if np.any(outside):
+ colors[outside, :] = sulcmap[outside, :]
+
+ vertexdata = np.concatenate((vertices, vnormals, colors), axis=1)
+ return vertexdata, triangles, fmin, fmax, pos, neg
+
+
+def prepare_geometry(
+ mesh,
+ overlay=None,
+ annot=None,
+ bg_map=None,
+ roi=None,
+ minval=None,
+ maxval=None,
+ invert=False,
+ scale=1.85,
+ color_mode=ColorSelection.BOTH,
+):
+ """Prepare vertex and color arrays for GPU upload.
+
+ This is a thin file-loading wrapper around
+ :func:`prepare_geometry_from_arrays`. Inputs are resolved via the
+ functions in :mod:`whippersnappy.geometry.inputs` so that every
+ parameter can be either a file path or a numpy array.
+
+ Parameters
+ ----------
+ mesh : str or tuple of (array-like, array-like)
+ Surface file path (FreeSurfer format) **or** a ``(vertices, faces)``
+ tuple/list where *vertices* is (N, 3) float and *faces* is (M, 3) int.
+ overlay : str, array-like, or None, optional
+ Path to an overlay (.mgh / FreeSurfer morph) file, or a (N,) array
+ of per-vertex scalar values.
+ annot : str, tuple, or None, optional
+ Path to a FreeSurfer .annot file, or a ``(labels, ctab)`` /
+ ``(labels, ctab, names)`` tuple.
+ bg_map : str, array-like, or None, optional
+ Path to a curvature/morph file used for background shading, or a
+ (N,) array whose sign determines light/dark shading.
+ roi : str, array-like, or None, optional
+ Path to a FreeSurfer label file or a (N,) boolean array. Vertices
+ with ``True`` receive overlay coloring; others fall back to *bg_map*.
+ minval, maxval : float or None, optional
+ Threshold and saturation values for overlay scaling.
+ invert : bool, optional, default False
+ Invert color mapping.
+ scale : float, optional, default 1.85
+ Geometry scaling factor applied by :func:`normalize_mesh`.
+ color_mode : ColorSelection, optional, default ColorSelection.BOTH
+ Which sign(s) of overlay values to use for coloring.
+
+ Returns
+ -------
+ vertexdata : numpy.ndarray
+ Nx9 array (position x3, normal x3, color x3) ready for GPU upload.
+ triangles : numpy.ndarray
+ Mx3 uint32 triangle index array.
+ fmin, fmax : float or None
+ Final threshold and saturation values used for color mapping.
+ pos, neg : bool or None
+ Flags indicating whether positive/negative overlay values are present.
+
+ Raises
+ ------
+ TypeError
+ If *mesh* is not a valid type.
+ ValueError
+ If overlay or annotation arrays do not match the surface vertex count.
+
+ Examples
+ --------
+ File-path usage::
+
+ vdata, tris, fmin, fmax, pos, neg = prepare_geometry(
+ 'fsaverage/surf/lh.white',
+ overlay='fsaverage/surf/lh.thickness',
+ bg_map='fsaverage/surf/lh.curv',
+ roi='fsaverage/label/lh.cortex.label',
+ )
+
+ Array inputs::
+
+ import numpy as np
+ v = np.array([[0,0,0],[1,0,0],[0,1,0],[0,0,1]], dtype=np.float32)
+ f = np.array([[0,1,2],[0,1,3],[0,2,3],[1,2,3]], dtype=np.uint32)
+ vdata, tris, *_ = prepare_geometry((v, f))
+ """
+ vertices, faces = resolve_mesh(mesh)
+ n = vertices.shape[0]
+ overlay_arr = resolve_overlay(overlay, n_vertices=n)
+ bg_map_arr = resolve_bg_map(bg_map, n_vertices=n)
+ roi_arr = resolve_roi(roi, n_vertices=n)
+ annot_result = resolve_annot(annot, n_vertices=n)
+ annot_arr = annot_result[0] if annot_result is not None else None
+ ctab_arr = annot_result[1] if annot_result is not None else None
+ return prepare_geometry_from_arrays(
+ vertices, faces, overlay_arr, annot_arr, ctab_arr,
+ bg_map_arr, roi_arr, minval, maxval, invert, scale, color_mode,
+ )
+
+
+def prepare_and_validate_geometry(
+ mesh,
+ overlay=None,
+ annot=None,
+ bg_map=None,
+ roi=None,
+ fthresh=None,
+ fmax=None,
+ invert=False,
+ scale=1.85,
+ color_mode=ColorSelection.BOTH,
+):
+ """Load and validate mesh geometry and overlay/annotation inputs.
+
+ This is a small wrapper around :func:`prepare_geometry` that performs
+ the same overlay-presence validation used throughout the static snapshot
+ helpers.
+
+ Parameters
+ ----------
+ mesh : str or tuple
+ Passed through to :func:`prepare_geometry`.
+ overlay, annot, bg_map, roi : str, array-like, or None
+ Passed through to :func:`prepare_geometry`.
+ fthresh, fmax : float or None
+ Threshold and saturation values passed to the geometry preparer.
+ invert : bool
+ Passed to the geometry preparer.
+ scale : float
+ Scaling factor passed to the geometry preparer.
+ color_mode : ColorSelection
+ Which sign of overlay to display (POSITIVE/NEGATIVE/BOTH).
+
+ Returns
+ -------
+ tuple
+ ``(meshdata, triangles, fthresh, fmax, pos, neg)`` as returned by
+ :func:`prepare_geometry`.
+
+ Raises
+ ------
+ ValueError
+ If the overlay contains no values appropriate for ``color_mode``.
+ """
+ import logging
+ logger = logging.getLogger(__name__)
+ meshdata, triangles, out_fthresh, out_fmax, pos, neg = prepare_geometry(
+ mesh,
+ overlay,
+ annot,
+ bg_map,
+ roi,
+ fthresh,
+ fmax,
+ invert,
+ scale=scale,
+ color_mode=color_mode,
+ )
+
+ # Validate overlay presence similar to previous inline checks
+ if overlay is not None:
+ if color_mode == ColorSelection.POSITIVE:
+ if not pos and neg:
+ logger.error("Overlay has no values to display with positive color_mode")
+ raise ValueError("Overlay has no values to display with positive color_mode")
+ neg = False
+ elif color_mode == ColorSelection.NEGATIVE:
+ if pos and not neg:
+ logger.error("Overlay has no values to display with negative color_mode")
+ raise ValueError("Overlay has no values to display with negative color_mode")
+ pos = False
+ if not pos and not neg:
+ logger.error("Overlay has no values to display")
+ raise ValueError("Overlay has no values to display")
+
+ return meshdata, triangles, out_fthresh, out_fmax, pos, neg
+
diff --git a/whippersnappy/geometry/surf_name.py b/whippersnappy/geometry/surf_name.py
new file mode 100644
index 0000000..1d1f07d
--- /dev/null
+++ b/whippersnappy/geometry/surf_name.py
@@ -0,0 +1,32 @@
+"""Helper for finding a surface file name inside a subject directory.
+
+This replaces the previous `io.py` name which was generic; `surf_name.py` is
+more descriptive (it provides `get_surf_name`).
+"""
+import os
+
+
+def get_surf_name(sdir, hemi):
+ """Find a suitable surface basename in a subject directory.
+
+ The function searches the standard FreeSurfer `surf/` directory for a
+ common surface name in order of preference and returns the basename
+ (search for 'pial_semi_inflated', 'white', and then 'inflated').
+
+ Parameters
+ ----------
+ sdir : str
+ Path to the subject directory containing a `surf/` subdirectory.
+ hemi : {'lh','rh'}
+ Hemisphere prefix to use when searching for surface files.
+
+ Returns
+ -------
+ surf_name : str or None
+ The surface basename if found, otherwise ``None``.
+ """
+ for surf_name_option in ["pial_semi_inflated", "white", "inflated"]:
+ path = os.path.join(sdir, "surf", f"{hemi}.{surf_name_option}")
+ if os.path.exists(path):
+ return surf_name_option
+ return None
diff --git a/whippersnappy/gl/__init__.py b/whippersnappy/gl/__init__.py
new file mode 100644
index 0000000..4b458bf
--- /dev/null
+++ b/whippersnappy/gl/__init__.py
@@ -0,0 +1,46 @@
+"""OpenGL helper utilities (gl package).
+
+This package replaces the previous `gl_utils.py` module.
+Functions are re-exported at package level for convenience, e.g.:
+
+ from whippersnappy.gl import init_window, setup_shader
+
+"""
+
+from . import _platform # noqa: F401 — MUST be first; sets PYOPENGL_PLATFORM
+from .camera import make_model, make_projection, make_view
+from .shaders import get_default_shaders, get_webgl_shaders
+from .utils import (
+ capture_window,
+ compile_shader_program,
+ create_vao,
+ create_window_with_fallback,
+ init_window,
+ render_scene,
+ set_camera_uniforms,
+ set_default_gl_state,
+ set_lighting_uniforms,
+ setup_buffers,
+ setup_shader,
+ setup_vertex_attributes,
+ terminate_context,
+)
+from .views import (
+ ViewState,
+ arcball_rotation_matrix,
+ arcball_vector,
+ compute_view_matrix,
+ get_view_matrices,
+ get_view_matrix,
+)
+
+__all__ = [
+ 'create_vao', 'compile_shader_program', 'setup_buffers', 'setup_vertex_attributes',
+ 'set_default_gl_state', 'set_camera_uniforms', 'set_lighting_uniforms',
+ 'init_window', 'render_scene', 'setup_shader', 'capture_window',
+ 'make_model', 'make_projection', 'make_view',
+ 'get_default_shaders', 'get_view_matrices', 'get_view_matrix',
+ 'get_webgl_shaders', 'terminate_context',
+ 'ViewState', 'compute_view_matrix',
+ 'arcball_vector', 'arcball_rotation_matrix',
+]
diff --git a/whippersnappy/gl/_platform.py b/whippersnappy/gl/_platform.py
new file mode 100644
index 0000000..8a94ad4
--- /dev/null
+++ b/whippersnappy/gl/_platform.py
@@ -0,0 +1,17 @@
+"""Bootstrap PyOpenGL platform selection — must be imported first.
+
+Imported unconditionally at the top of gl/__init__.py before any other
+OpenGL symbol. Sets PYOPENGL_PLATFORM=egl when running headless on Linux
+so that PyOpenGL uses the EGL backend instead of GLX.
+
+On macOS PyOpenGL uses CGL and on Windows it uses WGL — both are handled
+natively without EGL. If the user has already set PYOPENGL_PLATFORM that
+value is always respected.
+"""
+import os
+import sys
+
+if "PYOPENGL_PLATFORM" not in os.environ and sys.platform == "linux":
+ # No X11/Wayland display on Linux → force EGL headless backend.
+ if not os.environ.get("DISPLAY") and not os.environ.get("WAYLAND_DISPLAY"):
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
diff --git a/whippersnappy/gl/camera.py b/whippersnappy/gl/camera.py
new file mode 100644
index 0000000..bcfa376
--- /dev/null
+++ b/whippersnappy/gl/camera.py
@@ -0,0 +1,76 @@
+"""Camera and transform helpers (moved under gl package)."""
+
+import pyrr
+
+
+def make_projection(width, height, fov=20.0, near=0.1, far=100.0):
+ """Create a 4x4 perspective projection matrix.
+
+ Parameters
+ ----------
+ width, height : int
+ Viewport dimensions in pixels (used to compute aspect ratio).
+ fov : float, optional, default 20.0
+ Vertical field of view in degrees.
+ near, far : float, optional, default 0.1, 100.0
+ Near and far clipping planes. Default are 0.1 and 100.0, respectively.
+
+ Returns
+ -------
+ numpy.ndarray
+ 4x4 projection matrix.
+ """
+ return pyrr.matrix44.create_perspective_projection(fov, width / height, near, far)
+
+
+def make_view(camera_pos=(0.0, 0.0, -5.0)):
+ """Create a view matrix for a camera located at ``camera_pos``.
+
+ Parameters
+ ----------
+ camera_pos : sequence of float, optional, default (0.0, 0.0, -5.0)
+ 3-element position of the camera in world space.
+ Default is (0.0, 0.0, -5.0).
+
+ Returns
+ -------
+ numpy.ndarray
+ 4x4 view matrix.
+ """
+ return pyrr.matrix44.create_from_translation(pyrr.Vector3(camera_pos))
+
+
+def make_model():
+ """Create a default model matrix (identity translation).
+
+ Returns
+ -------
+ numpy.ndarray
+ 4x4 model matrix.
+ """
+ return pyrr.matrix44.create_from_translation(pyrr.Vector3([0.0, 0.0, 0.0]))
+
+
+def make_transform(translation, rotation, scale):
+ """Build a model transform matrix from translation, rotation and uniform scale.
+
+ Parameters
+ ----------
+ translation : sequence of float
+ 3-element translation vector.
+ rotation : numpy.ndarray
+ 4x4 rotation matrix.
+ scale : float
+ Uniform scaling factor.
+
+ Returns
+ -------
+ numpy.ndarray
+ 4x4 transformation matrix (translation * rotation * scale).
+ """
+ scale_matrix = pyrr.matrix44.create_from_scale([scale, scale, scale])
+ return (
+ pyrr.matrix44.create_from_translation(translation)
+ * rotation
+ * scale_matrix
+ )
diff --git a/whippersnappy/gl/egl_context.py b/whippersnappy/gl/egl_context.py
new file mode 100644
index 0000000..950dffa
--- /dev/null
+++ b/whippersnappy/gl/egl_context.py
@@ -0,0 +1,374 @@
+"""EGL off-screen (headless) OpenGL context via pbuffer + FBO.
+
+This module provides a drop-in alternative to GLFW window creation for
+headless environments (CI, Docker, HPC clusters) where no X11/Wayland
+display is available. It requires:
+
+ - A system EGL library (``libegl1`` on Debian/Ubuntu, already present
+ in the WhipperSnapPy Dockerfile).
+ - PyOpenGL >= 3.1 (already a project dependency), which ships
+ ``OpenGL.EGL`` bindings.
+ - Either an NVIDIA GPU with the EGL driver, or Mesa ``libEGL-mesa0``
+ (llvmpipe software renderer) for CPU-only systems.
+
+Typical usage (internal, called from ``create_window_with_fallback``)::
+
+ from whippersnappy.gl.egl_context import EGLContext
+
+ ctx = EGLContext(width, height)
+ ctx.make_current()
+ # ... OpenGL calls ...
+ img = ctx.read_pixels()
+ ctx.destroy()
+"""
+
+import ctypes
+import logging
+import os
+import sys
+
+if sys.platform == "darwin":
+ raise ImportError("EGL is not available on macOS; use GLFW/CGL instead.")
+
+# Must be set before OpenGL.GL is imported anywhere in the process.
+# If already set (e.g. user set it, or GLFW succeeded), respect it.
+# We set it here because this module is only imported when EGL is needed.
+if os.environ.get("PYOPENGL_PLATFORM") != "egl":
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
+
+import OpenGL.GL as gl
+from PIL import Image
+
+logger = logging.getLogger(__name__)
+
+# ---------------------------------------------------------------------------
+# EGL constants not exposed by all PyOpenGL versions
+# ---------------------------------------------------------------------------
+_EGL_SURFACE_TYPE = 0x3033
+_EGL_PBUFFER_BIT = 0x0001
+_EGL_RENDERABLE_TYPE = 0x3040
+_EGL_OPENGL_BIT = 0x0008
+_EGL_NONE = 0x3038
+_EGL_WIDTH = 0x3057
+_EGL_HEIGHT = 0x3056
+_EGL_OPENGL_API = 0x30A2
+_EGL_CONTEXT_MAJOR_VERSION = 0x3098
+_EGL_CONTEXT_MINOR_VERSION = 0x30FB
+_EGL_PLATFORM_DEVICE_EXT = 0x313F
+
+
+class EGLContext:
+ """A headless OpenGL 3.3 Core context backed by an EGL pbuffer + FBO.
+
+ The pbuffer surface is created solely to satisfy EGL's requirement for
+ a surface when calling ``eglMakeCurrent``. All rendering is directed
+ into an off-screen Framebuffer Object (FBO) so that ``glReadPixels``
+ captures exactly what was rendered regardless of platform quirks with
+ pbuffer readback.
+
+ Parameters
+ ----------
+ width, height : int
+ Dimensions of the off-screen render target in pixels.
+
+ Attributes
+ ----------
+ width, height : int
+ Render target dimensions.
+ fbo : int
+ OpenGL FBO handle (valid after ``make_current`` is called).
+
+ Raises
+ ------
+ ImportError
+ If ``OpenGL.EGL`` bindings are not available.
+ RuntimeError
+ If any EGL initialisation step fails.
+ """
+
+ def __init__(self, width: int, height: int):
+ self.width = width
+ self.height = height
+ self._libegl = None
+ self._display = None
+ self._surface = None
+ self._context = None
+ self._config = None
+ self.fbo = None
+ self._rbo_color = None
+ self._rbo_depth = None
+ self._init_egl()
+
+ # ------------------------------------------------------------------
+ # Private helpers
+ # ------------------------------------------------------------------
+
+ def _get_ext_fn(self, name, restype, argtypes):
+ """Load an EGL extension function via eglGetProcAddress."""
+ addr = self._libegl.eglGetProcAddress(name.encode())
+ if not addr:
+ raise RuntimeError(
+ f"eglGetProcAddress('{name}') returned NULL — "
+ f"extension not available on this driver."
+ )
+ FuncType = ctypes.CFUNCTYPE(restype, *argtypes)
+ return FuncType(addr)
+
+ def _init_egl(self):
+ import ctypes.util
+
+ egl_name = ctypes.util.find_library("EGL") or "libEGL.so.1"
+ try:
+ libegl = ctypes.CDLL(egl_name)
+ except OSError as e:
+ raise RuntimeError(
+ f"Could not load {egl_name}. "
+ "Install libegl1-mesa and retry."
+ ) from e
+ self._libegl = libegl # keep reference alive
+
+ # Set signatures for direct (non-extension) EGL symbols
+ libegl.eglGetProcAddress.restype = ctypes.c_void_p
+ libegl.eglGetProcAddress.argtypes = [ctypes.c_char_p]
+ libegl.eglQueryString.restype = ctypes.c_char_p
+ libegl.eglQueryString.argtypes = [ctypes.c_void_p, ctypes.c_int]
+ libegl.eglInitialize.restype = ctypes.c_bool
+ libegl.eglInitialize.argtypes = [ctypes.c_void_p,
+ ctypes.POINTER(ctypes.c_int),
+ ctypes.POINTER(ctypes.c_int)]
+ libegl.eglBindAPI.restype = ctypes.c_bool
+ libegl.eglBindAPI.argtypes = [ctypes.c_uint]
+ libegl.eglChooseConfig.restype = ctypes.c_bool
+ libegl.eglChooseConfig.argtypes = [ctypes.c_void_p,
+ ctypes.POINTER(ctypes.c_int),
+ ctypes.c_void_p, ctypes.c_int,
+ ctypes.POINTER(ctypes.c_int)]
+ libegl.eglCreatePbufferSurface.restype = ctypes.c_void_p
+ libegl.eglCreatePbufferSurface.argtypes = [ctypes.c_void_p, ctypes.c_void_p,
+ ctypes.POINTER(ctypes.c_int)]
+ libegl.eglCreateContext.restype = ctypes.c_void_p
+ libegl.eglCreateContext.argtypes = [ctypes.c_void_p, ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.POINTER(ctypes.c_int)]
+ libegl.eglMakeCurrent.restype = ctypes.c_bool
+ libegl.eglMakeCurrent.argtypes = [ctypes.c_void_p, ctypes.c_void_p,
+ ctypes.c_void_p, ctypes.c_void_p]
+ libegl.eglDestroyContext.restype = ctypes.c_bool
+ libegl.eglDestroyContext.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ libegl.eglDestroySurface.restype = ctypes.c_bool
+ libegl.eglDestroySurface.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ libegl.eglTerminate.restype = ctypes.c_bool
+ libegl.eglTerminate.argtypes = [ctypes.c_void_p]
+
+ # Check extensions and load ext functions via eglGetProcAddress
+ _EGL_EXTENSIONS = 0x3055
+ client_exts = libegl.eglQueryString(None, _EGL_EXTENSIONS) or b""
+ logger.debug("EGL client extensions: %s", client_exts.decode())
+
+ has_device_enum = b"EGL_EXT_device_enumeration" in client_exts
+ has_platform_base = b"EGL_EXT_platform_base" in client_exts
+
+ display = None
+ if has_device_enum and has_platform_base:
+ eglQueryDevicesEXT = self._get_ext_fn(
+ "eglQueryDevicesEXT",
+ ctypes.c_bool,
+ [ctypes.c_int, ctypes.c_void_p, ctypes.POINTER(ctypes.c_int)],
+ )
+ eglGetPlatformDisplayEXT = self._get_ext_fn(
+ "eglGetPlatformDisplayEXT",
+ ctypes.c_void_p,
+ [ctypes.c_int, ctypes.c_void_p, ctypes.POINTER(ctypes.c_int)],
+ )
+ display = self._open_device_display(
+ eglQueryDevicesEXT, eglGetPlatformDisplayEXT
+ )
+
+ if display is None:
+ logger.debug("Falling back to eglGetDisplay(EGL_DEFAULT_DISPLAY)")
+ libegl.eglGetDisplay.restype = ctypes.c_void_p
+ libegl.eglGetDisplay.argtypes = [ctypes.c_void_p]
+ display = libegl.eglGetDisplay(ctypes.c_void_p(0))
+
+ if not display:
+ raise RuntimeError(
+ "Could not obtain any EGL display. "
+ "Install libegl1-mesa for CPU rendering."
+ )
+ self._display = display
+
+ major, minor = ctypes.c_int(0), ctypes.c_int(0)
+ if not libegl.eglInitialize(
+ self._display, ctypes.byref(major), ctypes.byref(minor)
+ ):
+ raise RuntimeError("eglInitialize failed.")
+ logger.debug("EGL %d.%d", major.value, minor.value)
+
+ if not libegl.eglBindAPI(_EGL_OPENGL_API):
+ raise RuntimeError("eglBindAPI(OpenGL) failed.")
+
+ cfg_attribs = (ctypes.c_int * 7)(
+ _EGL_SURFACE_TYPE, _EGL_PBUFFER_BIT,
+ _EGL_RENDERABLE_TYPE, _EGL_OPENGL_BIT,
+ _EGL_NONE,
+ )
+ configs = (ctypes.c_void_p * 1)()
+ num_cfgs = ctypes.c_int(0)
+ if not libegl.eglChooseConfig(
+ self._display, cfg_attribs, configs, 1, ctypes.byref(num_cfgs)
+ ) or num_cfgs.value == 0:
+ raise RuntimeError("eglChooseConfig: no suitable config.")
+ self._config = configs[0]
+
+ pbuf_attribs = (ctypes.c_int * 5)(
+ _EGL_WIDTH, 1, _EGL_HEIGHT, 1, _EGL_NONE
+ )
+ self._surface = libegl.eglCreatePbufferSurface(
+ self._display, self._config, pbuf_attribs
+ )
+ if not self._surface:
+ raise RuntimeError("eglCreatePbufferSurface failed.")
+
+ ctx_attribs = (ctypes.c_int * 5)(
+ _EGL_CONTEXT_MAJOR_VERSION, 3,
+ _EGL_CONTEXT_MINOR_VERSION, 3,
+ _EGL_NONE,
+ )
+ self._context = libegl.eglCreateContext(
+ self._display, self._config, None, ctx_attribs
+ )
+ if not self._context:
+ raise RuntimeError(
+ "eglCreateContext for OpenGL 3.3 Core failed. "
+ "Try: MESA_GL_VERSION_OVERRIDE=3.3 MESA_GLSL_VERSION_OVERRIDE=330"
+ )
+ logger.info("EGL context created (%dx%d)", self.width, self.height)
+
+
+ def _open_device_display(self, eglQueryDevicesEXT, eglGetPlatformDisplayEXT):
+ """Enumerate EGL devices and return first usable display pointer."""
+ n = ctypes.c_int(0)
+ if not eglQueryDevicesEXT(0, None, ctypes.byref(n)) or n.value == 0:
+ logger.warning("eglQueryDevicesEXT: no devices.")
+ return None
+ logger.debug("EGL: %d device(s) found", n.value)
+ devices = (ctypes.c_void_p * n.value)()
+ eglQueryDevicesEXT(n.value, devices, ctypes.byref(n))
+ no_attribs = (ctypes.c_int * 1)(_EGL_NONE)
+ for i, dev in enumerate(devices):
+ dpy = eglGetPlatformDisplayEXT(
+ _EGL_PLATFORM_DEVICE_EXT, ctypes.c_void_p(dev), no_attribs
+ )
+ if dpy:
+ logger.debug("EGL: using device %d", i)
+ return dpy
+ return None
+
+
+ def make_current(self):
+ """Make this EGL context current and set up the FBO render target.
+
+ Must be called before any OpenGL commands. Creates and binds an
+ FBO backed by two renderbuffers (RGBA color + depth/stencil).
+ """
+ if not self._libegl.eglMakeCurrent(
+ self._display, self._surface, self._surface, self._context
+ ):
+ raise RuntimeError("eglMakeCurrent failed.")
+
+ # Force PyOpenGL to discover and cache the context we just made current.
+ # PyOpenGL's contextdata module only recognizes contexts it has "seen"
+ # via at least one GL call; glGetError() is the cheapest trigger.
+ gl.glGetError()
+
+ # Build FBO so rendering is directed off-screen
+ self.fbo = gl.glGenFramebuffers(1)
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.fbo)
+
+ # Color renderbuffer
+ self._rbo_color = gl.glGenRenderbuffers(1)
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._rbo_color)
+ gl.glRenderbufferStorage(
+ gl.GL_RENDERBUFFER, gl.GL_RGBA8, self.width, self.height
+ )
+ gl.glFramebufferRenderbuffer(
+ gl.GL_FRAMEBUFFER,
+ gl.GL_COLOR_ATTACHMENT0,
+ gl.GL_RENDERBUFFER,
+ self._rbo_color,
+ )
+
+ # Depth + stencil renderbuffer
+ self._rbo_depth = gl.glGenRenderbuffers(1)
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._rbo_depth)
+ gl.glRenderbufferStorage(
+ gl.GL_RENDERBUFFER,
+ gl.GL_DEPTH24_STENCIL8,
+ self.width,
+ self.height,
+ )
+ gl.glFramebufferRenderbuffer(
+ gl.GL_FRAMEBUFFER,
+ gl.GL_DEPTH_STENCIL_ATTACHMENT,
+ gl.GL_RENDERBUFFER,
+ self._rbo_depth,
+ )
+
+ status = gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER)
+ if status != gl.GL_FRAMEBUFFER_COMPLETE:
+ raise RuntimeError(
+ f"FBO is not complete after EGL setup (status=0x{status:X})."
+ )
+
+ # Set the viewport to match the render target
+ gl.glViewport(0, 0, self.width, self.height)
+ logger.debug("EGL FBO complete and bound (%dx%d)", self.width, self.height)
+
+ def read_pixels(self) -> Image.Image:
+ """Read the FBO contents and return a PIL RGB Image.
+
+ Returns
+ -------
+ PIL.Image.Image
+ Captured frame, vertically flipped to convert from OpenGL's
+ bottom-left origin to image top-left convention.
+ """
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.fbo)
+ gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
+ buf = gl.glReadPixels(
+ 0, 0, self.width, self.height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE
+ )
+ img = Image.frombytes("RGB", (self.width, self.height), buf)
+ return img.transpose(Image.FLIP_TOP_BOTTOM)
+
+ def destroy(self):
+ libegl = self._libegl
+ # GL cleanup first (context must be current)
+ if self.fbo is not None:
+ gl.glDeleteFramebuffers(1, [self.fbo])
+ self.fbo = None
+ if self._rbo_color is not None:
+ gl.glDeleteRenderbuffers(1, [self._rbo_color])
+ self._rbo_color = None
+ if self._rbo_depth is not None:
+ gl.glDeleteRenderbuffers(1, [self._rbo_depth])
+ self._rbo_depth = None
+ if self._display:
+ libegl.eglMakeCurrent(self._display, None, None, None)
+ if self._context:
+ libegl.eglDestroyContext(self._display, self._context)
+ if self._surface:
+ libegl.eglDestroySurface(self._display, self._surface)
+ libegl.eglTerminate(self._display)
+ self._display = None
+ self._context = None
+ self._surface = None
+ logger.debug("EGL context destroyed.")
+
+ # Allow use as a context manager
+ def __enter__(self):
+ self.make_current()
+ return self
+
+ def __exit__(self, *_):
+ self.destroy()
diff --git a/whippersnappy/gl/shaders.py b/whippersnappy/gl/shaders.py
new file mode 100644
index 0000000..fd3c789
--- /dev/null
+++ b/whippersnappy/gl/shaders.py
@@ -0,0 +1,217 @@
+"""Shared shader sources inside the gl package."""
+
+def get_default_shaders():
+ """Return the default GLSL 330 vertex and fragment shader sources.
+
+ These shaders are intended for desktop OpenGL (GLSL 330) use in the
+ offline snapshot renderer. The returned strings contain full GLSL
+ shader sources for vertex and fragment stages.
+
+ Returns
+ -------
+ vertex_shader, fragment_shader : tuple[str, str]
+ Tuple containing the vertex shader and fragment shader source code
+ as plain strings.
+ """
+ vertex_shader = """
+
+ #version 330
+
+ layout (location = 0) in vec3 aPos;
+ layout (location = 1) in vec3 aNormal;
+ layout (location = 2) in vec3 aColor;
+
+ out vec3 FragPos;
+ out vec3 Normal;
+ out vec3 Color;
+
+ uniform mat4 transform;
+ uniform mat4 model;
+ uniform mat4 view;
+ uniform mat4 projection;
+
+ void main()
+ {
+ gl_Position = projection * view * model * transform * vec4(aPos, 1.0f);
+ FragPos = vec3(model * transform * vec4(aPos, 1.0));
+ // normal matrix should be computed outside and passed!
+ Normal = mat3(transpose(inverse(view * model * transform))) * aNormal;
+ Color = aColor;
+ }
+
+ """
+
+ fragment_shader = """
+ #version 330
+
+ in vec3 FragPos;
+ in vec3 Normal;
+ in vec3 Color;
+
+ out vec4 FragColor;
+
+ uniform vec3 lightColor = vec3(1.0, 1.0, 1.0);
+ uniform bool doSpecular = true;
+ uniform float ambientStrength = 0.0;
+
+ void main()
+ {
+ // ambient
+ vec3 ambient = ambientStrength * lightColor;
+
+ // diffuse
+ vec3 norm = normalize(Normal);
+ vec4 diffweights = vec4(0.6, 0.4, 0.4, 0.3);
+
+ // key light (overhead)
+ vec3 lightPos1 = vec3(0.0,5.0,5.0);
+ vec3 lightDir = normalize(lightPos1 - FragPos);
+ float diff = max(dot(norm, lightDir), 0.0);
+ vec3 diffuse = diffweights[0] * diff * lightColor;
+
+ // headlight (at camera)
+ vec3 lightPos2 = vec3(0.0,0.0,5.0);
+ lightDir = normalize(lightPos2 - FragPos);
+ vec3 ohlightDir = lightDir;
+ diff = max(dot(norm, lightDir), 0.0);
+ diffuse = diffuse + diffweights[1] * diff * lightColor;
+
+ // fill light (from below)
+ vec3 lightPos3 = vec3(0.0,-5.0,5.0);
+ lightDir = normalize(lightPos3 - FragPos);
+ diff = max(dot(norm, lightDir), 0.0);
+ diffuse = diffuse + diffweights[2] * diff * lightColor;
+
+ // left right back lights
+ vec3 lightPos4 = vec3(5.0,0.0,-5.0);
+ lightDir = normalize(lightPos4 - FragPos);
+ diff = max(dot(norm, lightDir), 0.0);
+ diffuse = diffuse + diffweights[3] * diff * lightColor;
+
+ vec3 lightPos5 = vec3(-5.0,0.0,-5.0);
+ lightDir = normalize(lightPos5 - FragPos);
+ diff = max(dot(norm, lightDir), 0.0);
+ diffuse = diffuse + diffweights[3] * diff * lightColor;
+
+ // specular — camera is at (0,0,-5) in world space (from make_view),
+ // not at origin, so viewDir must point from FragPos toward (0,0,-5).
+ vec3 result;
+ if (doSpecular)
+ {
+ float specularStrength = 0.5;
+ vec3 cameraPos = vec3(0.0, 0.0, -5.0);
+ vec3 viewDir = normalize(cameraPos - FragPos);
+ vec3 reflectDir = reflect(ohlightDir, norm);
+ float spec = pow(max(dot(viewDir, reflectDir), 0.0), 32);
+ vec3 specular = specularStrength * spec * lightColor;
+ result = (ambient + diffuse + specular) * Color;
+ }
+ else
+ {
+ result = (ambient + diffuse) * Color;
+ }
+ FragColor = vec4(result, 1.0);
+ }
+
+ """
+
+ return vertex_shader, fragment_shader
+
+
+def get_webgl_shaders():
+ """Return vertex and fragment shader source strings suitable for WebGL/Three.js.
+
+ These shader snippets are small GLSL pieces that expect Three.js to
+ provide built-in attributes/uniforms (e.g. projectionMatrix,
+ modelViewMatrix, normalMatrix). They are used by the Jupyter
+ pythreejs-based viewer.
+
+ Returns
+ -------
+ vertex_shader, fragment_shader : tuple[str, str]
+ Vertex and fragment shader source strings for WebGL / Three.js.
+ """
+
+ # Only declare custom attributes - Three.js provides position, normal, matrices
+ # Don't declare position, normal, *Matrix
+ # Only attributes like color , or uniforms like lightColor, ambientStrenght
+ # Use normalMatrix instead of computing transpose...
+ vertex_shader = """
+ attribute vec3 color;
+
+ varying vec3 vFragPos;
+ varying vec3 vNormal;
+ varying vec3 vColor;
+
+ void main()
+ {
+ gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
+ vFragPos = vec3(modelViewMatrix * vec4(position, 1.0));
+ vNormal = normalMatrix * normal;
+ vColor = color;
+ }
+ """
+
+ fragment_shader = """
+ precision highp float;
+
+ varying vec3 vNormal;
+ varying vec3 vFragPos;
+ varying vec3 vColor;
+
+ uniform vec3 lightColor;
+ uniform float ambientStrength;
+
+ void main()
+ {
+ // ambient
+ vec3 ambient = ambientStrength * lightColor;
+
+ // diffuse
+ vec3 norm = normalize(vNormal);
+ vec4 diffweights = vec4(0.6, 0.4, 0.4, 0.3);
+
+ // key light (overhead)
+ vec3 lightPos1 = vec3(0.0, 5.0, 5.0);
+ vec3 lightDir = normalize(lightPos1 - vFragPos);
+ float diff = max(dot(norm, lightDir), 0.0);
+ vec3 diffuse = diffweights[0] * diff * lightColor;
+
+ // headlight (at camera)
+ vec3 lightPos2 = vec3(0.0, 0.0, 5.0);
+ lightDir = normalize(lightPos2 - vFragPos);
+ vec3 ohlightDir = lightDir;
+ diff = max(dot(norm, lightDir), 0.0);
+ diffuse = diffuse + diffweights[1] * diff * lightColor;
+
+ // fill light (from below)
+ vec3 lightPos3 = vec3(0.0, -5.0, 5.0);
+ lightDir = normalize(lightPos3 - vFragPos);
+ diff = max(dot(norm, lightDir), 0.0);
+ diffuse = diffuse + diffweights[2] * diff * lightColor;
+
+ // left right back lights
+ vec3 lightPos4 = vec3(5.0, 0.0, -5.0);
+ lightDir = normalize(lightPos4 - vFragPos);
+ diff = max(dot(norm, lightDir), 0.0);
+ diffuse = diffuse + diffweights[3] * diff * lightColor;
+
+ vec3 lightPos5 = vec3(-5.0, 0.0, -5.0);
+ lightDir = normalize(lightPos5 - vFragPos);
+ diff = max(dot(norm, lightDir), 0.0);
+ diffuse = diffuse + diffweights[3] * diff * lightColor;
+
+ // specular
+ float specularStrength = 0.5;
+ vec3 viewDir = normalize(-vFragPos);
+ vec3 reflectDir = reflect(-ohlightDir, norm);
+ float spec = pow(max(dot(viewDir, reflectDir), 0.0), 32.0);
+ vec3 specular = specularStrength * spec * lightColor;
+
+ vec3 result = (ambient + diffuse + specular) * vColor;
+ gl_FragColor = vec4(result, 1.0);
+ }
+ """
+
+ return vertex_shader, fragment_shader
+
diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py
new file mode 100644
index 0000000..288e383
--- /dev/null
+++ b/whippersnappy/gl/utils.py
@@ -0,0 +1,450 @@
+"""GL helper utilities.
+
+Contains the implementation of OpenGL helpers used by the package.
+"""
+
+import logging
+import os
+import sys
+from typing import Any
+
+import glfw
+import OpenGL.GL as gl
+import OpenGL.GL.shaders as shaders
+from PIL import Image
+
+from .camera import make_model, make_projection, make_view
+from .shaders import get_default_shaders
+
+# Module logger
+logger = logging.getLogger(__name__)
+
+# Module-level EGL context handle (None when GLFW is used instead)
+_egl_context: Any = None
+
+
+def create_vao():
+ """Create and bind a Vertex Array Object (VAO).
+
+ Returns
+ -------
+ int
+ OpenGL handle for the created VAO.
+ """
+ vao = gl.glGenVertexArrays(1)
+ gl.glBindVertexArray(vao)
+ return vao
+
+
+def compile_shader_program(vertex_src, fragment_src):
+ """Compile GLSL vertex and fragment sources and link them into a program.
+
+ Parameters
+ ----------
+ vertex_src : str
+ Vertex shader source code.
+ fragment_src : str
+ Fragment shader source code.
+
+ Returns
+ -------
+ int
+ OpenGL program handle.
+ """
+ return shaders.compileProgram(
+ shaders.compileShader(vertex_src, gl.GL_VERTEX_SHADER),
+ shaders.compileShader(fragment_src, gl.GL_FRAGMENT_SHADER),
+ )
+
+
+def setup_buffers(meshdata, triangles):
+ """Create and upload vertex and element buffers for the mesh.
+
+ Parameters
+ ----------
+ meshdata : numpy.ndarray
+ Vertex array with interleaved attributes (position, normal, color).
+ triangles : numpy.ndarray
+ Face index array.
+
+ Returns
+ -------
+ (vbo, ebo) : tuple
+ OpenGL buffer handles for the VBO and EBO.
+ """
+ vbo = gl.glGenBuffers(1)
+ gl.glBindBuffer(gl.GL_ARRAY_BUFFER, vbo)
+ gl.glBufferData(gl.GL_ARRAY_BUFFER, meshdata.nbytes, meshdata, gl.GL_STATIC_DRAW)
+
+ ebo = gl.glGenBuffers(1)
+ gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, ebo)
+ gl.glBufferData(
+ gl.GL_ELEMENT_ARRAY_BUFFER, triangles.nbytes, triangles, gl.GL_STATIC_DRAW
+ )
+
+ return vbo, ebo
+
+
+def setup_vertex_attributes(shader):
+ """Configure vertex attribute pointers for position, normal and color.
+
+ Parameters
+ ----------
+ shader : int
+ OpenGL shader program handle used to query attribute locations.
+ """
+ position = gl.glGetAttribLocation(shader, "aPos")
+ gl.glVertexAttribPointer(
+ position, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(0)
+ )
+ gl.glEnableVertexAttribArray(position)
+
+ vnormalpos = gl.glGetAttribLocation(shader, "aNormal")
+ gl.glVertexAttribPointer(
+ vnormalpos, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(3 * 4)
+ )
+ gl.glEnableVertexAttribArray(vnormalpos)
+
+ colorpos = gl.glGetAttribLocation(shader, "aColor")
+ gl.glVertexAttribPointer(
+ colorpos, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(6 * 4)
+ )
+ gl.glEnableVertexAttribArray(colorpos)
+
+
+def set_default_gl_state():
+ """Set frequently used default OpenGL state for rendering.
+
+ This function enables depth testing and sets a default clear color.
+ """
+ gl.glClearColor(0.0, 0.0, 0.0, 1.0)
+ gl.glEnable(gl.GL_DEPTH_TEST)
+
+
+def set_camera_uniforms(shader, view, projection, model):
+ """Upload camera MVP (view, projection, model) matrices to the shader.
+
+ Parameters
+ ----------
+ shader : int
+ OpenGL shader program handle.
+ view, projection, model : array-like
+ 4x4 matrices to be uploaded to the corresponding shader uniforms.
+ """
+ view_loc = gl.glGetUniformLocation(shader, "view")
+ proj_loc = gl.glGetUniformLocation(shader, "projection")
+ model_loc = gl.glGetUniformLocation(shader, "model")
+ gl.glUniformMatrix4fv(view_loc, 1, gl.GL_FALSE, view)
+ gl.glUniformMatrix4fv(proj_loc, 1, gl.GL_FALSE, projection)
+ gl.glUniformMatrix4fv(model_loc, 1, gl.GL_FALSE, model)
+
+
+def set_lighting_uniforms(shader, specular=True, ambient=0.0, light_color=(1.0, 1.0, 1.0)):
+ """Set lighting-related uniforms (specular toggle, ambient, light color).
+
+ Parameters
+ ----------
+ shader : int
+ OpenGL shader program handle.
+ specular : bool, optional, default True
+ Enable specular highlights.
+ ambient : float, optional, default 0.0
+ Ambient light strength.
+ light_color : tuple, optional, default (1.0, 1.0, 1.0)
+ RGB light color.
+ """
+ specular_loc = gl.glGetUniformLocation(shader, "doSpecular")
+ gl.glUniform1i(specular_loc, specular)
+
+ light_color_loc = gl.glGetUniformLocation(shader, "lightColor")
+ gl.glUniform3f(light_color_loc, *light_color)
+
+ ambient_loc = gl.glGetUniformLocation(shader, "ambientStrength")
+ gl.glUniform1f(ambient_loc, ambient)
+
+
+def init_window(width, height, title="PyOpenGL", visible=True):
+ """Create a GLFW window, make an OpenGL context current and return the window handle.
+
+ Parameters
+ ----------
+ width, height : int
+ Window dimensions in pixels.
+ title : str, optional, default 'PyOpenGL'
+ Window title.
+ visible : bool, optional, default True
+ If False create an invisible/offscreen window (useful for headless
+ rendering when a display is available but no screen is needed).
+
+ Returns
+ -------
+ window or False
+ GLFW window handle on success, or False on failure.
+ """
+ if not glfw.init():
+ return False
+
+ glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 3)
+ glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 3)
+ glfw.window_hint(glfw.OPENGL_FORWARD_COMPAT, True)
+ glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE)
+ if not visible:
+ glfw.window_hint(glfw.VISIBLE, glfw.FALSE)
+ window = glfw.create_window(width, height, title, None, None)
+ if not window:
+ glfw.terminate()
+ return False
+ glfw.set_input_mode(window, glfw.STICKY_KEYS, gl.GL_TRUE)
+ glfw.make_context_current(window)
+ glfw.swap_interval(0)
+ return window
+
+
+def create_window_with_fallback(width, height, title="WhipperSnapPy", visible=True):
+ """Create an OpenGL context, trying GLFW first and EGL as a fallback.
+
+ The function attempts context creation in this priority order:
+
+ 1. **GLFW visible window** — normal path on workstations.
+ 2. **GLFW invisible window** — when a display exists but no screen
+ is needed (e.g. a remote desktop session).
+ 3. **EGL pbuffer** — fully headless; no display server required.
+ Works with NVIDIA/AMD GPU drivers and Mesa (llvmpipe) on CPU-only
+ systems. Requires ``libegl1`` (already installed in the Docker
+ image) and ``pyopengl >= 3.1``.
+
+ When EGL is used the module-level ``_egl_context`` is set and
+ ``make_current()`` is called so that subsequent OpenGL calls work
+ identically to the GLFW path.
+
+ Parameters
+ ----------
+ width : int
+ Render target width in pixels.
+ height : int
+ Render target height in pixels.
+ title : str, optional
+ Window title (used for GLFW paths only). Default is ``'WhipperSnapPy'``.
+ visible : bool, optional
+ Prefer a visible window. Default is ``True``.
+
+ Returns
+ -------
+ GLFWwindow or None
+ GLFW window handle when GLFW succeeded, ``None`` when EGL is used
+ (the context is already current via ``_egl_context.make_current()``).
+
+ Raises
+ ------
+ RuntimeError
+ If all three methods fail to produce a usable OpenGL context.
+ """
+ global _egl_context
+
+ # Fast-path: if _check_display() already determined there is no working
+ # display, skip the two doomed GLFW attempts and go straight to EGL.
+ # This avoids warning noise and wasted time in Docker/CI/headless SSH.
+ # The sys.platform guard is preserved — EGL is Linux-only.
+ if os.environ.get("PYOPENGL_PLATFORM") == "egl":
+ if sys.platform != "linux":
+ raise RuntimeError(
+ f"Could not create any OpenGL context via GLFW on {sys.platform}. "
+ "Ensure a display is available."
+ )
+ logger.info("No working display detected — using EGL headless directly.")
+ try:
+ from .egl_context import EGLContext
+ ctx = EGLContext(width, height)
+ ctx.make_current()
+ _egl_context = ctx
+ logger.info("Using EGL headless context — no display server required.")
+ return None
+ except (ImportError, RuntimeError) as exc:
+ raise RuntimeError(
+ f"EGL headless context failed: {exc}"
+ ) from exc
+
+ # --- Step 1: GLFW visible window ---
+ window = init_window(width, height, title, visible=visible)
+ if window:
+ return window
+
+ # --- Step 2: GLFW invisible window ---
+ if visible:
+ logger.warning(
+ "Could not create visible GLFW window; retrying with invisible window."
+ )
+ window = init_window(width, height, title, visible=False)
+ if window:
+ return window
+
+ # --- Step 3: EGL headless pbuffer (Linux only) ---
+ logger.warning(
+ "GLFW context creation failed entirely (no display?). "
+ "Attempting EGL headless context."
+ )
+ if sys.platform != "linux":
+ raise RuntimeError(
+ f"Could not create any OpenGL context via GLFW on {sys.platform}. "
+ "Ensure a display is available."
+ )
+ try:
+ from .egl_context import EGLContext
+ ctx = EGLContext(width, height)
+ ctx.make_current()
+ _egl_context = ctx
+ logger.info("Using EGL headless context — no display server required.")
+ return None
+ except (ImportError, RuntimeError) as exc:
+ raise RuntimeError(
+ "Could not create any OpenGL context (tried GLFW visible, "
+ f"GLFW invisible, EGL pbuffer). Last error: {exc}"
+ ) from exc
+
+
+def terminate_context(window):
+ """Release the active OpenGL context regardless of how it was created.
+
+ This is a drop-in replacement for ``glfw.terminate()`` that also
+ handles the EGL path. Call it at the end of every rendering function
+ instead of calling ``glfw.terminate()`` directly.
+
+ Parameters
+ ----------
+ window : GLFWwindow or None
+ The GLFW window handle returned by ``create_window_with_fallback``,
+ or ``None`` when EGL is active.
+ """
+ global _egl_context
+ if _egl_context is not None:
+ _egl_context.destroy() # type: ignore[union-attr]
+ _egl_context = None
+ else:
+ glfw.terminate()
+
+
+def setup_shader(meshdata, triangles, width, height, specular=True, ambient=0.0):
+ """Create shader program, upload mesh and initialize camera & lighting.
+
+ This is a convenience wrapper that compiles default shaders, creates
+ VAO/VBO/EBO, and configures common uniforms (camera matrices, lighting).
+
+ Parameters
+ ----------
+ meshdata : numpy.ndarray
+ Interleaved vertex data.
+ triangles : numpy.ndarray
+ Triangle indices.
+ width, height : int
+ Framebuffer size used to compute projection matrix.
+ specular : bool, optional, default True
+ Enable specular highlights.
+ ambient : float, optional, default 0.0
+ Ambient lighting strength.
+
+ Returns
+ -------
+ shader : int
+ Compiled OpenGL shader program handle.
+ """
+ vertex_shader, fragment_shader = get_default_shaders()
+
+ create_vao()
+ shader = compile_shader_program(vertex_shader, fragment_shader)
+ setup_buffers(meshdata, triangles)
+ setup_vertex_attributes(shader)
+
+ gl.glUseProgram(shader)
+ set_default_gl_state()
+
+ view = make_view()
+ projection = make_projection(width, height)
+ model = make_model()
+ set_camera_uniforms(shader, view, projection, model)
+ set_lighting_uniforms(shader, specular=specular, ambient=ambient)
+
+ return shader
+
+def capture_window(window):
+ """Read the current GL framebuffer and return it as a PIL Image (RGB).
+
+ Works for both GLFW windows and EGL headless contexts. When EGL is
+ active (``window`` is ``None``) the pixels are read from the FBO that
+ was set up by :class:`~whippersnappy.gl.egl_context.EGLContext`; in
+ that case there is no HiDPI scaling to account for.
+
+ Parameters
+ ----------
+ window : GLFWwindow or None
+ GLFW window handle, or ``None`` when an EGL context is active.
+
+ Returns
+ -------
+ PIL.Image.Image
+ RGB image of the rendered frame, with the vertical flip applied so
+ that the origin is at the top-left (image convention).
+ """
+ global _egl_context
+
+ # --- EGL path: read directly from the FBO ---
+ if _egl_context is not None:
+ return _egl_context.read_pixels() # type: ignore[union-attr]
+
+ # --- GLFW path: read from the default framebuffer ---
+ monitor = glfw.get_primary_monitor()
+ if monitor is None:
+ # Invisible / offscreen GLFW window — no monitor, no HiDPI scaling.
+ x_scale, y_scale = 1.0, 1.0
+ else:
+ x_scale, y_scale = glfw.get_monitor_content_scale(monitor)
+ width, height = glfw.get_framebuffer_size(window)
+
+ logger.debug("Framebuffer size = (%s,%s)", width, height)
+ logger.debug("Monitor scale = (%s,%s)", x_scale, y_scale)
+
+ gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
+ img_buf = gl.glReadPixels(0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE)
+ image = Image.frombytes("RGB", (width, height), img_buf)
+ image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
+
+ if x_scale != 1 or y_scale != 1:
+ rwidth = int(round(width / x_scale))
+ rheight = int(round(height / y_scale))
+ logger.debug("Rescale to = (%s,%s)", rwidth, rheight)
+ image.thumbnail((rwidth, rheight), Image.Resampling.LANCZOS)
+
+ return image
+
+
+def render_scene(shader, triangles, transform):
+ """Render a single draw call using the supplied shader/indices.
+
+ Parameters
+ ----------
+ shader : int
+ OpenGL shader program handle.
+ triangles : numpy.ndarray
+ Element/index array used for the draw call.
+ transform : array-like
+ 4x4 transform matrix (model/view/projection combined) to upload to
+ the shader uniform named ``transform``.
+
+ Raises
+ ------
+ RuntimeError
+ If a GL error occurs during rendering.
+ """
+ try:
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
+ except Exception as exc:
+ logger.error("glClear failed: %s", exc)
+ raise RuntimeError(f"glClear failed: {exc}") from exc
+
+ transform_loc = gl.glGetUniformLocation(shader, "transform")
+ gl.glUniformMatrix4fv(transform_loc, 1, gl.GL_FALSE, transform)
+ gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None)
+
+ err = gl.glGetError()
+ if err != gl.GL_NO_ERROR:
+ logger.error("OpenGL error after draw: %s", err)
+ raise RuntimeError(f"OpenGL error after draw: {err}")
diff --git a/whippersnappy/gl/views.py b/whippersnappy/gl/views.py
new file mode 100644
index 0000000..0d2af24
--- /dev/null
+++ b/whippersnappy/gl/views.py
@@ -0,0 +1,199 @@
+"""View matrices, presets, and interactive view state under gl package."""
+
+from dataclasses import dataclass, field
+
+import numpy as np
+import pyrr # still needed for compute_view_matrix
+
+from ..utils.types import ViewType
+
+# ---------------------------------------------------------------------------
+# ViewState — single source of truth for all mutable view parameters
+# ---------------------------------------------------------------------------
+
+@dataclass
+class ViewState:
+ """Mutable view parameters for the interactive GUI render loop.
+
+ All mouse/keyboard interaction updates this object; the view matrix is
+ recomputed from it each frame via :func:`compute_view_matrix`.
+
+ Parameters
+ ----------
+ rotation : np.ndarray
+ 4×4 float32 rotation matrix (identity = no rotation applied).
+ pan : np.ndarray
+ (x, y) pan offset in normalised screen-space units.
+ zoom : float
+ Z-translation packed into the transform matrix.
+ last_mouse_pos : np.ndarray or None
+ Last recorded mouse position in pixels; ``None`` when no button held.
+ left_button_down : bool
+ Whether the left mouse button is currently pressed.
+ right_button_down : bool
+ Whether the right mouse button is currently pressed.
+ middle_button_down : bool
+ Whether the middle mouse button is currently pressed.
+ """
+ rotation: np.ndarray = field(
+ default_factory=lambda: np.eye(4, dtype=np.float32)
+ )
+ pan: np.ndarray = field(
+ default_factory=lambda: np.zeros(2, dtype=np.float32)
+ )
+ zoom: float = 0.4
+ last_mouse_pos: np.ndarray | None = None
+ left_button_down: bool = False
+ right_button_down: bool = False
+ middle_button_down: bool = False
+
+
+def compute_view_matrix(view_state: ViewState, base_view: np.ndarray) -> np.ndarray:
+ """Return the ``transform`` uniform — exactly as snap_rotate does it.
+
+ Packs ``transl * rotation * base_view`` into a single matrix, matching
+ the snap_rotate convention (line: ``viewmat = transl * rot * base_view``).
+ The ``model`` and ``view`` uniforms are left as set by ``setup_shader``
+ (identity and camera respectively) and must not be overwritten.
+
+ Parameters
+ ----------
+ view_state : ViewState
+ Current interactive view state.
+ base_view : np.ndarray
+ Fixed 4×4 orientation preset from :func:`get_view_matrices`.
+
+ Returns
+ -------
+ np.ndarray
+ 4×4 float32 matrix for the ``transform`` shader uniform.
+ """
+ transl = pyrr.Matrix44.from_translation((
+ view_state.pan[0],
+ view_state.pan[1],
+ 0.4 + view_state.zoom,
+ ))
+ rot = pyrr.Matrix44(view_state.rotation)
+ return np.array(transl * rot * pyrr.Matrix44(base_view), dtype=np.float32)
+
+
+
+# ---------------------------------------------------------------------------
+# Arcball helpers
+# ---------------------------------------------------------------------------
+
+def arcball_vector(x: float, y: float, width: int, height: int) -> np.ndarray:
+ """Map a 2-D screen pixel to a point on the unit arcball sphere.
+
+ Normalises (x, y) to [-1, 1] NDC, then projects onto the unit sphere.
+ Points outside the sphere radius are clamped to the rim (z = 0).
+
+ Parameters
+ ----------
+ x, y : float
+ Mouse position in pixels.
+ width, height : int
+ Window dimensions in pixels.
+
+ Returns
+ -------
+ np.ndarray
+ Unit 3-vector on (or clamped to) the arcball sphere.
+ """
+ s = min(width, height)
+ p = np.array([
+ (2.0 * x - width) / s,
+ -(2.0 * y - height) / s,
+ 0.0,
+ ], dtype=np.float64)
+ sq = p[0] ** 2 + p[1] ** 2
+ if sq <= 1.0:
+ p[2] = np.sqrt(1.0 - sq)
+ else:
+ p /= np.sqrt(sq) # clamp to rim
+ n = np.linalg.norm(p)
+ return p / n if n > 0 else p
+
+
+def arcball_rotation_matrix(v1: np.ndarray, v2: np.ndarray) -> np.ndarray:
+ """Return a 4×4 rotation matrix that rotates unit vector *v1* to *v2*.
+
+ Uses Rodrigues' rotation formula in pure numpy — no pyrr dependency.
+ Returns identity when *v1* and *v2* are coincident.
+
+ Parameters
+ ----------
+ v1, v2 : np.ndarray
+ Unit 3-vectors on the arcball sphere.
+
+ Returns
+ -------
+ np.ndarray
+ 4×4 float32 rotation matrix compatible with pyrr.
+ """
+ axis = np.cross(v1, v2)
+ axis_len = np.linalg.norm(axis)
+ if axis_len < 1e-10:
+ return np.eye(4, dtype=np.float32)
+
+ axis = axis / axis_len
+ angle = np.arctan2(axis_len, np.dot(v1, v2))
+
+ # Rodrigues' formula: R = I cos(a) + sin(a) [axis]× + (1-cos(a)) axis⊗axis
+ c, s = np.cos(angle), np.sin(angle)
+ t = 1.0 - c
+ x, y, z = axis
+ r3 = np.array([
+ [t*x*x + c, t*x*y - s*z, t*x*z + s*y],
+ [t*x*y + s*z, t*y*y + c, t*y*z - s*x],
+ [t*x*z - s*y, t*y*z + s*x, t*z*z + c ],
+ ], dtype=np.float32)
+
+ r4 = np.eye(4, dtype=np.float32)
+ r4[:3, :3] = r3
+ return r4
+
+
+def get_view_matrices():
+ """Return canonical 4x4 view matrices for common brain orientations.
+
+ The returned dictionary maps :class:`whippersnappy.utils.types.ViewType`
+ enum members to corresponding 4x4 view matrices (dtype float32) that
+ can be used as camera/view transforms in the OpenGL renderer.
+
+ Returns
+ -------
+ dict
+ Mapping of :class:`ViewType` -> 4x4 numpy.ndarray view matrix.
+ """
+ view_left = np.array([[0, 0, -1, 0], [-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32)
+ view_right = np.array([[0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32)
+ view_back = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32)
+ view_front = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32)
+ view_bottom = np.array([[-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=np.float32)
+ view_top = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32)
+
+ return {
+ ViewType.LEFT: view_left,
+ ViewType.RIGHT: view_right,
+ ViewType.BACK: view_back,
+ ViewType.FRONT: view_front,
+ ViewType.TOP: view_top,
+ ViewType.BOTTOM: view_bottom,
+ }
+
+
+def get_view_matrix(view_type):
+ """Return the 4x4 view matrix for a single :class:`ViewType`.
+
+ Parameters
+ ----------
+ view_type : ViewType
+ Enum member indicating the requested view.
+
+ Returns
+ -------
+ numpy.ndarray
+ 4x4 float32 view matrix.
+ """
+ return get_view_matrices()[view_type]
diff --git a/whippersnappy/gui/__init__.py b/whippersnappy/gui/__init__.py
new file mode 100644
index 0000000..6c5c9a9
--- /dev/null
+++ b/whippersnappy/gui/__init__.py
@@ -0,0 +1,4 @@
+from .config_app import ConfigWindow
+
+__all__ = ["ConfigWindow"]
+
diff --git a/whippersnappy/config_app.py b/whippersnappy/gui/config_app.py
similarity index 73%
rename from whippersnappy/config_app.py
rename to whippersnappy/gui/config_app.py
index e824a79..0f85ed7 100644
--- a/whippersnappy/config_app.py
+++ b/whippersnappy/gui/config_app.py
@@ -5,10 +5,6 @@
Dependencies:
PyQt6
-
-@Author : Ahmed Faisal Abdelrahman
-@Created : 20.03.2022
-
"""
from PyQt6.QtCore import Qt
@@ -23,20 +19,24 @@
class ConfigWindow(QWidget):
- """
- Encapsulates the Qt widget for the parameter configuration.
+ """Qt configuration window for interactive parameter tuning.
+
+ The configuration window exposes sliders and text boxes to adjust the
+ f-threshold and f-max parameters used by the renderer. The widget is
+ intended to run alongside the OpenGL window and to push updated values
+ to the renderer via polling from the main loop.
Parameters
----------
- parent : QWidget
- This widget's parent, if any (usually none).
- screen_dims : tuple
- Integers specifying screen dims in pixels; used to always position
- the window in the top-right corner, if given.
- initial_fthresh_value : float
- Initial fthreshold value is 2.0.
- initial_fmax_value : float
- Initial fmax value is 4.0.
+ parent : QWidget, optional
+ Parent Qt widget. Defaults to ``None``.
+ screen_dims : tuple or None, optional
+ (width, height) of the available screen; used to position the
+ window in the top-right corner when provided.
+ initial_fthresh_value : float, optional
+ Initial threshold value (default 2.0).
+ initial_fmax_value : float, optional
+ Initial fmax value (default 4.0).
"""
def __init__(
@@ -153,11 +153,11 @@ def __init__(
self.setGeometry(0, 0, self.window_size[0], self.window_size[1])
def fthresh_slider_value_cb(self):
- """
- Callback function for user-modified fthresh slider.
+ """Handle changes from the f-threshold slider.
- This function is triggered when the user modifies the fthresh slider. It
- stores the selected value and updates the corresponding user input box.
+ This slot is connected to the slider's valueChanged signal. It maps the
+ slider tick value into the configured value range and updates the
+ text input box accordingly.
"""
self.current_fthresh_value = self.convert_value_to_range(
self.fthresh_slider.value(),
@@ -167,22 +167,12 @@ def fthresh_slider_value_cb(self):
self.fthresh_value_box.setText(str(self.current_fthresh_value))
def fthresh_value_cb(self, new_value):
- """
- Callback function for user input of fthresh value.
-
- This function is triggered when the user inputs a value for fthresh. It
- stores the selected value and updates the corresponding slider.
+ """Handle text input changes for f-threshold.
Parameters
----------
new_value : float or str
- The new value input by the user. It can be a float or a string that
- can be converted to a float.
-
- Returns
- -------
- None
- This function does not return any value.
+ The new value input by the user. May be a float or numeric string.
"""
# Do not react to invalid values:
try:
@@ -200,12 +190,7 @@ def fthresh_value_cb(self, new_value):
self.fthresh_slider.setValue(int(slider_fthresh_value))
def fmax_slider_value_cb(self):
- """
- Callback function for user-modified fmax slider.
-
- This function is triggered when the user modifies the fmax slider. It
- stores the selected value and updates the corresponding user input box.
- """
+ """Handle changes from the f-max slider and update the text box."""
self.current_fmax_value = self.convert_value_to_range(
self.fmax_slider.value(),
self.fmax_slider_tick_limits,
@@ -214,22 +199,12 @@ def fmax_slider_value_cb(self):
self.fmax_value_box.setText(str(self.current_fmax_value))
def fmax_value_cb(self, new_value):
- """
- Callback function for user input of fmax value.
-
- This function is triggered when the user inputs a value for fmax. It
- stores the selected value and updates the corresponding slider.
+ """Handle text input changes for f-max.
Parameters
----------
new_value : float or str
- The new value input by the user. It can be a float or a string that
- can be converted to a float.
-
- Returns
- -------
- None
- This function does not return any value.
+ New value provided by the user.
"""
# Do not react to invalid values:
try:
@@ -247,25 +222,21 @@ def fmax_value_cb(self, new_value):
self.fmax_slider.setValue(int(slider_fmax_value))
def convert_value_to_range(self, value, old_limits, new_limits):
- """
- Convert a given number from one range to another.
-
- This is useful for transforming values from the original range to that
- of the slider widget tick range and vice-versa.
+ """Map ``value`` from ``old_limits`` to ``new_limits``.
Parameters
----------
value : float
Value to be converted.
old_limits : tuple
- Minimum and maximum values that define the source range.
+ (min, max) source range.
new_limits : tuple
- Minimum and maximum values that define the target range.
+ (min, max) target range.
Returns
-------
- new_value : float
- Converted value.
+ float
+ Value mapped into ``new_limits``.
"""
old_range = old_limits[1] - old_limits[0]
new_range = new_limits[1] - new_limits[0]
@@ -274,40 +245,34 @@ def convert_value_to_range(self, value, old_limits, new_limits):
return new_value
def get_fthresh_value(self):
- """
- Return the current stores value for fthresh.
+ """Return the currently selected f-threshold value.
Returns
-------
- current_fthresh_value: float
- Current fthresh value.
+ float
+ Current f-threshold value.
"""
return self.current_fthresh_value
def get_fmax_value(self):
- """
- Return the current stores value for fmax.
+ """Return the currently selected f-max value.
Returns
-------
- current_fmax_value : float
- Current fmax value.
+ float
+ Current f-max value.
"""
return self.current_fmax_value
def keyPressEvent(self, event):
- """
- Close the window when the ESC key is pressed.
+ """Handle key press events for the window.
+
+ The handler closes the window when the ESC key is pressed.
Parameters
----------
event : QKeyEvent
- The key event object representing the key press.
-
- Returns
- -------
- None
- This function return None.
+ Qt key event delivered by the framework.
"""
if event.key() == Qt.Key.Escape:
self.close()
diff --git a/whippersnappy/plot3d.py b/whippersnappy/plot3d.py
new file mode 100644
index 0000000..8fefcce
--- /dev/null
+++ b/whippersnappy/plot3d.py
@@ -0,0 +1,249 @@
+"""3D plotting for WhipperSnapPy using pythreejs (Three.js) for Jupyter.
+
+This module provides interactive 3D brain visualization for Jupyter notebooks
+using Three.js/WebGL. It works in all Jupyter environments (browser,
+JupyterLab, Colab, VS Code).
+
+Unlike the desktop GUI (``whippersnap`` command), this renders entirely in the
+browser via WebGL and is designed for notebook environments.
+
+Usage::
+
+ from whippersnappy import plot3d
+ viewer = plot3d(mesh='path/to/lh.white', bg_map='path/to/lh.curv')
+ display(viewer)
+
+Dependencies:
+ pythreejs, ipywidgets, numpy
+"""
+
+import logging
+
+import numpy as np
+import pythreejs as p3js
+from ipywidgets import HTML, VBox
+
+from .geometry import prepare_geometry
+from .gl import get_webgl_shaders
+from .utils.types import ColorSelection
+
+# Module logger
+logger = logging.getLogger(__name__)
+
+
+def plot3d(
+ mesh,
+ overlay=None,
+ annot=None,
+ bg_map=None,
+ roi=None,
+ minval=None,
+ maxval=None,
+ invert=False,
+ scale=1.85,
+ color_mode=None,
+ width=800,
+ height=800,
+ ambient=0.1,
+):
+ """Create an interactive 3D notebook viewer using pythreejs (Three.js).
+
+ This function prepares geometry and color information (via
+ :func:`whippersnappy.geometry.prepare_geometry`) and constructs a
+ pythreejs renderer and controls wrapped in an ``ipywidgets.VBox`` for
+ display inside a Jupyter notebook.
+
+ The mesh can be any triangular surface — not just brain surfaces.
+ Supported file formats: FreeSurfer binary surface (e.g. ``lh.white``),
+ ASCII OFF (``.off``), legacy ASCII VTK PolyData (``.vtk``), ASCII PLY
+ (``.ply``), or a ``(vertices, faces)`` numpy array tuple.
+
+ Parameters
+ ----------
+ mesh : str or tuple of (array-like, array-like)
+ Path to a mesh file (FreeSurfer binary, ``.off``, ``.vtk``, or
+ ``.ply``) **or** a ``(vertices, faces)`` tuple.
+ overlay : str, array-like, or None, optional
+ Path to a per-vertex scalar file, or a (N,) array of per-vertex
+ scalar values.
+ annot : str, tuple, or None, optional
+ Path to a FreeSurfer .annot file, or a ``(labels, ctab)`` /
+ ``(labels, ctab, names)`` tuple for categorical labeling.
+ bg_map : str, array-like, or None, optional
+ Path to a per-vertex scalar file **or** a (N,) array used as
+ grayscale background shading for non-overlay regions.
+ roi : str, array-like, or None, optional
+ Path to a FreeSurfer label file **or** a (N,) boolean array to
+ restrict overlay coloring to a subset of vertices.
+ minval, maxval : float or None, optional
+ Threshold and saturation values used for color mapping.
+ If ``None``, sensible defaults are chosen automatically.
+ invert : bool, optional
+ If True, invert the overlay color map. Default is ``False``.
+ scale : float, optional
+ Global geometry scale applied during preparation. Default is ``1.85``.
+ color_mode : ColorSelection or None, optional
+ Which sign of overlay values to color (BOTH/POSITIVE/NEGATIVE).
+ If None, defaults to ``ColorSelection.BOTH``.
+ width, height : int, optional
+ Canvas dimensions for the generated renderer. Default is ``800``.
+ ambient : float, optional
+ Ambient lighting strength passed to the Three.js shader. Default is ``0.1``.
+
+ Returns
+ -------
+ ipywidgets.VBox
+ A widget containing the pythreejs Renderer and a small info panel.
+
+ Raises
+ ------
+ ValueError, FileNotFoundError
+ Errors from :func:`prepare_geometry` are propagated (for example
+ shape mismatches between overlay and mesh vertex count).
+
+ Examples
+ --------
+ In a Jupyter notebook::
+
+ from whippersnappy import plot3d
+ from IPython.display import display
+
+ # FreeSurfer surface
+ viewer = plot3d('lh.white', overlay='lh.thickness', bg_map='lh.curv')
+ display(viewer)
+
+ # Any triangular mesh via OFF / VTK / PLY
+ viewer = plot3d('mesh.off', overlay='values.mgh')
+ display(viewer)
+
+ # Array inputs
+ import numpy as np
+ v = np.random.randn(500, 3).astype(np.float32)
+ f = np.zeros((1, 3), dtype=np.uint32)
+ viewer = plot3d((v, f))
+ display(viewer)
+ """
+ # Load and prepare mesh data
+ color_mode = color_mode or ColorSelection.BOTH
+ meshdata, triangles, fmin, fmax, pos, neg = prepare_geometry(
+ mesh, overlay, annot, bg_map, roi,
+ minval, maxval, invert, scale, color_mode
+ )
+
+ logger.info("Loaded mesh: %d vertices, %d faces", meshdata.shape[0], triangles.shape[0])
+
+ # Extract vertices, normals, and colors
+ vertices = meshdata[:, :3] # x, y, z
+ normals = meshdata[:, 3:6] # nx, ny, nz
+ colors = meshdata[:, 6:9] # r, g, b
+
+ # Center and scale the mesh
+ center = vertices.mean(axis=0)
+ vertices = vertices - center
+ max_extent = np.abs(vertices).max()
+ vertices = vertices / max_extent * 2.0
+
+ # Create Three.js mesh
+ mesh = create_threejs_mesh_with_custom_shaders(vertices, triangles, colors, normals, ambient=ambient)
+
+ camera = p3js.PerspectiveCamera(
+ position=[-5, 0, 0],
+ up=[0, 0, 1],
+ aspect=width/height,
+ fov=45,
+ near=0.1,
+ far=1000
+ )
+
+ # Create scene without lights (use our own custom shader):
+ scene = p3js.Scene(
+ children=[mesh, camera], # No lights needed
+ background='#000000'
+ )
+
+ # Create renderer
+ renderer = p3js.Renderer(
+ camera=camera,
+ scene=scene,
+ controls=[p3js.OrbitControls(controlling=camera)],
+ width=width,
+ height=height,
+ antialias=True
+ )
+
+ # Create info display
+ info_text = f"""
+
+ Interactive 3D Viewer (Three.js) ✓
+ Vertices: {len(vertices):,}
+ Triangles: {len(triangles):,}
+
+ Drag to rotate, scroll to zoom, right-drag to pan
+ """
+
+ if overlay or annot:
+ info_text += "
📊 Overlay/annotation loaded"
+ elif bg_map:
+ info_text += "
🧠 Curvature (grayscale is correct)"
+
+ info_text += "
"
+
+ info = HTML(value=info_text)
+
+ # Combine renderer and info
+ viewer = VBox([renderer, info])
+
+ return viewer
+
+def create_threejs_mesh_with_custom_shaders(vertices, faces, colors, normals, ambient=0.1):
+ """Create a pythreejs.Mesh using custom shader material and buffers.
+
+ The function builds a BufferGeometry with position, color and normal
+ attributes, attaches an index buffer, and creates a ShaderMaterial
+ using the WebGL shader snippets returned by :func:`get_webgl_shaders`.
+
+ Parameters
+ ----------
+ vertices : numpy.ndarray
+ Array of shape (N, 3) with vertex positions (float32).
+ faces : numpy.ndarray
+ Integer face index array shape (M, 3) or flattened (3*M,) dtype uint32.
+ colors : numpy.ndarray
+ Array of shape (N, 3) with per-vertex RGB colors (float32).
+ normals : numpy.ndarray
+ Array of shape (N, 3) with per-vertex normals (float32).
+ ambient : float, optional, default 0.1
+ Ambient lighting strength for the shader (passed to Three.js uniform).
+
+ Returns
+ -------
+ pythreejs.Mesh
+ Mesh object ready to be inserted into a pythreejs.Scene.
+ """
+ vertices = vertices.astype(np.float32)
+ colors = colors.astype(np.float32)
+ normals = normals.astype(np.float32)
+ faces = faces.astype(np.uint32).flatten()
+
+ vertex_shader, fragment_shader = get_webgl_shaders()
+
+ geometry = p3js.BufferGeometry(
+ attributes={
+ 'position': p3js.BufferAttribute(array=vertices),
+ 'color': p3js.BufferAttribute(array=colors),
+ 'normal': p3js.BufferAttribute(array=normals),
+ }
+ )
+ geometry.index = p3js.BufferAttribute(array=faces)
+
+ material = p3js.ShaderMaterial(
+ vertexShader=vertex_shader,
+ fragmentShader=fragment_shader,
+ uniforms={
+ 'lightColor': {'value': [1.0, 1.0, 1.0]},
+ 'ambientStrength': {'value': ambient}
+ }
+ )
+
+ three_mesh = p3js.Mesh(geometry=geometry, material=material)
+ return three_mesh
diff --git a/whippersnappy/Roboto-Regular.ttf b/whippersnappy/resources/fonts/Roboto-Regular.ttf
similarity index 100%
rename from whippersnappy/Roboto-Regular.ttf
rename to whippersnappy/resources/fonts/Roboto-Regular.ttf
diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py
new file mode 100644
index 0000000..176d98e
--- /dev/null
+++ b/whippersnappy/snap.py
@@ -0,0 +1,712 @@
+"""Snapshot (static rendering) API for WhipperSnapPy."""
+
+import logging
+import os
+
+import glfw
+import numpy as np
+import pyrr
+from PIL import Image, ImageFont
+
+from .geometry import estimate_overlay_thresholds, get_surf_name
+from .geometry.prepare import prepare_and_validate_geometry
+from .gl.utils import capture_window, create_window_with_fallback, render_scene, setup_shader, terminate_context
+from .gl.views import get_view_matrices
+from .utils.image import create_colorbar, draw_caption, draw_colorbar, load_roboto_font, text_size
+from .utils.types import ColorSelection, OrientationType, ViewType
+
+# Module logger
+logger = logging.getLogger(__name__)
+
+
+def snap1(
+ mesh,
+ outpath=None,
+ overlay=None,
+ annot=None,
+ bg_map=None,
+ roi=None,
+ view=ViewType.LEFT,
+ viewmat=None,
+ width=700,
+ height=500,
+ fthresh=None,
+ fmax=None,
+ caption=None,
+ caption_x=None,
+ caption_y=None,
+ caption_scale=1,
+ invert=False,
+ colorbar=True,
+ colorbar_x=None,
+ colorbar_y=None,
+ colorbar_scale=1,
+ orientation=OrientationType.HORIZONTAL,
+ color_mode=ColorSelection.BOTH,
+ font_file=None,
+ specular=True,
+ brain_scale=1.5,
+ ambient=0.0,
+):
+ """Render a single static snapshot of a surface mesh.
+
+ This function opens an OpenGL context, uploads the provided
+ surface geometry and colors (overlay or annotation), renders the scene
+ for a single view, captures the framebuffer, and returns a PIL Image.
+ When ``outpath`` is provided the image is also written to disk.
+
+ The mesh can be any triangular surface — not just brain surfaces.
+ Supported file formats: FreeSurfer binary surface (e.g. ``lh.white``),
+ ASCII OFF (``.off``), legacy ASCII VTK PolyData (``.vtk``), ASCII PLY
+ (``.ply``), or a ``(vertices, faces)`` numpy array tuple.
+
+ Parameters
+ ----------
+ mesh : str or tuple of (array-like, array-like)
+ Path to a mesh file (FreeSurfer binary, ``.off``, ``.vtk``, or
+ ``.ply``) **or** a ``(vertices, faces)`` tuple where *vertices* is
+ (N, 3) float and *faces* is (M, 3) int.
+ outpath : str or None, optional
+ When provided, the resulting image is saved to this path.
+ overlay : str, array-like, or None, optional
+ Overlay file path (``.mgh`` or FreeSurfer morph) **or** a (N,) array
+ of per-vertex scalar values. If ``None``, coloring falls back to
+ background shading / annotation.
+ annot : str, tuple, or None, optional
+ Path to a FreeSurfer .annot file **or** a ``(labels, ctab)`` /
+ ``(labels, ctab, names)`` tuple with per-vertex labels.
+ bg_map : str, array-like, or None, optional
+ Path to a per-vertex scalar file **or** a (N,) array whose sign
+ determines light/dark background shading for non-overlay vertices.
+ roi : str, array-like, or None, optional
+ Path to a FreeSurfer label file **or** a (N,) boolean array.
+ Vertices with ``True`` receive overlay coloring; others fall back
+ to *bg_map* shading.
+ view : ViewType, optional
+ Which pre-defined view to render (left, right, front, ...).
+ Default is ``ViewType.LEFT``.
+ viewmat : 4x4 matrix-like, optional
+ Optional view matrix to override the pre-defined view.
+ width, height : int, optional
+ Output canvas size in pixels. Defaults to (700×500).
+ fthresh, fmax : float or None, optional
+ Threshold and saturation values for overlay coloring.
+ caption, caption_x, caption_y, caption_scale : str/float, optional
+ Caption text and layout parameters.
+ invert : bool, optional
+ Invert the color scale. Default is ``False``.
+ colorbar : bool, optional
+ If True, render a colorbar when an overlay is present. Default is ``True``.
+ colorbar_x, colorbar_y, colorbar_scale : float, optional
+ Colorbar positioning and scale. Scale defaults to 1.
+ orientation : OrientationType, optional
+ Colorbar orientation (HORIZONTAL/VERTICAL). Default is ``OrientationType.HORIZONTAL``.
+ color_mode : ColorSelection, optional
+ Which sign of overlay to color (POSITIVE/NEGATIVE/BOTH). Default is ``ColorSelection.BOTH``.
+ font_file : str or None, optional
+ Path to a TTF font for captions; fallback to bundled font if None.
+ specular : bool, optional
+ Enable specular highlights. Default is ``True``.
+ brain_scale : float, optional
+ Scale factor applied when preparing the geometry. Default is ``1.5``.
+ ambient : float, optional
+ Ambient lighting strength for shader. Default is ``0.0``.
+
+ Returns
+ -------
+ PIL.Image.Image
+ Rendered snapshot as a PIL Image.
+
+ Raises
+ ------
+ RuntimeError
+ If the OpenGL/GLFW context cannot be initialized.
+ ValueError
+ If the overlay contains no values to display for the chosen
+ color_mode.
+ FileNotFoundError
+ If a required file cannot be found.
+
+ Examples
+ --------
+ FreeSurfer surface with overlay::
+
+ >>> from whippersnappy import snap1
+ >>> img = snap1('lh.white', overlay='lh.thickness',
+ ... bg_map='lh.curv', roi='lh.cortex.label')
+ >>> img.save('/tmp/lh.png')
+
+ Array inputs (any triangular mesh)::
+
+ >>> import numpy as np
+ >>> v = np.random.randn(100, 3).astype(np.float32)
+ >>> f = np.array([[0, 1, 2]], dtype=np.uint32)
+ >>> img = snap1((v, f))
+
+ OFF / VTK / PLY file::
+
+ >>> img = snap1('mesh.off', overlay='values.mgh')
+ """
+ ref_width = 700
+ ref_height = 500
+ ui_scale = min(width / ref_width, height / ref_height)
+ try:
+ if glfw.init():
+ primary_monitor = glfw.get_primary_monitor()
+ if primary_monitor:
+ mode = glfw.get_video_mode(primary_monitor)
+ if width > mode.size.width:
+ logger.info("Requested width %d exceeds screen width %d, expect black bars",
+ width, mode.size.width)
+ elif height > mode.size.height:
+ logger.info("Requested height %d exceeds screen height %d, expect black bars",
+ height, mode.size.height)
+ except Exception:
+ pass # headless — no monitor info available, that's fine
+
+ image = Image.new("RGB", (width, height))
+
+ bwidth = int(540 * brain_scale * ui_scale)
+ bheight = int(450 * brain_scale * ui_scale)
+ brain_display_width = min(bwidth, width)
+ brain_display_height = min(bheight, height)
+ logger.debug("Requested (width,height) = (%s,%s)", width, height)
+ logger.debug("Brain (width,height) = (%s,%s)", bwidth, bheight)
+ logger.debug("B-Display (width,height) = (%s,%s)", brain_display_width, brain_display_height)
+
+ # will raise exception if it cannot be created
+ window = create_window_with_fallback(brain_display_width, brain_display_height, "WhipperSnapPy", visible=True)
+ try:
+ meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry(
+ mesh,
+ overlay,
+ annot,
+ bg_map,
+ roi,
+ fthresh,
+ fmax,
+ invert,
+ scale=brain_scale,
+ color_mode=color_mode,
+ )
+
+ shader = setup_shader(meshdata, triangles, brain_display_width, brain_display_height,
+ specular=specular, ambient=ambient)
+
+ transl = pyrr.Matrix44.from_translation((0, 0, 0.4))
+ view_mats = get_view_matrices()
+ viewmat = transl * (view_mats[view] if viewmat is None else viewmat)
+ render_scene(shader, triangles, viewmat)
+
+ # Center the brain rendering in the output image, clamp to zero
+ brain_x = max(0, (width - brain_display_width) // 2)
+ brain_y = max(0, (height - brain_display_height) // 2)
+ image.paste(capture_window(window), (brain_x, brain_y))
+
+ bar = (
+ create_colorbar(
+ fthresh, fmax, invert, orientation, colorbar_scale * ui_scale, pos, neg, font_file=font_file
+ )
+ if overlay is not None and colorbar
+ else None
+ )
+ font = (
+ load_roboto_font(int(20 * caption_scale * ui_scale))
+ if font_file is None
+ else ImageFont.truetype(font_file, int(20 * caption_scale * ui_scale))
+ if caption
+ else None
+ )
+
+ # Compute positions to avoid overlap, unless explicit positions are given
+ text_w, text_h = text_size(caption, font) if caption and font else (0, 0)
+ bar_h = bar.height if bar is not None else 0
+ gap = int(4 * ui_scale)
+ bottom_pad = int(20 * ui_scale)
+
+ if orientation == OrientationType.HORIZONTAL:
+ # If explicit positions are given, use them
+ if colorbar_x is not None or colorbar_y is not None or caption_x is not None or caption_y is not None:
+ bx = int(colorbar_x * width) if colorbar_x is not None else None
+ by = int(colorbar_y * height) if colorbar_y is not None else None
+ cx = int(caption_x * width) if caption_x is not None else None
+ cy = int(caption_y * height) if caption_y is not None else None
+ draw_colorbar(image, bar, orientation, x=bx, y=by)
+ draw_caption(image, caption, font, orientation, x=cx, y=cy)
+ else:
+ # Place colorbar above caption if both present
+ if bar is not None and caption:
+ bar_y = image.height - bottom_pad - text_h - gap - bar_h
+ caption_y = image.height - bottom_pad - text_h
+ elif bar is not None:
+ bar_y = image.height - bottom_pad - bar_h
+ caption_y = None
+ elif caption:
+ bar_y = None
+ caption_y = image.height - bottom_pad - text_h
+ else:
+ bar_y = caption_y = None
+ draw_colorbar(image, bar, orientation, y=bar_y)
+ draw_caption(image, caption, font, orientation, y=caption_y)
+ else:
+ # For vertical, allow explicit x/y for both, else use default
+ bx = int(colorbar_x * width) if colorbar_x is not None else None
+ by = int(colorbar_y * height) if colorbar_y is not None else None
+ cx = int(caption_x * width) if caption_x is not None else None
+ cy = int(caption_y * height) if caption_y is not None else None
+ draw_colorbar(image, bar, orientation, x=bx, y=by)
+ draw_caption(image, caption, font, orientation, x=cx, y=cy)
+
+ if outpath:
+ logger.info("Saving snapshot to %s", outpath)
+ image.save(outpath)
+ return image
+ finally:
+ terminate_context(window)
+
+
+def snap4(
+ lh_overlay=None,
+ rh_overlay=None,
+ lh_annot=None,
+ rh_annot=None,
+ fthresh=None,
+ fmax=None,
+ sdir=None,
+ caption=None,
+ invert=False,
+ roi_name="cortex.label",
+ surfname=None,
+ bg_map_name="curv",
+ colorbar=True,
+ outpath=None,
+ font_file=None,
+ specular=True,
+ ambient=0.0,
+ brain_scale=1.85,
+ color_mode=ColorSelection.BOTH,
+):
+ """Render four snapshot views (left/right hemispheres, lateral/medial).
+
+ This convenience function renders four views (lateral/medial for each
+ hemisphere), stitches them together into a single PIL Image and returns
+ it (and saves it to ``outpath`` when provided). It is typically used to
+ produce publication-ready overview figures composed from both
+ hemispheres.
+
+ Parameters
+ ----------
+ lh_overlay, rh_overlay : str, array-like, or None
+ Left/right hemisphere overlay — either a file path (FreeSurfer morph
+ or .mgh) or a per-vertex scalar array. Typically provided as a pair
+ for a coherent two-hemisphere color scale.
+ lh_annot, rh_annot : str, tuple, or None
+ Left/right hemisphere annotation — either a path to a .annot file or
+ a ``(labels, ctab)`` / ``(labels, ctab, names)`` tuple.
+ Cannot be combined with ``lh_overlay``/``rh_overlay``.
+ fthresh, fmax : float or None
+ Threshold and saturation for overlay coloring. Auto-estimated when
+ ``None``.
+ sdir : str or None
+ Subject directory containing ``surf/`` and ``label/`` subdirectories.
+ Falls back to ``$SUBJECTS_DIR`` when ``None``.
+ caption : str or None
+ Caption string to place on the final image.
+ invert : bool, optional
+ Invert color scale. Default is ``False``.
+ roi_name : str, optional
+ Basename of the label file used to restrict overlay coloring (default
+ ``'cortex.label'``). The full path is constructed as
+ ``/label/.``.
+ surfname : str or None, optional
+ Surface basename to load (e.g. ``'white'``); auto-detected when
+ ``None``.
+ bg_map_name : str, optional
+ Basename of the curvature/morph file used for background shading
+ (default ``'curv'``). The full path is constructed as
+ ``/surf/.``.
+ colorbar : bool, optional
+ Whether to draw a colorbar on the composed image. Default is ``True``.
+ outpath : str or None, optional
+ If provided, save composed image to this path.
+ font_file : str or None, optional
+ Path to a font to use for captions.
+ specular : bool, optional
+ Enable/disable specular highlights in the renderer. Default is ``True``.
+ ambient : float, optional
+ Ambient lighting strength. Default is ``0``.
+ brain_scale : float, optional
+ Scaling factor passed to geometry preparation. Default is ``1.85``.
+ color_mode : ColorSelection, optional
+ Which sign of overlay to color (POSITIVE/NEGATIVE/BOTH). Default is ``ColorSelection.BOTH``.
+
+ Returns
+ -------
+ PIL.Image.Image
+ Composed image of the four views.
+
+ Raises
+ ------
+ ValueError
+ For invalid argument combinations or when required overlay values
+ are absent.
+ FileNotFoundError
+ When required surface files are not found.
+
+ Examples
+ --------
+ >>> from whippersnappy import snap4
+ >>> img = snap4(
+ ... lh_overlay='fsaverage/surf/lh.thickness',
+ ... rh_overlay='fsaverage/surf/rh.thickness',
+ ... sdir='./fsaverage'
+ ... )
+ >>> img.save('/tmp/whippersnappy_overview.png')
+ """
+ wwidth = 540
+ wheight = 450
+
+ # Resolve sdir early so path-building works for both the pre-pass and
+ # the rendering loop.
+ if sdir is None:
+ sdir = os.environ.get("SUBJECTS_DIR")
+ if not sdir and surfname is None:
+ logger.error("No sdir or SUBJECTS_DIR provided")
+ raise ValueError("No sdir or SUBJECTS_DIR provided")
+ if not sdir and surfname is not None:
+ logger.error("surfname provided but sdir is None")
+ raise ValueError("surfname provided but sdir is None; cannot construct mesh path.")
+
+ # Pre-pass: estimate missing fthresh/fmax from overlays for global color scale
+ has_overlay = lh_overlay is not None or rh_overlay is not None
+ if has_overlay and (fthresh is None or fmax is None):
+ est_fthreshs = []
+ est_fmaxs = []
+ for _overlay in filter(None, (lh_overlay, rh_overlay)):
+ h_fthresh, h_fmax = estimate_overlay_thresholds(_overlay, fthresh, fmax)
+ est_fthreshs.append(h_fthresh)
+ est_fmaxs.append(h_fmax)
+ if fthresh is None and est_fthreshs:
+ fthresh = min(est_fthreshs)
+ if fmax is None and est_fmaxs:
+ fmax = max(est_fmaxs)
+ logger.debug("Global color range: fthresh=%s fmax=%s", fthresh, fmax)
+
+ # will raise exception if it cannot be created
+ window = create_window_with_fallback(wwidth, wheight, "WhipperSnapPy", visible=True)
+ try:
+ # Use standard view matrices from get_view_matrices and ViewType
+ view_mats = get_view_matrices()
+ view_left = view_mats[ViewType.LEFT]
+ view_right = view_mats[ViewType.RIGHT]
+ transl = pyrr.Matrix44.from_translation((0, 0, 0.4))
+
+ # Predefine hemisphere images so static analysis knows they exist even if
+ # an earlier step raises an exception (we still will fail at runtime).
+ lhimg = None
+ rhimg = None
+
+ for hemi in ("lh", "rh"):
+ if surfname is None:
+ found_surfname = get_surf_name(sdir, hemi)
+ if found_surfname is None:
+ logger.error("Could not find valid surface in %s for hemi: %s!", sdir, hemi)
+ raise FileNotFoundError(f"Could not find valid surface in {sdir} for hemi: {hemi}")
+ mesh = os.path.join(sdir, "surf", hemi + "." + found_surfname)
+ else:
+ mesh = os.path.join(sdir, "surf", hemi + "." + surfname)
+
+ # Assign derived paths for bg_map and roi
+ bg_map = os.path.join(sdir, "surf", hemi + "." + bg_map_name) if bg_map_name else None
+ roi = os.path.join(sdir, "label", hemi + "." + roi_name) if roi_name else None
+ overlay = lh_overlay if hemi == "lh" else rh_overlay
+ annot = lh_annot if hemi == "lh" else rh_annot
+
+ # If overlay is an array, it doesn't have a path to log; handle gracefully
+ if isinstance(overlay, str):
+ logger.debug("overlay=%s exists=%s", overlay, os.path.exists(overlay))
+ elif overlay is not None:
+ logger.debug("overlay=", getattr(overlay, 'shape', None))
+
+ # Diagnostic: report mesh and overlay paths and whether they exist
+ logger.debug("hemisphere=%s", hemi)
+ if isinstance(mesh, str):
+ logger.debug("mesh=%s exists=%s", mesh, os.path.exists(mesh))
+ if isinstance(annot, str) and annot is not None:
+ logger.debug("annot=%s exists=%s", annot, os.path.exists(annot))
+ if bg_map is not None:
+ logger.debug("bg_map=%s exists=%s", bg_map, os.path.exists(bg_map))
+
+ try:
+ meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry(
+ mesh, overlay, annot, bg_map, roi, fthresh, fmax, invert,
+ scale=brain_scale, color_mode=color_mode
+ )
+ except Exception as e:
+ logger.error("prepare_geometry failed for %s: %s", mesh, e)
+ raise
+
+ # Diagnostics about mesh data
+ try:
+ logger.debug("meshdata shape: %s; triangles count: %s", getattr(meshdata, 'shape', None),
+ getattr(triangles, 'size', None))
+ except Exception:
+ pass
+
+ try:
+ shader = setup_shader(meshdata, triangles, wwidth, wheight, specular=specular, ambient=ambient)
+ logger.debug("Shader setup complete")
+ except Exception as e:
+ logger.error("setup_shader failed: %s", e)
+ raise
+
+ render_scene(shader, triangles, transl * view_left)
+ im1 = capture_window(window)
+ render_scene(shader, triangles, transl * view_right)
+ im2 = capture_window(window)
+
+ if hemi == "lh":
+ lhimg = Image.new("RGB", (im1.width, im1.height + im2.height))
+ lhimg.paste(im1, (0, 0))
+ lhimg.paste(im2, (0, im1.height))
+ else:
+ rhimg = Image.new("RGB", (im1.width, im1.height + im2.height))
+ # For right hemisphere, reverse the order: top=im2, bottom=im1
+ rhimg.paste(im2, (0, 0))
+ rhimg.paste(im1, (0, im2.height))
+
+ # Add small padding around each hemisphere to avoid cropping at edges
+ pad = max(4, int(0.03 * wwidth))
+ padded_lh = Image.new("RGB", (lhimg.width + 2 * pad, lhimg.height + 2 * pad), (0, 0, 0))
+ padded_lh.paste(lhimg, (pad, pad))
+ padded_rh = Image.new("RGB", (rhimg.width + 2 * pad, rhimg.height + 2 * pad), (0, 0, 0))
+ padded_rh.paste(rhimg, (pad, pad))
+
+ image = Image.new("RGB", (padded_lh.width + padded_rh.width, padded_lh.height))
+ image.paste(padded_lh, (0, 0))
+ image.paste(padded_rh, (padded_lh.width, 0))
+
+ font = load_roboto_font(20) if font_file is None else ImageFont.truetype(font_file, 20) if caption else None
+ # Place caption at bottom, colorbar above if both present
+ text_w, text_h = text_size(caption, font) if caption and font else (0, 0)
+ bottom_pad = 20
+ gap = 4
+ caption_y = image.height - bottom_pad - text_h
+ bar = (
+ create_colorbar(fthresh, fmax, invert, pos=pos, neg=neg)
+ if lh_annot is None and rh_annot is None and colorbar
+ else None
+ )
+ bar_h = bar.height if bar is not None else 0
+ if bar is not None and caption:
+ bar_y = image.height - bottom_pad - text_h - gap - bar_h
+ draw_colorbar(image, bar, OrientationType.HORIZONTAL, y=bar_y)
+ draw_caption(image, caption, font, OrientationType.HORIZONTAL, y=caption_y)
+ elif bar is not None:
+ bar_y = image.height - bottom_pad - bar_h
+ draw_colorbar(image, bar, OrientationType.HORIZONTAL, y=bar_y)
+ elif caption:
+ draw_caption(image, caption, font, OrientationType.HORIZONTAL, y=caption_y)
+
+ # If outpath is specified, save to disk
+ if outpath:
+ logger.info("Saving snapshot to %s", outpath)
+ image.save(outpath)
+
+ return image
+ finally:
+ terminate_context(window)
+
+
+def snap_rotate(
+ mesh,
+ outpath,
+ n_frames=72,
+ fps=24,
+ width=700,
+ height=500,
+ overlay=None,
+ bg_map=None,
+ annot=None,
+ roi=None,
+ fthresh=None,
+ fmax=None,
+ invert=False,
+ specular=True,
+ ambient=0.0,
+ brain_scale=1.5,
+ start_view=ViewType.LEFT,
+ color_mode=ColorSelection.BOTH,
+):
+ """Render a rotating 360° video of a surface mesh.
+
+ Rotates the view around the vertical (Y) axis in ``n_frames`` equal
+ steps, captures each frame via OpenGL, and encodes the result into a
+ video file. An animated GIF can be produced by passing an ``outpath``
+ ending in ``.gif``; in that case ``imageio-ffmpeg`` is not required.
+
+ The mesh can be any triangular surface — not just brain surfaces.
+ Supported file formats: FreeSurfer binary surface, ASCII OFF (``.off``),
+ legacy ASCII VTK PolyData (``.vtk``), ASCII PLY (``.ply``), or a
+ ``(vertices, faces)`` numpy array tuple.
+
+ Parameters
+ ----------
+ mesh : str or tuple of (array-like, array-like)
+ Path to a mesh file (FreeSurfer binary, ``.off``, ``.vtk``, or
+ ``.ply``) **or** a ``(vertices, faces)`` tuple.
+ outpath : str
+ Destination file path. The extension controls the output format:
+
+ * ``.mp4`` — H.264 MP4 (recommended, requires ``imageio-ffmpeg``).
+ * ``.webm`` — VP9 WebM (requires ``imageio-ffmpeg``).
+ * ``.gif`` — animated GIF (no ffmpeg required, but larger file).
+
+ n_frames : int, optional
+ Number of frames for a full 360° rotation. Default is ``72``
+ (one frame every 5°).
+ fps : int, optional
+ Output frame rate in frames per second. Default is ``24``.
+ width, height : int, optional
+ Render resolution in pixels. Defaults are ``700`` and ``500``.
+ overlay : str, array-like, or None, optional
+ Per-vertex overlay file path or array (e.g. thickness).
+ bg_map : str, array-like, or None, optional
+ Curvature/morph file path or array for background shading.
+ annot : str, tuple, or None, optional
+ FreeSurfer ``.annot`` file path or ``(labels, ctab)`` tuple.
+ roi : str, array-like, or None, optional
+ Label file path or boolean array to restrict overlay coloring.
+ fthresh : float or None, optional
+ Overlay threshold value.
+ fmax : float or None, optional
+ Overlay saturation value.
+ invert : bool, optional
+ Invert the overlay color scale. Default is ``False``.
+ specular : bool, optional
+ Enable specular highlights. Default is ``True``.
+ ambient : float, optional
+ Ambient lighting strength. Default is ``0.0``.
+ brain_scale : float, optional
+ Geometry scale factor. Default is ``1.5``.
+ start_view : ViewType, optional
+ Pre-defined view to start the rotation from.
+ Default is ``ViewType.LEFT``.
+ color_mode : ColorSelection, optional
+ Which overlay sign to color (POSITIVE/NEGATIVE/BOTH).
+ Default is ``ColorSelection.BOTH``.
+
+ Returns
+ -------
+ str
+ The resolved ``outpath`` that was written.
+
+ Raises
+ ------
+ ImportError
+ If ``imageio`` or ``imageio-ffmpeg`` is not installed and a
+ video format (``.mp4``, ``.webm``) was requested.
+ RuntimeError
+ If the OpenGL context cannot be initialised.
+ ValueError
+ If the overlay contains no values for the chosen color mode.
+
+ Examples
+ --------
+ >>> from whippersnappy import snap_rotate
+ >>> snap_rotate(
+ ... 'fsaverage/surf/lh.white',
+ ... '/tmp/rotation.mp4',
+ ... overlay='fsaverage/surf/lh.thickness',
+ ... )
+ '/tmp/rotation.mp4'
+ """
+ ext = os.path.splitext(outpath)[1].lower()
+ use_gif = ext == ".gif"
+
+ if not use_gif:
+ try:
+ import imageio # noqa: F401
+ import imageio_ffmpeg # noqa: F401
+ except ImportError as exc:
+ raise ImportError(
+ f"Video output requires the 'imageio' and 'imageio-ffmpeg' packages. "
+ f"Install with: pip install 'whippersnappy[video]'\n"
+ f"Original error: {exc}"
+ ) from exc
+ import imageio
+ else:
+ try:
+ import imageio # noqa: F401
+ except ImportError as exc:
+ raise ImportError(
+ "GIF output requires the 'imageio' package. "
+ "Install with: pip install 'whippersnappy[video]'"
+ ) from exc
+ import imageio
+
+ window = create_window_with_fallback(width, height, "WhipperSnapPy", visible=True)
+ try:
+ meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry(
+ mesh,
+ overlay,
+ annot,
+ bg_map,
+ roi,
+ fthresh,
+ fmax,
+ invert,
+ scale=brain_scale,
+ color_mode=color_mode,
+ )
+ logger.info(
+ "Rendering %d frames at %dx%d (%.0f° per step) → %s",
+ n_frames, width, height, 360.0 / n_frames, outpath,
+ )
+
+ shader = setup_shader(meshdata, triangles, width, height,
+ specular=specular, ambient=ambient)
+
+ transl = pyrr.Matrix44.from_translation((0, 0, 0.4))
+ base_view = get_view_matrices()[start_view]
+
+ frames = []
+ for i in range(n_frames):
+ angle = 2 * np.pi * i / n_frames
+ rot = pyrr.Matrix44.from_y_rotation(angle)
+ viewmat = transl * rot * base_view
+ render_scene(shader, triangles, viewmat)
+ frames.append(np.array(capture_window(window)))
+ if (i + 1) % max(1, n_frames // 10) == 0:
+ logger.debug(" frame %d / %d", i + 1, n_frames)
+
+ finally:
+ terminate_context(window)
+
+ logger.info("Encoding %d frames to %s …", len(frames), outpath)
+ if use_gif:
+ # Pure-PIL GIF — no ffmpeg required
+ pil_frames = [Image.fromarray(f) for f in frames]
+ pil_frames[0].save(
+ outpath,
+ save_all=True,
+ append_images=pil_frames[1:],
+ loop=0,
+ duration=int(1000 / fps),
+ optimize=True,
+ )
+ else:
+ writer_kwargs = {
+ "fps": fps,
+ "codec": "libx264",
+ "quality": 6,
+ "pixelformat": "yuv420p",
+ }
+ if ext == ".webm":
+ writer_kwargs["codec"] = "libvpx-vp9"
+ writer_kwargs.pop("pixelformat", None)
+ imageio.mimwrite(outpath, frames, **writer_kwargs)
+
+ logger.info("Saved rotation video to %s", outpath)
+ return outpath
+
diff --git a/whippersnappy/types.py b/whippersnappy/types.py
deleted file mode 100644
index b707a20..0000000
--- a/whippersnappy/types.py
+++ /dev/null
@@ -1,29 +0,0 @@
-"""Contains the types used in WhipperSnapPy.
-
-Dependencies:
- enum
-
-@Author : Abdulla Ahmadkhan
-@Created : 02.10.2025
-@Revised : 02.10.2025
-
-"""
-import enum
-
-
-class ColorSelection(enum.Enum):
- BOTH = 1
- POSITIVE = 2
- NEGATIVE = 3
-
-class OrientationType(enum.Enum):
- HORIZONTAL = 1
- VERTICAL = 2
-
-class ViewType(enum.Enum):
- LEFT = 1
- RIGHT = 2
- BACK = 3
- FRONT = 4
- TOP = 5
- BOTTOM = 6
\ No newline at end of file
diff --git a/whippersnappy/utils/__init__.py b/whippersnappy/utils/__init__.py
index 285e1d4..26ac282 100644
--- a/whippersnappy/utils/__init__.py
+++ b/whippersnappy/utils/__init__.py
@@ -1 +1,4 @@
-"""Utilities module."""
+"""Utils subpackage exports."""
+from . import colormap, datasets, image, types
+
+__all__ = ["colormap", "datasets", "image", "types"]
diff --git a/whippersnappy/utils/_config.py b/whippersnappy/utils/_config.py
deleted file mode 100644
index b1d393c..0000000
--- a/whippersnappy/utils/_config.py
+++ /dev/null
@@ -1,112 +0,0 @@
-import platform
-import re
-import sys
-from functools import partial
-from importlib.metadata import requires, version
-from typing import IO, Callable, Optional
-
-import psutil
-
-
-def sys_info(fid: Optional[IO] = None, developer: bool = False):
- """Print the system information for debugging.
-
- Parameters
- ----------
- fid : file-like, default=None
- The file to write to, passed to :func:`print`.
- Can be None to use :data:`sys.stdout`.
- developer : bool, default=False
- If True, display information about optional dependencies.
- """
-
- ljust = 26
- out = partial(print, end="", file=fid)
- package = __package__.split(".")[0]
-
- # OS information - requires python 3.8 or above
- out("Platform:".ljust(ljust) + platform.platform() + "\n")
- # Python information
- out("Python:".ljust(ljust) + sys.version.replace("\n", " ") + "\n")
- out("Executable:".ljust(ljust) + sys.executable + "\n")
- # CPU information
- out("CPU:".ljust(ljust) + platform.processor() + "\n")
- out("Physical cores:".ljust(ljust) + str(psutil.cpu_count(False)) + "\n")
- out("Logical cores:".ljust(ljust) + str(psutil.cpu_count(True)) + "\n")
- # Memory information
- out("RAM:".ljust(ljust))
- out(f"{psutil.virtual_memory().total / float(2 ** 30):0.1f} GB\n")
- out("SWAP:".ljust(ljust))
- out(f"{psutil.swap_memory().total / float(2 ** 30):0.1f} GB\n")
-
- # dependencies
- out("\nDependencies info\n")
- out(f"{package}:".ljust(ljust) + version(package) + "\n")
- dependencies = [
- elt.split(";")[0].rstrip() for elt in requires(package) if "extra" not in elt
- ]
- _list_dependencies_info(out, ljust, dependencies)
-
- # extras
- if developer:
- keys = (
- "build",
- "doc",
- "test",
- "style",
- )
- for key in keys:
- dependencies = [
- elt.split(";")[0].rstrip()
- for elt in requires(package)
- if f"extra == '{key}'" in elt or f'extra == "{key}"' in elt
- ]
- if len(dependencies) == 0:
- continue
- out(f"\nOptional '{key}' info\n")
- _list_dependencies_info(out, ljust, dependencies)
-
-
-def _list_dependencies_info(out: Callable, ljust: int, dependencies: list[str]):
- """List dependencies names and versions.
-
- Parameters
- ----------
- out : Callable
- output function
- ljust : int
- length of returned string
- dependencies : List[str]
- list of dependencies
-
- """
-
- for dep in dependencies:
- # handle dependencies with version specifiers
- specifiers_pattern = r"(~=|==|!=|<=|>=|<|>|===)"
- specifiers = re.findall(specifiers_pattern, dep)
- if len(specifiers) != 0:
- dep, _ = dep.split(specifiers[0])
- while not dep[-1].isalpha():
- dep = dep[:-1]
- # handle dependencies provided with a [key], e.g. pydocstyle[toml]
- if "[" in dep:
- dep = dep.split("[")[0]
- try:
- version_ = version(dep)
- except Exception:
- version_ = "Not found."
-
- # handle special dependencies with backends, C dep, ..
- if dep in ("matplotlib", "seaborn") and version_ != "Not found.":
- try:
- from matplotlib import pyplot as plt
-
- backend = plt.get_backend()
- except Exception:
- backend = "Not found"
-
- out(f"{dep}:".ljust(ljust) + version_ + f" (backend: {backend})\n")
-
- else:
- out(f"{dep}:".ljust(ljust) + version_ + "\n")
diff --git a/whippersnappy/utils/colormap.py b/whippersnappy/utils/colormap.py
new file mode 100644
index 0000000..460f81b
--- /dev/null
+++ b/whippersnappy/utils/colormap.py
@@ -0,0 +1,183 @@
+"""Colormap and value preprocessing utilities."""
+
+import logging
+
+import numpy as np
+
+from .types import ColorSelection
+
+# Module logger
+logger = logging.getLogger(__name__)
+
+
+def heat_color(values, invert=False):
+ """Convert an array of float values into RGB heat color values.
+
+ Maps scalar values to RGB triplets suitable for visualization. Input
+ values are expected to be in a symmetric range around zero; mapping
+ produces blue-to-red heat colors. NaN inputs propagate to NaN outputs.
+
+ Parameters
+ ----------
+ values : array_like
+ 1-D array of float values to map. May include NaNs.
+ invert : bool, optional
+ If True, invert the sign of the input values before mapping.
+ Default is False.
+
+ Returns
+ -------
+ numpy.ndarray
+ Array of shape (N, 3) and dtype float32 with RGB channels in [0, 1].
+ """
+ if invert:
+ values = -1.0 * values
+ vabs = np.abs(values)
+ colors = np.zeros((vabs.size, 3), dtype=np.float32)
+ crb = 0.5625 + 3 * 0.4375 * vabs
+ cg = 1.5 * (vabs - (1.0 / 3.0))
+ n1 = values < -1.0
+ nm = (values >= -1.0) & (values < -(1.0 / 3.0))
+ n0 = (values >= -(1.0 / 3.0)) & (values < 0)
+ p0 = (values >= 0) & (values < (1.0 / 3.0))
+ pm = (values >= (1.0 / 3.0)) & (values < 1.0)
+ p1 = values >= 1.0
+ colors[n1, 1:3] = 1.0
+ colors[nm, 1] = cg[nm]
+ colors[nm, 2] = 1.0
+ colors[n0, 2] = crb[n0]
+ colors[p0, 0] = crb[p0]
+ colors[pm, 1] = cg[pm]
+ colors[pm, 0] = 1.0
+ colors[p1, 0:2] = 1.0
+ colors[np.isnan(values), :] = np.nan
+ return colors
+
+
+def mask_sign(values, color_mode):
+ """Mask values that don't match the requested sign selection.
+
+ Parameters
+ ----------
+ values : array_like
+ Input numeric array.
+ color_mode : ColorSelection
+ Enum indicating which sign to preserve (POSITIVE, NEGATIVE, BOTH).
+
+ Returns
+ -------
+ numpy.ndarray
+ Copy of ``values`` where elements not matching the requested sign
+ are set to ``np.nan``.
+ """
+ masked_values = np.copy(values)
+ if color_mode == ColorSelection.POSITIVE:
+ masked_values[masked_values < 0] = np.nan
+ elif color_mode == ColorSelection.NEGATIVE:
+ masked_values[masked_values > 0] = np.nan
+ return masked_values
+
+
+def rescale_overlay(values, minval, maxval):
+ """Rescale overlay values into a normalized range for colormap computation.
+
+ Values whose absolute magnitude is below ``minval`` are set to ``NaN``.
+ Remaining values are shifted by ``minval`` and divided by ``(maxval - minval)``.
+
+ Parameters
+ ----------
+ values : numpy.ndarray
+ Numeric array of overlay values (1-D).
+ minval : float
+ Minimum absolute threshold — values with abs < minval are treated as absent.
+ maxval : float
+ Maximum absolute value used for normalization.
+
+ Returns
+ -------
+ tuple
+ ``(values, minval, maxval, pos, neg)`` where ``values`` is the rescaled
+ array, and ``pos``/``neg`` are booleans indicating presence of positive
+ / negative values after rescaling.
+
+ Raises
+ ------
+ ValueError
+ If ``minval`` or ``maxval`` is negative.
+ """
+ valsign = np.sign(values)
+ valabs = np.abs(values)
+
+ if maxval < 0 or minval < 0:
+ logger.error("rescale_overlay ERROR: min and maxval should both be positive!")
+ raise ValueError("minval and maxval must be non-negative")
+
+ values[valabs < minval] = np.nan
+ range_val = maxval - minval
+ if range_val == 0:
+ values = np.zeros_like(values)
+ else:
+ values = values - valsign * minval
+ values = values / range_val
+
+ pos = np.any(values[~np.isnan(values)] > 0)
+ neg = np.any(values[~np.isnan(values)] < 0)
+
+ return values, minval, maxval, pos, neg
+
+
+def binary_color(values, thres, color_low, color_high):
+ """Create a binary colormap for values based on a threshold.
+
+ Parameters
+ ----------
+ values : array_like
+ 1-D array of values to map.
+ thres : float
+ Threshold value used to split the colors.
+ color_low, color_high : scalar or sequence
+ Colors assigned to values below/above the threshold. Scalars are
+ expanded to RGB triplets.
+
+ Returns
+ -------
+ numpy.ndarray
+ Array of shape (N, 3) and dtype float32 containing RGB colors.
+ """
+ if np.isscalar(color_low):
+ color_low = np.array((color_low, color_low, color_low), dtype=np.float32)
+ if np.isscalar(color_high):
+ color_high = np.array((color_high, color_high, color_high), dtype=np.float32)
+ colors = np.empty((values.size, 3), dtype=np.float32)
+ colors[values < thres, :] = color_low
+ colors[values >= thres, :] = color_high
+ return colors
+
+
+def mask_label(values, labelpath=None):
+ """Apply a label file as a mask to an array of per-vertex values.
+
+ If ``labelpath`` is provided the function loads vertex indices from the
+ label file and sets all entries not listed in the label to ``NaN``.
+
+ Parameters
+ ----------
+ values : numpy.ndarray
+ 1-D array indexed by vertex id.
+ labelpath : str or None, optional
+ Path to a label file readable by ``numpy.loadtxt`` (expected format
+ with vertex ids in the first column after two header lines).
+
+ Returns
+ -------
+ numpy.ndarray
+ Array with vertices not included in the label set to ``np.nan``.
+ """
+ if not labelpath:
+ return values
+ maskvids = np.loadtxt(labelpath, dtype=int, skiprows=2, usecols=[0])
+ imask = np.ones(values.shape, dtype=bool)
+ imask[maskvids] = False
+ values[imask] = np.nan
+ return values
+
diff --git a/whippersnappy/utils/datasets.py b/whippersnappy/utils/datasets.py
new file mode 100644
index 0000000..7f2293f
--- /dev/null
+++ b/whippersnappy/utils/datasets.py
@@ -0,0 +1,124 @@
+"""Sample dataset download utility for WhipperSnapPy.
+
+Downloads and caches a small anonymized FreeSurfer subject from the
+WhipperSnapPy GitHub release assets for use in tutorials and tests.
+"""
+
+from pathlib import Path
+
+RELEASE_URL = (
+ "https://github.com/Deep-MI/WhipperSnapPy"
+ "/releases/download/data-v1.0/{file_name}"
+)
+
+# Mapping of relative path inside the subject directory → SHA-256 hash.
+# GitHub release assets are flat (no subdirectories), so the URL uses only
+# the basename while pooch.retrieve() reconstructs the subdirectory locally.
+_FILES = {
+ "README.md": "sha256:ecb6ddf31cec17f3a8636fc3ecac90099c441228811efed56104e29fcd301bc5",
+ "surf/lh.white": "sha256:4ab049fb42ca882ba9b56f8fe0d0e8814973e7fa2e0575a794d8e468abf7d62f",
+ "surf/lh.curv": "sha256:9edbde57be8593cd9d89d9d1124e2175edd8ecfee55d53e066d89700c480b12a",
+ "surf/lh.thickness": "sha256:40ab3483284608c6c5cca2d3d794a60cd1bcbeb0140bb1ca6ad0fce7962c57c6",
+ "surf/rh.white": "sha256:43035c53a8b04bebe4e843c34f80588f253f79052a8dbf7194b706495b11f8d2",
+ "surf/rh.curv": "sha256:af2bc71133d7ef17ce1a3a6f4208d2495a5a4c96da00c80b59be03bb7c8ea83f",
+ "surf/rh.thickness": "sha256:50ec291c73928cd697156edd9e0e77f5c54d15c56cf84810d2564b496876e132",
+ "label/lh.aparc.DKTatlas.mapped.annot": "sha256:4d48d33f4fd8278ab973a1552f6ea9c396dfc1791b707ed17ad8e761299c4960",
+ "label/lh.cortex.label": "sha256:79ae17fcfde6b2e0a75a0652fcc0f3c072e4ea62a541843b7338e01c598b0b6e",
+ "label/rh.aparc.DKTatlas.mapped.annot": "sha256:12217166d8ef43ee1fa280511ec2ba0796c6885f527a4455b93760acc73ce273",
+ "label/rh.cortex.label": "sha256:162c97c887eb1ec857fe575b8cc4e4b950c7dd5ec181a581d709bbe7fca58f9e",
+}
+
+
+def _build_dict(base: Path) -> dict:
+ """Build the return dictionary of paths from a subject base directory."""
+ return {
+ "sdir": str(base),
+ "lh_white": str(base / "surf/lh.white"),
+ "lh_curv": str(base / "surf/lh.curv"),
+ "lh_thickness": str(base / "surf/lh.thickness"),
+ "rh_white": str(base / "surf/rh.white"),
+ "rh_curv": str(base / "surf/rh.curv"),
+ "rh_thickness": str(base / "surf/rh.thickness"),
+ "lh_annot": str(base / "label/lh.aparc.DKTatlas.mapped.annot"),
+ "lh_label": str(base / "label/lh.cortex.label"),
+ "rh_annot": str(base / "label/rh.aparc.DKTatlas.mapped.annot"),
+ "rh_label": str(base / "label/rh.cortex.label"),
+ }
+
+
+def fetch_sample_subject() -> dict:
+ """Download and cache the WhipperSnapPy sample subject (Rhineland Study).
+
+ Downloads FreeSurfer surface files for one anonymized subject into the
+ OS-specific user cache directory and returns a dictionary of paths to
+ all files. Files are only downloaded once; subsequent calls use the
+ local cache.
+
+ If a ``sub-rs/`` directory with all required files is found next to the
+ package root (i.e. inside the source repository), it is used directly
+ without any network access. This allows the Sphinx doc build to work
+ before the GitHub release assets are published.
+
+ Returns
+ -------
+ dict
+ Dictionary with the following keys:
+
+ * ``sdir`` -- path to the subject root directory (``sub-rs/``),
+ usable directly as the ``sdir`` argument to :func:`~whippersnappy.snap4`.
+ * ``lh_white`` -- path to ``surf/lh.white``.
+ * ``lh_curv`` -- path to ``surf/lh.curv``.
+ * ``lh_thickness`` -- path to ``surf/lh.thickness``.
+ * ``rh_white`` -- path to ``surf/rh.white``.
+ * ``rh_curv`` -- path to ``surf/rh.curv``.
+ * ``rh_thickness`` -- path to ``surf/rh.thickness``.
+ * ``lh_annot`` -- path to ``label/lh.aparc.DKTatlas.mapped.annot``.
+ * ``lh_label`` -- path to ``label/lh.cortex.label``.
+ * ``rh_annot`` -- path to ``label/rh.aparc.DKTatlas.mapped.annot``.
+ * ``rh_label`` -- path to ``label/rh.cortex.label``.
+
+ Raises
+ ------
+ ImportError
+ If ``pooch`` is not installed. Install with
+ ``pip install 'whippersnappy[notebook]'``.
+
+ Notes
+ -----
+ Data from the Rhineland Study (Koch et al.),
+ https://doi.org/10.5281/zenodo.11186582, CC BY 4.0.
+
+ Examples
+ --------
+ >>> from whippersnappy import fetch_sample_subject
+ >>> data = fetch_sample_subject()
+ >>> print(data["sdir"])
+ """
+ try:
+ import pooch
+ except ImportError as e:
+ raise ImportError(
+ "fetch_sample_subject() requires pooch. "
+ "Install with: pip install 'whippersnappy[notebook]'"
+ ) from e
+
+ # Use a local sub-rs/ directory (present in the source repo) when all
+ # required files are already there — no network access needed.
+ _pkg_root = Path(__file__).parent.parent.parent # .../whippersnappy/
+ _local = _pkg_root / "sub-rs"
+ if _local.is_dir() and all((_local / p).exists() for p in _FILES):
+ return _build_dict(_local)
+
+ # Otherwise download from the GitHub release and cache in the OS cache dir.
+ base = Path(pooch.os_cache("whippersnappy")) / "sub-rs"
+
+ for rel_path, known_hash in _FILES.items():
+ rel = Path(rel_path)
+ pooch.retrieve(
+ url=RELEASE_URL.format(file_name=rel.name),
+ known_hash=known_hash,
+ fname=rel.name,
+ path=base / rel.parent,
+ )
+
+ return _build_dict(base)
diff --git a/whippersnappy/utils/image.py b/whippersnappy/utils/image.py
new file mode 100644
index 0000000..f9d2ea0
--- /dev/null
+++ b/whippersnappy/utils/image.py
@@ -0,0 +1,343 @@
+"""Image and text helper utilities used by snapshot renderers (moved under utils).
+"""
+import numpy as np
+from PIL import Image, ImageDraw
+
+from .colormap import heat_color
+from .types import OrientationType
+
+try:
+ # Prefer stdlib importlib.resources
+ from importlib import resources
+except Exception:
+ import importlib_resources as resources
+import warnings
+
+from PIL import ImageFont
+
+
+def text_size(caption, font):
+ """Return text width and height in pixels for a given caption and font.
+
+ Parameters
+ ----------
+ caption : str
+ Text to measure.
+ font : PIL.ImageFont.FreeTypeFont or similar
+ Font object used for measurement.
+
+ Returns
+ -------
+ (width, height) : tuple[int, int]
+ Pixel dimensions of rendered text.
+ """
+ dummy_img = Image.new("L", (1, 1))
+ draw = ImageDraw.Draw(dummy_img)
+ bbox = draw.textbbox((0, 0), caption, font=font, anchor="lt")
+ text_width = bbox[2] - bbox[0]
+ text_height = bbox[3] - bbox[1]
+ return text_width, text_height
+
+
+def get_colorbar_label_positions(
+ font,
+ labels,
+ colorbar_rect,
+ gapspace=0,
+ pos=True,
+ neg=True,
+ orientation=OrientationType.HORIZONTAL,
+):
+ """Compute positions for colorbar label text.
+
+ Parameters
+ ----------
+ font : PIL.ImageFont
+ Font used to measure text sizes.
+ labels : dict
+ Mapping of label keys to text strings (e.g. 'upper','lower','middle').
+ colorbar_rect : tuple
+ Rectangle for the colorbar (x, y, width, height).
+ gapspace : int, optional, default 0
+ Additional spacing used for split colorbars.
+ pos, neg : bool, optional, default True, True
+ Whether positive/negative sides are present.
+ orientation : OrientationType, optional, default OrientationType.HORIZONTAL
+ Orientation of the colorbar.
+
+ Returns
+ -------
+ positions : dict
+ Mapping of label key -> (x, y) pixel position.
+ """
+ positions = {}
+ cb_x, cb_y, cb_width, cb_height = colorbar_rect
+ cb_labels_gap = 5
+
+ if orientation == OrientationType.HORIZONTAL:
+ label_y = cb_y + cb_height + cb_labels_gap
+
+ w, _ = text_size(labels["upper"], font)
+ if pos:
+ positions["upper"] = (cb_x + cb_width - w, label_y)
+ else:
+ upper_x = cb_x + cb_width - w - int(gapspace) if gapspace > 0 else cb_x + cb_width - w
+ positions["upper"] = (upper_x, label_y)
+
+ w, _ = text_size(labels["lower"], font)
+ if neg:
+ positions["lower"] = (cb_x, label_y)
+ else:
+ lower_x = cb_x + int(gapspace) if gapspace > 0 else cb_x
+ positions["lower"] = (lower_x, label_y)
+
+ if neg and pos:
+ if gapspace == 0:
+ w, _ = text_size(labels["middle"], font)
+ positions["middle"] = (cb_x + cb_width // 2 - w // 2, label_y)
+ else:
+ w, _ = text_size(labels["middle_neg"], font)
+ positions["middle_neg"] = (cb_x + cb_width // 2 - w - int(gapspace), label_y)
+ w, _ = text_size(labels["middle_pos"], font)
+ positions["middle_pos"] = (cb_x + cb_width // 2 + int(gapspace), label_y)
+ else:
+ label_x = cb_x + cb_width + cb_labels_gap
+
+ _, h = text_size(labels["upper"], font)
+ if pos:
+ positions["upper"] = (label_x, cb_y)
+ else:
+ upper_y = cb_y + int(gapspace) if gapspace > 0 else cb_y
+ positions["upper"] = (label_x, upper_y)
+
+ _, h = text_size(labels["lower"], font)
+ if neg:
+ positions["lower"] = (label_x, cb_y + cb_height - 1.5 * h)
+ else:
+ lower_y = cb_y + cb_height - int(gapspace) - 1.5 * h if gapspace > 0 else cb_y + cb_height - 1.5 * h
+ positions["lower"] = (label_x, lower_y)
+
+ if neg and pos:
+ if gapspace == 0:
+ _, h = text_size(labels["middle"], font)
+ positions["middle"] = (label_x, cb_y + cb_height // 2 - h // 2)
+ else:
+ _, h = text_size(labels["middle_pos"], font)
+ positions["middle_pos"] = (label_x, cb_y + cb_height // 2 - 1.5 * h - int(gapspace))
+ _, h = text_size(labels["middle_neg"], font)
+ positions["middle_neg"] = (label_x, cb_y + cb_height // 2 + int(gapspace))
+
+ return positions
+
+
+def create_colorbar(
+ fmin,
+ fmax,
+ invert,
+ orientation=OrientationType.HORIZONTAL,
+ colorbar_scale=1,
+ pos=True,
+ neg=True,
+ font_file=None,
+):
+ """Create a colored colorbar as a PIL.Image.
+
+ The colorbar visualizes the overlay color mapping (using
+ :func:`whippersnappy.utils.colormap.heat_color`) and optionally draws
+ numeric labels for the min/threshold/saturation positions.
+
+ Parameters
+ ----------
+ fmin, fmax : float
+ Threshold and saturation values used to label the colorbar.
+ invert : bool
+ Invert the heat color mapping.
+ orientation : OrientationType, optional, default OrientationType.HORIZONTAL
+ Orientation of the colorbar (HORIZONTAL/VERTICAL).
+ colorbar_scale : float, optional, default 1
+ Scale factor for resulting image size.
+ pos, neg : bool, optional, default True, True
+ Whether the colorbar has positive/negative regions.
+ font_file : str or None, optional
+ Path to a TTF font file to use for labels.
+
+ Returns
+ -------
+ PIL.Image.Image or None
+ A PIL image containing the colorbar, or ``None`` if inputs are
+ insufficient (e.g. fmin/fmax are None).
+ """
+ # If fmin/fmax are not specified, we cannot create a meaningful colorbar.
+ if fmin is None or fmax is None:
+ return None
+
+ cwidth = int(200 * colorbar_scale)
+ cheight = int(30 * colorbar_scale)
+ gapspace = 0
+
+ if fmin > 0.01:
+ num = int(0.42 * cwidth)
+ gapspace = 0.08 * cwidth
+ else:
+ num = int(0.5 * cwidth)
+ if not neg or not pos:
+ num = num * 2
+ gapspace = gapspace * 2
+
+ values = np.nan * np.ones(cwidth)
+ steps = np.linspace(0.01, 1, num)
+ if pos and not neg:
+ values[-steps.size:] = steps
+ elif not pos and neg:
+ values[: steps.size] = -1.0 * np.flip(steps)
+ else:
+ values[: steps.size] = -1.0 * np.flip(steps)
+ values[-steps.size:] = steps
+
+ colors = heat_color(values, invert)
+ colors[np.isnan(values), :] = 0.33 * np.ones((1, 3))
+ img_bar = np.uint8(np.tile(colors, (cheight, 1, 1)) * 255)
+
+ pad_top, pad_left = 3, 10
+ img_buf = np.zeros((cheight + 2 * pad_top, cwidth + 2 * pad_left, 3), dtype=np.uint8)
+ img_buf[pad_top : cheight + pad_top, pad_left : cwidth + pad_left, :] = img_bar
+ image = Image.fromarray(img_buf)
+
+ if font_file is None:
+ # Try to load bundled font from package resources
+ font = None
+ try:
+ font_trav = resources.files("whippersnappy").joinpath("resources", "fonts", "Roboto-Regular.ttf")
+ with resources.as_file(font_trav) as font_path:
+ font = ImageFont.truetype(str(font_path), int(12 * colorbar_scale))
+ except Exception:
+ warnings.warn("Roboto font not found in package resources; falling back to default font",
+ UserWarning, stacklevel=2)
+ font = ImageFont.load_default()
+ else:
+ try:
+ font = ImageFont.truetype(font_file, int(12 * colorbar_scale))
+ except Exception:
+ font = ImageFont.load_default()
+
+ labels = {}
+ labels["upper"] = f">{fmax:.2f}" if pos else (f"{-fmin:.2f}" if gapspace != 0 else "0")
+ labels["lower"] = f"<{-fmax:.2f}" if neg else (f"{fmin:.2f}" if gapspace != 0 else "0")
+ if neg and pos and gapspace != 0:
+ labels["middle_neg"] = f"{-fmin:.2f}"
+ labels["middle_pos"] = f"{fmin:.2f}"
+ elif neg and pos and gapspace == 0:
+ labels["middle"] = "0"
+
+ caption_sizes = [text_size(caption, font) for caption in labels.values()]
+ max_caption_width = int(max([caption_size[0] for caption_size in caption_sizes]))
+ max_caption_height = int(max([caption_size[1] for caption_size in caption_sizes]))
+
+ if orientation == OrientationType.VERTICAL:
+ image = image.rotate(90, expand=True)
+ new_width = image.width + int(max_caption_width)
+ new_image = Image.new("RGB", (new_width, image.height), (0, 0, 0))
+ new_image.paste(image, (0, 0))
+ image = new_image
+ colorbar_rect = (pad_top, pad_left, cheight, cwidth)
+ else:
+ new_height = image.height + int(max_caption_height * 2)
+ new_image = Image.new("RGB", (image.width, new_height), (0, 0, 0))
+ new_image.paste(image, (0, 0))
+ image = new_image
+ colorbar_rect = (pad_left, pad_top, cwidth, cheight)
+
+ positions = get_colorbar_label_positions(font, labels, colorbar_rect, gapspace, pos, neg, orientation)
+ draw = ImageDraw.Draw(image)
+ for label_key, position in positions.items():
+ draw.text((int(position[0]), int(position[1])), labels[label_key], fill=(220, 220, 220), font=font)
+
+ return image
+
+
+def load_roboto_font(size=14):
+ """Load the bundled Roboto font from package resources.
+
+ Parameters
+ ----------
+ size : int, optional
+ Requested point size.
+
+ Returns
+ -------
+ PIL.ImageFont.FreeTypeFont or PIL.ImageFont.ImageFont or None
+ A PIL font object; falls back to ``ImageFont.load_default()`` or
+ ``None`` if fonts cannot be loaded.
+ """
+ try:
+ # resources was imported earlier in this module
+ font_trav = resources.files("whippersnappy").joinpath("resources", "fonts", "Roboto-Regular.ttf")
+ with resources.as_file(font_trav) as font_path:
+ return ImageFont.truetype(str(font_path), size)
+ except Exception:
+ warnings.warn("Roboto font not found in package resources; falling back to default font", UserWarning,
+ stacklevel=2)
+ try:
+ return ImageFont.load_default()
+ except Exception:
+ return None
+
+
+def draw_colorbar(image, bar, orientation, x=None, y=None):
+ """Paste a colorbar image onto the target image at the specified position.
+
+ Parameters
+ ----------
+ image : PIL.Image.Image
+ The target image to paste onto.
+ bar : PIL.Image.Image
+ The colorbar image to paste.
+ orientation : OrientationType
+ Orientation of the colorbar (HORIZONTAL/VERTICAL).
+ x, y : int or None, optional
+ Position to paste the colorbar. If None, defaults to centered at bottom (horizontal) or right (vertical).
+ """
+ if bar is None:
+ return
+ if orientation == OrientationType.HORIZONTAL:
+ bx = int(0.5 * (image.width - bar.width)) if x is None else x
+ by = image.height - bar.height - 10 if y is None else y
+ image.paste(bar, (bx, by))
+ else:
+ bx = image.width - bar.width - 10 if x is None else x
+ by = int(0.5 * (image.height - bar.height)) if y is None else y
+ image.paste(bar, (bx, by))
+
+
+def draw_caption(image, caption, font, orientation, x=None, y=None):
+ """Draw a caption string onto the image at the specified position and orientation.
+
+ Parameters
+ ----------
+ image : PIL.Image.Image
+ The target image to draw onto.
+ caption : str
+ The caption text to draw.
+ font : PIL.ImageFont
+ Font to use for the caption.
+ orientation : OrientationType
+ Orientation of the caption (HORIZONTAL/VERTICAL).
+ x, y : int or None, optional
+ Position to draw the caption. If None, defaults to centered at bottom (horizontal) or right (vertical).
+ """
+ if not caption or font is None:
+ return
+ text_w, text_h = text_size(caption, font)
+ draw = ImageDraw.Draw(image)
+ if orientation == OrientationType.HORIZONTAL:
+ cx = int(0.5 * (image.width - text_w)) if x is None else x
+ cy = image.height - text_h - 10 if y is None else y
+ draw.text((cx, cy), caption, (220, 220, 220), font=font, anchor="lt")
+ else:
+ temp_caption_img = Image.new("RGBA", (text_w, text_h), (0, 0, 0, 0))
+ ImageDraw.Draw(temp_caption_img).text((0, 0), caption, font=font, anchor="lt")
+ rotated_caption = temp_caption_img.rotate(90, expand=True, fillcolor=(0, 0, 0, 0))
+ rotated_w, rotated_h = rotated_caption.size
+ cx = image.width - rotated_w - 10 if x is None else x
+ cy = int(0.5 * (image.height - rotated_h)) if y is None else y
+ image.paste(rotated_caption, (cx, cy), rotated_caption)
diff --git a/whippersnappy/utils/types.py b/whippersnappy/utils/types.py
new file mode 100644
index 0000000..8d0e417
--- /dev/null
+++ b/whippersnappy/utils/types.py
@@ -0,0 +1,93 @@
+"""Contains the types used in WhipperSnapPy.
+
+This module defines small enumeration types used across the package for
+controlling color selection, colorbar orientation, and predefined views.
+
+Classes
+-------
+ColorSelection
+ Which sign(s) of overlay values should be used to produce colors.
+OrientationType
+ Orientation of UI elements such as the colorbar (horizontal or vertical).
+ViewType
+ Predefined canonical view orientations for rendering the brain surface.
+"""
+
+import enum
+
+
+class ColorSelection(enum.Enum):
+ """Enum to select which sign(s) of overlay values to color.
+
+ Parameters
+ ----------
+ *values : tuple
+ Positional arguments passed to the Enum constructor (not used by
+ consumers of this enum). Documented here to satisfy documentation
+ linters that inspect the class signature.
+
+ Attributes
+ ----------
+ BOTH : int
+ Use both positive and negative values for coloring.
+ POSITIVE : int
+ Use only positive values for coloring.
+ NEGATIVE : int
+ Use only negative values for coloring.
+ """
+ BOTH = 1
+ POSITIVE = 2
+ NEGATIVE = 3
+
+
+class OrientationType(enum.Enum):
+ """Enum describing orientation choices for elements like the colorbar.
+
+ Parameters
+ ----------
+ *values : tuple
+ Positional arguments passed to the Enum constructor (not used by
+ consumers of this enum).
+
+ Attributes
+ ----------
+ HORIZONTAL : int
+ Layout along the horizontal axis.
+ VERTICAL : int
+ Layout along the vertical axis.
+ """
+ HORIZONTAL = 1
+ VERTICAL = 2
+
+
+class ViewType(enum.Enum):
+ """Predefined canonical view directions used by snapshot renderers.
+
+ Parameters
+ ----------
+ *values : tuple
+ Positional arguments passed to the Enum constructor (not used by
+ consumers of this enum).
+
+ Attributes
+ ----------
+ LEFT : int
+ Left hemisphere lateral view.
+ RIGHT : int
+ Right hemisphere lateral view.
+ BACK : int
+ Posterior view.
+ FRONT : int
+ Anterior/frontal view.
+ TOP : int
+ Superior/top view.
+ BOTTOM : int
+ Inferior/bottom view.
+ """
+ LEFT = 1
+ RIGHT = 2
+ BACK = 3
+ FRONT = 4
+ TOP = 5
+ BOTTOM = 6
+