From e972b029e272bb11f1d8400b2bd1ad729e8c67aa Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Mon, 16 Feb 2026 15:32:52 +0100 Subject: [PATCH 01/83] v2.0 RC : major refactor, webgl in notebook (plot3d) --- README.md | 243 ++- examples/whippersnappy_demo.ipynb | 260 +++ pyproject.toml | 11 +- .../utils/tests => tests}/__init__.py | 0 .../utils/tests => tests}/test_config.py | 0 whippersnappy/__init__.py | 68 + whippersnappy/_version.py | 9 +- whippersnappy/cli/whippersnap.py | 11 +- whippersnappy/core.py | 1498 ----------------- whippersnappy/geometry/__init__.py | 11 + whippersnappy/geometry/prepare.py | 158 ++ whippersnappy/{ => geometry}/read_geometry.py | 0 whippersnappy/geometry/surf_name.py | 19 + whippersnappy/gl/__init__.py | 33 + whippersnappy/gl/camera.py | 29 + whippersnappy/gl/shaders.py | 194 +++ whippersnappy/gl/utils.py | 174 ++ whippersnappy/gl/views.py | 29 + whippersnappy/gui/__init__.py | 4 + whippersnappy/{ => gui}/config_app.py | 0 whippersnappy/plot3d.py | 194 +++ .../{ => resources/fonts}/Roboto-Regular.ttf | Bin whippersnappy/snap.py | 385 +++++ whippersnappy/utils/__init__.py | 5 +- whippersnappy/utils/colormap.py | 88 + whippersnappy/utils/image.py | 217 +++ whippersnappy/{ => utils}/types.py | 0 27 files changed, 2049 insertions(+), 1591 deletions(-) create mode 100644 examples/whippersnappy_demo.ipynb rename {whippersnappy/utils/tests => tests}/__init__.py (100%) rename {whippersnappy/utils/tests => tests}/test_config.py (100%) delete mode 100644 whippersnappy/core.py create mode 100644 whippersnappy/geometry/__init__.py create mode 100644 whippersnappy/geometry/prepare.py rename whippersnappy/{ => geometry}/read_geometry.py (100%) create mode 100644 whippersnappy/geometry/surf_name.py create mode 100644 whippersnappy/gl/__init__.py create mode 100644 whippersnappy/gl/camera.py create mode 100644 whippersnappy/gl/shaders.py create mode 100644 whippersnappy/gl/utils.py create mode 100644 whippersnappy/gl/views.py create mode 100644 whippersnappy/gui/__init__.py rename whippersnappy/{ => gui}/config_app.py (100%) create mode 100644 whippersnappy/plot3d.py rename whippersnappy/{ => resources/fonts}/Roboto-Regular.ttf (100%) create mode 100644 whippersnappy/snap.py create mode 100644 whippersnappy/utils/colormap.py create mode 100644 whippersnappy/utils/image.py rename whippersnappy/{ => utils}/types.py (100%) diff --git a/README.md b/README.md index 1b2895f..daf2486 100644 --- a/README.md +++ b/README.md @@ -1,80 +1,163 @@ -# WhipperSnapPy - -WhipperSnapPY is a small Python OpenGL program to render FreeSurfer and -FastSurfer surface models and color overlays and generate screen shots. - -## Contents: - -- Capture 4x4 surface plots (front & back, left and right) -- OpenGL window for interactive visualization - -## Installation: - -The `WhipperSnapPy` package can be installed from pypi via -``` -python3 -m pip install whippersnappy -``` - -Note, that currently no off-screen rendering is natively supported. Even in snap -mode an invisible window will be created to render the openGL output -and capture the contents to an image. In order to run this on a headless -server, inside Docker, or via ssh we recommend to install xvfb and run - -``` -sudo apt update && apt install -y python3 python3-pip xvfb libxcb-xinerama0 -pip3 install pyopengl glfw pillow numpy pyrr PyQt6 -pip3 install whippersnappy -xvfb-run whippersnap ... -``` - -## Usage: - -### Local: - -After installing the Python package, the whippersnap program can be run using -the installed command line tool such as in the following example: -``` -whippersnap -lh $OVERLAY_DIR/$LH_OVERLAY_FILE \ - -rh $OVERLAY_DIR/$RH_OVERLAY_FILE \ - -sd $SURF_SUBJECT_DIR \ - --fmax 4 --fthresh 2 --invert \ - --caption caption.txt \ - -o $OUTPUT_DIR/whippersnappy_image.png \ -``` - -For more options see `whippersnap --help`. -Note, that adding the `--interactive` flag will start an interactive GUI that -includes a visualization of one hemisphere side and a simple application through -which color threshold values can be configured. - -### Docker: - -The whippersnap program can be run within a docker container to capture -a snapshot by building the provided Docker image and running a container as -follows: -``` -docker build --rm=true -t whippersnappy -f ./Dockerfile . -``` -``` -docker run --rm --init --name my_whippersnappy -v $SURF_SUBJECT_DIR:/surf_subject_dir \ - -v $OVERLAY_DIR:/overlay_dir \ - -v $OUTPUT_DIR:/output_dir \ - --user $(id -u):$(id -g) whippersnappy:latest \ - --lh_overlay /overlay_dir/$LH_OVERLAY_FILE \ - --rh_overlay /overlay_dir/$RH_OVERLAY_FILE \ - --sdir /surf_subject_dir \ - --output_path /output_dir/whippersnappy_image.png -``` - -In this example: `$SURF_SUBJECT_DIR` contains the surface files, `$OVERLAY_DIR` contains the overlays to be loaded on to the surfaces, `$OUTPUT_DIR` is the local output directory in which the snapshot will be saved, and `${LH/RH}_OVERLAY_FILE` point to the specific overlay files to load. - -**Note:** The `--init` flag to Docker is needed for the `xvfb-run` tool to be used correctly for off screen rendering. - - -## API Documentation - -The API Documentation can be found at https://deep-mi.org/WhipperSnapPy . - -## Links: - -We also invite you to check out our lab webpage at https://deep-mi.org +# WhipperSnapPy + +WhipperSnapPY is a small Python OpenGL program to render FreeSurfer and +FastSurfer surface models and color overlays and generate screen shots. + +## Contents: + +- Capture 4x4 surface plots (front & back, left and right) +- OpenGL window for interactive visualization (GUI) +- Interactive 3D viewer for Jupyter notebooks with mouse-controlled rotation + +## Installation: + +The `WhipperSnapPy` package can be installed from pypi via +``` +python3 -m pip install whippersnappy +``` + +Note, that currently no off-screen rendering is natively supported. Even in snap +mode an invisible window will be created to render the openGL output +and capture the contents to an image. In order to run this on a headless +server, inside Docker, or via ssh we recommend to install xvfb and run + +``` +sudo apt update && apt install -y python3 python3-pip xvfb libxcb-xinerama0 +pip3 install pyopengl glfw pillow numpy pyrr PyQt6 +pip3 install whippersnappy +xvfb-run whippersnap ... +``` + +## Usage: + +### Local: + +After installing the Python package, the whippersnap program can be run using +the installed command line tool such as in the following example: +``` +whippersnap -lh $OVERLAY_DIR/$LH_OVERLAY_FILE \ + -rh $OVERLAY_DIR/$RH_OVERLAY_FILE \ + -sd $SURF_SUBJECT_DIR \ + --fmax 4 --fthresh 2 --invert \ + --caption caption.txt \ + -o $OUTPUT_DIR/whippersnappy_image.png \ +``` + +For more options see `whippersnap --help`. +Note, that adding the `--interactive` flag will start an interactive GUI that +includes a visualization of one hemisphere side and a simple application through +which color threshold values can be configured. + +## Quick Imports + +```python +from whippersnappy import snap1, snap4, plot3d +``` + +### Jupyter Notebooks: + +WhipperSnapPy supports both static and **fully interactive 3D visualization** in Jupyter notebooks. + +#### Interactive 3D Plotting + +For **interactive mouse-controlled 3D rendering**: + +```bash +pip install 'whippersnappy[notebook]' +``` + +```python +from whippersnappy import plot3d +from IPython.display import display + +viewer = plot3d( + meshpath='/path/to/surf/lh.white', + curvpath='/path/to/surf/lh.curv', # curvature + overlaypath='/path/to/surf/lh.thickness', # optional: for colored overlays + labelpath='/path/to/label/lh.cortex', # optional: for masking + minval=0.0, + maxval=5.5, + width=800, + height=800, +) +display(viewer) +``` + +**Features:** +- ✅ Works in ALL Jupyter environments (browser, JupyterLab, Colab, VS Code) +- ✅ Mouse-controlled rotation, zoom, and pan +- ✅ Professional lighting (Three.js/WebGL) +- ✅ Supports overlays, annotations, and curvature +- ✅ Same technology Plotly uses for 3D plots + +#### Static Rendering + +For static publication-quality images: + +```python +from whippersnappy import snap1 +from whippersnappy.types import ViewType +from IPython.display import display + +img = snap1( + meshpath='/path/to/surf/lh.white', + overlaypath='/path/to/surf/lh.thickness', + curvpath='/path/to/surf/lh.curv', + view=ViewType.LEFT, # or RIGHT, FRONT, BACK, TOP, BOTTOM + width=800, + height=800, + brain_scale=1.5, + specular=True, +) +display(img) +``` + +**Benefits:** +- ✅ Full PyOpenGL control for custom lighting +- ✅ Publication-quality output +- ✅ Fast performance +- ✅ Identical to GUI version + +See `examples/whippersnappy_demo.ipynb` for complete examples. + +### Desktop GUI: + +For interactive desktop application with GUI controls: + +```bash +whippersnap --interactive -lh path/to/lh.white -rh path/to/rh.white +``` + +This launches a native desktop GUI (not a notebook) with sliders and controls. + +### Docker: + +The whippersnap program can be run within a docker container to capture +a snapshot by building the provided Docker image and running a container as +follows: +``` +docker build --rm=true -t whippersnappy -f ./Dockerfile . +``` +``` +docker run --rm --init --name my_whippersnappy -v $SURF_SUBJECT_DIR:/surf_subject_dir \ + -v $OVERLAY_DIR:/overlay_dir \ + -v $OUTPUT_DIR:/output_dir \ + --user $(id -u):$(id -g) whippersnappy:latest \ + --lh_overlay /overlay_dir/$LH_OVERLAY_FILE \ + --rh_overlay /overlay_dir/$RH_OVERLAY_FILE \ + --sdir /surf_subject_dir \ + --output_path /output_dir/whippersnappy_image.png +``` + +In this example: `$SURF_SUBJECT_DIR` contains the surface files, `$OVERLAY_DIR` contains the overlays to be loaded on to the surfaces, `$OUTPUT_DIR` is the local output directory in which the snapshot will be saved, and `${LH/RH}_OVERLAY_FILE` point to the specific overlay files to load. + +**Note:** The `--init` flag to Docker is needed for the `xvfb-run` tool to be used correctly for off-screen rendering. + + +## API Documentation + +The API Documentation can be found at https://deep-mi.org/WhipperSnapPy . + +## Links: + +We also invite you to check out our lab webpage at https://deep-mi.org diff --git a/examples/whippersnappy_demo.ipynb b/examples/whippersnappy_demo.ipynb new file mode 100644 index 0000000..f6e74dd --- /dev/null +++ b/examples/whippersnappy_demo.ipynb @@ -0,0 +1,260 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "dfc714ca8243e82", + "metadata": {}, + "outputs": [], + "source": [ + "# WhipperSnapPy Demo - Static & Interactive Rendering\n", + "# This notebook demonstrates both static and interactive 3D brain visualization.\n", + "#\n", + "# Installation:\n", + "# pip install 'whippersnappy[notebook]'\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "472f4e1b64de35a9", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from IPython.display import display\n", + "\n", + "from whippersnappy import plot3d, snap1, snap4\n", + "from whippersnappy.utils.types import ViewType\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82ee18d3ef6732df", + "metadata": {}, + "outputs": [], + "source": [ + "# Setup Paths\n", + "# Edit these to point to your FreeSurfer/FastSurfer data:\n", + "\n", + "# Set your subject directory here (either an absolute path to a subject's directory\n", + "# containing a `surf/` subdirectory, or None to use the SUBJECTS_DIR environment variable).\n", + "# Example: sdir = '/home/user/freesurfer/subjects/subject01'\n", + "# IMPORTANT: set this value before running the rest of the notebook.\n", + "sdir = '/path/to/your/subjectdir' # <-- update this to your SUBJECTS_DIR or subject path\n", + "\n", + "# Verify that sdir exists and is a directory\n", + "if not os.path.isdir(sdir):\n", + " raise ValueError(f\"Subject directory does not exist: {sdir}\\nPlease set `sdir` to a valid subject directory containing a 'surf/' subdirectory.\")\n", + "\n", + "# Derive per-hemisphere paths from the subject directory\n", + "lh_surf_path = os.path.join(sdir, 'surf', 'lh.white')\n", + "lh_thickness_path = os.path.join(sdir, 'surf', 'lh.thickness')\n", + "lh_curv_path = os.path.join(sdir, 'surf', 'lh.curv')\n", + "lh_label_path = os.path.join(sdir, 'label', 'lh.cortex.label')\n", + "lh_annot_path = os.path.join(sdir, 'label', 'lh.aparc.annot')\n", + "\n", + "rh_surf_path = os.path.join(sdir, 'surf', 'rh.white')\n", + "rh_thickness_path = os.path.join(sdir, 'surf', 'rh.thickness')\n", + "rh_curv_path = os.path.join(sdir, 'surf', 'rh.curv')\n", + "rh_label_path = os.path.join(sdir, 'label', 'rh.cortex.label')\n", + "rh_annot_path = os.path.join(sdir, 'label', 'rh.aparc.annot')\n", + "\n", + "# Preset overlay variables for convenience for Snap4\n", + "lh_overlay = lh_thickness_path if os.path.exists(lh_thickness_path) else None\n", + "rh_overlay = rh_thickness_path if os.path.exists(rh_thickness_path) else None\n", + "\n", + "print(f\"Subject dir: {sdir}\")\n", + "print(f\"Surface exists? LH: {os.path.exists(lh_surf_path)} | RH: {os.path.exists(rh_surf_path)}\")\n", + "print(f\"Thickness exists? LH: {os.path.exists(lh_thickness_path)} | RH: {os.path.exists(rh_thickness_path)}\")\n", + "print(f\"Curv exists? LH: {os.path.exists(lh_curv_path)} | RH: {os.path.exists(rh_curv_path)}\")\n", + "print(f\"Label exists? LH: {os.path.exists(lh_label_path)} | RH: {os.path.exists(rh_label_path)}\")\n", + "print(f\"Annot exists? LH: {os.path.exists(lh_thickness_path)} | RH: {os.path.exists(rh_thickness_path)}\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd25577540008818", + "metadata": {}, + "outputs": [], + "source": [ + "# Part 1: Static Rendering - Single View\n", + "# Generate publication-quality static images with full PyOpenGL control.\n", + "\n", + "# Render a single view\n", + "img = snap1(\n", + " meshpath=lh_surf_path,\n", + " overlaypath=lh_overlay,\n", + " curvpath=lh_curv_path if os.path.exists(lh_curv_path) else None,\n", + " view=ViewType.LEFT,\n", + " width=800,\n", + " height=800,\n", + " brain_scale=1.5, # Adjust to avoid cropping\n", + " specular=True, # Professional lighting\n", + ")\n", + "display(img)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43a6d7a79e9a2264", + "metadata": {}, + "outputs": [], + "source": [ + "# Part 1: Static Rendering - Snap4 (both hemispheres)\n", + "# Use snap4 to render front/back views for left and right hemispheres and display the combined image.\n", + "\n", + "print(f\"Using subjects dir: {sdir}\")\n", + "\n", + "# lh_overlay and rh_overlay are precomputed in the Setup cell\n", + "# Call snap4 and receive a PIL Image directly when outpath=None\n", + "img4 = snap4(\n", + " lhoverlaypath=lh_overlay,\n", + " rhoverlaypath=rh_overlay,\n", + " sdir=sdir,\n", + " caption='Snap4 - both hemispheres',\n", + " outpath=None, # return PIL image instead of writing to disk\n", + " specular=True,\n", + " brain_scale=1.8,\n", + ")\n", + "\n", + "# Display result (snap4 returns a PIL.Image when outpath is None)\n", + "if img4 is not None:\n", + " display(img4)\n", + "else:\n", + " print(\"snap4 did not return an image; check inputs and OpenGL context.\")\n" + ] + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "# Part 2: Interactive 3D Rendering\n", + "# Mouse-controlled 3D visualization using Three.js (works in all Jupyter environments).\n", + "#\n", + "# Controls:\n", + "# - Drag: Rotate\n", + "# - Scroll: Zoom\n", + "# - Right-drag: Pan\n", + "\n", + "# Interactive viewer with curvature (grayscale)\n", + "viewer = plot3d(\n", + " meshpath=lh_surf_path,\n", + " curvpath=lh_curv_path if os.path.exists(lh_curv_path) else None,\n", + " width=800,\n", + " height=800,\n", + ")\n", + "display(viewer)\n" + ], + "id": "ca7e2c155177b7c5" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "# Interactive colored overlay (if available)\n", + "if lh_overlay:\n", + " viewer = plot3d(\n", + " meshpath=lh_surf_path,\n", + " overlaypath=lh_overlay,\n", + " curvpath=lh_curv_path if os.path.exists(lh_curv_path) else None,\n", + " labelpath=lh_label_path,\n", + " minval=0.0, # Threshold\n", + " maxval=5.5, # Saturation\n", + " width=800,\n", + " height=800,\n", + " )\n", + " display(viewer)\n", + "else:\n", + " print(\"Thickness overlay not found - skipping colored example\")\n" + ], + "id": "f712c59cf4a54a0d" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "# Interactive label map overlay\n", + "if lh_annot_path:\n", + " viewer = plot3d(\n", + " meshpath=lh_surf_path,\n", + " annotpath=lh_annot_path,\n", + " curvpath=lh_curv_path if os.path.exists(lh_curv_path) else None,\n", + " width=800,\n", + " height=800,\n", + " )\n", + " display(viewer)\n", + "else:\n", + " print(\"Annot overlay not found - skipping label map example\")\n" + ], + "id": "9d5f61470c130e6e" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd1b1b7561ebecfa", + "metadata": {}, + "outputs": [], + "source": [ + "# Notes\n", + "#\n", + "# Static Rendering:\n", + "# - Returns PIL Image objects (no disk I/O needed)\n", + "# - Full PyOpenGL control for custom lighting\n", + "# - Publication-quality output\n", + "# - Fast and deterministic\n", + "#\n", + "# Interactive Rendering:\n", + "# - Uses Three.js/WebGL (runs in browser)\n", + "# - Works in all Jupyter environments\n", + "# - Full mouse control (rotate, zoom, pan)\n", + "# - Same technology as Plotly 3D plots\n", + "#\n", + "# Color Notes:\n", + "# - Curvature: Grayscale (sulci = dark, gyri = light) - this is correct!\n", + "# - Overlays: Colored heatmaps (thickness, activation, statistics)\n", + "# - Annotations: Distinct colored regions (parcellations)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a065ca6a14ecdda9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 88dafe6..751b3d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = 'setuptools.build_meta' [project] name = 'whippersnappy' -version = '1.4.0-dev' +version = '2.0.0-dev' description = 'A package to plot and capture FastSurfer and FreeSurfer-style surface overlays.' readme = 'README.md' license = {file = 'LICENSE'} @@ -66,6 +66,10 @@ doc = [ 'IPython', # For syntax highlighting in notebooks 'ipykernel', ] +notebook = [ + 'pythreejs', # Three.js for interactive 3D (works in all Jupyter environments) + 'ipywidgets', # Required for pythreejs +] style = [ 'bibclean', 'codespell', @@ -82,6 +86,7 @@ all = [ 'whippersnappy[doc]', 'whippersnappy[style]', 'whippersnappy[test]', + 'whippersnappy[notebook]', ] full = [ 'whippersnappy[all]', @@ -98,14 +103,14 @@ whippersnap = 'whippersnappy.cli:run' whippersnappy-sys_info = 'whippersnappy.commands.sys_info:run' [tool.setuptools] -include-package-data = false +include-package-data = true [tool.setuptools.packages.find] include = ['whippersnappy*'] exclude = ['whippersnappy*tests'] [tool.setuptools.package-data] -whippersnappy = ['*.ttf'] +whippersnappy = ['resources/fonts/*.ttf'] [tool.pydocstyle] convention = 'numpy' diff --git a/whippersnappy/utils/tests/__init__.py b/tests/__init__.py similarity index 100% rename from whippersnappy/utils/tests/__init__.py rename to tests/__init__.py diff --git a/whippersnappy/utils/tests/test_config.py b/tests/test_config.py similarity index 100% rename from whippersnappy/utils/tests/test_config.py rename to tests/test_config.py diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index f7f862b..8cb37cc 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -1,2 +1,70 @@ +"""WhipperSnapPy: Plot and capture FastSurfer and FreeSurfer-style surface overlays. + +WhipperSnapPy provides tools for rendering brain surface meshes with statistical +overlays and annotations. It includes: + +- **Static rendering**: `snap1()` and `snap4()` functions for publication-quality images +- **3D plotting**: For Jupyter notebooks with mouse-controlled 3D (via Three.js) +- **GUI**: Desktop application with `--interactive` flag +- **CLI tools**: Command-line interface for batch processing +- **Custom shaders**: Full control over OpenGL lighting and rendering + +For static image generation: + + from whippersnappy import snap1, snap4 + from whippersnappy.utils.types import ViewType + from IPython.display import display + + img = snap1(meshpath='path/to/surface.white', view=ViewType.LEFT) + display(img) + +For interactive 3D in Jupyter notebooks: + + # Requires: pip install 'whippersnappy[notebook]' + from whippersnappy import plot3d + + viewer = plot3d( + meshpath='path/to/surface.white', + curvpath='path/to/curv', + overlaypath='path/to/thickness.mgh' # optional: for colors + ) + display(viewer) + +For desktop GUI: + + # Command line + whippersnap --interactive -lh path/to/lh.white -rh path/to/rh.white + +Features: +- Works in ALL Jupyter environments (browser, JupyterLab, Colab, VS Code) +- Mouse-controlled rotation, zoom, and pan +- Professional lighting via Three.js/WebGL +- Same technology Plotly uses for 3D plots + +""" + +from .utils.types import ViewType + from ._version import __version__ # noqa: F401 +from .snap import snap1, snap4 from .utils._config import sys_info # noqa: F401 + +# 3D plotting for notebooks (Three.js-based, works in all Jupyter environments) +try: + from .plot3d import plot3d + _has_plot3d = True +except ImportError: + _has_plot3d = False + +# Export list +__all__ = [ + "__version__", + "sys_info", + "snap1", + "snap4", +] + +if _has_plot3d: + __all__.append("plot3d") +# Top-level convenience export for frequently used enum +__all__.append("ViewType") diff --git a/whippersnappy/_version.py b/whippersnappy/_version.py index 77f24d8..bf95b14 100644 --- a/whippersnappy/_version.py +++ b/whippersnappy/_version.py @@ -1,5 +1,8 @@ """Version number.""" -from importlib.metadata import version - -__version__ = version(__package__) +try: + from importlib.metadata import version + __version__ = version(__package__) +except Exception: + # Fallback when package is not installed (e.g., running from source) + __version__ = "1.4.0-dev" diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index 5af061e..1a44869 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -33,14 +33,13 @@ import pyrr from PyQt6.QtWidgets import QApplication -from whippersnappy.config_app import ConfigWindow -from whippersnappy.core import ( - get_surf_name, +from whippersnappy import snap4 +from whippersnappy.geometry import get_surf_name, prepare_geometry +from whippersnappy.gl import ( init_window, - prepare_geometry, setup_shader, - snap4, ) +from whippersnappy.gui import ConfigWindow # Global variables for config app configuration state: current_fthresh_ = None @@ -366,7 +365,7 @@ def run(): # pip3 install pyopengl glfw pillow numpy pyrr # xvfb-run python3 test4.py -# instead of the above one could really do headless off screen rendering via +# instead of the above one could really do headless off-screen rendering via # EGL (preferred) or OSMesa. The latter looks doable. EGL looks tricky. # EGL is part of any modern NVIDIA driver # OSMesa needs to be installed, but should work almost everywhere diff --git a/whippersnappy/core.py b/whippersnappy/core.py deleted file mode 100644 index 5201251..0000000 --- a/whippersnappy/core.py +++ /dev/null @@ -1,1498 +0,0 @@ -"""Contains the core functionalities of WhipperSnapPy. - -Dependencies: - numpy, glfw, pyrr, PyOpenGL, pillow - -@Author : Martin Reuter -@Created : 27.02.2022 -@Revised : 02.10.2025 - -""" - -import math -import os -import sys - -import glfw -import numpy as np -import OpenGL.GL as gl -import OpenGL.GL.shaders as shaders -import pyrr -from PIL import Image, ImageDraw, ImageFont - -from .read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data -from .types import ColorSelection, OrientationType, ViewType - - -def normalize_mesh(v, scale=1.0): - """ - Normalize mesh vertex coordinates. - - - Center their bounding box at the origin. - - Ensure that the longest side-length is equal to the scale variable (default 1). - - Parameters - ---------- - v : numpy.ndarray - Vertex array (Nvert X 3). - scale : float - Scaling constant. - - Returns - ------- - v: numpy.ndarray - Normalized vertex array (Nvert X 3). - """ - # center bounding box at origin - # scale longest side to scale (default 1) - bbmax = np.max(v, axis=0) - bbmin = np.min(v, axis=0) - v = v - 0.5 * (bbmax + bbmin) - v = scale * v / np.max(bbmax - bbmin) - return v - - -# adopted from lapy -def vertex_normals(v, t): - """ - Compute vertex normals. - - Triangle normals around each vertex are averaged, weighted by the angle - that they contribute. - Vertex ordering is important in t: counterclockwise when looking at the - triangle from above, so that normals point outwards. - - Parameters - ---------- - v : numpy.ndarray - Vertex array (Nvert X 3). - t : numpy.ndarray - Triangle array (Ntria X 3). - - Returns - ------- - normals: numpy.ndarray - Normals array: n - normals (Nvert X 3). - """ - # Compute vertex coordinates and a difference vector for each triangle: - v0 = v[t[:, 0], :] - v1 = v[t[:, 1], :] - v2 = v[t[:, 2], :] - v1mv0 = v1 - v0 - v2mv1 = v2 - v1 - v0mv2 = v0 - v2 - # Compute cross product at every vertex - # will point into the same direction with lengths depending on spanned area - cr0 = np.cross(v1mv0, -v0mv2) - cr1 = np.cross(v2mv1, -v1mv0) - cr2 = np.cross(v0mv2, -v2mv1) - # Add normals at each vertex (there can be duplicate indices in t at vertex i) - n = np.zeros(v.shape) - np.add.at(n, t[:, 0], cr0) - np.add.at(n, t[:, 1], cr1) - np.add.at(n, t[:, 2], cr2) - # Normalize normals - ln = np.sqrt(np.sum(n * n, axis=1)) - ln[ln < sys.float_info.epsilon] = 1 # avoid division by zero - n = n / ln.reshape(-1, 1) - return n - - -def heat_color(values, invert=False): - """ - Convert an array of float values into RBG heat color values. - - Only values between -1 and 1 will receive gradient and colors will - max-out at -1 and 1. Negative values will be blue and positive - red (unless invert is passed to flip the heatmap). Masked values - (nan) will map to masked colors (nan,nan,nan). - - Parameters - ---------- - values : numpy.ndarray - Float values of function on the surface mesh (length Nvert). - invert : bool - Whether to invert the heat map (blue is positive and red negative). - - Returns - ------- - colors: numpy.ndarray - (Nvert x 3) array of RGB of heat map as 0.0 .. 1.0 floats. - """ - # values (1 dim array length n) will receive gradient between -1 and 1 - # nan will return (nan,nan,nan) - # returns colors (r,g,b) as n x 3 array - if invert: - values = -1.0 * values - vabs = np.abs(values) - colors = np.zeros((vabs.size, 3), dtype=np.float32) - crb = 0.5625 + 3 * 0.4375 * vabs - cg = 1.5 * (vabs - (1.0 / 3.0)) - n1 = values < -1.0 - nm = (values >= -1.0) & (values < -(1.0 / 3.0)) - n0 = (values >= -(1.0 / 3.0)) & (values < 0) - p0 = (values >= 0) & (values < (1.0 / 3.0)) - pm = (values >= (1.0 / 3.0)) & (values < 1.0) - p1 = values >= 1.0 - # fill in colors for the 5 blocks - colors[n1, 1:3] = 1.0 # bright blue - colors[nm, 1] = cg[nm] # cg increasing green channel - colors[nm, 2] = 1.0 # and keeping blue on full - colors[n0, 2] = crb[n0] # crb increasing blue channel - colors[p0, 0] = crb[p0] # crb increasing red channel - colors[pm, 1] = cg[pm] # cg increasing green channel - colors[pm, 0] = 1.0 # and keeping red on full - colors[p1, 0:2] = 1.0 # yellow - colors[np.isnan(values), :] = np.nan - return colors - -def mask_sign(values, color_mode): - """ - Mask values don't have the same sign as the color_mode. - - The masked values will be replaced by nan. - - Parameters - ---------- - values : numpy.ndarray - Float values of function on the surface mesh (length Nvert). - color_mode : ColorSelection - Select which values to color, can be ColorSelection.BOTH, ColorSelection.POSITIVE - or ColorSelection.NEGATIVE. Default: ColorSelection.BOTH. - - Returns - ------- - values: numpy.ndarray - Float array of input function on mesh (length Nvert). - """ - masked_values = np.copy(values) - if color_mode == ColorSelection.POSITIVE: - masked_values[masked_values < 0] = np.nan - elif color_mode == ColorSelection.NEGATIVE: - masked_values[masked_values > 0] = np.nan - return masked_values - -def rescale_overlay(values, minval=None, maxval=None): - """ - Rescale values for color map computation. - - minval and maxval are two positive floats (maxval>minval). - Values between -minval and minval will be masked (np.nan); - others will be shifted towards zero (from both sides) - and scaled so that -maxval and maxval are at -1 and +1. - - Parameters - ---------- - values : numpy.ndarray - Float values of function on the surface (length Nvert). - minval : float - Minimum value. - maxval : float - Maximum value. - - Returns - ------- - values: numpy.ndarray - Float array of input function on mesh (length Nvert). - minval: float - Positive minimum value (crop values whose absolute value is below). - maxval: float - Positive maximum value (saturate color at maxval and -maxval). - pos: bool - Whether positive values are present at all after cropping. - neg: bool - Whether negative values are present at all after cropping. - """ - valsign = np.sign(values) - valabs = np.abs(values) - - if maxval < 0 or minval < 0: - print("resacle_overlay ERROR: min and maxval should both be positive!") - exit(1) - - # Mask values below minval - values[valabs < minval] = np.nan - - # Rescale map symmetrically to -1 .. 1 with the minval = 0 - # Any arithmetic operation containing NaN values results in NaN - range_val = maxval - minval - if range_val == 0: - values = np.zeros_like(values) - else: - values = values - valsign * minval - values = values / range_val - - # Check if there are any positive or negative values - pos = np.any(values[~np.isnan(values)] > 0) - neg = np.any(values[~np.isnan(values)] < 0) - - return values, minval, maxval, pos, neg - - -def binary_color(values, thres, color_low, color_high): - """ - Create a binary colormap based on a threshold value. - - This function assigns colors to input values based on whether they are - below or equal to the threshold (thres) or greater than the threshold. - - Values below thres are color_low, others are color_high. - color_low and color_high can be float (gray scale), or 1x3 array of RGB. - - Parameters - ---------- - values : numpy.ndarray - Input vertex function as float array (length Nvert). - thres : float - Threshold value. - color_low : float or numpy.ndarray - Lower color value(s). - color_high : float or numpy.ndarray - Higher color value(s). - - Returns - ------- - colors : numpy.ndarray - Binary colormap. - """ - if np.isscalar(color_low): - color_low = np.array((color_low, color_low, color_low), dtype=np.float32) - if np.isscalar(color_high): - color_high = np.array((color_high, color_high, color_high), dtype=np.float32) - colors = np.empty((values.size, 3), dtype=np.float32) - colors[values < thres, :] = color_low - colors[values >= thres, :] = color_high - return colors - - -def mask_label(values, labelpath=None): - """ - Apply a labelfile as a mask. - - Labelfile freesurfer format has indices of values that should be kept; - all other values will be set to np.nan. - - Parameters - ---------- - values : numpy.ndarray - Float values of function defined at vertices (a 1-dim array). - labelpath : str - Absolute path to label file. - - Returns - ------- - values: numpy.ndarray - Masked surface function values. - """ - if not labelpath: - return values - # this is the mask of vertices to keep, e.g. cortex labels - maskvids = np.loadtxt(labelpath, dtype=int, skiprows=2, usecols=[0]) - imask = np.ones(values.shape, dtype=bool) - imask[maskvids] = False - values[imask] = np.nan - return values - - -def prepare_geometry( - surfpath, - overlaypath=None, - annotpath=None, - curvpath=None, - labelpath=None, - minval=None, - maxval=None, - invert=False, - scale=1.85, - color_mode=ColorSelection.BOTH -): - """ - Prepare meshdata for upload to GPU. - - Vertex coordinates, vertex normals and color values are concatenated into - large vertexdata array. Also returns triangles, minimum and maximum overlay - values as well as whether negative values are present or not in triangles. - - Parameters - ---------- - surfpath : str - Path to surface file (usually lh or rh.pial_semi_inflated). - overlaypath : str - Path to overlay file. - annotpath : str - Path to annotation file. - curvpath : str - Path to curvature file (usually lh or rh.curv). - labelpath : str - Path to label file (mask; usually cortex.label). - minval : float - Minimum threshold to stop coloring (-minval used for neg values). - maxval : float - Maximum value to saturate (-maxval used for negative values). - invert : bool - Invert color map. - scale : float - Global scaling factor. Default: 1.85. - color_mode : ColorSelection - Select which values to color, can be ColorSelection.BOTH, ColorSelection.POSITIVE - or ColorSelection.NEGATIVE. Default: ColorSelection.BOTH. - - Returns - ------- - vertexdata: numpy.ndarray - Concatenated array with vertex coords, vertex normals and colors - as a (Nvert X 9) float32 array. - triangles: numpy.ndarray - Triangle array as a (Ntria X 3) uint32 array. - fmin: float - Minimum value of overlay function after rescale. - fmax: float - Maximum value of overlay function after rescale. - pos: bool - Whether positive values are there after rescale/cropping. - neg: bool - Whether negative values are there after rescale/cropping. - """ - - # read vertices and triangles - surf = read_geometry(surfpath, read_metadata=False) - vertices = normalize_mesh(np.array(surf[0], dtype=np.float32), scale) - triangles = np.array(surf[1], dtype=np.uint32) - # compute vertex normals - vnormals = np.array(vertex_normals(vertices, triangles), dtype=np.float32) - # read curvature - if curvpath: - curv = read_morph_data(curvpath) - sulcmap = binary_color(curv, 0.0, color_low=0.5, color_high=0.33) - else: - # if no curv pattern, color mesh in mid-gray - sulcmap = 0.5 * np.ones(vertices.shape, dtype=np.float32) - # read map (stats etc) or annotation - if overlaypath: - _, file_extension = os.path.splitext(overlaypath) - - if file_extension == ".mgh": - mapdata = read_mgh_data(overlaypath) - else: - mapdata = read_morph_data(overlaypath) - - valabs = np.abs(mapdata) - if maxval is None: - maxval = np.max(valabs) if np.any(valabs) else 0 - if minval is None: - minval = max(0.0, np.min(valabs) if np.any(valabs) else 0) - - # Mask map and get either positive and/or negative values - mapdata = mask_sign(mapdata, color_mode) - - # Rescale the map with minval and maxval - mapdata, fmin, fmax, pos, neg = rescale_overlay(mapdata, minval, maxval) - - # mask map with label - mapdata = mask_label(mapdata, labelpath) - - # compute color - colors = heat_color(mapdata, invert) - - missing = np.isnan(mapdata) - colors[missing, :] = sulcmap[missing, :] - elif annotpath: - annot, ctab, names = read_annot_data(annotpath) - # compute color - colors = ctab[annot, 0:3] / np.max(ctab[:, 0:3]) - # annot can contain -1 indices; these indicate non-annotated - # regions, but are valid indices in Python; need to recode - # them as missing - colors[annot==-1,:] = sulcmap[annot==-1,:] - colors = colors.astype(np.float32) - fmin = None - fmax = None - pos = None - neg = None - else: - colors = sulcmap - fmin = None - fmax = None - pos = None - neg = None - # concatenate matrices - vertexdata = np.concatenate((vertices, vnormals, colors), axis=1) - return vertexdata, triangles, fmin, fmax, pos, neg - - -def init_window(width, height, title="PyOpenGL", visible=True): - """ - Create window with width, height, title. - - If visible False, hide window. - - Parameters - ---------- - width : int - Window width. - height : int - Window height. - title : str - Window title. - visible : bool - Window visibility. - - Returns - ------- - window: glfw.LP__GLFWwindow - GUI window. - """ - if not glfw.init(): - return False - - glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 3) - glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 3) - glfw.window_hint(glfw.OPENGL_FORWARD_COMPAT, True) - glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE) - if not visible: - glfw.window_hint(glfw.VISIBLE, glfw.FALSE) - window = glfw.create_window(width, height, title, None, None) - if not window: - glfw.terminate() - return False - # Enable key events - glfw.set_input_mode(window, glfw.STICKY_KEYS, gl.GL_TRUE) - # Enable key event callback - # glfw.set_key_callback(window,key_event) - glfw.make_context_current(window) - # vsync and glfw do not play nice. when vsync is enabled mouse movement is jittery. - glfw.swap_interval(0) - return window - - -def setup_shader(meshdata, triangles, width, height, specular=True, ambient=0.0): - """ - Create vertex and fragment shaders. - - Set up data and parameters (such as the initial view matrix) on the GPU. - - In meshdata: - - the first 3 columns are the vertex coordinates - - the next 3 columns are the vertex normals - - the final 3 columns are the color RGB values - - Parameters - ---------- - meshdata : numpy.ndarray - Mesh array (shape: n x 9, dtype: np.float32). - triangles : bool - Triangle indices array (shape: m x 3). - width : int - Window width (to set perspective projection). - height : int - Window height (to set perspective projection). - specular : Boolean - By default specular is set as True. - ambient : float - Ambient light strength, by default 0: use only diffuse light sources. - - Returns - ------- - shader: ShaderProgram - Compiled OpenGL shader program. - """ - - VERTEX_SHADER = """ - - #version 330 - - layout (location = 0) in vec3 aPos; - layout (location = 1) in vec3 aNormal; - layout (location = 2) in vec3 aColor; - - out vec3 FragPos; - out vec3 Normal; - out vec3 Color; - - uniform mat4 transform; - uniform mat4 model; - uniform mat4 view; - uniform mat4 projection; - - void main() - { - gl_Position = projection * view * model * transform * vec4(aPos, 1.0f); - FragPos = vec3(model * transform * vec4(aPos, 1.0)); - // normal matrix should be computed outside and passed! - Normal = mat3(transpose(inverse(view * model * transform))) * aNormal; - Color = aColor; - } - - """ - - FRAGMENT_SHADER = """ - #version 330 - - in vec3 Normal; - in vec3 FragPos; - in vec3 Color; - - out vec4 FragColor; - - uniform vec3 lightColor = vec3(1.0, 1.0, 1.0); - uniform bool doSpecular = true; - uniform float ambientStrength = 0.0; - - void main() - { - // ambient - vec3 ambient = ambientStrength * lightColor; - - // diffuse - vec3 norm = normalize(Normal); - // values for overhead, front, below, back lights - //vec4 diffweights = vec4(0.4, 0.6, 0.4, 0.4); //more light below - vec4 diffweights = vec4(0.6, 0.4, 0.4, 0.3); //orig more shadow - - // key light (overhead) - vec3 lightPos1 = vec3(0.0,5.0,5.0); - vec3 lightDir = normalize(lightPos1 - FragPos); - float diff = max(dot(norm, lightDir), 0.0); - vec3 diffuse = diffweights[0] * diff * lightColor; - - // headlight (at camera) - vec3 lightPos2 = vec3(0.0,0.0,5.0); - lightDir = normalize(lightPos2 - FragPos); - vec3 ohlightDir = lightDir; // needed for specular - diff = max(dot(norm, lightDir), 0.0); - diffuse = diffuse + diffweights[1] * diff * lightColor; - - // fill light (from below) - vec3 lightPos3 = vec3(0.0,-5.0,5.0); - lightDir = normalize(lightPos3 - FragPos); - diff = max(dot(norm, lightDir), 0.0); - diffuse = diffuse + diffweights[2] * diff * lightColor; - - // left right back lights (both are same brightness) - vec3 lightPos4 = vec3(5.0,0.0,-5.0); - lightDir = normalize(lightPos4 - FragPos); - diff = max(dot(norm, lightDir), 0.0); - diffuse = diffuse + diffweights[3] * diff * lightColor; - - vec3 lightPos5 = vec3(-5.0,0.0,-5.0); - lightDir = normalize(lightPos5 - FragPos); - diff = max(dot(norm, lightDir), 0.0); - diffuse = diffuse + diffweights[3] * diff * lightColor; - - // specular - vec3 result; - if (doSpecular) - { - float specularStrength = 0.5; - // the viewer is always at (0,0,0) in view-space, - // so viewDir is (0,0,0) - Position => -Position - vec3 viewDir = normalize(-FragPos); - vec3 reflectDir = reflect(ohlightDir, norm); - float spec = pow(max(dot(viewDir, reflectDir), 0.0), 32); - vec3 specular = specularStrength * spec * lightColor; - // final color - result = (ambient + diffuse + specular) * Color; - } - else - { - // final color no specular - result = (ambient + diffuse) * Color; - } - FragColor = vec4(result, 1.0); - } - - """ - - # Create Vertex Buffer object in gpu - VBO = gl.glGenBuffers(1) - # Bind the buffer - gl.glBindBuffer(gl.GL_ARRAY_BUFFER, VBO) - gl.glBufferData(gl.GL_ARRAY_BUFFER, meshdata.nbytes, meshdata, gl.GL_STATIC_DRAW) - - # Create Vertex Array object - VAO = gl.glGenVertexArrays(1) - # Bind array - gl.glBindVertexArray(VAO) - gl.glBufferData(gl.GL_ARRAY_BUFFER, meshdata.nbytes, meshdata, gl.GL_STATIC_DRAW) - - # Create Element Buffer Object - EBO = gl.glGenBuffers(1) - gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, EBO) - gl.glBufferData( - gl.GL_ELEMENT_ARRAY_BUFFER, triangles.nbytes, triangles, gl.GL_STATIC_DRAW - ) - - # Compile The Program and shaders - shader = gl.shaders.compileProgram( - shaders.compileShader(VERTEX_SHADER, gl.GL_VERTEX_SHADER), - shaders.compileShader(FRAGMENT_SHADER, gl.GL_FRAGMENT_SHADER), - ) - - # get the position from shader - position = gl.glGetAttribLocation(shader, "aPos") - gl.glVertexAttribPointer( - position, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(0) - ) - gl.glEnableVertexAttribArray(position) - - vnormalpos = gl.glGetAttribLocation(shader, "aNormal") - gl.glVertexAttribPointer( - vnormalpos, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(3 * 4) - ) - gl.glEnableVertexAttribArray(vnormalpos) - - colorpos = gl.glGetAttribLocation(shader, "aColor") - gl.glVertexAttribPointer( - colorpos, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(6 * 4) - ) - gl.glEnableVertexAttribArray(colorpos) - - gl.glUseProgram(shader) - - gl.glClearColor(0.0, 0.0, 0.0, 1.0) - gl.glEnable(gl.GL_DEPTH_TEST) - - # Creating Projection Matrix - view = pyrr.matrix44.create_from_translation(pyrr.Vector3([0.0, 0.0, -5.0])) - projection = pyrr.matrix44.create_perspective_projection( - 20.0, width / height, 0.1, 100.0 - ) - model = pyrr.matrix44.create_from_translation(pyrr.Vector3([0.0, 0.0, 0.0])) - - # Set matrices in vertex shader - view_loc = gl.glGetUniformLocation(shader, "view") - proj_loc = gl.glGetUniformLocation(shader, "projection") - model_loc = gl.glGetUniformLocation(shader, "model") - gl.glUniformMatrix4fv(view_loc, 1, gl.GL_FALSE, view) - gl.glUniformMatrix4fv(proj_loc, 1, gl.GL_FALSE, projection) - gl.glUniformMatrix4fv(model_loc, 1, gl.GL_FALSE, model) - - # setup doSpecular in fragment shader - specular_loc = gl.glGetUniformLocation(shader, "doSpecular") - gl.glUniform1i(specular_loc, specular) - - # setup light color in fragment shader - lightColor_loc = gl.glGetUniformLocation(shader, "lightColor") - gl.glUniform3f(lightColor_loc, 1.0, 1.0, 1.0) - - # setup ambient light strength (default=0) - ambientLight_loc = gl.glGetUniformLocation(shader, "ambientStrength") - gl.glUniform1f(ambientLight_loc, ambient) - - return shader - - -def capture_window(width, height): - """ - Capture the GL region (0,0) .. (width,height) into PIL Image. - - Parameters - ---------- - width : int - Window width. - height : int - Window height. - - Returns - ------- - image: PIL.Image.Image - Captured image. - """ - if sys.platform == "darwin": - # not sure why on mac the drawing area is 4 times as large (2x2): - width = 2 * width - height = 2 * height - gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) # may not be needed - img_buf = gl.glReadPixels(0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE) - image = Image.frombytes("RGB", (width, height), img_buf) - image = image.transpose(Image.FLIP_TOP_BOTTOM) - if sys.platform == "darwin": - image.thumbnail((0.5 * width, 0.5 * height), Image.Resampling.LANCZOS) - return image - -def text_size(caption, font): - """ - Get the size of the text. - - Parameters - ---------- - caption : str - Text that is to be rendered. - font : PIL.ImageFont.FreeTypeFont - Font of the labels. - - Returns - ------- - text_width: int - Width of the text in pixels. - text_height: int - Height of the text in pixels. - """ - dummy_img = Image.new("L", (1, 1)) - draw = ImageDraw.Draw(dummy_img) - bbox = draw.textbbox((0, 0), caption, font=font, anchor="lt") - text_width = bbox[2] - bbox[0] - text_height = bbox[3] - bbox[1] - return text_width, text_height - -def get_colorbar_label_positions( - font, - labels, - colorbar_rect, - gapspace=0, - pos=True, - neg=True, - orientation=OrientationType.HORIZONTAL -): - """ - Get the positions of the labels for the colorbar. - - Parameters - ---------- - font : PIL.ImageFont.FreeTypeFont - Font of the labels. - labels : dict - Label texts that are to be rendered. - colorbar_rect : tuple - The coordinate values of the colorbar edges. - gapspace : int - Length of the gray space representing the threshold. Default : 0. - pos : bool - Show positive axis. Default: True. - neg : bool - Show negative axis. Default: True. - orientation : OrientationType - Orientation of the colorbar, can be OrientationType.HORIZONTAL or - OrientationType.VERTICAL. Default : OrientationType.HORIZONTAL. - - Returns - ------- - positions: dict - Positions of all labels. - """ - positions = {} - cb_x, cb_y, cb_width, cb_height = colorbar_rect - cb_labels_gap = 5 - - if orientation == OrientationType.HORIZONTAL: - label_y = cb_y + cb_height + cb_labels_gap - - # Upper - w, h = text_size(labels["upper"], font) - if pos: - positions["upper"] = (cb_x + cb_width - w, label_y) - else: - upper_x = cb_x + cb_width - w - int(gapspace) if gapspace > 0 else cb_x + cb_width - w - positions["upper"] = (upper_x, label_y) - - # Lower - w, h = text_size(labels["lower"], font) - if neg: - positions["lower"] = (cb_x, label_y) - else: - lower_x = cb_x + int(gapspace) if gapspace > 0 else cb_x - positions["lower"] = (lower_x, label_y) - - # Middle - if neg and pos: - if gapspace == 0: - # Single middle - w, h = text_size(labels["middle"], font) - positions["middle"] = (cb_x + cb_width // 2 - w // 2, label_y) - else: - # Middle Negative - w, h = text_size(labels["middle_neg"], font) - positions["middle_neg"] = (cb_x + cb_width // 2 - w - int(gapspace), label_y) - - # Middle Positive - w, h = text_size(labels["middle_pos"], font) - positions["middle_pos"] = (cb_x + cb_width // 2 + int(gapspace), label_y) - - else: # orientation == OrientationType.VERTICAL - label_x = cb_x + cb_width + cb_labels_gap - - # Upper - w, h = text_size(labels["upper"], font) - if pos: - positions["upper"] = (label_x, cb_y) - else: - upper_y = cb_y + int(gapspace) if gapspace > 0 else cb_y - positions["upper"] = (label_x, upper_y) - - # Lower - w, h = text_size(labels["lower"], font) - if neg: - positions["lower"] = (label_x, cb_y + cb_height - 1.5 * h) - else: - lower_y = cb_y + cb_height - int(gapspace) - 1.5 * h if gapspace > 0 else cb_y + cb_height - 1.5 * h - positions["lower"] = (label_x, lower_y) - - # Middle labels - if neg and pos: - if gapspace == 0: - # Single middle - w, h = text_size(labels["middle"], font) - positions["middle"] = (label_x, cb_y + cb_height // 2 - h // 2) - else: - # Middle Positive - w, h = text_size(labels["middle_pos"], font) - positions["middle_pos"] = (label_x, cb_y + cb_height // 2 - 1.5 * h - int(gapspace)) - - # Middle Negative - w, h = text_size(labels["middle_neg"], font) - positions["middle_neg"] = (label_x, cb_y + cb_height // 2 + int(gapspace)) - - return positions - -def create_colorbar( - fmin, - fmax, - invert, - orientation=OrientationType.HORIZONTAL, - colorbar_scale=1, - pos=True, - neg=True, - font_file=None -): - """ - Create colorbar image with text indicating min and max values. - - Parameters - ---------- - fmin : int - Absolute min value that receives color (threshold). - fmax : int - Absolute max value where color saturates. - invert : bool - Color invert. - orientation : OrientationType - Orientation of the colorbar, can be OrientationType.HORIZONTAL or - OrientationType.VERTICAL. Default : OrientationType.HORIZONTAL. - colorbar_scale : number - Colorbar scaling factor. Default: 1. - pos : bool - Show positive axis. - neg : bool - Show negative axis. - font_file : str - Path to the file describing the font to be used. - - Returns - ------- - image: PIL.Image.Image - Colorbar image. - """ - cwidth = int(200 * colorbar_scale) - cheight = int(30 * colorbar_scale) - gapspace = 0 - - # Add gray gap if needed - if fmin > 0.01: - # Leave gray gap - num = int(0.42 * cwidth) - gapspace = 0.08 * cwidth - else: - num = int(0.5 * cwidth) - if not neg or not pos: - num = num * 2 - gapspace = gapspace * 2 - - # Set the values for the colorbar - values = np.nan * np.ones(cwidth) - steps = np.linspace(0.01, 1, num) - if pos and not neg: - values[-steps.size :] = steps - elif not pos and neg: - values[: steps.size] = -1.0 * np.flip(steps) - else: - values[: steps.size] = -1.0 * np.flip(steps) - values[-steps.size :] = steps - - # Set the colors - colors = heat_color(values, invert) - colors[np.isnan(values), :] = 0.33 * np.ones((1, 3)) - img_bar = np.uint8(np.tile(colors, (cheight, 1, 1)) * 255) - - # Pad with black - pad_top, pad_left = 3, 10 - img_buf = np.zeros((cheight + 2 * pad_top, cwidth + 2 * pad_left, 3), dtype=np.uint8) - img_buf[pad_top : cheight + pad_top, pad_left : cwidth + pad_left, :] = img_bar - image = Image.fromarray(img_buf) - - # Get the font for the labels - if font_file is None: - script_dir = "/".join(str(__file__).split("/")[:-1]) - font_file = os.path.join(script_dir, "Roboto-Regular.ttf") - font = ImageFont.truetype(font_file, int(12 * colorbar_scale)) - - # Labels for the colorbar - labels = {} - labels["upper"] = f">{fmax:.2f}" if pos else (f"{-fmin:.2f}" if gapspace != 0 else "0") - labels["lower"] = f"<{-fmax:.2f}" if neg else (f"{fmin:.2f}" if gapspace != 0 else "0") - if neg and pos and gapspace != 0: - labels["middle_neg"] = f"{-fmin:.2f}" - labels["middle_pos"] = f"{fmin:.2f}" - elif neg and pos and gapspace == 0: - labels["middle"] = "0" - - # Maximum caption sizes - caption_sizes = [text_size(caption, font) for caption in labels.values()] - max_caption_width = int(max([caption_size[0] for caption_size in caption_sizes])) - max_caption_height = int(max([caption_size[1] for caption_size in caption_sizes])) - - # Extend colorbar image by the maximum caption size to fit the labels and rotate image if needed - if orientation == OrientationType.VERTICAL: - image = image.rotate(90, expand=True) - - new_width = image.width + int(max_caption_width) - new_image = Image.new("RGB", (new_width, image.height), (0, 0, 0)) - new_image.paste(image, (0, 0)) - image = new_image - - colorbar_rect = (pad_top, pad_left, cheight, cwidth) - else: - new_height = image.height + int(max_caption_height * 2) - new_image = Image.new("RGB", (image.width, new_height), (0, 0, 0)) - new_image.paste(image, (0, 0)) - image = new_image - - colorbar_rect = (pad_left, pad_top, cwidth, cheight) - - # Get positions of the labels - positions = get_colorbar_label_positions(font, labels, colorbar_rect, gapspace, pos, neg, orientation) - - # Draw the labels - draw = ImageDraw.Draw(image) - for label_key, position in positions.items(): - draw.text((int(position[0]), int(position[1])), labels[label_key], fill=(220, 220, 220), font=font) - - return image - -def snap1( - meshpath, - outpath, - overlaypath=None, - annotpath=None, - labelpath=None, - curvpath=None, - view=ViewType.LEFT, - viewmat=None, - width=None, - height=None, - fthresh=None, - fmax=None, - caption=None, - caption_x=None, - caption_y=None, - caption_scale=1, - invert=False, - colorbar=True, - colorbar_x=None, - colorbar_y=None, - colorbar_scale=1, - orientation=OrientationType.HORIZONTAL, - color_mode=ColorSelection.BOTH, - font_file=None, - specular=True, - brain_scale=1, - ambient=0.0, -): - """ - Snap one view (view and hemisphere is determined by the user). - - Colorbar, caption, and saving are optional. - - Parameters - ---------- - meshpath : str - Path to the surface file (FreeSurfer format). - outpath : str - Path to the output image file. - overlaypath : str - Path to the overlay file (FreeSurfer format). - annotpath : str - Path to the annotation file (FreeSurfer format). - labelpath : str - Path to the label file (FreeSurfer format). - curvpath : str - Path to the curvature file for texture in non-colored regions. - view : ViewType - Predefined views, can be ViewType.LEFT, ViewType.RIGHT, ViewType.BACK, - ViewType.FRONT, ViewType.TOP or ViewType.BOTTOM. Default: ViewType.LEFT. - viewmat : array-like - User-defined 4x4 viewing matrix. Overwrites view. - width : number - Width of the image. Default: automatically chosen. - height : number - Height of the image. Default: automatically chosen. - fthresh : float - Pos absolute value under which no color is shown. - fmax : float - Pos absolute value above which color is saturated. - caption : str - Caption text to be placed on the image. - caption_x : number - Normalized horizontal position of the caption. Default: automatically chosen. - caption_y : number - Normalized vertical position of the caption. Default: automatically chosen. - caption_scale : number - Caption scaling factor. Default: 1. - invert : bool - Invert color (blue positive, red negative). - colorbar : bool - Show colorbar on image. - colorbar_x : number - Normalized horizontal position of the colorbar. Default: automatically chosen. - colorbar_y : number - Normalized vertical position of the colorbar. Default: automatically chosen. - colorbar_scale : number - Colorbar scaling factor. Default: 1. - orientation : OrientationType - Orientation of the colorbar and caption, can be OrientationType.VERTICAL or - OrientationType.HORIZONTAL. Default: OrientationType.HORIZONTAL. - color_mode : ColorSelection - Select which values to color, can be ColorSelection.BOTH, ColorSelection.POSITIVE - or ColorSelection.NEGATIVE. Default: ColorSelection.BOTH. - font_file : str - Path to the file describing the font to be used in captions. - specular : bool - Specular is by default set as True. - brain_scale : float - Brain scaling factor. Default: 1. - ambient : float - Ambient light, default 0, only use diffuse light sources. - - Returns - ------- - None - This function returns None. - """ - # Setup base image - REFWWIDTH = 700 - REFWHEIGHT = 500 - WWIDTH = REFWWIDTH if width is None else width - WHEIGHT = REFWHEIGHT if height is None else height - UI_SCALE = min(WWIDTH / REFWWIDTH, WHEIGHT / REFWHEIGHT) - - # Check screen resolution - if not glfw.init(): - print( - "[ERROR] Could not init glfw!" - ) - sys.exit(1) - primary_monitor = glfw.get_primary_monitor() - mode = glfw.get_video_mode(primary_monitor) - screen_width = mode.size.width - screen_height = mode.size.height - if WWIDTH > screen_width: - print( - f"[INFO] Requested width {WWIDTH} exceeds screen width {screen_width}, expect black bars" - ) - elif WHEIGHT > screen_height: - print( - f"[INFO] Requested height {WHEIGHT} exceeds screen height {screen_height}, expect black bars" - ) - - # Create the base image - image = Image.new("RGB", (WWIDTH, WHEIGHT)) - - # Setup brain image - # (keep aspect ratio, as the mesh scale and distances are set accordingly) - BWIDTH = int(540 * brain_scale * UI_SCALE) - BHEIGHT = int(450 * brain_scale * UI_SCALE) - brain_display_width = min(BWIDTH, WWIDTH) - brain_display_height = min(BHEIGHT, WHEIGHT) - - visible = True - window = init_window(brain_display_width, brain_display_height, "WhipperSnapPy 2.0", visible) - if not window: - return False # need raise error here in future - - viewLeft = np.array([[ 0, 0,-1, 0], [-1, 0, 0, 0], [ 0, 1, 0, 0], [ 0, 0, 0, 1]]) # left w top up // right - viewRight = np.array([[ 0, 0, 1, 0], [ 1, 0, 0, 0], [ 0, 1, 0, 0], [ 0, 0, 0, 1]]) # right w top up // right - viewBack = np.array([[ 1, 0, 0, 0], [ 0, 0,-1, 0], [ 0, 1, 0, 0], [ 0, 0, 0, 1]]) # back w top up // back - viewFront = np.array([[-1, 0 ,0, 0], [ 0, 0, 1, 0], [ 0, 1, 0, 0], [ 0, 0, 0, 1]]) # front w top up // front - viewBottom = np.array([[-1, 0, 0, 0], [ 0, 1, 0, 0], [ 0, 0,-1, 0], [ 0, 0, 0, 1]]) # bottom ant up // bottom - viewTop = np.array([[ 1, 0, 0, 0], [ 0, 1, 0, 0], [ 0, 0, 1, 0], [ 0, 0, 0, 1]]) # top w ant up // top - - transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) - - # Load and colorize data - meshdata, triangles, fthresh, fmax, pos, neg = prepare_geometry( - meshpath, overlaypath, annotpath, curvpath, labelpath, fthresh, fmax, invert, - scale=brain_scale, color_mode=color_mode - ) - - # Check if there is data to display - if overlaypath is not None: - if color_mode == ColorSelection.POSITIVE: - if not pos and neg: - print( - "[Error] Overlay has no values to display with positive color_mode" - ) - sys.exit(1) - neg = False - elif color_mode == ColorSelection.NEGATIVE: - if pos and not neg: - print( - "[Error] Overlay has no values to display with negative color_mode" - ) - sys.exit(1) - pos = False - if not pos and not neg: - print( - "[Error] Overlay has no values to display" - ) - sys.exit(1) - - # Upload to GPU and compile shaders - shader = setup_shader(meshdata, triangles, brain_display_width, brain_display_height, - specular=specular, ambient=ambient) - - # Draw - gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) - transformLoc = gl.glGetUniformLocation(shader, "transform") - if viewmat is None: - if view == ViewType.LEFT: - viewmat = transl * viewLeft - elif view == ViewType.RIGHT: - viewmat = transl * viewRight - elif view == ViewType.BACK: - viewmat = transl * viewBack - elif view == ViewType.FRONT: - viewmat = transl * viewFront - elif view == ViewType.BOTTOM: - viewmat = transl * viewBottom - elif view == ViewType.TOP: - viewmat = transl * viewTop - else: - viewmat = transl * viewmat - - gl.glUniformMatrix4fv(transformLoc, 1, gl.GL_FALSE, viewmat) - gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) - - im1 = capture_window(brain_display_width, brain_display_height) - - # Center brain - brain_x = 0 if WWIDTH < BWIDTH else (WWIDTH - BWIDTH) // 2 - brain_y = 0 if WHEIGHT < BHEIGHT else (WHEIGHT - BHEIGHT) // 2 - - image.paste(im1, (brain_x, brain_y)) - - # Create colorbar - bar = None - bar_w = bar_h = 0 - if overlaypath is not None and colorbar: - bar = create_colorbar(fthresh, fmax, invert, orientation, colorbar_scale * UI_SCALE, - pos, neg, font_file=font_file) - bar_w, bar_h = bar.size - - # Create caption - font = None - text_w = text_h = 0 - if caption: - if font_file is None: - script_dir = "/".join(str(__file__).split("/")[:-1]) - font_file = os.path.join(script_dir, "Roboto-Regular.ttf") - font = ImageFont.truetype(font_file, 20 * caption_scale * UI_SCALE) - text_w, text_h = text_size(caption, font) - - text_w = int(text_w) - text_h = int(text_h) - - # Constants defining the position of the caption and colorbar - BOTTOM_PAD = int(20 * UI_SCALE) - RIGHT_PAD = int(20 * UI_SCALE) - GAP = int(4 * UI_SCALE) - - if orientation == OrientationType.HORIZONTAL: - # Place the colorbar - if bar is not None: - if colorbar_x is None: - bx = int(0.5 * (image.width - bar_w)) - else: - bx = int(colorbar_x * WWIDTH) - if colorbar_y is None: - gap_and_caption = (GAP + text_h) if caption and caption_y is None else 0 - by = image.height - BOTTOM_PAD - gap_and_caption - bar_h - else: - by = int(colorbar_y * WHEIGHT) - image.paste(bar, (bx, by)) - - # Place the caption - if caption: - if caption_x is None: - cx = int(0.5 * (image.width - text_w)) - else: - cx = int(caption_x * WWIDTH) - if caption_y is None: - cy = image.height - BOTTOM_PAD - text_h - else: - cy = int(caption_y * WHEIGHT) - ImageDraw.Draw(image).text( - (cx, cy), caption, (220, 220, 220), font=font, anchor="lt" - ) - else: # orientation == OrientationType.VERTICAL - # Place the colorbar - if bar is not None: - if colorbar_x is None: - gap_and_caption = (GAP + text_h) if caption and caption_x is None else 0 - bx = image.width - RIGHT_PAD - gap_and_caption - bar_w - else: - bx = int(colorbar_x * WWIDTH) - if colorbar_y is None: - by = int(0.5 * (image.height - bar_h)) - else: - by = int(colorbar_y * WHEIGHT) - image.paste(bar, (bx, by)) - - # Place the caption - if caption: - # Create a new transparent image and rotate it - temp_caption_img = Image.new("RGBA", (text_w, text_h), (0,0,0,0)) - ImageDraw.Draw(temp_caption_img).text((0, 0), caption, font=font, anchor="lt") - rotated_caption = temp_caption_img.rotate(90, expand=True, fillcolor=(0,0,0,0)) - rotated_w, rotated_h = rotated_caption.size - - if caption_x is None: - cx = image.width - RIGHT_PAD - rotated_w - else: - cx = int(caption_x * WWIDTH) - if caption_y is None: - cy = int(0.5 * (image.height - rotated_h)) - else: - cy = int(caption_y * WHEIGHT) - - image.paste(rotated_caption, (cx, cy), rotated_caption) - - # save image - print(f"[INFO] Saving snapshot to {outpath}") - image.save(outpath) - - glfw.terminate() - - return None - -def snap4( - lhoverlaypath=None, - rhoverlaypath=None, - lhannotpath=None, - rhannotpath=None, - fthresh=None, - fmax=None, - sdir=None, - caption=None, - invert=False, - labelname="cortex.label", - surfname=None, - curvname="curv", - colorbar=True, - outpath=None, - font_file=None, - specular=True, - ambient=0.0, -): - """ - Snap four views (front and back for left and right hemispheres). - - Save an image that includes the views and a color bar. - - Parameters - ---------- - lhoverlaypath : str - Path to the overlay files for left hemi (FreeSurfer format). - rhoverlaypath : str - Path to the overlay files for right hemi (FreeSurfer format). - lhannotpath : str - Path to the annotation files for left hemi (FreeSurfer format). - rhannotpath : str - Path to the annotation files for right hemi (FreeSurfer format). - fthresh : float - Pos absolute value under which no color is shown. - fmax : float - Pos absolute value above which color is saturated. - sdir : str - Subject dir containing surf files. - caption : str - Caption text to be placed on the image. - invert : bool - Invert color (blue positive, red negative). - labelname : str - Label for masking, usually cortex.label. - surfname : str - Surface to display values on, usually pial_semi_inflated from fsaverage. - curvname : str - Curvature file for texture in non-colored regions (default curv). - colorbar : bool - Show colorbar on image. Will be ignored for annotation files. - outpath : str - Path to the output image file. - font_file : str - Path to the file describing the font to be used in captions. - specular : bool - Specular is by default set as True. - ambient : float - Ambient light, default 0, only use diffuse light sources. - - Returns - ------- - None - This function returns None. - """ - # setup window - # (keep aspect ratio, as the mesh scale and distances are set accordingly) - wwidth = 540 - wheight = 450 - visible = True - window = init_window(wwidth, wheight, "WhipperSnapPy 2.0", visible) - if not window: - return False # need raise error here in future - - # set up matrices to show object left and right side: - rot_z = pyrr.Matrix44.from_z_rotation(-0.5 * math.pi) - rot_x = pyrr.Matrix44.from_x_rotation(0.5 * math.pi) - # rot_y = pyrr.Matrix44.from_y_rotation(math.pi/6) - viewLeft = rot_x * rot_z - rot_y = pyrr.Matrix44.from_y_rotation(math.pi) - viewRight = rot_y * viewLeft - transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) - - for hemi in ("lh", "rh"): - if surfname is None: - print( - "[INFO] No surf_name provided. Looking for options in surf directory..." - ) - - if sdir is None: - sdir = os.environ.get("SUBJECTS_DIR") - if not sdir: - print( - "[INFO] No surf_name or subjects directory (sdir) \ -provided, can not find surf file" - ) - sys.exit(1) - - found_surfname = get_surf_name(sdir, hemi) - - if found_surfname is None: - print( - f"[ERROR] Could not find valid surface in {sdir} for hemi: {hemi}!" - ) - sys.exit(1) - meshpath = os.path.join(sdir, "surf", hemi + "." + found_surfname) - else: - meshpath = os.path.join(sdir, "surf", hemi + "." + surfname) - - curvpath = None - if curvname: - curvpath = os.path.join(sdir, "surf", hemi + "." + curvname) - labelpath = None - if labelname: - labelpath = os.path.join(sdir, "label", hemi + "." + labelname) - if hemi == "lh": - overlaypath = lhoverlaypath - annotpath = lhannotpath - else: - overlaypath = rhoverlaypath - annotpath = rhannotpath - - # load and colorize data - meshdata, triangles, fthresh, fmax, pos, neg = prepare_geometry( - meshpath, overlaypath, annotpath, curvpath, labelpath, fthresh, fmax, invert - ) - - # Check if there is something to display - if pos == 0 and neg == 0: - print( - "[Error] Overlay has no values to display" - ) - sys.exit(1) - - # upload to GPU and compile shaders - shader = setup_shader(meshdata, triangles, wwidth, wheight, - specular=specular, ambient=ambient) - - # draw - gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) - transformLoc = gl.glGetUniformLocation(shader, "transform") - viewmat = viewLeft - if hemi == "lh": - viewmat = transl * viewmat - gl.glUniformMatrix4fv(transformLoc, 1, gl.GL_FALSE, viewmat) - gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) - - im1 = capture_window(wwidth, wheight) - - glfw.swap_buffers(window) - gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) - viewmat = viewRight - if hemi == "rh": - viewmat = transl * viewmat - gl.glUniformMatrix4fv(transformLoc, 1, gl.GL_FALSE, viewmat) - gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) - - im2 = capture_window(wwidth, wheight) - - if hemi == "lh": - lhimg = Image.new("RGB", (im1.width, im1.height + im2.height)) - lhimg.paste(im1, (0, 0)) - lhimg.paste(im2, (0, im1.height)) - else: - rhimg = Image.new("RGB", (im1.width, im1.height + im2.height)) - rhimg.paste(im2, (0, 0)) - rhimg.paste(im1, (0, im2.height)) - - image = Image.new("RGB", (lhimg.width + rhimg.width, lhimg.height)) - image.paste(lhimg, (0, 0)) - image.paste(rhimg, (im1.width, 0)) - - if caption: - if font_file is None: - script_dir = "/".join(str(__file__).split("/")[:-1]) - font_file = os.path.join(script_dir, "Roboto-Regular.ttf") - font = ImageFont.truetype(font_file, 20) - xpos = 0.5 * (image.width - font.getlength(caption)) - ImageDraw.Draw(image).text( - (xpos, image.height - 40), caption, (220, 220, 220), font=font - ) - - if lhannotpath is None and rhannotpath is None and colorbar: - bar = create_colorbar(fthresh, fmax, invert, pos=pos, neg=neg) - xpos = int(0.5 * (image.width - bar.width)) - ypos = int(0.5 * (image.height - bar.height)) - image.paste(bar, (xpos, ypos)) - - if outpath: - print(f"[INFO] Saving snapshot to {outpath}") - image.save(outpath) - - glfw.terminate() - - return None - -def get_surf_name(sdir, hemi): - """ - Find a valid surface file in the specified subject directory. - - A valid file can be one of: ['pial_semi_inflated', 'white', 'inflated']. - - Parameters - ---------- - sdir : str - Subject directory. - hemi : str - Hemisphere; one of: ['lh', 'rh']. - - Returns - ------- - surfname: str - Valid and existing surf file's name; otherwise, None. - """ - for surf_name_option in ["pial_semi_inflated", "white", "inflated"]: - if os.path.exists(os.path.join(sdir, "surf", hemi + "." + surf_name_option)): - print("[INFO] Found {}".format(hemi + "." + surf_name_option)) - return surf_name_option - else: - print("[INFO] No {} file found".format(hemi + "." + surf_name_option)) - else: - return None diff --git a/whippersnappy/geometry/__init__.py b/whippersnappy/geometry/__init__.py new file mode 100644 index 0000000..b7ac525 --- /dev/null +++ b/whippersnappy/geometry/__init__.py @@ -0,0 +1,11 @@ +"""Geometry subpackage exports. + +Expose prepare_geometry and small IO helpers under `whippersnappy.geometry`. +""" +from .prepare import prepare_geometry +from .read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data +from .surf_name import get_surf_name + +__all__ = [ + 'prepare_geometry', 'read_geometry', 'read_annot_data', 'read_mgh_data', 'read_morph_data', 'get_surf_name' +] diff --git a/whippersnappy/geometry/prepare.py b/whippersnappy/geometry/prepare.py new file mode 100644 index 0000000..6035151 --- /dev/null +++ b/whippersnappy/geometry/prepare.py @@ -0,0 +1,158 @@ +"""Geometry helpers for mesh processing and GPU preparation (prepare.py). + +This module contains the primary `prepare_geometry` function used to +normalize meshes, compute normals and assemble vertex arrays for OpenGL. +""" + +import os +import warnings + +import numpy as np + +from whippersnappy.geometry.read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data +from whippersnappy.utils.colormap import binary_color, heat_color, mask_label, mask_sign, rescale_overlay +from whippersnappy.utils.types import ColorSelection + + +def normalize_mesh(v, scale=1.0): + """Normalize mesh vertex coordinates.""" + bbmax = np.max(v, axis=0) + bbmin = np.min(v, axis=0) + v = v - 0.5 * (bbmax + bbmin) + v = scale * v / np.max(bbmax - bbmin) + return v + + +def vertex_normals(v, t): + """Compute vertex normals.""" + v0 = v[t[:, 0], :] + v1 = v[t[:, 1], :] + v2 = v[t[:, 2], :] + v1mv0 = v1 - v0 + v2mv1 = v2 - v1 + v0mv2 = v0 - v2 + cr0 = np.cross(v1mv0, -v0mv2) + cr1 = np.cross(v2mv1, -v1mv0) + cr2 = np.cross(v0mv2, -v2mv1) + n = np.zeros(v.shape) + np.add.at(n, t[:, 0], cr0) + np.add.at(n, t[:, 1], cr1) + np.add.at(n, t[:, 2], cr2) + ln = np.sqrt(np.sum(n * n, axis=1)) + ln[ln < np.finfo(float).eps] = 1 + n = n / ln.reshape(-1, 1) + return n + + +def prepare_geometry( + surfpath, + overlaypath=None, + annotpath=None, + curvpath=None, + labelpath=None, + minval=None, + maxval=None, + invert=False, + scale=1.85, + color_mode=ColorSelection.BOTH, +): + """Prepare meshdata for upload to GPU.""" + surf = read_geometry(surfpath, read_metadata=False) + vertices = normalize_mesh(np.array(surf[0], dtype=np.float32), scale) + triangles = np.array(surf[1], dtype=np.uint32) + vnormals = np.array(vertex_normals(vertices, triangles), dtype=np.float32) + num_vertices = vertices.shape[0] + + # try to load sulcal colormap + sulcmap = 0.5 * np.ones(vertices.shape, dtype=np.float32) + if curvpath: + curv = read_morph_data(curvpath) + if curv.shape[0] != num_vertices: + warnings.warn(f"Curvature file {curvpath} has {curv.shape[0]} values, but mesh has {num_vertices}.") + else: + sulcmap = binary_color(curv, 0.0, color_low=0.5, color_high=0.33) + + # Initialize defaults for overlay outputs + fmin = None + fmax = None + pos = None + neg = None + colors = sulcmap # use as default + + # try to load overlay data + if overlaypath: + _, file_extension = os.path.splitext(overlaypath) + if file_extension == ".mgh": + mapdata = read_mgh_data(overlaypath) + else: + mapdata = read_morph_data(overlaypath) + + # Check if overlay length matches number of vertices. If not, raise an error. + if mapdata.shape[0] != num_vertices: + raise ValueError( + f"Overlay file {overlaypath} has {mapdata.shape[0]} values but mesh has {num_vertices}.\n" + "This usually means the overlay does not match the provided surface " + "(e.g. RH overlay used with LH surface). Provide the correct overlay " + "file." + ) + else: + valabs = np.abs(mapdata) + if maxval is None: + maxval = np.max(valabs) if np.any(valabs) else 0 + if minval is None: + minval = max(0.0, np.min(valabs) if np.any(valabs) else 0) + + mapdata = mask_sign(mapdata, color_mode) + mapdata, fmin, fmax, pos, neg = rescale_overlay(mapdata, minval, maxval) + colors = heat_color(mapdata, invert) + # some mapdata values could be nan (below min threshold) + missing = np.isnan(mapdata) + if np.any(missing): + colors[missing, :] = sulcmap[missing, :] + # alternatively try to load annotation data + elif annotpath: + # Read annotation (per-vertex labels) and colormap table. + annot, ctab, _ = read_annot_data(annotpath) + + # Check if annotation length matches number of vertices. If not, raise an error. + if annot.shape[0] != num_vertices: + raise ValueError( + f"Annotation file {annotpath} has {annot.shape[0]} values but mesh has {num_vertices}.\n" + "This usually means the .annot does not match the provided surface " + "(e.g. RH annot used with LH surface). Provide the correct annot " + "file." + ) + else: + # If annot is shorter, pad with -1 (meaning 'no label') to match + # mesh vertices. + if annot.shape[0] < num_vertices: + pad_len = num_vertices - annot.shape[0] + annot = np.pad(annot, (0, pad_len), mode="constant", constant_values=-1) + + # Ensure integer type for safe indexing + annot = annot.astype(np.int32) + + # Start with sulcmap as the default and only overwrite valid label indices + colors = np.array(sulcmap, dtype=np.float32) + + # Normalize colortable: detect whether ctab is 0-255 or 0-1 + ctab_rgb = ctab[:, 0:3].astype(np.float32) + denom = 255.0 if np.max(ctab_rgb) > 1 else 1.0 + + # Only assign colors for valid annotation indices (>=0 and within the color table) + valid = (annot >= 0) & (annot < ctab.shape[0]) + if np.any(valid): + colors[valid, :] = ctab_rgb[annot[valid], :] / denom + + # Ensure colors dtype matches vertices/normals + colors = np.asarray(colors, dtype=np.float32) + + # Apply label mask to colors if labelpath is provided, + # regardless of whether overlay or annot data was loaded + if labelpath: + mask = np.isnan(mask_label(np.ones(num_vertices), labelpath)) + if np.any(mask): + colors[mask, :] = sulcmap[mask, :] + + vertexdata = np.concatenate((vertices, vnormals, colors), axis=1) + return vertexdata, triangles, fmin, fmax, pos, neg diff --git a/whippersnappy/read_geometry.py b/whippersnappy/geometry/read_geometry.py similarity index 100% rename from whippersnappy/read_geometry.py rename to whippersnappy/geometry/read_geometry.py diff --git a/whippersnappy/geometry/surf_name.py b/whippersnappy/geometry/surf_name.py new file mode 100644 index 0000000..69e517e --- /dev/null +++ b/whippersnappy/geometry/surf_name.py @@ -0,0 +1,19 @@ +"""Helper for finding a surface file name inside a subject directory. + +This replaces the previous `io.py` name which was generic; `surf_name.py` is +more descriptive (it provides `get_surf_name`). +""" +import os + + +def get_surf_name(sdir, hemi): + """Find a valid surface file in the specified subject directory. + + Returns the surface basename (e.g. 'white', 'inflated', etc.) or None. + """ + for surf_name_option in ["pial_semi_inflated", "white", "inflated"]: + path = os.path.join(sdir, "surf", f"{hemi}.{surf_name_option}") + if os.path.exists(path): + return surf_name_option + return None + diff --git a/whippersnappy/gl/__init__.py b/whippersnappy/gl/__init__.py new file mode 100644 index 0000000..f2894dc --- /dev/null +++ b/whippersnappy/gl/__init__.py @@ -0,0 +1,33 @@ +"""OpenGL helper utilities (gl package). + +This package replaces the previous `gl_utils.py` module. +Functions are re-exported at package level for convenience, e.g.: + + from whippersnappy.gl import init_window, setup_shader + +""" + +from .camera import make_model, make_projection, make_view +from .shaders import get_default_shaders, get_webgl_shaders +from .utils import ( + capture_window, + compile_shader_program, + create_vao, + init_window, + set_camera_uniforms, + set_default_gl_state, + set_lighting_uniforms, + setup_buffers, + setup_shader, + setup_vertex_attributes, +) +from .views import get_view_matrices, get_view_matrix + +__all__ = [ + 'create_vao', 'compile_shader_program', 'setup_buffers', 'setup_vertex_attributes', + 'set_default_gl_state', 'set_camera_uniforms', 'set_lighting_uniforms', + 'init_window', 'setup_shader', 'capture_window', + 'make_model', 'make_projection', 'make_view', + 'get_default_shaders', 'get_view_matrices', 'get_view_matrix', + 'get_webgl_shaders' +] diff --git a/whippersnappy/gl/camera.py b/whippersnappy/gl/camera.py new file mode 100644 index 0000000..089dd52 --- /dev/null +++ b/whippersnappy/gl/camera.py @@ -0,0 +1,29 @@ +"""Camera and transform helpers (moved under gl package).""" + +import pyrr + + +def make_projection(width, height, fov=20.0, near=0.1, far=100.0): + """Create a perspective projection matrix.""" + return pyrr.matrix44.create_perspective_projection(fov, width / height, near, far) + + +def make_view(camera_pos=(0.0, 0.0, -5.0)): + """Create a view matrix from a camera position.""" + return pyrr.matrix44.create_from_translation(pyrr.Vector3(camera_pos)) + + +def make_model(): + """Create a default model matrix.""" + return pyrr.matrix44.create_from_translation(pyrr.Vector3([0.0, 0.0, 0.0])) + + +def make_transform(translation, rotation, scale): + """Create a model transform from translation, rotation and scale.""" + scale_matrix = pyrr.matrix44.create_from_scale([scale, scale, scale]) + return ( + pyrr.matrix44.create_from_translation(translation) + * rotation + * scale_matrix + ) + diff --git a/whippersnappy/gl/shaders.py b/whippersnappy/gl/shaders.py new file mode 100644 index 0000000..acdc9ed --- /dev/null +++ b/whippersnappy/gl/shaders.py @@ -0,0 +1,194 @@ +"""Shared shader sources inside the gl package.""" + +def get_default_shaders(): + """Return the default vertex and fragment shader sources (GLSL 330).""" + vertex_shader = """ + + #version 330 + + layout (location = 0) in vec3 aPos; + layout (location = 1) in vec3 aNormal; + layout (location = 2) in vec3 aColor; + + out vec3 FragPos; + out vec3 Normal; + out vec3 Color; + + uniform mat4 transform; + uniform mat4 model; + uniform mat4 view; + uniform mat4 projection; + + void main() + { + gl_Position = projection * view * model * transform * vec4(aPos, 1.0f); + FragPos = vec3(model * transform * vec4(aPos, 1.0)); + // normal matrix should be computed outside and passed! + Normal = mat3(transpose(inverse(view * model * transform))) * aNormal; + Color = aColor; + } + + """ + + fragment_shader = """ + #version 330 + + in vec3 Normal; + in vec3 FragPos; + in vec3 Color; + + out vec4 FragColor; + + uniform vec3 lightColor = vec3(1.0, 1.0, 1.0); + uniform bool doSpecular = true; + uniform float ambientStrength = 0.0; + + void main() + { + // ambient + vec3 ambient = ambientStrength * lightColor; + + // diffuse + vec3 norm = normalize(Normal); + vec4 diffweights = vec4(0.6, 0.4, 0.4, 0.3); + + // key light (overhead) + vec3 lightPos1 = vec3(0.0,5.0,5.0); + vec3 lightDir = normalize(lightPos1 - FragPos); + float diff = max(dot(norm, lightDir), 0.0); + vec3 diffuse = diffweights[0] * diff * lightColor; + + // headlight (at camera) + vec3 lightPos2 = vec3(0.0,0.0,5.0); + lightDir = normalize(lightPos2 - FragPos); + vec3 ohlightDir = lightDir; + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[1] * diff * lightColor; + + // fill light (from below) + vec3 lightPos3 = vec3(0.0,-5.0,5.0); + lightDir = normalize(lightPos3 - FragPos); + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[2] * diff * lightColor; + + // left right back lights + vec3 lightPos4 = vec3(5.0,0.0,-5.0); + lightDir = normalize(lightPos4 - FragPos); + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[3] * diff * lightColor; + + vec3 lightPos5 = vec3(-5.0,0.0,-5.0); + lightDir = normalize(lightPos5 - FragPos); + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[3] * diff * lightColor; + + // specular + vec3 result; + if (doSpecular) + { + float specularStrength = 0.5; + vec3 viewDir = normalize(-FragPos); + vec3 reflectDir = reflect(ohlightDir, norm); + float spec = pow(max(dot(viewDir, reflectDir), 0.0), 32); + vec3 specular = specularStrength * spec * lightColor; + result = (ambient + diffuse + specular) * Color; + } + else + { + result = (ambient + diffuse) * Color; + } + FragColor = vec4(result, 1.0); + } + + """ + + return vertex_shader, fragment_shader + + +def get_webgl_shaders(): + """Return the default vertex and fragment shader sources (GLSL 330).""" + + # Only declare custom attributes - Three.js provides position, normal, matrices + # Don't declare position, normal, *Matrix + # Only attributes like color , or uniforms like lightColor, ambientStrenght + # Use normalMatrix instead of computing transpose... + vertex_shader = """ + attribute vec3 color; + + varying vec3 vFragPos; + varying vec3 vNormal; + varying vec3 vColor; + + void main() + { + gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0); + vFragPos = vec3(modelViewMatrix * vec4(position, 1.0)); + vNormal = normalMatrix * normal; + vColor = color; + } + """ + + fragment_shader = """ + precision highp float; + + varying vec3 vNormal; + varying vec3 vFragPos; + varying vec3 vColor; + + uniform vec3 lightColor; + uniform float ambientStrength; + + void main() + { + // ambient + vec3 ambient = ambientStrength * lightColor; + + // diffuse + vec3 norm = normalize(vNormal); + vec4 diffweights = vec4(0.6, 0.4, 0.4, 0.3); + + // key light (overhead) + vec3 lightPos1 = vec3(0.0, 5.0, 5.0); + vec3 lightDir = normalize(lightPos1 - vFragPos); + float diff = max(dot(norm, lightDir), 0.0); + vec3 diffuse = diffweights[0] * diff * lightColor; + + // headlight (at camera) + vec3 lightPos2 = vec3(0.0, 0.0, 5.0); + lightDir = normalize(lightPos2 - vFragPos); + vec3 ohlightDir = lightDir; + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[1] * diff * lightColor; + + // fill light (from below) + vec3 lightPos3 = vec3(0.0, -5.0, 5.0); + lightDir = normalize(lightPos3 - vFragPos); + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[2] * diff * lightColor; + + // left right back lights + vec3 lightPos4 = vec3(5.0, 0.0, -5.0); + lightDir = normalize(lightPos4 - vFragPos); + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[3] * diff * lightColor; + + vec3 lightPos5 = vec3(-5.0, 0.0, -5.0); + lightDir = normalize(lightPos5 - vFragPos); + diff = max(dot(norm, lightDir), 0.0); + diffuse = diffuse + diffweights[3] * diff * lightColor; + + // specular + float specularStrength = 0.5; + vec3 viewDir = normalize(-vFragPos); + vec3 reflectDir = reflect(-ohlightDir, norm); + float spec = pow(max(dot(viewDir, reflectDir), 0.0), 32.0); + vec3 specular = specularStrength * spec * lightColor; + + vec3 result = (ambient + diffuse + specular) * vColor; + gl_FragColor = vec4(result, 1.0); + } + """ + + return vertex_shader, fragment_shader + + diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py new file mode 100644 index 0000000..1aa3b58 --- /dev/null +++ b/whippersnappy/gl/utils.py @@ -0,0 +1,174 @@ +"""GL helper utilities. + +Contains the implementation of OpenGL helpers used by the package. +""" + +import sys + +import glfw +import OpenGL.GL as gl +import OpenGL.GL.shaders as shaders +from PIL import Image + +from .camera import make_model, make_projection, make_view +from .shaders import get_default_shaders + + +def create_vao(): + """Create and bind a VAO, returning its handle.""" + vao = gl.glGenVertexArrays(1) + gl.glBindVertexArray(vao) + return vao + + +def compile_shader_program(vertex_src, fragment_src): + """Compile and link a shader program.""" + return gl.shaders.compileProgram( + shaders.compileShader(vertex_src, gl.GL_VERTEX_SHADER), + shaders.compileShader(fragment_src, gl.GL_FRAGMENT_SHADER), + ) + + +def setup_buffers(meshdata, triangles): + """Create VBO/EBO and upload mesh data.""" + vbo = gl.glGenBuffers(1) + gl.glBindBuffer(gl.GL_ARRAY_BUFFER, vbo) + gl.glBufferData(gl.GL_ARRAY_BUFFER, meshdata.nbytes, meshdata, gl.GL_STATIC_DRAW) + + ebo = gl.glGenBuffers(1) + gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, ebo) + gl.glBufferData( + gl.GL_ELEMENT_ARRAY_BUFFER, triangles.nbytes, triangles, gl.GL_STATIC_DRAW + ) + + return vbo, ebo + + +def setup_vertex_attributes(shader): + """Configure vertex attribute pointers for position, normal, color.""" + position = gl.glGetAttribLocation(shader, "aPos") + gl.glVertexAttribPointer( + position, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(0) + ) + gl.glEnableVertexAttribArray(position) + + vnormalpos = gl.glGetAttribLocation(shader, "aNormal") + gl.glVertexAttribPointer( + vnormalpos, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(3 * 4) + ) + gl.glEnableVertexAttribArray(vnormalpos) + + colorpos = gl.glGetAttribLocation(shader, "aColor") + gl.glVertexAttribPointer( + colorpos, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(6 * 4) + ) + gl.glEnableVertexAttribArray(colorpos) + + +def set_default_gl_state(): + """Apply common GL state for rendering.""" + gl.glClearColor(0.0, 0.0, 0.0, 1.0) + gl.glEnable(gl.GL_DEPTH_TEST) + + +def set_camera_uniforms(shader, view, projection, model): + """Set view/projection/model uniforms in the shader.""" + view_loc = gl.glGetUniformLocation(shader, "view") + proj_loc = gl.glGetUniformLocation(shader, "projection") + model_loc = gl.glGetUniformLocation(shader, "model") + gl.glUniformMatrix4fv(view_loc, 1, gl.GL_FALSE, view) + gl.glUniformMatrix4fv(proj_loc, 1, gl.GL_FALSE, projection) + gl.glUniformMatrix4fv(model_loc, 1, gl.GL_FALSE, model) + + +def set_lighting_uniforms(shader, specular=True, ambient=0.0, light_color=(1.0, 1.0, 1.0)): + """Set lighting uniforms in the shader.""" + specular_loc = gl.glGetUniformLocation(shader, "doSpecular") + gl.glUniform1i(specular_loc, specular) + + light_color_loc = gl.glGetUniformLocation(shader, "lightColor") + gl.glUniform3f(light_color_loc, *light_color) + + ambient_loc = gl.glGetUniformLocation(shader, "ambientStrength") + gl.glUniform1f(ambient_loc, ambient) + + +def init_window(width, height, title="PyOpenGL", visible=True): + """Create an OpenGL window (GLFW) and make its context current.""" + if not glfw.init(): + return False + + glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 3) + glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 3) + glfw.window_hint(glfw.OPENGL_FORWARD_COMPAT, True) + glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE) + if not visible: + glfw.window_hint(glfw.VISIBLE, glfw.FALSE) + window = glfw.create_window(width, height, title, None, None) + if not window: + glfw.terminate() + return False + glfw.set_input_mode(window, glfw.STICKY_KEYS, gl.GL_TRUE) + glfw.make_context_current(window) + glfw.swap_interval(0) + return window + + +def setup_shader(meshdata, triangles, width, height, specular=True, ambient=0.0): + """Create vertex and fragment shaders, set up VAO/VBO/EBO, and initialize camera/lighting uniforms. + + This function composes several low-level helpers in this module and returns the compiled shader program handle. + """ + vertex_shader, fragment_shader = get_default_shaders() + + create_vao() + shader = compile_shader_program(vertex_shader, fragment_shader) + setup_buffers(meshdata, triangles) + setup_vertex_attributes(shader) + + gl.glUseProgram(shader) + set_default_gl_state() + + view = make_view() + projection = make_projection(width, height) + model = make_model() + set_camera_uniforms(shader, view, projection, model) + set_lighting_uniforms(shader, specular=specular, ambient=ambient) + + return shader + + +def capture_window(width, height): + """Capture the current GL framebuffer region into a PIL Image (RGB). + + On macOS we adjust for the retina scaling factor by reading at double resolution and downsampling. + """ + if sys.platform == "darwin": + rwidth = 2 * width + rheight = 2 * height + else: + rwidth = width + rheight = height + + gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) + img_buf = gl.glReadPixels(0, 0, rwidth, rheight, gl.GL_RGB, gl.GL_UNSIGNED_BYTE) + image = Image.frombytes("RGB", (rwidth, rheight), img_buf) + image = image.transpose(Image.FLIP_TOP_BOTTOM) + if sys.platform == "darwin": + image.thumbnail((0.5 * rwidth, 0.5 * rheight), Image.Resampling.LANCZOS) + return image + + +__all__ = [ + "create_vao", + "compile_shader_program", + "setup_buffers", + "setup_vertex_attributes", + "set_default_gl_state", + "set_camera_uniforms", + "set_lighting_uniforms", + "init_window", + "setup_shader", + "capture_window", +] + diff --git a/whippersnappy/gl/views.py b/whippersnappy/gl/views.py new file mode 100644 index 0000000..477f655 --- /dev/null +++ b/whippersnappy/gl/views.py @@ -0,0 +1,29 @@ +"""View matrices and presets under gl package.""" + +import numpy as np + +from whippersnappy.utils.types import ViewType + + +def get_view_matrices(): + """Return canonical view matrices for left/right/front/back/top/bottom.""" + view_left = np.array([[0, 0, -1, 0], [-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32) + view_right = np.array([[0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32) + view_back = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32) + view_front = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32) + view_bottom = np.array([[-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=np.float32) + view_top = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32) + + return { + ViewType.LEFT: view_left, + ViewType.RIGHT: view_right, + ViewType.BACK: view_back, + ViewType.FRONT: view_front, + ViewType.TOP: view_top, + ViewType.BOTTOM: view_bottom, + } + + +def get_view_matrix(view_type): + """Return a view matrix for a single view type.""" + return get_view_matrices()[view_type] diff --git a/whippersnappy/gui/__init__.py b/whippersnappy/gui/__init__.py new file mode 100644 index 0000000..6c5c9a9 --- /dev/null +++ b/whippersnappy/gui/__init__.py @@ -0,0 +1,4 @@ +from .config_app import ConfigWindow + +__all__ = ["ConfigWindow"] + diff --git a/whippersnappy/config_app.py b/whippersnappy/gui/config_app.py similarity index 100% rename from whippersnappy/config_app.py rename to whippersnappy/gui/config_app.py diff --git a/whippersnappy/plot3d.py b/whippersnappy/plot3d.py new file mode 100644 index 0000000..789db48 --- /dev/null +++ b/whippersnappy/plot3d.py @@ -0,0 +1,194 @@ +"""3D plotting for WhipperSnapPy using pythreejs (Three.js) for Jupyter. + +This module provides interactive 3D brain visualization for Jupyter notebooks using +Three.js/WebGL. It works in all Jupyter environments (browser, JupyterLab, Colab, VS Code). + +Unlike the desktop GUI (launched with --interactive flag), this plots in the browser +using WebGL and is specifically designed for notebook environments. + +Usage: + from whippersnappy import plot3d + viewer = plot3d(meshpath='path/to/lh.white', curvpath='path/to/lh.curv') + display(viewer) + +Dependencies: + pythreejs, ipywidgets, numpy + +@Author : Martin Reuter +@Created : 14.02.2026 +""" + +import numpy as np +import pythreejs as p3js +from ipywidgets import HTML, VBox + +from whippersnappy.utils.types import ColorSelection + +from .geometry import prepare_geometry +from .gl import get_webgl_shaders + + +def plot3d( + meshpath, + overlaypath=None, + annotpath=None, + curvpath=None, + labelpath=None, + minval=None, + maxval=None, + invert=False, + scale=1.85, + color_mode=None, + width=800, + height=800, +): + """Create an interactive 3D notebook viewer using pythreejs (Three.js). + + This creates a browser-based interactive 3D viewer for Jupyter notebooks. + Works in all Jupyter environments (browser, JupyterLab, Colab, VS Code). + + Note: This is different from the desktop GUI (launched with --interactive flag). + + Parameters + ---------- + meshpath : str + Path to surface file + overlaypath : str, optional + Path to overlay file + annotpath : str, optional + Path to annotation file + curvpath : str, optional + Path to curvature file + labelpath : str, optional + Path to label file + minval : float, optional + Minimum threshold for coloring + maxval : float, optional + Maximum value for color saturation + invert : bool, default False + Invert color map + scale : float, default 1.85 + Global scaling factor + color_mode : ColorSelection, optional + Select which values to color + width : int, default 800 + Canvas width + height : int, default 800 + Canvas height + + Returns + ------- + viewer : ipywidgets.VBox + Interactive 3D viewer widget + + Examples + -------- + >>> from whippersnappy import plot3d + >>> viewer = plot3d('path/to/lh.white', curvpath='path/to/lh.curv') + >>> display(viewer) + """ + + # Load and prepare mesh data + color_mode = color_mode or ColorSelection.BOTH + meshdata, triangles, fmin, fmax, pos, neg = prepare_geometry( + meshpath, overlaypath, annotpath, curvpath, labelpath, + minval, maxval, invert, scale, color_mode + ) + + print(f"Loaded mesh: {meshdata.shape[0]} vertices, {triangles.shape[0]} faces") + + # Extract vertices, normals, and colors + vertices = meshdata[:, :3] # x, y, z + normals = meshdata[:, 3:6] # nx, ny, nz + colors = meshdata[:, 6:9] # r, g, b + + # Center and scale the mesh + center = vertices.mean(axis=0) + vertices = vertices - center + max_extent = np.abs(vertices).max() + vertices = vertices / max_extent * 2.0 + + # Create Three.js mesh + mesh = create_threejs_mesh_with_custom_shaders(vertices, triangles, colors, normals) + + camera = p3js.PerspectiveCamera( + position=[-5, 0, 0], + up=[0, 0, 1], + aspect=width/height, + fov=45, + near=0.1, + far=1000 + ) + + # Create scene without lights (use our own custom shader): + scene = p3js.Scene( + children=[mesh, camera], # No lights needed + background='#000000' + ) + + # Create renderer + renderer = p3js.Renderer( + camera=camera, + scene=scene, + controls=[p3js.OrbitControls(controlling=camera)], + width=width, + height=height, + antialias=True + ) + + # Create info display + info_text = f""" +
+ Interactive 3D Viewer (Three.js) ✓
+ Vertices: {len(vertices):,}
+ Triangles: {len(triangles):,}
+
+ Drag to rotate, scroll to zoom, right-drag to pan
+ """ + + if overlaypath or annotpath: + info_text += "
📊 Overlay/annotation loaded" + elif curvpath: + info_text += "
🧠 Curvature (grayscale is correct)" + + info_text += "
" + + info = HTML(value=info_text) + + # Combine renderer and info + viewer = VBox([renderer, info]) + + return viewer + + + +def create_threejs_mesh_with_custom_shaders(vertices, faces, colors, normals): + """Custom lighting shader - fixed for Three.js.""" + + vertices = vertices.astype(np.float32) + colors = colors.astype(np.float32) + normals = normals.astype(np.float32) + faces = faces.astype(np.uint32).flatten() + + vertex_shader, fragment_shader = get_webgl_shaders() + + geometry = p3js.BufferGeometry( + attributes={ + 'position': p3js.BufferAttribute(array=vertices), + 'color': p3js.BufferAttribute(array=colors), + 'normal': p3js.BufferAttribute(array=normals), + } + ) + geometry.index = p3js.BufferAttribute(array=faces) + + material = p3js.ShaderMaterial( + vertexShader=vertex_shader, + fragmentShader=fragment_shader, + uniforms={ + 'lightColor': {'value': [1.0, 1.0, 1.0]}, + 'ambientStrength': {'value': 0.1} + } + ) + + mesh = p3js.Mesh(geometry=geometry, material=material) + return mesh diff --git a/whippersnappy/Roboto-Regular.ttf b/whippersnappy/resources/fonts/Roboto-Regular.ttf similarity index 100% rename from whippersnappy/Roboto-Regular.ttf rename to whippersnappy/resources/fonts/Roboto-Regular.ttf diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py new file mode 100644 index 0000000..832eaf0 --- /dev/null +++ b/whippersnappy/snap.py @@ -0,0 +1,385 @@ +"""Snapshot (static rendering) API for WhipperSnapPy. + +""" + +import os +import sys + +import glfw +import numpy as np +import OpenGL.GL as gl +import pyrr +from PIL import Image, ImageDraw, ImageFont + +from whippersnappy.geometry import get_surf_name, prepare_geometry +from whippersnappy.utils.image import create_colorbar, load_roboto_font, text_size +from whippersnappy.utils.types import ColorSelection, OrientationType, ViewType + +from . import gl as _gl +from .gl import get_view_matrices + + +def snap1( + meshpath, + outpath=None, + overlaypath=None, + annotpath=None, + labelpath=None, + curvpath=None, + view=ViewType.LEFT, + viewmat=None, + width=None, + height=None, + fthresh=None, + fmax=None, + caption=None, + caption_x=None, + caption_y=None, + caption_scale=1, + invert=False, + colorbar=True, + colorbar_x=None, + colorbar_y=None, + colorbar_scale=1, + orientation=OrientationType.HORIZONTAL, + color_mode=ColorSelection.BOTH, + font_file=None, + specular=True, + brain_scale=1.5, + ambient=0.0, +): + """Snap one view (view and hemisphere is determined by the user).""" + ref_width = 700 + ref_height = 500 + wwidth = ref_width if width is None else width + wheight = ref_height if height is None else height + ui_scale = min(wwidth / ref_width, wheight / ref_height) + + if not glfw.init(): + print("[ERROR] Could not init glfw!") + sys.exit(1) + primary_monitor = glfw.get_primary_monitor() + mode = glfw.get_video_mode(primary_monitor) + screen_width = mode.size.width + screen_height = mode.size.height + if wwidth > screen_width: + print(f"[INFO] Requested width {wwidth} exceeds screen width {screen_width}, expect black bars") + elif wheight > screen_height: + print(f"[INFO] Requested height {wheight} exceeds screen height {screen_height}, expect black bars") + + image = Image.new("RGB", (wwidth, wheight)) + + bwidth = int(540 * brain_scale * ui_scale) + bheight = int(450 * brain_scale * ui_scale) + brain_display_width = min(bwidth, wwidth) + brain_display_height = min(bheight, wheight) + + window = _gl.init_window(brain_display_width, brain_display_height, "WhipperSnapPy 2.0", visible=True) + if not window: + return False + + transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) + + meshdata, triangles, fthresh, fmax, pos, neg = prepare_geometry( + meshpath, + overlaypath, + annotpath, + curvpath, + labelpath, + fthresh, + fmax, + invert, + scale=brain_scale, + color_mode=color_mode, + ) + + if overlaypath is not None: + if color_mode == ColorSelection.POSITIVE: + if not pos and neg: + print("[Error] Overlay has no values to display with positive color_mode") + sys.exit(1) + neg = False + elif color_mode == ColorSelection.NEGATIVE: + if pos and not neg: + print("[Error] Overlay has no values to display with negative color_mode") + sys.exit(1) + pos = False + if not pos and not neg: + print("[Error] Overlay has no values to display") + sys.exit(1) + + shader = _gl.setup_shader(meshdata, triangles, brain_display_width, brain_display_height, + specular=specular, ambient=ambient) + + gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) + transform_loc = gl.glGetUniformLocation(shader, "transform") + view_mats = get_view_matrices() + viewmat = transl * (view_mats[view] if viewmat is None else viewmat) + gl.glUniformMatrix4fv(transform_loc, 1, gl.GL_FALSE, viewmat) + gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) + + im1 = _gl.capture_window(brain_display_width, brain_display_height) + + brain_x = 0 if wwidth < bwidth else (wwidth - bwidth) // 2 + brain_y = 0 if wheight < bheight else (wheight - bheight) // 2 + image.paste(im1, (brain_x, brain_y)) + + bar = None + bar_w = bar_h = 0 + if overlaypath is not None and colorbar: + bar = create_colorbar(fthresh, fmax, invert, orientation, colorbar_scale * ui_scale, pos, neg, font_file=font_file) + bar_w, bar_h = bar.size + + font = None + text_w = text_h = 0 + if caption: + if font_file is None: + font = load_roboto_font(int(20 * caption_scale * ui_scale)) + else: + try: + font = ImageFont.truetype(font_file, int(20 * caption_scale * ui_scale)) + except Exception: + font = load_roboto_font(int(20 * caption_scale * ui_scale)) + text_w, text_h = text_size(caption, font) + text_w = int(text_w) + text_h = int(text_h) + + bottom_pad = int(20 * ui_scale) + right_pad = int(20 * ui_scale) + gap = int(4 * ui_scale) + + if orientation == OrientationType.HORIZONTAL: + if bar is not None: + bx = int(0.5 * (image.width - bar_w)) if colorbar_x is None else int(colorbar_x * wwidth) + if colorbar_y is None: + gap_and_caption = (gap + text_h) if caption and caption_y is None else 0 + by = image.height - bottom_pad - gap_and_caption - bar_h + else: + by = int(colorbar_y * wheight) + image.paste(bar, (bx, by)) + + if caption: + cx = int(0.5 * (image.width - text_w)) if caption_x is None else int(caption_x * wwidth) + cy = image.height - bottom_pad - text_h if caption_y is None else int(caption_y * wheight) + ImageDraw.Draw(image).text((cx, cy), caption, (220, 220, 220), font=font, anchor="lt") + else: + if bar is not None: + if colorbar_x is None: + gap_and_caption = (gap + text_h) if caption and caption_x is None else 0 + bx = image.width - right_pad - gap_and_caption - bar_w + else: + bx = int(colorbar_x * wwidth) + by = int(0.5 * (image.height - bar_h)) if colorbar_y is None else int(colorbar_y * wheight) + image.paste(bar, (bx, by)) + + if caption: + temp_caption_img = Image.new("RGBA", (text_w, text_h), (0, 0, 0, 0)) + ImageDraw.Draw(temp_caption_img).text((0, 0), caption, font=font, anchor="lt") + rotated_caption = temp_caption_img.rotate(90, expand=True, fillcolor=(0, 0, 0, 0)) + rotated_w, rotated_h = rotated_caption.size + + cx = image.width - right_pad - rotated_w if caption_x is None else int(caption_x * wwidth) + cy = int(0.5 * (image.height - rotated_h)) if caption_y is None else int(caption_y * wheight) + image.paste(rotated_caption, (cx, cy), rotated_caption) + + if outpath is None: + glfw.terminate() + return image + + print(f"[INFO] Saving snapshot to {outpath}") + image.save(outpath) + glfw.terminate() + return None + + +def snap4( + lhoverlaypath=None, + rhoverlaypath=None, + lhannotpath=None, + rhannotpath=None, + fthresh=None, + fmax=None, + sdir=None, + caption=None, + invert=False, + labelname="cortex.label", + surfname=None, + curvname="curv", + colorbar=True, + outpath=None, + font_file=None, + specular=True, + ambient=0.0, + brain_scale=1.85, +): + """Snap four views (front and back for left and right hemispheres).""" + wwidth = 540 + wheight = 450 + # Try to create a visible window first (better for debugging), + # but fall back to an invisible/offscreen window if that fails. + window = _gl.init_window(wwidth, wheight, "WhipperSnapPy 2.0", visible=True) + if not window: + print("[WARNING] Could not create visible GLFW window; retrying with invisible window (offscreen).") + window = _gl.init_window(wwidth, wheight, "WhipperSnapPy 2.0", visible=False) + if not window: + print("[ERROR] Could not create any GLFW window/context. OpenGL context unavailable.") + return None + + rot_z = pyrr.Matrix44.from_z_rotation(-0.5 * np.pi) + rot_x = pyrr.Matrix44.from_x_rotation(0.5 * np.pi) + view_left = rot_x * rot_z + rot_y = pyrr.Matrix44.from_y_rotation(np.pi) + view_right = rot_y * view_left + transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) + + for hemi in ("lh", "rh"): + if surfname is None: + if sdir is None: + sdir = os.environ.get("SUBJECTS_DIR") + if not sdir: + print("[INFO] No surf_name or subjects directory (sdir) provided") + sys.exit(1) + found_surfname = get_surf_name(sdir, hemi) + if found_surfname is None: + print(f"[ERROR] Could not find valid surface in {sdir} for hemi: {hemi}!") + sys.exit(1) + meshpath = os.path.join(sdir, "surf", hemi + "." + found_surfname) + else: + meshpath = os.path.join(sdir, "surf", hemi + "." + surfname) + + # Assign derived paths + curvpath = os.path.join(sdir, "surf", hemi + "." + curvname) if curvname else None + labelpath = os.path.join(sdir, "label", hemi + "." + labelname) if labelname else None + overlaypath = lhoverlaypath if hemi == "lh" else rhoverlaypath + annotpath = lhannotpath if hemi == "lh" else rhannotpath + + # Diagnostic: report mesh and overlay paths and whether they exist + print(f"[DEBUG] hemisphere={hemi}") + print(f"[DEBUG] meshpath={meshpath} exists={os.path.exists(meshpath)}") + if overlaypath is not None: + print(f"[DEBUG] overlaypath={overlaypath} exists={os.path.exists(overlaypath)}") + if annotpath is not None: + print(f"[DEBUG] annotpath={annotpath} exists={os.path.exists(annotpath)}") + if curvpath is not None: + print(f"[DEBUG] curvpath={curvpath} exists={os.path.exists(curvpath)}") + + try: + meshdata, triangles, fthresh, fmax, pos, neg = prepare_geometry( + meshpath, overlaypath, annotpath, curvpath, labelpath, fthresh, fmax, invert, scale=brain_scale + ) + except Exception as e: + print(f"[ERROR] prepare_geometry failed for {meshpath}: {e}") + glfw.terminate() + return None + + # Diagnostics about mesh data + try: + print(f"[DEBUG] meshdata shape: {getattr(meshdata, 'shape', None)}; triangles count: {getattr(triangles, 'size', None)}") + except Exception: + pass + + if pos == 0 and neg == 0: + print("[Error] Overlay has no values to display") + sys.exit(1) + + try: + shader = _gl.setup_shader(meshdata, triangles, wwidth, wheight, specular=specular, ambient=ambient) + print("[DEBUG] Shader setup complete") + except Exception as e: + print(f"[ERROR] setup_shader failed: {e}") + glfw.terminate() + return None + + try: + gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) + except Exception as e: + print(f"[ERROR] glClear failed: {e}") + print(f"glError: {gl.glGetError()}") + glfw.terminate() + return None + transform_loc = gl.glGetUniformLocation(shader, "transform") + viewmat = view_left if hemi == "lh" else view_right + gl.glUniformMatrix4fv(transform_loc, 1, gl.GL_FALSE, transl * viewmat) + gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) + try: + im1 = _gl.capture_window(wwidth, wheight) + print(f"[DEBUG] Captured image 1 size: {im1.size}") + except Exception as e: + print(f"[ERROR] capture_window failed: {e}") + glfw.terminate() + return None + + glfw.swap_buffers(window) + try: + gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) + except Exception as e: + print(f"[ERROR] glClear failed: {e}") + print(f"glError: {gl.glGetError()}") + glfw.terminate() + return None + viewmat = view_right if hemi == "lh" else view_left + gl.glUniformMatrix4fv(transform_loc, 1, gl.GL_FALSE, transl * viewmat) + gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) + try: + im2 = _gl.capture_window(wwidth, wheight) + print(f"[DEBUG] Captured image 2 size: {im2.size}") + except Exception as e: + print(f"[ERROR] capture_window failed: {e}") + glfw.terminate() + return None + + if hemi == "lh": + lhimg = Image.new("RGB", (im1.width, im1.height + im2.height)) + lhimg.paste(im1, (0, 0)) + lhimg.paste(im2, (0, im1.height)) + else: + rhimg = Image.new("RGB", (im1.width, im1.height + im2.height)) + # Keep same top/bottom ordering as left hemisphere: top=im1, bottom=im2 + rhimg.paste(im1, (0, 0)) + rhimg.paste(im2, (0, im1.height)) + + # Add small padding around each hemisphere to avoid cropping at edges + pad = max(4, int(0.03 * wwidth)) + padded_lh = Image.new("RGB", (lhimg.width + 2 * pad, lhimg.height + 2 * pad), (0, 0, 0)) + padded_lh.paste(lhimg, (pad, pad)) + padded_rh = Image.new("RGB", (rhimg.width + 2 * pad, rhimg.height + 2 * pad), (0, 0, 0)) + padded_rh.paste(rhimg, (pad, pad)) + + image = Image.new("RGB", (padded_lh.width + padded_rh.width, padded_lh.height)) + image.paste(padded_lh, (0, 0)) + image.paste(padded_rh, (padded_lh.width, 0)) + + if caption: + if font_file is None: + font = load_roboto_font(20) + else: + try: + font = ImageFont.truetype(font_file, 20) + except Exception: + font = load_roboto_font(20) + if font is not None: + xpos = 0.5 * (image.width - getattr(font, 'getlength', lambda s: 0)(caption)) + ImageDraw.Draw(image).text((xpos, image.height - 40), caption, (220, 220, 220), font=font) + else: + ImageDraw.Draw(image).text((10, image.height - 40), caption, (220, 220, 220)) + + if lhannotpath is None and rhannotpath is None and colorbar: + bar = create_colorbar(fthresh, fmax, invert, pos=pos, neg=neg) + if bar is not None: + xpos = int(0.5 * (image.width - bar.width)) + ypos = int(0.5 * (image.height - bar.height)) + image.paste(bar, (xpos, ypos)) + + # If outpath is None, return the PIL Image object directly (no disk I/O) + if outpath is None: + glfw.terminate() + return image + + # Otherwise save to disk + if outpath: + print(f"[INFO] Saving snapshot to {outpath}") + image.save(outpath) + + glfw.terminate() + return None + diff --git a/whippersnappy/utils/__init__.py b/whippersnappy/utils/__init__.py index 285e1d4..dcb119b 100644 --- a/whippersnappy/utils/__init__.py +++ b/whippersnappy/utils/__init__.py @@ -1 +1,4 @@ -"""Utilities module.""" +"""Utils subpackage exports.""" +from . import colormap, image, types + +__all__ = ["image", "colormap", "types"] diff --git a/whippersnappy/utils/colormap.py b/whippersnappy/utils/colormap.py new file mode 100644 index 0000000..125cc97 --- /dev/null +++ b/whippersnappy/utils/colormap.py @@ -0,0 +1,88 @@ +"""Colormap and value preprocessing utilities.""" + +import numpy as np + +from whippersnappy.utils.types import ColorSelection + + +def heat_color(values, invert=False): + """Convert an array of float values into RBG heat color values.""" + if invert: + values = -1.0 * values + vabs = np.abs(values) + colors = np.zeros((vabs.size, 3), dtype=np.float32) + crb = 0.5625 + 3 * 0.4375 * vabs + cg = 1.5 * (vabs - (1.0 / 3.0)) + n1 = values < -1.0 + nm = (values >= -1.0) & (values < -(1.0 / 3.0)) + n0 = (values >= -(1.0 / 3.0)) & (values < 0) + p0 = (values >= 0) & (values < (1.0 / 3.0)) + pm = (values >= (1.0 / 3.0)) & (values < 1.0) + p1 = values >= 1.0 + colors[n1, 1:3] = 1.0 + colors[nm, 1] = cg[nm] + colors[nm, 2] = 1.0 + colors[n0, 2] = crb[n0] + colors[p0, 0] = crb[p0] + colors[pm, 1] = cg[pm] + colors[pm, 0] = 1.0 + colors[p1, 0:2] = 1.0 + colors[np.isnan(values), :] = np.nan + return colors + + +def mask_sign(values, color_mode): + """Mask values that don't have the same sign as color_mode.""" + masked_values = np.copy(values) + if color_mode == ColorSelection.POSITIVE: + masked_values[masked_values < 0] = np.nan + elif color_mode == ColorSelection.NEGATIVE: + masked_values[masked_values > 0] = np.nan + return masked_values + + +def rescale_overlay(values, minval=None, maxval=None): + """Rescale values for color map computation.""" + valsign = np.sign(values) + valabs = np.abs(values) + + if maxval < 0 or minval < 0: + print("rescale_overlay ERROR: min and maxval should both be positive!") + exit(1) + + values[valabs < minval] = np.nan + range_val = maxval - minval + if range_val == 0: + values = np.zeros_like(values) + else: + values = values - valsign * minval + values = values / range_val + + pos = np.any(values[~np.isnan(values)] > 0) + neg = np.any(values[~np.isnan(values)] < 0) + + return values, minval, maxval, pos, neg + + +def binary_color(values, thres, color_low, color_high): + """Create a binary colormap based on a threshold value.""" + if np.isscalar(color_low): + color_low = np.array((color_low, color_low, color_low), dtype=np.float32) + if np.isscalar(color_high): + color_high = np.array((color_high, color_high, color_high), dtype=np.float32) + colors = np.empty((values.size, 3), dtype=np.float32) + colors[values < thres, :] = color_low + colors[values >= thres, :] = color_high + return colors + + +def mask_label(values, labelpath=None): + """Apply a label file as a mask.""" + if not labelpath: + return values + maskvids = np.loadtxt(labelpath, dtype=int, skiprows=2, usecols=[0]) + imask = np.ones(values.shape, dtype=bool) + imask[maskvids] = False + values[imask] = np.nan + return values + diff --git a/whippersnappy/utils/image.py b/whippersnappy/utils/image.py new file mode 100644 index 0000000..52c2645 --- /dev/null +++ b/whippersnappy/utils/image.py @@ -0,0 +1,217 @@ +"""Image and text helper utilities used by snapshot renderers (moved under utils). +""" +import numpy as np +from PIL import Image, ImageDraw + +from whippersnappy.utils.colormap import heat_color +from whippersnappy.utils.types import OrientationType + +try: + # Prefer stdlib importlib.resources + from importlib import resources +except Exception: + import importlib_resources as resources +import warnings + +from PIL import ImageFont + + +def text_size(caption, font): + """Return text width and height in pixels.""" + dummy_img = Image.new("L", (1, 1)) + draw = ImageDraw.Draw(dummy_img) + bbox = draw.textbbox((0, 0), caption, font=font, anchor="lt") + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + return text_width, text_height + + +def get_colorbar_label_positions( + font, + labels, + colorbar_rect, + gapspace=0, + pos=True, + neg=True, + orientation=OrientationType.HORIZONTAL, +): + """Return label positions for a colorbar.""" + positions = {} + cb_x, cb_y, cb_width, cb_height = colorbar_rect + cb_labels_gap = 5 + + if orientation == OrientationType.HORIZONTAL: + label_y = cb_y + cb_height + cb_labels_gap + + w, _ = text_size(labels["upper"], font) + if pos: + positions["upper"] = (cb_x + cb_width - w, label_y) + else: + upper_x = cb_x + cb_width - w - int(gapspace) if gapspace > 0 else cb_x + cb_width - w + positions["upper"] = (upper_x, label_y) + + w, _ = text_size(labels["lower"], font) + if neg: + positions["lower"] = (cb_x, label_y) + else: + lower_x = cb_x + int(gapspace) if gapspace > 0 else cb_x + positions["lower"] = (lower_x, label_y) + + if neg and pos: + if gapspace == 0: + w, _ = text_size(labels["middle"], font) + positions["middle"] = (cb_x + cb_width // 2 - w // 2, label_y) + else: + w, _ = text_size(labels["middle_neg"], font) + positions["middle_neg"] = (cb_x + cb_width // 2 - w - int(gapspace), label_y) + w, _ = text_size(labels["middle_pos"], font) + positions["middle_pos"] = (cb_x + cb_width // 2 + int(gapspace), label_y) + else: + label_x = cb_x + cb_width + cb_labels_gap + + _, h = text_size(labels["upper"], font) + if pos: + positions["upper"] = (label_x, cb_y) + else: + upper_y = cb_y + int(gapspace) if gapspace > 0 else cb_y + positions["upper"] = (label_x, upper_y) + + _, h = text_size(labels["lower"], font) + if neg: + positions["lower"] = (label_x, cb_y + cb_height - 1.5 * h) + else: + lower_y = cb_y + cb_height - int(gapspace) - 1.5 * h if gapspace > 0 else cb_y + cb_height - 1.5 * h + positions["lower"] = (label_x, lower_y) + + if neg and pos: + if gapspace == 0: + _, h = text_size(labels["middle"], font) + positions["middle"] = (label_x, cb_y + cb_height // 2 - h // 2) + else: + _, h = text_size(labels["middle_pos"], font) + positions["middle_pos"] = (label_x, cb_y + cb_height // 2 - 1.5 * h - int(gapspace)) + _, h = text_size(labels["middle_neg"], font) + positions["middle_neg"] = (label_x, cb_y + cb_height // 2 + int(gapspace)) + + return positions + + +def create_colorbar( + fmin, + fmax, + invert, + orientation=OrientationType.HORIZONTAL, + colorbar_scale=1, + pos=True, + neg=True, + font_file=None, +): + """Create a colorbar image (PIL.Image) using the project's heat_color. + + Parameters mirror the previous implementation in `render.py`. + """ + # If fmin/fmax are not specified, we cannot create a meaningful colorbar. + if fmin is None or fmax is None: + return None + + cwidth = int(200 * colorbar_scale) + cheight = int(30 * colorbar_scale) + gapspace = 0 + + if fmin > 0.01: + num = int(0.42 * cwidth) + gapspace = 0.08 * cwidth + else: + num = int(0.5 * cwidth) + if not neg or not pos: + num = num * 2 + gapspace = gapspace * 2 + + values = np.nan * np.ones(cwidth) + steps = np.linspace(0.01, 1, num) + if pos and not neg: + values[-steps.size:] = steps + elif not pos and neg: + values[: steps.size] = -1.0 * np.flip(steps) + else: + values[: steps.size] = -1.0 * np.flip(steps) + values[-steps.size:] = steps + + colors = heat_color(values, invert) + colors[np.isnan(values), :] = 0.33 * np.ones((1, 3)) + img_bar = np.uint8(np.tile(colors, (cheight, 1, 1)) * 255) + + pad_top, pad_left = 3, 10 + img_buf = np.zeros((cheight + 2 * pad_top, cwidth + 2 * pad_left, 3), dtype=np.uint8) + img_buf[pad_top : cheight + pad_top, pad_left : cwidth + pad_left, :] = img_bar + image = Image.fromarray(img_buf) + + if font_file is None: + # Try to load bundled font from package resources + font = None + try: + font_trav = resources.files("whippersnappy").joinpath("resources", "fonts", "Roboto-Regular.ttf") + with resources.as_file(font_trav) as font_path: + font = ImageFont.truetype(str(font_path), int(12 * colorbar_scale)) + except Exception: + warnings.warn("Roboto font not found in package resources; falling back to default font", UserWarning) + font = ImageFont.load_default() + else: + try: + font = ImageFont.truetype(font_file, int(12 * colorbar_scale)) + except Exception: + font = ImageFont.load_default() + + labels = {} + labels["upper"] = f">{fmax:.2f}" if pos else (f"{-fmin:.2f}" if gapspace != 0 else "0") + labels["lower"] = f"<{-fmax:.2f}" if neg else (f"{fmin:.2f}" if gapspace != 0 else "0") + if neg and pos and gapspace != 0: + labels["middle_neg"] = f"{-fmin:.2f}" + labels["middle_pos"] = f"{fmin:.2f}" + elif neg and pos and gapspace == 0: + labels["middle"] = "0" + + caption_sizes = [text_size(caption, font) for caption in labels.values()] + max_caption_width = int(max([caption_size[0] for caption_size in caption_sizes])) + max_caption_height = int(max([caption_size[1] for caption_size in caption_sizes])) + + if orientation == OrientationType.VERTICAL: + image = image.rotate(90, expand=True) + new_width = image.width + int(max_caption_width) + new_image = Image.new("RGB", (new_width, image.height), (0, 0, 0)) + new_image.paste(image, (0, 0)) + image = new_image + colorbar_rect = (pad_top, pad_left, cheight, cwidth) + else: + new_height = image.height + int(max_caption_height * 2) + new_image = Image.new("RGB", (image.width, new_height), (0, 0, 0)) + new_image.paste(image, (0, 0)) + image = new_image + colorbar_rect = (pad_left, pad_top, cwidth, cheight) + + positions = get_colorbar_label_positions(font, labels, colorbar_rect, gapspace, pos, neg, orientation) + draw = ImageDraw.Draw(image) + for label_key, position in positions.items(): + draw.text((int(position[0]), int(position[1])), labels[label_key], fill=(220, 220, 220), font=font) + + return image + + +def load_roboto_font(size=14): + """Load bundled Roboto-Regular.ttf from package resources. + + Returns a PIL ImageFont instance. Falls back to ImageFont.load_default() + if the bundled font isn't available. + """ + try: + # resources was imported earlier in this module + font_trav = resources.files("whippersnappy").joinpath("resources", "fonts", "Roboto-Regular.ttf") + with resources.as_file(font_trav) as font_path: + return ImageFont.truetype(str(font_path), size) + except Exception: + warnings.warn("Roboto font not found in package resources; falling back to default font", UserWarning) + try: + return ImageFont.load_default() + except Exception: + return None + diff --git a/whippersnappy/types.py b/whippersnappy/utils/types.py similarity index 100% rename from whippersnappy/types.py rename to whippersnappy/utils/types.py From c24b56614fc630d72681e4b7e8dd388303e6f610 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Mon, 16 Feb 2026 16:22:22 +0100 Subject: [PATCH 02/83] add logger and raise errors instead of exit --- whippersnappy/cli/whippersnap.py | 68 +++++++++++++++++------- whippersnappy/plot3d.py | 15 ++++-- whippersnappy/snap.py | 88 +++++++++++++++++--------------- whippersnappy/utils/_config.py | 4 +- whippersnappy/utils/colormap.py | 9 +++- 5 files changed, 117 insertions(+), 67 deletions(-) diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index 1a44869..532b0bf 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -25,13 +25,17 @@ import math import os import signal -import sys import threading +import logging import glfw import OpenGL.GL as gl import pyrr -from PyQt6.QtWidgets import QApplication +try: + from PyQt6.QtWidgets import QApplication +except Exception: + # GUI dependency missing; handle at runtime when interactive mode is requested + QApplication = None from whippersnappy import snap4 from whippersnappy.geometry import get_surf_name, prepare_geometry @@ -41,6 +45,9 @@ ) from whippersnappy.gui import ConfigWindow +# Module logger +logger = logging.getLogger(__name__) + # Global variables for config app configuration state: current_fthresh_ = None current_fmax_ = None @@ -101,13 +108,12 @@ def show_window( return False if surfname is None: - print("[INFO] No surf_name provided. Looking for options in surf directory...") + logger.info("No surf_name provided. Looking for options in surf directory...") found_surfname = get_surf_name(sdir, hemi) if found_surfname is None: - print( - f"[ERROR] Could not find a valid surf file in {sdir} for hemi: {hemi}!" - ) - sys.exit(0) + msg = f"Could not find a valid surf file in {sdir} for hemi: {hemi}!" + logger.error(msg) + raise FileNotFoundError(msg) meshpath = os.path.join(sdir, "surf", hemi + "." + found_surfname) else: meshpath = os.path.join(sdir, "surf", hemi + "." + surfname) @@ -132,11 +138,7 @@ def show_window( ) shader = setup_shader(meshdata, triangles, wwidth, weight, specular=specular) - print() - print("Keys:") - print("Left - Right : Rotate Geometry") - print("ESC : Quit") - print() + logger.info("\nKeys:\nLeft - Right : Rotate Geometry\nESC : Quit\n") ypos = 0 while glfw.get_key( @@ -194,6 +196,10 @@ def config_app_exit_handler(): def run(): global current_fthresh_, current_fmax_, app_, app_window_ + # Configure basic logging for CLI invocation so messages from module loggers + # are visible to end users. Avoid configuring on import by doing this here. + import logging as _logging + _logging.basicConfig(level=_logging.INFO, format='%(levelname)s: %(message)s') parser = argparse.ArgumentParser() parser.add_argument( @@ -282,20 +288,37 @@ def run(): # check for mutually exclusive arguments if (args.lh_overlay or args.rh_overlay) and (args.lh_annot or args.rh_annot): - print("[ERROR] Cannot use lh_overlay/rh_overlay and lh_annot/rh_annot arguments at the same time.") - sys.exit(0) + msg = "Cannot use lh_overlay/rh_overlay and lh_annot/rh_annot arguments at the same time." + logger.error(msg) + raise ValueError(msg) # check if at least one variant is present if args.lh_overlay is None and args.rh_overlay is None and args.lh_annot is None and args.rh_annot is None: - print("[ERROR] Either lh_overlay/rh_overlay or lh_annot/rh_annot must be present.") - sys.exit(0) + msg = "Either lh_overlay/rh_overlay or lh_annot/rh_annot must be present." + logger.error(msg) + raise ValueError(msg) # check if both hemis are present if (args.lh_overlay is None and args.rh_overlay is not None) or \ (args.lh_overlay is not None and args.rh_overlay is None) or \ (args.lh_annot is None and args.rh_annot is not None) or \ (args.lh_annot is not None and args.rh_annot is None): - print("[ERROR] If lh_overlay or lh_annot is present, rh_overlay or rh_annot must also be present " \ - "(and vice versa).") - sys.exit(0) + msg = "If lh_overlay or lh_annot is present, rh_overlay or rh_annot must also be present (and vice versa)." + logger.error(msg) + raise ValueError(msg) + + logger.info(f"Left hemisphere overlay: {args.lh_overlay}") + logger.info(f"Right hemisphere overlay: {args.rh_overlay}") + logger.info(f"Left hemisphere annotation: {args.lh_annot}") + logger.info(f"Right hemisphere annotation: {args.rh_annot}") + logger.info(f"Subject directory: {args.sdir}") + logger.info(f"Surface name: {args.surf_name}") + logger.info(f"Output path: {args.output_path}") + logger.info(f"Caption: {args.caption}") + logger.info(f"Colorbar: {'enabled' if not args.no_colorbar else 'disabled'}") + logger.info(f"fmax: {args.fmax}") + logger.info(f"fthresh: {args.fthresh}") + logger.info(f"Interactive mode: {'enabled' if args.interactive else 'disabled'}") + logger.info(f"Color scale inversion: {'enabled' if args.invert else 'disabled'}") + logger.info(f"Specular reflection: {'enabled' if args.specular else 'disabled'}") # if not args.interactive: @@ -336,6 +359,13 @@ def run(): ) thread.start() + # Ensure GUI toolkit is available + if QApplication is None: + raise ImportError( + "Interactive mode requires PyQt6. Install it (pip install PyQt6) " + "or run without --interactive." + ) + # Setting up and running config app window (must be main thread): app_ = QApplication([]) app_.setStyle("Fusion") # the default diff --git a/whippersnappy/plot3d.py b/whippersnappy/plot3d.py index 789db48..90d9b96 100644 --- a/whippersnappy/plot3d.py +++ b/whippersnappy/plot3d.py @@ -18,6 +18,7 @@ @Created : 14.02.2026 """ +import logging import numpy as np import pythreejs as p3js from ipywidgets import HTML, VBox @@ -27,6 +28,9 @@ from .geometry import prepare_geometry from .gl import get_webgl_shaders +# Module logger +logger = logging.getLogger(__name__) + def plot3d( meshpath, @@ -83,9 +87,12 @@ def plot3d( Examples -------- - >>> from whippersnappy import plot3d - >>> viewer = plot3d('path/to/lh.white', curvpath='path/to/lh.curv') - >>> display(viewer) + In a notebook: + + from whippersnappy import plot3d + from IPython.display import display + viewer = plot3d('path/to/lh.white', curvpath='path/to/lh.curv') + display(viewer) """ # Load and prepare mesh data @@ -95,7 +102,7 @@ def plot3d( minval, maxval, invert, scale, color_mode ) - print(f"Loaded mesh: {meshdata.shape[0]} vertices, {triangles.shape[0]} faces") + logger.info("Loaded mesh: %d vertices, %d faces", meshdata.shape[0], triangles.shape[0]) # Extract vertices, normals, and colors vertices = meshdata[:, :3] # x, y, z diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py index 832eaf0..825ed3c 100644 --- a/whippersnappy/snap.py +++ b/whippersnappy/snap.py @@ -3,7 +3,7 @@ """ import os -import sys +import logging import glfw import numpy as np @@ -18,6 +18,9 @@ from . import gl as _gl from .gl import get_view_matrices +# Module logger +logger = logging.getLogger(__name__) + def snap1( meshpath, @@ -56,16 +59,16 @@ def snap1( ui_scale = min(wwidth / ref_width, wheight / ref_height) if not glfw.init(): - print("[ERROR] Could not init glfw!") - sys.exit(1) + logger.error("Could not init glfw!") + raise RuntimeError("Could not initialize GLFW; OpenGL context unavailable") primary_monitor = glfw.get_primary_monitor() mode = glfw.get_video_mode(primary_monitor) screen_width = mode.size.width screen_height = mode.size.height if wwidth > screen_width: - print(f"[INFO] Requested width {wwidth} exceeds screen width {screen_width}, expect black bars") + logger.info("Requested width %d exceeds screen width %d, expect black bars", wwidth, screen_width) elif wheight > screen_height: - print(f"[INFO] Requested height {wheight} exceeds screen height {screen_height}, expect black bars") + logger.info("Requested height %d exceeds screen height %d, expect black bars", wheight, screen_height) image = Image.new("RGB", (wwidth, wheight)) @@ -96,17 +99,17 @@ def snap1( if overlaypath is not None: if color_mode == ColorSelection.POSITIVE: if not pos and neg: - print("[Error] Overlay has no values to display with positive color_mode") - sys.exit(1) + logger.error("Overlay has no values to display with positive color_mode") + raise ValueError("Overlay has no values to display with positive color_mode") neg = False elif color_mode == ColorSelection.NEGATIVE: if pos and not neg: - print("[Error] Overlay has no values to display with negative color_mode") - sys.exit(1) + logger.error("Overlay has no values to display with negative color_mode") + raise ValueError("Overlay has no values to display with negative color_mode") pos = False if not pos and not neg: - print("[Error] Overlay has no values to display") - sys.exit(1) + logger.error("Overlay has no values to display") + raise ValueError("Overlay has no values to display") shader = _gl.setup_shader(meshdata, triangles, brain_display_width, brain_display_height, specular=specular, ambient=ambient) @@ -186,7 +189,7 @@ def snap1( glfw.terminate() return image - print(f"[INFO] Saving snapshot to {outpath}") + logger.info("Saving snapshot to %s", outpath) image.save(outpath) glfw.terminate() return None @@ -217,12 +220,12 @@ def snap4( wheight = 450 # Try to create a visible window first (better for debugging), # but fall back to an invisible/offscreen window if that fails. - window = _gl.init_window(wwidth, wheight, "WhipperSnapPy 2.0", visible=True) + window = _gl.init_window(wwidth, wheight, "WhipperSnapPy", visible=True) if not window: - print("[WARNING] Could not create visible GLFW window; retrying with invisible window (offscreen).") - window = _gl.init_window(wwidth, wheight, "WhipperSnapPy 2.0", visible=False) + logger.warning("Could not create visible GLFW window; retrying with invisible window (offscreen).") + window = _gl.init_window(wwidth, wheight, "WhipperSnapPy", visible=False) if not window: - print("[ERROR] Could not create any GLFW window/context. OpenGL context unavailable.") + logger.error("Could not create any GLFW window/context. OpenGL context unavailable.") return None rot_z = pyrr.Matrix44.from_z_rotation(-0.5 * np.pi) @@ -232,17 +235,22 @@ def snap4( view_right = rot_y * view_left transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) + # Predefine hemisphere images so static analysis knows they exist even if + # an earlier step raises an exception (we still will fail at runtime). + lhimg = None + rhimg = None + for hemi in ("lh", "rh"): if surfname is None: if sdir is None: sdir = os.environ.get("SUBJECTS_DIR") if not sdir: - print("[INFO] No surf_name or subjects directory (sdir) provided") - sys.exit(1) + logger.error("No surf_name or subjects directory (sdir) provided") + raise ValueError("No surf_name or SUBJECTS_DIR provided") found_surfname = get_surf_name(sdir, hemi) if found_surfname is None: - print(f"[ERROR] Could not find valid surface in {sdir} for hemi: {hemi}!") - sys.exit(1) + logger.error("Could not find valid surface in %s for hemi: %s!", sdir, hemi) + raise FileNotFoundError(f"Could not find valid surface in {sdir} for hemi: {hemi}") meshpath = os.path.join(sdir, "surf", hemi + "." + found_surfname) else: meshpath = os.path.join(sdir, "surf", hemi + "." + surfname) @@ -254,47 +262,47 @@ def snap4( annotpath = lhannotpath if hemi == "lh" else rhannotpath # Diagnostic: report mesh and overlay paths and whether they exist - print(f"[DEBUG] hemisphere={hemi}") - print(f"[DEBUG] meshpath={meshpath} exists={os.path.exists(meshpath)}") + logger.debug("hemisphere=%s", hemi) + logger.debug("meshpath=%s exists=%s", meshpath, os.path.exists(meshpath)) if overlaypath is not None: - print(f"[DEBUG] overlaypath={overlaypath} exists={os.path.exists(overlaypath)}") + logger.debug("overlaypath=%s exists=%s", overlaypath, os.path.exists(overlaypath)) if annotpath is not None: - print(f"[DEBUG] annotpath={annotpath} exists={os.path.exists(annotpath)}") + logger.debug("annotpath=%s exists=%s", annotpath, os.path.exists(annotpath)) if curvpath is not None: - print(f"[DEBUG] curvpath={curvpath} exists={os.path.exists(curvpath)}") + logger.debug("curvpath=%s exists=%s", curvpath, os.path.exists(curvpath)) try: meshdata, triangles, fthresh, fmax, pos, neg = prepare_geometry( meshpath, overlaypath, annotpath, curvpath, labelpath, fthresh, fmax, invert, scale=brain_scale ) except Exception as e: - print(f"[ERROR] prepare_geometry failed for {meshpath}: {e}") + logger.error("prepare_geometry failed for %s: %s", meshpath, e) glfw.terminate() return None # Diagnostics about mesh data try: - print(f"[DEBUG] meshdata shape: {getattr(meshdata, 'shape', None)}; triangles count: {getattr(triangles, 'size', None)}") + logger.debug("meshdata shape: %s; triangles count: %s", getattr(meshdata, 'shape', None), getattr(triangles, 'size', None)) except Exception: pass if pos == 0 and neg == 0: - print("[Error] Overlay has no values to display") - sys.exit(1) + logger.error("Overlay has no values to display") + raise ValueError("Overlay has no values to display") try: shader = _gl.setup_shader(meshdata, triangles, wwidth, wheight, specular=specular, ambient=ambient) - print("[DEBUG] Shader setup complete") + logger.debug("Shader setup complete") except Exception as e: - print(f"[ERROR] setup_shader failed: {e}") + logger.error("setup_shader failed: %s", e) glfw.terminate() return None try: gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) except Exception as e: - print(f"[ERROR] glClear failed: {e}") - print(f"glError: {gl.glGetError()}") + logger.error("glClear failed: %s", e) + logger.error("glError: %s", gl.glGetError()) glfw.terminate() return None transform_loc = gl.glGetUniformLocation(shader, "transform") @@ -303,9 +311,9 @@ def snap4( gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) try: im1 = _gl.capture_window(wwidth, wheight) - print(f"[DEBUG] Captured image 1 size: {im1.size}") + logger.debug("Captured image 1 size: %s", im1.size) except Exception as e: - print(f"[ERROR] capture_window failed: {e}") + logger.error("capture_window failed: %s", e) glfw.terminate() return None @@ -313,8 +321,8 @@ def snap4( try: gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) except Exception as e: - print(f"[ERROR] glClear failed: {e}") - print(f"glError: {gl.glGetError()}") + logger.error("glClear failed: %s", e) + logger.error("glError: %s", gl.glGetError()) glfw.terminate() return None viewmat = view_right if hemi == "lh" else view_left @@ -322,9 +330,9 @@ def snap4( gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) try: im2 = _gl.capture_window(wwidth, wheight) - print(f"[DEBUG] Captured image 2 size: {im2.size}") + logger.debug("Captured image 2 size: %s", im2.size) except Exception as e: - print(f"[ERROR] capture_window failed: {e}") + logger.error("capture_window failed: %s", e) glfw.terminate() return None @@ -377,7 +385,7 @@ def snap4( # Otherwise save to disk if outpath: - print(f"[INFO] Saving snapshot to {outpath}") + logger.info("Saving snapshot to %s", outpath) image.save(outpath) glfw.terminate() diff --git a/whippersnappy/utils/_config.py b/whippersnappy/utils/_config.py index b1d393c..2c2ca61 100644 --- a/whippersnappy/utils/_config.py +++ b/whippersnappy/utils/_config.py @@ -100,8 +100,8 @@ def _list_dependencies_info(out: Callable, ljust: int, dependencies: list[str]): # handle special dependencies with backends, C dep, .. if dep in ("matplotlib", "seaborn") and version_ != "Not found.": try: - from matplotlib import pyplot as plt - + import importlib + plt = importlib.import_module("matplotlib.pyplot") backend = plt.get_backend() except Exception: backend = "Not found" diff --git a/whippersnappy/utils/colormap.py b/whippersnappy/utils/colormap.py index 125cc97..26a6100 100644 --- a/whippersnappy/utils/colormap.py +++ b/whippersnappy/utils/colormap.py @@ -1,10 +1,15 @@ """Colormap and value preprocessing utilities.""" +import logging import numpy as np from whippersnappy.utils.types import ColorSelection +# Module logger +logger = logging.getLogger(__name__) + + def heat_color(values, invert=False): """Convert an array of float values into RBG heat color values.""" if invert: @@ -47,8 +52,8 @@ def rescale_overlay(values, minval=None, maxval=None): valabs = np.abs(values) if maxval < 0 or minval < 0: - print("rescale_overlay ERROR: min and maxval should both be positive!") - exit(1) + logger.error("rescale_overlay ERROR: min and maxval should both be positive!") + raise ValueError("minval and maxval must be non-negative") values[valabs < minval] = np.nan range_val = maxval - minval From 870e1dcfc93b5959bef88991e335c560a2d8504f Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Mon, 16 Feb 2026 17:27:21 +0100 Subject: [PATCH 03/83] add docstrings --- examples/whippersnappy_demo.ipynb | 15 ++- whippersnappy/__init__.py | 3 +- whippersnappy/cli/whippersnap.py | 100 +++++++++++++----- whippersnappy/commands/sys_info.py | 16 ++- whippersnappy/geometry/prepare.py | 83 ++++++++++++++- whippersnappy/geometry/surf_name.py | 19 +++- whippersnappy/gl/camera.py | 57 +++++++++- whippersnappy/gl/shaders.py | 27 ++++- whippersnappy/gl/utils.py | 147 +++++++++++++++++++++----- whippersnappy/gl/views.py | 25 ++++- whippersnappy/gui/config_app.py | 104 +++++++++---------- whippersnappy/plot3d.py | 94 +++++++++++------ whippersnappy/snap.py | 156 +++++++++++++++++++++++++++- whippersnappy/utils/colormap.py | 104 +++++++++++++++++-- whippersnappy/utils/image.py | 89 ++++++++++++++-- 15 files changed, 847 insertions(+), 192 deletions(-) diff --git a/examples/whippersnappy_demo.ipynb b/examples/whippersnappy_demo.ipynb index f6e74dd..8010485 100644 --- a/examples/whippersnappy_demo.ipynb +++ b/examples/whippersnappy_demo.ipynb @@ -39,15 +39,19 @@ "# Setup Paths\n", "# Edit these to point to your FreeSurfer/FastSurfer data:\n", "\n", - "# Set your subject directory here (either an absolute path to a subject's directory\n", - "# containing a `surf/` subdirectory, or None to use the SUBJECTS_DIR environment variable).\n", + "# Set your subject directory here\n", + "# (either an absolute path to a subject's directory\n", + "# containing a `surf/` subdirectory).\n", "# Example: sdir = '/home/user/freesurfer/subjects/subject01'\n", "# IMPORTANT: set this value before running the rest of the notebook.\n", - "sdir = '/path/to/your/subjectdir' # <-- update this to your SUBJECTS_DIR or subject path\n", + "sdir = '/path/to/your/subjectdir' # <-- update\n", "\n", "# Verify that sdir exists and is a directory\n", "if not os.path.isdir(sdir):\n", - " raise ValueError(f\"Subject directory does not exist: {sdir}\\nPlease set `sdir` to a valid subject directory containing a 'surf/' subdirectory.\")\n", + " raise ValueError(\n", + " f\"Subject directory does not exist: {sdir}\\n\"\n", + " \"Please set `sdir` to a valid subject directory containing a 'surf/' subdirectory.\"\n", + " )\n", "\n", "# Derive per-hemisphere paths from the subject directory\n", "lh_surf_path = os.path.join(sdir, 'surf', 'lh.white')\n", @@ -107,7 +111,8 @@ "outputs": [], "source": [ "# Part 1: Static Rendering - Snap4 (both hemispheres)\n", - "# Use snap4 to render front/back views for left and right hemispheres and display the combined image.\n", + "# Use snap4 to render front/back views for left and right\n", + "# hemispheres and display the combined image.\n", "\n", "print(f\"Using subjects dir: {sdir}\")\n", "\n", diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index 8cb37cc..d649b9a 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -43,11 +43,10 @@ """ -from .utils.types import ViewType - from ._version import __version__ # noqa: F401 from .snap import snap1, snap4 from .utils._config import sys_info # noqa: F401 +from .utils.types import ViewType # 3D plotting for notebooks (Three.js-based, works in all Jupyter environments) try: diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index 532b0bf..26e8fd4 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -22,15 +22,16 @@ """ import argparse +import logging import math import os import signal import threading -import logging import glfw import OpenGL.GL as gl import pyrr + try: from PyQt6.QtWidgets import QApplication except Exception: @@ -68,36 +69,49 @@ def show_window( curvname="curv", specular=True, ): - """ - Start an interactive window in which an overlay can be viewed. + """Start a live interactive OpenGL window for viewing a hemisphere. + + The function initializes a GLFW window and renders the requested + hemisphere with any provided overlay/annotation. It polls for + configuration updates from the separate configuration GUI and updates + the rendered scene accordingly. Parameters ---------- - hemi : str - Hemisphere; one of: ['lh', 'rh']. - overlaypath : str - Path to the overlay file for the specified hemi (FreeSurfer format). - annotpath : str - Path to the annotation file for the specified hemi (FreeSurfer format). - sdir : str - Subject dir containing surf files. - caption : str - Caption text to be placed on the image. - invert : bool - Invert color (blue positive, red negative). - labelname : str - Label for masking, usually cortex.label. - surfname : str - Surface to display values on, usually pial_semi_inflated from fsaverage. - curvname : str - Curvature file for texture in non-colored regions (default curv). - specular : bool, optional - If True, enable specular. + hemi : {'lh','rh'} + Hemisphere to display. + overlaypath : str or None, optional + Path to a per-vertex overlay file (e.g. thickness). If ``None`` no + overlay will be applied. + annotpath : str or None, optional + Path to a .annot file providing categorical labels for vertices. + sdir : str or None, optional + Subject directory containing `surf/` and `label/` subdirectories. + caption : str or None, optional + Caption text to display in the viewer window. + invert : bool, optional, default False + Invert the overlay color mapping. + labelname : str, optional, default 'cortex.label' + Label filename used to mask vertices. + surfname : str or None, optional + Surface basename (e.g. 'white'); if ``None`` the function will try + to auto-detect a suitable surface in ``sdir``. + curvname : str or None, optional, default 'curv' + Curvature filename used to texture non-colored regions. + specular : bool, optional, default True + Enable specular highlights in the shader. Returns ------- - None - This function does not return any value. + bool + ``False`` if the window/context could not be created, ``None`` on + normal termination. The function primarily drives an interactive + event loop and does not return programmatic geometry objects. + + Raises + ------ + FileNotFoundError + If a requested surface file cannot be located in ``sdir``. """ global current_fthresh_, current_fmax_, app_, app_window_, app_window_closed_ @@ -190,11 +204,47 @@ def show_window( def config_app_exit_handler(): + """Mark the configuration application as closed. + + This handler is connected to the configuration app's about-to-quit + signal and sets a module-level flag that the main OpenGL loop polls to + terminate cleanly. + + Returns + ------- + None + """ global app_window_closed_ app_window_closed_ = True def run(): + """Command-line entry point for the WhipperSnapPy snapshot/interactive tool. + + Parses command-line arguments, validates argument combinations, and + either launches a non-interactive snapshot generation (``snap4``) or + starts the interactive viewer and configuration GUI. + + Behavior + -------- + - Validates that either overlay or annotation inputs are provided for + both hemispheres (or raises ``ValueError``). + - In non-interactive mode calls :func:`whippersnappy.snap4` to produce + and optionally save a composed image. + - In interactive mode spawns the OpenGL viewer thread and launches the + PyQt6-based configuration window in the main thread. + + Returns + ------- + None + + Raises + ------ + ValueError + For invalid or mutually exclusive argument combinations. + ImportError + If interactive mode is requested but PyQt6 is not available. + """ global current_fthresh_, current_fmax_, app_, app_window_ # Configure basic logging for CLI invocation so messages from module loggers # are visible to end users. Avoid configuring on import by doing this here. diff --git a/whippersnappy/commands/sys_info.py b/whippersnappy/commands/sys_info.py index c1607cf..b000cfa 100644 --- a/whippersnappy/commands/sys_info.py +++ b/whippersnappy/commands/sys_info.py @@ -4,7 +4,21 @@ def run(): - """Run sys_info() command.""" + """Run the sys_info command-line helper. + + Parses CLI arguments and delegates to the package-level `sys_info` + function which prints system and dependency information. + + Parameters + ---------- + None + + Returns + ------- + None + This function prints information to stdout via the `sys_info` + helper and does not return a value. + """ parser = argparse.ArgumentParser( prog=f"{__package__.split('.')[0]}-sys_info", description="sys_info" ) diff --git a/whippersnappy/geometry/prepare.py b/whippersnappy/geometry/prepare.py index 6035151..8bd63f9 100644 --- a/whippersnappy/geometry/prepare.py +++ b/whippersnappy/geometry/prepare.py @@ -15,7 +15,24 @@ def normalize_mesh(v, scale=1.0): - """Normalize mesh vertex coordinates.""" + """Center and scale mesh vertex coordinates to a unit cube. + + The function recenters the vertices around the origin and scales them so + that the maximum extent fits into a unit cube, optionally applying an + additional scale factor. + + Parameters + ---------- + v : numpy.ndarray + Vertex coordinate array of shape (n_vertices, 3). + scale : float, optional + Additional multiplicative scale applied after normalization. + + Returns + ------- + numpy.ndarray + Normalized vertex coordinates with same shape as ``v``. + """ bbmax = np.max(v, axis=0) bbmin = np.min(v, axis=0) v = v - 0.5 * (bbmax + bbmin) @@ -24,7 +41,20 @@ def normalize_mesh(v, scale=1.0): def vertex_normals(v, t): - """Compute vertex normals.""" + """Compute per-vertex normals from triangle connectivity. + + Parameters + ---------- + v : numpy.ndarray + Vertex coordinates (n_vertices, 3). + t : numpy.ndarray + Triangle indices (n_faces, 3). + + Returns + ------- + numpy.ndarray + Per-vertex unit normals (n_vertices, 3). + """ v0 = v[t[:, 0], :] v1 = v[t[:, 1], :] v2 = v[t[:, 2], :] @@ -56,7 +86,51 @@ def prepare_geometry( scale=1.85, color_mode=ColorSelection.BOTH, ): - """Prepare meshdata for upload to GPU.""" + """Prepare vertex and color arrays for GPU upload. + + This function loads a surface geometry from ``surfpath``, optionally + loads an overlay (mgh/curv) or annotation (.annot) and produces an + interleaved vertex array containing positions, normals and colors + suitable for uploading to OpenGL (vertex buffer objects). + + Parameters + ---------- + surfpath : str + Path to the surface file. + overlaypath : str or None, optional + Path to an overlay (mgh/curv) file providing per-vertex scalar + values used for coloring. + annotpath : str or None, optional + Path to a FreeSurfer .annot file for categorical labeling. + curvpath : str or None, optional + Path to curvature data used as fallback texture. + labelpath : str or None, optional + Path to a label file used to mask vertices. + minval, maxval : float or None, optional + Threshold and saturation values for overlay scaling. + invert : bool, optional, default False + Invert color mapping. + scale : float, optional, default 1.85 + Geometry scaling factor applied by ``normalize_mesh``. + color_mode : ColorSelection, optional, default ColorSelection.BOTH + Which sign(s) of overlay values to use for coloring. + + Returns + ------- + vertexdata : numpy.ndarray + Nx9 array (position x3, normal x3, color x3) ready for GPU upload. + triangles : numpy.ndarray + Mx3 uint32 triangle index array. + fmin, fmax : float or None + Final threshold and saturation values used for color mapping. + pos, neg : bool or None + Flags indicating whether positive/negative overlay values are present. + + Raises + ------ + ValueError + If overlay or annotation arrays do not match the surface vertex count. + """ surf = read_geometry(surfpath, read_metadata=False) vertices = normalize_mesh(np.array(surf[0], dtype=np.float32), scale) triangles = np.array(surf[1], dtype=np.uint32) @@ -68,7 +142,8 @@ def prepare_geometry( if curvpath: curv = read_morph_data(curvpath) if curv.shape[0] != num_vertices: - warnings.warn(f"Curvature file {curvpath} has {curv.shape[0]} values, but mesh has {num_vertices}.") + warnings.warn(f"Curvature file {curvpath} has {curv.shape[0]} values, but mesh has {num_vertices}.", + stacklevel=2) else: sulcmap = binary_color(curv, 0.0, color_low=0.5, color_high=0.33) diff --git a/whippersnappy/geometry/surf_name.py b/whippersnappy/geometry/surf_name.py index 69e517e..1d1f07d 100644 --- a/whippersnappy/geometry/surf_name.py +++ b/whippersnappy/geometry/surf_name.py @@ -7,13 +7,26 @@ def get_surf_name(sdir, hemi): - """Find a valid surface file in the specified subject directory. + """Find a suitable surface basename in a subject directory. - Returns the surface basename (e.g. 'white', 'inflated', etc.) or None. + The function searches the standard FreeSurfer `surf/` directory for a + common surface name in order of preference and returns the basename + (search for 'pial_semi_inflated', 'white', and then 'inflated'). + + Parameters + ---------- + sdir : str + Path to the subject directory containing a `surf/` subdirectory. + hemi : {'lh','rh'} + Hemisphere prefix to use when searching for surface files. + + Returns + ------- + surf_name : str or None + The surface basename if found, otherwise ``None``. """ for surf_name_option in ["pial_semi_inflated", "white", "inflated"]: path = os.path.join(sdir, "surf", f"{hemi}.{surf_name_option}") if os.path.exists(path): return surf_name_option return None - diff --git a/whippersnappy/gl/camera.py b/whippersnappy/gl/camera.py index 089dd52..bcfa376 100644 --- a/whippersnappy/gl/camera.py +++ b/whippersnappy/gl/camera.py @@ -4,26 +4,73 @@ def make_projection(width, height, fov=20.0, near=0.1, far=100.0): - """Create a perspective projection matrix.""" + """Create a 4x4 perspective projection matrix. + + Parameters + ---------- + width, height : int + Viewport dimensions in pixels (used to compute aspect ratio). + fov : float, optional, default 20.0 + Vertical field of view in degrees. + near, far : float, optional, default 0.1, 100.0 + Near and far clipping planes. Default are 0.1 and 100.0, respectively. + + Returns + ------- + numpy.ndarray + 4x4 projection matrix. + """ return pyrr.matrix44.create_perspective_projection(fov, width / height, near, far) def make_view(camera_pos=(0.0, 0.0, -5.0)): - """Create a view matrix from a camera position.""" + """Create a view matrix for a camera located at ``camera_pos``. + + Parameters + ---------- + camera_pos : sequence of float, optional, default (0.0, 0.0, -5.0) + 3-element position of the camera in world space. + Default is (0.0, 0.0, -5.0). + + Returns + ------- + numpy.ndarray + 4x4 view matrix. + """ return pyrr.matrix44.create_from_translation(pyrr.Vector3(camera_pos)) def make_model(): - """Create a default model matrix.""" + """Create a default model matrix (identity translation). + + Returns + ------- + numpy.ndarray + 4x4 model matrix. + """ return pyrr.matrix44.create_from_translation(pyrr.Vector3([0.0, 0.0, 0.0])) def make_transform(translation, rotation, scale): - """Create a model transform from translation, rotation and scale.""" + """Build a model transform matrix from translation, rotation and uniform scale. + + Parameters + ---------- + translation : sequence of float + 3-element translation vector. + rotation : numpy.ndarray + 4x4 rotation matrix. + scale : float + Uniform scaling factor. + + Returns + ------- + numpy.ndarray + 4x4 transformation matrix (translation * rotation * scale). + """ scale_matrix = pyrr.matrix44.create_from_scale([scale, scale, scale]) return ( pyrr.matrix44.create_from_translation(translation) * rotation * scale_matrix ) - diff --git a/whippersnappy/gl/shaders.py b/whippersnappy/gl/shaders.py index acdc9ed..c6c237a 100644 --- a/whippersnappy/gl/shaders.py +++ b/whippersnappy/gl/shaders.py @@ -1,7 +1,18 @@ """Shared shader sources inside the gl package.""" def get_default_shaders(): - """Return the default vertex and fragment shader sources (GLSL 330).""" + """Return the default GLSL 330 vertex and fragment shader sources. + + These shaders are intended for desktop OpenGL (GLSL 330) use in the + offline snapshot renderer. The returned strings contain full GLSL + shader sources for vertex and fragment stages. + + Returns + ------- + vertex_shader, fragment_shader : tuple[str, str] + Tuple containing the vertex shader and fragment shader source code + as plain strings. + """ vertex_shader = """ #version 330 @@ -106,7 +117,18 @@ def get_default_shaders(): def get_webgl_shaders(): - """Return the default vertex and fragment shader sources (GLSL 330).""" + """Return vertex and fragment shader source strings suitable for WebGL/Three.js. + + These shader snippets are small GLSL pieces that expect Three.js to + provide built-in attributes/uniforms (e.g. projectionMatrix, + modelViewMatrix, normalMatrix). They are used by the Jupyter + pythreejs-based viewer. + + Returns + ------- + vertex_shader, fragment_shader : tuple[str, str] + Vertex and fragment shader source strings for WebGL / Three.js. + """ # Only declare custom attributes - Three.js provides position, normal, matrices # Don't declare position, normal, *Matrix @@ -191,4 +213,3 @@ def get_webgl_shaders(): return vertex_shader, fragment_shader - diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index 1aa3b58..bfcf0c1 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -15,14 +15,33 @@ def create_vao(): - """Create and bind a VAO, returning its handle.""" + """Create and bind a Vertex Array Object (VAO). + + Returns + ------- + int + OpenGL handle for the created VAO. + """ vao = gl.glGenVertexArrays(1) gl.glBindVertexArray(vao) return vao def compile_shader_program(vertex_src, fragment_src): - """Compile and link a shader program.""" + """Compile GLSL vertex and fragment sources and link them into a program. + + Parameters + ---------- + vertex_src : str + Vertex shader source code. + fragment_src : str + Fragment shader source code. + + Returns + ------- + int + OpenGL program handle. + """ return gl.shaders.compileProgram( shaders.compileShader(vertex_src, gl.GL_VERTEX_SHADER), shaders.compileShader(fragment_src, gl.GL_FRAGMENT_SHADER), @@ -30,7 +49,20 @@ def compile_shader_program(vertex_src, fragment_src): def setup_buffers(meshdata, triangles): - """Create VBO/EBO and upload mesh data.""" + """Create and upload vertex and element buffers for the mesh. + + Parameters + ---------- + meshdata : numpy.ndarray + Vertex array with interleaved attributes (position, normal, color). + triangles : numpy.ndarray + Face index array. + + Returns + ------- + (vbo, ebo) : tuple + OpenGL buffer handles for the VBO and EBO. + """ vbo = gl.glGenBuffers(1) gl.glBindBuffer(gl.GL_ARRAY_BUFFER, vbo) gl.glBufferData(gl.GL_ARRAY_BUFFER, meshdata.nbytes, meshdata, gl.GL_STATIC_DRAW) @@ -45,7 +77,13 @@ def setup_buffers(meshdata, triangles): def setup_vertex_attributes(shader): - """Configure vertex attribute pointers for position, normal, color.""" + """Configure vertex attribute pointers for position, normal and color. + + Parameters + ---------- + shader : int + OpenGL shader program handle used to query attribute locations. + """ position = gl.glGetAttribLocation(shader, "aPos") gl.glVertexAttribPointer( position, 3, gl.GL_FLOAT, gl.GL_FALSE, 9 * 4, gl.ctypes.c_void_p(0) @@ -66,13 +104,24 @@ def setup_vertex_attributes(shader): def set_default_gl_state(): - """Apply common GL state for rendering.""" + """Set frequently used default OpenGL state for rendering. + + This function enables depth testing and sets a default clear color. + """ gl.glClearColor(0.0, 0.0, 0.0, 1.0) gl.glEnable(gl.GL_DEPTH_TEST) def set_camera_uniforms(shader, view, projection, model): - """Set view/projection/model uniforms in the shader.""" + """Upload camera MVP (view, projection, model) matrices to the shader. + + Parameters + ---------- + shader : int + OpenGL shader program handle. + view, projection, model : array-like + 4x4 matrices to be uploaded to the corresponding shader uniforms. + """ view_loc = gl.glGetUniformLocation(shader, "view") proj_loc = gl.glGetUniformLocation(shader, "projection") model_loc = gl.glGetUniformLocation(shader, "model") @@ -82,7 +131,19 @@ def set_camera_uniforms(shader, view, projection, model): def set_lighting_uniforms(shader, specular=True, ambient=0.0, light_color=(1.0, 1.0, 1.0)): - """Set lighting uniforms in the shader.""" + """Set lighting-related uniforms (specular toggle, ambient, light color). + + Parameters + ---------- + shader : int + OpenGL shader program handle. + specular : bool, optional, default True + Enable specular highlights. + ambient : float, optional, default 0.0 + Ambient light strength. + light_color : tuple, optional, default (1.0, 1.0, 1.0) + RGB light color. + """ specular_loc = gl.glGetUniformLocation(shader, "doSpecular") gl.glUniform1i(specular_loc, specular) @@ -94,7 +155,22 @@ def set_lighting_uniforms(shader, specular=True, ambient=0.0, light_color=(1.0, def init_window(width, height, title="PyOpenGL", visible=True): - """Create an OpenGL window (GLFW) and make its context current.""" + """Create a GLFW window, make an OpenGL context current and return the window handle. + + Parameters + ---------- + width, height : int + Window dimensions in pixels. + title : str, optional, default 'PyOpenGL' + Window title. + visible : bool, optional, default True + If False create an invisible/offscreen window (useful for headless rendering). + + Returns + ------- + window or False + GLFW window handle on success, or False on failure. + """ if not glfw.init(): return False @@ -115,9 +191,28 @@ def init_window(width, height, title="PyOpenGL", visible=True): def setup_shader(meshdata, triangles, width, height, specular=True, ambient=0.0): - """Create vertex and fragment shaders, set up VAO/VBO/EBO, and initialize camera/lighting uniforms. - - This function composes several low-level helpers in this module and returns the compiled shader program handle. + """Create shader program, upload mesh and initialize camera & lighting. + + This is a convenience wrapper that compiles default shaders, creates + VAO/VBO/EBO, and configures common uniforms (camera matrices, lighting). + + Parameters + ---------- + meshdata : numpy.ndarray + Interleaved vertex data. + triangles : numpy.ndarray + Triangle indices. + width, height : int + Framebuffer size used to compute projection matrix. + specular : bool, optional, default True + Enable specular highlights. + ambient : float, optional, default 0.0 + Ambient lighting strength. + + Returns + ------- + shader : int + Compiled OpenGL shader program handle. """ vertex_shader, fragment_shader = get_default_shaders() @@ -139,9 +234,20 @@ def setup_shader(meshdata, triangles, width, height, specular=True, ambient=0.0) def capture_window(width, height): - """Capture the current GL framebuffer region into a PIL Image (RGB). + """Read the current GL framebuffer and return it as a PIL.Image (RGB). + + On macOS (retina) this function reads at double resolution and downscales + the result to compensate for pixel-density differences. - On macOS we adjust for the retina scaling factor by reading at double resolution and downsampling. + Parameters + ---------- + width, height : int + Desired output image dimensions. + + Returns + ------- + PIL.Image.Image + RGB image containing the captured framebuffer content. """ if sys.platform == "darwin": rwidth = 2 * width @@ -157,18 +263,3 @@ def capture_window(width, height): if sys.platform == "darwin": image.thumbnail((0.5 * rwidth, 0.5 * rheight), Image.Resampling.LANCZOS) return image - - -__all__ = [ - "create_vao", - "compile_shader_program", - "setup_buffers", - "setup_vertex_attributes", - "set_default_gl_state", - "set_camera_uniforms", - "set_lighting_uniforms", - "init_window", - "setup_shader", - "capture_window", -] - diff --git a/whippersnappy/gl/views.py b/whippersnappy/gl/views.py index 477f655..5a04b51 100644 --- a/whippersnappy/gl/views.py +++ b/whippersnappy/gl/views.py @@ -6,7 +6,17 @@ def get_view_matrices(): - """Return canonical view matrices for left/right/front/back/top/bottom.""" + """Return canonical 4x4 view matrices for common brain orientations. + + The returned dictionary maps :class:`whippersnappy.utils.types.ViewType` + enum members to corresponding 4x4 view matrices (dtype float32) that + can be used as camera/view transforms in the OpenGL renderer. + + Returns + ------- + dict + Mapping of :class:`ViewType` -> 4x4 numpy.ndarray view matrix. + """ view_left = np.array([[0, 0, -1, 0], [-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32) view_right = np.array([[0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32) view_back = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=np.float32) @@ -25,5 +35,16 @@ def get_view_matrices(): def get_view_matrix(view_type): - """Return a view matrix for a single view type.""" + """Return the 4x4 view matrix for a single :class:`ViewType`. + + Parameters + ---------- + view_type : ViewType + Enum member indicating the requested view. + + Returns + ------- + numpy.ndarray + 4x4 float32 view matrix. + """ return get_view_matrices()[view_type] diff --git a/whippersnappy/gui/config_app.py b/whippersnappy/gui/config_app.py index e824a79..95d5571 100644 --- a/whippersnappy/gui/config_app.py +++ b/whippersnappy/gui/config_app.py @@ -23,20 +23,24 @@ class ConfigWindow(QWidget): - """ - Encapsulates the Qt widget for the parameter configuration. + """Qt configuration window for interactive parameter tuning. + + The configuration window exposes sliders and text boxes to adjust the + f-threshold and f-max parameters used by the renderer. The widget is + intended to run alongside the OpenGL window and to push updated values + to the renderer via polling from the main loop. Parameters ---------- - parent : QWidget - This widget's parent, if any (usually none). - screen_dims : tuple - Integers specifying screen dims in pixels; used to always position - the window in the top-right corner, if given. - initial_fthresh_value : float - Initial fthreshold value is 2.0. - initial_fmax_value : float - Initial fmax value is 4.0. + parent : QWidget, optional + Parent Qt widget. Defaults to ``None``. + screen_dims : tuple or None, optional + (width, height) of the available screen; used to position the + window in the top-right corner when provided. + initial_fthresh_value : float, optional + Initial threshold value (default 2.0). + initial_fmax_value : float, optional + Initial fmax value (default 4.0). """ def __init__( @@ -153,11 +157,15 @@ def __init__( self.setGeometry(0, 0, self.window_size[0], self.window_size[1]) def fthresh_slider_value_cb(self): - """ - Callback function for user-modified fthresh slider. + """Handle changes from the f-threshold slider. - This function is triggered when the user modifies the fthresh slider. It - stores the selected value and updates the corresponding user input box. + This slot is connected to the slider's valueChanged signal. It maps the + slider tick value into the configured value range and updates the + text input box accordingly. + + Returns + ------- + None """ self.current_fthresh_value = self.convert_value_to_range( self.fthresh_slider.value(), @@ -167,22 +175,16 @@ def fthresh_slider_value_cb(self): self.fthresh_value_box.setText(str(self.current_fthresh_value)) def fthresh_value_cb(self, new_value): - """ - Callback function for user input of fthresh value. - - This function is triggered when the user inputs a value for fthresh. It - stores the selected value and updates the corresponding slider. + """Handle text input changes for f-threshold. Parameters ---------- new_value : float or str - The new value input by the user. It can be a float or a string that - can be converted to a float. + The new value input by the user. May be a float or numeric string. Returns ------- None - This function does not return any value. """ # Do not react to invalid values: try: @@ -200,11 +202,11 @@ def fthresh_value_cb(self, new_value): self.fthresh_slider.setValue(int(slider_fthresh_value)) def fmax_slider_value_cb(self): - """ - Callback function for user-modified fmax slider. + """Handle changes from the f-max slider and update the text box. - This function is triggered when the user modifies the fmax slider. It - stores the selected value and updates the corresponding user input box. + Returns + ------- + None """ self.current_fmax_value = self.convert_value_to_range( self.fmax_slider.value(), @@ -214,22 +216,16 @@ def fmax_slider_value_cb(self): self.fmax_value_box.setText(str(self.current_fmax_value)) def fmax_value_cb(self, new_value): - """ - Callback function for user input of fmax value. - - This function is triggered when the user inputs a value for fmax. It - stores the selected value and updates the corresponding slider. + """Handle text input changes for f-max. Parameters ---------- new_value : float or str - The new value input by the user. It can be a float or a string that - can be converted to a float. + New value provided by the user. Returns ------- - None - This function does not return any value. + None """ # Do not react to invalid values: try: @@ -247,25 +243,21 @@ def fmax_value_cb(self, new_value): self.fmax_slider.setValue(int(slider_fmax_value)) def convert_value_to_range(self, value, old_limits, new_limits): - """ - Convert a given number from one range to another. - - This is useful for transforming values from the original range to that - of the slider widget tick range and vice-versa. + """Map ``value`` from ``old_limits`` to ``new_limits``. Parameters ---------- value : float Value to be converted. old_limits : tuple - Minimum and maximum values that define the source range. + (min, max) source range. new_limits : tuple - Minimum and maximum values that define the target range. + (min, max) target range. Returns ------- - new_value : float - Converted value. + float + Value mapped into ``new_limits``. """ old_range = old_limits[1] - old_limits[0] new_range = new_limits[1] - new_limits[0] @@ -274,40 +266,38 @@ def convert_value_to_range(self, value, old_limits, new_limits): return new_value def get_fthresh_value(self): - """ - Return the current stores value for fthresh. + """Return the currently selected f-threshold value. Returns ------- - current_fthresh_value: float - Current fthresh value. + float + Current f-threshold value. """ return self.current_fthresh_value def get_fmax_value(self): - """ - Return the current stores value for fmax. + """Return the currently selected f-max value. Returns ------- - current_fmax_value : float - Current fmax value. + float + Current f-max value. """ return self.current_fmax_value def keyPressEvent(self, event): - """ - Close the window when the ESC key is pressed. + """Handle key press events for the window. + + The handler closes the window when the ESC key is pressed. Parameters ---------- event : QKeyEvent - The key event object representing the key press. + Qt key event delivered by the framework. Returns ------- None - This function return None. """ if event.key() == Qt.Key.Escape: self.close() diff --git a/whippersnappy/plot3d.py b/whippersnappy/plot3d.py index 90d9b96..ee780b7 100644 --- a/whippersnappy/plot3d.py +++ b/whippersnappy/plot3d.py @@ -19,6 +19,7 @@ """ import logging + import numpy as np import pythreejs as p3js from ipywidgets import HTML, VBox @@ -48,50 +49,54 @@ def plot3d( ): """Create an interactive 3D notebook viewer using pythreejs (Three.js). - This creates a browser-based interactive 3D viewer for Jupyter notebooks. - Works in all Jupyter environments (browser, JupyterLab, Colab, VS Code). - - Note: This is different from the desktop GUI (launched with --interactive flag). + This function prepares geometry and color information (via + :func:`whippersnappy.geometry.prepare_geometry`) and constructs a + pythreejs renderer and controls wrapped in an ``ipywidgets.VBox`` for + display inside a Jupyter notebook. Parameters ---------- meshpath : str - Path to surface file - overlaypath : str, optional - Path to overlay file - annotpath : str, optional - Path to annotation file - curvpath : str, optional - Path to curvature file - labelpath : str, optional - Path to label file - minval : float, optional - Minimum threshold for coloring - maxval : float, optional - Maximum value for color saturation - invert : bool, default False - Invert color map - scale : float, default 1.85 - Global scaling factor - color_mode : ColorSelection, optional - Select which values to color - width : int, default 800 - Canvas width - height : int, default 800 - Canvas height + Path to the surface file (FreeSurfer-style surface, e.g. "lh.white"). + overlaypath : str or None, optional + Path to a per-vertex overlay (thickness/curvature) file. + annotpath : str or None, optional + Path to a FreeSurfer .annot file for categorical labeling. + curvpath : str or None, optional + Path to a curvature file used as grayscale texture for unlabeled regions. + labelpath : str or None, optional + Path to a label file used to mask out vertices. + minval, maxval : float or None, optional + Threshold and saturation values used for color mapping (passed to + :func:`prepare_geometry`). If ``None``, sensible defaults are chosen. + invert : bool, optional, default False + If True, invert the overlay color map. + scale : float, optional, default 1.85 + Global geometry scale applied during preparation. + color_mode : ColorSelection or None, optional + Which sign of overlay values to color (BOTH/POSITIVE/NEGATIVE). + If None, defaults to ``ColorSelection.BOTH``. + width, height : int, optional, default 800 + Canvas dimensions for the generated renderer. Returns ------- - viewer : ipywidgets.VBox - Interactive 3D viewer widget + ipywidgets.VBox + A widget containing the pythreejs Renderer and a small info panel. + + Raises + ------ + ValueError, FileNotFoundError + Errors originating from :func:`prepare_geometry` (for example when + input arrays don't match the mesh vertex count) are propagated. - Examples - -------- - In a notebook: + Example + ------- + In a Jupyter notebook:: from whippersnappy import plot3d from IPython.display import display - viewer = plot3d('path/to/lh.white', curvpath='path/to/lh.curv') + viewer = plot3d('fsaverage/surf/lh.white', overlaypath='fsaverage/surf/lh.thickness') display(viewer) """ @@ -170,7 +175,28 @@ def plot3d( def create_threejs_mesh_with_custom_shaders(vertices, faces, colors, normals): - """Custom lighting shader - fixed for Three.js.""" + """Create a pythreejs.Mesh using custom shader material and buffers. + + The function builds a BufferGeometry with position, color and normal + attributes, attaches an index buffer, and creates a ShaderMaterial + using the WebGL shader snippets returned by :func:`get_webgl_shaders`. + + Parameters + ---------- + vertices : numpy.ndarray + Array of shape (N, 3) with vertex positions (float32). + faces : numpy.ndarray + Integer face index array shape (M, 3) or flattened (3*M,) dtype uint32. + colors : numpy.ndarray + Array of shape (N, 3) with per-vertex RGB colors (float32). + normals : numpy.ndarray + Array of shape (N, 3) with per-vertex normals (float32). + + Returns + ------- + pythreejs.Mesh + Mesh object ready to be inserted into a pythreejs.Scene. + """ vertices = vertices.astype(np.float32) colors = colors.astype(np.float32) diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py index 825ed3c..0ac067b 100644 --- a/whippersnappy/snap.py +++ b/whippersnappy/snap.py @@ -2,8 +2,8 @@ """ -import os import logging +import os import glfw import numpy as np @@ -51,7 +51,84 @@ def snap1( brain_scale=1.5, ambient=0.0, ): - """Snap one view (view and hemisphere is determined by the user).""" + """Render a single static snapshot of a surface view. + + This function opens an (offscreen) OpenGL context, uploads the provided + surface geometry and colors (overlay or annotation), renders the scene + for a single view, captures the framebuffer, and returns a PIL Image + containing the rendered brain view. When ``outpath`` is provided the + image is also written to disk. + + Parameters + ---------- + meshpath : str + Path to the surface file (FreeSurfer-format, e.g. "lh.white"). + outpath : str or None, optional + When provided, the resulting image is saved to this path. If ``None`` + the PIL Image object is returned. + overlaypath : str or None, optional + Path to overlay/mgh file providing per-vertex values to color the + surface. If ``None``, coloring falls back to curvature/annotation. + annotpath : str or None, optional + Path to a FreeSurfer .annot file with per-vertex labels. + labelpath : str or None, optional + Path to a label file (cortex.label) used to mask overlay values. + curvpath : str or None, optional + Path to curvature file used to texture non-colored regions. + view : ViewType, optional + Which pre-defined view to render (left, right, front, ...). Default is ``ViewType.LEFT``. + viewmat : 4x4 matrix-like, optional + Optional view matrix to override the pre-defined view. + width, height : int or None, optional + Requested overall canvas width/height in pixels. If ``None`` defaults + are used (700x500 reference). + fthresh, fmax : float or None, optional + Threshold and saturation values for overlay coloring. + caption, caption_x, caption_y, caption_scale : str/float, optional + Caption text and layout parameters. Caption defaults to ``None`` and caption_scale defaults to 1. + invert : bool, optional + Invert the color scale. Default is ``False``. + colorbar : bool, optional + If True, render a colorbar when an overlay is present. Default is ``True``. + colorbar_x, colorbar_y, colorbar_scale : float, optional + Colorbar positioning and scale flags. Scale defaults to 1. + orientation : OrientationType, optional + Orientation of the colorbar (HORIZONTAL/VERTICAL). Default is ``OrientationType.HORIZONTAL``. + color_mode : ColorSelection, optional + Which sign of overlay to color (POSITIVE/NEGATIVE/BOTH). Default is ``ColorSelection.BOTH``. + font_file : str or None, optional + Path to a TTF font for captions; fallback to bundled font if None. + specular : bool, optional + Enable specular highlights. Default is ``True``. + brain_scale : float, optional + Scale factor applied when preparing the geometry. Default is ``1.5``. + ambient : float, optional + Ambient lighting strength for shader. Default is ``0.0``. + + Returns + ------- + PIL.Image.Image or None + If ``outpath`` is ``None`` the function returns a PIL Image object + containing the rendered snapshot. If ``outpath`` is provided the + image is saved to disk and ``None`` is returned. + + Raises + ------ + RuntimeError + If the OpenGL/GLFW context cannot be initialized. + ValueError + If the overlay contains no values to display for the chosen + color_mode. + FileNotFoundError + If required surface files cannot be found when deriving from + SUBJECTS_DIR in multi-view helpers. + + Example + ------- + >>> from whippersnappy import snap1 + >>> img = snap1('fsaverage/surf/lh.white', overlaypath='fsaverage/surf/lh.thickness') + >>> img.save('/tmp/lh.png') + """ ref_width = 700 ref_height = 500 wwidth = ref_width if width is None else width @@ -130,7 +207,8 @@ def snap1( bar = None bar_w = bar_h = 0 if overlaypath is not None and colorbar: - bar = create_colorbar(fthresh, fmax, invert, orientation, colorbar_scale * ui_scale, pos, neg, font_file=font_file) + bar = create_colorbar(fthresh, fmax, invert, orientation, colorbar_scale * ui_scale, pos, neg, + font_file=font_file) bar_w, bar_h = bar.size font = None @@ -215,7 +293,74 @@ def snap4( ambient=0.0, brain_scale=1.85, ): - """Snap four views (front and back for left and right hemispheres).""" + """Render four snapshot views (left/right hemispheres, front/back). + + This convenience function renders four views (top/bottom for each + hemisphere), stitches them together into a single PIL Image and returns + it (or saves it to ``outpath`` when provided). It is typically used to + produce publication-ready overview figures composed from both + hemispheres. + + Parameters + ---------- + lhoverlaypath, rhoverlaypath : str or None + Paths to left/right hemisphere overlay files (mutually required if + either is provided). + lhannotpath, rhannotpath : str or None + Paths to left/right hemisphere annotation (.annot) files. + fthresh, fmax : float or None + Threshold and saturation for overlay coloring. + sdir : str or None + Subject directory (used when surfname is not provided). If not + supplied the environment variable ``SUBJECTS_DIR`` is consulted. + caption : str or None + Caption string to place on the final image. + invert : bool, optional + Invert color scale. Default is ``False``. + labelname : str, optional + Name of the label file (default 'cortex.label'). + surfname : str or None, optional + Surface basename to load (if None the function will auto-discover a + suitable surface). + curvname : str or None, optional + Curvature file basename to load for texturing non-colored regions. Default is ``curv``. + colorbar : bool, optional + Whether to draw a colorbar on the composed image. Default is ``True``. + outpath : str or None, optional + If provided, save composed image to this path and return ``None``. + font_file : str or None, optional + Path to a font to use for captions. + specular : bool, optional + Enable/disable specular highlights in the renderer. Default is ``True``. + ambient : float, optional + Ambient lighting strength. Default is ``0``. + brain_scale : float, optional + Scaling factor passed to geometry preparation. Default is ``1.85``. + + Returns + ------- + PIL.Image.Image or None + Composed image of the four views, or ``None`` if ``outpath`` was + provided and the image was written to disk. + + Raises + ------ + ValueError + For invalid argument combinations or when required overlay values + are absent. + FileNotFoundError + When required surface files are not found. + + Example + ------- + >>> from whippersnappy import snap4 + >>> img = snap4( + >>> lhoverlaypath='fsaverage/surf/lh.thickness', + >>> rhoverlaypath='fsaverage/surf/rh.thickness', + >>> sdir='./fsaverage' + >>> ) + >>> img.save('/tmp/whippersnappy_overview.png') + """ wwidth = 540 wheight = 450 # Try to create a visible window first (better for debugging), @@ -282,7 +427,8 @@ def snap4( # Diagnostics about mesh data try: - logger.debug("meshdata shape: %s; triangles count: %s", getattr(meshdata, 'shape', None), getattr(triangles, 'size', None)) + logger.debug("meshdata shape: %s; triangles count: %s", getattr(meshdata, 'shape', None), + getattr(triangles, 'size', None)) except Exception: pass diff --git a/whippersnappy/utils/colormap.py b/whippersnappy/utils/colormap.py index 26a6100..292ee76 100644 --- a/whippersnappy/utils/colormap.py +++ b/whippersnappy/utils/colormap.py @@ -1,17 +1,35 @@ """Colormap and value preprocessing utilities.""" import logging + import numpy as np from whippersnappy.utils.types import ColorSelection - # Module logger logger = logging.getLogger(__name__) def heat_color(values, invert=False): - """Convert an array of float values into RBG heat color values.""" + """Convert an array of float values into RGB heat color values. + + Maps scalar values to RGB triplets suitable for visualization. Input + values are expected to be in a symmetric range around zero; mapping + produces blue-to-red heat colors. NaN inputs propagate to NaN outputs. + + Parameters + ---------- + values : array_like + 1-D array of float values to map. May include NaNs. + invert : bool, optional + If True, invert the sign of the input values before mapping. + Default is False. + + Returns + ------- + numpy.ndarray + Array of shape (N, 3) and dtype float32 with RGB channels in [0, 1]. + """ if invert: values = -1.0 * values vabs = np.abs(values) @@ -37,7 +55,21 @@ def heat_color(values, invert=False): def mask_sign(values, color_mode): - """Mask values that don't have the same sign as color_mode.""" + """Mask values that don't match the requested sign selection. + + Parameters + ---------- + values : array_like + Input numeric array. + color_mode : ColorSelection + Enum indicating which sign to preserve (POSITIVE, NEGATIVE, BOTH). + + Returns + ------- + numpy.ndarray + Copy of ``values`` where elements not matching the requested sign + are set to ``np.nan``. + """ masked_values = np.copy(values) if color_mode == ColorSelection.POSITIVE: masked_values[masked_values < 0] = np.nan @@ -46,8 +78,33 @@ def mask_sign(values, color_mode): return masked_values -def rescale_overlay(values, minval=None, maxval=None): - """Rescale values for color map computation.""" +def rescale_overlay(values, minval, maxval): + """Rescale overlay values into a normalized range for colormap computation. + + Values whose absolute magnitude is below ``minval`` are set to ``NaN``. + Remaining values are shifted by ``minval`` and divided by ``(maxval - minval)``. + + Parameters + ---------- + values : numpy.ndarray + Numeric array of overlay values (1-D). + minval : float + Minimum absolute threshold — values with abs < minval are treated as absent. + maxval : float + Maximum absolute value used for normalization. + + Returns + ------- + tuple + ``(values, minval, maxval, pos, neg)`` where ``values`` is the rescaled + array, and ``pos``/``neg`` are booleans indicating presence of positive + / negative values after rescaling. + + Raises + ------ + ValueError + If ``minval`` or ``maxval`` is negative. + """ valsign = np.sign(values) valabs = np.abs(values) @@ -70,7 +127,23 @@ def rescale_overlay(values, minval=None, maxval=None): def binary_color(values, thres, color_low, color_high): - """Create a binary colormap based on a threshold value.""" + """Create a binary colormap for values based on a threshold. + + Parameters + ---------- + values : array_like + 1-D array of values to map. + thres : float + Threshold value used to split the colors. + color_low, color_high : scalar or sequence + Colors assigned to values below/above the threshold. Scalars are + expanded to RGB triplets. + + Returns + ------- + numpy.ndarray + Array of shape (N, 3) and dtype float32 containing RGB colors. + """ if np.isscalar(color_low): color_low = np.array((color_low, color_low, color_low), dtype=np.float32) if np.isscalar(color_high): @@ -82,7 +155,24 @@ def binary_color(values, thres, color_low, color_high): def mask_label(values, labelpath=None): - """Apply a label file as a mask.""" + """Apply a label file as a mask to an array of per-vertex values. + + If ``labelpath`` is provided the function loads vertex indices from the + label file and sets all entries not listed in the label to ``NaN``. + + Parameters + ---------- + values : numpy.ndarray + 1-D array indexed by vertex id. + labelpath : str or None, optional + Path to a label file readable by ``numpy.loadtxt`` (expected format + with vertex ids in the first column after two header lines). + + Returns + ------- + numpy.ndarray + Array with vertices not included in the label set to ``np.nan``. + """ if not labelpath: return values maskvids = np.loadtxt(labelpath, dtype=int, skiprows=2, usecols=[0]) diff --git a/whippersnappy/utils/image.py b/whippersnappy/utils/image.py index 52c2645..57b8499 100644 --- a/whippersnappy/utils/image.py +++ b/whippersnappy/utils/image.py @@ -17,7 +17,20 @@ def text_size(caption, font): - """Return text width and height in pixels.""" + """Return text width and height in pixels for a given caption and font. + + Parameters + ---------- + caption : str + Text to measure. + font : PIL.ImageFont.FreeTypeFont or similar + Font object used for measurement. + + Returns + ------- + (width, height) : tuple[int, int] + Pixel dimensions of rendered text. + """ dummy_img = Image.new("L", (1, 1)) draw = ImageDraw.Draw(dummy_img) bbox = draw.textbbox((0, 0), caption, font=font, anchor="lt") @@ -35,7 +48,28 @@ def get_colorbar_label_positions( neg=True, orientation=OrientationType.HORIZONTAL, ): - """Return label positions for a colorbar.""" + """Compute positions for colorbar label text. + + Parameters + ---------- + font : PIL.ImageFont + Font used to measure text sizes. + labels : dict + Mapping of label keys to text strings (e.g. 'upper','lower','middle'). + colorbar_rect : tuple + Rectangle for the colorbar (x, y, width, height). + gapspace : int, optional, default 0 + Additional spacing used for split colorbars. + pos, neg : bool, optional, default True, True + Whether positive/negative sides are present. + orientation : OrientationType, optional, default OrientationType.HORIZONTAL + Orientation of the colorbar. + + Returns + ------- + positions : dict + Mapping of label key -> (x, y) pixel position. + """ positions = {} cb_x, cb_y, cb_width, cb_height = colorbar_rect cb_labels_gap = 5 @@ -106,9 +140,32 @@ def create_colorbar( neg=True, font_file=None, ): - """Create a colorbar image (PIL.Image) using the project's heat_color. - - Parameters mirror the previous implementation in `render.py`. + """Create a colored colorbar as a PIL.Image. + + The colorbar visualizes the overlay color mapping (using + :func:`whippersnappy.utils.colormap.heat_color`) and optionally draws + numeric labels for the min/threshold/saturation positions. + + Parameters + ---------- + fmin, fmax : float + Threshold and saturation values used to label the colorbar. + invert : bool + Invert the heat color mapping. + orientation : OrientationType, optional, default OrientationType.HORIZONTAL + Orientation of the colorbar (HORIZONTAL/VERTICAL). + colorbar_scale : float, optional, default 1 + Scale factor for resulting image size. + pos, neg : bool, optional, default True, True + Whether the colorbar has positive/negative regions. + font_file : str or None, optional + Path to a TTF font file to use for labels. + + Returns + ------- + PIL.Image.Image or None + A PIL image containing the colorbar, or ``None`` if inputs are + insufficient (e.g. fmin/fmax are None). """ # If fmin/fmax are not specified, we cannot create a meaningful colorbar. if fmin is None or fmax is None: @@ -154,7 +211,8 @@ def create_colorbar( with resources.as_file(font_trav) as font_path: font = ImageFont.truetype(str(font_path), int(12 * colorbar_scale)) except Exception: - warnings.warn("Roboto font not found in package resources; falling back to default font", UserWarning) + warnings.warn("Roboto font not found in package resources; falling back to default font", + UserWarning, stacklevel=2) font = ImageFont.load_default() else: try: @@ -198,10 +256,18 @@ def create_colorbar( def load_roboto_font(size=14): - """Load bundled Roboto-Regular.ttf from package resources. - - Returns a PIL ImageFont instance. Falls back to ImageFont.load_default() - if the bundled font isn't available. + """Load the bundled Roboto font from package resources. + + Parameters + ---------- + size : int, optional + Requested point size. + + Returns + ------- + PIL.ImageFont.FreeTypeFont or PIL.ImageFont.ImageFont or None + A PIL font object; falls back to ``ImageFont.load_default()`` or + ``None`` if fonts cannot be loaded. """ try: # resources was imported earlier in this module @@ -209,7 +275,8 @@ def load_roboto_font(size=14): with resources.as_file(font_trav) as font_path: return ImageFont.truetype(str(font_path), size) except Exception: - warnings.warn("Roboto font not found in package resources; falling back to default font", UserWarning) + warnings.warn("Roboto font not found in package resources; falling back to default font", UserWarning, + stacklevel=2) try: return ImageFont.load_default() except Exception: From 297e812138db696319b46e4ef809f34c10f7d21d Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Mon, 16 Feb 2026 20:09:44 +0100 Subject: [PATCH 04/83] update testing --- tests/test_config.py | 2 +- whippersnappy/__init__.py | 2 +- whippersnappy/{utils => }/_config.py | 75 +++++++++++++++++++++++++--- whippersnappy/_version.py | 2 +- 4 files changed, 72 insertions(+), 9 deletions(-) rename whippersnappy/{utils => }/_config.py (55%) diff --git a/tests/test_config.py b/tests/test_config.py index 961f472..5210734 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,6 @@ from io import StringIO -from .._config import sys_info +from whippersnappy._config import sys_info def test_sys_info(): diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index d649b9a..83584e2 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -45,7 +45,7 @@ from ._version import __version__ # noqa: F401 from .snap import snap1, snap4 -from .utils._config import sys_info # noqa: F401 +from ._config import sys_info # noqa: F401 from .utils.types import ViewType # 3D plotting for notebooks (Three.js-based, works in all Jupyter environments) diff --git a/whippersnappy/utils/_config.py b/whippersnappy/_config.py similarity index 55% rename from whippersnappy/utils/_config.py rename to whippersnappy/_config.py index 2c2ca61..97a5ff6 100644 --- a/whippersnappy/utils/_config.py +++ b/whippersnappy/_config.py @@ -1,3 +1,7 @@ +"""Configuration and system-info helpers (top-level module). + +""" + import platform import re import sys @@ -41,10 +45,43 @@ def sys_info(fid: Optional[IO] = None, developer: bool = False): # dependencies out("\nDependencies info\n") - out(f"{package}:".ljust(ljust) + version(package) + "\n") - dependencies = [ - elt.split(";")[0].rstrip() for elt in requires(package) if "extra" not in elt - ] + # package version may not be available when running tests from the + # repository root (package not installed). Handle gracefully. + try: + pkg_version = version(package) + except Exception: + pkg_version = "Not installed." + out(f"{package}:".ljust(ljust) + pkg_version + "\n") + + try: + raw_requires = requires(package) or [] + except Exception: + raw_requires = [] + + # If package metadata is not present (e.g. running from source), try to + # read dependencies declared in pyproject.toml so tests running against + # the tree without installing still report expected deps. + if not raw_requires: + try: + from pathlib import Path + try: + import tomllib as _toml + except Exception: + _toml = None + + repo_root = Path(__file__).resolve().parents[1] + pyproject_path = repo_root / "pyproject.toml" + if _toml is not None and pyproject_path.exists(): + with pyproject_path.open("rb") as fh: + data = _toml.load(fh) + proj = data.get("project", {}) + deps = proj.get("dependencies", []) or [] + # dependencies may be in the form 'pkg>=1.2' etc. + raw_requires = deps + except Exception: + raw_requires = [] + + dependencies = [elt.split(";")[0].rstrip() for elt in raw_requires if "extra" not in elt] _list_dependencies_info(out, ljust, dependencies) # extras @@ -56,10 +93,36 @@ def sys_info(fid: Optional[IO] = None, developer: bool = False): "style", ) for key in keys: + try: + raw_requires = requires(package) or [] + except Exception: + raw_requires = [] + + # If package metadata missing, fall back to pyproject.toml optional-dependencies + if not raw_requires: + try: + from pathlib import Path + try: + import tomllib as _toml + except Exception: + _toml = None + + repo_root = Path(__file__).resolve().parents[1] + pyproject_path = repo_root / "pyproject.toml" + if _toml is not None and pyproject_path.exists(): + with pyproject_path.open("rb") as fh: + data = _toml.load(fh) + proj = data.get("project", {}) + opt = proj.get("optional-dependencies", {}) or {} + deps = opt.get(key, []) or [] + raw_requires = deps + except Exception: + raw_requires = [] + dependencies = [ elt.split(";")[0].rstrip() - for elt in requires(package) - if f"extra == '{key}'" in elt or f'extra == "{key}"' in elt + for elt in raw_requires + if f"extra == '{key}'" in elt or f"extra == \"{key}\"" in elt or True ] if len(dependencies) == 0: continue diff --git a/whippersnappy/_version.py b/whippersnappy/_version.py index bf95b14..3481ece 100644 --- a/whippersnappy/_version.py +++ b/whippersnappy/_version.py @@ -5,4 +5,4 @@ __version__ = version(__package__) except Exception: # Fallback when package is not installed (e.g., running from source) - __version__ = "1.4.0-dev" + __version__ = "dev" From 6479bc4bb185405bec90646281d8e35ed449eee2 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Mon, 16 Feb 2026 22:01:15 +0100 Subject: [PATCH 05/83] bump actions to newer python --- .codespellignore | 4 ---- .github/workflows/build.yml | 2 +- .github/workflows/code-style.yml | 14 +++++--------- .github/workflows/doc.yml | 18 ++++++++++-------- .github/workflows/pytest.yml | 2 +- pyproject.toml | 6 ++++++ 6 files changed, 23 insertions(+), 23 deletions(-) delete mode 100644 .codespellignore diff --git a/.codespellignore b/.codespellignore deleted file mode 100644 index 26c58f9..0000000 --- a/.codespellignore +++ /dev/null @@ -1,4 +0,0 @@ -coo -daty -anormal -wheight diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2b5c781..b493b94 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -15,7 +15,7 @@ jobs: fail-fast: false matrix: os: [ubuntu, macos, windows] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] name: ${{ matrix.os }} - py${{ matrix.python-version }} runs-on: ${{ matrix.os }}-latest defaults: diff --git a/.github/workflows/code-style.yml b/.github/workflows/code-style.yml index 889727e..ecc3e00 100644 --- a/.github/workflows/code-style.yml +++ b/.github/workflows/code-style.yml @@ -15,10 +15,11 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v6 - - name: Setup Python 3.10 + - name: Setup Python 3.13 uses: actions/setup-python@v6 with: - python-version: '3.10' + python-version: '3.13' + cache: 'pip' - name: Install dependencies run: | python -m pip install --progress-bar off --upgrade pip setuptools wheel @@ -27,10 +28,5 @@ jobs: run: ruff check . - name: Run codespell uses: codespell-project/actions-codespell@master - with: - check_filenames: true - check_hidden: true - skip: './.git,./build,./.mypy_cache,./.pytest_cache' - ignore_words_file: ./.codespellignore - # - name: Run pydocstyle - # run: pydocstyle . + - name: Run pydocstyle + run: pydocstyle . diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 2290615..f1e7b56 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -20,17 +20,17 @@ jobs: uses: actions/checkout@v6 with: path: ./main - - name: Setup Python 3.10 + - name: Setup Python 3.13 uses: actions/setup-python@v6 with: - python-version: '3.10' - - name: Install Linux dependencies - if: runner.os == 'Linux' + python-version: '3.13' + cache: pip + - name: Install system dependencies run: | - sudo apt update + # sudo apt update sudo apt-get update - sudo apt-get -y install libgl1 libegl1 - sudo apt install libxcb-cursor0 + sudo apt-get -y install libgl1 libegl1 libxcb-cursor0 pandoc + # sudo apt install libxcb-cursor0 - name: Install package run: | python -m pip install --progress-bar off --upgrade pip setuptools wheel @@ -43,7 +43,9 @@ jobs: uses: actions/upload-artifact@v6 with: name: doc-dev - path: ./doc-build/dev + path: | + doc-build/dev + !doc-build/dev/.doctrees deploy: if: github.event_name == 'push' diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index f427086..af2d59a 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: os: [ubuntu, macos, windows] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] name: ${{ matrix.os }} - py${{ matrix.python-version }} runs-on: ${{ matrix.os }}-latest defaults: diff --git a/pyproject.toml b/pyproject.toml index 751b3d6..c0f6c17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,6 +140,12 @@ select = [ [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] +[tool.codespell] +#ignore-words-list = 'coo,daty' +check-filenames = true +check-hidden = true +skip = './.git,./build,./.mypy_cache,./.pytest_cache' + [tool.pytest.ini_options] minversion = '6.0' addopts = '--durations 20 --junit-xml=junit-results.xml --verbose' From 0f41d593aed1603484880c259d2dba9fd8e56ac7 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Mon, 16 Feb 2026 23:42:12 +0100 Subject: [PATCH 06/83] fix input order --- whippersnappy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index 83584e2..6e6e002 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -43,9 +43,9 @@ """ +from ._config import sys_info # noqa: F401 from ._version import __version__ # noqa: F401 from .snap import snap1, snap4 -from ._config import sys_info # noqa: F401 from .utils.types import ViewType # 3D plotting for notebooks (Three.js-based, works in all Jupyter environments) From ebd9b6278a2733581918b8fa37295d708e160782 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Mon, 16 Feb 2026 23:49:25 +0100 Subject: [PATCH 07/83] re-add ignore-world-list --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c0f6c17..3a4abdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,7 +141,7 @@ select = [ "__init__.py" = ["F401"] [tool.codespell] -#ignore-words-list = 'coo,daty' +ignore-words-list = 'aNormal,wheight' check-filenames = true check-hidden = true skip = './.git,./build,./.mypy_cache,./.pytest_cache' From ee0cc3f2551f5b8b5ca58590b7800b2dd02ec707 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 17 Feb 2026 00:06:44 +0100 Subject: [PATCH 08/83] update doc --- doc/api/index.rst | 5 +++-- doc/conf.py | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/doc/api/index.rst b/doc/api/index.rst index 2b668c7..3083ecc 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -8,7 +8,8 @@ API References :toctree: generated/ - config_app.ConfigWindow - core + snap + plot3d + gui.config_app.ConfigWindow cli.whippersnap diff --git a/doc/conf.py b/doc/conf.py index 1bc8980..c0e1c26 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -101,6 +101,26 @@ autodoc_warningiserror = True autoclass_content = "class" +# Mock modules that may not be available in the doc builder environment +# (PyQt6, OpenGL, GLFW, pythreejs, etc.). Adjust this list if your builder +# provides any of these packages. +autodoc_mock_imports = [ + "PyQt6", + "PyQt6.QtWidgets", + "PyQt6.QtCore", + "PyQt6.QtGui", + "glfw", + "OpenGL", + "OpenGL.GL", + "OpenGL.GL.shaders", + "pythreejs", + "ipywidgets", + "pyopengl", + "pyrr", + "PIL", + "matplotlib", +] + # -- intersphinx ------------------------------------------------------------- intersphinx_mapping = { "matplotlib": ("https://matplotlib.org/stable", None), From 36b859a432ff7fb39fcef76dd4d7b47ca2e7004c Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 17 Feb 2026 00:12:44 +0100 Subject: [PATCH 09/83] fix pydocstyle errors --- whippersnappy/_config.py | 6 +----- whippersnappy/plot3d.py | 2 -- whippersnappy/snap.py | 4 +--- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/whippersnappy/_config.py b/whippersnappy/_config.py index 97a5ff6..112450d 100644 --- a/whippersnappy/_config.py +++ b/whippersnappy/_config.py @@ -1,6 +1,4 @@ -"""Configuration and system-info helpers (top-level module). - -""" +"""Configuration and system-info helpers (top-level module).""" import platform import re @@ -23,7 +21,6 @@ def sys_info(fid: Optional[IO] = None, developer: bool = False): developer : bool, default=False If True, display information about optional dependencies. """ - ljust = 26 out = partial(print, end="", file=fid) package = __package__.split(".")[0] @@ -143,7 +140,6 @@ def _list_dependencies_info(out: Callable, ljust: int, dependencies: list[str]): list of dependencies """ - for dep in dependencies: # handle dependencies with version specifiers specifiers_pattern = r"(~=|==|!=|<=|>=|<|>|===)" diff --git a/whippersnappy/plot3d.py b/whippersnappy/plot3d.py index ee780b7..86ba4fe 100644 --- a/whippersnappy/plot3d.py +++ b/whippersnappy/plot3d.py @@ -99,7 +99,6 @@ def plot3d( viewer = plot3d('fsaverage/surf/lh.white', overlaypath='fsaverage/surf/lh.thickness') display(viewer) """ - # Load and prepare mesh data color_mode = color_mode or ColorSelection.BOTH meshdata, triangles, fmin, fmax, pos, neg = prepare_geometry( @@ -197,7 +196,6 @@ def create_threejs_mesh_with_custom_shaders(vertices, faces, colors, normals): pythreejs.Mesh Mesh object ready to be inserted into a pythreejs.Scene. """ - vertices = vertices.astype(np.float32) colors = colors.astype(np.float32) normals = normals.astype(np.float32) diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py index 0ac067b..b109bb8 100644 --- a/whippersnappy/snap.py +++ b/whippersnappy/snap.py @@ -1,6 +1,4 @@ -"""Snapshot (static rendering) API for WhipperSnapPy. - -""" +"""Snapshot (static rendering) API for WhipperSnapPy.""" import logging import os From 678b2bab968a3757df3faa3bdc396908dbfe568b Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 17 Feb 2026 00:26:57 +0100 Subject: [PATCH 10/83] fix some docstrings --- whippersnappy/cli/whippersnap.py | 11 ------ whippersnappy/gui/config_app.py | 23 +---------- whippersnappy/plot3d.py | 4 +- whippersnappy/snap.py | 8 ++-- whippersnappy/utils/types.py | 65 +++++++++++++++++++++++++++++--- 5 files changed, 66 insertions(+), 45 deletions(-) diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index 26e8fd4..13bfc42 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -209,10 +209,6 @@ def config_app_exit_handler(): This handler is connected to the configuration app's about-to-quit signal and sets a module-level flag that the main OpenGL loop polls to terminate cleanly. - - Returns - ------- - None """ global app_window_closed_ app_window_closed_ = True @@ -224,9 +220,6 @@ def run(): Parses command-line arguments, validates argument combinations, and either launches a non-interactive snapshot generation (``snap4``) or starts the interactive viewer and configuration GUI. - - Behavior - -------- - Validates that either overlay or annotation inputs are provided for both hemispheres (or raises ``ValueError``). - In non-interactive mode calls :func:`whippersnappy.snap4` to produce @@ -234,10 +227,6 @@ def run(): - In interactive mode spawns the OpenGL viewer thread and launches the PyQt6-based configuration window in the main thread. - Returns - ------- - None - Raises ------ ValueError diff --git a/whippersnappy/gui/config_app.py b/whippersnappy/gui/config_app.py index 95d5571..857943a 100644 --- a/whippersnappy/gui/config_app.py +++ b/whippersnappy/gui/config_app.py @@ -162,10 +162,6 @@ def fthresh_slider_value_cb(self): This slot is connected to the slider's valueChanged signal. It maps the slider tick value into the configured value range and updates the text input box accordingly. - - Returns - ------- - None """ self.current_fthresh_value = self.convert_value_to_range( self.fthresh_slider.value(), @@ -181,10 +177,6 @@ def fthresh_value_cb(self, new_value): ---------- new_value : float or str The new value input by the user. May be a float or numeric string. - - Returns - ------- - None """ # Do not react to invalid values: try: @@ -202,12 +194,7 @@ def fthresh_value_cb(self, new_value): self.fthresh_slider.setValue(int(slider_fthresh_value)) def fmax_slider_value_cb(self): - """Handle changes from the f-max slider and update the text box. - - Returns - ------- - None - """ + """Handle changes from the f-max slider and update the text box.""" self.current_fmax_value = self.convert_value_to_range( self.fmax_slider.value(), self.fmax_slider_tick_limits, @@ -222,10 +209,6 @@ def fmax_value_cb(self, new_value): ---------- new_value : float or str New value provided by the user. - - Returns - ------- - None """ # Do not react to invalid values: try: @@ -294,10 +277,6 @@ def keyPressEvent(self, event): ---------- event : QKeyEvent Qt key event delivered by the framework. - - Returns - ------- - None """ if event.key() == Qt.Key.Escape: self.close() diff --git a/whippersnappy/plot3d.py b/whippersnappy/plot3d.py index 86ba4fe..32a325d 100644 --- a/whippersnappy/plot3d.py +++ b/whippersnappy/plot3d.py @@ -90,8 +90,8 @@ def plot3d( Errors originating from :func:`prepare_geometry` (for example when input arrays don't match the mesh vertex count) are propagated. - Example - ------- + Examples + -------- In a Jupyter notebook:: from whippersnappy import plot3d diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py index b109bb8..abbfa7f 100644 --- a/whippersnappy/snap.py +++ b/whippersnappy/snap.py @@ -121,8 +121,8 @@ def snap1( If required surface files cannot be found when deriving from SUBJECTS_DIR in multi-view helpers. - Example - ------- + Examples + -------- >>> from whippersnappy import snap1 >>> img = snap1('fsaverage/surf/lh.white', overlaypath='fsaverage/surf/lh.thickness') >>> img.save('/tmp/lh.png') @@ -349,8 +349,8 @@ def snap4( FileNotFoundError When required surface files are not found. - Example - ------- + Examples + -------- >>> from whippersnappy import snap4 >>> img = snap4( >>> lhoverlaypath='fsaverage/surf/lh.thickness', diff --git a/whippersnappy/utils/types.py b/whippersnappy/utils/types.py index b707a20..eae363f 100644 --- a/whippersnappy/utils/types.py +++ b/whippersnappy/utils/types.py @@ -1,29 +1,82 @@ """Contains the types used in WhipperSnapPy. -Dependencies: - enum +This module defines small enumeration types used across the package for +controlling color selection, colorbar orientation, and predefined views. -@Author : Abdulla Ahmadkhan -@Created : 02.10.2025 -@Revised : 02.10.2025 +Classes +------- +ColorSelection + Which sign(s) of overlay values should be used to produce colors. +OrientationType + Orientation of UI elements such as the colorbar (horizontal or vertical). +ViewType + Predefined canonical view orientations for rendering the brain surface. +Examples +-------- +>>> from whippersnappy.utils.types import ColorSelection, ViewType +>>> ColorSelection.BOTH + +>>> ViewType.LEFT + """ + import enum class ColorSelection(enum.Enum): + """Enum to select which sign(s) of overlay values to color. + + Members + ------- + BOTH : int + Use both positive and negative values for coloring. + POSITIVE : int + Use only positive values for coloring. + NEGATIVE : int + Use only negative values for coloring. + """ BOTH = 1 POSITIVE = 2 NEGATIVE = 3 + class OrientationType(enum.Enum): + """Enum describing orientation choices for elements like the colorbar. + + Members + ------- + HORIZONTAL : int + Layout along the horizontal axis. + VERTICAL : int + Layout along the vertical axis. + """ HORIZONTAL = 1 VERTICAL = 2 + class ViewType(enum.Enum): + """Predefined canonical view directions used by snapshot renderers. + + Members + ------- + LEFT : int + Left hemisphere lateral view. + RIGHT : int + Right hemisphere lateral view. + BACK : int + Posterior view. + FRONT : int + Anterior/frontal view. + TOP : int + Superior/top view. + BOTTOM : int + Inferior/bottom view. + """ LEFT = 1 RIGHT = 2 BACK = 3 FRONT = 4 TOP = 5 - BOTTOM = 6 \ No newline at end of file + BOTTOM = 6 + From fbadd9283414560602f7042773cc36b3dc17bcff Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 17 Feb 2026 00:35:18 +0100 Subject: [PATCH 11/83] explicit test path --- .github/workflows/pytest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index af2d59a..54be48f 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -39,7 +39,7 @@ jobs: - name: Display system information run: whippersnappy-sys_info --developer - name: Run pytest - run: pytest whippersnappy --cov=whippersnappy --cov-report=xml --cov-config=pyproject.toml + run: pytest tests --cov=whippersnappy --cov-report=xml --cov-config=pyproject.toml - name: Upload to codecov if: ${{ matrix.os == 'ubuntu' && matrix.python-version == '3.10' && github.repository == 'Deep-MI/WhipperSnapPy' }} uses: codecov/codecov-action@v5 From 1488c0ca44195598f9df2b075fe1cda19fa86b7d Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 17 Feb 2026 00:39:35 +0100 Subject: [PATCH 12/83] improve doc strings --- whippersnappy/cli/whippersnap.py | 18 ++++++++++++------ whippersnappy/utils/types.py | 12 ++++++------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index 13bfc42..65fc0b2 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -220,12 +220,18 @@ def run(): Parses command-line arguments, validates argument combinations, and either launches a non-interactive snapshot generation (``snap4``) or starts the interactive viewer and configuration GUI. - - Validates that either overlay or annotation inputs are provided for - both hemispheres (or raises ``ValueError``). - - In non-interactive mode calls :func:`whippersnappy.snap4` to produce - and optionally save a composed image. - - In interactive mode spawns the OpenGL viewer thread and launches the - PyQt6-based configuration window in the main thread. + + Notes + ----- + The function validates that either overlay or annotation inputs are + provided for both hemispheres; it raises ``ValueError`` for invalid + combinations. + + In non-interactive mode the function calls :func:`whippersnappy.snap4` + to produce and optionally save a composed image. + + In interactive mode it spawns the OpenGL viewer thread and launches + the PyQt6-based configuration window in the main thread. Raises ------ diff --git a/whippersnappy/utils/types.py b/whippersnappy/utils/types.py index eae363f..04555f4 100644 --- a/whippersnappy/utils/types.py +++ b/whippersnappy/utils/types.py @@ -27,8 +27,8 @@ class ColorSelection(enum.Enum): """Enum to select which sign(s) of overlay values to color. - Members - ------- + Attributes + ---------- BOTH : int Use both positive and negative values for coloring. POSITIVE : int @@ -44,8 +44,8 @@ class ColorSelection(enum.Enum): class OrientationType(enum.Enum): """Enum describing orientation choices for elements like the colorbar. - Members - ------- + Attributes + ---------- HORIZONTAL : int Layout along the horizontal axis. VERTICAL : int @@ -58,8 +58,8 @@ class OrientationType(enum.Enum): class ViewType(enum.Enum): """Predefined canonical view directions used by snapshot renderers. - Members - ------- + Attributes + ---------- LEFT : int Left hemisphere lateral view. RIGHT : int From b4849ac06d00bd5a28aca847fb067a9822b07d3c Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 17 Feb 2026 00:50:50 +0100 Subject: [PATCH 13/83] more sphinx fixes --- doc/conf.py | 5 ++++- whippersnappy/cli/whippersnap.py | 14 +++++++------- whippersnappy/utils/types.py | 8 -------- 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index c0e1c26..352c5e0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -183,6 +183,10 @@ r"\.__div__", r"\.__neg__", } +# Exclude the small types module from numpydoc validation. The enum classes +# in `whippersnappy.utils.types` intentionally do not document Enum +# constructor varargs; exclude the module to silence PR01 for this file only. +numpydoc_validation_exclude.update({r"^whippersnappy\.utils\.types($|\.)"}) # -- sphinxcontrib-bibtex ---------------------------------------------------- bibtex_bibfiles = ["./references.bib"] @@ -276,4 +280,3 @@ def ensure_pandoc_installed(_): def setup(app): app.connect("builder-inited", ensure_pandoc_installed) - diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index 65fc0b2..7246660 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -221,6 +221,13 @@ def run(): either launches a non-interactive snapshot generation (``snap4``) or starts the interactive viewer and configuration GUI. + Raises + ------ + ValueError + For invalid or mutually exclusive argument combinations. + ImportError + If interactive mode is requested but PyQt6 is not available. + Notes ----- The function validates that either overlay or annotation inputs are @@ -232,13 +239,6 @@ def run(): In interactive mode it spawns the OpenGL viewer thread and launches the PyQt6-based configuration window in the main thread. - - Raises - ------ - ValueError - For invalid or mutually exclusive argument combinations. - ImportError - If interactive mode is requested but PyQt6 is not available. """ global current_fthresh_, current_fmax_, app_, app_window_ # Configure basic logging for CLI invocation so messages from module loggers diff --git a/whippersnappy/utils/types.py b/whippersnappy/utils/types.py index 04555f4..8110eb7 100644 --- a/whippersnappy/utils/types.py +++ b/whippersnappy/utils/types.py @@ -11,14 +11,6 @@ Orientation of UI elements such as the colorbar (horizontal or vertical). ViewType Predefined canonical view orientations for rendering the brain surface. - -Examples --------- ->>> from whippersnappy.utils.types import ColorSelection, ViewType ->>> ColorSelection.BOTH - ->>> ViewType.LEFT - """ import enum From fe6ff6014555a4a9ac6cb7f7c0c28022ed4f3878 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 17 Feb 2026 00:53:27 +0100 Subject: [PATCH 14/83] more sphinx fixes --- doc/conf.py | 4 ---- whippersnappy/utils/types.py | 19 +++++++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 352c5e0..e558aa1 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -183,10 +183,6 @@ r"\.__div__", r"\.__neg__", } -# Exclude the small types module from numpydoc validation. The enum classes -# in `whippersnappy.utils.types` intentionally do not document Enum -# constructor varargs; exclude the module to silence PR01 for this file only. -numpydoc_validation_exclude.update({r"^whippersnappy\.utils\.types($|\.)"}) # -- sphinxcontrib-bibtex ---------------------------------------------------- bibtex_bibfiles = ["./references.bib"] diff --git a/whippersnappy/utils/types.py b/whippersnappy/utils/types.py index 8110eb7..8d0e417 100644 --- a/whippersnappy/utils/types.py +++ b/whippersnappy/utils/types.py @@ -19,6 +19,13 @@ class ColorSelection(enum.Enum): """Enum to select which sign(s) of overlay values to color. + Parameters + ---------- + *values : tuple + Positional arguments passed to the Enum constructor (not used by + consumers of this enum). Documented here to satisfy documentation + linters that inspect the class signature. + Attributes ---------- BOTH : int @@ -36,6 +43,12 @@ class ColorSelection(enum.Enum): class OrientationType(enum.Enum): """Enum describing orientation choices for elements like the colorbar. + Parameters + ---------- + *values : tuple + Positional arguments passed to the Enum constructor (not used by + consumers of this enum). + Attributes ---------- HORIZONTAL : int @@ -50,6 +63,12 @@ class OrientationType(enum.Enum): class ViewType(enum.Enum): """Predefined canonical view directions used by snapshot renderers. + Parameters + ---------- + *values : tuple + Positional arguments passed to the Enum constructor (not used by + consumers of this enum). + Attributes ---------- LEFT : int From df597a71bf97959ef7a89def0a9d97a72abfada8 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 00:47:05 +0100 Subject: [PATCH 15/83] types toplevel, better window capture --- examples/whippersnappy_demo.ipynb | 27 ++++---- whippersnappy/__init__.py | 2 +- whippersnappy/gl/utils.py | 69 ++++++++++++++------ whippersnappy/snap.py | 103 +++++++++++++----------------- 4 files changed, 111 insertions(+), 90 deletions(-) diff --git a/examples/whippersnappy_demo.ipynb b/examples/whippersnappy_demo.ipynb index 8010485..fbafd8b 100644 --- a/examples/whippersnappy_demo.ipynb +++ b/examples/whippersnappy_demo.ipynb @@ -25,8 +25,7 @@ "\n", "from IPython.display import display\n", "\n", - "from whippersnappy import plot3d, snap1, snap4\n", - "from whippersnappy.utils.types import ViewType\n" + "from whippersnappy import ViewType, plot3d, snap1, snap4\n" ] }, { @@ -136,10 +135,11 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "ca7e2c155177b7c5", + "metadata": {}, + "outputs": [], "source": [ "# Part 2: Interactive 3D Rendering\n", "# Mouse-controlled 3D visualization using Three.js (works in all Jupyter environments).\n", @@ -157,14 +157,14 @@ " height=800,\n", ")\n", "display(viewer)\n" - ], - "id": "ca7e2c155177b7c5" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "f712c59cf4a54a0d", + "metadata": {}, + "outputs": [], "source": [ "# Interactive colored overlay (if available)\n", "if lh_overlay:\n", @@ -181,14 +181,14 @@ " display(viewer)\n", "else:\n", " print(\"Thickness overlay not found - skipping colored example\")\n" - ], - "id": "f712c59cf4a54a0d" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "9d5f61470c130e6e", + "metadata": {}, + "outputs": [], "source": [ "# Interactive label map overlay\n", "if lh_annot_path:\n", @@ -202,8 +202,7 @@ " display(viewer)\n", "else:\n", " print(\"Annot overlay not found - skipping label map example\")\n" - ], - "id": "9d5f61470c130e6e" + ] }, { "cell_type": "code", diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index 6e6e002..052b68c 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -46,7 +46,7 @@ from ._config import sys_info # noqa: F401 from ._version import __version__ # noqa: F401 from .snap import snap1, snap4 -from .utils.types import ViewType +from .utils.types import ColorSelection, OrientationType, ViewType # 3D plotting for notebooks (Three.js-based, works in all Jupyter environments) try: diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index bfcf0c1..7ecb202 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -3,7 +3,7 @@ Contains the implementation of OpenGL helpers used by the package. """ -import sys +import logging import glfw import OpenGL.GL as gl @@ -13,6 +13,9 @@ from .camera import make_model, make_projection, make_view from .shaders import get_default_shaders +# Module logger +logger = logging.getLogger(__name__) + def create_vao(): """Create and bind a Vertex Array Object (VAO). @@ -232,34 +235,64 @@ def setup_shader(meshdata, triangles, width, height, specular=True, ambient=0.0) return shader - -def capture_window(width, height): +def capture_window(window): """Read the current GL framebuffer and return it as a PIL.Image (RGB). - On macOS (retina) this function reads at double resolution and downscales - the result to compensate for pixel-density differences. + This function captures the framebuffer for the provided GLFW `window` + and returns an RGB :class:`PIL.Image.Image`. On HiDPI displays (e.g. + macOS Retina) the framebuffer may be larger than the logical window + size; the function will downscale the captured physical framebuffer to + logical pixel dimensions when a non-1.0 monitor content scale is + detected. Parameters ---------- - width, height : int - Desired output image dimensions. + window : GLFWwindow + GLFW window handle whose current OpenGL context/framebuffer will be + read. The function calls :func:`glfw.get_framebuffer_size` to obtain + the read dimensions and :func:`glfw.get_primary_monitor` / + :func:`glfw.get_monitor_content_scale` to detect the display scale. Returns ------- PIL.Image.Image - RGB image containing the captured framebuffer content. + RGB image containing the captured framebuffer content. On standard + (1x) displays the returned image has the same dimensions as the + framebuffer. On HiDPI displays the image is downscaled to logical + window dimensions (framebuffer size divided by the monitor content + scale) using ``Image.Resampling.LANCZOS``. + + Notes + ----- + - The function uses ``glReadPixels`` with ``GL_PACK_ALIGNMENT=1`` and + converts the raw bytes into a PIL image, performing a vertical flip + to convert OpenGL's bottom-left origin to the image top-left origin. + - Prefer :func:`glfw.get_window_content_scale` or + :func:`glfw.get_monitor_content_scale` to detect per-window/monitor + scaling. The function currently uses the primary monitor's content + scale as a heuristic for HiDPI detection. + - If strict static analyzers complain about ``Image.FLIP_TOP_BOTTOM`` + you can switch to ``Image.Transpose.FLIP_TOP_BOTTOM`` for newer + Pillow versions. """ - if sys.platform == "darwin": - rwidth = 2 * width - rheight = 2 * height - else: - rwidth = width - rheight = height + # Get primary monitor + monitor = glfw.get_primary_monitor() + # Get scale factors + x_scale, y_scale = glfw.get_monitor_content_scale(monitor) + # Get framebuffer size + width, height = glfw.get_framebuffer_size(window) + + logger.debug("Framebuffer size = (%s,%s)", width, height) + logger.debug("Monitor scale = (%s,%s)", x_scale, y_scale) gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) - img_buf = gl.glReadPixels(0, 0, rwidth, rheight, gl.GL_RGB, gl.GL_UNSIGNED_BYTE) - image = Image.frombytes("RGB", (rwidth, rheight), img_buf) + img_buf = gl.glReadPixels(0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE) + image = Image.frombytes("RGB", (width, height), img_buf) image = image.transpose(Image.FLIP_TOP_BOTTOM) - if sys.platform == "darwin": - image.thumbnail((0.5 * rwidth, 0.5 * rheight), Image.Resampling.LANCZOS) + + if x_scale != 1 or y_scale != 1: + rwidth = int(round(width / x_scale)) + rheight = int(round(height / y_scale)) + logger.debug("Rescale to = (%s,%s)", rwidth, rheight) + image.thumbnail((rwidth, rheight), Image.Resampling.LANCZOS) return image diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py index abbfa7f..fd7e683 100644 --- a/whippersnappy/snap.py +++ b/whippersnappy/snap.py @@ -29,8 +29,8 @@ def snap1( curvpath=None, view=ViewType.LEFT, viewmat=None, - width=None, - height=None, + width=700, + height=500, fthresh=None, fmax=None, caption=None, @@ -51,7 +51,7 @@ def snap1( ): """Render a single static snapshot of a surface view. - This function opens an (offscreen) OpenGL context, uploads the provided + This function opens an OpenGL context, uploads the provided surface geometry and colors (overlay or annotation), renders the scene for a single view, captures the framebuffer, and returns a PIL Image containing the rendered brain view. When ``outpath`` is provided the @@ -62,8 +62,7 @@ def snap1( meshpath : str Path to the surface file (FreeSurfer-format, e.g. "lh.white"). outpath : str or None, optional - When provided, the resulting image is saved to this path. If ``None`` - the PIL Image object is returned. + When provided, the resulting image is saved to this path. overlaypath : str or None, optional Path to overlay/mgh file providing per-vertex values to color the surface. If ``None``, coloring falls back to curvature/annotation. @@ -77,9 +76,8 @@ def snap1( Which pre-defined view to render (left, right, front, ...). Default is ``ViewType.LEFT``. viewmat : 4x4 matrix-like, optional Optional view matrix to override the pre-defined view. - width, height : int or None, optional - Requested overall canvas width/height in pixels. If ``None`` defaults - are used (700x500 reference). + width, height : int, optional + Requested overall canvas width/height in pixels. Defaults to (700x500). fthresh, fmax : float or None, optional Threshold and saturation values for overlay coloring. caption, caption_x, caption_y, caption_scale : str/float, optional @@ -105,10 +103,8 @@ def snap1( Returns ------- - PIL.Image.Image or None - If ``outpath`` is ``None`` the function returns a PIL Image object - containing the rendered snapshot. If ``outpath`` is provided the - image is saved to disk and ``None`` is returned. + PIL.Image.Image + Returns a PIL Image object containing the rendered snapshot. Raises ------ @@ -129,9 +125,7 @@ def snap1( """ ref_width = 700 ref_height = 500 - wwidth = ref_width if width is None else width - wheight = ref_height if height is None else height - ui_scale = min(wwidth / ref_width, wheight / ref_height) + ui_scale = min(width / ref_width, height / ref_height) if not glfw.init(): logger.error("Could not init glfw!") @@ -140,24 +134,26 @@ def snap1( mode = glfw.get_video_mode(primary_monitor) screen_width = mode.size.width screen_height = mode.size.height - if wwidth > screen_width: - logger.info("Requested width %d exceeds screen width %d, expect black bars", wwidth, screen_width) - elif wheight > screen_height: - logger.info("Requested height %d exceeds screen height %d, expect black bars", wheight, screen_height) + if width > screen_width: + logger.info("Requested width %d exceeds screen width %d, expect black bars", width, screen_width) + elif height > screen_height: + logger.info("Requested height %d exceeds screen height %d, expect black bars", height, screen_height) - image = Image.new("RGB", (wwidth, wheight)) + image = Image.new("RGB", (width, height)) bwidth = int(540 * brain_scale * ui_scale) bheight = int(450 * brain_scale * ui_scale) - brain_display_width = min(bwidth, wwidth) - brain_display_height = min(bheight, wheight) - - window = _gl.init_window(brain_display_width, brain_display_height, "WhipperSnapPy 2.0", visible=True) + brain_display_width = min(bwidth, width) + brain_display_height = min(bheight, height) + logger.debug("Requested (width,height) = (%s,%s)", width, height) + logger.debug("Screen (width,height) = (%s,%s)", screen_width, screen_height) + logger.debug("Brain (width,height) = (%s,%s)", bwidth, bheight) + logger.debug("B-Display (width,height) = (%s,%s)", brain_display_width, brain_display_height) + + window = _gl.init_window(brain_display_width, brain_display_height, "WhipperSnapPy", visible=True) if not window: return False - transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) - meshdata, triangles, fthresh, fmax, pos, neg = prepare_geometry( meshpath, overlaypath, @@ -191,15 +187,17 @@ def snap1( gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) transform_loc = gl.glGetUniformLocation(shader, "transform") + # Small translation to move the brain into the view frustum + transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) view_mats = get_view_matrices() viewmat = transl * (view_mats[view] if viewmat is None else viewmat) gl.glUniformMatrix4fv(transform_loc, 1, gl.GL_FALSE, viewmat) gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) - im1 = _gl.capture_window(brain_display_width, brain_display_height) + im1 = _gl.capture_window(window) - brain_x = 0 if wwidth < bwidth else (wwidth - bwidth) // 2 - brain_y = 0 if wheight < bheight else (wheight - bheight) // 2 + brain_x = 0 if width < bwidth else (width - bwidth) // 2 + brain_y = 0 if height < bheight else (height - bheight) // 2 image.paste(im1, (brain_x, brain_y)) bar = None @@ -229,17 +227,17 @@ def snap1( if orientation == OrientationType.HORIZONTAL: if bar is not None: - bx = int(0.5 * (image.width - bar_w)) if colorbar_x is None else int(colorbar_x * wwidth) + bx = int(0.5 * (image.width - bar_w)) if colorbar_x is None else int(colorbar_x * width) if colorbar_y is None: gap_and_caption = (gap + text_h) if caption and caption_y is None else 0 by = image.height - bottom_pad - gap_and_caption - bar_h else: - by = int(colorbar_y * wheight) + by = int(colorbar_y * height) image.paste(bar, (bx, by)) if caption: - cx = int(0.5 * (image.width - text_w)) if caption_x is None else int(caption_x * wwidth) - cy = image.height - bottom_pad - text_h if caption_y is None else int(caption_y * wheight) + cx = int(0.5 * (image.width - text_w)) if caption_x is None else int(caption_x * width) + cy = image.height - bottom_pad - text_h if caption_y is None else int(caption_y * height) ImageDraw.Draw(image).text((cx, cy), caption, (220, 220, 220), font=font, anchor="lt") else: if bar is not None: @@ -247,8 +245,8 @@ def snap1( gap_and_caption = (gap + text_h) if caption and caption_x is None else 0 bx = image.width - right_pad - gap_and_caption - bar_w else: - bx = int(colorbar_x * wwidth) - by = int(0.5 * (image.height - bar_h)) if colorbar_y is None else int(colorbar_y * wheight) + bx = int(colorbar_x * width) + by = int(0.5 * (image.height - bar_h)) if colorbar_y is None else int(colorbar_y * height) image.paste(bar, (bx, by)) if caption: @@ -257,18 +255,15 @@ def snap1( rotated_caption = temp_caption_img.rotate(90, expand=True, fillcolor=(0, 0, 0, 0)) rotated_w, rotated_h = rotated_caption.size - cx = image.width - right_pad - rotated_w if caption_x is None else int(caption_x * wwidth) - cy = int(0.5 * (image.height - rotated_h)) if caption_y is None else int(caption_y * wheight) + cx = image.width - right_pad - rotated_w if caption_x is None else int(caption_x * width) + cy = int(0.5 * (image.height - rotated_h)) if caption_y is None else int(caption_y * height) image.paste(rotated_caption, (cx, cy), rotated_caption) - if outpath is None: - glfw.terminate() - return image - - logger.info("Saving snapshot to %s", outpath) - image.save(outpath) + if outpath: + logger.info("Saving snapshot to %s", outpath) + image.save(outpath) glfw.terminate() - return None + return image def snap4( @@ -295,7 +290,7 @@ def snap4( This convenience function renders four views (top/bottom for each hemisphere), stitches them together into a single PIL Image and returns - it (or saves it to ``outpath`` when provided). It is typically used to + it (and saves it to ``outpath`` when provided). It is typically used to produce publication-ready overview figures composed from both hemispheres. @@ -325,7 +320,7 @@ def snap4( colorbar : bool, optional Whether to draw a colorbar on the composed image. Default is ``True``. outpath : str or None, optional - If provided, save composed image to this path and return ``None``. + If provided, save composed image to this path. font_file : str or None, optional Path to a font to use for captions. specular : bool, optional @@ -337,9 +332,8 @@ def snap4( Returns ------- - PIL.Image.Image or None - Composed image of the four views, or ``None`` if ``outpath`` was - provided and the image was written to disk. + PIL.Image.Image + Composed image of the four views. Raises ------ @@ -454,7 +448,7 @@ def snap4( gl.glUniformMatrix4fv(transform_loc, 1, gl.GL_FALSE, transl * viewmat) gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) try: - im1 = _gl.capture_window(wwidth, wheight) + im1 = _gl.capture_window(window) logger.debug("Captured image 1 size: %s", im1.size) except Exception as e: logger.error("capture_window failed: %s", e) @@ -473,7 +467,7 @@ def snap4( gl.glUniformMatrix4fv(transform_loc, 1, gl.GL_FALSE, transl * viewmat) gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) try: - im2 = _gl.capture_window(wwidth, wheight) + im2 = _gl.capture_window(window) logger.debug("Captured image 2 size: %s", im2.size) except Exception as e: logger.error("capture_window failed: %s", e) @@ -522,16 +516,11 @@ def snap4( ypos = int(0.5 * (image.height - bar.height)) image.paste(bar, (xpos, ypos)) - # If outpath is None, return the PIL Image object directly (no disk I/O) - if outpath is None: - glfw.terminate() - return image - # Otherwise save to disk if outpath: logger.info("Saving snapshot to %s", outpath) image.save(outpath) glfw.terminate() - return None + return image From 5b0a5e6e898fd99af1e47b512f5aa3cd0d81861a Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 01:15:13 +0100 Subject: [PATCH 16/83] make QT6 optional for gui --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3a4abdd..af5de1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,6 @@ dependencies = [ 'pillow', 'pyopengl==3.1.6', 'nibabel', - 'PyQt6', 'psutil' ] @@ -70,6 +69,9 @@ notebook = [ 'pythreejs', # Three.js for interactive 3D (works in all Jupyter environments) 'ipywidgets', # Required for pythreejs ] +gui = [ + 'PyQt6', +] style = [ 'bibclean', 'codespell', @@ -87,6 +89,7 @@ all = [ 'whippersnappy[style]', 'whippersnappy[test]', 'whippersnappy[notebook]', + 'whippersnappy[gui]', ] full = [ 'whippersnappy[all]', From 452f26c788d19520bc4a7060d74ec9e4a9902353 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 01:15:37 +0100 Subject: [PATCH 17/83] remove Qt6 in docker (no gui) --- Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 8a12393..9c768ee 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:20.04 +FROM ubuntu:24.04 # Install packages RUN apt-get update && apt-get install -y --no-install-recommends \ @@ -9,7 +9,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ # Install python packages RUN pip install --upgrade pip -RUN pip install pyopengl glfw pillow numpy pyrr PyQt6 +RUN pip install pyopengl glfw pillow numpy pyrr COPY . /WhipperSnapPy RUN pip install /WhipperSnapPy From 567a6fe49132984eaa235d0770ce72f478941ba2 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 01:16:08 +0100 Subject: [PATCH 18/83] update documentation slightly --- doc/index.rst | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/doc/index.rst b/doc/index.rst index 439547d..e3c824d 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -10,7 +10,7 @@ Welcome to whippersnappy's documentation! ========================================= WhipperSnapPY is a small Python OpenGL program to render -FreeSurfer and FastSurfer surface models and color overlays and generate screen shots. +FreeSurfer and FastSurfer surface models with color overlays or parcellations and generate screen shots. License ------- @@ -21,16 +21,17 @@ A full copy of the license can be found `on GitHub `_. Contents -------- -- Capture 4x4 surface plots (front & back, left and right) -- OpenGL window for interactive visualization +- Snap1: Capture a single shapshot of a surface with an overlay +- Snap4: Capture 4x4 surface plots (front & back, left and right) of a Free- or FastSurfer brain surface with an overlay +- Plot3d: Interactive 3D WebGL visualization in IPython notebooks +- OpenGL QT GUI for interactive visualization Note, that currently no off-screen rendering is supported. Even in snap mode an invisible window will be created to render the openGL output and capture the contents to an image. In order to run this on a headless server, inside Docker, or via ssh we recommend to install xvfb and run .. code-block:: bash apt update && apt install -y python3 python3-pip xvfb - pip3 install pyopengl glfw pillow numpy pyrr PyQt5==5.15.6 - pip3 install . + pip3 install whippersnappy xvfb-run whippersnap ... Installation From 2f980f54106387a397f69ec1e4e4127cb6eb6811 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 07:12:48 +0100 Subject: [PATCH 19/83] fix typo --- doc/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/index.rst b/doc/index.rst index e3c824d..81e9c4e 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -21,7 +21,7 @@ A full copy of the license can be found `on GitHub `_. Contents -------- -- Snap1: Capture a single shapshot of a surface with an overlay +- Snap1: Capture a single snapshot of a surface with an overlay - Snap4: Capture 4x4 surface plots (front & back, left and right) of a Free- or FastSurfer brain surface with an overlay - Plot3d: Interactive 3D WebGL visualization in IPython notebooks - OpenGL QT GUI for interactive visualization From b895807d8b554e93a153042f58f718a4d0898de4 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 15:39:44 +0100 Subject: [PATCH 20/83] local imports, further refactor of snap1 and snap4 to avoid code duplication --- whippersnappy/__init__.py | 2 +- whippersnappy/cli/whippersnap.py | 23 ++- whippersnappy/geometry/prepare.py | 83 +++++++++- whippersnappy/gl/__init__.py | 1 + whippersnappy/gl/utils.py | 63 ++++++++ whippersnappy/gl/views.py | 2 +- whippersnappy/plot3d.py | 2 +- whippersnappy/snap.py | 260 ++++++++++-------------------- whippersnappy/utils/colormap.py | 2 +- whippersnappy/utils/image.py | 63 +++++++- 10 files changed, 307 insertions(+), 194 deletions(-) diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index 052b68c..6e6e002 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -46,7 +46,7 @@ from ._config import sys_info # noqa: F401 from ._version import __version__ # noqa: F401 from .snap import snap1, snap4 -from .utils.types import ColorSelection, OrientationType, ViewType +from .utils.types import ViewType # 3D plotting for notebooks (Three.js-based, works in all Jupyter environments) try: diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index 7246660..9f63433 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -38,13 +38,12 @@ # GUI dependency missing; handle at runtime when interactive mode is requested QApplication = None -from whippersnappy import snap4 -from whippersnappy.geometry import get_surf_name, prepare_geometry -from whippersnappy.gl import ( - init_window, - setup_shader, +from .. import snap4 +from ..geometry import get_surf_name, prepare_geometry +from ..gl import ( + init_window, render_scene, setup_shader, capture_window, get_view_matrices ) -from whippersnappy.gui import ConfigWindow +from ..gui import ConfigWindow # Module logger logger = logging.getLogger(__name__) @@ -119,7 +118,8 @@ def show_window( weight = 600 window = init_window(wwidth, weight, "WhipperSnapPy", visible=True) if not window: - return False + logger.error("Could not create any GLFW window/context. OpenGL context unavailable.") + raise RuntimeError("Could not create any GLFW window/context. OpenGL context unavailable.") if surfname is None: logger.info("No surf_name provided. Looking for options in surf directory...") @@ -155,9 +155,7 @@ def show_window( logger.info("\nKeys:\nLeft - Right : Rotate Geometry\nESC : Quit\n") ypos = 0 - while glfw.get_key( - window, glfw.KEY_ESCAPE - ) != glfw.PRESS and not glfw.window_should_close(window): + while glfw.get_key(window, glfw.KEY_ESCAPE) != glfw.PRESS and not glfw.window_should_close(window): # Terminate if config app window was closed: if app_window_closed_: break @@ -181,9 +179,7 @@ def show_window( current_fthresh_, current_fmax_, ) - shader = setup_shader( - meshdata, triangles, wwidth, weight, specular=specular - ) + shader = setup_shader(meshdata, triangles, wwidth, weight, specular=specular) transformLoc = gl.glGetUniformLocation(shader, "transform") gl.glUniformMatrix4fv(transformLoc, 1, gl.GL_FALSE, rot_y * viewLeft) @@ -196,7 +192,6 @@ def show_window( # Draw gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) - glfw.swap_buffers(window) glfw.terminate() diff --git a/whippersnappy/geometry/prepare.py b/whippersnappy/geometry/prepare.py index 8bd63f9..6db8350 100644 --- a/whippersnappy/geometry/prepare.py +++ b/whippersnappy/geometry/prepare.py @@ -9,9 +9,9 @@ import numpy as np -from whippersnappy.geometry.read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data -from whippersnappy.utils.colormap import binary_color, heat_color, mask_label, mask_sign, rescale_overlay -from whippersnappy.utils.types import ColorSelection +from .read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data +from ..utils.colormap import binary_color, heat_color, mask_label, mask_sign, rescale_overlay +from ..utils.types import ColorSelection def normalize_mesh(v, scale=1.0): @@ -231,3 +231,80 @@ def prepare_geometry( vertexdata = np.concatenate((vertices, vnormals, colors), axis=1) return vertexdata, triangles, fmin, fmax, pos, neg + + +def prepare_and_validate_geometry( + meshpath, + overlaypath, + annotpath, + curvpath, + labelpath, + fthresh, + fmax, + invert, + scale, + color_mode, +): + """Load and validate mesh geometry and overlay/annotation inputs. + + This is a small wrapper around :func:`prepare_geometry` + that performs the same overlay presence validation used throughout the + static snapshot helpers. + + Parameters + ---------- + meshpath, overlaypath, annotpath, curvpath, labelpath : str or None + Paths passed through to :func:`prepare_geometry`. + fthresh, fmax : float or None + Threshold and saturation values passed to the geometry preparer. + invert : bool + Passed to the geometry preparer. + scale : float + Scaling factor passed to the geometry preparer. + color_mode : ColorSelection + Which sign of overlay to display (POSITIVE/NEGATIVE/BOTH). + + Returns + ------- + tuple + ``(meshdata, triangles, fthresh, fmax, pos, neg)`` as returned by + :func:`prepare_geometry`. + + Raises + ------ + ValueError + If the overlay contains no values appropriate for ``color_mode``. + """ + import logging + logger = logging.getLogger(__name__) + meshdata, triangles, out_fthresh, out_fmax, pos, neg = prepare_geometry( + meshpath, + overlaypath, + annotpath, + curvpath, + labelpath, + fthresh, + fmax, + invert, + scale=scale, + color_mode=color_mode, + ) + + # Validate overlay presence similar to previous inline checks + if overlaypath is not None: + if color_mode == ColorSelection.POSITIVE: + if not pos and neg: + logger.error("Overlay has no values to display with positive color_mode") + raise ValueError("Overlay has no values to display with positive color_mode") + neg = False + elif color_mode == ColorSelection.NEGATIVE: + if pos and not neg: + logger.error("Overlay has no values to display with negative color_mode") + raise ValueError("Overlay has no values to display with negative color_mode") + pos = False + if not pos and not neg: + logger.error("Overlay has no values to display") + raise ValueError("Overlay has no values to display") + + return meshdata, triangles, out_fthresh, out_fmax, pos, neg + diff --git a/whippersnappy/gl/__init__.py b/whippersnappy/gl/__init__.py index f2894dc..73996dd 100644 --- a/whippersnappy/gl/__init__.py +++ b/whippersnappy/gl/__init__.py @@ -13,6 +13,7 @@ capture_window, compile_shader_program, create_vao, + create_window_with_fallback, init_window, set_camera_uniforms, set_default_gl_state, diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index 7ecb202..09ff93e 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -193,6 +193,36 @@ def init_window(width, height, title="PyOpenGL", visible=True): return window +def create_window_with_fallback(width, height, title="WhipperSnapPy", visible=True): + """Create a GLFW window, preferring a visible window and falling back to an invisible one. + + Parameters + ---------- + width : int + Requested window width in logical pixels. + height : int + Requested window height in logical pixels. + title : str, optional + Window title. Default is ``'WhipperSnapPy'``. + visible : bool, optional + Prefer a visible window when True (default). If creation fails the + function will retry with an invisible/offscreen window. + + Returns + ------- + GLFWwindow or None + The created GLFW window handle, or ``None`` if creation failed. + """ + window = init_window(width, height, title, visible=visible) + if not window and visible: + logger.warning("Could not create visible GLFW window; retrying with invisible window (offscreen).") + window = init_window(width, height, title, visible=False) + if not window: + logger.error("Could not create any GLFW window/context. OpenGL context unavailable.") + raise RuntimeError("Could not create any GLFW window/context. OpenGL context unavailable.") + return window + + def setup_shader(meshdata, triangles, width, height, specular=True, ambient=0.0): """Create shader program, upload mesh and initialize camera & lighting. @@ -296,3 +326,36 @@ def capture_window(window): logger.debug("Rescale to = (%s,%s)", rwidth, rheight) image.thumbnail((rwidth, rheight), Image.Resampling.LANCZOS) return image + +def render_scene(shader, triangles, transform): + """Render a single draw call using the supplied shader/indices. + + Parameters + ---------- + shader : int + OpenGL shader program handle. + triangles : numpy.ndarray + Element/index array used for the draw call. + transform : array-like + 4x4 transform matrix (model/view/projection combined) to upload to + the shader uniform named ``transform``. + + Raises + ------ + RuntimeError + If a GL error occurs during rendering. + """ + try: + gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) + except Exception as exc: + logger.error("glClear failed: %s", exc) + raise RuntimeError(f"glClear failed: {exc}") + + transform_loc = gl.glGetUniformLocation(shader, "transform") + gl.glUniformMatrix4fv(transform_loc, 1, gl.GL_FALSE, transform) + gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) + + err = gl.glGetError() + if err != gl.GL_NO_ERROR: + logger.error("OpenGL error after draw: %s", err) + raise RuntimeError(f"OpenGL error after draw: {err}") diff --git a/whippersnappy/gl/views.py b/whippersnappy/gl/views.py index 5a04b51..507b89f 100644 --- a/whippersnappy/gl/views.py +++ b/whippersnappy/gl/views.py @@ -2,7 +2,7 @@ import numpy as np -from whippersnappy.utils.types import ViewType +from ..utils.types import ViewType def get_view_matrices(): diff --git a/whippersnappy/plot3d.py b/whippersnappy/plot3d.py index 32a325d..88938f3 100644 --- a/whippersnappy/plot3d.py +++ b/whippersnappy/plot3d.py @@ -24,7 +24,7 @@ import pythreejs as p3js from ipywidgets import HTML, VBox -from whippersnappy.utils.types import ColorSelection +from .utils.types import ColorSelection from .geometry import prepare_geometry from .gl import get_webgl_shaders diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py index fd7e683..e7515a9 100644 --- a/whippersnappy/snap.py +++ b/whippersnappy/snap.py @@ -4,17 +4,16 @@ import os import glfw -import numpy as np -import OpenGL.GL as gl import pyrr from PIL import Image, ImageDraw, ImageFont -from whippersnappy.geometry import get_surf_name, prepare_geometry -from whippersnappy.utils.image import create_colorbar, load_roboto_font, text_size -from whippersnappy.utils.types import ColorSelection, OrientationType, ViewType +from .geometry import get_surf_name +from .utils.image import create_colorbar, load_roboto_font, text_size, draw_colorbar, draw_caption +from .utils.types import ColorSelection, OrientationType, ViewType -from . import gl as _gl -from .gl import get_view_matrices +from .gl.views import get_view_matrices +from .gl.utils import render_scene, create_window_with_fallback, capture_window, setup_shader +from .geometry.prepare import prepare_and_validate_geometry # Module logger logger = logging.getLogger(__name__) @@ -150,11 +149,10 @@ def snap1( logger.debug("Brain (width,height) = (%s,%s)", bwidth, bheight) logger.debug("B-Display (width,height) = (%s,%s)", brain_display_width, brain_display_height) - window = _gl.init_window(brain_display_width, brain_display_height, "WhipperSnapPy", visible=True) - if not window: - return False + # will raise exception if it cannot be created + window = create_window_with_fallback(brain_display_width, brain_display_height, "WhipperSnapPy", visible=True) - meshdata, triangles, fthresh, fmax, pos, neg = prepare_geometry( + meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry( meshpath, overlaypath, annotpath, @@ -167,97 +165,60 @@ def snap1( color_mode=color_mode, ) - if overlaypath is not None: - if color_mode == ColorSelection.POSITIVE: - if not pos and neg: - logger.error("Overlay has no values to display with positive color_mode") - raise ValueError("Overlay has no values to display with positive color_mode") - neg = False - elif color_mode == ColorSelection.NEGATIVE: - if pos and not neg: - logger.error("Overlay has no values to display with negative color_mode") - raise ValueError("Overlay has no values to display with negative color_mode") - pos = False - if not pos and not neg: - logger.error("Overlay has no values to display") - raise ValueError("Overlay has no values to display") - - shader = _gl.setup_shader(meshdata, triangles, brain_display_width, brain_display_height, + shader = setup_shader(meshdata, triangles, brain_display_width, brain_display_height, specular=specular, ambient=ambient) - gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) - transform_loc = gl.glGetUniformLocation(shader, "transform") - # Small translation to move the brain into the view frustum transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) view_mats = get_view_matrices() viewmat = transl * (view_mats[view] if viewmat is None else viewmat) - gl.glUniformMatrix4fv(transform_loc, 1, gl.GL_FALSE, viewmat) - gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) - - im1 = _gl.capture_window(window) - - brain_x = 0 if width < bwidth else (width - bwidth) // 2 - brain_y = 0 if height < bheight else (height - bheight) // 2 - image.paste(im1, (brain_x, brain_y)) - - bar = None - bar_w = bar_h = 0 - if overlaypath is not None and colorbar: - bar = create_colorbar(fthresh, fmax, invert, orientation, colorbar_scale * ui_scale, pos, neg, - font_file=font_file) - bar_w, bar_h = bar.size - - font = None - text_w = text_h = 0 - if caption: - if font_file is None: - font = load_roboto_font(int(20 * caption_scale * ui_scale)) - else: - try: - font = ImageFont.truetype(font_file, int(20 * caption_scale * ui_scale)) - except Exception: - font = load_roboto_font(int(20 * caption_scale * ui_scale)) - text_w, text_h = text_size(caption, font) - text_w = int(text_w) - text_h = int(text_h) + render_scene(shader, triangles, viewmat) - bottom_pad = int(20 * ui_scale) - right_pad = int(20 * ui_scale) + # Center the brain rendering in the output image, clamp to zero + brain_x = max(0, (width - brain_display_width) // 2) + brain_y = max(0, (height - brain_display_height) // 2) + image.paste(capture_window(window), (brain_x, brain_y)) + + bar = create_colorbar(fthresh, fmax, invert, orientation, colorbar_scale * ui_scale, pos, neg, font_file=font_file) if overlaypath is not None and colorbar else None + font = load_roboto_font(int(20 * caption_scale * ui_scale)) if font_file is None else ImageFont.truetype(font_file, int(20 * caption_scale * ui_scale)) if caption else None + + # Compute positions to avoid overlap, unless explicit positions are given + text_w, text_h = text_size(caption, font) if caption and font else (0, 0) + bar_h = bar.height if bar is not None else 0 gap = int(4 * ui_scale) + bottom_pad = int(20 * ui_scale) if orientation == OrientationType.HORIZONTAL: - if bar is not None: - bx = int(0.5 * (image.width - bar_w)) if colorbar_x is None else int(colorbar_x * width) - if colorbar_y is None: - gap_and_caption = (gap + text_h) if caption and caption_y is None else 0 - by = image.height - bottom_pad - gap_and_caption - bar_h + # If explicit positions are given, use them + if colorbar_x is not None or colorbar_y is not None or caption_x is not None or caption_y is not None: + bx = int(colorbar_x * width) if colorbar_x is not None else None + by = int(colorbar_y * height) if colorbar_y is not None else None + cx = int(caption_x * width) if caption_x is not None else None + cy = int(caption_y * height) if caption_y is not None else None + draw_colorbar(image, bar, orientation, x=bx, y=by) + draw_caption(image, caption, font, orientation, x=cx, y=cy) + else: + # Place colorbar above caption if both present + if bar is not None and caption: + bar_y = image.height - bottom_pad - text_h - gap - bar_h + caption_y = image.height - bottom_pad - text_h + elif bar is not None: + bar_y = image.height - bottom_pad - bar_h + caption_y = None + elif caption: + bar_y = None + caption_y = image.height - bottom_pad - text_h else: - by = int(colorbar_y * height) - image.paste(bar, (bx, by)) - - if caption: - cx = int(0.5 * (image.width - text_w)) if caption_x is None else int(caption_x * width) - cy = image.height - bottom_pad - text_h if caption_y is None else int(caption_y * height) - ImageDraw.Draw(image).text((cx, cy), caption, (220, 220, 220), font=font, anchor="lt") + bar_y = caption_y = None + draw_colorbar(image, bar, orientation, y=bar_y) + draw_caption(image, caption, font, orientation, y=caption_y) else: - if bar is not None: - if colorbar_x is None: - gap_and_caption = (gap + text_h) if caption and caption_x is None else 0 - bx = image.width - right_pad - gap_and_caption - bar_w - else: - bx = int(colorbar_x * width) - by = int(0.5 * (image.height - bar_h)) if colorbar_y is None else int(colorbar_y * height) - image.paste(bar, (bx, by)) - - if caption: - temp_caption_img = Image.new("RGBA", (text_w, text_h), (0, 0, 0, 0)) - ImageDraw.Draw(temp_caption_img).text((0, 0), caption, font=font, anchor="lt") - rotated_caption = temp_caption_img.rotate(90, expand=True, fillcolor=(0, 0, 0, 0)) - rotated_w, rotated_h = rotated_caption.size - - cx = image.width - right_pad - rotated_w if caption_x is None else int(caption_x * width) - cy = int(0.5 * (image.height - rotated_h)) if caption_y is None else int(caption_y * height) - image.paste(rotated_caption, (cx, cy), rotated_caption) + # For vertical, allow explicit x/y for both, else use default + bx = int(colorbar_x * width) if colorbar_x is not None else None + by = int(colorbar_y * height) if colorbar_y is not None else None + cx = int(caption_x * width) if caption_x is not None else None + cy = int(caption_y * height) if caption_y is not None else None + draw_colorbar(image, bar, orientation, x=bx, y=by) + draw_caption(image, caption, font, orientation, x=cx, y=cy) if outpath: logger.info("Saving snapshot to %s", outpath) @@ -285,6 +246,7 @@ def snap4( specular=True, ambient=0.0, brain_scale=1.85, + color_mode=ColorSelection.BOTH, ): """Render four snapshot views (left/right hemispheres, front/back). @@ -349,27 +311,19 @@ def snap4( >>> img = snap4( >>> lhoverlaypath='fsaverage/surf/lh.thickness', >>> rhoverlaypath='fsaverage/surf/rh.thickness', - >>> sdir='./fsaverage' + >>> sdir='./fsaverage' >>> ) >>> img.save('/tmp/whippersnappy_overview.png') """ wwidth = 540 wheight = 450 - # Try to create a visible window first (better for debugging), - # but fall back to an invisible/offscreen window if that fails. - window = _gl.init_window(wwidth, wheight, "WhipperSnapPy", visible=True) - if not window: - logger.warning("Could not create visible GLFW window; retrying with invisible window (offscreen).") - window = _gl.init_window(wwidth, wheight, "WhipperSnapPy", visible=False) - if not window: - logger.error("Could not create any GLFW window/context. OpenGL context unavailable.") - return None + # will raise exception if it cannot be created + window = create_window_with_fallback(wwidth, wheight, "WhipperSnapPy", visible=True) - rot_z = pyrr.Matrix44.from_z_rotation(-0.5 * np.pi) - rot_x = pyrr.Matrix44.from_x_rotation(0.5 * np.pi) - view_left = rot_x * rot_z - rot_y = pyrr.Matrix44.from_y_rotation(np.pi) - view_right = rot_y * view_left + # Use standard view matrices from get_view_matrices and ViewType + view_mats = get_view_matrices() + view_left = view_mats[ViewType.LEFT] + view_right = view_mats[ViewType.RIGHT] transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) # Predefine hemisphere images so static analysis knows they exist even if @@ -409,8 +363,8 @@ def snap4( logger.debug("curvpath=%s exists=%s", curvpath, os.path.exists(curvpath)) try: - meshdata, triangles, fthresh, fmax, pos, neg = prepare_geometry( - meshpath, overlaypath, annotpath, curvpath, labelpath, fthresh, fmax, invert, scale=brain_scale + meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry( + meshpath, overlaypath, annotpath, curvpath, labelpath, fthresh, fmax, invert, scale=brain_scale, color_mode=color_mode ) except Exception as e: logger.error("prepare_geometry failed for %s: %s", meshpath, e) @@ -429,50 +383,17 @@ def snap4( raise ValueError("Overlay has no values to display") try: - shader = _gl.setup_shader(meshdata, triangles, wwidth, wheight, specular=specular, ambient=ambient) + shader = setup_shader(meshdata, triangles, wwidth, wheight, specular=specular, ambient=ambient) logger.debug("Shader setup complete") except Exception as e: logger.error("setup_shader failed: %s", e) glfw.terminate() return None - try: - gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) - except Exception as e: - logger.error("glClear failed: %s", e) - logger.error("glError: %s", gl.glGetError()) - glfw.terminate() - return None - transform_loc = gl.glGetUniformLocation(shader, "transform") - viewmat = view_left if hemi == "lh" else view_right - gl.glUniformMatrix4fv(transform_loc, 1, gl.GL_FALSE, transl * viewmat) - gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) - try: - im1 = _gl.capture_window(window) - logger.debug("Captured image 1 size: %s", im1.size) - except Exception as e: - logger.error("capture_window failed: %s", e) - glfw.terminate() - return None - - glfw.swap_buffers(window) - try: - gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) - except Exception as e: - logger.error("glClear failed: %s", e) - logger.error("glError: %s", gl.glGetError()) - glfw.terminate() - return None - viewmat = view_right if hemi == "lh" else view_left - gl.glUniformMatrix4fv(transform_loc, 1, gl.GL_FALSE, transl * viewmat) - gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) - try: - im2 = _gl.capture_window(window) - logger.debug("Captured image 2 size: %s", im2.size) - except Exception as e: - logger.error("capture_window failed: %s", e) - glfw.terminate() - return None + render_scene(shader, triangles, transl * view_left) + im1 = capture_window(window) + render_scene(shader, triangles, transl * view_right) + im2 = capture_window(window) if hemi == "lh": lhimg = Image.new("RGB", (im1.width, im1.height + im2.height)) @@ -480,9 +401,9 @@ def snap4( lhimg.paste(im2, (0, im1.height)) else: rhimg = Image.new("RGB", (im1.width, im1.height + im2.height)) - # Keep same top/bottom ordering as left hemisphere: top=im1, bottom=im2 - rhimg.paste(im1, (0, 0)) - rhimg.paste(im2, (0, im1.height)) + # For right hemisphere, reverse the order: top=im2, bottom=im1 + rhimg.paste(im2, (0, 0)) + rhimg.paste(im1, (0, im2.height)) # Add small padding around each hemisphere to avoid cropping at edges pad = max(4, int(0.03 * wwidth)) @@ -495,28 +416,25 @@ def snap4( image.paste(padded_lh, (0, 0)) image.paste(padded_rh, (padded_lh.width, 0)) - if caption: - if font_file is None: - font = load_roboto_font(20) - else: - try: - font = ImageFont.truetype(font_file, 20) - except Exception: - font = load_roboto_font(20) - if font is not None: - xpos = 0.5 * (image.width - getattr(font, 'getlength', lambda s: 0)(caption)) - ImageDraw.Draw(image).text((xpos, image.height - 40), caption, (220, 220, 220), font=font) - else: - ImageDraw.Draw(image).text((10, image.height - 40), caption, (220, 220, 220)) - - if lhannotpath is None and rhannotpath is None and colorbar: - bar = create_colorbar(fthresh, fmax, invert, pos=pos, neg=neg) - if bar is not None: - xpos = int(0.5 * (image.width - bar.width)) - ypos = int(0.5 * (image.height - bar.height)) - image.paste(bar, (xpos, ypos)) - - # Otherwise save to disk + font = load_roboto_font(20) if font_file is None else ImageFont.truetype(font_file, 20) if caption else None + # Place caption at bottom, colorbar above if both present + text_w, text_h = text_size(caption, font) if caption and font else (0, 0) + bottom_pad = 20 + gap = 4 + caption_y = image.height - bottom_pad - text_h + bar = create_colorbar(fthresh, fmax, invert, pos=pos, neg=neg) if lhannotpath is None and rhannotpath is None and colorbar else None + bar_h = bar.height if bar is not None else 0 + if bar is not None and caption: + bar_y = image.height - bottom_pad - text_h - gap - bar_h + draw_colorbar(image, bar, OrientationType.HORIZONTAL, y=bar_y) + draw_caption(image, caption, font, OrientationType.HORIZONTAL, y=caption_y) + elif bar is not None: + bar_y = image.height - bottom_pad - bar_h + draw_colorbar(image, bar, OrientationType.HORIZONTAL, y=bar_y) + elif caption: + draw_caption(image, caption, font, OrientationType.HORIZONTAL, y=caption_y) + + # If outpath is specified, save to disk if outpath: logger.info("Saving snapshot to %s", outpath) image.save(outpath) diff --git a/whippersnappy/utils/colormap.py b/whippersnappy/utils/colormap.py index 292ee76..460f81b 100644 --- a/whippersnappy/utils/colormap.py +++ b/whippersnappy/utils/colormap.py @@ -4,7 +4,7 @@ import numpy as np -from whippersnappy.utils.types import ColorSelection +from .types import ColorSelection # Module logger logger = logging.getLogger(__name__) diff --git a/whippersnappy/utils/image.py b/whippersnappy/utils/image.py index 57b8499..f9d2ea0 100644 --- a/whippersnappy/utils/image.py +++ b/whippersnappy/utils/image.py @@ -3,8 +3,8 @@ import numpy as np from PIL import Image, ImageDraw -from whippersnappy.utils.colormap import heat_color -from whippersnappy.utils.types import OrientationType +from .colormap import heat_color +from .types import OrientationType try: # Prefer stdlib importlib.resources @@ -282,3 +282,62 @@ def load_roboto_font(size=14): except Exception: return None + +def draw_colorbar(image, bar, orientation, x=None, y=None): + """Paste a colorbar image onto the target image at the specified position. + + Parameters + ---------- + image : PIL.Image.Image + The target image to paste onto. + bar : PIL.Image.Image + The colorbar image to paste. + orientation : OrientationType + Orientation of the colorbar (HORIZONTAL/VERTICAL). + x, y : int or None, optional + Position to paste the colorbar. If None, defaults to centered at bottom (horizontal) or right (vertical). + """ + if bar is None: + return + if orientation == OrientationType.HORIZONTAL: + bx = int(0.5 * (image.width - bar.width)) if x is None else x + by = image.height - bar.height - 10 if y is None else y + image.paste(bar, (bx, by)) + else: + bx = image.width - bar.width - 10 if x is None else x + by = int(0.5 * (image.height - bar.height)) if y is None else y + image.paste(bar, (bx, by)) + + +def draw_caption(image, caption, font, orientation, x=None, y=None): + """Draw a caption string onto the image at the specified position and orientation. + + Parameters + ---------- + image : PIL.Image.Image + The target image to draw onto. + caption : str + The caption text to draw. + font : PIL.ImageFont + Font to use for the caption. + orientation : OrientationType + Orientation of the caption (HORIZONTAL/VERTICAL). + x, y : int or None, optional + Position to draw the caption. If None, defaults to centered at bottom (horizontal) or right (vertical). + """ + if not caption or font is None: + return + text_w, text_h = text_size(caption, font) + draw = ImageDraw.Draw(image) + if orientation == OrientationType.HORIZONTAL: + cx = int(0.5 * (image.width - text_w)) if x is None else x + cy = image.height - text_h - 10 if y is None else y + draw.text((cx, cy), caption, (220, 220, 220), font=font, anchor="lt") + else: + temp_caption_img = Image.new("RGBA", (text_w, text_h), (0, 0, 0, 0)) + ImageDraw.Draw(temp_caption_img).text((0, 0), caption, font=font, anchor="lt") + rotated_caption = temp_caption_img.rotate(90, expand=True, fillcolor=(0, 0, 0, 0)) + rotated_w, rotated_h = rotated_caption.size + cx = image.width - rotated_w - 10 if x is None else x + cy = int(0.5 * (image.height - rotated_h)) if y is None else y + image.paste(rotated_caption, (cx, cy), rotated_caption) From 308e562914c64152f8e46443a3ff1939fac4eb6e Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 16:37:46 +0100 Subject: [PATCH 21/83] bug fixes in CLI, added whippersnap1 CLI (untested) --- pyproject.toml | 3 +- whippersnappy/cli/whippersnap.py | 138 +++++++++++++++--------------- whippersnappy/cli/whippersnap1.py | 119 ++++++++++++++++++++++++++ whippersnappy/geometry/prepare.py | 2 +- whippersnappy/gl/utils.py | 2 +- whippersnappy/plot3d.py | 3 +- whippersnappy/snap.py | 36 +++++--- 7 files changed, 218 insertions(+), 85 deletions(-) create mode 100644 whippersnappy/cli/whippersnap1.py diff --git a/pyproject.toml b/pyproject.toml index af5de1c..5d3e392 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,7 +102,8 @@ source = 'https://github.com/Deep-MI/WhipperSnapPy' tracker = 'https://github.com/Deep-MI/WhipperSnapPy/issues' [project.scripts] -whippersnap = 'whippersnappy.cli:run' +whippersnap = 'whippersnappy.cli.whippersnap:run' +whippersnap1 = 'whippersnappy.cli.whippersnap1:run' whippersnappy-sys_info = 'whippersnappy.commands.sys_info:run' [tool.setuptools] diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index 9f63433..e4581a4 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -15,17 +15,14 @@ -sd $SURF_SUBJECT_DIR -o $OUTPUT_PATH (See help for full list of arguments.) - -@Author1 : Martin Reuter -@Author2 : Ahmed Faisal Abdelrahman -@Created : 16.03.2022 """ import argparse import logging -import math import os import signal +import sys +import tempfile import threading import glfw @@ -39,11 +36,11 @@ QApplication = None from .. import snap4 +from .._version import __version__ from ..geometry import get_surf_name, prepare_geometry -from ..gl import ( - init_window, render_scene, setup_shader, capture_window, get_view_matrices -) +from ..gl import get_view_matrices, init_window, setup_shader from ..gui import ConfigWindow +from ..utils.types import ViewType # Module logger logger = logging.getLogger(__name__) @@ -102,21 +99,21 @@ def show_window( Returns ------- - bool - ``False`` if the window/context could not be created, ``None`` on - normal termination. The function primarily drives an interactive - event loop and does not return programmatic geometry objects. + None + The function primarily drives an interactive event loop and does not return programmatic geometry objects. Raises ------ + RuntimeError + If the window/context could not be created. FileNotFoundError If a requested surface file cannot be located in ``sdir``. """ global current_fthresh_, current_fmax_, app_, app_window_, app_window_closed_ wwidth = 720 - weight = 600 - window = init_window(wwidth, weight, "WhipperSnapPy", visible=True) + wheight = 600 + window = init_window(wwidth, wheight, "WhipperSnapPy", visible=True) if not window: logger.error("Could not create any GLFW window/context. OpenGL context unavailable.") raise RuntimeError("Could not create any GLFW window/context. OpenGL context unavailable.") @@ -139,18 +136,17 @@ def show_window( if labelname: labelpath = os.path.join(sdir, "label", hemi + "." + labelname) - # set up matrices to show object left and right side: - rot_z = pyrr.Matrix44.from_z_rotation(-0.5 * math.pi) - rot_x = pyrr.Matrix44.from_x_rotation(0.5 * math.pi) - viewLeft = rot_x * rot_z - # rot_y = pyrr.Matrix44.from_y_rotation(math.pi) - # viewRight = rot_y * viewLeft + # set up canonical view matrix for the selected hemisphere + view_mats = get_view_matrices() + viewmat = view_mats[ViewType.LEFT] # fallback + if hemi == "rh": + viewmat = view_mats[ViewType.RIGHT] rot_y = pyrr.Matrix44.from_y_rotation(0) meshdata, triangles, fthresh, fmax, neg = prepare_geometry( meshpath, overlaypath, annotpath, curvpath, labelpath, current_fthresh_, current_fmax_ ) - shader = setup_shader(meshdata, triangles, wwidth, weight, specular=specular) + shader = setup_shader(meshdata, triangles, wwidth, wheight, specular=specular) logger.info("\nKeys:\nLeft - Right : Rotate Geometry\nESC : Quit\n") @@ -174,15 +170,16 @@ def show_window( meshdata, triangles, fthresh, fmax, neg = prepare_geometry( meshpath, overlaypath, + annotpath, curvpath, labelpath, current_fthresh_, current_fmax_, ) - shader = setup_shader(meshdata, triangles, wwidth, weight, specular=specular) + shader = setup_shader(meshdata, triangles, wwidth, wheight, specular=specular) transformLoc = gl.glGetUniformLocation(shader, "transform") - gl.glUniformMatrix4fv(transformLoc, 1, gl.GL_FALSE, rot_y * viewLeft) + gl.glUniformMatrix4fv(transformLoc, 1, gl.GL_FALSE, rot_y * viewmat) if glfw.get_key(window, glfw.KEY_RIGHT) == glfw.PRESS: ypos = ypos + 0.0004 @@ -195,7 +192,9 @@ def show_window( glfw.swap_buffers(window) glfw.terminate() - app_.quit() + # Do NOT call app_.quit() here; QApplication teardown must be handled in the main thread. + # Only set app_window_closed_ = True in this thread. + app_window_closed_ = True def config_app_exit_handler(): @@ -238,10 +237,14 @@ def run(): global current_fthresh_, current_fmax_, app_, app_window_ # Configure basic logging for CLI invocation so messages from module loggers # are visible to end users. Avoid configuring on import by doing this here. - import logging as _logging - _logging.basicConfig(level=_logging.INFO, format='%(levelname)s: %(message)s') + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') parser = argparse.ArgumentParser() + parser.add_argument( + "--version", + action="version", + version=f"%(prog)s {__version__}" + ) parser.add_argument( "-lh", "--lh_overlay", @@ -291,7 +294,7 @@ def run(): "-o", "--output_path", type=str, - default="/tmp/whippersnappy_snap.png", + default=os.path.join(tempfile.gettempdir(), "whippersnappy_snap.png"), help="Absolute path to the output file (snapshot image), " "if not running interactive mode.", ) @@ -304,8 +307,14 @@ def run(): action="store_true", default=False, help="Switch off colorbar.") - parser.add_argument("--fmax", type=float, default=4.0) - parser.add_argument("--fthresh", type=float, default=2.0) + parser.add_argument("--fmax", + type=float, + default=4.0, + help="Overlay saturation value (default: 4.0)") + parser.add_argument("--fthresh", + type=float, + default=2.0, + help="Overlay threshold value (default: 2.0)") parser.add_argument( "-i", "--interactive", @@ -326,39 +335,29 @@ def run(): args = parser.parse_args() - # check for mutually exclusive arguments - if (args.lh_overlay or args.rh_overlay) and (args.lh_annot or args.rh_annot): - msg = "Cannot use lh_overlay/rh_overlay and lh_annot/rh_annot arguments at the same time." - logger.error(msg) - raise ValueError(msg) - # check if at least one variant is present - if args.lh_overlay is None and args.rh_overlay is None and args.lh_annot is None and args.rh_annot is None: - msg = "Either lh_overlay/rh_overlay or lh_annot/rh_annot must be present." - logger.error(msg) - raise ValueError(msg) - # check if both hemis are present - if (args.lh_overlay is None and args.rh_overlay is not None) or \ - (args.lh_overlay is not None and args.rh_overlay is None) or \ - (args.lh_annot is None and args.rh_annot is not None) or \ - (args.lh_annot is not None and args.rh_annot is None): - msg = "If lh_overlay or lh_annot is present, rh_overlay or rh_annot must also be present (and vice versa)." - logger.error(msg) - raise ValueError(msg) - - logger.info(f"Left hemisphere overlay: {args.lh_overlay}") - logger.info(f"Right hemisphere overlay: {args.rh_overlay}") - logger.info(f"Left hemisphere annotation: {args.lh_annot}") - logger.info(f"Right hemisphere annotation: {args.rh_annot}") - logger.info(f"Subject directory: {args.sdir}") - logger.info(f"Surface name: {args.surf_name}") - logger.info(f"Output path: {args.output_path}") - logger.info(f"Caption: {args.caption}") - logger.info(f"Colorbar: {'enabled' if not args.no_colorbar else 'disabled'}") - logger.info(f"fmax: {args.fmax}") - logger.info(f"fthresh: {args.fthresh}") - logger.info(f"Interactive mode: {'enabled' if args.interactive else 'disabled'}") - logger.info(f"Color scale inversion: {'enabled' if args.invert else 'disabled'}") - logger.info(f"Specular reflection: {'enabled' if args.specular else 'disabled'}") + try: + # check for mutually exclusive arguments + if (args.lh_overlay or args.rh_overlay) and (args.lh_annot or args.rh_annot): + msg = "Cannot use lh_overlay/rh_overlay and lh_annot/rh_annot arguments at the same time." + logger.error(msg) + raise ValueError(msg) + # check if at least one variant is present + if args.lh_overlay is None and args.rh_overlay is None and args.lh_annot is None and args.rh_annot is None: + msg = "Either lh_overlay/rh_overlay or lh_annot/rh_annot must be present." + logger.error(msg) + raise ValueError(msg) + # check if both hemis are present + if (args.lh_overlay is None and args.rh_overlay is not None) or \ + (args.lh_overlay is not None and args.rh_overlay is None) or \ + (args.lh_annot is None and args.rh_annot is not None) or \ + (args.lh_annot is not None and args.rh_annot is None): + msg = "If lh_overlay or lh_annot is present, rh_overlay or rh_annot must also be present (and vice versa)." + logger.error(msg) + raise ValueError(msg) + except ValueError as e: + parser.error(str(e)) + + logger.debug("Parsed args: %s", vars(args)) # if not args.interactive: @@ -381,6 +380,12 @@ def run(): current_fthresh_ = args.fthresh current_fmax_ = args.fmax + # Ensure GUI toolkit is available + if QApplication is None: + print("ERROR: Interactive mode requires PyQt6. Install it (pip install PyQt6)" + " or run without --interactive.", file=sys.stderr) + sys.exit(1) + # Starting interactive OpenGL window in a separate thread: thread = threading.Thread( target=show_window, @@ -399,17 +404,10 @@ def run(): ) thread.start() - # Ensure GUI toolkit is available - if QApplication is None: - raise ImportError( - "Interactive mode requires PyQt6. Install it (pip install PyQt6) " - "or run without --interactive." - ) - # Setting up and running config app window (must be main thread): app_ = QApplication([]) app_.setStyle("Fusion") # the default - app_.signals.aboutToQuit.connect(config_app_exit_handler) + app_.aboutToQuit.connect(config_app_exit_handler) screen_geometry = app_.primaryScreen().availableGeometry() app_window_ = ConfigWindow( diff --git a/whippersnappy/cli/whippersnap1.py b/whippersnappy/cli/whippersnap1.py new file mode 100644 index 0000000..832cfd0 --- /dev/null +++ b/whippersnappy/cli/whippersnap1.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +"""CLI entry point for single-mesh snapshot via snap1.""" + +import argparse +import logging +import os +import tempfile + +from .. import snap1 +from .._version import __version__ +from ..utils.types import ColorSelection, OrientationType, ViewType + +_VIEW_CHOICES = {v.name.lower(): v for v in ViewType} +_ORIENT_CHOICES = {o.name.lower(): o for o in OrientationType} +_COLOR_CHOICES = {c.name.lower(): c for c in ColorSelection} + + +def run(): + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + parser = argparse.ArgumentParser( + prog="whippersnap1", + description=( + "Render a single-view screenshot of any triangular surface mesh " + "(FreeSurfer or otherwise) without a GUI." + ), + ) + parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}") + + # --- Required --- + parser.add_argument( + "meshpath", + type=str, + help="Path to the surface file. FreeSurfer binary format (e.g. lh.white) " + "or any mesh readable by the geometry module.", + ) + + # --- Output --- + parser.add_argument( + "-o", "--output", + type=str, + default=os.path.join(tempfile.gettempdir(), "whippersnappy_snap1.png"), + help="Output PNG path. Defaults to a temp file.", + ) + + # --- Optional overlay / annotation / label / curv --- + parser.add_argument("--overlay", type=str, default=None, help="Per-vertex overlay file.") + parser.add_argument("--annot", type=str, default=None, help="FreeSurfer .annot file.") + parser.add_argument("--label", type=str, default=None, help="Label file for masking.") + parser.add_argument("--curv", type=str, default=None, help="Curvature file for texturing.") + + # --- View --- + parser.add_argument( + "--view", + type=str, + default="left", + choices=list(_VIEW_CHOICES), + help="Pre-defined view direction (default: left).", + ) + + # --- Appearance --- + parser.add_argument("--width", type=int, default=700) + parser.add_argument("--height", type=int, default=500) + parser.add_argument("--fthresh", type=float, default=None, help="Overlay threshold.") + parser.add_argument("--fmax", type=float, default=None, help="Overlay saturation value.") + parser.add_argument("--caption", type=str, default=None) + parser.add_argument("--invert", action="store_true", help="Invert color scale.") + parser.add_argument("--no-colorbar", dest="no_colorbar", action="store_true") + parser.add_argument( + "--color-mode", + type=str, + default="both", + choices=list(_COLOR_CHOICES), + help="Which overlay sign to display (default: both).", + ) + parser.add_argument( + "--orientation", + type=str, + default="horizontal", + choices=list(_ORIENT_CHOICES), + help="Colorbar orientation (default: horizontal).", + ) + parser.add_argument("--diffuse", dest="specular", action="store_false", default=True, + help="Use diffuse-only shading (no specular).") + parser.add_argument("--brain-scale", type=float, default=1.5, + help="Geometry scale factor (default: 1.5).") + parser.add_argument("--ambient", type=float, default=0.0, + help="Ambient light strength (default: 0.0).") + parser.add_argument("--font", type=str, default=None, + help="Path to a TTF font for captions.") + + args = parser.parse_args() + + try: + img = snap1( + meshpath=args.meshpath, + outpath=args.output, + overlaypath=args.overlay, + annotpath=args.annot, + labelpath=args.label, + curvpath=args.curv, + view=_VIEW_CHOICES[args.view], + width=args.width, + height=args.height, + fthresh=args.fthresh, + fmax=args.fmax, + caption=args.caption, + invert=args.invert, + colorbar=not args.no_colorbar, + color_mode=_COLOR_CHOICES[args.color_mode], + orientation=_ORIENT_CHOICES[args.orientation], + font_file=args.font, + specular=args.specular, + brain_scale=args.brain_scale, + ambient=args.ambient, + ) + logging.getLogger(__name__).info("Snapshot saved to %s (%dx%d)", args.output, img.width, img.height) + except (RuntimeError, FileNotFoundError, ValueError) as e: + parser.error(str(e)) diff --git a/whippersnappy/geometry/prepare.py b/whippersnappy/geometry/prepare.py index 6db8350..70efbe2 100644 --- a/whippersnappy/geometry/prepare.py +++ b/whippersnappy/geometry/prepare.py @@ -9,9 +9,9 @@ import numpy as np -from .read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data from ..utils.colormap import binary_color, heat_color, mask_label, mask_sign, rescale_overlay from ..utils.types import ColorSelection +from .read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data def normalize_mesh(v, scale=1.0): diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index 09ff93e..0029079 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -349,7 +349,7 @@ def render_scene(shader, triangles, transform): gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) except Exception as exc: logger.error("glClear failed: %s", exc) - raise RuntimeError(f"glClear failed: {exc}") + raise RuntimeError(f"glClear failed: {exc}") from exc transform_loc = gl.glGetUniformLocation(shader, "transform") gl.glUniformMatrix4fv(transform_loc, 1, gl.GL_FALSE, transform) diff --git a/whippersnappy/plot3d.py b/whippersnappy/plot3d.py index 88938f3..d93be86 100644 --- a/whippersnappy/plot3d.py +++ b/whippersnappy/plot3d.py @@ -24,10 +24,9 @@ import pythreejs as p3js from ipywidgets import HTML, VBox -from .utils.types import ColorSelection - from .geometry import prepare_geometry from .gl import get_webgl_shaders +from .utils.types import ColorSelection # Module logger logger = logging.getLogger(__name__) diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py index e7515a9..0e8ddab 100644 --- a/whippersnappy/snap.py +++ b/whippersnappy/snap.py @@ -5,15 +5,14 @@ import glfw import pyrr -from PIL import Image, ImageDraw, ImageFont +from PIL import Image, ImageFont from .geometry import get_surf_name -from .utils.image import create_colorbar, load_roboto_font, text_size, draw_colorbar, draw_caption -from .utils.types import ColorSelection, OrientationType, ViewType - -from .gl.views import get_view_matrices -from .gl.utils import render_scene, create_window_with_fallback, capture_window, setup_shader from .geometry.prepare import prepare_and_validate_geometry +from .gl.utils import capture_window, create_window_with_fallback, render_scene, setup_shader +from .gl.views import get_view_matrices +from .utils.image import create_colorbar, draw_caption, draw_colorbar, load_roboto_font, text_size +from .utils.types import ColorSelection, OrientationType, ViewType # Module logger logger = logging.getLogger(__name__) @@ -178,8 +177,20 @@ def snap1( brain_y = max(0, (height - brain_display_height) // 2) image.paste(capture_window(window), (brain_x, brain_y)) - bar = create_colorbar(fthresh, fmax, invert, orientation, colorbar_scale * ui_scale, pos, neg, font_file=font_file) if overlaypath is not None and colorbar else None - font = load_roboto_font(int(20 * caption_scale * ui_scale)) if font_file is None else ImageFont.truetype(font_file, int(20 * caption_scale * ui_scale)) if caption else None + bar = ( + create_colorbar( + fthresh, fmax, invert, orientation, colorbar_scale * ui_scale, pos, neg, font_file=font_file + ) + if overlaypath is not None and colorbar + else None + ) + font = ( + load_roboto_font(int(20 * caption_scale * ui_scale)) + if font_file is None + else ImageFont.truetype(font_file, int(20 * caption_scale * ui_scale)) + if caption + else None + ) # Compute positions to avoid overlap, unless explicit positions are given text_w, text_h = text_size(caption, font) if caption and font else (0, 0) @@ -364,7 +375,8 @@ def snap4( try: meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry( - meshpath, overlaypath, annotpath, curvpath, labelpath, fthresh, fmax, invert, scale=brain_scale, color_mode=color_mode + meshpath, overlaypath, annotpath, curvpath, labelpath, fthresh, fmax, invert, + scale=brain_scale, color_mode=color_mode ) except Exception as e: logger.error("prepare_geometry failed for %s: %s", meshpath, e) @@ -422,7 +434,11 @@ def snap4( bottom_pad = 20 gap = 4 caption_y = image.height - bottom_pad - text_h - bar = create_colorbar(fthresh, fmax, invert, pos=pos, neg=neg) if lhannotpath is None and rhannotpath is None and colorbar else None + bar = ( + create_colorbar(fthresh, fmax, invert, pos=pos, neg=neg) + if lhannotpath is None and rhannotpath is None and colorbar + else None + ) bar_h = bar.height if bar is not None else 0 if bar is not None and caption: bar_y = image.height - bottom_pad - text_h - gap - bar_h From a3e52de475113865cff6ed7b8c4ec998f3bfbe91 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 16:44:28 +0100 Subject: [PATCH 22/83] add docstring parameter --- whippersnappy/snap.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py index 0e8ddab..7a7fb38 100644 --- a/whippersnappy/snap.py +++ b/whippersnappy/snap.py @@ -302,6 +302,8 @@ def snap4( Ambient lighting strength. Default is ``0``. brain_scale : float, optional Scaling factor passed to geometry preparation. Default is ``1.85``. + color_mode : ColorSelection, optional + Which sign of overlay to color (POSITIVE/NEGATIVE/BOTH). Default is ``ColorSelection.BOTH``. Returns ------- From edff5d41be15d93b2337e6cd6281c87b63c3c5d5 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 17:26:11 +0100 Subject: [PATCH 23/83] adding egl off-screen rendering capabilities --- Dockerfile | 11 +- whippersnappy/gl/__init__.py | 7 +- whippersnappy/gl/egl_context.py | 287 ++++++++++++++++++++++++++++++++ whippersnappy/gl/utils.py | 148 ++++++++++------ whippersnappy/snap.py | 6 +- 5 files changed, 401 insertions(+), 58 deletions(-) create mode 100644 whippersnappy/gl/egl_context.py diff --git a/Dockerfile b/Dockerfile index 9c768ee..ece2e19 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,18 +1,17 @@ FROM ubuntu:24.04 -# Install packages RUN apt-get update && apt-get install -y --no-install-recommends \ - python3 pip xvfb libglib2.0-0 libxkbcommon-x11-0 libgl1 libegl1 \ - libfontconfig1 libdbus-1-3 && \ + python3 pip \ + libegl1 \ + libglib2.0-0 libfontconfig1 libdbus-1-3 && \ apt clean && \ - rm -rf /var/libs/apt/lists/* /tmp/* /var/tmp/* + rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* -# Install python packages RUN pip install --upgrade pip RUN pip install pyopengl glfw pillow numpy pyrr COPY . /WhipperSnapPy RUN pip install /WhipperSnapPy -ENTRYPOINT ["xvfb-run","whippersnap"] +ENTRYPOINT ["whippersnap"] CMD ["--help"] diff --git a/whippersnappy/gl/__init__.py b/whippersnappy/gl/__init__.py index 73996dd..fd5cb38 100644 --- a/whippersnappy/gl/__init__.py +++ b/whippersnappy/gl/__init__.py @@ -15,20 +15,23 @@ create_vao, create_window_with_fallback, init_window, + render_scene, set_camera_uniforms, set_default_gl_state, set_lighting_uniforms, setup_buffers, setup_shader, setup_vertex_attributes, + terminate_context, ) from .views import get_view_matrices, get_view_matrix +from .egl_context import EGLContext __all__ = [ 'create_vao', 'compile_shader_program', 'setup_buffers', 'setup_vertex_attributes', 'set_default_gl_state', 'set_camera_uniforms', 'set_lighting_uniforms', - 'init_window', 'setup_shader', 'capture_window', + 'init_window', 'render_scene', 'setup_shader', 'capture_window', 'make_model', 'make_projection', 'make_view', 'get_default_shaders', 'get_view_matrices', 'get_view_matrix', - 'get_webgl_shaders' + 'get_webgl_shaders', 'terminate_context', 'EGLContext', ] diff --git a/whippersnappy/gl/egl_context.py b/whippersnappy/gl/egl_context.py new file mode 100644 index 0000000..594bf69 --- /dev/null +++ b/whippersnappy/gl/egl_context.py @@ -0,0 +1,287 @@ +"""EGL off-screen (headless) OpenGL context via pbuffer + FBO. + +This module provides a drop-in alternative to GLFW window creation for +headless environments (CI, Docker, HPC clusters) where no X11/Wayland +display is available. It requires: + + - A system EGL library (``libegl1`` on Debian/Ubuntu, already present + in the WhipperSnapPy Dockerfile). + - PyOpenGL >= 3.1 (already a project dependency), which ships + ``OpenGL.EGL`` bindings. + - Either an NVIDIA GPU with the EGL driver, or Mesa ``libEGL-mesa0`` + (llvmpipe software renderer) for CPU-only systems. + +Typical usage (internal, called from ``create_window_with_fallback``):: + + from whippersnappy.gl.egl_context import EGLContext + + ctx = EGLContext(width, height) + ctx.make_current() + # ... OpenGL calls ... + img = ctx.read_pixels() + ctx.destroy() +""" + +import ctypes +import logging + +import OpenGL.GL as gl +from PIL import Image + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# EGL constants not exposed by all PyOpenGL versions +# --------------------------------------------------------------------------- +_EGL_SURFACE_TYPE = 0x3033 +_EGL_PBUFFER_BIT = 0x0001 +_EGL_RENDERABLE_TYPE = 0x3040 +_EGL_OPENGL_BIT = 0x0008 +_EGL_NONE = 0x3038 +_EGL_WIDTH = 0x3057 +_EGL_HEIGHT = 0x3056 +_EGL_OPENGL_API = 0x30A2 +_EGL_CONTEXT_MAJOR_VERSION = 0x3098 +_EGL_CONTEXT_MINOR_VERSION = 0x30FB + + +def _check_egl_available(): + """Raise ImportError with a helpful message if EGL bindings are absent.""" + try: + from OpenGL import EGL as _EGL # noqa: F401 + except (ImportError, AttributeError) as exc: + raise ImportError( + "OpenGL.EGL is not available. Make sure pyopengl >= 3.1 is " + "installed and libegl1 (or equivalent) is present on the system." + ) from exc + + +class EGLContext: + """A headless OpenGL 3.3 Core context backed by an EGL pbuffer + FBO. + + The pbuffer surface is created solely to satisfy EGL's requirement for + a surface when calling ``eglMakeCurrent``. All rendering is directed + into an off-screen Framebuffer Object (FBO) so that ``glReadPixels`` + captures exactly what was rendered regardless of platform quirks with + pbuffer readback. + + Parameters + ---------- + width, height : int + Dimensions of the off-screen render target in pixels. + + Attributes + ---------- + width, height : int + Render target dimensions. + fbo : int + OpenGL FBO handle (valid after ``make_current`` is called). + + Raises + ------ + ImportError + If ``OpenGL.EGL`` bindings are not available. + RuntimeError + If any EGL initialisation step fails. + """ + + def __init__(self, width: int, height: int): + _check_egl_available() + from OpenGL import EGL + + self._EGL = EGL + self.width = width + self.height = height + self._display = None + self._surface = None + self._context = None + self.fbo = None + self._rbo_color = None + self._rbo_depth = None + + self._init_egl() + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _init_egl(self): + EGL = self._EGL + + # 1. Get the default EGL display (works for both GPU and Mesa) + self._display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY) + if self._display == EGL.EGL_NO_DISPLAY: + raise RuntimeError("eglGetDisplay returned EGL_NO_DISPLAY.") + + major = ctypes.c_int(0) + minor = ctypes.c_int(0) + if not EGL.eglInitialize(self._display, major, minor): + raise RuntimeError("eglInitialize failed.") + logger.debug("EGL version %d.%d", major.value, minor.value) + + # 2. Bind the OpenGL API (not OpenGL ES) + if not EGL.eglBindAPI(_EGL_OPENGL_API): + raise RuntimeError("eglBindAPI(OpenGL) failed.") + + # 3. Choose a framebuffer config + cfg_attribs = (ctypes.c_int * 7)( + _EGL_SURFACE_TYPE, _EGL_PBUFFER_BIT, + _EGL_RENDERABLE_TYPE, _EGL_OPENGL_BIT, + _EGL_NONE, + ) + configs = (EGL.EGLConfig * 1)() + num_cfg = ctypes.c_int(0) + if not EGL.eglChooseConfig( + self._display, cfg_attribs, configs, 1, ctypes.byref(num_cfg) + ) or num_cfg.value == 0: + raise RuntimeError( + "eglChooseConfig found no suitable configs. " + "Ensure a Mesa or GPU EGL driver is installed (libegl1-mesa or libegl1)." + ) + + # 4. Create a minimal pbuffer surface (1×1 is sufficient — rendering + # goes into the FBO, not this surface) + pbuf_attribs = (ctypes.c_int * 5)( + _EGL_WIDTH, 1, + _EGL_HEIGHT, 1, + _EGL_NONE, + ) + self._surface = EGL.eglCreatePbufferSurface( + self._display, configs[0], pbuf_attribs + ) + if self._surface == EGL.EGL_NO_SURFACE: + raise RuntimeError("eglCreatePbufferSurface failed.") + + # 5. Create an OpenGL 3.3 Core context + ctx_attribs = (ctypes.c_int * 5)( + _EGL_CONTEXT_MAJOR_VERSION, 3, + _EGL_CONTEXT_MINOR_VERSION, 3, + _EGL_NONE, + ) + self._context = EGL.eglCreateContext( + self._display, configs[0], EGL.EGL_NO_CONTEXT, ctx_attribs + ) + if self._context == EGL.EGL_NO_CONTEXT: + raise RuntimeError( + "eglCreateContext failed. " + "The EGL driver may not support OpenGL 3.3 Core. " + "Check with: glxinfo | grep 'OpenGL version'" + ) + + logger.info("EGL headless context created (%dx%d)", self.width, self.height) + + def make_current(self): + """Make this EGL context current and set up the FBO render target. + + Must be called before any OpenGL commands. Creates and binds an + FBO backed by two renderbuffers (RGBA color + depth/stencil). + """ + EGL = self._EGL + if not EGL.eglMakeCurrent( + self._display, self._surface, self._surface, self._context + ): + raise RuntimeError("eglMakeCurrent failed.") + + # Build FBO so rendering is directed off-screen + self.fbo = gl.glGenFramebuffers(1) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.fbo) + + # Color renderbuffer + self._rbo_color = gl.glGenRenderbuffers(1) + gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._rbo_color) + gl.glRenderbufferStorage( + gl.GL_RENDERBUFFER, gl.GL_RGBA8, self.width, self.height + ) + gl.glFramebufferRenderbuffer( + gl.GL_FRAMEBUFFER, + gl.GL_COLOR_ATTACHMENT0, + gl.GL_RENDERBUFFER, + self._rbo_color, + ) + + # Depth + stencil renderbuffer + self._rbo_depth = gl.glGenRenderbuffers(1) + gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._rbo_depth) + gl.glRenderbufferStorage( + gl.GL_RENDERBUFFER, + gl.GL_DEPTH24_STENCIL8, + self.width, + self.height, + ) + gl.glFramebufferRenderbuffer( + gl.GL_FRAMEBUFFER, + gl.GL_DEPTH_STENCIL_ATTACHMENT, + gl.GL_RENDERBUFFER, + self._rbo_depth, + ) + + status = gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) + if status != gl.GL_FRAMEBUFFER_COMPLETE: + raise RuntimeError( + f"FBO is not complete after EGL setup (status=0x{status:X})." + ) + + # Set the viewport to match the render target + gl.glViewport(0, 0, self.width, self.height) + logger.debug("EGL FBO complete and bound (%dx%d)", self.width, self.height) + + def read_pixels(self) -> Image.Image: + """Read the FBO contents and return a PIL RGB Image. + + Returns + ------- + PIL.Image.Image + Captured frame, vertically flipped to convert from OpenGL's + bottom-left origin to image top-left convention. + """ + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.fbo) + gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) + buf = gl.glReadPixels( + 0, 0, self.width, self.height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE + ) + img = Image.frombytes("RGB", (self.width, self.height), buf) + return img.transpose(Image.FLIP_TOP_BOTTOM) + + def destroy(self): + """Release the FBO, renderbuffers, EGL context and surface. + + Safe to call multiple times; subsequent calls are no-ops. + """ + EGL = self._EGL + + # Clean up GL objects first (context must still be current) + if self.fbo is not None: + gl.glDeleteFramebuffers(1, [self.fbo]) + self.fbo = None + if self._rbo_color is not None: + gl.glDeleteRenderbuffers(1, [self._rbo_color]) + self._rbo_color = None + if self._rbo_depth is not None: + gl.glDeleteRenderbuffers(1, [self._rbo_depth]) + self._rbo_depth = None + + if self._display is not None: + EGL.eglMakeCurrent( + self._display, + EGL.EGL_NO_SURFACE, + EGL.EGL_NO_SURFACE, + EGL.EGL_NO_CONTEXT, + ) + if self._context is not None: + EGL.eglDestroyContext(self._display, self._context) + self._context = None + if self._surface is not None: + EGL.eglDestroySurface(self._display, self._surface) + self._surface = None + EGL.eglTerminate(self._display) + self._display = None + + logger.debug("EGL context destroyed.") + + # Allow use as a context manager + def __enter__(self): + self.make_current() + return self + + def __exit__(self, *_): + self.destroy() diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index 0029079..f83a624 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -12,10 +12,14 @@ from .camera import make_model, make_projection, make_view from .shaders import get_default_shaders +from .egl_context import EGLContext # Module logger logger = logging.getLogger(__name__) +# Module-level EGL context handle (None when GLFW is used instead) +_egl_context: "EGLContext | None" = None + def create_vao(): """Create and bind a Vertex Array Object (VAO). @@ -167,7 +171,8 @@ def init_window(width, height, title="PyOpenGL", visible=True): title : str, optional, default 'PyOpenGL' Window title. visible : bool, optional, default True - If False create an invisible/offscreen window (useful for headless rendering). + If False create an invisible/offscreen window (useful for headless + rendering when a display is available but no screen is needed). Returns ------- @@ -194,33 +199,97 @@ def init_window(width, height, title="PyOpenGL", visible=True): def create_window_with_fallback(width, height, title="WhipperSnapPy", visible=True): - """Create a GLFW window, preferring a visible window and falling back to an invisible one. + """Create an OpenGL context, trying GLFW first and EGL as a fallback. + + The function attempts context creation in this priority order: + + 1. **GLFW visible window** — normal path on workstations. + 2. **GLFW invisible window** — when a display exists but no screen + is needed (e.g. a remote desktop session). + 3. **EGL pbuffer** — fully headless; no display server required. + Works with NVIDIA/AMD GPU drivers and Mesa (llvmpipe) on CPU-only + systems. Requires ``libegl1`` (already installed in the Docker + image) and ``pyopengl >= 3.1``. + + When EGL is used the module-level ``_egl_context`` is set and + ``make_current()`` is called so that subsequent OpenGL calls work + identically to the GLFW path. Parameters ---------- width : int - Requested window width in logical pixels. + Render target width in pixels. height : int - Requested window height in logical pixels. + Render target height in pixels. title : str, optional - Window title. Default is ``'WhipperSnapPy'``. + Window title (used for GLFW paths only). Default is ``'WhipperSnapPy'``. visible : bool, optional - Prefer a visible window when True (default). If creation fails the - function will retry with an invisible/offscreen window. + Prefer a visible window. Default is ``True``. Returns ------- GLFWwindow or None - The created GLFW window handle, or ``None`` if creation failed. + GLFW window handle when GLFW succeeded, ``None`` when EGL is used + (the context is already current via ``_egl_context.make_current()``). + + Raises + ------ + RuntimeError + If all three methods fail to produce a usable OpenGL context. """ + global _egl_context + + # --- Step 1: GLFW visible window --- window = init_window(width, height, title, visible=visible) - if not window and visible: - logger.warning("Could not create visible GLFW window; retrying with invisible window (offscreen).") + if window: + return window + + # --- Step 2: GLFW invisible window --- + if visible: + logger.warning( + "Could not create visible GLFW window; retrying with invisible window." + ) window = init_window(width, height, title, visible=False) - if not window: - logger.error("Could not create any GLFW window/context. OpenGL context unavailable.") - raise RuntimeError("Could not create any GLFW window/context. OpenGL context unavailable.") - return window + if window: + return window + + # --- Step 3: EGL headless pbuffer --- + logger.warning( + "GLFW context creation failed entirely (no display?). " + "Attempting EGL headless context." + ) + try: + ctx = EGLContext(width, height) + ctx.make_current() + _egl_context = ctx + logger.info("Using EGL headless context — no display server required.") + return None # callers treat None as "EGL is active" + except (ImportError, RuntimeError) as exc: + raise RuntimeError( + "Could not create any OpenGL context (tried GLFW visible, " + f"GLFW invisible, EGL pbuffer). Last error: {exc}" + ) from exc + + +def terminate_context(window): + """Release the active OpenGL context regardless of how it was created. + + This is a drop-in replacement for ``glfw.terminate()`` that also + handles the EGL path. Call it at the end of every rendering function + instead of calling ``glfw.terminate()`` directly. + + Parameters + ---------- + window : GLFWwindow or None + The GLFW window handle returned by ``create_window_with_fallback``, + or ``None`` when EGL is active. + """ + global _egl_context + if _egl_context is not None: + _egl_context.destroy() + _egl_context = None + else: + glfw.terminate() def setup_shader(meshdata, triangles, width, height, specular=True, ambient=0.0): @@ -266,50 +335,33 @@ def setup_shader(meshdata, triangles, width, height, specular=True, ambient=0.0) return shader def capture_window(window): - """Read the current GL framebuffer and return it as a PIL.Image (RGB). + """Read the current GL framebuffer and return it as a PIL Image (RGB). - This function captures the framebuffer for the provided GLFW `window` - and returns an RGB :class:`PIL.Image.Image`. On HiDPI displays (e.g. - macOS Retina) the framebuffer may be larger than the logical window - size; the function will downscale the captured physical framebuffer to - logical pixel dimensions when a non-1.0 monitor content scale is - detected. + Works for both GLFW windows and EGL headless contexts. When EGL is + active (``window`` is ``None``) the pixels are read from the FBO that + was set up by :class:`~whippersnappy.gl.egl_context.EGLContext`; in + that case there is no HiDPI scaling to account for. Parameters ---------- - window : GLFWwindow - GLFW window handle whose current OpenGL context/framebuffer will be - read. The function calls :func:`glfw.get_framebuffer_size` to obtain - the read dimensions and :func:`glfw.get_primary_monitor` / - :func:`glfw.get_monitor_content_scale` to detect the display scale. + window : GLFWwindow or None + GLFW window handle, or ``None`` when an EGL context is active. Returns ------- PIL.Image.Image - RGB image containing the captured framebuffer content. On standard - (1x) displays the returned image has the same dimensions as the - framebuffer. On HiDPI displays the image is downscaled to logical - window dimensions (framebuffer size divided by the monitor content - scale) using ``Image.Resampling.LANCZOS``. - - Notes - ----- - - The function uses ``glReadPixels`` with ``GL_PACK_ALIGNMENT=1`` and - converts the raw bytes into a PIL image, performing a vertical flip - to convert OpenGL's bottom-left origin to the image top-left origin. - - Prefer :func:`glfw.get_window_content_scale` or - :func:`glfw.get_monitor_content_scale` to detect per-window/monitor - scaling. The function currently uses the primary monitor's content - scale as a heuristic for HiDPI detection. - - If strict static analyzers complain about ``Image.FLIP_TOP_BOTTOM`` - you can switch to ``Image.Transpose.FLIP_TOP_BOTTOM`` for newer - Pillow versions. + RGB image of the rendered frame, with the vertical flip applied so + that the origin is at the top-left (image convention). """ - # Get primary monitor + global _egl_context + + # --- EGL path: read directly from the FBO --- + if _egl_context is not None: + return _egl_context.read_pixels() + + # --- GLFW path: read from the default framebuffer --- monitor = glfw.get_primary_monitor() - # Get scale factors x_scale, y_scale = glfw.get_monitor_content_scale(monitor) - # Get framebuffer size width, height = glfw.get_framebuffer_size(window) logger.debug("Framebuffer size = (%s,%s)", width, height) @@ -325,8 +377,10 @@ def capture_window(window): rheight = int(round(height / y_scale)) logger.debug("Rescale to = (%s,%s)", rwidth, rheight) image.thumbnail((rwidth, rheight), Image.Resampling.LANCZOS) + return image + def render_scene(shader, triangles, transform): """Render a single draw call using the supplied shader/indices. diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py index 7a7fb38..572fabd 100644 --- a/whippersnappy/snap.py +++ b/whippersnappy/snap.py @@ -9,7 +9,7 @@ from .geometry import get_surf_name from .geometry.prepare import prepare_and_validate_geometry -from .gl.utils import capture_window, create_window_with_fallback, render_scene, setup_shader +from .gl.utils import capture_window, create_window_with_fallback, render_scene, setup_shader, terminate_context from .gl.views import get_view_matrices from .utils.image import create_colorbar, draw_caption, draw_colorbar, load_roboto_font, text_size from .utils.types import ColorSelection, OrientationType, ViewType @@ -234,7 +234,7 @@ def snap1( if outpath: logger.info("Saving snapshot to %s", outpath) image.save(outpath) - glfw.terminate() + terminate_context(window) return image @@ -457,6 +457,6 @@ def snap4( logger.info("Saving snapshot to %s", outpath) image.save(outpath) - glfw.terminate() + terminate_context(window) return image From a2bf6cd48c056b347720ee227e7e2415679ba7f7 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 18:00:20 +0100 Subject: [PATCH 24/83] egl updates --- whippersnappy/gl/egl_context.py | 211 ++++++++++++++++++++++---------- 1 file changed, 146 insertions(+), 65 deletions(-) diff --git a/whippersnappy/gl/egl_context.py b/whippersnappy/gl/egl_context.py index 594bf69..d8513d9 100644 --- a/whippersnappy/gl/egl_context.py +++ b/whippersnappy/gl/egl_context.py @@ -105,70 +105,167 @@ def __init__(self, width: int, height: int): # Private helpers # ------------------------------------------------------------------ + def _get_ext_fn(self, name, restype, argtypes): + """Load an EGL extension function via eglGetProcAddress.""" + addr = self._libegl.eglGetProcAddress(name.encode()) + if not addr: + raise RuntimeError( + f"eglGetProcAddress('{name}') returned NULL — " + f"extension not available on this driver." + ) + FuncType = ctypes.CFUNCTYPE(restype, *argtypes) + return FuncType(addr) + def _init_egl(self): - EGL = self._EGL + import ctypes.util + + egl_name = ctypes.util.find_library("EGL") or "libEGL.so.1" + try: + libegl = ctypes.CDLL(egl_name) + except OSError as e: + raise RuntimeError( + f"Could not load {egl_name}. " + "Install libegl1-mesa and retry." + ) from e + self._libegl = libegl # keep reference alive + + # Set signatures for direct (non-extension) EGL symbols + libegl.eglGetProcAddress.restype = ctypes.c_void_p + libegl.eglGetProcAddress.argtypes = [ctypes.c_char_p] + libegl.eglQueryString.restype = ctypes.c_char_p + libegl.eglQueryString.argtypes = [ctypes.c_void_p, ctypes.c_int] + libegl.eglInitialize.restype = ctypes.c_bool + libegl.eglInitialize.argtypes = [ctypes.c_void_p, + ctypes.POINTER(ctypes.c_int), + ctypes.POINTER(ctypes.c_int)] + libegl.eglBindAPI.restype = ctypes.c_bool + libegl.eglBindAPI.argtypes = [ctypes.c_uint] + libegl.eglChooseConfig.restype = ctypes.c_bool + libegl.eglChooseConfig.argtypes = [ctypes.c_void_p, + ctypes.POINTER(ctypes.c_int), + ctypes.c_void_p, ctypes.c_int, + ctypes.POINTER(ctypes.c_int)] + libegl.eglCreatePbufferSurface.restype = ctypes.c_void_p + libegl.eglCreatePbufferSurface.argtypes = [ctypes.c_void_p, ctypes.c_void_p, + ctypes.POINTER(ctypes.c_int)] + libegl.eglCreateContext.restype = ctypes.c_void_p + libegl.eglCreateContext.argtypes = [ctypes.c_void_p, ctypes.c_void_p, + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_int)] + libegl.eglMakeCurrent.restype = ctypes.c_bool + libegl.eglMakeCurrent.argtypes = [ctypes.c_void_p, ctypes.c_void_p, + ctypes.c_void_p, ctypes.c_void_p] + libegl.eglDestroyContext.restype = ctypes.c_bool + libegl.eglDestroyContext.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + libegl.eglDestroySurface.restype = ctypes.c_bool + libegl.eglDestroySurface.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + libegl.eglTerminate.restype = ctypes.c_bool + libegl.eglTerminate.argtypes = [ctypes.c_void_p] + + # Check extensions and load ext functions via eglGetProcAddress + _EGL_EXTENSIONS = 0x3055 + client_exts = libegl.eglQueryString(None, _EGL_EXTENSIONS) or b"" + logger.debug("EGL client extensions: %s", client_exts.decode()) + + has_device_enum = b"EGL_EXT_device_enumeration" in client_exts + has_platform_base = b"EGL_EXT_platform_base" in client_exts + + display = None + if has_device_enum and has_platform_base: + eglQueryDevicesEXT = self._get_ext_fn( + "eglQueryDevicesEXT", + ctypes.c_bool, + [ctypes.c_int, ctypes.c_void_p, ctypes.POINTER(ctypes.c_int)], + ) + eglGetPlatformDisplayEXT = self._get_ext_fn( + "eglGetPlatformDisplayEXT", + ctypes.c_void_p, + [ctypes.c_int, ctypes.c_void_p, ctypes.POINTER(ctypes.c_int)], + ) + display = self._open_device_display( + eglQueryDevicesEXT, eglGetPlatformDisplayEXT + ) - # 1. Get the default EGL display (works for both GPU and Mesa) - self._display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY) - if self._display == EGL.EGL_NO_DISPLAY: - raise RuntimeError("eglGetDisplay returned EGL_NO_DISPLAY.") + if display is None: + logger.debug("Falling back to eglGetDisplay(EGL_DEFAULT_DISPLAY)") + libegl.eglGetDisplay.restype = ctypes.c_void_p + libegl.eglGetDisplay.argtypes = [ctypes.c_void_p] + display = libegl.eglGetDisplay(ctypes.c_void_p(0)) + + if not display: + raise RuntimeError( + "Could not obtain any EGL display. " + "Install libegl1-mesa for CPU rendering." + ) + self._display = display - major = ctypes.c_int(0) - minor = ctypes.c_int(0) - if not EGL.eglInitialize(self._display, major, minor): + major, minor = ctypes.c_int(0), ctypes.c_int(0) + if not libegl.eglInitialize( + self._display, ctypes.byref(major), ctypes.byref(minor) + ): raise RuntimeError("eglInitialize failed.") - logger.debug("EGL version %d.%d", major.value, minor.value) + logger.debug("EGL %d.%d", major.value, minor.value) - # 2. Bind the OpenGL API (not OpenGL ES) - if not EGL.eglBindAPI(_EGL_OPENGL_API): + if not libegl.eglBindAPI(_EGL_OPENGL_API): raise RuntimeError("eglBindAPI(OpenGL) failed.") - # 3. Choose a framebuffer config cfg_attribs = (ctypes.c_int * 7)( - _EGL_SURFACE_TYPE, _EGL_PBUFFER_BIT, + _EGL_SURFACE_TYPE, _EGL_PBUFFER_BIT, _EGL_RENDERABLE_TYPE, _EGL_OPENGL_BIT, _EGL_NONE, ) - configs = (EGL.EGLConfig * 1)() - num_cfg = ctypes.c_int(0) - if not EGL.eglChooseConfig( - self._display, cfg_attribs, configs, 1, ctypes.byref(num_cfg) - ) or num_cfg.value == 0: - raise RuntimeError( - "eglChooseConfig found no suitable configs. " - "Ensure a Mesa or GPU EGL driver is installed (libegl1-mesa or libegl1)." - ) + configs = (ctypes.c_void_p * 1)() + num_cfgs = ctypes.c_int(0) + if not libegl.eglChooseConfig( + self._display, cfg_attribs, configs, 1, ctypes.byref(num_cfgs) + ) or num_cfgs.value == 0: + raise RuntimeError("eglChooseConfig: no suitable config.") + self._config = configs[0] - # 4. Create a minimal pbuffer surface (1×1 is sufficient — rendering - # goes into the FBO, not this surface) pbuf_attribs = (ctypes.c_int * 5)( - _EGL_WIDTH, 1, - _EGL_HEIGHT, 1, - _EGL_NONE, + _EGL_WIDTH, 1, _EGL_HEIGHT, 1, _EGL_NONE ) - self._surface = EGL.eglCreatePbufferSurface( - self._display, configs[0], pbuf_attribs + self._surface = libegl.eglCreatePbufferSurface( + self._display, self._config, pbuf_attribs ) - if self._surface == EGL.EGL_NO_SURFACE: + if not self._surface: raise RuntimeError("eglCreatePbufferSurface failed.") - # 5. Create an OpenGL 3.3 Core context ctx_attribs = (ctypes.c_int * 5)( _EGL_CONTEXT_MAJOR_VERSION, 3, _EGL_CONTEXT_MINOR_VERSION, 3, _EGL_NONE, ) - self._context = EGL.eglCreateContext( - self._display, configs[0], EGL.EGL_NO_CONTEXT, ctx_attribs + self._context = libegl.eglCreateContext( + self._display, self._config, None, ctx_attribs ) - if self._context == EGL.EGL_NO_CONTEXT: + if not self._context: raise RuntimeError( - "eglCreateContext failed. " - "The EGL driver may not support OpenGL 3.3 Core. " - "Check with: glxinfo | grep 'OpenGL version'" + "eglCreateContext for OpenGL 3.3 Core failed. " + "Try: MESA_GL_VERSION_OVERRIDE=3.3 MESA_GLSL_VERSION_OVERRIDE=330" + ) + logger.info("EGL context created (%dx%d)", self.width, self.height) + + + def _open_device_display(self, eglQueryDevicesEXT, eglGetPlatformDisplayEXT): + """Enumerate EGL devices and return first usable display pointer.""" + n = ctypes.c_int(0) + if not eglQueryDevicesEXT(0, None, ctypes.byref(n)) or n.value == 0: + logger.warning("eglQueryDevicesEXT: no devices.") + return None + logger.debug("EGL: %d device(s) found", n.value) + devices = (ctypes.c_void_p * n.value)() + eglQueryDevicesEXT(n.value, devices, ctypes.byref(n)) + no_attribs = (ctypes.c_int * 1)(_EGL_NONE) + for i, dev in enumerate(devices): + dpy = eglGetPlatformDisplayEXT( + _EGL_PLATFORM_DEVICE_EXT, ctypes.c_void_p(dev), no_attribs ) + if dpy: + logger.debug("EGL: using device %d", i) + return dpy + return None - logger.info("EGL headless context created (%dx%d)", self.width, self.height) def make_current(self): """Make this EGL context current and set up the FBO render target. @@ -243,39 +340,23 @@ def read_pixels(self) -> Image.Image: return img.transpose(Image.FLIP_TOP_BOTTOM) def destroy(self): - """Release the FBO, renderbuffers, EGL context and surface. - - Safe to call multiple times; subsequent calls are no-ops. - """ - EGL = self._EGL - - # Clean up GL objects first (context must still be current) + libegl = self._libegl + # GL cleanup first (context must be current) if self.fbo is not None: - gl.glDeleteFramebuffers(1, [self.fbo]) + gl.glDeleteFramebuffers(1, [self.fbo]); self.fbo = None if self._rbo_color is not None: - gl.glDeleteRenderbuffers(1, [self._rbo_color]) + gl.glDeleteRenderbuffers(1, [self._rbo_color]); self._rbo_color = None if self._rbo_depth is not None: - gl.glDeleteRenderbuffers(1, [self._rbo_depth]) + gl.glDeleteRenderbuffers(1, [self._rbo_depth]); self._rbo_depth = None - - if self._display is not None: - EGL.eglMakeCurrent( - self._display, - EGL.EGL_NO_SURFACE, - EGL.EGL_NO_SURFACE, - EGL.EGL_NO_CONTEXT, - ) - if self._context is not None: - EGL.eglDestroyContext(self._display, self._context) - self._context = None - if self._surface is not None: - EGL.eglDestroySurface(self._display, self._surface) - self._surface = None - EGL.eglTerminate(self._display) - self._display = None - + if self._display: + libegl.eglMakeCurrent(self._display, None, None, None) + if self._context: libegl.eglDestroyContext(self._display, self._context) + if self._surface: libegl.eglDestroySurface(self._display, self._surface) + libegl.eglTerminate(self._display) + self._display = self._context = self._surface = None logger.debug("EGL context destroyed.") # Allow use as a context manager From e9609f5ca86fb42462b2873d8b2a9745bd8ce23e Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 18:04:14 +0100 Subject: [PATCH 25/83] egl updates --- whippersnappy/gl/egl_context.py | 1 + 1 file changed, 1 insertion(+) diff --git a/whippersnappy/gl/egl_context.py b/whippersnappy/gl/egl_context.py index d8513d9..ad99ea4 100644 --- a/whippersnappy/gl/egl_context.py +++ b/whippersnappy/gl/egl_context.py @@ -43,6 +43,7 @@ _EGL_OPENGL_API = 0x30A2 _EGL_CONTEXT_MAJOR_VERSION = 0x3098 _EGL_CONTEXT_MINOR_VERSION = 0x30FB +_EGL_PLATFORM_DEVICE_EXT = 0x313F def _check_egl_available(): From a5bdc62e0f2003a2906978f089dd20f3c119b1bb Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 18:08:11 +0100 Subject: [PATCH 26/83] egl updates --- whippersnappy/gl/egl_context.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/whippersnappy/gl/egl_context.py b/whippersnappy/gl/egl_context.py index ad99ea4..0ec67e5 100644 --- a/whippersnappy/gl/egl_context.py +++ b/whippersnappy/gl/egl_context.py @@ -46,17 +46,6 @@ _EGL_PLATFORM_DEVICE_EXT = 0x313F -def _check_egl_available(): - """Raise ImportError with a helpful message if EGL bindings are absent.""" - try: - from OpenGL import EGL as _EGL # noqa: F401 - except (ImportError, AttributeError) as exc: - raise ImportError( - "OpenGL.EGL is not available. Make sure pyopengl >= 3.1 is " - "installed and libegl1 (or equivalent) is present on the system." - ) from exc - - class EGLContext: """A headless OpenGL 3.3 Core context backed by an EGL pbuffer + FBO. @@ -87,19 +76,16 @@ class EGLContext: """ def __init__(self, width: int, height: int): - _check_egl_available() - from OpenGL import EGL - - self._EGL = EGL self.width = width self.height = height + self._libegl = None self._display = None self._surface = None self._context = None + self._config = None self.fbo = None self._rbo_color = None self._rbo_depth = None - self._init_egl() # ------------------------------------------------------------------ @@ -274,9 +260,8 @@ def make_current(self): Must be called before any OpenGL commands. Creates and binds an FBO backed by two renderbuffers (RGBA color + depth/stencil). """ - EGL = self._EGL - if not EGL.eglMakeCurrent( - self._display, self._surface, self._surface, self._context + if not self._libegl.eglMakeCurrent( + self._display, self._surface, self._surface, self._context ): raise RuntimeError("eglMakeCurrent failed.") From 272e936c369382b03d95e6ba5780478bc7bd27ec Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 18:25:29 +0100 Subject: [PATCH 27/83] egl updates --- whippersnappy/snap.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py index 572fabd..d62551a 100644 --- a/whippersnappy/snap.py +++ b/whippersnappy/snap.py @@ -125,17 +125,21 @@ def snap1( ref_height = 500 ui_scale = min(width / ref_width, height / ref_height) - if not glfw.init(): - logger.error("Could not init glfw!") - raise RuntimeError("Could not initialize GLFW; OpenGL context unavailable") - primary_monitor = glfw.get_primary_monitor() - mode = glfw.get_video_mode(primary_monitor) - screen_width = mode.size.width - screen_height = mode.size.height - if width > screen_width: - logger.info("Requested width %d exceeds screen width %d, expect black bars", width, screen_width) - elif height > screen_height: - logger.info("Requested height %d exceeds screen height %d, expect black bars", height, screen_height) + # Screen size check only makes sense with a real display; skip on headless. + # create_window_with_fallback will handle context creation + EGL fallback. + try: + if glfw.init(): + primary_monitor = glfw.get_primary_monitor() + if primary_monitor: + mode = glfw.get_video_mode(primary_monitor) + if width > mode.size.width: + logger.info("Requested width %d exceeds screen width %d, expect black bars", + width, mode.size.width) + elif height > mode.size.height: + logger.info("Requested height %d exceeds screen height %d, expect black bars", + height, mode.size.height) + except Exception: + pass # headless — no monitor info available, that's fine image = Image.new("RGB", (width, height)) @@ -382,7 +386,7 @@ def snap4( ) except Exception as e: logger.error("prepare_geometry failed for %s: %s", meshpath, e) - glfw.terminate() + terminate_context(window) return None # Diagnostics about mesh data @@ -401,7 +405,7 @@ def snap4( logger.debug("Shader setup complete") except Exception as e: logger.error("setup_shader failed: %s", e) - glfw.terminate() + terminate_context(window) return None render_scene(shader, triangles, transl * view_left) From d1da6d6f899f8167269041c2c89c745f392c8b33 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 18:27:27 +0100 Subject: [PATCH 28/83] egl updates --- whippersnappy/snap.py | 1 - 1 file changed, 1 deletion(-) diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py index d62551a..cb5d9b7 100644 --- a/whippersnappy/snap.py +++ b/whippersnappy/snap.py @@ -148,7 +148,6 @@ def snap1( brain_display_width = min(bwidth, width) brain_display_height = min(bheight, height) logger.debug("Requested (width,height) = (%s,%s)", width, height) - logger.debug("Screen (width,height) = (%s,%s)", screen_width, screen_height) logger.debug("Brain (width,height) = (%s,%s)", bwidth, bheight) logger.debug("B-Display (width,height) = (%s,%s)", brain_display_width, brain_display_height) From de4553f582264cb83d3d87d4ea081bc764819f66 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 18:34:44 +0100 Subject: [PATCH 29/83] egl updates --- whippersnappy/gl/egl_context.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/whippersnappy/gl/egl_context.py b/whippersnappy/gl/egl_context.py index 0ec67e5..01447a3 100644 --- a/whippersnappy/gl/egl_context.py +++ b/whippersnappy/gl/egl_context.py @@ -265,6 +265,11 @@ def make_current(self): ): raise RuntimeError("eglMakeCurrent failed.") + # Force PyOpenGL to discover and cache the context we just made current. + # PyOpenGL's contextdata module only recognizes contexts it has "seen" + # via at least one GL call; glGetError() is the cheapest trigger. + gl.glGetError() + # Build FBO so rendering is directed off-screen self.fbo = gl.glGenFramebuffers(1) gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.fbo) From 1f259e9deb7bf7ad28f47a260d474b052d1efea2 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 18:45:26 +0100 Subject: [PATCH 30/83] egl updates --- whippersnappy/gl/__init__.py | 2 ++ whippersnappy/gl/_platform.py | 14 ++++++++++++++ whippersnappy/gl/egl_context.py | 7 +++++++ whippersnappy/gl/utils.py | 31 ++++++++++++++++++++++++++++++- 4 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 whippersnappy/gl/_platform.py diff --git a/whippersnappy/gl/__init__.py b/whippersnappy/gl/__init__.py index fd5cb38..0d0bfda 100644 --- a/whippersnappy/gl/__init__.py +++ b/whippersnappy/gl/__init__.py @@ -7,6 +7,8 @@ """ +from . import _platform # noqa: F401 — MUST be first; sets PYOPENGL_PLATFORM + from .camera import make_model, make_projection, make_view from .shaders import get_default_shaders, get_webgl_shaders from .utils import ( diff --git a/whippersnappy/gl/_platform.py b/whippersnappy/gl/_platform.py new file mode 100644 index 0000000..e68db02 --- /dev/null +++ b/whippersnappy/gl/_platform.py @@ -0,0 +1,14 @@ +"""Bootstrap PyOpenGL platform selection — must be imported first. + +Imported unconditionally at the top of gl/__init__.py before any other +OpenGL symbol. Sets PYOPENGL_PLATFORM=egl when no display is available +so that PyOpenGL uses the EGL backend instead of GLX. + +If the user has already set PYOPENGL_PLATFORM, that value is respected. +""" +import os + +if "PYOPENGL_PLATFORM" not in os.environ: + # No display = definitely headless, force EGL now before any import + if not os.environ.get("DISPLAY") and not os.environ.get("WAYLAND_DISPLAY"): + os.environ["PYOPENGL_PLATFORM"] = "egl" diff --git a/whippersnappy/gl/egl_context.py b/whippersnappy/gl/egl_context.py index 01447a3..c93600e 100644 --- a/whippersnappy/gl/egl_context.py +++ b/whippersnappy/gl/egl_context.py @@ -24,6 +24,13 @@ import ctypes import logging +import os + +# Must be set before OpenGL.GL is imported anywhere in the process. +# If already set (e.g. user set it, or GLFW succeeded), respect it. +# We set it here because this module is only imported when EGL is needed. +if os.environ.get("PYOPENGL_PLATFORM") != "egl": + os.environ["PYOPENGL_PLATFORM"] = "egl" import OpenGL.GL as gl from PIL import Image diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index f83a624..d9dacbd 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -258,12 +258,20 @@ def create_window_with_fallback(width, height, title="WhipperSnapPy", visible=Tr "GLFW context creation failed entirely (no display?). " "Attempting EGL headless context." ) + # GLFW proved the display is unusable. Force EGL platform now so that + # PyOpenGL uses the EGL backend. This must happen before egl_context + # imports OpenGL.GL — but since utils.py already imported it, we need + # to also reset PyOpenGL's platform object. + os.environ["PYOPENGL_PLATFORM"] = "egl" + _force_pyopengl_egl_platform() + try: + from .egl_context import EGLContext ctx = EGLContext(width, height) ctx.make_current() _egl_context = ctx logger.info("Using EGL headless context — no display server required.") - return None # callers treat None as "EGL is active" + return None except (ImportError, RuntimeError) as exc: raise RuntimeError( "Could not create any OpenGL context (tried GLFW visible, " @@ -413,3 +421,24 @@ def render_scene(shader, triangles, transform): if err != gl.GL_NO_ERROR: logger.error("OpenGL error after draw: %s", err) raise RuntimeError(f"OpenGL error after draw: {err}") + +def _force_pyopengl_egl_platform(): + """Switch PyOpenGL's active platform to EGL at runtime. + + PyOpenGL caches its platform object at first import. When GLFW fails + and we fall back to EGL, we need to replace the cached platform so + that subsequent GL calls use the EGL context rather than expecting a + GLX context. This is only called once, when the fallback triggers. + """ + try: + import OpenGL.platform + import OpenGL.platform.egl as egl_platform_module + new_platform = egl_platform_module.EGLPlatform() + OpenGL.platform.PLATFORM = new_platform + # Also patch the GL function resolver to use the new platform + import OpenGL.GL + OpenGL.GL.glGetError.__self__.__class__.__bases__ # touch the class + logger.debug("Switched PyOpenGL platform to EGL.") + except Exception as e: + logger.warning("Could not switch PyOpenGL platform to EGL: %s", e) + # Not fatal — gl.glGetError() in make_current may still work From 4e30dc8a35e6a292a22d522fbf03e991890f17df Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 18:46:25 +0100 Subject: [PATCH 31/83] egl updates --- whippersnappy/gl/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index d9dacbd..bc6bf91 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -4,6 +4,7 @@ """ import logging +import os import glfw import OpenGL.GL as gl From bb99b05e2e9e21a1c76e34985a058a91c51e6df2 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 18:53:04 +0100 Subject: [PATCH 32/83] egl updates --- whippersnappy/__init__.py | 26 ++++++++++++++++++++++++++ whippersnappy/gl/utils.py | 27 --------------------------- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index 6e6e002..88bfb61 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -43,6 +43,32 @@ """ +import os + +def _check_display(): + """Return True if a working X11 display connection can be opened.""" + display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY") + if not display: + return False + try: + import ctypes, ctypes.util + libx11 = ctypes.CDLL(ctypes.util.find_library("X11") or "libX11.so.6") + libx11.XOpenDisplay.restype = ctypes.c_void_p + libx11.XOpenDisplay.argtypes = [ctypes.c_char_p] + libx11.XCloseDisplay.restype = None + libx11.XCloseDisplay.argtypes = [ctypes.c_void_p] + dpy = libx11.XOpenDisplay(display.encode()) + if dpy: + libx11.XCloseDisplay(dpy) + return True + return False + except Exception: + return False + +if "PYOPENGL_PLATFORM" not in os.environ: + if not _check_display(): + os.environ["PYOPENGL_PLATFORM"] = "egl" + from ._config import sys_info # noqa: F401 from ._version import __version__ # noqa: F401 from .snap import snap1, snap4 diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index bc6bf91..25ba866 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -259,13 +259,6 @@ def create_window_with_fallback(width, height, title="WhipperSnapPy", visible=Tr "GLFW context creation failed entirely (no display?). " "Attempting EGL headless context." ) - # GLFW proved the display is unusable. Force EGL platform now so that - # PyOpenGL uses the EGL backend. This must happen before egl_context - # imports OpenGL.GL — but since utils.py already imported it, we need - # to also reset PyOpenGL's platform object. - os.environ["PYOPENGL_PLATFORM"] = "egl" - _force_pyopengl_egl_platform() - try: from .egl_context import EGLContext ctx = EGLContext(width, height) @@ -423,23 +416,3 @@ def render_scene(shader, triangles, transform): logger.error("OpenGL error after draw: %s", err) raise RuntimeError(f"OpenGL error after draw: {err}") -def _force_pyopengl_egl_platform(): - """Switch PyOpenGL's active platform to EGL at runtime. - - PyOpenGL caches its platform object at first import. When GLFW fails - and we fall back to EGL, we need to replace the cached platform so - that subsequent GL calls use the EGL context rather than expecting a - GLX context. This is only called once, when the fallback triggers. - """ - try: - import OpenGL.platform - import OpenGL.platform.egl as egl_platform_module - new_platform = egl_platform_module.EGLPlatform() - OpenGL.platform.PLATFORM = new_platform - # Also patch the GL function resolver to use the new platform - import OpenGL.GL - OpenGL.GL.glGetError.__self__.__class__.__bases__ # touch the class - logger.debug("Switched PyOpenGL platform to EGL.") - except Exception as e: - logger.warning("Could not switch PyOpenGL platform to EGL: %s", e) - # Not fatal — gl.glGetError() in make_current may still work From 44563f3f8191b8810a151e8dce91348636faae7e Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 19:10:35 +0100 Subject: [PATCH 33/83] egl updates --- whippersnappy/gl/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index 25ba866..4022e43 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -13,7 +13,6 @@ from .camera import make_model, make_projection, make_view from .shaders import get_default_shaders -from .egl_context import EGLContext # Module logger logger = logging.getLogger(__name__) From b6cb745bbaf3cd44a2a4e8d71695e6a2c441afca Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 19:17:04 +0100 Subject: [PATCH 34/83] egl updates --- whippersnappy/__init__.py | 3 +++ whippersnappy/gl/egl_context.py | 3 +++ whippersnappy/gl/utils.py | 5 +++++ 3 files changed, 11 insertions(+) diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index 88bfb61..fd893c9 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -47,6 +47,9 @@ def _check_display(): """Return True if a working X11 display connection can be opened.""" + # macOS uses CGL/Cocoa — GLFW handles context creation natively, no EGL needed + if sys.platform == "darwin": + return True display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY") if not display: return False diff --git a/whippersnappy/gl/egl_context.py b/whippersnappy/gl/egl_context.py index c93600e..f0e4fbf 100644 --- a/whippersnappy/gl/egl_context.py +++ b/whippersnappy/gl/egl_context.py @@ -25,6 +25,9 @@ import ctypes import logging import os +import sys +if sys.platform == "darwin": + raise ImportError("EGL is not available on macOS; use GLFW/CGL instead.") # Must be set before OpenGL.GL is imported anywhere in the process. # If already set (e.g. user set it, or GLFW succeeded), respect it. diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index 4022e43..41f91ef 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -258,6 +258,11 @@ def create_window_with_fallback(width, height, title="WhipperSnapPy", visible=Tr "GLFW context creation failed entirely (no display?). " "Attempting EGL headless context." ) + if sys.platform == "darwin": + raise RuntimeError( + "Could not create any OpenGL context via GLFW on macOS. " + "Ensure you are running with a display available." + ) try: from .egl_context import EGLContext ctx = EGLContext(width, height) From 3435bf01551cd1ed840c531f597a787415428c0a Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 19:18:48 +0100 Subject: [PATCH 35/83] egl updates --- whippersnappy/__init__.py | 1 + whippersnappy/gl/utils.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index fd893c9..0e86672 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -44,6 +44,7 @@ """ import os +import sys def _check_display(): """Return True if a working X11 display connection can be opened.""" diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index 41f91ef..8112b9a 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -4,7 +4,7 @@ """ import logging -import os +import sys import glfw import OpenGL.GL as gl From b0026a6b3c902437d2b96a5c37b327ff0664a4a9 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 19:21:01 +0100 Subject: [PATCH 36/83] egl updates --- whippersnappy/gl/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/whippersnappy/gl/__init__.py b/whippersnappy/gl/__init__.py index 0d0bfda..5876f0a 100644 --- a/whippersnappy/gl/__init__.py +++ b/whippersnappy/gl/__init__.py @@ -27,7 +27,6 @@ terminate_context, ) from .views import get_view_matrices, get_view_matrix -from .egl_context import EGLContext __all__ = [ 'create_vao', 'compile_shader_program', 'setup_buffers', 'setup_vertex_attributes', @@ -35,5 +34,5 @@ 'init_window', 'render_scene', 'setup_shader', 'capture_window', 'make_model', 'make_projection', 'make_view', 'get_default_shaders', 'get_view_matrices', 'get_view_matrix', - 'get_webgl_shaders', 'terminate_context', 'EGLContext', + 'get_webgl_shaders', 'terminate_context', ] From bc9bafe9e30b2c33f6d83b084757d2de5bb8bae2 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 19:30:46 +0100 Subject: [PATCH 37/83] egl updates --- Dockerfile | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index ece2e19..5103578 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,11 +1,13 @@ -FROM ubuntu:24.04 +FROM python:3.11-slim RUN apt-get update && apt-get install -y --no-install-recommends \ - python3 pip \ libegl1 \ - libglib2.0-0 libfontconfig1 libdbus-1-3 && \ - apt clean && \ - rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* + libgl1 \ + libglib2.0-0 \ + libfontconfig1 \ + libdbus-1-3 && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* RUN pip install --upgrade pip RUN pip install pyopengl glfw pillow numpy pyrr @@ -15,3 +17,4 @@ RUN pip install /WhipperSnapPy ENTRYPOINT ["whippersnap"] CMD ["--help"] + From 8aa55f71d1676f585d68b4bbce87a5a7afa313f5 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 18 Feb 2026 19:43:59 +0100 Subject: [PATCH 38/83] egl updates --- pyproject.toml | 2 +- whippersnappy/cli/whippersnap.py | 9 ++++++++- whippersnappy/gui/config_app.py | 4 ---- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5d3e392..70e925a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ 'numpy>=1.21', 'pyrr', 'pillow', - 'pyopengl==3.1.6', + 'pyopengl>=3.1.8', 'nibabel', 'psutil' ] diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index e4581a4..161b9ac 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -39,7 +39,6 @@ from .._version import __version__ from ..geometry import get_surf_name, prepare_geometry from ..gl import get_view_matrices, init_window, setup_shader -from ..gui import ConfigWindow from ..utils.types import ViewType # Module logger @@ -377,6 +376,14 @@ def run(): specular=args.specular, ) else: + try: + from ..gui import ConfigWindow # lazy import + except ModuleNotFoundError as e: + raise RuntimeError( + "Interactive mode requires the optional dependency PyQt6. " + "Install with: pip install 'whippersnappy[gui]'" + ) from e + current_fthresh_ = args.fthresh current_fmax_ = args.fmax diff --git a/whippersnappy/gui/config_app.py b/whippersnappy/gui/config_app.py index 857943a..0f85ed7 100644 --- a/whippersnappy/gui/config_app.py +++ b/whippersnappy/gui/config_app.py @@ -5,10 +5,6 @@ Dependencies: PyQt6 - -@Author : Ahmed Faisal Abdelrahman -@Created : 20.03.2022 - """ from PyQt6.QtCore import Qt From ddff0593c80557a7ac2e9986bac1274d816fc0a0 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Thu, 19 Feb 2026 18:54:29 +0100 Subject: [PATCH 39/83] fixes, especially fixed thresholds in snap4 for both hemis --- whippersnappy/_config.py | 2 +- whippersnappy/geometry/__init__.py | 5 +- whippersnappy/geometry/prepare.py | 87 +++++- whippersnappy/plot3d.py | 13 +- whippersnappy/snap.py | 434 +++++++++++++++-------------- 5 files changed, 309 insertions(+), 232 deletions(-) diff --git a/whippersnappy/_config.py b/whippersnappy/_config.py index 112450d..871f75b 100644 --- a/whippersnappy/_config.py +++ b/whippersnappy/_config.py @@ -119,7 +119,7 @@ def sys_info(fid: Optional[IO] = None, developer: bool = False): dependencies = [ elt.split(";")[0].rstrip() for elt in raw_requires - if f"extra == '{key}'" in elt or f"extra == \"{key}\"" in elt or True + if f"extra == '{key}'" in elt or f"extra == \"{key}\"" in elt ] if len(dependencies) == 0: continue diff --git a/whippersnappy/geometry/__init__.py b/whippersnappy/geometry/__init__.py index b7ac525..85deec0 100644 --- a/whippersnappy/geometry/__init__.py +++ b/whippersnappy/geometry/__init__.py @@ -2,10 +2,11 @@ Expose prepare_geometry and small IO helpers under `whippersnappy.geometry`. """ -from .prepare import prepare_geometry +from .prepare import estimate_overlay_thresholds, prepare_geometry from .read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data from .surf_name import get_surf_name __all__ = [ - 'prepare_geometry', 'read_geometry', 'read_annot_data', 'read_mgh_data', 'read_morph_data', 'get_surf_name' + 'prepare_geometry', 'estimate_overlay_thresholds', + 'read_geometry', 'read_annot_data', 'read_mgh_data', 'read_morph_data', 'get_surf_name', ] diff --git a/whippersnappy/geometry/prepare.py b/whippersnappy/geometry/prepare.py index 70efbe2..f4da0ce 100644 --- a/whippersnappy/geometry/prepare.py +++ b/whippersnappy/geometry/prepare.py @@ -65,15 +65,84 @@ def vertex_normals(v, t): cr1 = np.cross(v2mv1, -v1mv0) cr2 = np.cross(v0mv2, -v2mv1) n = np.zeros(v.shape) - np.add.at(n, t[:, 0], cr0) - np.add.at(n, t[:, 1], cr1) - np.add.at(n, t[:, 2], cr2) + #np.add.at(n, t[:, 0], cr0) + #np.add.at(n, t[:, 1], cr1) + #np.add.at(n, t[:, 2], cr2) + # Vectorized accumulation using bincount + idx = np.concatenate([t[:, 0], t[:, 1], t[:, 2]]) + contribs = np.vstack([cr0, cr1, cr2]) + n = np.empty((v.shape[0], 3), dtype=np.float64) + for j in range(3): + n[:, j] = np.bincount(idx, weights=contribs[:, j], minlength=v.shape[0]) ln = np.sqrt(np.sum(n * n, axis=1)) ln[ln < np.finfo(float).eps] = 1 n = n / ln.reshape(-1, 1) return n +def _estimate_thresholds_from_array(mapdata, minval=None, maxval=None): + """Estimate threshold and saturation values from an already-loaded array. + + Parameters + ---------- + mapdata : numpy.ndarray + Per-vertex overlay values. + minval : float or None, optional + If provided, used as-is; otherwise estimated as the minimum absolute + value in the data. + maxval : float or None, optional + If provided, used as-is; otherwise estimated as the maximum absolute + value in the data. + + Returns + ------- + minval : float + Threshold value (lower bound of the color scale). + maxval : float + Saturation value (upper bound of the color scale). + """ + valabs = np.abs(mapdata) + if maxval is None: + maxval = float(np.max(valabs)) if np.any(valabs) else 0.0 + if minval is None: + minval = float(max(0.0, np.min(valabs) if np.any(valabs) else 0.0)) + return minval, maxval + + +def estimate_overlay_thresholds(overlaypath, minval=None, maxval=None): + """Estimate threshold and saturation values from an overlay file. + + Reads the overlay data and derives ``fmin`` / ``fmax`` from the absolute + values without performing any geometry or color work. Both values are + returned unchanged when they are already provided by the caller, making + the function safe to call unconditionally. + + Parameters + ---------- + overlaypath : str + Path to the overlay file (.mgh or FreeSurfer morph format). + minval : float or None, optional + If provided, used as-is for the threshold; otherwise estimated as + the minimum absolute value in the overlay. + maxval : float or None, optional + If provided, used as-is for the saturation; otherwise estimated as + the maximum absolute value in the overlay. + + Returns + ------- + minval : float + Threshold value (lower bound of the color scale). + maxval : float + Saturation value (upper bound of the color scale). + """ + _, file_extension = os.path.splitext(overlaypath) + if file_extension == ".mgh": + mapdata = read_mgh_data(overlaypath) + else: + mapdata = read_morph_data(overlaypath) + return _estimate_thresholds_from_array(mapdata, minval, maxval) + + def prepare_geometry( surfpath, overlaypath=None, @@ -171,11 +240,7 @@ def prepare_geometry( "file." ) else: - valabs = np.abs(mapdata) - if maxval is None: - maxval = np.max(valabs) if np.any(valabs) else 0 - if minval is None: - minval = max(0.0, np.min(valabs) if np.any(valabs) else 0) + minval, maxval = _estimate_thresholds_from_array(mapdata, minval, maxval) mapdata = mask_sign(mapdata, color_mode) mapdata, fmin, fmax, pos, neg = rescale_overlay(mapdata, minval, maxval) @@ -198,12 +263,6 @@ def prepare_geometry( "file." ) else: - # If annot is shorter, pad with -1 (meaning 'no label') to match - # mesh vertices. - if annot.shape[0] < num_vertices: - pad_len = num_vertices - annot.shape[0] - annot = np.pad(annot, (0, pad_len), mode="constant", constant_values=-1) - # Ensure integer type for safe indexing annot = annot.astype(np.int32) diff --git a/whippersnappy/plot3d.py b/whippersnappy/plot3d.py index d93be86..4be6af9 100644 --- a/whippersnappy/plot3d.py +++ b/whippersnappy/plot3d.py @@ -45,6 +45,7 @@ def plot3d( color_mode=None, width=800, height=800, + ambient=0.1, ): """Create an interactive 3D notebook viewer using pythreejs (Three.js). @@ -77,6 +78,8 @@ def plot3d( If None, defaults to ``ColorSelection.BOTH``. width, height : int, optional, default 800 Canvas dimensions for the generated renderer. + ambient : float, optional, default 0.1 + Ambient lighting strength for the shader (passed to Three.js uniform). Returns ------- @@ -119,7 +122,7 @@ def plot3d( vertices = vertices / max_extent * 2.0 # Create Three.js mesh - mesh = create_threejs_mesh_with_custom_shaders(vertices, triangles, colors, normals) + mesh = create_threejs_mesh_with_custom_shaders(vertices, triangles, colors, normals, ambient=ambient) camera = p3js.PerspectiveCamera( position=[-5, 0, 0], @@ -170,9 +173,7 @@ def plot3d( return viewer - - -def create_threejs_mesh_with_custom_shaders(vertices, faces, colors, normals): +def create_threejs_mesh_with_custom_shaders(vertices, faces, colors, normals, ambient=0.1): """Create a pythreejs.Mesh using custom shader material and buffers. The function builds a BufferGeometry with position, color and normal @@ -189,6 +190,8 @@ def create_threejs_mesh_with_custom_shaders(vertices, faces, colors, normals): Array of shape (N, 3) with per-vertex RGB colors (float32). normals : numpy.ndarray Array of shape (N, 3) with per-vertex normals (float32). + ambient : float, optional, default 0.1 + Ambient lighting strength for the shader (passed to Three.js uniform). Returns ------- @@ -216,7 +219,7 @@ def create_threejs_mesh_with_custom_shaders(vertices, faces, colors, normals): fragmentShader=fragment_shader, uniforms={ 'lightColor': {'value': [1.0, 1.0, 1.0]}, - 'ambientStrength': {'value': 0.1} + 'ambientStrength': {'value': ambient} } ) diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py index cb5d9b7..daddd6e 100644 --- a/whippersnappy/snap.py +++ b/whippersnappy/snap.py @@ -7,7 +7,7 @@ import pyrr from PIL import Image, ImageFont -from .geometry import get_surf_name +from .geometry import estimate_overlay_thresholds, get_surf_name from .geometry.prepare import prepare_and_validate_geometry from .gl.utils import capture_window, create_window_with_fallback, render_scene, setup_shader, terminate_context from .gl.views import get_view_matrices @@ -124,9 +124,6 @@ def snap1( ref_width = 700 ref_height = 500 ui_scale = min(width / ref_width, height / ref_height) - - # Screen size check only makes sense with a real display; skip on headless. - # create_window_with_fallback will handle context creation + EGL fallback. try: if glfw.init(): primary_monitor = glfw.get_primary_monitor() @@ -153,92 +150,93 @@ def snap1( # will raise exception if it cannot be created window = create_window_with_fallback(brain_display_width, brain_display_height, "WhipperSnapPy", visible=True) + try: + meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry( + meshpath, + overlaypath, + annotpath, + curvpath, + labelpath, + fthresh, + fmax, + invert, + scale=brain_scale, + color_mode=color_mode, + ) + + shader = setup_shader(meshdata, triangles, brain_display_width, brain_display_height, + specular=specular, ambient=ambient) - meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry( - meshpath, - overlaypath, - annotpath, - curvpath, - labelpath, - fthresh, - fmax, - invert, - scale=brain_scale, - color_mode=color_mode, - ) - - shader = setup_shader(meshdata, triangles, brain_display_width, brain_display_height, - specular=specular, ambient=ambient) - - transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) - view_mats = get_view_matrices() - viewmat = transl * (view_mats[view] if viewmat is None else viewmat) - render_scene(shader, triangles, viewmat) - - # Center the brain rendering in the output image, clamp to zero - brain_x = max(0, (width - brain_display_width) // 2) - brain_y = max(0, (height - brain_display_height) // 2) - image.paste(capture_window(window), (brain_x, brain_y)) - - bar = ( - create_colorbar( - fthresh, fmax, invert, orientation, colorbar_scale * ui_scale, pos, neg, font_file=font_file + transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) + view_mats = get_view_matrices() + viewmat = transl * (view_mats[view] if viewmat is None else viewmat) + render_scene(shader, triangles, viewmat) + + # Center the brain rendering in the output image, clamp to zero + brain_x = max(0, (width - brain_display_width) // 2) + brain_y = max(0, (height - brain_display_height) // 2) + image.paste(capture_window(window), (brain_x, brain_y)) + + bar = ( + create_colorbar( + fthresh, fmax, invert, orientation, colorbar_scale * ui_scale, pos, neg, font_file=font_file + ) + if overlaypath is not None and colorbar + else None ) - if overlaypath is not None and colorbar - else None - ) - font = ( - load_roboto_font(int(20 * caption_scale * ui_scale)) - if font_file is None - else ImageFont.truetype(font_file, int(20 * caption_scale * ui_scale)) - if caption - else None - ) - - # Compute positions to avoid overlap, unless explicit positions are given - text_w, text_h = text_size(caption, font) if caption and font else (0, 0) - bar_h = bar.height if bar is not None else 0 - gap = int(4 * ui_scale) - bottom_pad = int(20 * ui_scale) - - if orientation == OrientationType.HORIZONTAL: - # If explicit positions are given, use them - if colorbar_x is not None or colorbar_y is not None or caption_x is not None or caption_y is not None: + font = ( + load_roboto_font(int(20 * caption_scale * ui_scale)) + if font_file is None + else ImageFont.truetype(font_file, int(20 * caption_scale * ui_scale)) + if caption + else None + ) + + # Compute positions to avoid overlap, unless explicit positions are given + text_w, text_h = text_size(caption, font) if caption and font else (0, 0) + bar_h = bar.height if bar is not None else 0 + gap = int(4 * ui_scale) + bottom_pad = int(20 * ui_scale) + + if orientation == OrientationType.HORIZONTAL: + # If explicit positions are given, use them + if colorbar_x is not None or colorbar_y is not None or caption_x is not None or caption_y is not None: + bx = int(colorbar_x * width) if colorbar_x is not None else None + by = int(colorbar_y * height) if colorbar_y is not None else None + cx = int(caption_x * width) if caption_x is not None else None + cy = int(caption_y * height) if caption_y is not None else None + draw_colorbar(image, bar, orientation, x=bx, y=by) + draw_caption(image, caption, font, orientation, x=cx, y=cy) + else: + # Place colorbar above caption if both present + if bar is not None and caption: + bar_y = image.height - bottom_pad - text_h - gap - bar_h + caption_y = image.height - bottom_pad - text_h + elif bar is not None: + bar_y = image.height - bottom_pad - bar_h + caption_y = None + elif caption: + bar_y = None + caption_y = image.height - bottom_pad - text_h + else: + bar_y = caption_y = None + draw_colorbar(image, bar, orientation, y=bar_y) + draw_caption(image, caption, font, orientation, y=caption_y) + else: + # For vertical, allow explicit x/y for both, else use default bx = int(colorbar_x * width) if colorbar_x is not None else None by = int(colorbar_y * height) if colorbar_y is not None else None cx = int(caption_x * width) if caption_x is not None else None cy = int(caption_y * height) if caption_y is not None else None draw_colorbar(image, bar, orientation, x=bx, y=by) draw_caption(image, caption, font, orientation, x=cx, y=cy) - else: - # Place colorbar above caption if both present - if bar is not None and caption: - bar_y = image.height - bottom_pad - text_h - gap - bar_h - caption_y = image.height - bottom_pad - text_h - elif bar is not None: - bar_y = image.height - bottom_pad - bar_h - caption_y = None - elif caption: - bar_y = None - caption_y = image.height - bottom_pad - text_h - else: - bar_y = caption_y = None - draw_colorbar(image, bar, orientation, y=bar_y) - draw_caption(image, caption, font, orientation, y=caption_y) - else: - # For vertical, allow explicit x/y for both, else use default - bx = int(colorbar_x * width) if colorbar_x is not None else None - by = int(colorbar_y * height) if colorbar_y is not None else None - cx = int(caption_x * width) if caption_x is not None else None - cy = int(caption_y * height) if caption_y is not None else None - draw_colorbar(image, bar, orientation, x=bx, y=by) - draw_caption(image, caption, font, orientation, x=cx, y=cy) - - if outpath: - logger.info("Saving snapshot to %s", outpath) - image.save(outpath) - terminate_context(window) - return image + + if outpath: + logger.info("Saving snapshot to %s", outpath) + image.save(outpath) + return image + finally: + terminate_context(window) def snap4( @@ -333,133 +331,149 @@ def snap4( """ wwidth = 540 wheight = 450 - # will raise exception if it cannot be created - window = create_window_with_fallback(wwidth, wheight, "WhipperSnapPy", visible=True) - # Use standard view matrices from get_view_matrices and ViewType - view_mats = get_view_matrices() - view_left = view_mats[ViewType.LEFT] - view_right = view_mats[ViewType.RIGHT] - transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) - - # Predefine hemisphere images so static analysis knows they exist even if - # an earlier step raises an exception (we still will fail at runtime). - lhimg = None - rhimg = None - - for hemi in ("lh", "rh"): - if surfname is None: - if sdir is None: - sdir = os.environ.get("SUBJECTS_DIR") - if not sdir: - logger.error("No surf_name or subjects directory (sdir) provided") - raise ValueError("No surf_name or SUBJECTS_DIR provided") - found_surfname = get_surf_name(sdir, hemi) - if found_surfname is None: - logger.error("Could not find valid surface in %s for hemi: %s!", sdir, hemi) - raise FileNotFoundError(f"Could not find valid surface in {sdir} for hemi: {hemi}") - meshpath = os.path.join(sdir, "surf", hemi + "." + found_surfname) - else: - meshpath = os.path.join(sdir, "surf", hemi + "." + surfname) - - # Assign derived paths - curvpath = os.path.join(sdir, "surf", hemi + "." + curvname) if curvname else None - labelpath = os.path.join(sdir, "label", hemi + "." + labelname) if labelname else None - overlaypath = lhoverlaypath if hemi == "lh" else rhoverlaypath - annotpath = lhannotpath if hemi == "lh" else rhannotpath - - # Diagnostic: report mesh and overlay paths and whether they exist - logger.debug("hemisphere=%s", hemi) - logger.debug("meshpath=%s exists=%s", meshpath, os.path.exists(meshpath)) - if overlaypath is not None: - logger.debug("overlaypath=%s exists=%s", overlaypath, os.path.exists(overlaypath)) - if annotpath is not None: - logger.debug("annotpath=%s exists=%s", annotpath, os.path.exists(annotpath)) - if curvpath is not None: - logger.debug("curvpath=%s exists=%s", curvpath, os.path.exists(curvpath)) - - try: - meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry( - meshpath, overlaypath, annotpath, curvpath, labelpath, fthresh, fmax, invert, - scale=brain_scale, color_mode=color_mode - ) - except Exception as e: - logger.error("prepare_geometry failed for %s: %s", meshpath, e) - terminate_context(window) - return None - - # Diagnostics about mesh data - try: - logger.debug("meshdata shape: %s; triangles count: %s", getattr(meshdata, 'shape', None), - getattr(triangles, 'size', None)) - except Exception: - pass - - if pos == 0 and neg == 0: - logger.error("Overlay has no values to display") - raise ValueError("Overlay has no values to display") - - try: - shader = setup_shader(meshdata, triangles, wwidth, wheight, specular=specular, ambient=ambient) - logger.debug("Shader setup complete") - except Exception as e: - logger.error("setup_shader failed: %s", e) - terminate_context(window) - return None - - render_scene(shader, triangles, transl * view_left) - im1 = capture_window(window) - render_scene(shader, triangles, transl * view_right) - im2 = capture_window(window) - - if hemi == "lh": - lhimg = Image.new("RGB", (im1.width, im1.height + im2.height)) - lhimg.paste(im1, (0, 0)) - lhimg.paste(im2, (0, im1.height)) - else: - rhimg = Image.new("RGB", (im1.width, im1.height + im2.height)) - # For right hemisphere, reverse the order: top=im2, bottom=im1 - rhimg.paste(im2, (0, 0)) - rhimg.paste(im1, (0, im2.height)) - - # Add small padding around each hemisphere to avoid cropping at edges - pad = max(4, int(0.03 * wwidth)) - padded_lh = Image.new("RGB", (lhimg.width + 2 * pad, lhimg.height + 2 * pad), (0, 0, 0)) - padded_lh.paste(lhimg, (pad, pad)) - padded_rh = Image.new("RGB", (rhimg.width + 2 * pad, rhimg.height + 2 * pad), (0, 0, 0)) - padded_rh.paste(rhimg, (pad, pad)) - - image = Image.new("RGB", (padded_lh.width + padded_rh.width, padded_lh.height)) - image.paste(padded_lh, (0, 0)) - image.paste(padded_rh, (padded_lh.width, 0)) - - font = load_roboto_font(20) if font_file is None else ImageFont.truetype(font_file, 20) if caption else None - # Place caption at bottom, colorbar above if both present - text_w, text_h = text_size(caption, font) if caption and font else (0, 0) - bottom_pad = 20 - gap = 4 - caption_y = image.height - bottom_pad - text_h - bar = ( - create_colorbar(fthresh, fmax, invert, pos=pos, neg=neg) - if lhannotpath is None and rhannotpath is None and colorbar - else None - ) - bar_h = bar.height if bar is not None else 0 - if bar is not None and caption: - bar_y = image.height - bottom_pad - text_h - gap - bar_h - draw_colorbar(image, bar, OrientationType.HORIZONTAL, y=bar_y) - draw_caption(image, caption, font, OrientationType.HORIZONTAL, y=caption_y) - elif bar is not None: - bar_y = image.height - bottom_pad - bar_h - draw_colorbar(image, bar, OrientationType.HORIZONTAL, y=bar_y) - elif caption: - draw_caption(image, caption, font, OrientationType.HORIZONTAL, y=caption_y) - - # If outpath is specified, save to disk - if outpath: - logger.info("Saving snapshot to %s", outpath) - image.save(outpath) - - terminate_context(window) - return image + # Resolve sdir early so path-building works for both the pre-pass and + # the rendering loop. + if sdir is None: + sdir = os.environ.get("SUBJECTS_DIR") + if not sdir and surfname is None: + logger.error("No sdir or SUBJECTS_DIR provided") + raise ValueError("No sdir or SUBJECTS_DIR provided") + if not sdir and surfname is not None: + logger.error("surfname provided but sdir is None") + raise ValueError("surfname provided but sdir is None; cannot construct meshpath.") + + # Pre-pass: estimate missing fthresh/fmax from overlays for global color scale + has_overlay = lhoverlaypath is not None or rhoverlaypath is not None + if has_overlay and (fthresh is None or fmax is None): + est_fthreshs = [] + est_fmaxs = [] + for overlaypath in filter(None, (lhoverlaypath, rhoverlaypath)): + h_fthresh, h_fmax = estimate_overlay_thresholds(overlaypath, fthresh, fmax) + est_fthreshs.append(h_fthresh) + est_fmaxs.append(h_fmax) + if fthresh is None and est_fthreshs: + fthresh = min(est_fthreshs) + if fmax is None and est_fmaxs: + fmax = max(est_fmaxs) + logger.debug("Global color range: fthresh=%s fmax=%s", fthresh, fmax) + # will raise exception if it cannot be created + window = create_window_with_fallback(wwidth, wheight, "WhipperSnapPy", visible=True) + try: + # Use standard view matrices from get_view_matrices and ViewType + view_mats = get_view_matrices() + view_left = view_mats[ViewType.LEFT] + view_right = view_mats[ViewType.RIGHT] + transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) + + # Predefine hemisphere images so static analysis knows they exist even if + # an earlier step raises an exception (we still will fail at runtime). + lhimg = None + rhimg = None + + for hemi in ("lh", "rh"): + if surfname is None: + found_surfname = get_surf_name(sdir, hemi) + if found_surfname is None: + logger.error("Could not find valid surface in %s for hemi: %s!", sdir, hemi) + raise FileNotFoundError(f"Could not find valid surface in {sdir} for hemi: {hemi}") + meshpath = os.path.join(sdir, "surf", hemi + "." + found_surfname) + else: + meshpath = os.path.join(sdir, "surf", hemi + "." + surfname) + + # Assign derived paths + curvpath = os.path.join(sdir, "surf", hemi + "." + curvname) if curvname else None + labelpath = os.path.join(sdir, "label", hemi + "." + labelname) if labelname else None + overlaypath = lhoverlaypath if hemi == "lh" else rhoverlaypath + annotpath = lhannotpath if hemi == "lh" else rhannotpath + + # Diagnostic: report mesh and overlay paths and whether they exist + logger.debug("hemisphere=%s", hemi) + logger.debug("meshpath=%s exists=%s", meshpath, os.path.exists(meshpath)) + if overlaypath is not None: + logger.debug("overlaypath=%s exists=%s", overlaypath, os.path.exists(overlaypath)) + if annotpath is not None: + logger.debug("annotpath=%s exists=%s", annotpath, os.path.exists(annotpath)) + if curvpath is not None: + logger.debug("curvpath=%s exists=%s", curvpath, os.path.exists(curvpath)) + + try: + meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry( + meshpath, overlaypath, annotpath, curvpath, labelpath, fthresh, fmax, invert, + scale=brain_scale, color_mode=color_mode + ) + except Exception as e: + logger.error("prepare_geometry failed for %s: %s", meshpath, e) + raise + + # Diagnostics about mesh data + try: + logger.debug("meshdata shape: %s; triangles count: %s", getattr(meshdata, 'shape', None), + getattr(triangles, 'size', None)) + except Exception: + pass + + try: + shader = setup_shader(meshdata, triangles, wwidth, wheight, specular=specular, ambient=ambient) + logger.debug("Shader setup complete") + except Exception as e: + logger.error("setup_shader failed: %s", e) + raise + + render_scene(shader, triangles, transl * view_left) + im1 = capture_window(window) + render_scene(shader, triangles, transl * view_right) + im2 = capture_window(window) + + if hemi == "lh": + lhimg = Image.new("RGB", (im1.width, im1.height + im2.height)) + lhimg.paste(im1, (0, 0)) + lhimg.paste(im2, (0, im1.height)) + else: + rhimg = Image.new("RGB", (im1.width, im1.height + im2.height)) + # For right hemisphere, reverse the order: top=im2, bottom=im1 + rhimg.paste(im2, (0, 0)) + rhimg.paste(im1, (0, im2.height)) + + # Add small padding around each hemisphere to avoid cropping at edges + pad = max(4, int(0.03 * wwidth)) + padded_lh = Image.new("RGB", (lhimg.width + 2 * pad, lhimg.height + 2 * pad), (0, 0, 0)) + padded_lh.paste(lhimg, (pad, pad)) + padded_rh = Image.new("RGB", (rhimg.width + 2 * pad, rhimg.height + 2 * pad), (0, 0, 0)) + padded_rh.paste(rhimg, (pad, pad)) + + image = Image.new("RGB", (padded_lh.width + padded_rh.width, padded_lh.height)) + image.paste(padded_lh, (0, 0)) + image.paste(padded_rh, (padded_lh.width, 0)) + + font = load_roboto_font(20) if font_file is None else ImageFont.truetype(font_file, 20) if caption else None + # Place caption at bottom, colorbar above if both present + text_w, text_h = text_size(caption, font) if caption and font else (0, 0) + bottom_pad = 20 + gap = 4 + caption_y = image.height - bottom_pad - text_h + bar = ( + create_colorbar(fthresh, fmax, invert, pos=pos, neg=neg) + if lhannotpath is None and rhannotpath is None and colorbar + else None + ) + bar_h = bar.height if bar is not None else 0 + if bar is not None and caption: + bar_y = image.height - bottom_pad - text_h - gap - bar_h + draw_colorbar(image, bar, OrientationType.HORIZONTAL, y=bar_y) + draw_caption(image, caption, font, OrientationType.HORIZONTAL, y=caption_y) + elif bar is not None: + bar_y = image.height - bottom_pad - bar_h + draw_colorbar(image, bar, OrientationType.HORIZONTAL, y=bar_y) + elif caption: + draw_caption(image, caption, font, OrientationType.HORIZONTAL, y=caption_y) + + # If outpath is specified, save to disk + if outpath: + logger.info("Saving snapshot to %s", outpath) + image.save(outpath) + + return image + finally: + terminate_context(window) From 55bfdda3bcd55ab526835d079a0c353a84a7b9d6 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Thu, 19 Feb 2026 22:12:51 +0100 Subject: [PATCH 40/83] update build and test --- .github/workflows/build.yml | 3 +++ .github/workflows/pytest.yml | 3 +++ whippersnappy/__init__.py | 11 ++++++++--- whippersnappy/_config.py | 16 +++++++++++----- whippersnappy/gl/_platform.py | 11 +++++++---- whippersnappy/gl/utils.py | 21 ++++++++++++--------- 6 files changed, 44 insertions(+), 21 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b493b94..2b04065 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -28,6 +28,9 @@ jobs: uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} + - name: Install system OpenGL/EGL libraries (Ubuntu) + if: matrix.os == 'ubuntu' + run: sudo apt-get install -y --no-install-recommends libegl1 libgl1 - name: Install dependencies run: | python -m pip install --progress-bar off --upgrade pip setuptools wheel diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 54be48f..59f2311 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -32,6 +32,9 @@ jobs: uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} + - name: Install system OpenGL/EGL libraries (Ubuntu) + if: matrix.os == 'ubuntu' + run: sudo apt-get install -y --no-install-recommends libegl1 libgl1 - name: Install package run: | python -m pip install --progress-bar off --upgrade pip setuptools wheel diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index 0e86672..7660d73 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -47,9 +47,14 @@ import sys def _check_display(): - """Return True if a working X11 display connection can be opened.""" - # macOS uses CGL/Cocoa — GLFW handles context creation natively, no EGL needed - if sys.platform == "darwin": + """Return True if a display is available or the platform handles GL natively. + + On macOS (CGL) and Windows (WGL) PyOpenGL does not use EGL, so we always + return True to avoid setting ``PYOPENGL_PLATFORM=egl`` on those systems. + On Linux we probe for an X11/Wayland display server. + """ + if sys.platform != "linux": + # macOS uses CGL, Windows uses WGL — no EGL needed on either. return True display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY") if not display: diff --git a/whippersnappy/_config.py b/whippersnappy/_config.py index 871f75b..6617445 100644 --- a/whippersnappy/_config.py +++ b/whippersnappy/_config.py @@ -90,6 +90,7 @@ def sys_info(fid: Optional[IO] = None, developer: bool = False): "style", ) for key in keys: + _from_pyproject = False try: raw_requires = requires(package) or [] except Exception: @@ -113,14 +114,19 @@ def sys_info(fid: Optional[IO] = None, developer: bool = False): opt = proj.get("optional-dependencies", {}) or {} deps = opt.get(key, []) or [] raw_requires = deps + _from_pyproject = True except Exception: raw_requires = [] - dependencies = [ - elt.split(";")[0].rstrip() - for elt in raw_requires - if f"extra == '{key}'" in elt or f"extra == \"{key}\"" in elt - ] + if _from_pyproject: + # pyproject.toml entries have no extras markers — use as-is + dependencies = [elt.split(";")[0].rstrip() for elt in raw_requires] + else: + dependencies = [ + elt.split(";")[0].rstrip() + for elt in raw_requires + if f"extra == '{key}'" in elt or f"extra == \"{key}\"" in elt + ] if len(dependencies) == 0: continue out(f"\nOptional '{key}' info\n") diff --git a/whippersnappy/gl/_platform.py b/whippersnappy/gl/_platform.py index e68db02..8a94ad4 100644 --- a/whippersnappy/gl/_platform.py +++ b/whippersnappy/gl/_platform.py @@ -1,14 +1,17 @@ """Bootstrap PyOpenGL platform selection — must be imported first. Imported unconditionally at the top of gl/__init__.py before any other -OpenGL symbol. Sets PYOPENGL_PLATFORM=egl when no display is available +OpenGL symbol. Sets PYOPENGL_PLATFORM=egl when running headless on Linux so that PyOpenGL uses the EGL backend instead of GLX. -If the user has already set PYOPENGL_PLATFORM, that value is respected. +On macOS PyOpenGL uses CGL and on Windows it uses WGL — both are handled +natively without EGL. If the user has already set PYOPENGL_PLATFORM that +value is always respected. """ import os +import sys -if "PYOPENGL_PLATFORM" not in os.environ: - # No display = definitely headless, force EGL now before any import +if "PYOPENGL_PLATFORM" not in os.environ and sys.platform == "linux": + # No X11/Wayland display on Linux → force EGL headless backend. if not os.environ.get("DISPLAY") and not os.environ.get("WAYLAND_DISPLAY"): os.environ["PYOPENGL_PLATFORM"] = "egl" diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index 8112b9a..568dcca 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -5,6 +5,7 @@ import logging import sys +from typing import Any import glfw import OpenGL.GL as gl @@ -18,7 +19,7 @@ logger = logging.getLogger(__name__) # Module-level EGL context handle (None when GLFW is used instead) -_egl_context: "EGLContext | None" = None +_egl_context: Any = None def create_vao(): @@ -253,15 +254,15 @@ def create_window_with_fallback(width, height, title="WhipperSnapPy", visible=Tr if window: return window - # --- Step 3: EGL headless pbuffer --- + # --- Step 3: EGL headless pbuffer (Linux only) --- logger.warning( "GLFW context creation failed entirely (no display?). " "Attempting EGL headless context." ) - if sys.platform == "darwin": + if sys.platform != "linux": raise RuntimeError( - "Could not create any OpenGL context via GLFW on macOS. " - "Ensure you are running with a display available." + f"Could not create any OpenGL context via GLFW on {sys.platform}. " + "Ensure a display is available." ) try: from .egl_context import EGLContext @@ -292,7 +293,7 @@ def terminate_context(window): """ global _egl_context if _egl_context is not None: - _egl_context.destroy() + _egl_context.destroy() # type: ignore[union-attr] _egl_context = None else: glfw.terminate() @@ -363,7 +364,7 @@ def capture_window(window): # --- EGL path: read directly from the FBO --- if _egl_context is not None: - return _egl_context.read_pixels() + return _egl_context.read_pixels() # type: ignore[union-attr] # --- GLFW path: read from the default framebuffer --- monitor = glfw.get_primary_monitor() @@ -376,7 +377,10 @@ def capture_window(window): gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) img_buf = gl.glReadPixels(0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE) image = Image.frombytes("RGB", (width, height), img_buf) - image = image.transpose(Image.FLIP_TOP_BOTTOM) + # Image.Transpose.FLIP_TOP_BOTTOM is the preferred form since Pillow 9.1; + # fall back to the legacy integer constant for older installations. + _flip = getattr(Image, "Transpose", Image).FLIP_TOP_BOTTOM + image = image.transpose(_flip) if x_scale != 1 or y_scale != 1: rwidth = int(round(width / x_scale)) @@ -419,4 +423,3 @@ def render_scene(shader, triangles, transform): if err != gl.GL_NO_ERROR: logger.error("OpenGL error after draw: %s", err) raise RuntimeError(f"OpenGL error after draw: {err}") - From 3d1fe010649de5e927369e6bc0b9af2bee51dd6a Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Thu, 19 Feb 2026 22:40:48 +0100 Subject: [PATCH 41/83] fix ruff errors --- whippersnappy/__init__.py | 14 ++++++++------ whippersnappy/gl/__init__.py | 3 +-- whippersnappy/gl/egl_context.py | 17 +++++++++++------ 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index 7660d73..500825f 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -46,6 +46,7 @@ import os import sys + def _check_display(): """Return True if a display is available or the platform handles GL natively. @@ -60,7 +61,8 @@ def _check_display(): if not display: return False try: - import ctypes, ctypes.util + import ctypes + import ctypes.util libx11 = ctypes.CDLL(ctypes.util.find_library("X11") or "libX11.so.6") libx11.XOpenDisplay.restype = ctypes.c_void_p libx11.XOpenDisplay.argtypes = [ctypes.c_char_p] @@ -78,14 +80,14 @@ def _check_display(): if not _check_display(): os.environ["PYOPENGL_PLATFORM"] = "egl" -from ._config import sys_info # noqa: F401 -from ._version import __version__ # noqa: F401 -from .snap import snap1, snap4 -from .utils.types import ViewType +from ._config import sys_info # noqa: F401, E402 +from ._version import __version__ # noqa: F401, E402 +from .snap import snap1, snap4 # noqa: E402 +from .utils.types import ViewType # noqa: E402 # 3D plotting for notebooks (Three.js-based, works in all Jupyter environments) try: - from .plot3d import plot3d + from .plot3d import plot3d # noqa: E402 _has_plot3d = True except ImportError: _has_plot3d = False diff --git a/whippersnappy/gl/__init__.py b/whippersnappy/gl/__init__.py index 5876f0a..a2654d7 100644 --- a/whippersnappy/gl/__init__.py +++ b/whippersnappy/gl/__init__.py @@ -7,8 +7,7 @@ """ -from . import _platform # noqa: F401 — MUST be first; sets PYOPENGL_PLATFORM - +from . import _platform # noqa: F401 — MUST be first; sets PYOPENGL_PLATFORM from .camera import make_model, make_projection, make_view from .shaders import get_default_shaders, get_webgl_shaders from .utils import ( diff --git a/whippersnappy/gl/egl_context.py b/whippersnappy/gl/egl_context.py index f0e4fbf..950dffa 100644 --- a/whippersnappy/gl/egl_context.py +++ b/whippersnappy/gl/egl_context.py @@ -26,6 +26,7 @@ import logging import os import sys + if sys.platform == "darwin": raise ImportError("EGL is not available on macOS; use GLFW/CGL instead.") @@ -344,20 +345,24 @@ def destroy(self): libegl = self._libegl # GL cleanup first (context must be current) if self.fbo is not None: - gl.glDeleteFramebuffers(1, [self.fbo]); + gl.glDeleteFramebuffers(1, [self.fbo]) self.fbo = None if self._rbo_color is not None: - gl.glDeleteRenderbuffers(1, [self._rbo_color]); + gl.glDeleteRenderbuffers(1, [self._rbo_color]) self._rbo_color = None if self._rbo_depth is not None: - gl.glDeleteRenderbuffers(1, [self._rbo_depth]); + gl.glDeleteRenderbuffers(1, [self._rbo_depth]) self._rbo_depth = None if self._display: libegl.eglMakeCurrent(self._display, None, None, None) - if self._context: libegl.eglDestroyContext(self._display, self._context) - if self._surface: libegl.eglDestroySurface(self._display, self._surface) + if self._context: + libegl.eglDestroyContext(self._display, self._context) + if self._surface: + libegl.eglDestroySurface(self._display, self._surface) libegl.eglTerminate(self._display) - self._display = self._context = self._surface = None + self._display = None + self._context = None + self._surface = None logger.debug("EGL context destroyed.") # Allow use as a context manager From 4c5b8c2dda1d16eea1ae35bb9abd76920d057bc0 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Thu, 19 Feb 2026 23:01:15 +0100 Subject: [PATCH 42/83] shortcut EGL without DISPLAY on linux --- whippersnappy/gl/utils.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index 568dcca..d84d17d 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -240,6 +240,29 @@ def create_window_with_fallback(width, height, title="WhipperSnapPy", visible=Tr """ global _egl_context + # Fast-path: if _check_display() already determined there is no working + # display, skip the two doomed GLFW attempts and go straight to EGL. + # This avoids warning noise and wasted time in Docker/CI/headless SSH. + # The sys.platform guard is preserved — EGL is Linux-only. + if os.environ.get("PYOPENGL_PLATFORM") == "egl": + if sys.platform != "linux": + raise RuntimeError( + f"Could not create any OpenGL context via GLFW on {sys.platform}. " + "Ensure a display is available." + ) + logger.info("No working display detected — using EGL headless directly.") + try: + from .egl_context import EGLContext + ctx = EGLContext(width, height) + ctx.make_current() + _egl_context = ctx + logger.info("Using EGL headless context — no display server required.") + return None + except (ImportError, RuntimeError) as exc: + raise RuntimeError( + f"EGL headless context failed: {exc}" + ) from exc + # --- Step 1: GLFW visible window --- window = init_window(width, height, title, visible=visible) if window: From 91c34da1cc8120dd32b90ad36147913a63646c99 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Thu, 19 Feb 2026 23:26:57 +0100 Subject: [PATCH 43/83] new feature: video of rotation --- pyproject.toml | 5 + whippersnappy/__init__.py | 3 +- whippersnappy/cli/whippersnap1.py | 122 ++++++++++++++----- whippersnappy/gl/utils.py | 1 + whippersnappy/snap.py | 192 ++++++++++++++++++++++++++++++ 5 files changed, 294 insertions(+), 29 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 70e925a..589d6dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,10 @@ notebook = [ gui = [ 'PyQt6', ] +video = [ + 'imageio>=2.28', + 'imageio-ffmpeg>=0.4.9', # bundles its own ffmpeg binary +] style = [ 'bibclean', 'codespell', @@ -90,6 +94,7 @@ all = [ 'whippersnappy[test]', 'whippersnappy[notebook]', 'whippersnappy[gui]', + 'whippersnappy[video]', ] full = [ 'whippersnappy[all]', diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index 500825f..ca88081 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -82,7 +82,7 @@ def _check_display(): from ._config import sys_info # noqa: F401, E402 from ._version import __version__ # noqa: F401, E402 -from .snap import snap1, snap4 # noqa: E402 +from .snap import snap1, snap4, snap_rotate # noqa: E402 from .utils.types import ViewType # noqa: E402 # 3D plotting for notebooks (Three.js-based, works in all Jupyter environments) @@ -98,6 +98,7 @@ def _check_display(): "sys_info", "snap1", "snap4", + "snap_rotate", ] if _has_plot3d: diff --git a/whippersnappy/cli/whippersnap1.py b/whippersnappy/cli/whippersnap1.py index 832cfd0..0a3583b 100644 --- a/whippersnappy/cli/whippersnap1.py +++ b/whippersnappy/cli/whippersnap1.py @@ -6,7 +6,7 @@ import os import tempfile -from .. import snap1 +from .. import snap1, snap_rotate from .._version import __version__ from ..utils.types import ColorSelection, OrientationType, ViewType @@ -22,7 +22,8 @@ def run(): prog="whippersnap1", description=( "Render a single-view screenshot of any triangular surface mesh " - "(FreeSurfer or otherwise) without a GUI." + "(FreeSurfer or otherwise) without a GUI. " + "Pass --rotate to produce a 360° rotation video instead." ), ) parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}") @@ -39,8 +40,12 @@ def run(): parser.add_argument( "-o", "--output", type=str, - default=os.path.join(tempfile.gettempdir(), "whippersnappy_snap1.png"), - help="Output PNG path. Defaults to a temp file.", + default=None, + help=( + "Output file path. For snapshots defaults to a temp .png file; " + "for rotation videos defaults to a temp .mp4 file. " + "Use .gif for an animated GIF (no ffmpeg required)." + ), ) # --- Optional overlay / annotation / label / curv --- @@ -89,31 +94,92 @@ def run(): parser.add_argument("--font", type=str, default=None, help="Path to a TTF font for captions.") + # --- Rotation video --- + rotate_group = parser.add_argument_group("rotation video (--rotate)") + rotate_group.add_argument( + "--rotate", + action="store_true", + help="Produce a 360° rotation video instead of a static snapshot.", + ) + rotate_group.add_argument( + "--rotate-frames", + type=int, + default=72, + metavar="N", + help="Number of frames for a full rotation (default: 72, i.e. 5° per frame).", + ) + rotate_group.add_argument( + "--rotate-fps", + type=int, + default=24, + metavar="FPS", + help="Frame rate of the output video (default: 24).", + ) + rotate_group.add_argument( + "--rotate-start-view", + type=str, + default="left", + choices=list(_VIEW_CHOICES), + metavar="VIEW", + help="Starting view for the rotation (default: left).", + ) + args = parser.parse_args() + log = logging.getLogger(__name__) + try: - img = snap1( - meshpath=args.meshpath, - outpath=args.output, - overlaypath=args.overlay, - annotpath=args.annot, - labelpath=args.label, - curvpath=args.curv, - view=_VIEW_CHOICES[args.view], - width=args.width, - height=args.height, - fthresh=args.fthresh, - fmax=args.fmax, - caption=args.caption, - invert=args.invert, - colorbar=not args.no_colorbar, - color_mode=_COLOR_CHOICES[args.color_mode], - orientation=_ORIENT_CHOICES[args.orientation], - font_file=args.font, - specular=args.specular, - brain_scale=args.brain_scale, - ambient=args.ambient, - ) - logging.getLogger(__name__).info("Snapshot saved to %s (%dx%d)", args.output, img.width, img.height) - except (RuntimeError, FileNotFoundError, ValueError) as e: + if args.rotate: + outpath = args.output or os.path.join( + tempfile.gettempdir(), "whippersnappy_rotation.mp4" + ) + snap_rotate( + meshpath=args.meshpath, + outpath=outpath, + n_frames=args.rotate_frames, + fps=args.rotate_fps, + width=args.width, + height=args.height, + overlaypath=args.overlay, + curvpath=args.curv, + annotpath=args.annot, + labelpath=args.label, + fthresh=args.fthresh, + fmax=args.fmax, + invert=args.invert, + specular=args.specular, + ambient=args.ambient, + brain_scale=args.brain_scale, + start_view=_VIEW_CHOICES[args.rotate_start_view], + color_mode=_COLOR_CHOICES[args.color_mode], + ) + log.info("Rotation video saved to %s", outpath) + else: + outpath = args.output or os.path.join( + tempfile.gettempdir(), "whippersnappy_snap1.png" + ) + img = snap1( + meshpath=args.meshpath, + outpath=outpath, + overlaypath=args.overlay, + annotpath=args.annot, + labelpath=args.label, + curvpath=args.curv, + view=_VIEW_CHOICES[args.view], + width=args.width, + height=args.height, + fthresh=args.fthresh, + fmax=args.fmax, + caption=args.caption, + invert=args.invert, + colorbar=not args.no_colorbar, + color_mode=_COLOR_CHOICES[args.color_mode], + orientation=_ORIENT_CHOICES[args.orientation], + font_file=args.font, + specular=args.specular, + brain_scale=args.brain_scale, + ambient=args.ambient, + ) + log.info("Snapshot saved to %s (%dx%d)", outpath, img.width, img.height) + except (RuntimeError, FileNotFoundError, ValueError, ImportError) as e: parser.error(str(e)) diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index d84d17d..de279d7 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -4,6 +4,7 @@ """ import logging +import os import sys from typing import Any diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py index daddd6e..7095d6c 100644 --- a/whippersnappy/snap.py +++ b/whippersnappy/snap.py @@ -4,6 +4,7 @@ import os import glfw +import numpy as np import pyrr from PIL import Image, ImageFont @@ -477,3 +478,194 @@ def snap4( return image finally: terminate_context(window) + + +def snap_rotate( + meshpath, + outpath, + n_frames=72, + fps=24, + width=700, + height=500, + overlaypath=None, + curvpath=None, + annotpath=None, + labelpath=None, + fthresh=None, + fmax=None, + invert=False, + specular=True, + ambient=0.0, + brain_scale=1.5, + start_view=ViewType.LEFT, + color_mode=ColorSelection.BOTH, +): + """Render a rotating 360° video of a surface mesh. + + Rotates the view around the vertical (Y) axis in ``n_frames`` equal + steps, captures each frame via OpenGL, and encodes the result into a + compressed video file using ``imageio`` with the ``ffmpeg`` backend + (provided by ``imageio-ffmpeg``). + + An animated GIF can also be produced by passing an ``outpath`` ending + in ``.gif``; in that case ``imageio-ffmpeg`` is not required. + + Parameters + ---------- + meshpath : str + Path to the surface file (FreeSurfer binary format, e.g. ``lh.white``). + outpath : str + Destination file path. The extension controls the output format: + + * ``.mp4`` — H.264 MP4 (recommended, requires ``imageio-ffmpeg``). + * ``.webm`` — VP9 WebM (requires ``imageio-ffmpeg``). + * ``.gif`` — animated GIF (no ffmpeg required, but larger file). + + n_frames : int, optional + Number of frames for a full 360° rotation. Default is ``72`` + (one frame every 5°). + fps : int, optional + Output frame rate in frames per second. Default is ``24``. + width, height : int, optional + Render resolution in pixels. Defaults are ``700`` and ``500``. + overlaypath : str or None, optional + Path to per-vertex overlay file (e.g. thickness). + curvpath : str or None, optional + Path to curvature file for texturing non-colored regions. + annotpath : str or None, optional + Path to FreeSurfer ``.annot`` file. + labelpath : str or None, optional + Path to label file used to mask overlay values. + fthresh : float or None, optional + Overlay threshold value. + fmax : float or None, optional + Overlay saturation value. + invert : bool, optional + Invert the overlay color scale. Default is ``False``. + specular : bool, optional + Enable specular highlights. Default is ``True``. + ambient : float, optional + Ambient lighting strength. Default is ``0.0``. + brain_scale : float, optional + Geometry scale factor. Default is ``1.5``. + start_view : ViewType, optional + Pre-defined view to start the rotation from. + Default is ``ViewType.LEFT``. + color_mode : ColorSelection, optional + Which overlay sign to color (POSITIVE/NEGATIVE/BOTH). + Default is ``ColorSelection.BOTH``. + + Returns + ------- + str + The resolved ``outpath`` that was written. + + Raises + ------ + ImportError + If ``imageio`` or ``imageio-ffmpeg`` is not installed and a + video format (``.mp4``, ``.webm``) was requested. + RuntimeError + If the OpenGL context cannot be initialised. + ValueError + If the overlay contains no values for the chosen color mode. + + Examples + -------- + >>> from whippersnappy import snap_rotate + >>> snap_rotate( + ... 'fsaverage/surf/lh.white', + ... '/tmp/rotation.mp4', + ... overlaypath='fsaverage/surf/lh.thickness', + ... ) + '/tmp/rotation.mp4' + """ + ext = os.path.splitext(outpath)[1].lower() + use_gif = ext == ".gif" + + if not use_gif: + try: + import imageio # noqa: F401 + import imageio_ffmpeg # noqa: F401 + except ImportError as exc: + raise ImportError( + f"Video output requires the 'imageio' and 'imageio-ffmpeg' packages. " + f"Install with: pip install 'whippersnappy[video]'\n" + f"Original error: {exc}" + ) from exc + import imageio + else: + try: + import imageio # noqa: F401 + except ImportError as exc: + raise ImportError( + "GIF output requires the 'imageio' package. " + "Install with: pip install 'whippersnappy[video]'" + ) from exc + import imageio + + window = create_window_with_fallback(width, height, "WhipperSnapPy", visible=True) + try: + meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry( + meshpath, + overlaypath, + annotpath, + curvpath, + labelpath, + fthresh, + fmax, + invert, + scale=brain_scale, + color_mode=color_mode, + ) + logger.info( + "Rendering %d frames at %dx%d (%.0f° per step) → %s", + n_frames, width, height, 360.0 / n_frames, outpath, + ) + + shader = setup_shader(meshdata, triangles, width, height, + specular=specular, ambient=ambient) + + transl = pyrr.Matrix44.from_translation((0, 0, 0.4)) + base_view = get_view_matrices()[start_view] + + frames = [] + for i in range(n_frames): + angle = 2 * np.pi * i / n_frames + rot = pyrr.Matrix44.from_y_rotation(angle) + viewmat = transl * rot * base_view + render_scene(shader, triangles, viewmat) + frames.append(np.array(capture_window(window))) + if (i + 1) % max(1, n_frames // 10) == 0: + logger.debug(" frame %d / %d", i + 1, n_frames) + + finally: + terminate_context(window) + + logger.info("Encoding %d frames to %s …", len(frames), outpath) + if use_gif: + # Pure-PIL GIF — no ffmpeg required + pil_frames = [Image.fromarray(f) for f in frames] + pil_frames[0].save( + outpath, + save_all=True, + append_images=pil_frames[1:], + loop=0, + duration=int(1000 / fps), + optimize=True, + ) + else: + writer_kwargs = { + "fps": fps, + "codec": "libx264", + "quality": 6, + "pixelformat": "yuv420p", + } + if ext == ".webm": + writer_kwargs["codec"] = "libvpx-vp9" + writer_kwargs.pop("pixelformat", None) + imageio.mimwrite(outpath, frames, **writer_kwargs) + + logger.info("Saved rotation video to %s", outpath) + return outpath + From 7eed4185e595fba2f23a23c57696fdc0898bc4bf Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Thu, 19 Feb 2026 23:42:14 +0100 Subject: [PATCH 44/83] restructure CLI, keep snap4 and GUI separate --- pyproject.toml | 1 + whippersnappy/cli/whippersnap.py | 396 +++++++++--------------------- whippersnappy/cli/whippersnap4.py | 146 +++++++++++ 3 files changed, 269 insertions(+), 274 deletions(-) create mode 100644 whippersnappy/cli/whippersnap4.py diff --git a/pyproject.toml b/pyproject.toml index 589d6dc..9b3b16c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,7 @@ tracker = 'https://github.com/Deep-MI/WhipperSnapPy/issues' [project.scripts] whippersnap = 'whippersnappy.cli.whippersnap:run' whippersnap1 = 'whippersnappy.cli.whippersnap1:run' +whippersnap4 = 'whippersnappy.cli.whippersnap4:run' whippersnappy-sys_info = 'whippersnappy.commands.sys_info:run' [tool.setuptools] diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index 161b9ac..940bc46 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -1,20 +1,18 @@ #!/usr/bin/python3 -""" -Executes the whippersnappy program in an interactive or non-interactive mode. +"""Interactive GUI viewer for WhipperSnapPy. -The non-interactive mode (the default) creates an image that contains four -views of the surface, an optional color bar, and a configurable caption. +Opens a live OpenGL window for a single hemisphere together with a +Qt-based configuration panel that allows adjusting overlay thresholds +at runtime. -The interactive mode (--interactive) opens a simple GUI with a controllable -view of one of the hemispheres. In addition, the view through a separate -configuration app which allows adjusting thresholds, etc. during runtime. +Usage:: -Usage: - $ python3 whippersnap.py -lh $LH_OVERLAY_FILE -rh $RH_OVERLAY_FILE \ - -sd $SURF_SUBJECT_DIR -o $OUTPUT_PATH + whippersnap -lh -sd + whippersnap --lh_annot --rh_annot -sd -(See help for full list of arguments.) +See ``whippersnap --help`` for the full list of options. +For non-interactive four-view batch rendering use ``whippersnap4``. """ import argparse @@ -22,7 +20,6 @@ import os import signal import sys -import tempfile import threading import glfw @@ -32,10 +29,9 @@ try: from PyQt6.QtWidgets import QApplication except Exception: - # GUI dependency missing; handle at runtime when interactive mode is requested + # GUI dependency missing; raise a clear error at runtime QApplication = None -from .. import snap4 from .._version import __version__ from ..geometry import get_surf_name, prepare_geometry from ..gl import get_view_matrices, init_window, setup_shader @@ -44,7 +40,7 @@ # Module logger logger = logging.getLogger(__name__) -# Global variables for config app configuration state: +# Global state shared between the GL thread and the Qt main thread current_fthresh_ = None current_fmax_ = None app_ = None @@ -76,12 +72,11 @@ def show_window( hemi : {'lh','rh'} Hemisphere to display. overlaypath : str or None, optional - Path to a per-vertex overlay file (e.g. thickness). If ``None`` no - overlay will be applied. + Path to a per-vertex overlay file (e.g. thickness). annotpath : str or None, optional - Path to a .annot file providing categorical labels for vertices. + Path to a ``.annot`` file providing categorical labels for vertices. sdir : str or None, optional - Subject directory containing `surf/` and `label/` subdirectories. + Subject directory containing ``surf/`` and ``label/`` subdirectories. caption : str or None, optional Caption text to display in the viewer window. invert : bool, optional, default False @@ -89,22 +84,17 @@ def show_window( labelname : str, optional, default 'cortex.label' Label filename used to mask vertices. surfname : str or None, optional - Surface basename (e.g. 'white'); if ``None`` the function will try - to auto-detect a suitable surface in ``sdir``. + Surface basename (e.g. ``'white'``); if ``None`` the function will + auto-detect a suitable surface in ``sdir``. curvname : str or None, optional, default 'curv' Curvature filename used to texture non-colored regions. specular : bool, optional, default True Enable specular highlights in the shader. - Returns - ------- - None - The function primarily drives an interactive event loop and does not return programmatic geometry objects. - Raises ------ RuntimeError - If the window/context could not be created. + If the GLFW window or OpenGL context could not be created. FileNotFoundError If a requested surface file cannot be located in ``sdir``. """ @@ -128,18 +118,11 @@ def show_window( else: meshpath = os.path.join(sdir, "surf", hemi + "." + surfname) - curvpath = None - if curvname: - curvpath = os.path.join(sdir, "surf", hemi + "." + curvname) - labelpath = None - if labelname: - labelpath = os.path.join(sdir, "label", hemi + "." + labelname) + curvpath = os.path.join(sdir, "surf", hemi + "." + curvname) if curvname else None + labelpath = os.path.join(sdir, "label", hemi + "." + labelname) if labelname else None - # set up canonical view matrix for the selected hemisphere view_mats = get_view_matrices() - viewmat = view_mats[ViewType.LEFT] # fallback - if hemi == "rh": - viewmat = view_mats[ViewType.RIGHT] + viewmat = view_mats[ViewType.RIGHT] if hemi == "rh" else view_mats[ViewType.LEFT] rot_y = pyrr.Matrix44.from_y_rotation(0) meshdata, triangles, fthresh, fmax, neg = prepare_geometry( @@ -151,12 +134,10 @@ def show_window( ypos = 0 while glfw.get_key(window, glfw.KEY_ESCAPE) != glfw.PRESS and not glfw.window_should_close(window): - # Terminate if config app window was closed: if app_window_closed_: break glfw.poll_events() - gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) if app_window_ is not None: @@ -167,13 +148,8 @@ def show_window( current_fthresh_ = app_window_.get_fthresh_value() current_fmax_ = app_window_.get_fmax_value() meshdata, triangles, fthresh, fmax, neg = prepare_geometry( - meshpath, - overlaypath, - annotpath, - curvpath, - labelpath, - current_fthresh_, - current_fmax_, + meshpath, overlaypath, annotpath, curvpath, labelpath, + current_fthresh_, current_fmax_, ) shader = setup_shader(meshdata, triangles, wwidth, wheight, specular=specular) @@ -181,274 +157,146 @@ def show_window( gl.glUniformMatrix4fv(transformLoc, 1, gl.GL_FALSE, rot_y * viewmat) if glfw.get_key(window, glfw.KEY_RIGHT) == glfw.PRESS: - ypos = ypos + 0.0004 + ypos += 0.0004 if glfw.get_key(window, glfw.KEY_LEFT) == glfw.PRESS: - ypos = ypos - 0.0004 + ypos -= 0.0004 rot_y = pyrr.Matrix44.from_y_rotation(ypos) - # Draw gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) glfw.swap_buffers(window) glfw.terminate() - # Do NOT call app_.quit() here; QApplication teardown must be handled in the main thread. - # Only set app_window_closed_ = True in this thread. + # Signal the main thread to tear down the Qt app app_window_closed_ = True def config_app_exit_handler(): """Mark the configuration application as closed. - This handler is connected to the configuration app's about-to-quit - signal and sets a module-level flag that the main OpenGL loop polls to - terminate cleanly. + Connected to the Qt app's ``aboutToQuit`` signal so the OpenGL loop + in the worker thread terminates cleanly. """ global app_window_closed_ app_window_closed_ = True def run(): - """Command-line entry point for the WhipperSnapPy snapshot/interactive tool. + """Command-line entry point for the WhipperSnapPy interactive GUI. - Parses command-line arguments, validates argument combinations, and - either launches a non-interactive snapshot generation (``snap4``) or - starts the interactive viewer and configuration GUI. + Parses command-line arguments, validates them, then spawns the OpenGL + viewer thread and launches the PyQt6 configuration window in the main + thread. Raises ------ + RuntimeError + If PyQt6 is not installed. ValueError For invalid or mutually exclusive argument combinations. - ImportError - If interactive mode is requested but PyQt6 is not available. - - Notes - ----- - The function validates that either overlay or annotation inputs are - provided for both hemispheres; it raises ``ValueError`` for invalid - combinations. - - In non-interactive mode the function calls :func:`whippersnappy.snap4` - to produce and optionally save a composed image. - - In interactive mode it spawns the OpenGL viewer thread and launches - the PyQt6-based configuration window in the main thread. """ global current_fthresh_, current_fmax_, app_, app_window_ - # Configure basic logging for CLI invocation so messages from module loggers - # are visible to end users. Avoid configuring on import by doing this here. - logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') - - parser = argparse.ArgumentParser() - parser.add_argument( - "--version", - action="version", - version=f"%(prog)s {__version__}" - ) - parser.add_argument( - "-lh", - "--lh_overlay", - type=str, - default=None, - required=False, - help="Absolute path to the lh overlay file.", - ) - parser.add_argument( - "-rh", - "--rh_overlay", - type=str, - default=None, - required=False, - help="Absolute path to the rh overlay file.", - ) - parser.add_argument( - "--lh_annot", - type=str, - default=None, - required=False, - help="Absolute path to the lh annotation file.", - ) - parser.add_argument( - "--rh_annot", - type=str, - default=None, - required=False, - help="Absolute path to the rh annotation file.", - ) - parser.add_argument( - "-sd", - "--sdir", - type=str, - required=True, - help="Absolute path to subject directory from which surfaces will be loaded. " - "This is assumed to contain the surface files in a surf/ sub-directory.", - ) - parser.add_argument( - "-s", - "--surf_name", - type=str, - default=None, - help="Name of the surface file to load.", - ) - parser.add_argument( - "-o", - "--output_path", - type=str, - default=os.path.join(tempfile.gettempdir(), "whippersnappy_snap.png"), - help="Absolute path to the output file (snapshot image), " - "if not running interactive mode.", - ) - parser.add_argument( - "-c", "--caption", type=str, default="", help="Caption to place on the figure" - ) - parser.add_argument( - "--no-colorbar", - dest="no_colorbar", - action="store_true", - default=False, - help="Switch off colorbar.") - parser.add_argument("--fmax", - type=float, - default=4.0, - help="Overlay saturation value (default: 4.0)") - parser.add_argument("--fthresh", - type=float, - default=2.0, - help="Overlay threshold value (default: 2.0)") - parser.add_argument( - "-i", - "--interactive", - dest="interactive", - action="store_true", - help="Start an interactive GUI session.", - ) - parser.add_argument( - "--invert", dest="invert", action="store_true", help="Invert the color scale." - ) - parser.add_argument( - "--diffuse", - dest="specular", - action="store_false", - default=True, - help="Diffuse surface reflection (switch-off specular).", + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + parser = argparse.ArgumentParser( + prog="whippersnap", + description=( + "Interactive GUI viewer for a single hemisphere. " + "For batch four-view rendering use whippersnap4." + ), ) + parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}") + parser.add_argument("-lh", "--lh_overlay", type=str, default=None, + help="Path to the lh overlay file.") + parser.add_argument("-rh", "--rh_overlay", type=str, default=None, + help="Path to the rh overlay file.") + parser.add_argument("--lh_annot", type=str, default=None, + help="Path to the lh annotation file.") + parser.add_argument("--rh_annot", type=str, default=None, + help="Path to the rh annotation file.") + parser.add_argument("-sd", "--sdir", type=str, required=True, + help="Subject directory containing surf/ and label/ subdirectories.") + parser.add_argument("-s", "--surf_name", type=str, default=None, + help="Surface basename to load (e.g. 'white').") + parser.add_argument("-c", "--caption", type=str, default="", + help="Caption text.") + parser.add_argument("--fmax", type=float, default=4.0, + help="Overlay saturation value (default: 4.0).") + parser.add_argument("--fthresh", type=float, default=2.0, + help="Overlay threshold value (default: 2.0).") + parser.add_argument("--invert", action="store_true", + help="Invert the color scale.") + parser.add_argument("--diffuse", dest="specular", action="store_false", default=True, + help="Diffuse-only shading (no specular).") args = parser.parse_args() try: - # check for mutually exclusive arguments if (args.lh_overlay or args.rh_overlay) and (args.lh_annot or args.rh_annot): - msg = "Cannot use lh_overlay/rh_overlay and lh_annot/rh_annot arguments at the same time." - logger.error(msg) - raise ValueError(msg) - # check if at least one variant is present - if args.lh_overlay is None and args.rh_overlay is None and args.lh_annot is None and args.rh_annot is None: - msg = "Either lh_overlay/rh_overlay or lh_annot/rh_annot must be present." - logger.error(msg) - raise ValueError(msg) - # check if both hemis are present - if (args.lh_overlay is None and args.rh_overlay is not None) or \ - (args.lh_overlay is not None and args.rh_overlay is None) or \ - (args.lh_annot is None and args.rh_annot is not None) or \ - (args.lh_annot is not None and args.rh_annot is None): - msg = "If lh_overlay or lh_annot is present, rh_overlay or rh_annot must also be present (and vice versa)." - logger.error(msg) - raise ValueError(msg) + raise ValueError( + "Cannot use lh_overlay/rh_overlay and lh_annot/rh_annot at the same time." + ) + if not any([args.lh_overlay, args.rh_overlay, args.lh_annot, args.rh_annot]): + raise ValueError( + "Either lh_overlay/rh_overlay or lh_annot/rh_annot must be present." + ) + if (args.lh_overlay is None) != (args.rh_overlay is None): + raise ValueError("Both -lh and -rh overlays must be provided together.") + if (args.lh_annot is None) != (args.rh_annot is None): + raise ValueError("Both --lh_annot and --rh_annot must be provided together.") except ValueError as e: parser.error(str(e)) - logger.debug("Parsed args: %s", vars(args)) - - # - if not args.interactive: - snap4( - lhoverlaypath=args.lh_overlay, - rhoverlaypath=args.rh_overlay, - lhannotpath=args.lh_annot, - rhannotpath=args.rh_annot, - sdir=args.sdir, - caption=args.caption, - surfname=args.surf_name, - fthresh=args.fthresh, - fmax=args.fmax, - invert=args.invert, - colorbar=not(args.no_colorbar), - outpath=args.output_path, - specular=args.specular, + if QApplication is None: + print( + "ERROR: Interactive mode requires PyQt6. " + "Install with: pip install 'whippersnappy[gui]'", + file=sys.stderr, ) - else: - try: - from ..gui import ConfigWindow # lazy import - except ModuleNotFoundError as e: - raise RuntimeError( - "Interactive mode requires the optional dependency PyQt6. " - "Install with: pip install 'whippersnappy[gui]'" - ) from e - - current_fthresh_ = args.fthresh - current_fmax_ = args.fmax - - # Ensure GUI toolkit is available - if QApplication is None: - print("ERROR: Interactive mode requires PyQt6. Install it (pip install PyQt6)" - " or run without --interactive.", file=sys.stderr) - sys.exit(1) - - # Starting interactive OpenGL window in a separate thread: - thread = threading.Thread( - target=show_window, - args=( - "lh", - args.lh_overlay, - args.lh_annot, - args.sdir, - None, - False, - "cortex.label", - args.surf_name, - "curv", - args.specular, - ), - ) - thread.start() - - # Setting up and running config app window (must be main thread): - app_ = QApplication([]) - app_.setStyle("Fusion") # the default - app_.aboutToQuit.connect(config_app_exit_handler) - - screen_geometry = app_.primaryScreen().availableGeometry() - app_window_ = ConfigWindow( - screen_dims=(screen_geometry.width(), screen_geometry.height()), - initial_fthresh_value=current_fthresh_, - initial_fmax_value=current_fmax_, + raise RuntimeError( + "Interactive mode requires PyQt6. " + "Install with: pip install 'whippersnappy[gui]'" ) - # The following is a way to allow CTRL+C termination of the app: - signal.signal(signal.SIGINT, signal.SIG_DFL) - - app_window_.show() - app_.exec() - + try: + from ..gui import ConfigWindow # noqa: PLC0415 + except ModuleNotFoundError as e: + raise RuntimeError( + "Interactive mode requires PyQt6. " + "Install with: pip install 'whippersnappy[gui]'" + ) from e + + current_fthresh_ = args.fthresh + current_fmax_ = args.fmax + + thread = threading.Thread( + target=show_window, + args=( + "lh", + args.lh_overlay, + args.lh_annot, + args.sdir, + None, + False, + "cortex.label", + args.surf_name, + "curv", + args.specular, + ), + ) + thread.start() -# headless docker test using xvfb: -# Note, xvfb is a display server implementing the X11 protocol, and performing -# all graphics on memory. -# glfw needs a windows to render even if that is invisible, so above code -# will not work via ssh or on a headless server. xvfb can solve this by wrapping: -# docker run --name headless_test -ti -v$(pwd):/test ubuntu /bin/bash -# apt update && apt install -y python3 python3-pip xvfb -# pip3 install pyopengl glfw pillow numpy pyrr -# xvfb-run python3 test4.py + app_ = QApplication([]) + app_.setStyle("Fusion") + app_.aboutToQuit.connect(config_app_exit_handler) -# instead of the above one could really do headless off-screen rendering via -# EGL (preferred) or OSMesa. The latter looks doable. EGL looks tricky. -# EGL is part of any modern NVIDIA driver -# OSMesa needs to be installed, but should work almost everywhere + screen_geometry = app_.primaryScreen().availableGeometry() + app_window_ = ConfigWindow( + screen_dims=(screen_geometry.width(), screen_geometry.height()), + initial_fthresh_value=current_fthresh_, + initial_fmax_value=current_fmax_, + ) -# using EGL maybe like this: -# https://github.com/eduble/gl -# or via these bindings: -# https://github.com/perey/pegl + signal.signal(signal.SIGINT, signal.SIG_DFL) + app_window_.show() + app_.exec() -# or OSMesa -# https://github.com/AntonOvsyannikov/DockerGL diff --git a/whippersnappy/cli/whippersnap4.py b/whippersnappy/cli/whippersnap4.py new file mode 100644 index 0000000..5835055 --- /dev/null +++ b/whippersnappy/cli/whippersnap4.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 +"""CLI entry point for four-view batch rendering via snap4. + +Renders left and right hemisphere surfaces in lateral and medial views, +stitches them into a single image and writes it to disk. + +Usage:: + + whippersnap4 -lh -rh -sd -o out.png + whippersnap4 --lh_annot --rh_annot -sd + +See ``whippersnap4 --help`` for the full list of options. +For the interactive GUI use ``whippersnap``. +""" + +import argparse +import logging +import os +import tempfile + +from .. import snap4 +from .._version import __version__ + +# Module logger +logger = logging.getLogger(__name__) + + +def run(): + """Command-line entry point for WhipperSnapPy four-view batch rendering. + + Parses command-line arguments, validates them, and calls + :func:`whippersnappy.snap4` to produce a four-view composed image. + + Raises + ------ + ValueError + For invalid or mutually exclusive argument combinations. + RuntimeError + If the OpenGL context cannot be initialised. + FileNotFoundError + If required surface files cannot be found. + """ + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + parser = argparse.ArgumentParser( + prog="whippersnap4", + description=( + "Render a four-view (left/right hemisphere, lateral/medial) " + "batch snapshot without a GUI. " + "For the interactive GUI use whippersnap." + ), + ) + parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}") + + # --- Overlay / annotation inputs --- + parser.add_argument("-lh", "--lh_overlay", type=str, default=None, + help="Path to the lh overlay file.") + parser.add_argument("-rh", "--rh_overlay", type=str, default=None, + help="Path to the rh overlay file.") + parser.add_argument("--lh_annot", type=str, default=None, + help="Path to the lh annotation (.annot) file.") + parser.add_argument("--rh_annot", type=str, default=None, + help="Path to the rh annotation (.annot) file.") + + # --- Subject directory / surface --- + parser.add_argument("-sd", "--sdir", type=str, required=True, + help="Subject directory containing surf/ and label/ subdirectories.") + parser.add_argument("-s", "--surf_name", type=str, default=None, + help="Surface basename to load (e.g. 'white'); " + "auto-detected if not provided.") + + # --- Output --- + parser.add_argument( + "-o", "--output_path", + type=str, + default=os.path.join(tempfile.gettempdir(), "whippersnappy_snap4.png"), + help="Output image path (default: temp file).", + ) + + # --- Overlay appearance --- + parser.add_argument("--fmax", type=float, default=None, + help="Overlay saturation value (auto-estimated if not set).") + parser.add_argument("--fthresh", type=float, default=None, + help="Overlay threshold value (auto-estimated if not set).") + parser.add_argument("--invert", action="store_true", + help="Invert the color scale.") + parser.add_argument("--no-colorbar", dest="no_colorbar", action="store_true", + default=False, help="Suppress the colorbar.") + parser.add_argument("-c", "--caption", type=str, default="", + help="Caption text to place on the figure.") + + # --- Rendering --- + parser.add_argument("--diffuse", dest="specular", action="store_false", default=True, + help="Diffuse-only shading (no specular).") + parser.add_argument("--ambient", type=float, default=0.0, + help="Ambient light strength (default: 0.0).") + parser.add_argument("--brain-scale", type=float, default=1.85, + help="Geometry scale factor (default: 1.85).") + parser.add_argument("--font", type=str, default=None, + help="Path to a TTF font for captions.") + + args = parser.parse_args() + + try: + if (args.lh_overlay or args.rh_overlay) and (args.lh_annot or args.rh_annot): + raise ValueError( + "Cannot use lh_overlay/rh_overlay and lh_annot/rh_annot at the same time." + ) + if not any([args.lh_overlay, args.rh_overlay, args.lh_annot, args.rh_annot]): + raise ValueError( + "Either lh_overlay/rh_overlay or lh_annot/rh_annot must be present." + ) + if (args.lh_overlay is None) != (args.rh_overlay is None): + raise ValueError("Both -lh and -rh overlays must be provided together.") + if (args.lh_annot is None) != (args.rh_annot is None): + raise ValueError("Both --lh_annot and --rh_annot must be provided together.") + except ValueError as e: + parser.error(str(e)) + + logger.debug("Parsed args: %s", vars(args)) + + try: + img = snap4( + lhoverlaypath=args.lh_overlay, + rhoverlaypath=args.rh_overlay, + lhannotpath=args.lh_annot, + rhannotpath=args.rh_annot, + sdir=args.sdir, + caption=args.caption, + surfname=args.surf_name, + fthresh=args.fthresh, + fmax=args.fmax, + invert=args.invert, + colorbar=not args.no_colorbar, + outpath=args.output_path, + font_file=args.font, + specular=args.specular, + ambient=args.ambient, + brain_scale=args.brain_scale, + ) + logger.info( + "Snapshot saved to %s (%dx%d)", args.output_path, img.width, img.height + ) + except (RuntimeError, FileNotFoundError, ValueError) as e: + parser.error(str(e)) + From c1213dde8b4b560065f430c4d243b85e21197c28 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Thu, 19 Feb 2026 23:59:03 +0100 Subject: [PATCH 45/83] update Dockerfile to whippersnap4 and video --- Dockerfile | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index 5103578..ea56592 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,11 +10,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ rm -rf /var/lib/apt/lists/* RUN pip install --upgrade pip -RUN pip install pyopengl glfw pillow numpy pyrr COPY . /WhipperSnapPy -RUN pip install /WhipperSnapPy +RUN pip install /WhipperSnapPy[video] -ENTRYPOINT ["whippersnap"] +ENTRYPOINT ["whippersnap4"] CMD ["--help"] - From a19499394e86b986a57e33b737e7425a67820535 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Fri, 20 Feb 2026 00:12:09 +0100 Subject: [PATCH 46/83] add DOCKER doc and update README.md and sphinx --- DOCKER.md | 170 +++++++++++++++++++++++++++++++++++++++++++++++++ README.md | 128 ++++++++++++++++++++++--------------- doc/conf.py | 7 ++ doc/docker.rst | 4 ++ doc/index.rst | 92 +------------------------- pyproject.toml | 1 + 6 files changed, 263 insertions(+), 139 deletions(-) create mode 100644 DOCKER.md create mode 100644 doc/docker.rst diff --git a/DOCKER.md b/DOCKER.md new file mode 100644 index 0000000..c8b4b7b --- /dev/null +++ b/DOCKER.md @@ -0,0 +1,170 @@ +# WhipperSnapPy — Docker Guide + +The Docker image provides a fully headless rendering environment with EGL +off-screen support. No display server or `xvfb` is required. + +The default entry point is `whippersnap4` (four-view batch rendering). +`whippersnap1` (single-view snapshot and rotation video) can be invoked by +overriding the entry point. + +--- + +## Building the image + +From the repository root: + +```bash +docker build --rm -t whippersnappy -f Dockerfile . +``` + +--- + +## Running — four-view batch rendering (`whippersnap4`) + +`whippersnap4` renders lateral and medial views of both hemispheres and writes +a single composed PNG image. + +Mount your local directories into the container and pass the in-container paths +as arguments: + +```bash +docker run --rm --init \ + -v /path/to/subject:/subject \ + -v /path/to/output:/output \ + --user $(id -u):$(id -g) \ + whippersnappy \ + -lh /subject/surf/lh.thickness \ + -rh /subject/surf/rh.thickness \ + -sd /subject \ + -o /output/snap4.png +``` + +### With an annotation file instead of an overlay + +```bash +docker run --rm --init \ + -v /path/to/subject:/subject \ + -v /path/to/output:/output \ + --user $(id -u):$(id -g) \ + whippersnappy \ + --lh_annot /subject/label/lh.aparc.annot \ + --rh_annot /subject/label/rh.aparc.annot \ + -sd /subject \ + -o /output/snap4_annot.png +``` + +### With a caption and custom thresholds + +```bash +docker run --rm --init \ + -v /path/to/subject:/subject \ + -v /path/to/output:/output \ + --user $(id -u):$(id -g) \ + whippersnappy \ + -lh /subject/surf/lh.thickness \ + -rh /subject/surf/rh.thickness \ + -sd /subject \ + --fthresh 2.0 --fmax 4.0 \ + --caption "Cortical thickness" \ + -o /output/snap4_thickness.png +``` + +### All `whippersnap4` options + +``` +docker run --rm whippersnappy --help +``` + +--- + +## Running — single-view snapshot (`whippersnap1`) + +Override the entry point with `--entrypoint whippersnap1`: + +```bash +docker run --rm --init \ + --entrypoint whippersnap1 \ + -v /path/to/subject:/subject \ + -v /path/to/output:/output \ + --user $(id -u):$(id -g) \ + whippersnappy \ + /subject/surf/lh.white \ + --overlay /subject/surf/lh.thickness \ + --curv /subject/surf/lh.curv \ + --view left \ + --fthresh 2.0 --fmax 4.0 \ + -o /output/snap1.png +``` + +### All `whippersnap1` options + +```bash +docker run --rm --entrypoint whippersnap1 whippersnappy --help +``` + +--- + +## Running — 360° rotation video (`whippersnap1 --rotate`) + +`whippersnap1 --rotate` renders a full 360° rotation video and writes an +`.mp4`, `.webm`, or `.gif` file. `imageio-ffmpeg` is bundled in the image — +no system ffmpeg is required. + +### MP4 (H.264, recommended) + +```bash +docker run --rm --init \ + --entrypoint whippersnap1 \ + -v /path/to/subject:/subject \ + -v /path/to/output:/output \ + --user $(id -u):$(id -g) \ + whippersnappy \ + /subject/surf/lh.white \ + --overlay /subject/surf/lh.thickness \ + --curv /subject/surf/lh.curv \ + --rotate \ + --rotate-frames 72 \ + --rotate-fps 24 \ + -o /output/rotation.mp4 +``` + +### Animated GIF (no ffmpeg needed) + +```bash +docker run --rm --init \ + --entrypoint whippersnap1 \ + -v /path/to/subject:/subject \ + -v /path/to/output:/output \ + --user $(id -u):$(id -g) \ + whippersnappy \ + /subject/surf/lh.white \ + --overlay /subject/surf/lh.thickness \ + --rotate \ + --rotate-frames 36 \ + --rotate-fps 12 \ + -o /output/rotation.gif +``` + +--- + +## Path mapping summary + +| Host path | Container path | Purpose | +|-----------|---------------|---------| +| `/path/to/subject` | `/subject` | FreeSurfer subject directory (contains `surf/`, `label/`) | +| `/path/to/output` | `/output` | Directory where output files are written | + +All output files are written to the container path you pass via `-o`; mount the +parent directory to retrieve them on the host. + +--- + +## Notes + +- The `--init` flag is recommended so that signals (e.g. `Ctrl-C`) are handled + correctly inside the container. +- `--user $(id -u):$(id -g)` ensures output files are owned by your host user, + not root. +- The interactive GUI (`whippersnap`) is **not** available in the Docker image — + it requires a display server and PyQt6, which are not installed. + diff --git a/README.md b/README.md index daf2486..cb3498b 100644 --- a/README.md +++ b/README.md @@ -1,57 +1,78 @@ # WhipperSnapPy -WhipperSnapPY is a small Python OpenGL program to render FreeSurfer and -FastSurfer surface models and color overlays and generate screen shots. +WhipperSnapPy is a Python OpenGL program to render FreeSurfer and +FastSurfer surface models with color overlays or parcellations and generate +screenshots. ## Contents: -- Capture 4x4 surface plots (front & back, left and right) -- OpenGL window for interactive visualization (GUI) -- Interactive 3D viewer for Jupyter notebooks with mouse-controlled rotation +- `snap1` — single-view surface snapshot +- `snap4` — four-view composed image (lateral/medial, both hemispheres) +- `snap_rotate` — 360° rotation video (MP4, WebM, or GIF) +- `plot3d` — interactive 3D WebGL viewer for Jupyter notebooks +- `whippersnap` — desktop GUI with live Qt controls ## Installation: -The `WhipperSnapPy` package can be installed from pypi via -``` +The `WhipperSnapPy` package can be installed from PyPI via: + +```bash python3 -m pip install whippersnappy ``` -Note, that currently no off-screen rendering is natively supported. Even in snap -mode an invisible window will be created to render the openGL output -and capture the contents to an image. In order to run this on a headless -server, inside Docker, or via ssh we recommend to install xvfb and run +For rotation video support (MP4/WebM): +```bash +pip install 'whippersnappy[video]' ``` -sudo apt update && apt install -y python3 python3-pip xvfb libxcb-xinerama0 -pip3 install pyopengl glfw pillow numpy pyrr PyQt6 -pip3 install whippersnappy -xvfb-run whippersnap ... + +For the interactive desktop GUI: + +```bash +pip install 'whippersnappy[gui]' ``` +For interactive 3D in Jupyter notebooks: + +```bash +pip install 'whippersnappy[notebook]' +``` + +Off-screen (headless) rendering is supported natively via EGL on Linux — no +`xvfb` required. See the [Docker guide](DOCKER.md) for headless usage. + ## Usage: ### Local: -After installing the Python package, the whippersnap program can be run using -the installed command line tool such as in the following example: -``` -whippersnap -lh $OVERLAY_DIR/$LH_OVERLAY_FILE \ - -rh $OVERLAY_DIR/$RH_OVERLAY_FILE \ - -sd $SURF_SUBJECT_DIR \ - --fmax 4 --fthresh 2 --invert \ - --caption caption.txt \ - -o $OUTPUT_DIR/whippersnappy_image.png \ +After installing the Python package, the command-line tools can be run as in +the following examples: + +```bash +# Four-view batch rendering (both hemispheres) +whippersnap4 -lh $LH_OVERLAY -rh $RH_OVERLAY \ + -sd $SURF_SUBJECT_DIR \ + --fmax 4 --fthresh 2 --invert \ + --caption "My caption" \ + -o $OUTPUT_DIR/snap4.png + +# Single-view snapshot +whippersnap1 $SURF_SUBJECT_DIR/surf/lh.white \ + --overlay $LH_OVERLAY \ + --view left -o $OUTPUT_DIR/snap1.png + +# 360° rotation video +whippersnap1 $SURF_SUBJECT_DIR/surf/lh.white \ + --overlay $LH_OVERLAY \ + --rotate -o $OUTPUT_DIR/rotation.mp4 ``` -For more options see `whippersnap --help`. -Note, that adding the `--interactive` flag will start an interactive GUI that -includes a visualization of one hemisphere side and a simple application through -which color threshold values can be configured. +For more options see `whippersnap4 --help` or `whippersnap1 --help`. ## Quick Imports ```python -from whippersnappy import snap1, snap4, plot3d +from whippersnappy import snap1, snap4, snap_rotate, plot3d ``` ### Jupyter Notebooks: @@ -96,7 +117,7 @@ For static publication-quality images: ```python from whippersnappy import snap1 -from whippersnappy.types import ViewType +from whippersnappy.utils.types import ViewType from IPython.display import display img = snap1( @@ -122,36 +143,43 @@ See `examples/whippersnappy_demo.ipynb` for complete examples. ### Desktop GUI: -For interactive desktop application with GUI controls: +For interactive desktop visualization with Qt controls: ```bash -whippersnap --interactive -lh path/to/lh.white -rh path/to/rh.white +whippersnap -lh /path/to/lh.thickness -sd /path/to/subject ``` -This launches a native desktop GUI (not a notebook) with sliders and controls. +This launches a native desktop GUI with a live OpenGL window and a +configuration panel for adjusting overlay thresholds at runtime. +Requires `pip install 'whippersnappy[gui]'`. ### Docker: -The whippersnap program can be run within a docker container to capture -a snapshot by building the provided Docker image and running a container as -follows: -``` -docker build --rm=true -t whippersnappy -f ./Dockerfile . -``` -``` -docker run --rm --init --name my_whippersnappy -v $SURF_SUBJECT_DIR:/surf_subject_dir \ - -v $OVERLAY_DIR:/overlay_dir \ - -v $OUTPUT_DIR:/output_dir \ - --user $(id -u):$(id -g) whippersnappy:latest \ - --lh_overlay /overlay_dir/$LH_OVERLAY_FILE \ - --rh_overlay /overlay_dir/$RH_OVERLAY_FILE \ - --sdir /surf_subject_dir \ - --output_path /output_dir/whippersnappy_image.png +The Docker image provides a fully headless EGL rendering environment — no +display server or `xvfb` required. + +Build the image: + +```bash +docker build --rm -t whippersnappy -f Dockerfile . ``` -In this example: `$SURF_SUBJECT_DIR` contains the surface files, `$OVERLAY_DIR` contains the overlays to be loaded on to the surfaces, `$OUTPUT_DIR` is the local output directory in which the snapshot will be saved, and `${LH/RH}_OVERLAY_FILE` point to the specific overlay files to load. +Run a four-view batch snapshot: + +```bash +docker run --rm --init \ + -v /path/to/subject:/subject \ + -v /path/to/output:/output \ + --user $(id -u):$(id -g) \ + whippersnappy \ + -lh /subject/surf/lh.thickness \ + -rh /subject/surf/rh.thickness \ + -sd /subject \ + -o /output/snap4.png +``` -**Note:** The `--init` flag to Docker is needed for the `xvfb-run` tool to be used correctly for off-screen rendering. +For single-view snapshots, rotation videos, annotation overlays, custom +thresholds, and more examples see **[DOCKER.md](DOCKER.md)**. ## API Documentation diff --git a/doc/conf.py b/doc/conf.py index e558aa1..9265300 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -37,6 +37,7 @@ # extensions coming with Sphinx (named "sphinx.ext.*") or your custom # ones. extensions = [ + "myst_parser", "sphinx.ext.autodoc", "sphinx.ext.autosectionlabel", "sphinx.ext.autosummary", @@ -48,6 +49,12 @@ "sphinx_design", ] +# Tell Sphinx to parse both .rst and .md files; MyST handles the Markdown side. +source_suffix = { + ".rst": "restructuredtext", + ".md": "markdown", +} + templates_path = ["_templates"] exclude_patterns = [ "_build", diff --git a/doc/docker.rst b/doc/docker.rst new file mode 100644 index 0000000..f8249cd --- /dev/null +++ b/doc/docker.rst @@ -0,0 +1,4 @@ +.. _docker: + +.. include:: ../DOCKER.md + :parser: myst_parser.sphinx_ diff --git a/doc/index.rst b/doc/index.rst index 81e9c4e..13cfc84 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -1,97 +1,11 @@ .. include:: ./links.inc - -.. whippersnappy documentation master file, created by - sphinx-quickstart on Wed Jul 12 09:51:03 2023. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - -Welcome to whippersnappy's documentation! -========================================= - -WhipperSnapPY is a small Python OpenGL program to render -FreeSurfer and FastSurfer surface models with color overlays or parcellations and generate screen shots. - -License -------- - -`Whippersnappy `_ is licensed under the `MIT license`_. -A full copy of the license can be found `on GitHub `_. - -Contents --------- - -- Snap1: Capture a single snapshot of a surface with an overlay -- Snap4: Capture 4x4 surface plots (front & back, left and right) of a Free- or FastSurfer brain surface with an overlay -- Plot3d: Interactive 3D WebGL visualization in IPython notebooks -- OpenGL QT GUI for interactive visualization - -Note, that currently no off-screen rendering is supported. Even in snap mode an invisible window will be created to render the openGL output and capture the contents to an image. In order to run this on a headless server, inside Docker, or via ssh we recommend to install xvfb and run - -.. code-block:: bash - - apt update && apt install -y python3 python3-pip xvfb - pip3 install whippersnappy - xvfb-run whippersnap ... - -Installation ------------- - -The `Whippersnappy `_ package can be installed from this repository using: - -.. code-block:: bash - - python3 -m pip install . - -Usage ------ - -Local -''''' - -After installing the Python package, the whippersnap program can be run using the installed command line tool such as in the following example: - -.. code-block:: bash - - whippersnap -lh $OVERLAY_DIR/$LH_OVERLAY_FILE \ - -rh $OVERLAY_DIR/$RH_OVERLAY_FILE \ - -sd $SURF_SUBJECT_DIR \ - -o $OUTPUT_DIR/whippersnappy_image.png - -Note that adding the `--interactive` flag will start an interactive GUI that includes a visualization of one hemisphere side and a simple application through which color threshold values can be configured. - -Docker -'''''' - -the whippersnap program can be run within a docker container to capture a snapshot by building the provided Docker image and running a container as follows: - -.. code-block:: bash - - docker build --rm=true -t whippersnappy -f ./Dockerfile . - -.. code-block:: bash - - docker run --rm --init --name my_whippersnappy -v $SURF_SUBJECT_DIR:/surf_subject_dir \ - -v $OVERLAY_DIR:/overlay_dir \ - -v $OUTPUT_DIR:/output_dir \ - --user $(id -u):$(id -g) whippersnappy:latest \ - --lh_overlay /overlay_dir/$LH_OVERLAY_FILE \ - --rh_overlay /overlay_dir/$RH_OVERLAY_FILE \ - --sdir /surf_subject_dir \ - --output_path /output_dir/whippersnappy_image.png - -In this example: `$SURF_SUBJECT_DIR` contains the surface files, `$OVERLAY_DIR` contains the overlays to be loaded on to the surfaces, `$OUTPUT_DIR` is the local output directory in which the snapshot will be saved, and `${LH/RH}_OVERLAY_FILE` point to the specific overlay files to load. - -**Note:** The `--init` flag is needed for the `xvfb-run` tool to be used correctly. - -Links ------ - -We also invite you to check out our lab webpage at https://deep-mi.org +.. include:: ../README.md + :parser: myst_parser.sphinx_ .. toctree:: :hidden: + docker api/index - diff --git a/pyproject.toml b/pyproject.toml index 9b3b16c..d187670 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ doc = [ 'furo!=2023.8.17', 'matplotlib', 'memory-profiler', + 'myst-parser', 'numpydoc', 'sphinx!=7.2.*', 'sphinxcontrib-bibtex', From fb5c6cd93239f27e5bd38ec91b1d76cb60c564e7 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Fri, 20 Feb 2026 00:26:03 +0100 Subject: [PATCH 47/83] fix sphinx --- doc/DOCKER.md | 1 + doc/README.md | 1 + doc/conf.py | 3 +-- doc/docker.rst | 2 +- doc/index.rst | 2 +- 5 files changed, 5 insertions(+), 4 deletions(-) create mode 120000 doc/DOCKER.md create mode 120000 doc/README.md diff --git a/doc/DOCKER.md b/doc/DOCKER.md new file mode 120000 index 0000000..6916342 --- /dev/null +++ b/doc/DOCKER.md @@ -0,0 +1 @@ +../DOCKER.md \ No newline at end of file diff --git a/doc/README.md b/doc/README.md new file mode 120000 index 0000000..32d46ee --- /dev/null +++ b/doc/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/doc/conf.py b/doc/conf.py index 9265300..0eb3204 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -12,8 +12,6 @@ from importlib import import_module from typing import Dict, Optional -# from sphinx_gallery.sorting import FileNameSortKey - import whippersnappy project = "WhipperSnapPy" @@ -281,5 +279,6 @@ def ensure_pandoc_installed(_): delete_installer=True, ) + def setup(app): app.connect("builder-inited", ensure_pandoc_installed) diff --git a/doc/docker.rst b/doc/docker.rst index f8249cd..467f6eb 100644 --- a/doc/docker.rst +++ b/doc/docker.rst @@ -1,4 +1,4 @@ .. _docker: -.. include:: ../DOCKER.md +.. include:: DOCKER.md :parser: myst_parser.sphinx_ diff --git a/doc/index.rst b/doc/index.rst index 13cfc84..3a32dc5 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -1,6 +1,6 @@ .. include:: ./links.inc -.. include:: ../README.md +.. include:: README.md :parser: myst_parser.sphinx_ .. toctree:: From 607bbfedf3b050a5ba5912a8ee26c821bade82b3 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Fri, 20 Feb 2026 00:37:58 +0100 Subject: [PATCH 48/83] fix sphinx --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index cb3498b..263d8d0 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ pip install 'whippersnappy[notebook]' ``` Off-screen (headless) rendering is supported natively via EGL on Linux — no -`xvfb` required. See the [Docker guide](DOCKER.md) for headless usage. +`xvfb` required. See the Docker guide for headless usage. ## Usage: @@ -179,7 +179,7 @@ docker run --rm --init \ ``` For single-view snapshots, rotation videos, annotation overlays, custom -thresholds, and more examples see **[DOCKER.md](DOCKER.md)**. +thresholds, and more examples see DOCKER.md. ## API Documentation From cb09c2c271cca6662015c4c38649ea8027283b81 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Fri, 20 Feb 2026 01:03:25 +0100 Subject: [PATCH 49/83] remove sphinx-gallery --- doc/_templates/autosummary/function.rst | 3 -- doc/conf.py | 39 ++++++++++++------------- pyproject.toml | 1 - 3 files changed, 19 insertions(+), 24 deletions(-) diff --git a/doc/_templates/autosummary/function.rst b/doc/_templates/autosummary/function.rst index cdbecc4..27ca987 100644 --- a/doc/_templates/autosummary/function.rst +++ b/doc/_templates/autosummary/function.rst @@ -3,6 +3,3 @@ .. currentmodule:: {{ module }} .. autofunction:: {{ objname }} - -.. minigallery:: {{ fullname }} - :add-heading: diff --git a/doc/conf.py b/doc/conf.py index 0eb3204..f274fc8 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -47,20 +47,21 @@ "sphinx_design", ] -# Tell Sphinx to parse both .rst and .md files; MyST handles the Markdown side. -source_suffix = { - ".rst": "restructuredtext", - ".md": "markdown", -} - -templates_path = ["_templates"] +# .md files are included via '.. include:: :parser: myst_parser.sphinx_' +# in the RST stubs; they must NOT be registered as standalone Sphinx source +# documents or autosectionlabel will produce duplicate-label warnings from +# both the real file and the doc/ symlink. exclude_patterns = [ "_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints", + "*.md", # exclude symlinked .md files inside doc/ + "../*.md", # exclude root-level .md files ] +templates_path = ["_templates"] + # Sphinx will warn about all references where the target cannot be found. nitpicky = False nitpick_ignore = [] @@ -265,19 +266,17 @@ def linkcode_resolve(domain: str, info: Dict[str, str]) -> Optional[str]: DOCS_DIRECTORY = os.path.dirname(os.path.abspath(getsourcefile(lambda: 0))) def ensure_pandoc_installed(_): - import pypandoc - - # Download pandoc if necessary. If pandoc is already installed and on - # the PATH, the installed version will be used. Otherwise, we will - # download a copy of pandoc into docs/bin/ and add that to our PATH. - pandoc_dir = os.path.join(DOCS_DIRECTORY, "bin") - # Add dir containing pandoc binary to the PATH environment variable - if pandoc_dir not in os.environ["PATH"].split(os.pathsep): - os.environ["PATH"] += os.pathsep + pandoc_dir - pypandoc.ensure_pandoc_installed( - targetfolder=pandoc_dir, - delete_installer=True, - ) + try: + import pypandoc + pandoc_dir = os.path.join(DOCS_DIRECTORY, "bin") + if pandoc_dir not in os.environ["PATH"].split(os.pathsep): + os.environ["PATH"] += os.pathsep + pandoc_dir + pypandoc.ensure_pandoc_installed( + targetfolder=pandoc_dir, + delete_installer=True, + ) + except Exception: + pass # pandoc already on PATH (CI) or download failed (local SSL) — continue def setup(app): diff --git a/pyproject.toml b/pyproject.toml index d187670..fefddb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ doc = [ 'sphinxcontrib-bibtex', 'sphinx-copybutton', 'sphinx-design', - 'sphinx-gallery', 'sphinx-issues', 'pypandoc', 'nbsphinx', From 2e9b8cc4613a746a59b2bdba878ac2c0351c0099 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Fri, 20 Feb 2026 12:55:01 +0100 Subject: [PATCH 50/83] renamed examples -> tutorials --- README.md | 2 +- {examples => tutorials}/whippersnappy_demo.ipynb | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename {examples => tutorials}/whippersnappy_demo.ipynb (100%) diff --git a/README.md b/README.md index 263d8d0..c3475c2 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ display(img) - ✅ Fast performance - ✅ Identical to GUI version -See `examples/whippersnappy_demo.ipynb` for complete examples. +See `tutorials/whippersnappy_demo.ipynb` for complete examples. ### Desktop GUI: diff --git a/examples/whippersnappy_demo.ipynb b/tutorials/whippersnappy_demo.ipynb similarity index 100% rename from examples/whippersnappy_demo.ipynb rename to tutorials/whippersnappy_demo.ipynb From d7b61a4c7d331c51c14ffcb0ede075c0c0c50f09 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Fri, 20 Feb 2026 17:48:22 +0100 Subject: [PATCH 51/83] add code for example download via pooch and github release data --- .github/workflows/doc.yml | 5 + README.md | 2 +- pyproject.toml | 5 + tutorials/whippersnappy_demo.ipynb | 264 ------------------------- tutorials/whippersnappy_tutorial.ipynb | 205 +++++++++++++++++++ whippersnappy/__init__.py | 5 +- whippersnappy/utils/__init__.py | 4 +- whippersnappy/utils/datasets.py | 107 ++++++++++ 8 files changed, 328 insertions(+), 269 deletions(-) delete mode 100644 tutorials/whippersnappy_demo.ipynb create mode 100644 tutorials/whippersnappy_tutorial.ipynb create mode 100644 whippersnappy/utils/datasets.py diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index f1e7b56..632ca48 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -37,6 +37,11 @@ jobs: python -m pip install --progress-bar off main/.[doc] - name: Display system information run: whippersnappy-sys_info --developer + - name: Cache WhipperSnapPy sample data + uses: actions/cache@v4 + with: + path: ~/.cache/whippersnappy + key: sample-data-sub-rs-v2.0.0 - name: Build doc run: TZ=UTC sphinx-build ./main/doc ./doc-build/dev -W --keep-going - name: Upload documentation diff --git a/README.md b/README.md index c3475c2..799c397 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ display(img) - ✅ Fast performance - ✅ Identical to GUI version -See `tutorials/whippersnappy_demo.ipynb` for complete examples. +See `tutorials/whippersnappy_tutorial.ipynb` for complete examples. ### Desktop GUI: diff --git a/pyproject.toml b/pyproject.toml index fefddb4..2f95100 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,8 +64,13 @@ doc = [ 'nbsphinx', 'IPython', # For syntax highlighting in notebooks 'ipykernel', + # Needed to execute the tutorial notebook via nbsphinx: + 'pooch>=1.6', + 'pythreejs', + 'ipywidgets', ] notebook = [ + 'pooch>=1.6', # For downloading tutorial data files in the notebook 'pythreejs', # Three.js for interactive 3D (works in all Jupyter environments) 'ipywidgets', # Required for pythreejs ] diff --git a/tutorials/whippersnappy_demo.ipynb b/tutorials/whippersnappy_demo.ipynb deleted file mode 100644 index fbafd8b..0000000 --- a/tutorials/whippersnappy_demo.ipynb +++ /dev/null @@ -1,264 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "dfc714ca8243e82", - "metadata": {}, - "outputs": [], - "source": [ - "# WhipperSnapPy Demo - Static & Interactive Rendering\n", - "# This notebook demonstrates both static and interactive 3D brain visualization.\n", - "#\n", - "# Installation:\n", - "# pip install 'whippersnappy[notebook]'\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "472f4e1b64de35a9", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "\n", - "from IPython.display import display\n", - "\n", - "from whippersnappy import ViewType, plot3d, snap1, snap4\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "82ee18d3ef6732df", - "metadata": {}, - "outputs": [], - "source": [ - "# Setup Paths\n", - "# Edit these to point to your FreeSurfer/FastSurfer data:\n", - "\n", - "# Set your subject directory here\n", - "# (either an absolute path to a subject's directory\n", - "# containing a `surf/` subdirectory).\n", - "# Example: sdir = '/home/user/freesurfer/subjects/subject01'\n", - "# IMPORTANT: set this value before running the rest of the notebook.\n", - "sdir = '/path/to/your/subjectdir' # <-- update\n", - "\n", - "# Verify that sdir exists and is a directory\n", - "if not os.path.isdir(sdir):\n", - " raise ValueError(\n", - " f\"Subject directory does not exist: {sdir}\\n\"\n", - " \"Please set `sdir` to a valid subject directory containing a 'surf/' subdirectory.\"\n", - " )\n", - "\n", - "# Derive per-hemisphere paths from the subject directory\n", - "lh_surf_path = os.path.join(sdir, 'surf', 'lh.white')\n", - "lh_thickness_path = os.path.join(sdir, 'surf', 'lh.thickness')\n", - "lh_curv_path = os.path.join(sdir, 'surf', 'lh.curv')\n", - "lh_label_path = os.path.join(sdir, 'label', 'lh.cortex.label')\n", - "lh_annot_path = os.path.join(sdir, 'label', 'lh.aparc.annot')\n", - "\n", - "rh_surf_path = os.path.join(sdir, 'surf', 'rh.white')\n", - "rh_thickness_path = os.path.join(sdir, 'surf', 'rh.thickness')\n", - "rh_curv_path = os.path.join(sdir, 'surf', 'rh.curv')\n", - "rh_label_path = os.path.join(sdir, 'label', 'rh.cortex.label')\n", - "rh_annot_path = os.path.join(sdir, 'label', 'rh.aparc.annot')\n", - "\n", - "# Preset overlay variables for convenience for Snap4\n", - "lh_overlay = lh_thickness_path if os.path.exists(lh_thickness_path) else None\n", - "rh_overlay = rh_thickness_path if os.path.exists(rh_thickness_path) else None\n", - "\n", - "print(f\"Subject dir: {sdir}\")\n", - "print(f\"Surface exists? LH: {os.path.exists(lh_surf_path)} | RH: {os.path.exists(rh_surf_path)}\")\n", - "print(f\"Thickness exists? LH: {os.path.exists(lh_thickness_path)} | RH: {os.path.exists(rh_thickness_path)}\")\n", - "print(f\"Curv exists? LH: {os.path.exists(lh_curv_path)} | RH: {os.path.exists(rh_curv_path)}\")\n", - "print(f\"Label exists? LH: {os.path.exists(lh_label_path)} | RH: {os.path.exists(rh_label_path)}\")\n", - "print(f\"Annot exists? LH: {os.path.exists(lh_thickness_path)} | RH: {os.path.exists(rh_thickness_path)}\")\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fd25577540008818", - "metadata": {}, - "outputs": [], - "source": [ - "# Part 1: Static Rendering - Single View\n", - "# Generate publication-quality static images with full PyOpenGL control.\n", - "\n", - "# Render a single view\n", - "img = snap1(\n", - " meshpath=lh_surf_path,\n", - " overlaypath=lh_overlay,\n", - " curvpath=lh_curv_path if os.path.exists(lh_curv_path) else None,\n", - " view=ViewType.LEFT,\n", - " width=800,\n", - " height=800,\n", - " brain_scale=1.5, # Adjust to avoid cropping\n", - " specular=True, # Professional lighting\n", - ")\n", - "display(img)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "43a6d7a79e9a2264", - "metadata": {}, - "outputs": [], - "source": [ - "# Part 1: Static Rendering - Snap4 (both hemispheres)\n", - "# Use snap4 to render front/back views for left and right\n", - "# hemispheres and display the combined image.\n", - "\n", - "print(f\"Using subjects dir: {sdir}\")\n", - "\n", - "# lh_overlay and rh_overlay are precomputed in the Setup cell\n", - "# Call snap4 and receive a PIL Image directly when outpath=None\n", - "img4 = snap4(\n", - " lhoverlaypath=lh_overlay,\n", - " rhoverlaypath=rh_overlay,\n", - " sdir=sdir,\n", - " caption='Snap4 - both hemispheres',\n", - " outpath=None, # return PIL image instead of writing to disk\n", - " specular=True,\n", - " brain_scale=1.8,\n", - ")\n", - "\n", - "# Display result (snap4 returns a PIL.Image when outpath is None)\n", - "if img4 is not None:\n", - " display(img4)\n", - "else:\n", - " print(\"snap4 did not return an image; check inputs and OpenGL context.\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ca7e2c155177b7c5", - "metadata": {}, - "outputs": [], - "source": [ - "# Part 2: Interactive 3D Rendering\n", - "# Mouse-controlled 3D visualization using Three.js (works in all Jupyter environments).\n", - "#\n", - "# Controls:\n", - "# - Drag: Rotate\n", - "# - Scroll: Zoom\n", - "# - Right-drag: Pan\n", - "\n", - "# Interactive viewer with curvature (grayscale)\n", - "viewer = plot3d(\n", - " meshpath=lh_surf_path,\n", - " curvpath=lh_curv_path if os.path.exists(lh_curv_path) else None,\n", - " width=800,\n", - " height=800,\n", - ")\n", - "display(viewer)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f712c59cf4a54a0d", - "metadata": {}, - "outputs": [], - "source": [ - "# Interactive colored overlay (if available)\n", - "if lh_overlay:\n", - " viewer = plot3d(\n", - " meshpath=lh_surf_path,\n", - " overlaypath=lh_overlay,\n", - " curvpath=lh_curv_path if os.path.exists(lh_curv_path) else None,\n", - " labelpath=lh_label_path,\n", - " minval=0.0, # Threshold\n", - " maxval=5.5, # Saturation\n", - " width=800,\n", - " height=800,\n", - " )\n", - " display(viewer)\n", - "else:\n", - " print(\"Thickness overlay not found - skipping colored example\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9d5f61470c130e6e", - "metadata": {}, - "outputs": [], - "source": [ - "# Interactive label map overlay\n", - "if lh_annot_path:\n", - " viewer = plot3d(\n", - " meshpath=lh_surf_path,\n", - " annotpath=lh_annot_path,\n", - " curvpath=lh_curv_path if os.path.exists(lh_curv_path) else None,\n", - " width=800,\n", - " height=800,\n", - " )\n", - " display(viewer)\n", - "else:\n", - " print(\"Annot overlay not found - skipping label map example\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cd1b1b7561ebecfa", - "metadata": {}, - "outputs": [], - "source": [ - "# Notes\n", - "#\n", - "# Static Rendering:\n", - "# - Returns PIL Image objects (no disk I/O needed)\n", - "# - Full PyOpenGL control for custom lighting\n", - "# - Publication-quality output\n", - "# - Fast and deterministic\n", - "#\n", - "# Interactive Rendering:\n", - "# - Uses Three.js/WebGL (runs in browser)\n", - "# - Works in all Jupyter environments\n", - "# - Full mouse control (rotate, zoom, pan)\n", - "# - Same technology as Plotly 3D plots\n", - "#\n", - "# Color Notes:\n", - "# - Curvature: Grayscale (sulci = dark, gyri = light) - this is correct!\n", - "# - Overlays: Colored heatmaps (thickness, activation, statistics)\n", - "# - Annotations: Distinct colored regions (parcellations)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a065ca6a14ecdda9", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tutorials/whippersnappy_tutorial.ipynb b/tutorials/whippersnappy_tutorial.ipynb new file mode 100644 index 0000000..f89d2bb --- /dev/null +++ b/tutorials/whippersnappy_tutorial.ipynb @@ -0,0 +1,205 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "raw", + "source": [ + "{\n", + " \"cells\": [\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": [\n", + " \"# WhipperSnapPy Tutorial\\n\",\n", + " \"\\n\",\n", + " \"This notebook demonstrates static and interactive 3D brain surface visualization\\n\",\n", + " \"using WhipperSnapPy. It covers single-view snapshots (`snap1`), four-view overview\\n\",\n", + " \"images (`snap4`), and interactive WebGL rendering (`plot3d`).\\n\",\n", + " \"\\n\",\n", + " \"**Tutorial data** from the Rhineland Study (Koch et al.),\\n\",\n", + " \"[Zenodo: https://doi.org/10.5281/zenodo.11186582](https://doi.org/10.5281/zenodo.11186582), CC BY 4.0.\\n\",\n", + " \"\\n\",\n", + " \"Files are downloaded once and cached locally — no manual setup required.\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": [\n", + " \"from IPython.display import display\\n\",\n", + " \"\\n\",\n", + " \"from whippersnappy import fetch_sample_subject, snap1, snap4\\n\",\n", + " \"from whippersnappy.utils.types import ViewType\\n\",\n", + " \"\\n\",\n", + " \"# Downloads ~10 MB on first run; subsequent calls use the local cache.\\n\",\n", + " \"data = fetch_sample_subject()\\n\",\n", + " \"print(\\\"Sample subject cached at:\\\", data[\\\"sdir\\\"])\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": [\n", + " \"## snap1 — Basic Single View\\n\",\n", + " \"\\n\",\n", + " \"`snap1` renders a single static view of a surface mesh into a PIL Image.\\n\",\n", + " \"Here we render the left hemisphere with curvature texturing only (no overlay),\\n\",\n", + " \"which gives the classic sulcal depth shading.\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": [\n", + " \"img = snap1(data[\\\"lh_white\\\"], curvpath=data[\\\"lh_curv\\\"])\\n\",\n", + " \"display(img)\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": [\n", + " \"## snap1 — With Thickness Overlay\\n\",\n", + " \"\\n\",\n", + " \"By passing `overlaypath` and `labelpath`, the surface is colored by cortical\\n\",\n", + " \"thickness values, masked to the cortex label. The `view` parameter selects\\n\",\n", + " \"the lateral view of the left hemisphere.\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": [\n", + " \"img = snap1(\\n\",\n", + " \" data[\\\"lh_white\\\"],\\n\",\n", + " \" overlaypath=data[\\\"lh_thickness\\\"],\\n\",\n", + " \" curvpath=data[\\\"lh_curv\\\"],\\n\",\n", + " \" labelpath=data[\\\"lh_label\\\"],\\n\",\n", + " \" view=ViewType.LEFT,\\n\",\n", + " \")\\n\",\n", + " \"display(img)\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": [\n", + " \"## snap1 — With Parcellation Annotation\\n\",\n", + " \"\\n\",\n", + " \"`annotpath` accepts a FreeSurfer `.annot` file and colors each vertex by\\n\",\n", + " \"its parcellation label. This example uses the DKTatlas parcellation.\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": [\n", + " \"img = snap1(\\n\",\n", + " \" data[\\\"lh_white\\\"],\\n\",\n", + " \" annotpath=data[\\\"lh_annot\\\"],\\n\",\n", + " \" curvpath=data[\\\"lh_curv\\\"],\\n\",\n", + " \")\\n\",\n", + " \"display(img)\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": [\n", + " \"## snap4 — Four-View Overview\\n\",\n", + " \"\\n\",\n", + " \"`snap4` renders lateral and medial views of both hemispheres and stitches\\n\",\n", + " \"them into a single composed image. Here we use the parcellation annotation\\n\",\n", + " \"for both hemispheres.\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": [\n", + " \"img = snap4(\\n\",\n", + " \" sdir=data[\\\"sdir\\\"],\\n\",\n", + " \" lhannotpath=data[\\\"lh_annot\\\"],\\n\",\n", + " \" rhannotpath=data[\\\"rh_annot\\\"],\\n\",\n", + " \")\\n\",\n", + " \"display(img)\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"markdown\",\n", + " \"metadata\": {},\n", + " \"source\": [\n", + " \"## plot3d — Interactive 3D Viewer\\n\",\n", + " \"\\n\",\n", + " \"`plot3d` creates an interactive Three.js/WebGL viewer that works in all\\n\",\n", + " \"Jupyter environments. You can rotate, zoom, and pan with the mouse.\\n\",\n", + " \"Requires `pip install 'whippersnappy[notebook]'`.\"\n", + " ]\n", + " },\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"execution_count\": null,\n", + " \"metadata\": {},\n", + " \"outputs\": [],\n", + " \"source\": [\n", + " \"from whippersnappy import plot3d\\n\",\n", + " \"\\n\",\n", + " \"viewer = plot3d(\\n\",\n", + " \" meshpath=data[\\\"lh_white\\\"],\\n\",\n", + " \" curvpath=data[\\\"lh_curv\\\"],\\n\",\n", + " \" overlaypath=data[\\\"lh_thickness\\\"],\\n\",\n", + " \")\\n\",\n", + " \"display(viewer)\"\n", + " ]\n", + " }\n", + " ],\n", + " \"metadata\": {\n", + " \"kernelspec\": {\n", + " \"display_name\": \"Python 3\",\n", + " \"language\": \"python\",\n", + " \"name\": \"python3\"\n", + " },\n", + " \"language_info\": {\n", + " \"name\": \"python\",\n", + " \"version\": \"3.11.0\"\n", + " }\n", + " },\n", + " \"nbformat\": 4,\n", + " \"nbformat_minor\": 5\n", + "}\n" + ], + "id": "4f5ff62ca6bf4fd9" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index ca88081..2e3a41e 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -83,6 +83,7 @@ def _check_display(): from ._config import sys_info # noqa: F401, E402 from ._version import __version__ # noqa: F401, E402 from .snap import snap1, snap4, snap_rotate # noqa: E402 +from .utils.datasets import fetch_sample_subject # noqa: E402 from .utils.types import ViewType # noqa: E402 # 3D plotting for notebooks (Three.js-based, works in all Jupyter environments) @@ -99,9 +100,9 @@ def _check_display(): "snap1", "snap4", "snap_rotate", + "fetch_sample_subject", + "ViewType", ] if _has_plot3d: __all__.append("plot3d") -# Top-level convenience export for frequently used enum -__all__.append("ViewType") diff --git a/whippersnappy/utils/__init__.py b/whippersnappy/utils/__init__.py index dcb119b..26ac282 100644 --- a/whippersnappy/utils/__init__.py +++ b/whippersnappy/utils/__init__.py @@ -1,4 +1,4 @@ """Utils subpackage exports.""" -from . import colormap, image, types +from . import colormap, datasets, image, types -__all__ = ["image", "colormap", "types"] +__all__ = ["colormap", "datasets", "image", "types"] diff --git a/whippersnappy/utils/datasets.py b/whippersnappy/utils/datasets.py new file mode 100644 index 0000000..3facfc2 --- /dev/null +++ b/whippersnappy/utils/datasets.py @@ -0,0 +1,107 @@ +"""Sample dataset download utility for WhipperSnapPy. + +Downloads and caches a small anonymized FreeSurfer subject from the +WhipperSnapPy GitHub release assets for use in tutorials and tests. +""" + +from pathlib import Path + +RELEASE_URL = ( + "https://github.com/Deep-MI/WhipperSnapPy" + "/releases/download/v2.0.0/{file_name}" +) + +# Mapping of relative path inside the subject directory → SHA-256 hash. +# GitHub release assets are flat (no subdirectories), so the URL uses only +# the basename while pooch.retrieve() reconstructs the subdirectory locally. +_FILES = { + "README.md": "sha256:ecb6ddf31cec17f3a8636fc3ecac90099c441228811efed56104e29fcd301bc5", + "surf/lh.white": "sha256:4ab049fb42ca882ba9b56f8fe0d0e8814973e7fa2e0575a794d8e468abf7d62f", + "surf/lh.curv": "sha256:9edbde57be8593cd9d89d9d1124e2175edd8ecfee55d53e066d89700c480b12a", + "surf/lh.thickness": "sha256:40ab3483284608c6c5cca2d3d794a60cd1bcbeb0140bb1ca6ad0fce7962c57c6", + "surf/rh.white": "sha256:43035c53a8b04bebe4e843c34f80588f253f79052a8dbf7194b706495b11f8d2", + "surf/rh.curv": "sha256:af2bc71133d7ef17ce1a3a6f4208d2495a5a4c96da00c80b59be03bb7c8ea83f", + "surf/rh.thickness": "sha256:50ec291c73928cd697156edd9e0e77f5c54d15c56cf84810d2564b496876e132", + "label/lh.aparc.DKTatlas.mapped.annot": "sha256:4d48d33f4fd8278ab973a1552f6ea9c396dfc1791b707ed17ad8e761299c4960", + "label/lh.cortex.label": "sha256:79ae17fcfde6b2e0a75a0652fcc0f3c072e4ea62a541843b7338e01c598b0b6e", + "label/rh.aparc.DKTatlas.mapped.annot": "sha256:12217166d8ef43ee1fa280511ec2ba0796c6885f527a4455b93760acc73ce273", + "label/rh.cortex.label": "sha256:162c97c887eb1ec857fe575b8cc4e4b950c7dd5ec181a581d709bbe7fca58f9e", +} + + +def fetch_sample_subject() -> dict: + """Download and cache the WhipperSnapPy sample subject (Rhineland Study). + + Downloads FreeSurfer surface files for one anonymized subject into the + OS-specific user cache directory and returns a dictionary of paths to + all files. Files are only downloaded once; subsequent calls use the + local cache. + + Returns + ------- + dict + Dictionary with the following keys: + + * ``sdir`` -- path to the subject root directory (``sub-rs/``), + usable directly as the ``sdir`` argument to :func:`~whippersnappy.snap4`. + * ``lh_white`` -- path to ``surf/lh.white``. + * ``lh_curv`` -- path to ``surf/lh.curv``. + * ``lh_thickness`` -- path to ``surf/lh.thickness``. + * ``rh_white`` -- path to ``surf/rh.white``. + * ``rh_curv`` -- path to ``surf/rh.curv``. + * ``rh_thickness`` -- path to ``surf/rh.thickness``. + * ``lh_annot`` -- path to ``label/lh.aparc.DKTatlas.mapped.annot``. + * ``lh_label`` -- path to ``label/lh.cortex.label``. + * ``rh_annot`` -- path to ``label/rh.aparc.DKTatlas.mapped.annot``. + * ``rh_label`` -- path to ``label/rh.cortex.label``. + + Raises + ------ + ImportError + If ``pooch`` is not installed. Install with + ``pip install 'whippersnappy[notebook]'``. + + Notes + ----- + Data from the Rhineland Study (Koch et al.), + https://doi.org/10.5281/zenodo.11186582, CC BY 4.0. + + Examples + -------- + >>> from whippersnappy import fetch_sample_subject + >>> data = fetch_sample_subject() + >>> print(data["sdir"]) + """ + try: + import pooch + except ImportError as e: + raise ImportError( + "fetch_sample_subject() requires pooch. " + "Install with: pip install 'whippersnappy[notebook]'" + ) from e + + base = Path(pooch.os_cache("whippersnappy")) / "sub-rs" + + for rel_path, known_hash in _FILES.items(): + rel = Path(rel_path) + pooch.retrieve( + url=RELEASE_URL.format(file_name=rel.name), + known_hash=known_hash, + fname=rel.name, + path=base / rel.parent, + ) + + return { + "sdir": str(base), + "lh_white": str(base / "surf/lh.white"), + "lh_curv": str(base / "surf/lh.curv"), + "lh_thickness": str(base / "surf/lh.thickness"), + "rh_white": str(base / "surf/rh.white"), + "rh_curv": str(base / "surf/rh.curv"), + "rh_thickness": str(base / "surf/rh.thickness"), + "lh_annot": str(base / "label/lh.aparc.DKTatlas.mapped.annot"), + "lh_label": str(base / "label/lh.cortex.label"), + "rh_annot": str(base / "label/rh.aparc.DKTatlas.mapped.annot"), + "rh_label": str(base / "label/rh.cortex.label"), + } + From a0e8f961f4d37e5ee84d89869a347ca8233e7d0f Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Fri, 20 Feb 2026 18:37:23 +0100 Subject: [PATCH 52/83] update tutorial ipynb --- pyproject.toml | 2 +- tutorials/whippersnappy_tutorial.ipynb | 468 +++++++++++++++---------- whippersnappy/utils/datasets.py | 14 +- 3 files changed, 283 insertions(+), 201 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2f95100..e040501 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,9 +70,9 @@ doc = [ 'ipywidgets', ] notebook = [ - 'pooch>=1.6', # For downloading tutorial data files in the notebook 'pythreejs', # Three.js for interactive 3D (works in all Jupyter environments) 'ipywidgets', # Required for pythreejs + 'pooch>=1.6', ] gui = [ 'PyQt6', diff --git a/tutorials/whippersnappy_tutorial.ipynb b/tutorials/whippersnappy_tutorial.ipynb index f89d2bb..0b124ea 100644 --- a/tutorials/whippersnappy_tutorial.ipynb +++ b/tutorials/whippersnappy_tutorial.ipynb @@ -1,203 +1,285 @@ { "cells": [ { + "cell_type": "markdown", + "id": "e2612940", "metadata": {}, - "cell_type": "raw", - "source": [ - "{\n", - " \"cells\": [\n", - " {\n", - " \"cell_type\": \"markdown\",\n", - " \"metadata\": {},\n", - " \"source\": [\n", - " \"# WhipperSnapPy Tutorial\\n\",\n", - " \"\\n\",\n", - " \"This notebook demonstrates static and interactive 3D brain surface visualization\\n\",\n", - " \"using WhipperSnapPy. It covers single-view snapshots (`snap1`), four-view overview\\n\",\n", - " \"images (`snap4`), and interactive WebGL rendering (`plot3d`).\\n\",\n", - " \"\\n\",\n", - " \"**Tutorial data** from the Rhineland Study (Koch et al.),\\n\",\n", - " \"[Zenodo: https://doi.org/10.5281/zenodo.11186582](https://doi.org/10.5281/zenodo.11186582), CC BY 4.0.\\n\",\n", - " \"\\n\",\n", - " \"Files are downloaded once and cached locally — no manual setup required.\"\n", - " ]\n", - " },\n", - " {\n", - " \"cell_type\": \"code\",\n", - " \"execution_count\": null,\n", - " \"metadata\": {},\n", - " \"outputs\": [],\n", - " \"source\": [\n", - " \"from IPython.display import display\\n\",\n", - " \"\\n\",\n", - " \"from whippersnappy import fetch_sample_subject, snap1, snap4\\n\",\n", - " \"from whippersnappy.utils.types import ViewType\\n\",\n", - " \"\\n\",\n", - " \"# Downloads ~10 MB on first run; subsequent calls use the local cache.\\n\",\n", - " \"data = fetch_sample_subject()\\n\",\n", - " \"print(\\\"Sample subject cached at:\\\", data[\\\"sdir\\\"])\"\n", - " ]\n", - " },\n", - " {\n", - " \"cell_type\": \"markdown\",\n", - " \"metadata\": {},\n", - " \"source\": [\n", - " \"## snap1 — Basic Single View\\n\",\n", - " \"\\n\",\n", - " \"`snap1` renders a single static view of a surface mesh into a PIL Image.\\n\",\n", - " \"Here we render the left hemisphere with curvature texturing only (no overlay),\\n\",\n", - " \"which gives the classic sulcal depth shading.\"\n", - " ]\n", - " },\n", - " {\n", - " \"cell_type\": \"code\",\n", - " \"execution_count\": null,\n", - " \"metadata\": {},\n", - " \"outputs\": [],\n", - " \"source\": [\n", - " \"img = snap1(data[\\\"lh_white\\\"], curvpath=data[\\\"lh_curv\\\"])\\n\",\n", - " \"display(img)\"\n", - " ]\n", - " },\n", - " {\n", - " \"cell_type\": \"markdown\",\n", - " \"metadata\": {},\n", - " \"source\": [\n", - " \"## snap1 — With Thickness Overlay\\n\",\n", - " \"\\n\",\n", - " \"By passing `overlaypath` and `labelpath`, the surface is colored by cortical\\n\",\n", - " \"thickness values, masked to the cortex label. The `view` parameter selects\\n\",\n", - " \"the lateral view of the left hemisphere.\"\n", - " ]\n", - " },\n", - " {\n", - " \"cell_type\": \"code\",\n", - " \"execution_count\": null,\n", - " \"metadata\": {},\n", - " \"outputs\": [],\n", - " \"source\": [\n", - " \"img = snap1(\\n\",\n", - " \" data[\\\"lh_white\\\"],\\n\",\n", - " \" overlaypath=data[\\\"lh_thickness\\\"],\\n\",\n", - " \" curvpath=data[\\\"lh_curv\\\"],\\n\",\n", - " \" labelpath=data[\\\"lh_label\\\"],\\n\",\n", - " \" view=ViewType.LEFT,\\n\",\n", - " \")\\n\",\n", - " \"display(img)\"\n", - " ]\n", - " },\n", - " {\n", - " \"cell_type\": \"markdown\",\n", - " \"metadata\": {},\n", - " \"source\": [\n", - " \"## snap1 — With Parcellation Annotation\\n\",\n", - " \"\\n\",\n", - " \"`annotpath` accepts a FreeSurfer `.annot` file and colors each vertex by\\n\",\n", - " \"its parcellation label. This example uses the DKTatlas parcellation.\"\n", - " ]\n", - " },\n", - " {\n", - " \"cell_type\": \"code\",\n", - " \"execution_count\": null,\n", - " \"metadata\": {},\n", - " \"outputs\": [],\n", - " \"source\": [\n", - " \"img = snap1(\\n\",\n", - " \" data[\\\"lh_white\\\"],\\n\",\n", - " \" annotpath=data[\\\"lh_annot\\\"],\\n\",\n", - " \" curvpath=data[\\\"lh_curv\\\"],\\n\",\n", - " \")\\n\",\n", - " \"display(img)\"\n", - " ]\n", - " },\n", - " {\n", - " \"cell_type\": \"markdown\",\n", - " \"metadata\": {},\n", - " \"source\": [\n", - " \"## snap4 — Four-View Overview\\n\",\n", - " \"\\n\",\n", - " \"`snap4` renders lateral and medial views of both hemispheres and stitches\\n\",\n", - " \"them into a single composed image. Here we use the parcellation annotation\\n\",\n", - " \"for both hemispheres.\"\n", - " ]\n", - " },\n", - " {\n", - " \"cell_type\": \"code\",\n", - " \"execution_count\": null,\n", - " \"metadata\": {},\n", - " \"outputs\": [],\n", - " \"source\": [\n", - " \"img = snap4(\\n\",\n", - " \" sdir=data[\\\"sdir\\\"],\\n\",\n", - " \" lhannotpath=data[\\\"lh_annot\\\"],\\n\",\n", - " \" rhannotpath=data[\\\"rh_annot\\\"],\\n\",\n", - " \")\\n\",\n", - " \"display(img)\"\n", - " ]\n", - " },\n", - " {\n", - " \"cell_type\": \"markdown\",\n", - " \"metadata\": {},\n", - " \"source\": [\n", - " \"## plot3d — Interactive 3D Viewer\\n\",\n", - " \"\\n\",\n", - " \"`plot3d` creates an interactive Three.js/WebGL viewer that works in all\\n\",\n", - " \"Jupyter environments. You can rotate, zoom, and pan with the mouse.\\n\",\n", - " \"Requires `pip install 'whippersnappy[notebook]'`.\"\n", - " ]\n", - " },\n", - " {\n", - " \"cell_type\": \"code\",\n", - " \"execution_count\": null,\n", - " \"metadata\": {},\n", - " \"outputs\": [],\n", - " \"source\": [\n", - " \"from whippersnappy import plot3d\\n\",\n", - " \"\\n\",\n", - " \"viewer = plot3d(\\n\",\n", - " \" meshpath=data[\\\"lh_white\\\"],\\n\",\n", - " \" curvpath=data[\\\"lh_curv\\\"],\\n\",\n", - " \" overlaypath=data[\\\"lh_thickness\\\"],\\n\",\n", - " \")\\n\",\n", - " \"display(viewer)\"\n", - " ]\n", - " }\n", - " ],\n", - " \"metadata\": {\n", - " \"kernelspec\": {\n", - " \"display_name\": \"Python 3\",\n", - " \"language\": \"python\",\n", - " \"name\": \"python3\"\n", - " },\n", - " \"language_info\": {\n", - " \"name\": \"python\",\n", - " \"version\": \"3.11.0\"\n", - " }\n", - " },\n", - " \"nbformat\": 4,\n", - " \"nbformat_minor\": 5\n", - "}\n" - ], - "id": "4f5ff62ca6bf4fd9" + "source": [ + "# WhipperSnapPy Tutorial\n", + "\n", + "This notebook demonstrates static and interactive 3D brain surface visualization\n", + "using WhipperSnapPy. It covers single-view snapshots (`snap1`), four-view overview\n", + "images (`snap4`), and interactive WebGL rendering (`plot3d`).\n", + "\n", + "**Tutorial data** from the Rhineland Study (Koch et al.),\n", + "[Zenodo: https://doi.org/10.5281/zenodo.11186582](https://doi.org/10.5281/zenodo.11186582), CC BY 4.0." + ] + }, + { + "cell_type": "markdown", + "id": "6a29b826", + "metadata": {}, + "source": [ + "## Subject Directory\n", + "\n", + "Set `sdir` to your own FreeSurfer subject directory.\n", + "If you leave it empty, the sample subject **sub-rs** (one anonymized subject\n", + "from the Rhineland Study) is downloaded automatically (~20 MB, cached locally\n", + "after the first run)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a328fa6", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from whippersnappy import fetch_sample_subject\n", + "\n", + "# Set sdir to your FreeSurfer subject directory.\n", + "# Leave empty (\"\") to automatically download the sample subject (sub-rs,\n", + "# one anonymized subject from the Rhineland Study, ~20 MB, cached after\n", + "# first download).\n", + "sdir = \"\"\n", + "# sdir = \"/path/to/your/subject\"\n", + "\n", + "if not sdir:\n", + " sdir = fetch_sample_subject()[\"sdir\"] # downloads to OS cache as \"sub-rs/\"\n", + "\n", + "print(\"Subject directory:\", sdir)" + ] + }, + { + "cell_type": "markdown", + "id": "f733ebe4", + "metadata": {}, + "source": [ + "### Derive file paths from `sdir`\n", + "\n", + "All paths are constructed from `sdir`, so switching between subjects only\n", + "requires changing the single variable above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50469f2b", + "metadata": {}, + "outputs": [], + "source": [ + "# Surfaces\n", + "lh_white = os.path.join(sdir, \"surf\", \"lh.white\")\n", + "rh_white = os.path.join(sdir, \"surf\", \"rh.white\")\n", + "\n", + "# Curvature\n", + "lh_curv = os.path.join(sdir, \"surf\", \"lh.curv\")\n", + "rh_curv = os.path.join(sdir, \"surf\", \"rh.curv\")\n", + "\n", + "# Thickness overlay\n", + "lh_thickness = os.path.join(sdir, \"surf\", \"lh.thickness\")\n", + "rh_thickness = os.path.join(sdir, \"surf\", \"rh.thickness\")\n", + "\n", + "# Cortex label (mask for overlay)\n", + "lh_label = os.path.join(sdir, \"label\", \"lh.cortex.label\")\n", + "rh_label = os.path.join(sdir, \"label\", \"rh.cortex.label\")\n", + "\n", + "# Parcellation annotation (DKTatlas)\n", + "lh_annot = os.path.join(sdir, \"label\", \"lh.aparc.DKTatlas.mapped.annot\")\n", + "rh_annot = os.path.join(sdir, \"label\", \"rh.aparc.DKTatlas.mapped.annot\")" + ] + }, + { + "cell_type": "markdown", + "id": "15372c3c", + "metadata": {}, + "source": [ + "## snap1 — Basic Single View\n", + "\n", + "`snap1` renders a single static view of a surface mesh into a PIL Image.\n", + "Here we render the left hemisphere with curvature texturing only (no overlay),\n", + "which gives the classic sulcal depth shading." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b06b95de", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import display\n", + "from whippersnappy import snap1\n", + "\n", + "img = snap1(lh_white, curvpath=lh_curv)\n", + "display(img)" + ] + }, + { + "cell_type": "markdown", + "id": "8cc222d3", + "metadata": {}, + "source": [ + "## snap1 — With Thickness Overlay\n", + "\n", + "By passing `overlaypath` and `labelpath`, the surface is colored by cortical\n", + "thickness values, masked to the cortex label. The `view` parameter selects\n", + "the lateral view of the left hemisphere." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc3fb875", + "metadata": {}, + "outputs": [], + "source": [ + "from whippersnappy.utils.types import ViewType\n", + "\n", + "img = snap1(\n", + " lh_white,\n", + " overlaypath=lh_thickness,\n", + " curvpath=lh_curv,\n", + " labelpath=lh_label,\n", + " view=ViewType.LEFT,\n", + ")\n", + "display(img)" + ] + }, + { + "cell_type": "markdown", + "id": "07bc4394", + "metadata": {}, + "source": [ + "## snap1 — With Parcellation Annotation\n", + "\n", + "`annotpath` accepts a FreeSurfer `.annot` file and colors each vertex by\n", + "its parcellation label. This example uses the DKTatlas parcellation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e99cbdaa", + "metadata": {}, + "outputs": [], + "source": [ + "img = snap1(\n", + " lh_white,\n", + " annotpath=lh_annot,\n", + " curvpath=lh_curv,\n", + ")\n", + "display(img)" + ] + }, + { + "cell_type": "markdown", + "id": "8a84ffc4", + "metadata": {}, + "source": [ + "## snap4 — Four-View Overview\n", + "\n", + "`snap4` renders lateral and medial views of both hemispheres and stitches\n", + "them into a single composed image. Here we color both hemispheres by\n", + "cortical thickness, masked to the cortex label." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef4de420", + "metadata": {}, + "outputs": [], + "source": [ + "from whippersnappy import snap4\n", + "\n", + "img = snap4(\n", + " sdir=sdir,\n", + " lhoverlaypath=lh_thickness,\n", + " rhoverlaypath=rh_thickness,\n", + " colorbar=True,\n", + " caption=\"Cortical Thickness (mm)\",\n", + ")\n", + "display(img)" + ] + }, + { + "cell_type": "markdown", + "id": "f1699eee", + "metadata": {}, + "source": [ + "## plot3d — Interactive 3D Viewer\n", + "\n", + "`plot3d` creates an interactive Three.js/WebGL viewer that works in all\n", + "Jupyter environments. You can rotate, zoom, and pan with the mouse.\n", + "Requires `pip install 'whippersnappy[notebook]'`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7ef3287", + "metadata": {}, + "outputs": [], + "source": [ + "from whippersnappy import plot3d\n", + "\n", + "viewer = plot3d(\n", + " meshpath=lh_white,\n", + " curvpath=lh_curv,\n", + " overlaypath=lh_thickness,\n", + ")\n", + "display(viewer)" + ] + }, + { + "cell_type": "markdown", + "id": "6ded04e8", + "metadata": {}, + "source": [ + "## snap_rotate — Rotating 360° Animation\n", + "\n", + "`snap_rotate` renders a full 360° rotation of the surface. We output an\n", + "animated GIF so it displays inline in all Jupyter environments including\n", + "PyCharm. Use `.mp4` as `outpath` instead for a smaller file when playing\n", + "outside the notebook.\n", + "This cell takes the longest to run — execute it last." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ac0c275", + "metadata": {}, + "outputs": [], + "source": [ + "from whippersnappy import snap_rotate\n", + "from IPython.display import Image\n", + "\n", + "outpath_gif = \"/tmp/lh_thickness_rotate.gif\"\n", + "\n", + "snap_rotate(\n", + " meshpath=lh_white,\n", + " outpath=outpath_gif,\n", + " overlaypath=lh_thickness,\n", + " curvpath=lh_curv,\n", + " labelpath=lh_label,\n", + " n_frames=72,\n", + " fps=24,\n", + " width=800,\n", + " height=600,\n", + ")\n", + "print(\"GIF saved to:\", outpath_gif)\n", + "Image(filename=outpath_gif)" + ] } ], "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" } }, "nbformat": 4, diff --git a/whippersnappy/utils/datasets.py b/whippersnappy/utils/datasets.py index 3facfc2..ed1a47e 100644 --- a/whippersnappy/utils/datasets.py +++ b/whippersnappy/utils/datasets.py @@ -15,13 +15,13 @@ # GitHub release assets are flat (no subdirectories), so the URL uses only # the basename while pooch.retrieve() reconstructs the subdirectory locally. _FILES = { - "README.md": "sha256:ecb6ddf31cec17f3a8636fc3ecac90099c441228811efed56104e29fcd301bc5", - "surf/lh.white": "sha256:4ab049fb42ca882ba9b56f8fe0d0e8814973e7fa2e0575a794d8e468abf7d62f", - "surf/lh.curv": "sha256:9edbde57be8593cd9d89d9d1124e2175edd8ecfee55d53e066d89700c480b12a", - "surf/lh.thickness": "sha256:40ab3483284608c6c5cca2d3d794a60cd1bcbeb0140bb1ca6ad0fce7962c57c6", - "surf/rh.white": "sha256:43035c53a8b04bebe4e843c34f80588f253f79052a8dbf7194b706495b11f8d2", - "surf/rh.curv": "sha256:af2bc71133d7ef17ce1a3a6f4208d2495a5a4c96da00c80b59be03bb7c8ea83f", - "surf/rh.thickness": "sha256:50ec291c73928cd697156edd9e0e77f5c54d15c56cf84810d2564b496876e132", + "README.md": "sha256:ecb6ddf31cec17f3a8636fc3ecac90099c441228811efed56104e29fcd301bc5", + "surf/lh.white": "sha256:4ab049fb42ca882ba9b56f8fe0d0e8814973e7fa2e0575a794d8e468abf7d62f", + "surf/lh.curv": "sha256:9edbde57be8593cd9d89d9d1124e2175edd8ecfee55d53e066d89700c480b12a", + "surf/lh.thickness": "sha256:40ab3483284608c6c5cca2d3d794a60cd1bcbeb0140bb1ca6ad0fce7962c57c6", + "surf/rh.white": "sha256:43035c53a8b04bebe4e843c34f80588f253f79052a8dbf7194b706495b11f8d2", + "surf/rh.curv": "sha256:af2bc71133d7ef17ce1a3a6f4208d2495a5a4c96da00c80b59be03bb7c8ea83f", + "surf/rh.thickness": "sha256:50ec291c73928cd697156edd9e0e77f5c54d15c56cf84810d2564b496876e132", "label/lh.aparc.DKTatlas.mapped.annot": "sha256:4d48d33f4fd8278ab973a1552f6ea9c396dfc1791b707ed17ad8e761299c4960", "label/lh.cortex.label": "sha256:79ae17fcfde6b2e0a75a0652fcc0f3c072e4ea62a541843b7338e01c598b0b6e", "label/rh.aparc.DKTatlas.mapped.annot": "sha256:12217166d8ef43ee1fa280511ec2ba0796c6885f527a4455b93760acc73ce273", From dd880efb4dfe9c73d79f0d3a884b965397ca015a Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Fri, 20 Feb 2026 19:55:36 +0100 Subject: [PATCH 53/83] sphinx build update --- DOCKER.md | 2 +- README.md | 188 ++++++++----------------- doc/api/cli.rst | 23 +++ doc/api/index.rst | 18 ++- doc/api/plot3d.rst | 6 + doc/api/snap.rst | 6 + doc/conf.py | 25 +++- doc/docker.rst | 4 - doc/index.rst | 4 +- doc/tutorials/index.rst | 12 ++ tutorials/whippersnappy_tutorial.ipynb | 183 +++++++----------------- whippersnappy/cli/whippersnap.py | 28 +++- whippersnappy/cli/whippersnap1.py | 80 ++++++++++- whippersnappy/cli/whippersnap4.py | 34 +++++ whippersnappy/utils/datasets.py | 45 ++++-- 15 files changed, 361 insertions(+), 297 deletions(-) create mode 100644 doc/api/cli.rst create mode 100644 doc/api/plot3d.rst create mode 100644 doc/api/snap.rst delete mode 100644 doc/docker.rst create mode 100644 doc/tutorials/index.rst diff --git a/DOCKER.md b/DOCKER.md index c8b4b7b..dc7cd44 100644 --- a/DOCKER.md +++ b/DOCKER.md @@ -1,4 +1,4 @@ -# WhipperSnapPy — Docker Guide +# Docker Guide The Docker image provides a fully headless rendering environment with EGL off-screen support. No display server or `xvfb` is required. diff --git a/README.md b/README.md index 799c397..bff6885 100644 --- a/README.md +++ b/README.md @@ -1,23 +1,13 @@ # WhipperSnapPy -WhipperSnapPy is a Python OpenGL program to render FreeSurfer and -FastSurfer surface models with color overlays or parcellations and generate -screenshots. +WhipperSnapPy is a Python/OpenGL tool to render FreeSurfer and FastSurfer +surface models with color overlays or parcellations and generate screenshots +— from the command line, in Jupyter notebooks, or via a desktop GUI. -## Contents: - -- `snap1` — single-view surface snapshot -- `snap4` — four-view composed image (lateral/medial, both hemispheres) -- `snap_rotate` — 360° rotation video (MP4, WebM, or GIF) -- `plot3d` — interactive 3D WebGL viewer for Jupyter notebooks -- `whippersnap` — desktop GUI with live Qt controls - -## Installation: - -The `WhipperSnapPy` package can be installed from PyPI via: +## Installation ```bash -python3 -m pip install whippersnappy +pip install whippersnappy ``` For rotation video support (MP4/WebM): @@ -41,151 +31,91 @@ pip install 'whippersnappy[notebook]' Off-screen (headless) rendering is supported natively via EGL on Linux — no `xvfb` required. See the Docker guide for headless usage. -## Usage: +## Command-Line Usage -### Local: +After installation the following commands are available: -After installing the Python package, the command-line tools can be run as in -the following examples: +### Four-view snapshot (`whippersnap4`) -```bash -# Four-view batch rendering (both hemispheres) -whippersnap4 -lh $LH_OVERLAY -rh $RH_OVERLAY \ - -sd $SURF_SUBJECT_DIR \ - --fmax 4 --fthresh 2 --invert \ - --caption "My caption" \ - -o $OUTPUT_DIR/snap4.png - -# Single-view snapshot -whippersnap1 $SURF_SUBJECT_DIR/surf/lh.white \ - --overlay $LH_OVERLAY \ - --view left -o $OUTPUT_DIR/snap1.png +Renders lateral and medial views of both hemispheres into a single composed image: -# 360° rotation video -whippersnap1 $SURF_SUBJECT_DIR/surf/lh.white \ - --overlay $LH_OVERLAY \ - --rotate -o $OUTPUT_DIR/rotation.mp4 +```bash +whippersnap4 -lh $LH_OVERLAY \ + -rh $RH_OVERLAY \ + -sd $SUBJECT_DIR \ + --fmax 4 --fthresh 2 \ + --caption "Cortical Thickness" \ + -o snap4.png ``` -For more options see `whippersnap4 --help` or `whippersnap1 --help`. +### Single-view snapshot (`whippersnap1`) -## Quick Imports +Renders one surface view: -```python -from whippersnappy import snap1, snap4, snap_rotate, plot3d +```bash +whippersnap1 $SUBJECT_DIR/surf/lh.white \ + --overlay $LH_OVERLAY \ + --view left \ + -o snap1.png ``` -### Jupyter Notebooks: - -WhipperSnapPy supports both static and **fully interactive 3D visualization** in Jupyter notebooks. +### Rotation video (`whippersnap1 --rotate`) -#### Interactive 3D Plotting - -For **interactive mouse-controlled 3D rendering**: +Renders a 360° animation: ```bash -pip install 'whippersnappy[notebook]' -``` - -```python -from whippersnappy import plot3d -from IPython.display import display - -viewer = plot3d( - meshpath='/path/to/surf/lh.white', - curvpath='/path/to/surf/lh.curv', # curvature - overlaypath='/path/to/surf/lh.thickness', # optional: for colored overlays - labelpath='/path/to/label/lh.cortex', # optional: for masking - minval=0.0, - maxval=5.5, - width=800, - height=800, -) -display(viewer) +whippersnap1 $SUBJECT_DIR/surf/lh.white \ + --overlay $LH_OVERLAY \ + --rotate \ + -o rotation.mp4 ``` -**Features:** -- ✅ Works in ALL Jupyter environments (browser, JupyterLab, Colab, VS Code) -- ✅ Mouse-controlled rotation, zoom, and pan -- ✅ Professional lighting (Three.js/WebGL) -- ✅ Supports overlays, annotations, and curvature -- ✅ Same technology Plotly uses for 3D plots +### Desktop GUI (`whippersnap`) -#### Static Rendering +Launches an interactive Qt window with live threshold controls: -For static publication-quality images: - -```python -from whippersnappy import snap1 -from whippersnappy.utils.types import ViewType -from IPython.display import display - -img = snap1( - meshpath='/path/to/surf/lh.white', - overlaypath='/path/to/surf/lh.thickness', - curvpath='/path/to/surf/lh.curv', - view=ViewType.LEFT, # or RIGHT, FRONT, BACK, TOP, BOTTOM - width=800, - height=800, - brain_scale=1.5, - specular=True, -) -display(img) +```bash +pip install 'whippersnappy[gui]' +whippersnap --lh $LH_OVERLAY --sdir $SUBJECT_DIR ``` -**Benefits:** -- ✅ Full PyOpenGL control for custom lighting -- ✅ Publication-quality output -- ✅ Fast performance -- ✅ Identical to GUI version - -See `tutorials/whippersnappy_tutorial.ipynb` for complete examples. +For all options run `whippersnap4 --help`, `whippersnap1 --help`, or `whippersnap --help`. -### Desktop GUI: +## Python API -For interactive desktop visualization with Qt controls: - -```bash -whippersnap -lh /path/to/lh.thickness -sd /path/to/subject +```python +from whippersnappy import snap1, snap4, snap_rotate, plot3d ``` -This launches a native desktop GUI with a live OpenGL window and a -configuration panel for adjusting overlay thresholds at runtime. -Requires `pip install 'whippersnappy[gui]'`. +| Function | Description | +|---|---| +| `snap1` | Single-view surface snapshot → PIL Image | +| `snap4` | Four-view composed image (lateral/medial, both hemispheres) | +| `snap_rotate` | 360° rotation video (MP4, WebM, or GIF) | +| `plot3d` | Interactive 3D WebGL viewer for Jupyter notebooks | -### Docker: +### Example -The Docker image provides a fully headless EGL rendering environment — no -display server or `xvfb` required. - -Build the image: - -```bash -docker build --rm -t whippersnappy -f Dockerfile . +```python +from whippersnappy import snap4 +img = snap4(sdir='/path/to/subject', + lhoverlaypath='/path/to/lh.thickness', + rhoverlaypath='/path/to/rh.thickness', + colorbar=True, caption='Cortical Thickness (mm)') +img.save('snap4.png') ``` -Run a four-view batch snapshot: +See `tutorials/whippersnappy_tutorial.ipynb` for complete notebook examples. -```bash -docker run --rm --init \ - -v /path/to/subject:/subject \ - -v /path/to/output:/output \ - --user $(id -u):$(id -g) \ - whippersnappy \ - -lh /subject/surf/lh.thickness \ - -rh /subject/surf/rh.thickness \ - -sd /subject \ - -o /output/snap4.png -``` - -For single-view snapshots, rotation videos, annotation overlays, custom -thresholds, and more examples see DOCKER.md. +## Docker +The Docker image provides a fully headless EGL rendering environment — no +display server or `xvfb` required. See DOCKER.md for details. ## API Documentation -The API Documentation can be found at https://deep-mi.org/WhipperSnapPy . +https://deep-mi.org/WhipperSnapPy -## Links: +## Links -We also invite you to check out our lab webpage at https://deep-mi.org +Lab webpage: https://deep-mi.org diff --git a/doc/api/cli.rst b/doc/api/cli.rst new file mode 100644 index 0000000..6363ba2 --- /dev/null +++ b/doc/api/cli.rst @@ -0,0 +1,23 @@ +Command-Line Interfaces +======================= + +whippersnap1 +------------ + +.. automodule:: whippersnappy.cli.whippersnap1 + :members: + :member-order: bysource + +whippersnap4 +------------ + +.. automodule:: whippersnappy.cli.whippersnap4 + :members: + :member-order: bysource + +whippersnap +----------- + +.. automodule:: whippersnappy.cli.whippersnap + :members: + :member-order: bysource diff --git a/doc/api/index.rst b/doc/api/index.rst index 3083ecc..deed527 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -1,15 +1,13 @@ -API References -============== +.. _api_ref: +API Reference +============= .. currentmodule:: whippersnappy -.. autosummary:: - :toctree: generated/ +.. toctree:: + :maxdepth: 2 - - snap - plot3d - gui.config_app.ConfigWindow - cli.whippersnap - + cli.rst + snap.rst + plot3d.rst diff --git a/doc/api/plot3d.rst b/doc/api/plot3d.rst new file mode 100644 index 0000000..3546ce2 --- /dev/null +++ b/doc/api/plot3d.rst @@ -0,0 +1,6 @@ +plot3d +====== + +.. automodule:: whippersnappy.plot3d + :members: + :member-order: bysource diff --git a/doc/api/snap.rst b/doc/api/snap.rst new file mode 100644 index 0000000..5c8eea8 --- /dev/null +++ b/doc/api/snap.rst @@ -0,0 +1,6 @@ +snap +==== + +.. automodule:: whippersnappy.snap + :members: + :member-order: bysource diff --git a/doc/conf.py b/doc/conf.py index f274fc8..4707b81 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -45,6 +45,8 @@ "sphinxcontrib.bibtex", "sphinx_copybutton", "sphinx_design", + "nbsphinx", + "IPython.sphinxext.ipython_console_highlighting", ] # .md files are included via '.. include:: :parser: myst_parser.sphinx_' @@ -56,8 +58,8 @@ "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints", - "*.md", # exclude symlinked .md files inside doc/ - "../*.md", # exclude root-level .md files + "README.md", # symlinked from root — included inline via rst, not as a page + "../*.md", # exclude root-level .md files ] templates_path = ["_templates"] @@ -99,7 +101,8 @@ } # -- autosummary ------------------------------------------------------------- -autosummary_generate = True +# API stubs use automodule directly — no generated/ dir needed. +autosummary_generate = False # -- autodoc ----------------------------------------------------------------- autodoc_typehints = "none" @@ -188,6 +191,9 @@ r"\.__iter__", r"\.__div__", r"\.__neg__", + # Imported third-party objects exposed in plot3d module + r"\.HTML$", + r"\.VBox$", } # -- sphinxcontrib-bibtex ---------------------------------------------------- @@ -258,6 +264,19 @@ def linkcode_resolve(domain: str, info: Dict[str, str]) -> Optional[str]: # } +# -- nbsphinx ---------------------------------------------------------------- +# Re-execute notebooks during the Sphinx build so outputs appear in the docs. +# Notebooks are executed with the kernel specified by nbsphinx_kernel_name. +# The sample data is fetched from the GitHub release assets (or from the +# local sub-rs/ directory in the repo when the release is not yet published). +nbsphinx_execute = "auto" + +# Kernel to use for execution (must be installed: pip install ipykernel). +nbsphinx_kernel_name = "python3" + +# Maximum execution time per cell (seconds). +nbsphinx_timeout = 600 + # -- make sure pandoc gets installed ----------------------------------------- from inspect import getsourcefile import os diff --git a/doc/docker.rst b/doc/docker.rst deleted file mode 100644 index 467f6eb..0000000 --- a/doc/docker.rst +++ /dev/null @@ -1,4 +0,0 @@ -.. _docker: - -.. include:: DOCKER.md - :parser: myst_parser.sphinx_ diff --git a/doc/index.rst b/doc/index.rst index 3a32dc5..ed11c86 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -6,6 +6,6 @@ .. toctree:: :hidden: - docker + DOCKER.md + tutorials/index api/index - diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst new file mode 100644 index 0000000..aa08c0f --- /dev/null +++ b/doc/tutorials/index.rst @@ -0,0 +1,12 @@ +Tutorials +========= + +Hands-on notebooks demonstrating WhipperSnapPy's surface visualization +functionality. Each notebook can be downloaded and run locally — just +replace ``sdir = ""`` with the path to your own FastSurfer subject directory. + +.. toctree:: + :maxdepth: 1 + + whippersnappy_tutorial.ipynb + diff --git a/tutorials/whippersnappy_tutorial.ipynb b/tutorials/whippersnappy_tutorial.ipynb index 0b124ea..f779cce 100644 --- a/tutorials/whippersnappy_tutorial.ipynb +++ b/tutorials/whippersnappy_tutorial.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "e2612940", + "id": "2c795daa", "metadata": {}, "source": [ "# WhipperSnapPy Tutorial\n", @@ -17,7 +17,7 @@ }, { "cell_type": "markdown", - "id": "6a29b826", + "id": "8bd0a771", "metadata": {}, "source": [ "## Subject Directory\n", @@ -31,29 +31,14 @@ { "cell_type": "code", "execution_count": null, - "id": "0a328fa6", + "id": "30c14cc9", "metadata": {}, "outputs": [], - "source": [ - "import os\n", - "from whippersnappy import fetch_sample_subject\n", - "\n", - "# Set sdir to your FreeSurfer subject directory.\n", - "# Leave empty (\"\") to automatically download the sample subject (sub-rs,\n", - "# one anonymized subject from the Rhineland Study, ~20 MB, cached after\n", - "# first download).\n", - "sdir = \"\"\n", - "# sdir = \"/path/to/your/subject\"\n", - "\n", - "if not sdir:\n", - " sdir = fetch_sample_subject()[\"sdir\"] # downloads to OS cache as \"sub-rs/\"\n", - "\n", - "print(\"Subject directory:\", sdir)" - ] + "source": "import os\nfrom whippersnappy import fetch_sample_subject\n\n# Set sdir to your FreeSurfer subject directory.\n# Leave empty (\"\") to automatically use the sample subject (sub-rs,\n# one anonymized subject from the Rhineland Study). It is used directly\n# from the repository when available, otherwise downloaded (~20 MB) and\n# cached locally after the first run.\nsdir = \"\"\n# sdir = \"/path/to/your/subject\"\n\nif not sdir:\n sdir = fetch_sample_subject()[\"sdir\"]\n\nprint(\"Subject directory:\", sdir)\n" }, { "cell_type": "markdown", - "id": "f733ebe4", + "id": "73dd4d58", "metadata": {}, "source": [ "### Derive file paths from `sdir`\n", @@ -65,37 +50,17 @@ { "cell_type": "code", "execution_count": null, - "id": "50469f2b", + "id": "ec559dcd", "metadata": {}, "outputs": [], - "source": [ - "# Surfaces\n", - "lh_white = os.path.join(sdir, \"surf\", \"lh.white\")\n", - "rh_white = os.path.join(sdir, \"surf\", \"rh.white\")\n", - "\n", - "# Curvature\n", - "lh_curv = os.path.join(sdir, \"surf\", \"lh.curv\")\n", - "rh_curv = os.path.join(sdir, \"surf\", \"rh.curv\")\n", - "\n", - "# Thickness overlay\n", - "lh_thickness = os.path.join(sdir, \"surf\", \"lh.thickness\")\n", - "rh_thickness = os.path.join(sdir, \"surf\", \"rh.thickness\")\n", - "\n", - "# Cortex label (mask for overlay)\n", - "lh_label = os.path.join(sdir, \"label\", \"lh.cortex.label\")\n", - "rh_label = os.path.join(sdir, \"label\", \"rh.cortex.label\")\n", - "\n", - "# Parcellation annotation (DKTatlas)\n", - "lh_annot = os.path.join(sdir, \"label\", \"lh.aparc.DKTatlas.mapped.annot\")\n", - "rh_annot = os.path.join(sdir, \"label\", \"rh.aparc.DKTatlas.mapped.annot\")" - ] + "source": "# Surfaces\nlh_white = os.path.join(sdir, \"surf\", \"lh.white\")\nrh_white = os.path.join(sdir, \"surf\", \"rh.white\")\n\n# Curvature\nlh_curv = os.path.join(sdir, \"surf\", \"lh.curv\")\nrh_curv = os.path.join(sdir, \"surf\", \"rh.curv\")\n\n# Thickness overlay\nlh_thickness = os.path.join(sdir, \"surf\", \"lh.thickness\")\nrh_thickness = os.path.join(sdir, \"surf\", \"rh.thickness\")\n\n# Cortex label (mask for overlay)\nlh_label = os.path.join(sdir, \"label\", \"lh.cortex.label\")\nrh_label = os.path.join(sdir, \"label\", \"rh.cortex.label\")\n\n# Parcellation annotation (DKTatlas)\nlh_annot = os.path.join(sdir, \"label\", \"lh.aparc.DKTatlas.mapped.annot\")\nrh_annot = os.path.join(sdir, \"label\", \"rh.aparc.DKTatlas.mapped.annot\")\n" }, { "cell_type": "markdown", - "id": "15372c3c", + "id": "22be9ebf", "metadata": {}, "source": [ - "## snap1 — Basic Single View\n", + "## snap1 \u2014 Basic Single View\n", "\n", "`snap1` renders a single static view of a surface mesh into a PIL Image.\n", "Here we render the left hemisphere with curvature texturing only (no overlay),\n", @@ -105,23 +70,17 @@ { "cell_type": "code", "execution_count": null, - "id": "b06b95de", + "id": "783e547b", "metadata": {}, "outputs": [], - "source": [ - "from IPython.display import display\n", - "from whippersnappy import snap1\n", - "\n", - "img = snap1(lh_white, curvpath=lh_curv)\n", - "display(img)" - ] + "source": "from IPython.display import display\nfrom whippersnappy import snap1\n\nimg = snap1(lh_white, curvpath=lh_curv)\ndisplay(img)\n" }, { "cell_type": "markdown", - "id": "8cc222d3", + "id": "7173d312", "metadata": {}, "source": [ - "## snap1 — With Thickness Overlay\n", + "## snap1 \u2014 With Thickness Overlay\n", "\n", "By passing `overlaypath` and `labelpath`, the surface is colored by cortical\n", "thickness values, masked to the cortex label. The `view` parameter selects\n", @@ -131,28 +90,17 @@ { "cell_type": "code", "execution_count": null, - "id": "cc3fb875", + "id": "13407b67", "metadata": {}, "outputs": [], - "source": [ - "from whippersnappy.utils.types import ViewType\n", - "\n", - "img = snap1(\n", - " lh_white,\n", - " overlaypath=lh_thickness,\n", - " curvpath=lh_curv,\n", - " labelpath=lh_label,\n", - " view=ViewType.LEFT,\n", - ")\n", - "display(img)" - ] + "source": "from whippersnappy.utils.types import ViewType\n\nimg = snap1(\n lh_white,\n overlaypath=lh_thickness,\n curvpath=lh_curv,\n labelpath=lh_label,\n view=ViewType.LEFT,\n)\ndisplay(img)\n" }, { "cell_type": "markdown", - "id": "07bc4394", + "id": "4217c291", "metadata": {}, "source": [ - "## snap1 — With Parcellation Annotation\n", + "## snap1 \u2014 With Parcellation Annotation\n", "\n", "`annotpath` accepts a FreeSurfer `.annot` file and colors each vertex by\n", "its parcellation label. This example uses the DKTatlas parcellation." @@ -161,24 +109,17 @@ { "cell_type": "code", "execution_count": null, - "id": "e99cbdaa", + "id": "7271e902", "metadata": {}, "outputs": [], - "source": [ - "img = snap1(\n", - " lh_white,\n", - " annotpath=lh_annot,\n", - " curvpath=lh_curv,\n", - ")\n", - "display(img)" - ] + "source": "img = snap1(\n lh_white,\n annotpath=lh_annot,\n curvpath=lh_curv,\n)\ndisplay(img)\n" }, { "cell_type": "markdown", - "id": "8a84ffc4", + "id": "620c6c43", "metadata": {}, "source": [ - "## snap4 — Four-View Overview\n", + "## snap4 \u2014 Four-View Overview\n", "\n", "`snap4` renders lateral and medial views of both hemispheres and stitches\n", "them into a single composed image. Here we color both hemispheres by\n", @@ -188,28 +129,17 @@ { "cell_type": "code", "execution_count": null, - "id": "ef4de420", + "id": "903514b8", "metadata": {}, "outputs": [], - "source": [ - "from whippersnappy import snap4\n", - "\n", - "img = snap4(\n", - " sdir=sdir,\n", - " lhoverlaypath=lh_thickness,\n", - " rhoverlaypath=rh_thickness,\n", - " colorbar=True,\n", - " caption=\"Cortical Thickness (mm)\",\n", - ")\n", - "display(img)" - ] + "source": "from whippersnappy import snap4\n\nimg = snap4(\n sdir=sdir,\n lhoverlaypath=lh_thickness,\n rhoverlaypath=rh_thickness,\n colorbar=True,\n caption=\"Cortical Thickness (mm)\",\n)\ndisplay(img)\n" }, { "cell_type": "markdown", - "id": "f1699eee", + "id": "5d98c87b", "metadata": {}, "source": [ - "## plot3d — Interactive 3D Viewer\n", + "## plot3d \u2014 Interactive 3D Viewer\n", "\n", "`plot3d` creates an interactive Three.js/WebGL viewer that works in all\n", "Jupyter environments. You can rotate, zoom, and pan with the mouse.\n", @@ -219,60 +149,42 @@ { "cell_type": "code", "execution_count": null, - "id": "d7ef3287", + "id": "5d5a09c7", "metadata": {}, "outputs": [], - "source": [ - "from whippersnappy import plot3d\n", - "\n", - "viewer = plot3d(\n", - " meshpath=lh_white,\n", - " curvpath=lh_curv,\n", - " overlaypath=lh_thickness,\n", - ")\n", - "display(viewer)" - ] + "source": "from whippersnappy import plot3d\n\nviewer = plot3d(\n meshpath=lh_white,\n curvpath=lh_curv,\n overlaypath=lh_thickness,\n)\ndisplay(viewer)\n" }, { "cell_type": "markdown", - "id": "6ded04e8", + "id": "dc3970c3", "metadata": {}, "source": [ - "## snap_rotate — Rotating 360° Animation\n", + "## snap_rotate \u2014 Rotating 360\u00b0 Animation\n", "\n", - "`snap_rotate` renders a full 360° rotation of the surface. We output an\n", + "`snap_rotate` renders a full 360\u00b0 rotation of the surface. We output an\n", "animated GIF so it displays inline in all Jupyter environments including\n", "PyCharm. Use `.mp4` as `outpath` instead for a smaller file when playing\n", "outside the notebook.\n", - "This cell takes the longest to run — execute it last." + "This cell takes the longest to run \u2014 execute it last." ] }, { "cell_type": "code", "execution_count": null, - "id": "1ac0c275", + "id": "928a68ea", "metadata": {}, "outputs": [], - "source": [ - "from whippersnappy import snap_rotate\n", - "from IPython.display import Image\n", - "\n", - "outpath_gif = \"/tmp/lh_thickness_rotate.gif\"\n", - "\n", - "snap_rotate(\n", - " meshpath=lh_white,\n", - " outpath=outpath_gif,\n", - " overlaypath=lh_thickness,\n", - " curvpath=lh_curv,\n", - " labelpath=lh_label,\n", - " n_frames=72,\n", - " fps=24,\n", - " width=800,\n", - " height=600,\n", - ")\n", - "print(\"GIF saved to:\", outpath_gif)\n", - "Image(filename=outpath_gif)" - ] + "source": "from whippersnappy import snap_rotate\nfrom IPython.display import Image\n\noutpath_gif = \"/tmp/lh_thickness_rotate.gif\"\n\nsnap_rotate(\n meshpath=lh_white,\n outpath=outpath_gif,\n overlaypath=lh_thickness,\n curvpath=lh_curv,\n labelpath=lh_label,\n n_frames=72,\n fps=24,\n width=800,\n height=600,\n)\nprint(\"GIF saved to:\", outpath_gif)\n" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dcd38db4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": "Image(filename=outpath_gif)\n" } ], "metadata": { @@ -280,6 +192,15 @@ "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3" } }, "nbformat": 4, diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index 940bc46..4c839fd 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -190,9 +190,35 @@ def run(): Raises ------ RuntimeError - If PyQt6 is not installed. + If PyQt6 is not installed (``pip install 'whippersnappy[gui]'``). ValueError For invalid or mutually exclusive argument combinations. + + Notes + ----- + **Required:** + + * ``-sd`` / ``--sdir`` — subject directory containing ``surf/`` and + ``label/`` subdirectories. + * One of the following (not both): + + * ``-lh`` / ``--lh_overlay`` **and** ``-rh`` / ``--rh_overlay`` — per-vertex + scalar overlay files. + * ``--lh_annot`` **and** ``--rh_annot`` — FreeSurfer ``.annot`` + parcellation files. + + **Optional:** + + * ``-s`` / ``--surf_name`` — surface basename (e.g. ``white``); + auto-detected if not provided. + * ``-c`` / ``--caption`` — caption text shown in the viewer. + * ``--fthresh`` — overlay threshold (default: 2.0); adjustable live in the GUI. + * ``--fmax`` — overlay saturation (default: 4.0); adjustable live in the GUI. + * ``--invert`` — invert the color scale. + * ``--diffuse`` — use diffuse-only shading (no specular highlights). + + Requires ``pip install 'whippersnappy[gui]'``. + For non-interactive four-view batch rendering use ``whippersnap4``. """ global current_fthresh_, current_fmax_, app_, app_window_ logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") diff --git a/whippersnappy/cli/whippersnap1.py b/whippersnappy/cli/whippersnap1.py index 0a3583b..4d2b479 100644 --- a/whippersnappy/cli/whippersnap1.py +++ b/whippersnappy/cli/whippersnap1.py @@ -1,5 +1,35 @@ #!/usr/bin/env python3 -"""CLI entry point for single-mesh snapshot via snap1.""" +"""CLI entry point for single-mesh snapshot and rotation video via snap1/snap_rotate. + +Renders a single surface mesh from a chosen viewpoint and saves it as a PNG +image. Alternatively, pass ``--rotate`` to produce a full 360° rotation video +(MP4, WebM, or GIF). + +Usage:: + + # Static single-view snapshot (lateral view, thickness overlay) + whippersnap1 /surf/lh.white \\ + --overlay /surf/lh.thickness \\ + --curv /surf/lh.curv \\ + --label /label/lh.cortex.label \\ + --view left --fthresh 1.5 --fmax 4.0 \\ + -o snap1.png + + # 360° rotation video + whippersnap1 /surf/lh.white \\ + --overlay /surf/lh.thickness \\ + --rotate --rotate-frames 72 --rotate-fps 24 \\ + -o rotation.mp4 + + # Parcellation annotation + whippersnap1 /surf/lh.white \\ + --annot /label/lh.aparc.annot \\ + --view lateral -o snap_annot.png + +See ``whippersnap1 --help`` for the full list of options. +For four-view batch rendering use ``whippersnap4``. +For the interactive GUI use ``whippersnap``. +""" import argparse import logging @@ -16,7 +46,53 @@ def run(): - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + """Command-line entry point for single-view snapshot or rotation video. + + Parses command-line arguments, validates them, and calls either + :func:`whippersnappy.snap1` (static snapshot) or + :func:`whippersnappy.snap_rotate` (360° rotation video) depending on + whether ``--rotate`` is passed. + + Parameters + ---------- + None + All input is read from ``sys.argv`` via :mod:`argparse`. + + Raises + ------ + FileNotFoundError + If the mesh file or any overlay/annotation/label file cannot be found. + RuntimeError + If the OpenGL context cannot be initialised. + ValueError + For invalid argument combinations. + + Notes + ----- + **Snapshot options** (default mode): + + * ``meshpath`` — path to the surface file (FreeSurfer binary, e.g. ``lh.white``). + * ``--overlay`` — per-vertex scalar overlay (e.g. ``lh.thickness``). + * ``--annot`` — FreeSurfer ``.annot`` parcellation file. + * ``--label`` — label file used to mask overlay values to the cortex. + * ``--curv`` — curvature file for sulcal depth shading of uncolored vertices. + * ``--view`` — camera direction: ``left``, ``right``, ``front``, ``back``, + ``top``, ``bottom`` (default: ``left``). + * ``--fthresh`` / ``--fmax`` — overlay threshold and saturation values. + * ``--invert`` — invert the color scale. + * ``--no-colorbar`` — suppress the color bar. + * ``--caption`` — text label placed on the image. + * ``--width`` / ``--height`` — output resolution in pixels (default: 700×500). + * ``-o`` / ``--output`` — output file path (default: temp ``.png``). + + **Rotation video options** (pass ``--rotate``): + + * ``--rotate-frames`` — number of frames for one full rotation (default: 72). + * ``--rotate-fps`` — output frame rate (default: 24). + * ``--rotate-start-view`` — starting camera direction (default: ``left``). + * ``-o`` — output path; extension controls format: ``.mp4``, ``.webm``, + or ``.gif`` (GIF requires no ffmpeg). + """ parser = argparse.ArgumentParser( prog="whippersnap1", diff --git a/whippersnappy/cli/whippersnap4.py b/whippersnappy/cli/whippersnap4.py index 5835055..3d1e8cb 100644 --- a/whippersnappy/cli/whippersnap4.py +++ b/whippersnappy/cli/whippersnap4.py @@ -39,6 +39,40 @@ def run(): If the OpenGL context cannot be initialised. FileNotFoundError If required surface files cannot be found. + + Notes + ----- + **Required:** + + * ``-sd`` / ``--sdir`` — subject directory containing ``surf/`` and + ``label/`` subdirectories. + * One of the following (not both): + + * ``-lh`` / ``--lh_overlay`` **and** ``-rh`` / ``--rh_overlay`` — per-vertex + scalar overlay files for left and right hemispheres (e.g. ``lh.thickness``). + * ``--lh_annot`` **and** ``--rh_annot`` — FreeSurfer ``.annot`` parcellation + files for left and right hemispheres. + + **Output:** + + * ``-o`` / ``--output_path`` — output image path (default: temp ``.png``). + + **Overlay appearance:** + + * ``--fthresh`` / ``--fmax`` — threshold and saturation values + (auto-estimated if not set). + * ``--invert`` — invert the color scale. + * ``--no-colorbar`` — suppress the color bar. + * ``-c`` / ``--caption`` — text label placed on the figure. + + **Rendering:** + + * ``-s`` / ``--surf_name`` — surface basename (e.g. ``white``); + auto-detected if not provided. + * ``--diffuse`` — use diffuse-only shading (no specular highlights). + * ``--brain-scale`` — geometry scale factor (default: 1.85). + * ``--ambient`` — ambient light strength (default: 0.0). + * ``--font`` — path to a TTF font file for captions. """ logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") diff --git a/whippersnappy/utils/datasets.py b/whippersnappy/utils/datasets.py index ed1a47e..61cc707 100644 --- a/whippersnappy/utils/datasets.py +++ b/whippersnappy/utils/datasets.py @@ -29,6 +29,23 @@ } +def _build_dict(base: Path) -> dict: + """Build the return dictionary of paths from a subject base directory.""" + return { + "sdir": str(base), + "lh_white": str(base / "surf/lh.white"), + "lh_curv": str(base / "surf/lh.curv"), + "lh_thickness": str(base / "surf/lh.thickness"), + "rh_white": str(base / "surf/rh.white"), + "rh_curv": str(base / "surf/rh.curv"), + "rh_thickness": str(base / "surf/rh.thickness"), + "lh_annot": str(base / "label/lh.aparc.DKTatlas.mapped.annot"), + "lh_label": str(base / "label/lh.cortex.label"), + "rh_annot": str(base / "label/rh.aparc.DKTatlas.mapped.annot"), + "rh_label": str(base / "label/rh.cortex.label"), + } + + def fetch_sample_subject() -> dict: """Download and cache the WhipperSnapPy sample subject (Rhineland Study). @@ -37,6 +54,11 @@ def fetch_sample_subject() -> dict: all files. Files are only downloaded once; subsequent calls use the local cache. + If a ``sub-rs/`` directory with all required files is found next to the + package root (i.e. inside the source repository), it is used directly + without any network access. This allows the Sphinx doc build to work + before the GitHub release assets are published. + Returns ------- dict @@ -80,6 +102,14 @@ def fetch_sample_subject() -> dict: "Install with: pip install 'whippersnappy[notebook]'" ) from e + # Use a local sub-rs/ directory (present in the source repo) when all + # required files are already there — no network access needed. + _pkg_root = Path(__file__).parent.parent.parent # .../whippersnappy/ + _local = _pkg_root / "sub-rs" + if _local.is_dir() and all((_local / p).exists() for p in _FILES): + return _build_dict(_local) + + # Otherwise download from the GitHub release and cache in the OS cache dir. base = Path(pooch.os_cache("whippersnappy")) / "sub-rs" for rel_path, known_hash in _FILES.items(): @@ -91,17 +121,4 @@ def fetch_sample_subject() -> dict: path=base / rel.parent, ) - return { - "sdir": str(base), - "lh_white": str(base / "surf/lh.white"), - "lh_curv": str(base / "surf/lh.curv"), - "lh_thickness": str(base / "surf/lh.thickness"), - "rh_white": str(base / "surf/rh.white"), - "rh_curv": str(base / "surf/rh.curv"), - "rh_thickness": str(base / "surf/rh.thickness"), - "lh_annot": str(base / "label/lh.aparc.DKTatlas.mapped.annot"), - "lh_label": str(base / "label/lh.cortex.label"), - "rh_annot": str(base / "label/rh.aparc.DKTatlas.mapped.annot"), - "rh_label": str(base / "label/rh.cortex.label"), - } - + return _build_dict(base) From efb720ea6929e77b148cd99057bec34fc843412f Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Fri, 20 Feb 2026 20:06:40 +0100 Subject: [PATCH 54/83] fix ruff --- tutorials/whippersnappy_tutorial.ipynb | 137 ++++++++++++++++++++++--- 1 file changed, 120 insertions(+), 17 deletions(-) diff --git a/tutorials/whippersnappy_tutorial.ipynb b/tutorials/whippersnappy_tutorial.ipynb index f779cce..15e6d25 100644 --- a/tutorials/whippersnappy_tutorial.ipynb +++ b/tutorials/whippersnappy_tutorial.ipynb @@ -34,7 +34,24 @@ "id": "30c14cc9", "metadata": {}, "outputs": [], - "source": "import os\nfrom whippersnappy import fetch_sample_subject\n\n# Set sdir to your FreeSurfer subject directory.\n# Leave empty (\"\") to automatically use the sample subject (sub-rs,\n# one anonymized subject from the Rhineland Study). It is used directly\n# from the repository when available, otherwise downloaded (~20 MB) and\n# cached locally after the first run.\nsdir = \"\"\n# sdir = \"/path/to/your/subject\"\n\nif not sdir:\n sdir = fetch_sample_subject()[\"sdir\"]\n\nprint(\"Subject directory:\", sdir)\n" + "source": [ + "import os\n", + "\n", + "from whippersnappy import fetch_sample_subject\n", + "\n", + "# Set sdir to your FreeSurfer subject directory.\n", + "# Leave empty (\"\") to automatically use the sample subject (sub-rs,\n", + "# one anonymized subject from the Rhineland Study). It is used directly\n", + "# from the repository when available, otherwise downloaded (~20 MB) and\n", + "# cached locally after the first run.\n", + "sdir = \"\"\n", + "# sdir = \"/path/to/your/subject\"\n", + "\n", + "if not sdir:\n", + " sdir = fetch_sample_subject()[\"sdir\"]\n", + "\n", + "print(\"Subject directory:\", sdir)\n" + ] }, { "cell_type": "markdown", @@ -53,14 +70,34 @@ "id": "ec559dcd", "metadata": {}, "outputs": [], - "source": "# Surfaces\nlh_white = os.path.join(sdir, \"surf\", \"lh.white\")\nrh_white = os.path.join(sdir, \"surf\", \"rh.white\")\n\n# Curvature\nlh_curv = os.path.join(sdir, \"surf\", \"lh.curv\")\nrh_curv = os.path.join(sdir, \"surf\", \"rh.curv\")\n\n# Thickness overlay\nlh_thickness = os.path.join(sdir, \"surf\", \"lh.thickness\")\nrh_thickness = os.path.join(sdir, \"surf\", \"rh.thickness\")\n\n# Cortex label (mask for overlay)\nlh_label = os.path.join(sdir, \"label\", \"lh.cortex.label\")\nrh_label = os.path.join(sdir, \"label\", \"rh.cortex.label\")\n\n# Parcellation annotation (DKTatlas)\nlh_annot = os.path.join(sdir, \"label\", \"lh.aparc.DKTatlas.mapped.annot\")\nrh_annot = os.path.join(sdir, \"label\", \"rh.aparc.DKTatlas.mapped.annot\")\n" + "source": [ + "# Surfaces\n", + "lh_white = os.path.join(sdir, \"surf\", \"lh.white\")\n", + "rh_white = os.path.join(sdir, \"surf\", \"rh.white\")\n", + "\n", + "# Curvature\n", + "lh_curv = os.path.join(sdir, \"surf\", \"lh.curv\")\n", + "rh_curv = os.path.join(sdir, \"surf\", \"rh.curv\")\n", + "\n", + "# Thickness overlay\n", + "lh_thickness = os.path.join(sdir, \"surf\", \"lh.thickness\")\n", + "rh_thickness = os.path.join(sdir, \"surf\", \"rh.thickness\")\n", + "\n", + "# Cortex label (mask for overlay)\n", + "lh_label = os.path.join(sdir, \"label\", \"lh.cortex.label\")\n", + "rh_label = os.path.join(sdir, \"label\", \"rh.cortex.label\")\n", + "\n", + "# Parcellation annotation (DKTatlas)\n", + "lh_annot = os.path.join(sdir, \"label\", \"lh.aparc.DKTatlas.mapped.annot\")\n", + "rh_annot = os.path.join(sdir, \"label\", \"rh.aparc.DKTatlas.mapped.annot\")\n" + ] }, { "cell_type": "markdown", "id": "22be9ebf", "metadata": {}, "source": [ - "## snap1 \u2014 Basic Single View\n", + "## snap1 — Basic Single View\n", "\n", "`snap1` renders a single static view of a surface mesh into a PIL Image.\n", "Here we render the left hemisphere with curvature texturing only (no overlay),\n", @@ -73,14 +110,21 @@ "id": "783e547b", "metadata": {}, "outputs": [], - "source": "from IPython.display import display\nfrom whippersnappy import snap1\n\nimg = snap1(lh_white, curvpath=lh_curv)\ndisplay(img)\n" + "source": [ + "from IPython.display import display\n", + "\n", + "from whippersnappy import snap1\n", + "\n", + "img = snap1(lh_white, curvpath=lh_curv)\n", + "display(img)\n" + ] }, { "cell_type": "markdown", "id": "7173d312", "metadata": {}, "source": [ - "## snap1 \u2014 With Thickness Overlay\n", + "## snap1 — With Thickness Overlay\n", "\n", "By passing `overlaypath` and `labelpath`, the surface is colored by cortical\n", "thickness values, masked to the cortex label. The `view` parameter selects\n", @@ -93,14 +137,25 @@ "id": "13407b67", "metadata": {}, "outputs": [], - "source": "from whippersnappy.utils.types import ViewType\n\nimg = snap1(\n lh_white,\n overlaypath=lh_thickness,\n curvpath=lh_curv,\n labelpath=lh_label,\n view=ViewType.LEFT,\n)\ndisplay(img)\n" + "source": [ + "from whippersnappy.utils.types import ViewType\n", + "\n", + "img = snap1(\n", + " lh_white,\n", + " overlaypath=lh_thickness,\n", + " curvpath=lh_curv,\n", + " labelpath=lh_label,\n", + " view=ViewType.LEFT,\n", + ")\n", + "display(img)\n" + ] }, { "cell_type": "markdown", "id": "4217c291", "metadata": {}, "source": [ - "## snap1 \u2014 With Parcellation Annotation\n", + "## snap1 — With Parcellation Annotation\n", "\n", "`annotpath` accepts a FreeSurfer `.annot` file and colors each vertex by\n", "its parcellation label. This example uses the DKTatlas parcellation." @@ -112,14 +167,21 @@ "id": "7271e902", "metadata": {}, "outputs": [], - "source": "img = snap1(\n lh_white,\n annotpath=lh_annot,\n curvpath=lh_curv,\n)\ndisplay(img)\n" + "source": [ + "img = snap1(\n", + " lh_white,\n", + " annotpath=lh_annot,\n", + " curvpath=lh_curv,\n", + ")\n", + "display(img)\n" + ] }, { "cell_type": "markdown", "id": "620c6c43", "metadata": {}, "source": [ - "## snap4 \u2014 Four-View Overview\n", + "## snap4 — Four-View Overview\n", "\n", "`snap4` renders lateral and medial views of both hemispheres and stitches\n", "them into a single composed image. Here we color both hemispheres by\n", @@ -132,14 +194,25 @@ "id": "903514b8", "metadata": {}, "outputs": [], - "source": "from whippersnappy import snap4\n\nimg = snap4(\n sdir=sdir,\n lhoverlaypath=lh_thickness,\n rhoverlaypath=rh_thickness,\n colorbar=True,\n caption=\"Cortical Thickness (mm)\",\n)\ndisplay(img)\n" + "source": [ + "from whippersnappy import snap4\n", + "\n", + "img = snap4(\n", + " sdir=sdir,\n", + " lhoverlaypath=lh_thickness,\n", + " rhoverlaypath=rh_thickness,\n", + " colorbar=True,\n", + " caption=\"Cortical Thickness (mm)\",\n", + ")\n", + "display(img)\n" + ] }, { "cell_type": "markdown", "id": "5d98c87b", "metadata": {}, "source": [ - "## plot3d \u2014 Interactive 3D Viewer\n", + "## plot3d — Interactive 3D Viewer\n", "\n", "`plot3d` creates an interactive Three.js/WebGL viewer that works in all\n", "Jupyter environments. You can rotate, zoom, and pan with the mouse.\n", @@ -152,20 +225,29 @@ "id": "5d5a09c7", "metadata": {}, "outputs": [], - "source": "from whippersnappy import plot3d\n\nviewer = plot3d(\n meshpath=lh_white,\n curvpath=lh_curv,\n overlaypath=lh_thickness,\n)\ndisplay(viewer)\n" + "source": [ + "from whippersnappy import plot3d\n", + "\n", + "viewer = plot3d(\n", + " meshpath=lh_white,\n", + " curvpath=lh_curv,\n", + " overlaypath=lh_thickness,\n", + ")\n", + "display(viewer)\n" + ] }, { "cell_type": "markdown", "id": "dc3970c3", "metadata": {}, "source": [ - "## snap_rotate \u2014 Rotating 360\u00b0 Animation\n", + "## snap_rotate — Rotating 360° Animation\n", "\n", - "`snap_rotate` renders a full 360\u00b0 rotation of the surface. We output an\n", + "`snap_rotate` renders a full 360° rotation of the surface. We output an\n", "animated GIF so it displays inline in all Jupyter environments including\n", "PyCharm. Use `.mp4` as `outpath` instead for a smaller file when playing\n", "outside the notebook.\n", - "This cell takes the longest to run \u2014 execute it last." + "This cell takes the longest to run — execute it last." ] }, { @@ -174,7 +256,26 @@ "id": "928a68ea", "metadata": {}, "outputs": [], - "source": "from whippersnappy import snap_rotate\nfrom IPython.display import Image\n\noutpath_gif = \"/tmp/lh_thickness_rotate.gif\"\n\nsnap_rotate(\n meshpath=lh_white,\n outpath=outpath_gif,\n overlaypath=lh_thickness,\n curvpath=lh_curv,\n labelpath=lh_label,\n n_frames=72,\n fps=24,\n width=800,\n height=600,\n)\nprint(\"GIF saved to:\", outpath_gif)\n" + "source": [ + "from IPython.display import Image\n", + "\n", + "from whippersnappy import snap_rotate\n", + "\n", + "outpath_gif = \"/tmp/lh_thickness_rotate.gif\"\n", + "\n", + "snap_rotate(\n", + " meshpath=lh_white,\n", + " outpath=outpath_gif,\n", + " overlaypath=lh_thickness,\n", + " curvpath=lh_curv,\n", + " labelpath=lh_label,\n", + " n_frames=72,\n", + " fps=24,\n", + " width=800,\n", + " height=600,\n", + ")\n", + "print(\"GIF saved to:\", outpath_gif)\n" + ] }, { "cell_type": "code", @@ -184,7 +285,9 @@ "lines_to_next_cell": 2 }, "outputs": [], - "source": "Image(filename=outpath_gif)\n" + "source": [ + "Image(filename=outpath_gif)\n" + ] } ], "metadata": { From 2bb74121586e29099dbcff6dfff17053b1c3cc42 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Fri, 20 Feb 2026 20:13:08 +0100 Subject: [PATCH 55/83] fix doc build and update show_window paramters --- doc/tutorials/whippersnappy_tutorial.ipynb | 1 + whippersnappy/cli/whippersnap.py | 41 ++++++++++------------ whippersnappy/cli/whippersnap1.py | 6 +--- 3 files changed, 21 insertions(+), 27 deletions(-) create mode 120000 doc/tutorials/whippersnappy_tutorial.ipynb diff --git a/doc/tutorials/whippersnappy_tutorial.ipynb b/doc/tutorials/whippersnappy_tutorial.ipynb new file mode 120000 index 0000000..9a82e1b --- /dev/null +++ b/doc/tutorials/whippersnappy_tutorial.ipynb @@ -0,0 +1 @@ +../../tutorials/whippersnappy_tutorial.ipynb \ No newline at end of file diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index 4c839fd..de6e861 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -53,7 +53,6 @@ def show_window( overlaypath=None, annotpath=None, sdir=None, - caption=None, invert=False, labelname="cortex.label", surfname=None, @@ -77,19 +76,18 @@ def show_window( Path to a ``.annot`` file providing categorical labels for vertices. sdir : str or None, optional Subject directory containing ``surf/`` and ``label/`` subdirectories. - caption : str or None, optional - Caption text to display in the viewer window. - invert : bool, optional, default False - Invert the overlay color mapping. - labelname : str, optional, default 'cortex.label' - Label filename used to mask vertices. + invert : bool, optional + Invert the overlay color mapping. Default is ``False``. + labelname : str, optional + Label filename used to mask vertices. Default is ``'cortex.label'``. surfname : str or None, optional Surface basename (e.g. ``'white'``); if ``None`` the function will auto-detect a suitable surface in ``sdir``. - curvname : str or None, optional, default 'curv' + curvname : str or None, optional Curvature filename used to texture non-colored regions. - specular : bool, optional, default True - Enable specular highlights in the shader. + Default is ``'curv'``. + specular : bool, optional + Enable specular highlights in the shader. Default is ``True``. Raises ------ @@ -126,7 +124,8 @@ def show_window( rot_y = pyrr.Matrix44.from_y_rotation(0) meshdata, triangles, fthresh, fmax, neg = prepare_geometry( - meshpath, overlaypath, annotpath, curvpath, labelpath, current_fthresh_, current_fmax_ + meshpath, overlaypath, annotpath, curvpath, labelpath, current_fthresh_, current_fmax_, + invert=invert, ) shader = setup_shader(meshdata, triangles, wwidth, wheight, specular=specular) @@ -150,6 +149,7 @@ def show_window( meshdata, triangles, fthresh, fmax, neg = prepare_geometry( meshpath, overlaypath, annotpath, curvpath, labelpath, current_fthresh_, current_fmax_, + invert=invert, ) shader = setup_shader(meshdata, triangles, wwidth, wheight, specular=specular) @@ -296,17 +296,14 @@ def run(): thread = threading.Thread( target=show_window, - args=( - "lh", - args.lh_overlay, - args.lh_annot, - args.sdir, - None, - False, - "cortex.label", - args.surf_name, - "curv", - args.specular, + args=("lh",), + kwargs=dict( + overlaypath=args.lh_overlay, + annotpath=args.lh_annot, + sdir=args.sdir, + invert=args.invert, + surfname=args.surf_name, + specular=args.specular, ), ) thread.start() diff --git a/whippersnappy/cli/whippersnap1.py b/whippersnappy/cli/whippersnap1.py index 4d2b479..edd0195 100644 --- a/whippersnappy/cli/whippersnap1.py +++ b/whippersnappy/cli/whippersnap1.py @@ -52,11 +52,7 @@ def run(): :func:`whippersnappy.snap1` (static snapshot) or :func:`whippersnappy.snap_rotate` (360° rotation video) depending on whether ``--rotate`` is passed. - - Parameters - ---------- - None - All input is read from ``sys.argv`` via :mod:`argparse`. + All input is read from ``sys.argv`` via :mod:`argparse`. Raises ------ From cea31fea580cdfa4877682ab462af8144f54ceda Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Fri, 20 Feb 2026 20:27:00 +0100 Subject: [PATCH 56/83] only publish on v* releases to enable data only releases --- .github/workflows/doc.yml | 2 +- whippersnappy/utils/datasets.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 632ca48..78c3a1d 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -41,7 +41,7 @@ jobs: uses: actions/cache@v4 with: path: ~/.cache/whippersnappy - key: sample-data-sub-rs-v2.0.0 + key: sample-data-sub-rs-data-v1 - name: Build doc run: TZ=UTC sphinx-build ./main/doc ./doc-build/dev -W --keep-going - name: Upload documentation diff --git a/whippersnappy/utils/datasets.py b/whippersnappy/utils/datasets.py index 61cc707..4a88e1b 100644 --- a/whippersnappy/utils/datasets.py +++ b/whippersnappy/utils/datasets.py @@ -8,7 +8,7 @@ RELEASE_URL = ( "https://github.com/Deep-MI/WhipperSnapPy" - "/releases/download/v2.0.0/{file_name}" + "/releases/download/data-v1/{file_name}" ) # Mapping of relative path inside the subject directory → SHA-256 hash. From 536a5a0e37c5783c8baedd96491f45a9de7213aa Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 15:00:19 +0100 Subject: [PATCH 57/83] refactor to allow passing of arrays or files, add tests --- README.md | 33 ++- tests/test_array_inputs.py | 269 +++++++++++++++++++ tutorials/whippersnappy_tutorial.ipynb | 34 +-- whippersnappy/__init__.py | 8 +- whippersnappy/cli/whippersnap.py | 24 +- whippersnappy/cli/whippersnap1.py | 38 +-- whippersnappy/cli/whippersnap4.py | 8 +- whippersnappy/geometry/__init__.py | 24 +- whippersnappy/geometry/inputs.py | 234 +++++++++++++++++ whippersnappy/geometry/prepare.py | 343 +++++++++++++++---------- whippersnappy/plot3d.py | 46 ++-- whippersnappy/snap.py | 210 ++++++++------- 12 files changed, 971 insertions(+), 300 deletions(-) create mode 100644 tests/test_array_inputs.py create mode 100644 whippersnappy/geometry/inputs.py diff --git a/README.md b/README.md index bff6885..9a329cf 100644 --- a/README.md +++ b/README.md @@ -97,12 +97,39 @@ from whippersnappy import snap1, snap4, snap_rotate, plot3d ### Example ```python -from whippersnappy import snap4 +from whippersnappy import snap1, snap4 + +# File-path inputs (FreeSurfer subject directory) img = snap4(sdir='/path/to/subject', - lhoverlaypath='/path/to/lh.thickness', - rhoverlaypath='/path/to/rh.thickness', + lh_overlay='/path/to/lh.thickness', + rh_overlay='/path/to/rh.thickness', colorbar=True, caption='Cortical Thickness (mm)') img.save('snap4.png') + +# Single view with background shading and cortex mask +img = snap1('fsaverage/surf/lh.white', + overlay='fsaverage/surf/lh.thickness', + bg_map='fsaverage/surf/lh.curv', + roi='fsaverage/label/lh.cortex.label') +img.save('snap1.png') + +# Array inputs (e.g. from LaPy or trimesh) +import numpy as np +v = np.random.randn(1000, 3).astype(np.float32) +f = np.array([[0, 1, 2]], dtype=np.uint32) # minimal example +my_values = np.random.randn(1000).astype(np.float32) +tria_curv = np.random.randn(1000).astype(np.float32) +img = snap1((v, f), overlay=my_values, bg_map=tria_curv) +``` + +CLI usage: + +```bash +# Single view +whippersnap1 lh.white --overlay lh.thickness --bg-map lh.curv --roi lh.cortex.label -o snap1.png + +# Four-view batch +whippersnap4 -lh lh.thickness -rh rh.thickness -sd /path/to/subject -o snap4.png ``` See `tutorials/whippersnappy_tutorial.ipynb` for complete notebook examples. diff --git a/tests/test_array_inputs.py b/tests/test_array_inputs.py new file mode 100644 index 0000000..2a1757c --- /dev/null +++ b/tests/test_array_inputs.py @@ -0,0 +1,269 @@ +"""Tests for the array-input pathway introduced in v2.0-rc. + +These tests exercise the resolver functions and the geometry-preparation +pipeline entirely without touching any file on disk, using small synthetic +triangle meshes. +""" + +import numpy as np +import pytest + +from whippersnappy.geometry.inputs import ( + resolve_annot, + resolve_bg_map, + resolve_mesh, + resolve_overlay, + resolve_roi, +) +from whippersnappy.geometry.prepare import ( + _estimate_thresholds_from_array, + estimate_overlay_thresholds, + prepare_geometry, + prepare_geometry_from_arrays, +) + + +# --------------------------------------------------------------------------- +# Minimal synthetic mesh (tetrahedron) +# --------------------------------------------------------------------------- + +_V = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32) +_F = np.array([[0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3]], dtype=np.uint32) +_N = _V.shape[0] # 4 vertices + + +# --------------------------------------------------------------------------- +# resolve_mesh +# --------------------------------------------------------------------------- + +class TestResolveMesh: + def test_tuple_input(self): + v, f = resolve_mesh((_V, _F)) + assert v.shape == (4, 3) + assert v.dtype == np.float32 + assert f.shape == (4, 3) + assert f.dtype == np.uint32 + + def test_list_input(self): + v, f = resolve_mesh([_V, _F]) + assert v.shape == (4, 3) + assert f.shape == (4, 3) + + def test_wrong_type_raises(self): + with pytest.raises(TypeError): + resolve_mesh(42) + + def test_wrong_shape_vertices_raises(self): + bad_v = np.ones((4, 4), dtype=np.float32) + with pytest.raises(ValueError): + resolve_mesh((bad_v, _F)) + + def test_wrong_shape_faces_raises(self): + bad_f = np.ones((4, 4), dtype=np.uint32) + with pytest.raises(ValueError): + resolve_mesh((_V, bad_f)) + + +# --------------------------------------------------------------------------- +# resolve_overlay / resolve_bg_map +# --------------------------------------------------------------------------- + +class TestResolveOverlay: + def test_none_returns_none(self): + assert resolve_overlay(None, n_vertices=_N) is None + + def test_array_input(self): + arr = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) + result = resolve_overlay(arr, n_vertices=_N) + assert result.shape == (_N,) + assert result.dtype == np.float32 + + def test_shape_mismatch_raises(self): + arr = np.array([0.1, 0.5], dtype=np.float32) + with pytest.raises(ValueError): + resolve_overlay(arr, n_vertices=_N) + + def test_n_vertices_none_skips_check(self): + arr = np.array([0.1, 0.5], dtype=np.float32) + result = resolve_overlay(arr, n_vertices=None) + assert result.shape == (2,) + + +class TestResolveBgMap: + def test_none_returns_none(self): + assert resolve_bg_map(None, n_vertices=_N) is None + + def test_array_input(self): + arr = np.array([-1.0, 1.0, -0.5, 0.5], dtype=np.float32) + result = resolve_bg_map(arr, n_vertices=_N) + assert result.shape == (_N,) + + def test_shape_mismatch_raises(self): + with pytest.raises(ValueError): + resolve_bg_map(np.ones(2), n_vertices=_N) + + +# --------------------------------------------------------------------------- +# resolve_roi +# --------------------------------------------------------------------------- + +class TestResolveRoi: + def test_none_returns_none(self): + assert resolve_roi(None, n_vertices=_N) is None + + def test_bool_array(self): + roi = np.array([True, True, True, False], dtype=bool) + result = resolve_roi(roi, n_vertices=_N) + assert result.dtype == bool + assert result.shape == (_N,) + assert result[3] is np.bool_(False) + + def test_shape_mismatch_raises(self): + with pytest.raises(ValueError): + resolve_roi(np.ones(2, dtype=bool), n_vertices=_N) + + +# --------------------------------------------------------------------------- +# resolve_annot +# --------------------------------------------------------------------------- + +class TestResolveAnnot: + def test_none_returns_none(self): + assert resolve_annot(None, n_vertices=_N) is None + + def test_two_tuple(self): + labels = np.array([0, 1, 0, 1]) + ctab = np.array([[255, 0, 0, 0, 0], [0, 255, 0, 0, 1]]) + result = resolve_annot((labels, ctab), n_vertices=_N) + assert result is not None + assert len(result) == 3 + assert result[2] is None # names + + def test_three_tuple(self): + labels = np.zeros(_N, dtype=int) + ctab = np.array([[200, 100, 50, 0, 1]]) + names = ["region0"] + result = resolve_annot((labels, ctab, names), n_vertices=_N) + assert result[2] == names + + def test_shape_mismatch_raises(self): + labels = np.zeros(2, dtype=int) + ctab = np.array([[255, 0, 0, 0, 0]]) + with pytest.raises(ValueError): + resolve_annot((labels, ctab), n_vertices=_N) + + def test_wrong_type_raises(self): + with pytest.raises(TypeError): + resolve_annot(42, n_vertices=_N) + + +# --------------------------------------------------------------------------- +# estimate_overlay_thresholds +# --------------------------------------------------------------------------- + +class TestEstimateOverlayThresholds: + def test_array_input(self): + arr = np.array([1.0, 2.0, 3.0, -1.5], dtype=np.float32) + fmin, fmax = estimate_overlay_thresholds(arr) + assert fmin >= 0 + assert fmax == pytest.approx(3.0) + + def test_passthrough_when_provided(self): + arr = np.array([1.0, 2.0, 3.0], dtype=np.float32) + fmin, fmax = estimate_overlay_thresholds(arr, minval=0.5, maxval=5.0) + assert fmin == pytest.approx(0.5) + assert fmax == pytest.approx(5.0) + + +# --------------------------------------------------------------------------- +# prepare_geometry_from_arrays — pure array pipeline +# --------------------------------------------------------------------------- + +class TestPrepareGeometryFromArrays: + def test_no_overlay(self): + vdata, tris, fmin, fmax, pos, neg = prepare_geometry_from_arrays(_V, _F) + assert vdata.shape == (_N, 9) + assert tris.shape == (4, 3) + assert fmin is None and fmax is None + + def test_with_overlay(self): + overlay = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) + vdata, tris, fmin, fmax, pos, neg = prepare_geometry_from_arrays( + _V, _F, overlay=overlay + ) + assert vdata.shape == (_N, 9) + assert fmin is not None and fmax is not None + + def test_with_bg_map(self): + bg = np.array([-1.0, 1.0, -0.5, 0.5], dtype=np.float32) + vdata, tris, *_ = prepare_geometry_from_arrays(_V, _F, bg_map=bg) + assert vdata.shape == (_N, 9) + + def test_with_roi(self): + overlay = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) + roi = np.array([True, True, True, False], dtype=bool) + vdata, tris, *_ = prepare_geometry_from_arrays(_V, _F, overlay=overlay, roi=roi) + assert vdata.shape == (_N, 9) + + def test_overlay_shape_mismatch_raises(self): + overlay = np.array([0.1, 0.5], dtype=np.float32) # wrong length + with pytest.raises(ValueError): + prepare_geometry_from_arrays(_V, _F, overlay=overlay) + + +# --------------------------------------------------------------------------- +# prepare_geometry — thin wrapper (array path) +# --------------------------------------------------------------------------- + +class TestPrepareGeometry: + def test_tuple_mesh_no_overlay(self): + vdata, tris, *_ = prepare_geometry((_V, _F)) + assert vdata.shape == (_N, 9) + + def test_tuple_mesh_with_overlay_and_roi(self): + overlay = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) + roi = np.array([True, True, True, False], dtype=bool) + vdata, tris, fmin, fmax, pos, neg = prepare_geometry( + (_V, _F), overlay=overlay, roi=roi + ) + assert vdata.shape == (_N, 9) + assert fmin is not None + + def test_tuple_mesh_with_bg_map(self): + bg = np.array([-1.0, 1.0, -0.5, 0.5], dtype=np.float32) + vdata, tris, *_ = prepare_geometry((_V, _F), bg_map=bg) + assert vdata.shape == (_N, 9) + + def test_invalid_mesh_type_raises(self): + with pytest.raises(TypeError): + prepare_geometry(12345) + + +# --------------------------------------------------------------------------- +# snap1 array-input integration (no OpenGL — just geometry prep) +# --------------------------------------------------------------------------- + +class TestSnap1ArrayInputs: + """Integration test for the array-input pathway of snap1. + + We only test the geometry-preparation layer here (not OpenGL rendering) + so that the test suite can run in headless CI without a display. + """ + + def test_prepare_geometry_called_by_snap1_path(self): + """Verify that prepare_geometry accepts the same args snap1 would pass.""" + overlay = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) + roi = np.array([True, True, True, False], dtype=bool) + bg = np.array([-1.0, 1.0, -0.5, 0.5], dtype=np.float32) + vdata, tris, fmin, fmax, pos, neg = prepare_geometry( + (_V, _F), overlay=overlay, roi=roi, bg_map=bg + ) + assert vdata is not None + assert tris is not None + + def test_prepare_geometry_bg_map_array_no_error(self): + """Verify bg_map as array raises no error.""" + bg = np.array([-0.5, 0.5, -0.3, 0.3], dtype=np.float32) + vdata, tris, *_ = prepare_geometry((_V, _F), bg_map=bg) + assert vdata.shape == (_N, 9) + diff --git a/tutorials/whippersnappy_tutorial.ipynb b/tutorials/whippersnappy_tutorial.ipynb index 15e6d25..5898a06 100644 --- a/tutorials/whippersnappy_tutorial.ipynb +++ b/tutorials/whippersnappy_tutorial.ipynb @@ -115,7 +115,7 @@ "\n", "from whippersnappy import snap1\n", "\n", - "img = snap1(lh_white, curvpath=lh_curv)\n", + "img = snap1(lh_white, bg_map=lh_curv)\n", "display(img)\n" ] }, @@ -126,7 +126,7 @@ "source": [ "## snap1 — With Thickness Overlay\n", "\n", - "By passing `overlaypath` and `labelpath`, the surface is colored by cortical\n", + "By passing `overlay` and `roi`, the surface is colored by cortical\n", "thickness values, masked to the cortex label. The `view` parameter selects\n", "the lateral view of the left hemisphere." ] @@ -142,9 +142,9 @@ "\n", "img = snap1(\n", " lh_white,\n", - " overlaypath=lh_thickness,\n", - " curvpath=lh_curv,\n", - " labelpath=lh_label,\n", + " overlay=lh_thickness,\n", + " bg_map=lh_curv,\n", + " roi=lh_label,\n", " view=ViewType.LEFT,\n", ")\n", "display(img)\n" @@ -157,7 +157,7 @@ "source": [ "## snap1 — With Parcellation Annotation\n", "\n", - "`annotpath` accepts a FreeSurfer `.annot` file and colors each vertex by\n", + "`annot` accepts a FreeSurfer `.annot` file and colors each vertex by\n", "its parcellation label. This example uses the DKTatlas parcellation." ] }, @@ -170,8 +170,8 @@ "source": [ "img = snap1(\n", " lh_white,\n", - " annotpath=lh_annot,\n", - " curvpath=lh_curv,\n", + " annot=lh_annot,\n", + " bg_map=lh_curv,\n", ")\n", "display(img)\n" ] @@ -199,8 +199,8 @@ "\n", "img = snap4(\n", " sdir=sdir,\n", - " lhoverlaypath=lh_thickness,\n", - " rhoverlaypath=rh_thickness,\n", + " lh_overlay=lh_thickness,\n", + " rh_overlay=rh_thickness,\n", " colorbar=True,\n", " caption=\"Cortical Thickness (mm)\",\n", ")\n", @@ -229,9 +229,9 @@ "from whippersnappy import plot3d\n", "\n", "viewer = plot3d(\n", - " meshpath=lh_white,\n", - " curvpath=lh_curv,\n", - " overlaypath=lh_thickness,\n", + " mesh=lh_white,\n", + " bg_map=lh_curv,\n", + " overlay=lh_thickness,\n", ")\n", "display(viewer)\n" ] @@ -264,11 +264,11 @@ "outpath_gif = \"/tmp/lh_thickness_rotate.gif\"\n", "\n", "snap_rotate(\n", - " meshpath=lh_white,\n", + " mesh=lh_white,\n", " outpath=outpath_gif,\n", - " overlaypath=lh_thickness,\n", - " curvpath=lh_curv,\n", - " labelpath=lh_label,\n", + " overlay=lh_thickness,\n", + " bg_map=lh_curv,\n", + " roi=lh_label,\n", " n_frames=72,\n", " fps=24,\n", " width=800,\n", diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index 2e3a41e..83f119f 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -15,7 +15,7 @@ from whippersnappy.utils.types import ViewType from IPython.display import display - img = snap1(meshpath='path/to/surface.white', view=ViewType.LEFT) + img = snap1(mesh='path/to/surface.white', view=ViewType.LEFT) display(img) For interactive 3D in Jupyter notebooks: @@ -24,9 +24,9 @@ from whippersnappy import plot3d viewer = plot3d( - meshpath='path/to/surface.white', - curvpath='path/to/curv', - overlaypath='path/to/thickness.mgh' # optional: for colors + mesh='path/to/surface.white', + bg_map='path/to/curv', + overlay='path/to/thickness.mgh' # optional: for colors ) display(viewer) diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index de6e861..fca3c0e 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -50,8 +50,8 @@ def show_window( hemi, - overlaypath=None, - annotpath=None, + overlay=None, + annot=None, sdir=None, invert=False, labelname="cortex.label", @@ -70,9 +70,9 @@ def show_window( ---------- hemi : {'lh','rh'} Hemisphere to display. - overlaypath : str or None, optional + overlay : str or None, optional Path to a per-vertex overlay file (e.g. thickness). - annotpath : str or None, optional + annot : str or None, optional Path to a ``.annot`` file providing categorical labels for vertices. sdir : str or None, optional Subject directory containing ``surf/`` and ``label/`` subdirectories. @@ -112,19 +112,19 @@ def show_window( msg = f"Could not find a valid surf file in {sdir} for hemi: {hemi}!" logger.error(msg) raise FileNotFoundError(msg) - meshpath = os.path.join(sdir, "surf", hemi + "." + found_surfname) + mesh = os.path.join(sdir, "surf", hemi + "." + found_surfname) else: - meshpath = os.path.join(sdir, "surf", hemi + "." + surfname) + mesh = os.path.join(sdir, "surf", hemi + "." + surfname) - curvpath = os.path.join(sdir, "surf", hemi + "." + curvname) if curvname else None - labelpath = os.path.join(sdir, "label", hemi + "." + labelname) if labelname else None + bg_map = os.path.join(sdir, "surf", hemi + "." + curvname) if curvname else None + roi = os.path.join(sdir, "label", hemi + "." + labelname) if labelname else None view_mats = get_view_matrices() viewmat = view_mats[ViewType.RIGHT] if hemi == "rh" else view_mats[ViewType.LEFT] rot_y = pyrr.Matrix44.from_y_rotation(0) meshdata, triangles, fthresh, fmax, neg = prepare_geometry( - meshpath, overlaypath, annotpath, curvpath, labelpath, current_fthresh_, current_fmax_, + mesh, overlay, annot, bg_map, roi, current_fthresh_, current_fmax_, invert=invert, ) shader = setup_shader(meshdata, triangles, wwidth, wheight, specular=specular) @@ -147,7 +147,7 @@ def show_window( current_fthresh_ = app_window_.get_fthresh_value() current_fmax_ = app_window_.get_fmax_value() meshdata, triangles, fthresh, fmax, neg = prepare_geometry( - meshpath, overlaypath, annotpath, curvpath, labelpath, + mesh, overlay, annot, bg_map, roi, current_fthresh_, current_fmax_, invert=invert, ) @@ -298,8 +298,8 @@ def run(): target=show_window, args=("lh",), kwargs=dict( - overlaypath=args.lh_overlay, - annotpath=args.lh_annot, + overlay=args.lh_overlay, + annot=args.lh_annot, sdir=args.sdir, invert=args.invert, surfname=args.surf_name, diff --git a/whippersnappy/cli/whippersnap1.py b/whippersnappy/cli/whippersnap1.py index edd0195..aef3bf9 100644 --- a/whippersnappy/cli/whippersnap1.py +++ b/whippersnappy/cli/whippersnap1.py @@ -10,8 +10,8 @@ # Static single-view snapshot (lateral view, thickness overlay) whippersnap1 /surf/lh.white \\ --overlay /surf/lh.thickness \\ - --curv /surf/lh.curv \\ - --label /label/lh.cortex.label \\ + --bg-map /surf/lh.curv \\ + --roi /label/lh.cortex.label \\ --view left --fthresh 1.5 --fmax 4.0 \\ -o snap1.png @@ -67,7 +67,7 @@ def run(): ----- **Snapshot options** (default mode): - * ``meshpath`` — path to the surface file (FreeSurfer binary, e.g. ``lh.white``). + * ``mesh`` — path to the surface file (FreeSurfer binary, e.g. ``lh.white``). * ``--overlay`` — per-vertex scalar overlay (e.g. ``lh.thickness``). * ``--annot`` — FreeSurfer ``.annot`` parcellation file. * ``--label`` — label file used to mask overlay values to the cortex. @@ -102,7 +102,7 @@ def run(): # --- Required --- parser.add_argument( - "meshpath", + "mesh", type=str, help="Path to the surface file. FreeSurfer binary format (e.g. lh.white) " "or any mesh readable by the geometry module.", @@ -120,11 +120,15 @@ def run(): ), ) - # --- Optional overlay / annotation / label / curv --- + # --- Optional overlay / annotation / roi / bg-map --- parser.add_argument("--overlay", type=str, default=None, help="Per-vertex overlay file.") parser.add_argument("--annot", type=str, default=None, help="FreeSurfer .annot file.") - parser.add_argument("--label", type=str, default=None, help="Label file for masking.") - parser.add_argument("--curv", type=str, default=None, help="Curvature file for texturing.") + parser.add_argument("--roi", type=str, default=None, + help="Path to a FreeSurfer label file defining the region of interest " + "(vertices to include in overlay coloring).") + parser.add_argument("--bg-map", type=str, default=None, dest="bg_map", + help="Path to a per-vertex scalar file used as background shading " + "(sign determines light/dark).") # --- View --- parser.add_argument( @@ -206,16 +210,16 @@ def run(): tempfile.gettempdir(), "whippersnappy_rotation.mp4" ) snap_rotate( - meshpath=args.meshpath, + mesh=args.mesh, outpath=outpath, n_frames=args.rotate_frames, fps=args.rotate_fps, width=args.width, height=args.height, - overlaypath=args.overlay, - curvpath=args.curv, - annotpath=args.annot, - labelpath=args.label, + overlay=args.overlay, + bg_map=args.bg_map, + annot=args.annot, + roi=args.roi, fthresh=args.fthresh, fmax=args.fmax, invert=args.invert, @@ -231,12 +235,12 @@ def run(): tempfile.gettempdir(), "whippersnappy_snap1.png" ) img = snap1( - meshpath=args.meshpath, + mesh=args.mesh, outpath=outpath, - overlaypath=args.overlay, - annotpath=args.annot, - labelpath=args.label, - curvpath=args.curv, + overlay=args.overlay, + annot=args.annot, + roi=args.roi, + bg_map=args.bg_map, view=_VIEW_CHOICES[args.view], width=args.width, height=args.height, diff --git a/whippersnappy/cli/whippersnap4.py b/whippersnappy/cli/whippersnap4.py index 3d1e8cb..f8de394 100644 --- a/whippersnappy/cli/whippersnap4.py +++ b/whippersnappy/cli/whippersnap4.py @@ -155,10 +155,10 @@ def run(): try: img = snap4( - lhoverlaypath=args.lh_overlay, - rhoverlaypath=args.rh_overlay, - lhannotpath=args.lh_annot, - rhannotpath=args.rh_annot, + lh_overlay=args.lh_overlay, + rh_overlay=args.rh_overlay, + lh_annot=args.lh_annot, + rh_annot=args.rh_annot, sdir=args.sdir, caption=args.caption, surfname=args.surf_name, diff --git a/whippersnappy/geometry/__init__.py b/whippersnappy/geometry/__init__.py index 85deec0..d377e82 100644 --- a/whippersnappy/geometry/__init__.py +++ b/whippersnappy/geometry/__init__.py @@ -2,11 +2,29 @@ Expose prepare_geometry and small IO helpers under `whippersnappy.geometry`. """ -from .prepare import estimate_overlay_thresholds, prepare_geometry +from .inputs import resolve_annot, resolve_bg_map, resolve_mesh, resolve_overlay, resolve_roi +from .prepare import ( + estimate_overlay_thresholds, + prepare_and_validate_geometry, + prepare_geometry, + prepare_geometry_from_arrays, +) from .read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data from .surf_name import get_surf_name __all__ = [ - 'prepare_geometry', 'estimate_overlay_thresholds', - 'read_geometry', 'read_annot_data', 'read_mgh_data', 'read_morph_data', 'get_surf_name', + 'prepare_geometry', + 'prepare_geometry_from_arrays', + 'prepare_and_validate_geometry', + 'estimate_overlay_thresholds', + 'resolve_mesh', + 'resolve_overlay', + 'resolve_bg_map', + 'resolve_roi', + 'resolve_annot', + 'read_geometry', + 'read_annot_data', + 'read_mgh_data', + 'read_morph_data', + 'get_surf_name', ] diff --git a/whippersnappy/geometry/inputs.py b/whippersnappy/geometry/inputs.py new file mode 100644 index 0000000..35f6ca3 --- /dev/null +++ b/whippersnappy/geometry/inputs.py @@ -0,0 +1,234 @@ +"""Input resolver functions for WhipperSnapPy geometry loading. + +This module is the single source of truth for loading and validating all +user-facing inputs (mesh, overlay, background map, ROI, annotation). No +other module should call ``read_geometry``, ``read_morph_data``, +``read_mgh_data``, ``read_annot_data``, or ``mask_label`` directly — all +calls should go through the resolver functions defined here. +""" + +import os + +import numpy as np + +from ..utils.colormap import mask_label +from .read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data + + +def resolve_mesh(mesh): + """Resolve a mesh input to ``(vertices, faces)`` numpy arrays. + + Parameters + ---------- + mesh : str or tuple/list of two array-likes + Either a file path to a FreeSurfer-format surface file, or a + two-element sequence whose first element is the vertex coordinate + array and whose second element is the face index array. + + Returns + ------- + vertices : numpy.ndarray + Vertex coordinate array of shape (N, 3), dtype float32. + faces : numpy.ndarray + Triangle face index array of shape (M, 3), dtype uint32. + + Raises + ------ + TypeError + If *mesh* is neither a ``str`` nor a two-element tuple/list. + ValueError + If the resulting arrays do not have the expected shapes. + """ + if isinstance(mesh, str): + v_raw, f_raw = read_geometry(mesh, read_metadata=False) + vertices = np.asarray(v_raw, dtype=np.float32) + faces = np.asarray(f_raw, dtype=np.uint32) + elif isinstance(mesh, (tuple, list)) and len(mesh) == 2: + vertices = np.asarray(mesh[0], dtype=np.float32) + faces = np.asarray(mesh[1], dtype=np.uint32) + else: + raise TypeError( + f"mesh must be a file path (str) or a (vertices, faces) tuple/list, " + f"got {type(mesh).__name__!r}." + ) + + if vertices.ndim != 2 or vertices.shape[1] != 3: + raise ValueError( + f"vertices must be an array of shape (N, 3), got shape {vertices.shape}." + ) + if faces.ndim != 2 or faces.shape[1] != 3: + raise ValueError( + f"faces must be an array of shape (M, 3), got shape {faces.shape}." + ) + return vertices, faces + + +def _load_overlay_from_file(path): + """Load a 1-D per-vertex overlay array from a file path.""" + _, ext = os.path.splitext(path) + if ext == ".mgh": + return read_mgh_data(path) + return read_morph_data(path) + + +def resolve_overlay(overlay, *, n_vertices): + """Resolve an overlay input to a 1-D float32 numpy array, or ``None``. + + Parameters + ---------- + overlay : None, str, or array-like + * ``None`` — no overlay; returns ``None``. + * ``str`` — path to an overlay file (.mgh or FreeSurfer morph format). + * array-like — converted to ``np.float32``; must have shape + ``(n_vertices,)`` when *n_vertices* is not ``None``. + n_vertices : int or None + Expected number of vertices. Shape validation is skipped when + ``None`` (useful for ``estimate_overlay_thresholds``). + + Returns + ------- + numpy.ndarray of shape (n_vertices,) or None + + Raises + ------ + ValueError + If the loaded/converted array does not match *n_vertices*. + """ + if overlay is None: + return None + if isinstance(overlay, str): + arr = _load_overlay_from_file(overlay).astype(np.float32) + else: + arr = np.asarray(overlay, dtype=np.float32) + if n_vertices is not None and arr.shape != (n_vertices,): + raise ValueError( + f"overlay has shape {arr.shape} but mesh has {n_vertices} vertices." + ) + return arr + + +def resolve_bg_map(bg_map, *, n_vertices): + """Resolve a background-map input to a 1-D float32 numpy array, or ``None``. + + Identical logic to :func:`resolve_overlay`. + + Parameters + ---------- + bg_map : None, str, or array-like + Background shading data (typically curvature). + n_vertices : int or None + Expected number of vertices for shape validation. + + Returns + ------- + numpy.ndarray of shape (n_vertices,) or None + """ + if bg_map is None: + return None + if isinstance(bg_map, str): + arr = _load_overlay_from_file(bg_map).astype(np.float32) + else: + arr = np.asarray(bg_map, dtype=np.float32) + if n_vertices is not None and arr.shape != (n_vertices,): + raise ValueError( + f"bg_map has shape {arr.shape} but mesh has {n_vertices} vertices." + ) + return arr + + +def resolve_roi(roi, *, n_vertices): + """Resolve a region-of-interest input to a boolean numpy array, or ``None``. + + The returned boolean array has ``True`` for vertices that are *included* + in the overlay coloring and ``False`` for vertices that fall back to + background (``bg_map``) shading. + + Parameters + ---------- + roi : None, str, or array-like + * ``None`` — no masking; returns ``None``. + * ``str`` — path to a FreeSurfer label file. Vertices listed in the + file are marked ``True``; all others are ``False``. + * array-like — converted to ``np.bool_``; must have shape + ``(n_vertices,)``. + n_vertices : int + Expected number of vertices. + + Returns + ------- + numpy.ndarray of shape (n_vertices,) bool, or None + + Raises + ------ + ValueError + If the resolved array does not match *n_vertices*. + """ + if roi is None: + return None + if isinstance(roi, str): + # Use mask_label to get vertices included in the label (NaN = excluded). + sentinel = np.ones(n_vertices, dtype=np.float32) + masked = mask_label(sentinel, roi) + # Vertices NOT in the label were set to NaN → roi = ~isnan + arr = ~np.isnan(masked) + else: + arr = np.asarray(roi, dtype=bool) + if arr.shape != (n_vertices,): + raise ValueError( + f"roi has shape {arr.shape} but mesh has {n_vertices} vertices." + ) + return arr + + +def resolve_annot(annot, *, n_vertices): + """Resolve an annotation input to ``(labels, ctab, names)`` or ``None``. + + Parameters + ---------- + annot : None, str, or tuple of length 2 or 3 + * ``None`` — no annotation; returns ``None``. + * ``str`` — path to a FreeSurfer .annot file; loaded via + :func:`read_annot_data`. + * 2-tuple ``(labels, ctab)`` — validated and returned as + ``(labels, ctab, None)``. + * 3-tuple ``(labels, ctab, names)`` — validated and returned as-is. + n_vertices : int + Expected number of vertices for shape validation of *labels*. + + Returns + ------- + tuple of (labels, ctab, names) or None + * *labels* — integer array of shape (n_vertices,). + * *ctab* — color table array of shape (n_labels, ≥3). + * *names* — list of label names, or ``None``. + + Raises + ------ + TypeError + If *annot* is not one of the accepted types. + ValueError + If *labels* does not match *n_vertices*. + """ + if annot is None: + return None + if isinstance(annot, str): + labels, ctab, names = read_annot_data(annot) + elif isinstance(annot, (tuple, list)) and len(annot) == 2: + labels = np.asarray(annot[0]) + ctab = np.asarray(annot[1]) + names = None + elif isinstance(annot, (tuple, list)) and len(annot) == 3: + labels = np.asarray(annot[0]) + ctab = np.asarray(annot[1]) + names = annot[2] + else: + raise TypeError( + f"annot must be a file path (str), a (labels, ctab) tuple, " + f"or a (labels, ctab, names) tuple; got {type(annot).__name__!r}." + ) + if labels.shape != (n_vertices,): + raise ValueError( + f"annot labels have shape {labels.shape} but mesh has {n_vertices} vertices." + ) + return labels, ctab, names + diff --git a/whippersnappy/geometry/prepare.py b/whippersnappy/geometry/prepare.py index f4da0ce..d5df718 100644 --- a/whippersnappy/geometry/prepare.py +++ b/whippersnappy/geometry/prepare.py @@ -1,17 +1,20 @@ """Geometry helpers for mesh processing and GPU preparation (prepare.py). -This module contains the primary `prepare_geometry` function used to -normalize meshes, compute normals and assemble vertex arrays for OpenGL. +This module contains the primary geometry-preparation pipeline. The +low-level workhorse is :func:`prepare_geometry_from_arrays` which operates +entirely on numpy arrays. :func:`prepare_geometry` is a thin file-loading +wrapper that delegates to the resolver functions in +:mod:`whippersnappy.geometry.inputs` before calling +:func:`prepare_geometry_from_arrays`. """ -import os import warnings import numpy as np -from ..utils.colormap import binary_color, heat_color, mask_label, mask_sign, rescale_overlay +from ..utils.colormap import binary_color, heat_color, mask_sign, rescale_overlay from ..utils.types import ColorSelection -from .read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data +from .inputs import resolve_annot, resolve_bg_map, resolve_mesh, resolve_overlay, resolve_roi def normalize_mesh(v, scale=1.0): @@ -64,10 +67,6 @@ def vertex_normals(v, t): cr0 = np.cross(v1mv0, -v0mv2) cr1 = np.cross(v2mv1, -v1mv0) cr2 = np.cross(v0mv2, -v2mv1) - n = np.zeros(v.shape) - #np.add.at(n, t[:, 0], cr0) - #np.add.at(n, t[:, 1], cr1) - #np.add.at(n, t[:, 2], cr2) # Vectorized accumulation using bincount idx = np.concatenate([t[:, 0], t[:, 1], t[:, 2]]) contribs = np.vstack([cr0, cr1, cr2]) @@ -109,8 +108,8 @@ def _estimate_thresholds_from_array(mapdata, minval=None, maxval=None): return minval, maxval -def estimate_overlay_thresholds(overlaypath, minval=None, maxval=None): - """Estimate threshold and saturation values from an overlay file. +def estimate_overlay_thresholds(overlay, minval=None, maxval=None): + """Estimate threshold and saturation values from an overlay file or array. Reads the overlay data and derives ``fmin`` / ``fmax`` from the absolute values without performing any geometry or color work. Both values are @@ -119,8 +118,9 @@ def estimate_overlay_thresholds(overlaypath, minval=None, maxval=None): Parameters ---------- - overlaypath : str - Path to the overlay file (.mgh or FreeSurfer morph format). + overlay : str or array-like + Path to the overlay file (.mgh or FreeSurfer morph format), or a + numpy array / array-like of per-vertex scalar values. minval : float or None, optional If provided, used as-is for the threshold; otherwise estimated as the minimum absolute value in the overlay. @@ -135,52 +135,60 @@ def estimate_overlay_thresholds(overlaypath, minval=None, maxval=None): maxval : float Saturation value (upper bound of the color scale). """ - _, file_extension = os.path.splitext(overlaypath) - if file_extension == ".mgh": - mapdata = read_mgh_data(overlaypath) + if isinstance(overlay, str): + # Use resolve_overlay with n_vertices=None to skip shape validation + overlay_arr = resolve_overlay(overlay, n_vertices=None) else: - mapdata = read_morph_data(overlaypath) - return _estimate_thresholds_from_array(mapdata, minval, maxval) - - -def prepare_geometry( - surfpath, - overlaypath=None, - annotpath=None, - curvpath=None, - labelpath=None, + overlay_arr = np.asarray(overlay) + return _estimate_thresholds_from_array(overlay_arr, minval, maxval) + + +def prepare_geometry_from_arrays( + vertices, + faces, + overlay=None, + annot=None, + ctab=None, + bg_map=None, + roi=None, minval=None, maxval=None, invert=False, scale=1.85, color_mode=ColorSelection.BOTH, ): - """Prepare vertex and color arrays for GPU upload. + """Prepare vertex and color arrays for GPU upload from numpy arrays. - This function loads a surface geometry from ``surfpath``, optionally - loads an overlay (mgh/curv) or annotation (.annot) and produces an - interleaved vertex array containing positions, normals and colors - suitable for uploading to OpenGL (vertex buffer objects). + This is the core geometry preparation function. All inputs must already + be resolved numpy arrays; for file-path support use the thin wrapper + :func:`prepare_geometry`. Parameters ---------- - surfpath : str - Path to the surface file. - overlaypath : str or None, optional - Path to an overlay (mgh/curv) file providing per-vertex scalar - values used for coloring. - annotpath : str or None, optional - Path to a FreeSurfer .annot file for categorical labeling. - curvpath : str or None, optional - Path to curvature data used as fallback texture. - labelpath : str or None, optional - Path to a label file used to mask vertices. + vertices : numpy.ndarray + Vertex coordinate array of shape (N, 3), dtype float32. + faces : numpy.ndarray + Triangle index array of shape (M, 3), dtype uint32. + overlay : numpy.ndarray or None, optional + Per-vertex scalar values of shape (N,) float32 used for coloring. + annot : numpy.ndarray or None, optional + Per-vertex integer label indices of shape (N,) int32. + ctab : numpy.ndarray or None, optional + Color table array (n_labels, ≥3) associated with *annot*. + bg_map : numpy.ndarray or None, optional + Per-vertex scalar values of shape (N,) float32 whose sign determines + background shading (binary light/dark). When ``None`` a flat gray + background is used. + roi : numpy.ndarray of bool or None, optional + Boolean mask of shape (N,). ``True`` = vertex is inside the region + of interest and receives overlay coloring; ``False`` = vertex falls + back to background shading. When ``None`` all vertices are in-ROI. minval, maxval : float or None, optional Threshold and saturation values for overlay scaling. invert : bool, optional, default False Invert color mapping. scale : float, optional, default 1.85 - Geometry scaling factor applied by ``normalize_mesh``. + Geometry scaling factor applied by :func:`normalize_mesh`. color_mode : ColorSelection, optional, default ColorSelection.BOTH Which sign(s) of overlay values to use for coloring. @@ -200,120 +208,197 @@ def prepare_geometry( ValueError If overlay or annotation arrays do not match the surface vertex count. """ - surf = read_geometry(surfpath, read_metadata=False) - vertices = normalize_mesh(np.array(surf[0], dtype=np.float32), scale) - triangles = np.array(surf[1], dtype=np.uint32) + vertices = normalize_mesh(np.array(vertices, dtype=np.float32), scale) + triangles = np.array(faces, dtype=np.uint32) vnormals = np.array(vertex_normals(vertices, triangles), dtype=np.float32) num_vertices = vertices.shape[0] - # try to load sulcal colormap - sulcmap = 0.5 * np.ones(vertices.shape, dtype=np.float32) - if curvpath: - curv = read_morph_data(curvpath) - if curv.shape[0] != num_vertices: - warnings.warn(f"Curvature file {curvpath} has {curv.shape[0]} values, but mesh has {num_vertices}.", - stacklevel=2) + # Build background (sulcal) colormap + if bg_map is not None: + if bg_map.shape[0] != num_vertices: + warnings.warn( + f"bg_map has {bg_map.shape[0]} values but mesh has {num_vertices}.", + stacklevel=2, + ) + sulcmap = 0.5 * np.ones(vertices.shape, dtype=np.float32) else: - sulcmap = binary_color(curv, 0.0, color_low=0.5, color_high=0.33) + sulcmap = binary_color(bg_map, 0.0, color_low=0.5, color_high=0.33) + else: + sulcmap = 0.5 * np.ones(vertices.shape, dtype=np.float32) # Initialize defaults for overlay outputs fmin = None fmax = None pos = None neg = None - colors = sulcmap # use as default - - # try to load overlay data - if overlaypath: - _, file_extension = os.path.splitext(overlaypath) - if file_extension == ".mgh": - mapdata = read_mgh_data(overlaypath) - else: - mapdata = read_morph_data(overlaypath) + colors = sulcmap # use as default - # Check if overlay length matches number of vertices. If not, raise an error. - if mapdata.shape[0] != num_vertices: + # Apply overlay coloring + if overlay is not None: + if overlay.shape[0] != num_vertices: raise ValueError( - f"Overlay file {overlaypath} has {mapdata.shape[0]} values but mesh has {num_vertices}.\n" + f"overlay has {overlay.shape[0]} values but mesh has {num_vertices}.\n" "This usually means the overlay does not match the provided surface " - "(e.g. RH overlay used with LH surface). Provide the correct overlay " - "file." + "(e.g. RH overlay used with LH surface). Provide the correct overlay." ) - else: - minval, maxval = _estimate_thresholds_from_array(mapdata, minval, maxval) - - mapdata = mask_sign(mapdata, color_mode) - mapdata, fmin, fmax, pos, neg = rescale_overlay(mapdata, minval, maxval) - colors = heat_color(mapdata, invert) - # some mapdata values could be nan (below min threshold) - missing = np.isnan(mapdata) - if np.any(missing): - colors[missing, :] = sulcmap[missing, :] - # alternatively try to load annotation data - elif annotpath: - # Read annotation (per-vertex labels) and colormap table. - annot, ctab, _ = read_annot_data(annotpath) - - # Check if annotation length matches number of vertices. If not, raise an error. + mapdata = overlay.copy().astype(np.float64) + minval, maxval = _estimate_thresholds_from_array(mapdata, minval, maxval) + mapdata = mask_sign(mapdata, color_mode) + mapdata, fmin, fmax, pos, neg = rescale_overlay(mapdata, minval, maxval) + colors = heat_color(mapdata, invert) + # Some mapdata values could be nan (below min threshold) — fall back to bg + missing = np.isnan(mapdata) + if np.any(missing): + colors[missing, :] = sulcmap[missing, :] + + elif annot is not None and ctab is not None: + # Per-vertex annotation coloring if annot.shape[0] != num_vertices: raise ValueError( - f"Annotation file {annotpath} has {annot.shape[0]} values but mesh has {num_vertices}.\n" + f"annot has {annot.shape[0]} values but mesh has {num_vertices}.\n" "This usually means the .annot does not match the provided surface " - "(e.g. RH annot used with LH surface). Provide the correct annot " - "file." + "(e.g. RH annot used with LH surface). Provide the correct annot file." ) - else: - # Ensure integer type for safe indexing - annot = annot.astype(np.int32) - - # Start with sulcmap as the default and only overwrite valid label indices - colors = np.array(sulcmap, dtype=np.float32) - - # Normalize colortable: detect whether ctab is 0-255 or 0-1 - ctab_rgb = ctab[:, 0:3].astype(np.float32) - denom = 255.0 if np.max(ctab_rgb) > 1 else 1.0 - - # Only assign colors for valid annotation indices (>=0 and within the color table) - valid = (annot >= 0) & (annot < ctab.shape[0]) - if np.any(valid): - colors[valid, :] = ctab_rgb[annot[valid], :] / denom + annot = annot.astype(np.int32) + colors = np.array(sulcmap, dtype=np.float32) + ctab_rgb = np.asarray(ctab[:, 0:3], dtype=np.float32) + denom = 255.0 if np.max(ctab_rgb) > 1 else 1.0 + valid = (annot >= 0) & (annot < ctab.shape[0]) + if np.any(valid): + colors[valid, :] = ctab_rgb[annot[valid], :] / denom # Ensure colors dtype matches vertices/normals colors = np.asarray(colors, dtype=np.float32) - # Apply label mask to colors if labelpath is provided, - # regardless of whether overlay or annot data was loaded - if labelpath: - mask = np.isnan(mask_label(np.ones(num_vertices), labelpath)) - if np.any(mask): - colors[mask, :] = sulcmap[mask, :] + # Apply ROI mask: vertices where roi == False fall back to sulcmap + if roi is not None: + outside = ~roi + if np.any(outside): + colors[outside, :] = sulcmap[outside, :] vertexdata = np.concatenate((vertices, vnormals, colors), axis=1) return vertexdata, triangles, fmin, fmax, pos, neg +def prepare_geometry( + mesh, + overlay=None, + annot=None, + bg_map=None, + roi=None, + minval=None, + maxval=None, + invert=False, + scale=1.85, + color_mode=ColorSelection.BOTH, +): + """Prepare vertex and color arrays for GPU upload. + + This is a thin file-loading wrapper around + :func:`prepare_geometry_from_arrays`. Inputs are resolved via the + functions in :mod:`whippersnappy.geometry.inputs` so that every + parameter can be either a file path or a numpy array. + + Parameters + ---------- + mesh : str or tuple of (array-like, array-like) + Surface file path (FreeSurfer format) **or** a ``(vertices, faces)`` + tuple/list where *vertices* is (N, 3) float and *faces* is (M, 3) int. + overlay : str, array-like, or None, optional + Path to an overlay (.mgh / FreeSurfer morph) file, or a (N,) array + of per-vertex scalar values. + annot : str, tuple, or None, optional + Path to a FreeSurfer .annot file, or a ``(labels, ctab)`` / + ``(labels, ctab, names)`` tuple. + bg_map : str, array-like, or None, optional + Path to a curvature/morph file used for background shading, or a + (N,) array whose sign determines light/dark shading. + roi : str, array-like, or None, optional + Path to a FreeSurfer label file or a (N,) boolean array. Vertices + with ``True`` receive overlay coloring; others fall back to *bg_map*. + minval, maxval : float or None, optional + Threshold and saturation values for overlay scaling. + invert : bool, optional, default False + Invert color mapping. + scale : float, optional, default 1.85 + Geometry scaling factor applied by :func:`normalize_mesh`. + color_mode : ColorSelection, optional, default ColorSelection.BOTH + Which sign(s) of overlay values to use for coloring. + + Returns + ------- + vertexdata : numpy.ndarray + Nx9 array (position x3, normal x3, color x3) ready for GPU upload. + triangles : numpy.ndarray + Mx3 uint32 triangle index array. + fmin, fmax : float or None + Final threshold and saturation values used for color mapping. + pos, neg : bool or None + Flags indicating whether positive/negative overlay values are present. + + Raises + ------ + TypeError + If *mesh* is not a valid type. + ValueError + If overlay or annotation arrays do not match the surface vertex count. + + Examples + -------- + File-path usage:: + + vdata, tris, fmin, fmax, pos, neg = prepare_geometry( + 'fsaverage/surf/lh.white', + overlay='fsaverage/surf/lh.thickness', + bg_map='fsaverage/surf/lh.curv', + roi='fsaverage/label/lh.cortex.label', + ) + + Array inputs:: + + import numpy as np + v = np.array([[0,0,0],[1,0,0],[0,1,0],[0,0,1]], dtype=np.float32) + f = np.array([[0,1,2],[0,1,3],[0,2,3],[1,2,3]], dtype=np.uint32) + vdata, tris, *_ = prepare_geometry((v, f)) + """ + vertices, faces = resolve_mesh(mesh) + n = vertices.shape[0] + overlay_arr = resolve_overlay(overlay, n_vertices=n) + bg_map_arr = resolve_bg_map(bg_map, n_vertices=n) + roi_arr = resolve_roi(roi, n_vertices=n) + annot_result = resolve_annot(annot, n_vertices=n) + annot_arr = annot_result[0] if annot_result is not None else None + ctab_arr = annot_result[1] if annot_result is not None else None + return prepare_geometry_from_arrays( + vertices, faces, overlay_arr, annot_arr, ctab_arr, + bg_map_arr, roi_arr, minval, maxval, invert, scale, color_mode, + ) + + def prepare_and_validate_geometry( - meshpath, - overlaypath, - annotpath, - curvpath, - labelpath, - fthresh, - fmax, - invert, - scale, - color_mode, + mesh, + overlay=None, + annot=None, + bg_map=None, + roi=None, + fthresh=None, + fmax=None, + invert=False, + scale=1.85, + color_mode=ColorSelection.BOTH, ): """Load and validate mesh geometry and overlay/annotation inputs. - This is a small wrapper around :func:`prepare_geometry` - that performs the same overlay presence validation used throughout the - static snapshot helpers. + This is a small wrapper around :func:`prepare_geometry` that performs + the same overlay-presence validation used throughout the static snapshot + helpers. Parameters ---------- - meshpath, overlaypath, annotpath, curvpath, labelpath : str or None - Paths passed through to :func:`prepare_geometry`. + mesh : str or tuple + Passed through to :func:`prepare_geometry`. + overlay, annot, bg_map, roi : str, array-like, or None + Passed through to :func:`prepare_geometry`. fthresh, fmax : float or None Threshold and saturation values passed to the geometry preparer. invert : bool @@ -337,11 +422,11 @@ def prepare_and_validate_geometry( import logging logger = logging.getLogger(__name__) meshdata, triangles, out_fthresh, out_fmax, pos, neg = prepare_geometry( - meshpath, - overlaypath, - annotpath, - curvpath, - labelpath, + mesh, + overlay, + annot, + bg_map, + roi, fthresh, fmax, invert, @@ -350,7 +435,7 @@ def prepare_and_validate_geometry( ) # Validate overlay presence similar to previous inline checks - if overlaypath is not None: + if overlay is not None: if color_mode == ColorSelection.POSITIVE: if not pos and neg: logger.error("Overlay has no values to display with positive color_mode") diff --git a/whippersnappy/plot3d.py b/whippersnappy/plot3d.py index 4be6af9..f43c684 100644 --- a/whippersnappy/plot3d.py +++ b/whippersnappy/plot3d.py @@ -8,7 +8,7 @@ Usage: from whippersnappy import plot3d - viewer = plot3d(meshpath='path/to/lh.white', curvpath='path/to/lh.curv') + viewer = plot3d(mesh='path/to/lh.white', bg_map='path/to/lh.curv') display(viewer) Dependencies: @@ -33,11 +33,11 @@ def plot3d( - meshpath, - overlaypath=None, - annotpath=None, - curvpath=None, - labelpath=None, + mesh, + overlay=None, + annot=None, + bg_map=None, + roi=None, minval=None, maxval=None, invert=False, @@ -56,16 +56,21 @@ def plot3d( Parameters ---------- - meshpath : str - Path to the surface file (FreeSurfer-style surface, e.g. "lh.white"). - overlaypath : str or None, optional - Path to a per-vertex overlay (thickness/curvature) file. - annotpath : str or None, optional - Path to a FreeSurfer .annot file for categorical labeling. - curvpath : str or None, optional - Path to a curvature file used as grayscale texture for unlabeled regions. - labelpath : str or None, optional - Path to a label file used to mask out vertices. + mesh : str or tuple of (array-like, array-like) + Path to the surface file (FreeSurfer-style surface, e.g. ``"lh.white"``) + **or** a ``(vertices, faces)`` tuple. + overlay : str, array-like, or None, optional + Path to a per-vertex overlay (thickness/curvature) file, or a (N,) + array of per-vertex scalar values. + annot : str, tuple, or None, optional + Path to a FreeSurfer .annot file, or a ``(labels, ctab)`` / + ``(labels, ctab, names)`` tuple for categorical labeling. + bg_map : str, array-like, or None, optional + Path to a curvature file **or** a (N,) array used as grayscale + texture for non-overlay regions. + roi : str, array-like, or None, optional + Path to a FreeSurfer label file **or** a (N,) boolean array to + restrict overlay coloring. minval, maxval : float or None, optional Threshold and saturation values used for color mapping (passed to :func:`prepare_geometry`). If ``None``, sensible defaults are chosen. @@ -98,13 +103,14 @@ def plot3d( from whippersnappy import plot3d from IPython.display import display - viewer = plot3d('fsaverage/surf/lh.white', overlaypath='fsaverage/surf/lh.thickness') + + viewer = plot3d('fsaverage/surf/lh.white', overlay='fsaverage/surf/lh.thickness') display(viewer) """ # Load and prepare mesh data color_mode = color_mode or ColorSelection.BOTH meshdata, triangles, fmin, fmax, pos, neg = prepare_geometry( - meshpath, overlaypath, annotpath, curvpath, labelpath, + mesh, overlay, annot, bg_map, roi, minval, maxval, invert, scale, color_mode ) @@ -159,9 +165,9 @@ def plot3d( Drag to rotate, scroll to zoom, right-drag to pan
""" - if overlaypath or annotpath: + if overlay or annot: info_text += "
📊 Overlay/annotation loaded" - elif curvpath: + elif bg_map: info_text += "
🧠 Curvature (grayscale is correct)" info_text += "" diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py index 7095d6c..52efa0d 100644 --- a/whippersnappy/snap.py +++ b/whippersnappy/snap.py @@ -20,12 +20,12 @@ def snap1( - meshpath, + mesh, outpath=None, - overlaypath=None, - annotpath=None, - labelpath=None, - curvpath=None, + overlay=None, + annot=None, + bg_map=None, + roi=None, view=ViewType.LEFT, viewmat=None, width=700, @@ -58,19 +58,26 @@ def snap1( Parameters ---------- - meshpath : str - Path to the surface file (FreeSurfer-format, e.g. "lh.white"). + mesh : str or tuple of (array-like, array-like) + Path to the surface file (FreeSurfer-format, e.g. ``"lh.white"``) **or** + a ``(vertices, faces)`` tuple where *vertices* is (N, 3) float and + *faces* is (M, 3) int. outpath : str or None, optional When provided, the resulting image is saved to this path. - overlaypath : str or None, optional - Path to overlay/mgh file providing per-vertex values to color the - surface. If ``None``, coloring falls back to curvature/annotation. - annotpath : str or None, optional - Path to a FreeSurfer .annot file with per-vertex labels. - labelpath : str or None, optional - Path to a label file (cortex.label) used to mask overlay values. - curvpath : str or None, optional - Path to curvature file used to texture non-colored regions. + overlay : str, array-like, or None, optional + Overlay file path (.mgh or FreeSurfer morph) **or** a (N,) array of + per-vertex scalar values. If ``None``, coloring falls back to + background shading / annotation. + annot : str, tuple, or None, optional + Path to a FreeSurfer .annot file **or** a ``(labels, ctab)`` / + ``(labels, ctab, names)`` tuple with per-vertex labels. + bg_map : str, array-like, or None, optional + Path to a curvature/morph file **or** a (N,) array whose sign + determines light/dark background shading for non-overlay vertices. + roi : str, array-like, or None, optional + Path to a FreeSurfer label file **or** a (N,) boolean array. + Vertices with ``True`` receive overlay coloring; others fall back + to *bg_map* shading. view : ViewType, optional Which pre-defined view to render (left, right, front, ...). Default is ``ViewType.LEFT``. viewmat : 4x4 matrix-like, optional @@ -119,8 +126,16 @@ def snap1( Examples -------- >>> from whippersnappy import snap1 - >>> img = snap1('fsaverage/surf/lh.white', overlaypath='fsaverage/surf/lh.thickness') + >>> img = snap1('fsaverage/surf/lh.white', overlay='fsaverage/surf/lh.thickness', + ... bg_map='fsaverage/surf/lh.curv', roi='fsaverage/label/lh.cortex.label') >>> img.save('/tmp/lh.png') + + Array inputs:: + + >>> import numpy as np + >>> v = np.random.randn(100, 3).astype(np.float32) + >>> f = np.array([[0, 1, 2]], dtype=np.uint32) + >>> img = snap1((v, f)) """ ref_width = 700 ref_height = 500 @@ -153,11 +168,11 @@ def snap1( window = create_window_with_fallback(brain_display_width, brain_display_height, "WhipperSnapPy", visible=True) try: meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry( - meshpath, - overlaypath, - annotpath, - curvpath, - labelpath, + mesh, + overlay, + annot, + bg_map, + roi, fthresh, fmax, invert, @@ -182,7 +197,7 @@ def snap1( create_colorbar( fthresh, fmax, invert, orientation, colorbar_scale * ui_scale, pos, neg, font_file=font_file ) - if overlaypath is not None and colorbar + if overlay is not None and colorbar else None ) font = ( @@ -241,18 +256,18 @@ def snap1( def snap4( - lhoverlaypath=None, - rhoverlaypath=None, - lhannotpath=None, - rhannotpath=None, + lh_overlay=None, + rh_overlay=None, + lh_annot=None, + rh_annot=None, fthresh=None, fmax=None, sdir=None, caption=None, invert=False, - labelname="cortex.label", + roi_name="cortex.label", surfname=None, - curvname="curv", + bg_map_name="curv", colorbar=True, outpath=None, font_file=None, @@ -261,9 +276,9 @@ def snap4( brain_scale=1.85, color_mode=ColorSelection.BOTH, ): - """Render four snapshot views (left/right hemispheres, front/back). + """Render four snapshot views (left/right hemispheres, lateral/medial). - This convenience function renders four views (top/bottom for each + This convenience function renders four views (lateral/medial for each hemisphere), stitches them together into a single PIL Image and returns it (and saves it to ``outpath`` when provided). It is typically used to produce publication-ready overview figures composed from both @@ -271,27 +286,34 @@ def snap4( Parameters ---------- - lhoverlaypath, rhoverlaypath : str or None - Paths to left/right hemisphere overlay files (mutually required if - either is provided). - lhannotpath, rhannotpath : str or None - Paths to left/right hemisphere annotation (.annot) files. + lh_overlay, rh_overlay : str, array-like, or None + Left/right hemisphere overlay — either a file path (FreeSurfer morph + or .mgh) or a per-vertex scalar array. Mutually required if either + is provided. + lh_annot, rh_annot : str, tuple, or None + Left/right hemisphere annotation — either a path to a .annot file or + a ``(labels, ctab)`` / ``(labels, ctab, names)`` tuple. fthresh, fmax : float or None - Threshold and saturation for overlay coloring. + Threshold and saturation for overlay coloring. Auto-estimated when + ``None``. sdir : str or None - Subject directory (used when surfname is not provided). If not - supplied the environment variable ``SUBJECTS_DIR`` is consulted. + Subject directory containing ``surf/`` and ``label/`` subdirectories. + Falls back to ``$SUBJECTS_DIR`` when ``None``. caption : str or None Caption string to place on the final image. invert : bool, optional Invert color scale. Default is ``False``. - labelname : str, optional - Name of the label file (default 'cortex.label'). + roi_name : str, optional + Basename of the label file used to restrict overlay coloring (default + ``'cortex.label'``). The full path is constructed as + ``/label/.``. surfname : str or None, optional - Surface basename to load (if None the function will auto-discover a - suitable surface). - curvname : str or None, optional - Curvature file basename to load for texturing non-colored regions. Default is ``curv``. + Surface basename to load (e.g. ``'white'``); auto-detected when + ``None``. + bg_map_name : str, optional + Basename of the curvature/morph file used for background shading + (default ``'curv'``). The full path is constructed as + ``/surf/.``. colorbar : bool, optional Whether to draw a colorbar on the composed image. Default is ``True``. outpath : str or None, optional @@ -324,10 +346,10 @@ def snap4( -------- >>> from whippersnappy import snap4 >>> img = snap4( - >>> lhoverlaypath='fsaverage/surf/lh.thickness', - >>> rhoverlaypath='fsaverage/surf/rh.thickness', - >>> sdir='./fsaverage' - >>> ) + ... lh_overlay='fsaverage/surf/lh.thickness', + ... rh_overlay='fsaverage/surf/rh.thickness', + ... sdir='./fsaverage' + ... ) >>> img.save('/tmp/whippersnappy_overview.png') """ wwidth = 540 @@ -342,15 +364,15 @@ def snap4( raise ValueError("No sdir or SUBJECTS_DIR provided") if not sdir and surfname is not None: logger.error("surfname provided but sdir is None") - raise ValueError("surfname provided but sdir is None; cannot construct meshpath.") + raise ValueError("surfname provided but sdir is None; cannot construct mesh path.") # Pre-pass: estimate missing fthresh/fmax from overlays for global color scale - has_overlay = lhoverlaypath is not None or rhoverlaypath is not None + has_overlay = lh_overlay is not None or rh_overlay is not None if has_overlay and (fthresh is None or fmax is None): est_fthreshs = [] est_fmaxs = [] - for overlaypath in filter(None, (lhoverlaypath, rhoverlaypath)): - h_fthresh, h_fmax = estimate_overlay_thresholds(overlaypath, fthresh, fmax) + for _overlay in filter(None, (lh_overlay, rh_overlay)): + h_fthresh, h_fmax = estimate_overlay_thresholds(_overlay, fthresh, fmax) est_fthreshs.append(h_fthresh) est_fmaxs.append(h_fmax) if fthresh is None and est_fthreshs: @@ -379,33 +401,38 @@ def snap4( if found_surfname is None: logger.error("Could not find valid surface in %s for hemi: %s!", sdir, hemi) raise FileNotFoundError(f"Could not find valid surface in {sdir} for hemi: {hemi}") - meshpath = os.path.join(sdir, "surf", hemi + "." + found_surfname) + mesh = os.path.join(sdir, "surf", hemi + "." + found_surfname) else: - meshpath = os.path.join(sdir, "surf", hemi + "." + surfname) + mesh = os.path.join(sdir, "surf", hemi + "." + surfname) + + # Assign derived paths for bg_map and roi + bg_map = os.path.join(sdir, "surf", hemi + "." + bg_map_name) if bg_map_name else None + roi = os.path.join(sdir, "label", hemi + "." + roi_name) if roi_name else None + overlay = lh_overlay if hemi == "lh" else rh_overlay + annot = lh_annot if hemi == "lh" else rh_annot - # Assign derived paths - curvpath = os.path.join(sdir, "surf", hemi + "." + curvname) if curvname else None - labelpath = os.path.join(sdir, "label", hemi + "." + labelname) if labelname else None - overlaypath = lhoverlaypath if hemi == "lh" else rhoverlaypath - annotpath = lhannotpath if hemi == "lh" else rhannotpath + # If overlay is an array, it doesn't have a path to log; handle gracefully + if isinstance(overlay, str): + logger.debug("overlay=%s exists=%s", overlay, os.path.exists(overlay)) + elif overlay is not None: + logger.debug("overlay=", getattr(overlay, 'shape', None)) # Diagnostic: report mesh and overlay paths and whether they exist logger.debug("hemisphere=%s", hemi) - logger.debug("meshpath=%s exists=%s", meshpath, os.path.exists(meshpath)) - if overlaypath is not None: - logger.debug("overlaypath=%s exists=%s", overlaypath, os.path.exists(overlaypath)) - if annotpath is not None: - logger.debug("annotpath=%s exists=%s", annotpath, os.path.exists(annotpath)) - if curvpath is not None: - logger.debug("curvpath=%s exists=%s", curvpath, os.path.exists(curvpath)) + if isinstance(mesh, str): + logger.debug("mesh=%s exists=%s", mesh, os.path.exists(mesh)) + if isinstance(annot, str) and annot is not None: + logger.debug("annot=%s exists=%s", annot, os.path.exists(annot)) + if bg_map is not None: + logger.debug("bg_map=%s exists=%s", bg_map, os.path.exists(bg_map)) try: meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry( - meshpath, overlaypath, annotpath, curvpath, labelpath, fthresh, fmax, invert, + mesh, overlay, annot, bg_map, roi, fthresh, fmax, invert, scale=brain_scale, color_mode=color_mode ) except Exception as e: - logger.error("prepare_geometry failed for %s: %s", meshpath, e) + logger.error("prepare_geometry failed for %s: %s", mesh, e) raise # Diagnostics about mesh data @@ -456,7 +483,7 @@ def snap4( caption_y = image.height - bottom_pad - text_h bar = ( create_colorbar(fthresh, fmax, invert, pos=pos, neg=neg) - if lhannotpath is None and rhannotpath is None and colorbar + if lh_annot is None and rh_annot is None and colorbar else None ) bar_h = bar.height if bar is not None else 0 @@ -481,16 +508,16 @@ def snap4( def snap_rotate( - meshpath, + mesh, outpath, n_frames=72, fps=24, width=700, height=500, - overlaypath=None, - curvpath=None, - annotpath=None, - labelpath=None, + overlay=None, + bg_map=None, + annot=None, + roi=None, fthresh=None, fmax=None, invert=False, @@ -512,8 +539,9 @@ def snap_rotate( Parameters ---------- - meshpath : str - Path to the surface file (FreeSurfer binary format, e.g. ``lh.white``). + mesh : str or tuple of (array-like, array-like) + Path to the surface file (FreeSurfer binary format, e.g. ``lh.white``) + **or** a ``(vertices, faces)`` tuple. outpath : str Destination file path. The extension controls the output format: @@ -528,14 +556,14 @@ def snap_rotate( Output frame rate in frames per second. Default is ``24``. width, height : int, optional Render resolution in pixels. Defaults are ``700`` and ``500``. - overlaypath : str or None, optional - Path to per-vertex overlay file (e.g. thickness). - curvpath : str or None, optional - Path to curvature file for texturing non-colored regions. - annotpath : str or None, optional - Path to FreeSurfer ``.annot`` file. - labelpath : str or None, optional - Path to label file used to mask overlay values. + overlay : str, array-like, or None, optional + Per-vertex overlay file path or array (e.g. thickness). + bg_map : str, array-like, or None, optional + Curvature/morph file path or array for background shading. + annot : str, tuple, or None, optional + FreeSurfer ``.annot`` file path or ``(labels, ctab)`` tuple. + roi : str, array-like, or None, optional + Label file path or boolean array to restrict overlay coloring. fthresh : float or None, optional Overlay threshold value. fmax : float or None, optional @@ -576,7 +604,7 @@ def snap_rotate( >>> snap_rotate( ... 'fsaverage/surf/lh.white', ... '/tmp/rotation.mp4', - ... overlaypath='fsaverage/surf/lh.thickness', + ... overlay='fsaverage/surf/lh.thickness', ... ) '/tmp/rotation.mp4' """ @@ -607,11 +635,11 @@ def snap_rotate( window = create_window_with_fallback(width, height, "WhipperSnapPy", visible=True) try: meshdata, triangles, fthresh, fmax, pos, neg = prepare_and_validate_geometry( - meshpath, - overlaypath, - annotpath, - curvpath, - labelpath, + mesh, + overlay, + annot, + bg_map, + roi, fthresh, fmax, invert, From 26f24f93a3d8e8ed35151966aa4bbe7725b5481c Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 15:04:39 +0100 Subject: [PATCH 58/83] fix ruff --- tests/test_array_inputs.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_array_inputs.py b/tests/test_array_inputs.py index 2a1757c..1c557a2 100644 --- a/tests/test_array_inputs.py +++ b/tests/test_array_inputs.py @@ -16,13 +16,11 @@ resolve_roi, ) from whippersnappy.geometry.prepare import ( - _estimate_thresholds_from_array, estimate_overlay_thresholds, prepare_geometry, prepare_geometry_from_arrays, ) - # --------------------------------------------------------------------------- # Minimal synthetic mesh (tetrahedron) # --------------------------------------------------------------------------- From ddf9a95807ca59d3d8acd8a2488189115267b0cf Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 15:21:59 +0100 Subject: [PATCH 59/83] add import for OFF, PLY and VTK ASCII triangle meshes --- tests/data/tetra.off | 12 + tests/test_mesh_io.py | 397 +++++++++++++++++++++++ whippersnappy/geometry/__init__.py | 5 + whippersnappy/geometry/inputs.py | 35 +- whippersnappy/geometry/mesh_io.py | 501 +++++++++++++++++++++++++++++ 5 files changed, 943 insertions(+), 7 deletions(-) create mode 100644 tests/data/tetra.off create mode 100644 tests/test_mesh_io.py create mode 100644 whippersnappy/geometry/mesh_io.py diff --git a/tests/data/tetra.off b/tests/data/tetra.off new file mode 100644 index 0000000..def3b95 --- /dev/null +++ b/tests/data/tetra.off @@ -0,0 +1,12 @@ +OFF +# Minimal tetrahedron — 4 vertices, 4 triangular faces +4 4 6 +0.0 0.0 0.0 +1.0 0.0 0.0 +0.0 1.0 0.0 +0.0 0.0 1.0 +3 0 2 1 +3 0 1 3 +3 0 3 2 +3 1 2 3 + diff --git a/tests/test_mesh_io.py b/tests/test_mesh_io.py new file mode 100644 index 0000000..11a79db --- /dev/null +++ b/tests/test_mesh_io.py @@ -0,0 +1,397 @@ +"""Tests for whippersnappy/geometry/mesh_io.py and the updated resolve_mesh. + +All tests use in-memory strings written to temporary files so no external +data is required (except the bundled tetra.off sample). +""" + +import os +import tempfile + +import numpy as np +import pytest + +from whippersnappy.geometry.inputs import resolve_mesh +from whippersnappy.geometry.mesh_io import ( + read_mesh, + read_off, + read_ply_ascii, + read_vtk_ascii_polydata, +) + +# --------------------------------------------------------------------------- +# Shared sample content strings +# --------------------------------------------------------------------------- + +_TETRA_OFF = """\ +OFF +# tetrahedron +4 4 6 +0.0 0.0 0.0 +1.0 0.0 0.0 +0.0 1.0 0.0 +0.0 0.0 1.0 +3 0 2 1 +3 0 1 3 +3 0 3 2 +3 1 2 3 +""" + +_TETRA_VTK = """\ +# vtk DataFile Version 3.0 +tetrahedron +ASCII +DATASET POLYDATA +POINTS 4 float +0.0 0.0 0.0 +1.0 0.0 0.0 +0.0 1.0 0.0 +0.0 0.0 1.0 +POLYGONS 4 16 +3 0 2 1 +3 0 1 3 +3 0 3 2 +3 1 2 3 +""" + +_TETRA_PLY = """\ +ply +format ascii 1.0 +comment tetrahedron +element vertex 4 +property float x +property float y +property float z +element face 4 +property list uchar int vertex_indices +end_header +0.0 0.0 0.0 +1.0 0.0 0.0 +0.0 1.0 0.0 +0.0 0.0 1.0 +3 0 2 1 +3 0 1 3 +3 0 3 2 +3 1 2 3 +""" + + +def _write_tmp(content, suffix): + """Write *content* to a named temp file, return its path.""" + fd, path = tempfile.mkstemp(suffix=suffix) + with os.fdopen(fd, "w") as fh: + fh.write(content) + return path + + +def _expected_verts(): + return np.array( + [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32 + ) + + +def _expected_faces(): + return np.array( + [[0, 2, 1], [0, 1, 3], [0, 3, 2], [1, 2, 3]], dtype=np.uint32 + ) + + +# --------------------------------------------------------------------------- +# read_off +# --------------------------------------------------------------------------- + +class TestReadOff: + def test_basic(self): + path = _write_tmp(_TETRA_OFF, ".off") + try: + v, f = read_off(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + assert v.dtype == np.float32 + assert f.shape == (4, 3) + assert f.dtype == np.uint32 + np.testing.assert_array_equal(v, _expected_verts()) + np.testing.assert_array_equal(f, _expected_faces()) + + def test_bundled_sample(self): + """Verify the bundled tests/data/tetra.off file loads correctly.""" + here = os.path.dirname(__file__) + sample = os.path.join(here, "data", "tetra.off") + v, f = read_off(sample) + assert v.shape == (4, 3) + assert f.shape == (4, 3) + + def test_bad_header_raises(self): + content = "NOFF\n4 4 6\n0 0 0\n1 0 0\n0 1 0\n0 0 1\n3 0 1 2\n3 0 1 3\n3 0 2 3\n3 1 2 3\n" + path = _write_tmp(content, ".off") + try: + with pytest.raises(ValueError, match="OFF"): + read_off(path) + finally: + os.unlink(path) + + def test_quad_face_raises(self): + content = "OFF\n4 1 4\n0 0 0\n1 0 0\n0 1 0\n0 0 1\n4 0 1 2 3\n" + path = _write_tmp(content, ".off") + try: + with pytest.raises(ValueError, match="triangles"): + read_off(path) + finally: + os.unlink(path) + + def test_out_of_range_indices_raises(self): + content = "OFF\n3 1 3\n0 0 0\n1 0 0\n0 1 0\n3 0 1 99\n" + path = _write_tmp(content, ".off") + try: + with pytest.raises(ValueError, match="out of range"): + read_off(path) + finally: + os.unlink(path) + + def test_empty_file_raises(self): + path = _write_tmp("", ".off") + try: + with pytest.raises(ValueError, match="empty"): + read_off(path) + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# read_vtk_ascii_polydata +# --------------------------------------------------------------------------- + +class TestReadVtkAsciiPolydata: + def test_basic(self): + path = _write_tmp(_TETRA_VTK, ".vtk") + try: + v, f = read_vtk_ascii_polydata(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + assert v.dtype == np.float32 + assert f.shape == (4, 3) + assert f.dtype == np.uint32 + np.testing.assert_array_equal(v, _expected_verts()) + np.testing.assert_array_equal(f, _expected_faces()) + + def test_binary_vtk_raises(self): + content = "# vtk DataFile Version 3.0\ntest\nBINARY\nDATASET POLYDATA\n" + path = _write_tmp(content, ".vtk") + try: + with pytest.raises(ValueError, match="BINARY"): + read_vtk_ascii_polydata(path) + finally: + os.unlink(path) + + def test_non_polydata_raises(self): + content = "# vtk DataFile Version 3.0\ntest\nASCII\nDATASET UNSTRUCTURED_GRID\n" + path = _write_tmp(content, ".vtk") + try: + with pytest.raises(ValueError, match="POLYDATA"): + read_vtk_ascii_polydata(path) + finally: + os.unlink(path) + + def test_quad_polygon_raises(self): + content = ( + "# vtk DataFile Version 3.0\ntest\nASCII\nDATASET POLYDATA\n" + "POINTS 4 float\n0 0 0\n1 0 0\n0 1 0\n0 0 1\n" + "POLYGONS 1 5\n4 0 1 2 3\n" + ) + path = _write_tmp(content, ".vtk") + try: + with pytest.raises(ValueError, match="triangles"): + read_vtk_ascii_polydata(path) + finally: + os.unlink(path) + + def test_missing_points_raises(self): + content = ( + "# vtk DataFile Version 3.0\ntest\nASCII\nDATASET POLYDATA\n" + "POLYGONS 1 4\n3 0 1 2\n" + ) + path = _write_tmp(content, ".vtk") + try: + with pytest.raises(ValueError, match="POINTS"): + read_vtk_ascii_polydata(path) + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# read_ply_ascii +# --------------------------------------------------------------------------- + +class TestReadPlyAscii: + def test_basic(self): + path = _write_tmp(_TETRA_PLY, ".ply") + try: + v, f = read_ply_ascii(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + assert v.dtype == np.float32 + assert f.shape == (4, 3) + assert f.dtype == np.uint32 + np.testing.assert_array_equal(v, _expected_verts()) + np.testing.assert_array_equal(f, _expected_faces()) + + def test_extra_vertex_props(self): + """PLY with extra per-vertex properties (e.g. nx ny nz) should still load.""" + content = """\ +ply +format ascii 1.0 +element vertex 3 +property float x +property float y +property float z +property float nx +property float ny +property float nz +element face 1 +property list uchar int vertex_indices +end_header +0.0 0.0 0.0 0.0 0.0 1.0 +1.0 0.0 0.0 0.0 0.0 1.0 +0.0 1.0 0.0 0.0 0.0 1.0 +3 0 1 2 +""" + path = _write_tmp(content, ".ply") + try: + v, f = read_ply_ascii(path) + finally: + os.unlink(path) + assert v.shape == (3, 3) + assert f.shape == (1, 3) + + def test_binary_ply_raises(self): + content = "ply\nformat binary_little_endian 1.0\nelement vertex 4\nend_header\n" + path = _write_tmp(content, ".ply") + try: + with pytest.raises(ValueError, match="binary"): + read_ply_ascii(path) + finally: + os.unlink(path) + + def test_not_ply_raises(self): + content = "OFF\n4 4 6\n0 0 0\n1 0 0\n0 1 0\n0 0 1\n3 0 1 2\n" + path = _write_tmp(content, ".ply") + try: + with pytest.raises(ValueError, match="ply"): + read_ply_ascii(path) + finally: + os.unlink(path) + + def test_quad_face_raises(self): + content = """\ +ply +format ascii 1.0 +element vertex 4 +property float x +property float y +property float z +element face 1 +property list uchar int vertex_indices +end_header +0.0 0.0 0.0 +1.0 0.0 0.0 +0.0 1.0 0.0 +0.0 0.0 1.0 +4 0 1 2 3 +""" + path = _write_tmp(content, ".ply") + try: + with pytest.raises(ValueError, match="triangles"): + read_ply_ascii(path) + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# read_mesh dispatcher +# --------------------------------------------------------------------------- + +class TestReadMeshDispatcher: + def test_off_dispatch(self): + path = _write_tmp(_TETRA_OFF, ".off") + try: + v, f = read_mesh(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + + def test_vtk_dispatch(self): + path = _write_tmp(_TETRA_VTK, ".vtk") + try: + v, f = read_mesh(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + + def test_ply_dispatch(self): + path = _write_tmp(_TETRA_PLY, ".ply") + try: + v, f = read_mesh(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + + def test_unknown_extension_raises(self): + with pytest.raises(ValueError, match="Unsupported"): + read_mesh("/some/file.stl") + + def test_case_insensitive_extension(self): + """Uppercase .OFF extension should be recognised.""" + path = _write_tmp(_TETRA_OFF, ".OFF") + try: + v, f = read_mesh(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + + +# --------------------------------------------------------------------------- +# resolve_mesh routing +# --------------------------------------------------------------------------- + +class TestResolveMeshRouting: + def test_off_path_routed(self): + path = _write_tmp(_TETRA_OFF, ".off") + try: + v, f = resolve_mesh(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + assert v.dtype == np.float32 + assert f.shape == (4, 3) + assert f.dtype == np.uint32 + + def test_vtk_path_routed(self): + path = _write_tmp(_TETRA_VTK, ".vtk") + try: + v, f = resolve_mesh(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + + def test_ply_path_routed(self): + path = _write_tmp(_TETRA_PLY, ".ply") + try: + v, f = resolve_mesh(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + + def test_array_tuple_still_works(self): + v_in = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32) + f_in = np.array([[0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3]], dtype=np.uint32) + v, f = resolve_mesh((v_in, f_in)) + assert v.shape == (4, 3) + assert f.shape == (4, 3) + + def test_array_out_of_range_raises(self): + v_in = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]], dtype=np.float32) + f_in = np.array([[0, 1, 99]], dtype=np.uint32) # index 99 out of range + with pytest.raises(ValueError, match="out of range"): + resolve_mesh((v_in, f_in)) diff --git a/whippersnappy/geometry/__init__.py b/whippersnappy/geometry/__init__.py index d377e82..d2a57b0 100644 --- a/whippersnappy/geometry/__init__.py +++ b/whippersnappy/geometry/__init__.py @@ -3,6 +3,7 @@ Expose prepare_geometry and small IO helpers under `whippersnappy.geometry`. """ from .inputs import resolve_annot, resolve_bg_map, resolve_mesh, resolve_overlay, resolve_roi +from .mesh_io import read_mesh, read_off, read_ply_ascii, read_vtk_ascii_polydata from .prepare import ( estimate_overlay_thresholds, prepare_and_validate_geometry, @@ -22,6 +23,10 @@ 'resolve_bg_map', 'resolve_roi', 'resolve_annot', + 'read_mesh', + 'read_off', + 'read_vtk_ascii_polydata', + 'read_ply_ascii', 'read_geometry', 'read_annot_data', 'read_mgh_data', diff --git a/whippersnappy/geometry/inputs.py b/whippersnappy/geometry/inputs.py index 35f6ca3..0091938 100644 --- a/whippersnappy/geometry/inputs.py +++ b/whippersnappy/geometry/inputs.py @@ -12,8 +12,12 @@ import numpy as np from ..utils.colormap import mask_label +from .mesh_io import read_mesh as _read_mesh_by_ext from .read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data +# Extensions handled by the lightweight ASCII mesh readers in mesh_io.py +_MESH_IO_EXTS = frozenset({".off", ".vtk", ".ply"}) + def resolve_mesh(mesh): """Resolve a mesh input to ``(vertices, faces)`` numpy arrays. @@ -21,9 +25,13 @@ def resolve_mesh(mesh): Parameters ---------- mesh : str or tuple/list of two array-likes - Either a file path to a FreeSurfer-format surface file, or a - two-element sequence whose first element is the vertex coordinate - array and whose second element is the face index array. + * ``str`` — path to a mesh file. Files with extensions ``.off``, + ``.vtk``, or ``.ply`` are loaded by the lightweight ASCII readers + in :mod:`whippersnappy.geometry.mesh_io`. All other paths (e.g. + FreeSurfer surfaces such as ``lh.white``) are loaded via + :func:`whippersnappy.geometry.read_geometry`. + * Two-element tuple/list — ``(vertices, faces)`` array-likes + converted to ``float32`` and ``uint32`` numpy arrays respectively. Returns ------- @@ -37,12 +45,17 @@ def resolve_mesh(mesh): TypeError If *mesh* is neither a ``str`` nor a two-element tuple/list. ValueError - If the resulting arrays do not have the expected shapes. + If the resulting arrays do not have the expected shapes or if face + indices are out of range. """ if isinstance(mesh, str): - v_raw, f_raw = read_geometry(mesh, read_metadata=False) - vertices = np.asarray(v_raw, dtype=np.float32) - faces = np.asarray(f_raw, dtype=np.uint32) + ext = os.path.splitext(mesh)[1].lower() + if ext in _MESH_IO_EXTS: + vertices, faces = _read_mesh_by_ext(mesh) + else: + v_raw, f_raw = read_geometry(mesh, read_metadata=False) + vertices = np.asarray(v_raw, dtype=np.float32) + faces = np.asarray(f_raw, dtype=np.uint32) elif isinstance(mesh, (tuple, list)) and len(mesh) == 2: vertices = np.asarray(mesh[0], dtype=np.float32) faces = np.asarray(mesh[1], dtype=np.uint32) @@ -60,6 +73,14 @@ def resolve_mesh(mesh): raise ValueError( f"faces must be an array of shape (M, 3), got shape {faces.shape}." ) + # Bounds check for array inputs (file readers do their own check) + if not isinstance(mesh, str) and faces.size > 0: + n_verts = vertices.shape[0] + if int(faces.max()) >= n_verts or int(faces.min()) < 0: + raise ValueError( + f"Face indices out of range [0, {n_verts}): " + f"min={int(faces.min())}, max={int(faces.max())}." + ) return vertices, faces diff --git a/whippersnappy/geometry/mesh_io.py b/whippersnappy/geometry/mesh_io.py new file mode 100644 index 0000000..fc70ce2 --- /dev/null +++ b/whippersnappy/geometry/mesh_io.py @@ -0,0 +1,501 @@ +"""Lightweight ASCII mesh readers for common open formats. + +This module implements pure-Python (stdlib + numpy only) readers for: + +* **OFF** — Object File Format, ASCII triangles +* **VTK legacy ASCII PolyData** — ``DATASET POLYDATA`` with POINTS/POLYGONS +* **PLY ASCII** — Stanford PLY, ASCII encoding, triangles only + +All readers return ``(vertices, faces)`` where + +* ``vertices`` — ``float32`` array of shape ``(N, 3)`` +* ``faces`` — ``uint32`` array of shape ``(M, 3)`` + +The public dispatcher :func:`read_mesh` routes by file extension +(``.off``, ``.vtk``, ``.ply``). For FreeSurfer surfaces (no standard +extension) use the existing :func:`whippersnappy.geometry.read_geometry` +directly, or go through :func:`whippersnappy.geometry.inputs.resolve_mesh` +which handles the routing automatically. +""" + +import numpy as np + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _non_empty_lines(path): + """Yield stripped, non-empty, non-comment lines from a text file.""" + with open(path, encoding="utf-8", errors="replace") as fh: + for raw in fh: + line = raw.strip() + if line and not line.startswith("#"): + yield line + + +# --------------------------------------------------------------------------- +# OFF reader +# --------------------------------------------------------------------------- + +def read_off(path): + """Read an ASCII OFF (Object File Format) triangle mesh. + + Parameters + ---------- + path : str + Path to the ``.off`` file. + + Returns + ------- + vertices : numpy.ndarray, shape (N, 3), dtype float32 + faces : numpy.ndarray, shape (M, 3), dtype uint32 + + Raises + ------ + ValueError + If the file does not start with ``OFF``, if any face is not a + triangle, or if the declared counts don't match the data. + IOError + If the file cannot be opened. + + Notes + ----- + Comments (lines starting with ``#``) and blank lines are ignored + anywhere in the file. The optional ``COFF`` / ``NOFF`` / ``CNOFF`` + variants are *not* supported; only plain ``OFF`` is accepted. + """ + lines = list(_non_empty_lines(path)) + if not lines: + raise ValueError(f"OFF file is empty: {path!r}") + + # First line must be exactly "OFF" + header = lines[0].upper() + if header != "OFF": + raise ValueError( + f"Expected 'OFF' header on first non-comment line, got {lines[0]!r} " + f"in {path!r}. Only plain ASCII OFF is supported." + ) + + if len(lines) < 2: + raise ValueError(f"OFF file has no count line after header: {path!r}") + + # Second line: n_vertices n_faces n_edges + parts = lines[1].split() + if len(parts) < 2: + raise ValueError( + f"OFF count line must have at least 2 integers (n_vertices n_faces), " + f"got {lines[1]!r} in {path!r}." + ) + try: + n_verts = int(parts[0]) + n_faces = int(parts[1]) + except ValueError as exc: + raise ValueError( + f"Could not parse OFF count line {lines[1]!r} in {path!r}." + ) from exc + + # Validate we have enough data lines + data_lines = lines[2:] + if len(data_lines) < n_verts + n_faces: + raise ValueError( + f"OFF file declares {n_verts} vertices and {n_faces} faces " + f"but only {len(data_lines)} data lines follow in {path!r}." + ) + + # Parse vertices + vertices = np.empty((n_verts, 3), dtype=np.float32) + for i in range(n_verts): + try: + coords = data_lines[i].split()[:3] + vertices[i] = [float(c) for c in coords] + except (ValueError, IndexError) as exc: + raise ValueError( + f"Could not parse vertex {i} in OFF file {path!r}: " + f"{data_lines[i]!r}" + ) from exc + + # Parse faces + faces = np.empty((n_faces, 3), dtype=np.uint32) + for j in range(n_faces): + try: + tokens = data_lines[n_verts + j].split() + n_poly = int(tokens[0]) + except (ValueError, IndexError) as exc: + raise ValueError( + f"Could not parse face {j} in OFF file {path!r}: " + f"{data_lines[n_verts + j]!r}" + ) from exc + if n_poly != 3: + raise ValueError( + f"OFF face {j} has {n_poly} vertices; only triangles (3) are " + f"supported in {path!r}. Convert to a triangle mesh first." + ) + try: + faces[j] = [int(tokens[k]) for k in range(1, 4)] + except (ValueError, IndexError) as exc: + raise ValueError( + f"Could not parse face indices at face {j} in {path!r}: " + f"{data_lines[n_verts + j]!r}" + ) from exc + + # Bounds check + if n_faces > 0: + if int(faces.max()) >= n_verts or int(faces.min()) < 0: + raise ValueError( + f"OFF face indices out of range [0, {n_verts}) in {path!r}." + ) + + return vertices, faces + + +# --------------------------------------------------------------------------- +# Legacy VTK ASCII PolyData reader +# --------------------------------------------------------------------------- + +def read_vtk_ascii_polydata(path): + """Read a legacy ASCII VTK PolyData triangle mesh. + + Only the *ASCII legacy format* with ``DATASET POLYDATA`` is supported. + Binary VTK files are explicitly rejected. + + Parameters + ---------- + path : str + Path to the ``.vtk`` file. + + Returns + ------- + vertices : numpy.ndarray, shape (N, 3), dtype float32 + faces : numpy.ndarray, shape (M, 3), dtype uint32 + + Raises + ------ + ValueError + If the file is binary, is not POLYDATA, contains non-triangle + polygons, or if required sections are missing. + IOError + If the file cannot be opened. + """ + with open(path, encoding="utf-8", errors="replace") as fh: + raw_lines = fh.readlines() + + # Legacy VTK format: + # line 0: "# vtk DataFile Version x.x" + # line 1: title (arbitrary free text) + # line 2: "ASCII" or "BINARY" + # line 3: "DATASET " + if len(raw_lines) < 3: + raise ValueError(f"VTK file too short: {path!r}") + + fmt_line = raw_lines[2].strip().upper() + if "BINARY" in fmt_line: + raise ValueError( + f"Only ASCII legacy VTK POLYDATA is supported; " + f"file appears to be BINARY: {path!r}. " + f"Convert with: vtk-convert or meshio-convert." + ) + if "ASCII" not in fmt_line: + raise ValueError( + f"Could not determine VTK format from line 3 (expected 'ASCII' or " + f"'BINARY'): {raw_lines[2]!r} in {path!r}." + ) + + # Scan for DATASET POLYDATA + dataset_found = False + for line in raw_lines: + if line.strip().upper().startswith("DATASET"): + if "POLYDATA" not in line.upper(): + raise ValueError( + f"Only POLYDATA VTK datasets are supported, " + f"got: {line.strip()!r} in {path!r}." + ) + dataset_found = True + break + if not dataset_found: + raise ValueError(f"No DATASET line found in VTK file {path!r}.") + + # Tokenise everything into a flat list for easy sectioned parsing + lines = [raw_ln.strip() for raw_ln in raw_lines if raw_ln.strip() and not raw_ln.strip().startswith("#")] + + vertices = None + faces = None + i = 0 + while i < len(lines): + upper = lines[i].upper() + + if upper.startswith("POINTS"): + parts = lines[i].split() + n_pts = int(parts[1]) + # Collect 3*n_pts floats; they may span multiple lines + floats = [] + i += 1 + while len(floats) < 3 * n_pts and i < len(lines): + # Stop at next keyword section + if lines[i].upper().split()[0] in ( + "POLYGONS", "LINES", "STRIPS", "VERTICES", + "POINT_DATA", "CELL_DATA", "FIELD", "NORMALS", + "TEXTURE_COORDINATES", "SCALARS", "LOOKUP_TABLE", + ): + break + floats.extend(float(x) for x in lines[i].split()) + i += 1 + if len(floats) < 3 * n_pts: + raise ValueError( + f"Expected {3 * n_pts} floats for POINTS but got " + f"{len(floats)} in {path!r}." + ) + vertices = np.array(floats[: 3 * n_pts], dtype=np.float32).reshape(n_pts, 3) + continue # i already advanced + + elif upper.startswith("POLYGONS"): + parts = lines[i].split() + n_polys = int(parts[1]) + face_list = [] + i += 1 + while len(face_list) < n_polys and i < len(lines): + if lines[i].upper().split()[0] in ( + "POINTS", "LINES", "STRIPS", "VERTICES", + "POINT_DATA", "CELL_DATA", "FIELD", "NORMALS", + "TEXTURE_COORDINATES", "SCALARS", "LOOKUP_TABLE", + ): + break + tokens = lines[i].split() + n_poly = int(tokens[0]) + if n_poly != 3: + raise ValueError( + f"VTK polygon {len(face_list)} has {n_poly} vertices; " + f"only triangles (3) are supported in {path!r}. " + f"Triangulate the mesh before loading." + ) + face_list.append([int(tokens[1]), int(tokens[2]), int(tokens[3])]) + i += 1 + if len(face_list) < n_polys: + raise ValueError( + f"Expected {n_polys} polygons but only parsed " + f"{len(face_list)} in {path!r}." + ) + faces = np.array(face_list, dtype=np.uint32) + continue # i already advanced + + else: + i += 1 + + if vertices is None: + raise ValueError(f"No POINTS section found in VTK file {path!r}.") + if faces is None: + raise ValueError(f"No POLYGONS section found in VTK file {path!r}.") + + # Bounds check + n_verts = vertices.shape[0] + if faces.size > 0 and (int(faces.max()) >= n_verts or int(faces.min()) < 0): + raise ValueError( + f"VTK face indices out of range [0, {n_verts}) in {path!r}." + ) + + return vertices, faces + + +# --------------------------------------------------------------------------- +# PLY ASCII reader +# --------------------------------------------------------------------------- + +def read_ply_ascii(path): + """Read an ASCII PLY triangle mesh. + + Only ASCII PLY files are supported. Binary PLY files are explicitly + rejected with a helpful error message. + + Parameters + ---------- + path : str + Path to the ``.ply`` file. + + Returns + ------- + vertices : numpy.ndarray, shape (N, 3), dtype float32 + faces : numpy.ndarray, shape (M, 3), dtype uint32 + + Raises + ------ + ValueError + If the file is binary PLY, if faces are not triangles, or if the + header is malformed. + IOError + If the file cannot be opened. + """ + with open(path, encoding="utf-8", errors="replace") as fh: + raw_lines = fh.readlines() + + if not raw_lines or raw_lines[0].strip() != "ply": + raise ValueError( + f"File does not start with 'ply' magic; not a PLY file: {path!r}." + ) + + # Check encoding + for line in raw_lines[1:4]: + stripped = line.strip().lower() + if stripped.startswith("format"): + if "ascii" not in stripped: + raise ValueError( + f"PLY binary format not supported; only ASCII PLY is " + f"accepted: {path!r}. " + f"Convert with: plyconvert or meshio-convert." + ) + break + + # Parse header + n_verts = None + n_faces = None + vertex_props = [] # ordered list of property names for the vertex element + in_vertex = False + in_face = False + header_end = 0 + + for idx, line in enumerate(raw_lines): + stripped = line.strip() + lower = stripped.lower() + + if lower == "end_header": + header_end = idx + 1 + break + if lower.startswith("element vertex"): + n_verts = int(stripped.split()[2]) + in_vertex = True + in_face = False + elif lower.startswith("element face"): + n_faces = int(stripped.split()[2]) + in_face = True + in_vertex = False + elif lower.startswith("property") and in_vertex: + # e.g. "property float x" + parts = stripped.split() + if len(parts) >= 3: + vertex_props.append(parts[-1]) + elif lower.startswith("element") and not lower.startswith("element vertex") and not lower.startswith("element face"): + in_vertex = False + + if n_verts is None: + raise ValueError(f"No 'element vertex' found in PLY header: {path!r}.") + if n_faces is None: + raise ValueError(f"No 'element face' found in PLY header: {path!r}.") + + # Determine column indices for x, y, z + try: + xi = vertex_props.index("x") + yi = vertex_props.index("y") + zi = vertex_props.index("z") + except ValueError as exc: + raise ValueError( + f"PLY vertex element missing x/y/z properties in {path!r}; " + f"found: {vertex_props!r}." + ) from exc + + data_lines = [raw_line.strip() for raw_line in raw_lines[header_end:] if raw_line.strip()] + + if len(data_lines) < n_verts + n_faces: + raise ValueError( + f"PLY file has {len(data_lines)} data lines but expects " + f"{n_verts} vertices + {n_faces} faces in {path!r}." + ) + + # Parse vertices + vertices = np.empty((n_verts, 3), dtype=np.float32) + for i in range(n_verts): + try: + tokens = data_lines[i].split() + vertices[i] = [float(tokens[xi]), float(tokens[yi]), float(tokens[zi])] + except (ValueError, IndexError) as exc: + raise ValueError( + f"Could not parse PLY vertex {i} in {path!r}: " + f"{data_lines[i]!r}" + ) from exc + + # Parse faces + faces = np.empty((n_faces, 3), dtype=np.uint32) + for j in range(n_faces): + try: + tokens = data_lines[n_verts + j].split() + count = int(tokens[0]) + except (ValueError, IndexError) as exc: + raise ValueError( + f"Could not parse PLY face {j} in {path!r}: " + f"{data_lines[n_verts + j]!r}" + ) from exc + if count != 3: + raise ValueError( + f"PLY face {j} has {count} vertices; only triangles (3) are " + f"supported in {path!r}. Triangulate the mesh first." + ) + try: + faces[j] = [int(tokens[1]), int(tokens[2]), int(tokens[3])] + except (ValueError, IndexError) as exc: + raise ValueError( + f"Could not parse PLY face indices at face {j} in {path!r}: " + f"{data_lines[n_verts + j]!r}" + ) from exc + + # Bounds check + if n_faces > 0 and (int(faces.max()) >= n_verts or int(faces.min()) < 0): + raise ValueError( + f"PLY face indices out of range [0, {n_verts}) in {path!r}." + ) + + return vertices, faces + + +# --------------------------------------------------------------------------- +# Dispatcher +# --------------------------------------------------------------------------- + +_READERS = { + ".off": read_off, + ".vtk": read_vtk_ascii_polydata, + ".ply": read_ply_ascii, +} + +_SUPPORTED = ", ".join(sorted(_READERS)) + + +def read_mesh(path): + """Read a triangle mesh from an OFF, VTK, or PLY file. + + Dispatches to the appropriate reader based on the file extension. + For FreeSurfer binary surfaces (which typically have no standard + extension, e.g. ``lh.white``) use + :func:`whippersnappy.geometry.read_geometry` directly, or pass the + path through :func:`whippersnappy.geometry.inputs.resolve_mesh` which + handles the routing automatically. + + Parameters + ---------- + path : str + Path to a mesh file. Extension must be one of: + ``.off``, ``.vtk``, ``.ply`` (case-insensitive). + + Returns + ------- + vertices : numpy.ndarray, shape (N, 3), dtype float32 + faces : numpy.ndarray, shape (M, 3), dtype uint32 + + Raises + ------ + ValueError + If the extension is not recognised. + """ + import os + ext = os.path.splitext(path)[1].lower() + reader = _READERS.get(ext) + if reader is None: + raise ValueError( + f"Unsupported mesh file extension {ext!r} for {path!r}. " + f"Supported formats: {_SUPPORTED}. " + f"For FreeSurfer surfaces (no extension) use resolve_mesh() " + f"or read_geometry() directly." + ) + return reader(path) + + + + + From 1b622c3d5f2786ac5f8028a2f238b5bc1883a126 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 16:15:46 +0100 Subject: [PATCH 60/83] update data path --- .github/workflows/doc.yml | 2 +- whippersnappy/utils/datasets.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 78c3a1d..73850ce 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -41,7 +41,7 @@ jobs: uses: actions/cache@v4 with: path: ~/.cache/whippersnappy - key: sample-data-sub-rs-data-v1 + key: sample-data-sub-rs-data-v1.0 - name: Build doc run: TZ=UTC sphinx-build ./main/doc ./doc-build/dev -W --keep-going - name: Upload documentation diff --git a/whippersnappy/utils/datasets.py b/whippersnappy/utils/datasets.py index 4a88e1b..7f2293f 100644 --- a/whippersnappy/utils/datasets.py +++ b/whippersnappy/utils/datasets.py @@ -8,7 +8,7 @@ RELEASE_URL = ( "https://github.com/Deep-MI/WhipperSnapPy" - "/releases/download/data-v1/{file_name}" + "/releases/download/data-v1.0/{file_name}" ) # Mapping of relative path inside the subject directory → SHA-256 hash. From 2d89603c2d3ff7f52088f6b513eba99a936b6234 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 16:25:03 +0100 Subject: [PATCH 61/83] ruff fix --- whippersnappy/geometry/mesh_io.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/whippersnappy/geometry/mesh_io.py b/whippersnappy/geometry/mesh_io.py index fc70ce2..c30f04e 100644 --- a/whippersnappy/geometry/mesh_io.py +++ b/whippersnappy/geometry/mesh_io.py @@ -20,7 +20,6 @@ import numpy as np - # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- @@ -349,7 +348,6 @@ def read_ply_ascii(path): n_faces = None vertex_props = [] # ordered list of property names for the vertex element in_vertex = False - in_face = False header_end = 0 for idx, line in enumerate(raw_lines): @@ -362,17 +360,17 @@ def read_ply_ascii(path): if lower.startswith("element vertex"): n_verts = int(stripped.split()[2]) in_vertex = True - in_face = False elif lower.startswith("element face"): n_faces = int(stripped.split()[2]) - in_face = True in_vertex = False elif lower.startswith("property") and in_vertex: # e.g. "property float x" parts = stripped.split() if len(parts) >= 3: vertex_props.append(parts[-1]) - elif lower.startswith("element") and not lower.startswith("element vertex") and not lower.startswith("element face"): + elif (lower.startswith("element") + and not lower.startswith("element vertex") + and not lower.startswith("element face")): in_vertex = False if n_verts is None: From 2f73f2a7cc7f3f408d6f457c0cc030ed0b1bbdff Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 16:27:18 +0100 Subject: [PATCH 62/83] add imports to doc build for rotating gif --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index e040501..4c05456 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ doc = [ 'pooch>=1.6', 'pythreejs', 'ipywidgets', + 'imageio>=2.28', ] notebook = [ 'pythreejs', # Three.js for interactive 3D (works in all Jupyter environments) From 4021352aaed0a30ce3d23b54a99b6321fcd93754 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 17:24:37 +0100 Subject: [PATCH 63/83] update doc for any mesh input and also whippersnap GUI CLI --- README.md | 67 ++++-- pyproject.toml | 2 +- whippersnappy/__init__.py | 33 ++- whippersnappy/cli/whippersnap.py | 330 +++++++++++++++++++++--------- whippersnappy/cli/whippersnap1.py | 36 +++- whippersnappy/plot3d.py | 72 ++++--- whippersnappy/snap.py | 83 ++++---- 7 files changed, 410 insertions(+), 213 deletions(-) diff --git a/README.md b/README.md index 9a329cf..dd21af9 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,12 @@ # WhipperSnapPy -WhipperSnapPy is a Python/OpenGL tool to render FreeSurfer and FastSurfer -surface models with color overlays or parcellations and generate screenshots -— from the command line, in Jupyter notebooks, or via a desktop GUI. +WhipperSnapPy is a Python/OpenGL tool to render triangular surface meshes +with color overlays or parcellations and generate screenshots — from the +command line, in Jupyter notebooks, or via a desktop GUI. + +It works with FreeSurfer and FastSurfer brain surfaces as well as any +triangle mesh in OFF, legacy ASCII VTK PolyData, or ASCII PLY format, or +passed directly as a NumPy ``(vertices, faces)`` tuple. ## Installation @@ -50,18 +54,23 @@ whippersnap4 -lh $LH_OVERLAY \ ### Single-view snapshot (`whippersnap1`) -Renders one surface view: +Renders one view of any triangular surface mesh: ```bash whippersnap1 $SUBJECT_DIR/surf/lh.white \ --overlay $LH_OVERLAY \ + --bg-map $SUBJECT_DIR/surf/lh.curv \ + --roi $SUBJECT_DIR/label/lh.cortex.label \ --view left \ -o snap1.png + +# Also works with OFF / VTK / PLY +whippersnap1 mesh.off --overlay values.mgh -o snap1.png ``` ### Rotation video (`whippersnap1 --rotate`) -Renders a 360° animation: +Renders a 360° animation of any triangular surface mesh: ```bash whippersnap1 $SUBJECT_DIR/surf/lh.white \ @@ -72,11 +81,21 @@ whippersnap1 $SUBJECT_DIR/surf/lh.white \ ### Desktop GUI (`whippersnap`) -Launches an interactive Qt window with live threshold controls: +Launches an interactive Qt window with live threshold controls. + +**General mode** — any triangular mesh: ```bash pip install 'whippersnappy[gui]' -whippersnap --lh $LH_OVERLAY --sdir $SUBJECT_DIR +whippersnap --mesh mesh.off --overlay values.mgh +whippersnap --mesh lh.white --overlay lh.thickness --bg-map lh.curv +``` + +**FreeSurfer shortcut** — derive all paths from a subject directory: + +```bash +whippersnap -sd $SUBJECT_DIR --hemi lh -lh $LH_OVERLAY +whippersnap -sd $SUBJECT_DIR --hemi rh --annot rh.aparc.annot ``` For all options run `whippersnap4 --help`, `whippersnap1 --help`, or `whippersnap --help`. @@ -89,37 +108,42 @@ from whippersnappy import snap1, snap4, snap_rotate, plot3d | Function | Description | |---|---| -| `snap1` | Single-view surface snapshot → PIL Image | -| `snap4` | Four-view composed image (lateral/medial, both hemispheres) | -| `snap_rotate` | 360° rotation video (MP4, WebM, or GIF) | +| `snap1` | Single-view snapshot of any triangular mesh → PIL Image | +| `snap4` | Four-view composed image (FreeSurfer subject, lateral/medial both hemispheres) | +| `snap_rotate` | 360° rotation video of any triangular mesh (MP4, WebM, or GIF) | | `plot3d` | Interactive 3D WebGL viewer for Jupyter notebooks | +**Supported mesh inputs for `snap1`, `snap_rotate`, and `plot3d`:** +FreeSurfer binary surfaces (e.g. `lh.white`), OFF (`.off`), legacy ASCII VTK PolyData (`.vtk`), ASCII PLY (`.ply`), or a `(vertices, faces)` NumPy array tuple. + ### Example ```python from whippersnappy import snap1, snap4 -# File-path inputs (FreeSurfer subject directory) +# FreeSurfer surface with overlay +img = snap1('lh.white', + overlay='lh.thickness', + bg_map='lh.curv', + roi='lh.cortex.label') +img.save('snap1.png') + +# Four-view overview (FreeSurfer subject directory) img = snap4(sdir='/path/to/subject', lh_overlay='/path/to/lh.thickness', rh_overlay='/path/to/rh.thickness', colorbar=True, caption='Cortical Thickness (mm)') img.save('snap4.png') -# Single view with background shading and cortex mask -img = snap1('fsaverage/surf/lh.white', - overlay='fsaverage/surf/lh.thickness', - bg_map='fsaverage/surf/lh.curv', - roi='fsaverage/label/lh.cortex.label') -img.save('snap1.png') +# OFF / VTK / PLY mesh +img = snap1('mesh.off', overlay='values.mgh') # Array inputs (e.g. from LaPy or trimesh) import numpy as np v = np.random.randn(1000, 3).astype(np.float32) -f = np.array([[0, 1, 2]], dtype=np.uint32) # minimal example -my_values = np.random.randn(1000).astype(np.float32) -tria_curv = np.random.randn(1000).astype(np.float32) -img = snap1((v, f), overlay=my_values, bg_map=tria_curv) +f = np.array([[0, 1, 2]], dtype=np.uint32) +overlay = np.random.randn(1000).astype(np.float32) +img = snap1((v, f), overlay=overlay) ``` CLI usage: @@ -134,6 +158,7 @@ whippersnap4 -lh lh.thickness -rh rh.thickness -sd /path/to/subject -o snap4.png See `tutorials/whippersnappy_tutorial.ipynb` for complete notebook examples. + ## Docker The Docker image provides a fully headless EGL rendering environment — no diff --git a/pyproject.toml b/pyproject.toml index 4c05456..d7d81d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,10 +26,10 @@ classifiers = [ 'Operating System :: Unix', 'Operating System :: MacOS', 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', 'Natural Language :: English', 'License :: OSI Approved :: MIT License', 'Intended Audience :: Science/Research', diff --git a/whippersnappy/__init__.py b/whippersnappy/__init__.py index 83f119f..d656cbc 100644 --- a/whippersnappy/__init__.py +++ b/whippersnappy/__init__.py @@ -5,41 +5,38 @@ - **Static rendering**: `snap1()` and `snap4()` functions for publication-quality images - **3D plotting**: For Jupyter notebooks with mouse-controlled 3D (via Three.js) -- **GUI**: Desktop application with `--interactive` flag -- **CLI tools**: Command-line interface for batch processing -- **Custom shaders**: Full control over OpenGL lighting and rendering +- **GUI**: Interactive desktop viewer via the ``whippersnap`` command +- **CLI tools**: ``whippersnap1`` and ``whippersnap4`` for batch processing +- **Local mesh IO**: OFF, VTK ASCII PolyData, and PLY in addition to FreeSurfer surfaces -For static image generation: +For static image generation:: from whippersnappy import snap1, snap4 from whippersnappy.utils.types import ViewType from IPython.display import display - img = snap1(mesh='path/to/surface.white', view=ViewType.LEFT) + img = snap1('path/to/lh.white', view=ViewType.LEFT) display(img) -For interactive 3D in Jupyter notebooks: +For interactive 3D in Jupyter notebooks:: # Requires: pip install 'whippersnappy[notebook]' from whippersnappy import plot3d viewer = plot3d( - mesh='path/to/surface.white', - bg_map='path/to/curv', - overlay='path/to/thickness.mgh' # optional: for colors + mesh='path/to/lh.white', + bg_map='path/to/lh.curv', + overlay='path/to/lh.thickness', # optional: for colors ) display(viewer) -For desktop GUI: +For the interactive desktop GUI:: - # Command line - whippersnap --interactive -lh path/to/lh.white -rh path/to/rh.white - -Features: -- Works in ALL Jupyter environments (browser, JupyterLab, Colab, VS Code) -- Mouse-controlled rotation, zoom, and pan -- Professional lighting via Three.js/WebGL -- Same technology Plotly uses for 3D plots + # Requires: pip install 'whippersnappy[gui]' + # General mode — any mesh file: + whippersnap --mesh lh.white --overlay lh.thickness --bg-map lh.curv + # FreeSurfer shortcut — derive paths from subject directory: + whippersnap -sd path/to/subject_dir --hemi lh -lh lh.thickness """ diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index fca3c0e..b9da7bb 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -2,17 +2,26 @@ """Interactive GUI viewer for WhipperSnapPy. -Opens a live OpenGL window for a single hemisphere together with a -Qt-based configuration panel that allows adjusting overlay thresholds -at runtime. +Opens a live OpenGL window for any triangular surface mesh together with a +Qt-based configuration panel that allows adjusting overlay thresholds at +runtime. -Usage:: +Two input modes are supported: - whippersnap -lh -sd - whippersnap --lh_annot --rh_annot -sd +**General mode** — pass any mesh file directly:: + + whippersnap --mesh mesh.off --overlay values.mgh + whippersnap --mesh lh.white --overlay lh.thickness --bg-map lh.curv + +**FreeSurfer shortcut** — pass a subject directory and hemisphere; all +FreeSurfer paths are derived automatically:: + + whippersnap -sd --hemi lh --overlay lh.thickness + whippersnap -sd --hemi rh --annot rh.aparc.annot See ``whippersnap --help`` for the full list of options. For non-interactive four-view batch rendering use ``whippersnap4``. +For single-view non-interactive snapshots use ``whippersnap1``. """ import argparse @@ -49,52 +58,51 @@ def show_window( - hemi, + mesh, overlay=None, annot=None, - sdir=None, + bg_map=None, + roi=None, invert=False, - labelname="cortex.label", - surfname=None, - curvname="curv", specular=True, + view=ViewType.LEFT, ): - """Start a live interactive OpenGL window for viewing a hemisphere. + """Start a live interactive OpenGL window for viewing a triangular mesh. + + The function initializes a GLFW window and renders the provided mesh + with any supplied overlay or annotation. It polls for threshold updates + from the Qt configuration panel and re-renders whenever the thresholds + change. - The function initializes a GLFW window and renders the requested - hemisphere with any provided overlay/annotation. It polls for - configuration updates from the separate configuration GUI and updates - the rendered scene accordingly. + ``mesh`` is a fully resolved path (or ``(vertices, faces)`` tuple); + all FreeSurfer path-building is performed in :func:`run` before this + function is called. Parameters ---------- - hemi : {'lh','rh'} - Hemisphere to display. - overlay : str or None, optional - Path to a per-vertex overlay file (e.g. thickness). - annot : str or None, optional - Path to a ``.annot`` file providing categorical labels for vertices. - sdir : str or None, optional - Subject directory containing ``surf/`` and ``label/`` subdirectories. + mesh : str or tuple of (array-like, array-like) + Path to any mesh file supported by :func:`whippersnappy.geometry.inputs.resolve_mesh` + (FreeSurfer binary, ``.off``, ``.vtk``, ``.ply``) **or** a + ``(vertices, faces)`` array tuple. + overlay : str, array-like, or None, optional + Per-vertex scalar overlay — file path or (N,) array. + annot : str, tuple, or None, optional + FreeSurfer ``.annot`` file path or ``(labels, ctab[, names])`` tuple. + bg_map : str, array-like, or None, optional + Per-vertex scalar file or array for background shading (sign → light/dark). + roi : str, array-like, or None, optional + FreeSurfer label file path or boolean (N,) array masking overlay coloring. invert : bool, optional Invert the overlay color mapping. Default is ``False``. - labelname : str, optional - Label filename used to mask vertices. Default is ``'cortex.label'``. - surfname : str or None, optional - Surface basename (e.g. ``'white'``); if ``None`` the function will - auto-detect a suitable surface in ``sdir``. - curvname : str or None, optional - Curvature filename used to texture non-colored regions. - Default is ``'curv'``. specular : bool, optional Enable specular highlights in the shader. Default is ``True``. + view : ViewType, optional + Initial camera view direction. Default is ``ViewType.LEFT``. Raises ------ RuntimeError If the GLFW window or OpenGL context could not be created. - FileNotFoundError - If a requested surface file cannot be located in ``sdir``. """ global current_fthresh_, current_fmax_, app_, app_window_, app_window_closed_ @@ -105,22 +113,8 @@ def show_window( logger.error("Could not create any GLFW window/context. OpenGL context unavailable.") raise RuntimeError("Could not create any GLFW window/context. OpenGL context unavailable.") - if surfname is None: - logger.info("No surf_name provided. Looking for options in surf directory...") - found_surfname = get_surf_name(sdir, hemi) - if found_surfname is None: - msg = f"Could not find a valid surf file in {sdir} for hemi: {hemi}!" - logger.error(msg) - raise FileNotFoundError(msg) - mesh = os.path.join(sdir, "surf", hemi + "." + found_surfname) - else: - mesh = os.path.join(sdir, "surf", hemi + "." + surfname) - - bg_map = os.path.join(sdir, "surf", hemi + "." + curvname) if curvname else None - roi = os.path.join(sdir, "label", hemi + "." + labelname) if labelname else None - view_mats = get_view_matrices() - viewmat = view_mats[ViewType.RIGHT] if hemi == "rh" else view_mats[ViewType.LEFT] + viewmat = view_mats[view] rot_y = pyrr.Matrix44.from_y_rotation(0) meshdata, triangles, fthresh, fmax, neg = prepare_geometry( @@ -187,6 +181,15 @@ def run(): viewer thread and launches the PyQt6 configuration window in the main thread. + Two mutually exclusive input modes are supported: + + **General mode** — supply a mesh file directly with ``--mesh``. + Works with FreeSurfer binary surfaces, OFF, VTK, and PLY files. + + **FreeSurfer shortcut** — supply ``-sd``/``--sdir`` and + ``--hemi``; the surface, curvature, and cortex-label paths are all + derived automatically from the subject directory. + Raises ------ RuntimeError @@ -196,82 +199,210 @@ def run(): Notes ----- - **Required:** - - * ``-sd`` / ``--sdir`` — subject directory containing ``surf/`` and - ``label/`` subdirectories. - * One of the following (not both): + **General mode** (``--mesh`` required): - * ``-lh`` / ``--lh_overlay`` **and** ``-rh`` / ``--rh_overlay`` — per-vertex - scalar overlay files. - * ``--lh_annot`` **and** ``--rh_annot`` — FreeSurfer ``.annot`` - parcellation files. + * ``--mesh`` — path to any triangular mesh file (FreeSurfer binary, + ``.off``, ``.vtk``, ``.ply``). + * ``--overlay`` — per-vertex scalar overlay file path or ``.mgh``. + * ``--annot`` — FreeSurfer ``.annot`` file. + * ``--bg-map`` — per-vertex scalar file for background shading. + * ``--roi`` — FreeSurfer label file or boolean mask for overlay region. - **Optional:** + **FreeSurfer shortcut** (``-sd``/``--sdir`` + ``--hemi`` required): + * ``-sd`` / ``--sdir`` — subject directory containing ``surf/`` and + ``label/`` subdirectories. + * ``--hemi`` — hemisphere to display: ``lh`` or ``rh``. + * ``-lh`` / ``--lh_overlay`` or ``-rh`` / ``--rh_overlay`` — overlay + for the respective hemisphere (shorthand; equivalent to ``--overlay``). + * ``--annot`` — annotation file (full path). * ``-s`` / ``--surf_name`` — surface basename (e.g. ``white``); auto-detected if not provided. - * ``-c`` / ``--caption`` — caption text shown in the viewer. + * ``--curv-name`` — curvature basename (default: ``curv``). + * ``--label-name`` — cortex label basename (default: ``cortex.label``). + + **Common options** (both modes): + * ``--fthresh`` — overlay threshold (default: 2.0); adjustable live in the GUI. * ``--fmax`` — overlay saturation (default: 4.0); adjustable live in the GUI. * ``--invert`` — invert the color scale. * ``--diffuse`` — use diffuse-only shading (no specular highlights). + * ``--view`` — initial camera view (default: ``left``). Requires ``pip install 'whippersnappy[gui]'``. - For non-interactive four-view batch rendering use ``whippersnap4``. + For non-interactive batch rendering use ``whippersnap4`` or ``whippersnap1``. """ global current_fthresh_, current_fmax_, app_, app_window_ logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + _VIEW_CHOICES = {v.name.lower(): v for v in ViewType} + parser = argparse.ArgumentParser( prog="whippersnap", description=( - "Interactive GUI viewer for a single hemisphere. " - "For batch four-view rendering use whippersnap4." + "Interactive GUI viewer for any triangular surface mesh. " + "Pass --mesh for a direct mesh file, or -sd/--sdir + --hemi " + "for the FreeSurfer subject-directory shortcut. " + "For non-interactive batch rendering use whippersnap4 or whippersnap1." ), ) parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}") - parser.add_argument("-lh", "--lh_overlay", type=str, default=None, - help="Path to the lh overlay file.") - parser.add_argument("-rh", "--rh_overlay", type=str, default=None, - help="Path to the rh overlay file.") - parser.add_argument("--lh_annot", type=str, default=None, - help="Path to the lh annotation file.") - parser.add_argument("--rh_annot", type=str, default=None, - help="Path to the rh annotation file.") - parser.add_argument("-sd", "--sdir", type=str, required=True, - help="Subject directory containing surf/ and label/ subdirectories.") - parser.add_argument("-s", "--surf_name", type=str, default=None, - help="Surface basename to load (e.g. 'white').") - parser.add_argument("-c", "--caption", type=str, default="", - help="Caption text.") - parser.add_argument("--fmax", type=float, default=4.0, - help="Overlay saturation value (default: 4.0).") - parser.add_argument("--fthresh", type=float, default=2.0, - help="Overlay threshold value (default: 2.0).") - parser.add_argument("--invert", action="store_true", - help="Invert the color scale.") - parser.add_argument("--diffuse", dest="specular", action="store_false", default=True, - help="Diffuse-only shading (no specular).") + + # --- General mesh mode --- + general = parser.add_argument_group( + "general mode", + "Load any triangular mesh directly (OFF, VTK, PLY, FreeSurfer binary).", + ) + general.add_argument( + "--mesh", type=str, default=None, + help="Path to any triangular mesh file (.off, .vtk, .ply, or FreeSurfer binary).", + ) + general.add_argument( + "--bg-map", dest="bg_map", type=str, default=None, + help="Per-vertex scalar file for background shading (sign → light/dark).", + ) + general.add_argument( + "--roi", type=str, default=None, + help="FreeSurfer label file or boolean mask restricting overlay coloring.", + ) + + # --- FreeSurfer shortcut mode --- + fs = parser.add_argument_group( + "FreeSurfer shortcut", + "Derive mesh, curvature, and label paths from a subject directory.", + ) + fs.add_argument( + "-sd", "--sdir", type=str, default=None, + help="Subject directory containing surf/ and label/ subdirectories.", + ) + fs.add_argument( + "--hemi", type=str, default=None, choices=["lh", "rh"], + help="Hemisphere to display: lh or rh.", + ) + fs.add_argument( + "-s", "--surf_name", type=str, default=None, + help="Surface basename (e.g. 'white'); auto-detected if not provided.", + ) + fs.add_argument( + "--curv-name", dest="curv_name", type=str, default="curv", + help="Curvature file basename for background shading (default: curv).", + ) + fs.add_argument( + "--label-name", dest="label_name", type=str, default="cortex.label", + help="Cortex label basename for overlay masking (default: cortex.label).", + ) + + # Hemisphere-prefixed overlay shortcuts (FreeSurfer convention) + fs.add_argument( + "-lh", "--lh_overlay", type=str, default=None, + help="Shorthand for --overlay when using lh hemisphere (e.g. lh.thickness).", + ) + fs.add_argument( + "-rh", "--rh_overlay", type=str, default=None, + help="Shorthand for --overlay when using rh hemisphere (e.g. rh.thickness).", + ) + + # --- Inputs common to both modes --- + common = parser.add_argument_group("overlay / annotation (both modes)") + common.add_argument( + "--overlay", type=str, default=None, + help="Per-vertex scalar overlay file path.", + ) + common.add_argument( + "--annot", type=str, default=None, + help="FreeSurfer .annot file for parcellation coloring.", + ) + + # --- Appearance / rendering --- + rend = parser.add_argument_group("appearance") + rend.add_argument("--fmax", type=float, default=4.0, + help="Overlay saturation value (default: 4.0).") + rend.add_argument("--fthresh", type=float, default=2.0, + help="Overlay threshold value (default: 2.0).") + rend.add_argument("--invert", action="store_true", + help="Invert the color scale.") + rend.add_argument("--diffuse", dest="specular", action="store_false", default=True, + help="Diffuse-only shading (no specular highlights).") + rend.add_argument( + "--view", type=str, default="left", choices=list(_VIEW_CHOICES), + help="Initial camera view direction (default: left).", + ) args = parser.parse_args() + # ------------------------------------------------------------------ + # Resolve the two modes and build the final mesh / bg_map / roi paths + # ------------------------------------------------------------------ + fs_mode = args.sdir is not None or args.hemi is not None + general_mode = args.mesh is not None + try: - if (args.lh_overlay or args.rh_overlay) and (args.lh_annot or args.rh_annot): + if fs_mode and general_mode: + raise ValueError( + "Cannot combine --mesh with -sd/--sdir or --hemi. " + "Use either general mode (--mesh) or FreeSurfer mode (-sd + --hemi)." + ) + if not fs_mode and not general_mode: raise ValueError( - "Cannot use lh_overlay/rh_overlay and lh_annot/rh_annot at the same time." + "Either --mesh (general mode) or both -sd/--sdir and --hemi " + "(FreeSurfer shortcut) must be provided." ) - if not any([args.lh_overlay, args.rh_overlay, args.lh_annot, args.rh_annot]): + + if fs_mode: + if args.sdir is None or args.hemi is None: + raise ValueError( + "FreeSurfer mode requires both -sd/--sdir and --hemi." + ) + + # Resolve overlay: --overlay takes precedence; -lh/-rh are shorthands + overlay = args.overlay + if overlay is None: + if args.hemi == "lh" and args.lh_overlay: + overlay = args.lh_overlay + elif args.hemi == "rh" and args.rh_overlay: + overlay = args.rh_overlay + elif args.lh_overlay and not fs_mode: + raise ValueError( + "-lh/--lh_overlay is only valid in FreeSurfer mode (with --hemi lh)." + ) + elif args.rh_overlay and not fs_mode: + raise ValueError( + "-rh/--rh_overlay is only valid in FreeSurfer mode (with --hemi rh)." + ) + + if overlay and args.annot: + raise ValueError("Cannot combine --overlay/hemisphere overlay and --annot.") + if not overlay and not args.annot and not general_mode: raise ValueError( - "Either lh_overlay/rh_overlay or lh_annot/rh_annot must be present." + "Either an overlay (-lh/-rh/--overlay) or --annot must be provided." ) - if (args.lh_overlay is None) != (args.rh_overlay is None): - raise ValueError("Both -lh and -rh overlays must be provided together.") - if (args.lh_annot is None) != (args.rh_annot is None): - raise ValueError("Both --lh_annot and --rh_annot must be provided together.") + except ValueError as e: parser.error(str(e)) + # Build resolved mesh / bg_map / roi + if fs_mode: + hemi = args.hemi + sdir = args.sdir + if args.surf_name is None: + found = get_surf_name(sdir, hemi) + if found is None: + parser.error(f"Could not find a valid surface in {sdir} for hemi {hemi!r}.") + mesh_path = os.path.join(sdir, "surf", f"{hemi}.{found}") + else: + mesh_path = os.path.join(sdir, "surf", f"{hemi}.{args.surf_name}") + bg_map = os.path.join(sdir, "surf", f"{hemi}.{args.curv_name}") if args.curv_name else None + roi = os.path.join(sdir, "label", f"{hemi}.{args.label_name}") if args.label_name else None + view = ViewType.RIGHT if hemi == "rh" else ViewType.LEFT + else: + mesh_path = args.mesh + bg_map = args.bg_map + roi = args.roi + view = _VIEW_CHOICES[args.view] + + # ------------------------------------------------------------------ + # Start the Qt app + OpenGL thread + # ------------------------------------------------------------------ if QApplication is None: print( "ERROR: Interactive mode requires PyQt6. " @@ -292,18 +423,19 @@ def run(): ) from e current_fthresh_ = args.fthresh - current_fmax_ = args.fmax + current_fmax_ = args.fmax thread = threading.Thread( target=show_window, - args=("lh",), kwargs=dict( - overlay=args.lh_overlay, - annot=args.lh_annot, - sdir=args.sdir, + mesh=mesh_path, + overlay=overlay, + annot=args.annot, + bg_map=bg_map, + roi=roi, invert=args.invert, - surfname=args.surf_name, specular=args.specular, + view=view, ), ) thread.start() diff --git a/whippersnappy/cli/whippersnap1.py b/whippersnappy/cli/whippersnap1.py index aef3bf9..2ec9777 100644 --- a/whippersnappy/cli/whippersnap1.py +++ b/whippersnappy/cli/whippersnap1.py @@ -1,13 +1,17 @@ #!/usr/bin/env python3 """CLI entry point for single-mesh snapshot and rotation video via snap1/snap_rotate. -Renders a single surface mesh from a chosen viewpoint and saves it as a PNG -image. Alternatively, pass ``--rotate`` to produce a full 360° rotation video -(MP4, WebM, or GIF). +Renders any triangular surface mesh from a chosen viewpoint and saves it as a +PNG image. Alternatively, pass ``--rotate`` to produce a full 360° rotation +video (MP4, WebM, or GIF). + +The mesh can be a FreeSurfer binary surface (e.g. ``lh.white``), an ASCII OFF +file (``mesh.off``), a legacy ASCII VTK PolyData file (``mesh.vtk``), or an +ASCII PLY file (``mesh.ply``). Usage:: - # Static single-view snapshot (lateral view, thickness overlay) + # FreeSurfer surface — lateral view with thickness overlay whippersnap1 /surf/lh.white \\ --overlay /surf/lh.thickness \\ --bg-map /surf/lh.curv \\ @@ -15,6 +19,11 @@ --view left --fthresh 1.5 --fmax 4.0 \\ -o snap1.png + # OFF / VTK / PLY mesh with a numpy-saved overlay + whippersnap1 mesh.off --overlay values.mgh -o snap1.png + whippersnap1 mesh.vtk -o snap1.png + whippersnap1 mesh.ply --overlay values.mgh -o snap1.png + # 360° rotation video whippersnap1 /surf/lh.white \\ --overlay /surf/lh.thickness \\ @@ -24,7 +33,7 @@ # Parcellation annotation whippersnap1 /surf/lh.white \\ --annot /label/lh.aparc.annot \\ - --view lateral -o snap_annot.png + --view left -o snap_annot.png See ``whippersnap1 --help`` for the full list of options. For four-view batch rendering use ``whippersnap4``. @@ -67,11 +76,13 @@ def run(): ----- **Snapshot options** (default mode): - * ``mesh`` — path to the surface file (FreeSurfer binary, e.g. ``lh.white``). - * ``--overlay`` — per-vertex scalar overlay (e.g. ``lh.thickness``). + * ``mesh`` — path to any triangular surface mesh: FreeSurfer binary + (e.g. ``lh.white``), ASCII OFF (``.off``), legacy ASCII VTK PolyData + (``.vtk``), or ASCII PLY (``.ply``). + * ``--overlay`` — per-vertex scalar overlay (e.g. ``lh.thickness`` or a ``.mgh`` file). * ``--annot`` — FreeSurfer ``.annot`` parcellation file. - * ``--label`` — label file used to mask overlay values to the cortex. - * ``--curv`` — curvature file for sulcal depth shading of uncolored vertices. + * ``--roi`` — FreeSurfer label file defining vertices to include in overlay coloring. + * ``--bg-map`` — per-vertex scalar file whose sign controls light/dark background shading. * ``--view`` — camera direction: ``left``, ``right``, ``front``, ``back``, ``top``, ``bottom`` (default: ``left``). * ``--fthresh`` / ``--fmax`` — overlay threshold and saturation values. @@ -104,8 +115,11 @@ def run(): parser.add_argument( "mesh", type=str, - help="Path to the surface file. FreeSurfer binary format (e.g. lh.white) " - "or any mesh readable by the geometry module.", + help=( + "Path to the surface mesh file. Supported formats: " + "FreeSurfer binary surface (e.g. lh.white, rh.pial), " + "ASCII OFF (.off), legacy ASCII VTK PolyData (.vtk), ASCII PLY (.ply)." + ), ) # --- Output --- diff --git a/whippersnappy/plot3d.py b/whippersnappy/plot3d.py index f43c684..3b70a6e 100644 --- a/whippersnappy/plot3d.py +++ b/whippersnappy/plot3d.py @@ -1,21 +1,20 @@ """3D plotting for WhipperSnapPy using pythreejs (Three.js) for Jupyter. -This module provides interactive 3D brain visualization for Jupyter notebooks using -Three.js/WebGL. It works in all Jupyter environments (browser, JupyterLab, Colab, VS Code). +This module provides interactive 3D brain visualization for Jupyter notebooks +using Three.js/WebGL. It works in all Jupyter environments (browser, +JupyterLab, Colab, VS Code). -Unlike the desktop GUI (launched with --interactive flag), this plots in the browser -using WebGL and is specifically designed for notebook environments. +Unlike the desktop GUI (``whippersnap`` command), this renders entirely in the +browser via WebGL and is designed for notebook environments. + +Usage:: -Usage: from whippersnappy import plot3d viewer = plot3d(mesh='path/to/lh.white', bg_map='path/to/lh.curv') display(viewer) Dependencies: pythreejs, ipywidgets, numpy - -@Author : Martin Reuter -@Created : 14.02.2026 """ import logging @@ -54,37 +53,42 @@ def plot3d( pythreejs renderer and controls wrapped in an ``ipywidgets.VBox`` for display inside a Jupyter notebook. + The mesh can be any triangular surface — not just brain surfaces. + Supported file formats: FreeSurfer binary surface (e.g. ``lh.white``), + ASCII OFF (``.off``), legacy ASCII VTK PolyData (``.vtk``), ASCII PLY + (``.ply``), or a ``(vertices, faces)`` numpy array tuple. + Parameters ---------- mesh : str or tuple of (array-like, array-like) - Path to the surface file (FreeSurfer-style surface, e.g. ``"lh.white"``) - **or** a ``(vertices, faces)`` tuple. + Path to a mesh file (FreeSurfer binary, ``.off``, ``.vtk``, or + ``.ply``) **or** a ``(vertices, faces)`` tuple. overlay : str, array-like, or None, optional - Path to a per-vertex overlay (thickness/curvature) file, or a (N,) - array of per-vertex scalar values. + Path to a per-vertex scalar file, or a (N,) array of per-vertex + scalar values. annot : str, tuple, or None, optional Path to a FreeSurfer .annot file, or a ``(labels, ctab)`` / ``(labels, ctab, names)`` tuple for categorical labeling. bg_map : str, array-like, or None, optional - Path to a curvature file **or** a (N,) array used as grayscale - texture for non-overlay regions. + Path to a per-vertex scalar file **or** a (N,) array used as + grayscale background shading for non-overlay regions. roi : str, array-like, or None, optional Path to a FreeSurfer label file **or** a (N,) boolean array to - restrict overlay coloring. + restrict overlay coloring to a subset of vertices. minval, maxval : float or None, optional - Threshold and saturation values used for color mapping (passed to - :func:`prepare_geometry`). If ``None``, sensible defaults are chosen. - invert : bool, optional, default False - If True, invert the overlay color map. - scale : float, optional, default 1.85 - Global geometry scale applied during preparation. + Threshold and saturation values used for color mapping. + If ``None``, sensible defaults are chosen automatically. + invert : bool, optional + If True, invert the overlay color map. Default is ``False``. + scale : float, optional + Global geometry scale applied during preparation. Default is ``1.85``. color_mode : ColorSelection or None, optional Which sign of overlay values to color (BOTH/POSITIVE/NEGATIVE). If None, defaults to ``ColorSelection.BOTH``. - width, height : int, optional, default 800 - Canvas dimensions for the generated renderer. - ambient : float, optional, default 0.1 - Ambient lighting strength for the shader (passed to Three.js uniform). + width, height : int, optional + Canvas dimensions for the generated renderer. Default is ``800``. + ambient : float, optional + Ambient lighting strength passed to the Three.js shader. Default is ``0.1``. Returns ------- @@ -94,8 +98,8 @@ def plot3d( Raises ------ ValueError, FileNotFoundError - Errors originating from :func:`prepare_geometry` (for example when - input arrays don't match the mesh vertex count) are propagated. + Errors from :func:`prepare_geometry` are propagated (for example + shape mismatches between overlay and mesh vertex count). Examples -------- @@ -104,7 +108,19 @@ def plot3d( from whippersnappy import plot3d from IPython.display import display - viewer = plot3d('fsaverage/surf/lh.white', overlay='fsaverage/surf/lh.thickness') + # FreeSurfer surface + viewer = plot3d('lh.white', overlay='lh.thickness', bg_map='lh.curv') + display(viewer) + + # Any triangular mesh via OFF / VTK / PLY + viewer = plot3d('mesh.off', overlay='values.mgh') + display(viewer) + + # Array inputs + import numpy as np + v = np.random.randn(500, 3).astype(np.float32) + f = np.zeros((1, 3), dtype=np.uint32) + viewer = plot3d((v, f)) display(viewer) """ # Load and prepare mesh data diff --git a/whippersnappy/snap.py b/whippersnappy/snap.py index 52efa0d..176d98e 100644 --- a/whippersnappy/snap.py +++ b/whippersnappy/snap.py @@ -48,54 +48,59 @@ def snap1( brain_scale=1.5, ambient=0.0, ): - """Render a single static snapshot of a surface view. + """Render a single static snapshot of a surface mesh. This function opens an OpenGL context, uploads the provided surface geometry and colors (overlay or annotation), renders the scene - for a single view, captures the framebuffer, and returns a PIL Image - containing the rendered brain view. When ``outpath`` is provided the - image is also written to disk. + for a single view, captures the framebuffer, and returns a PIL Image. + When ``outpath`` is provided the image is also written to disk. + + The mesh can be any triangular surface — not just brain surfaces. + Supported file formats: FreeSurfer binary surface (e.g. ``lh.white``), + ASCII OFF (``.off``), legacy ASCII VTK PolyData (``.vtk``), ASCII PLY + (``.ply``), or a ``(vertices, faces)`` numpy array tuple. Parameters ---------- mesh : str or tuple of (array-like, array-like) - Path to the surface file (FreeSurfer-format, e.g. ``"lh.white"``) **or** - a ``(vertices, faces)`` tuple where *vertices* is (N, 3) float and - *faces* is (M, 3) int. + Path to a mesh file (FreeSurfer binary, ``.off``, ``.vtk``, or + ``.ply``) **or** a ``(vertices, faces)`` tuple where *vertices* is + (N, 3) float and *faces* is (M, 3) int. outpath : str or None, optional When provided, the resulting image is saved to this path. overlay : str, array-like, or None, optional - Overlay file path (.mgh or FreeSurfer morph) **or** a (N,) array of - per-vertex scalar values. If ``None``, coloring falls back to + Overlay file path (``.mgh`` or FreeSurfer morph) **or** a (N,) array + of per-vertex scalar values. If ``None``, coloring falls back to background shading / annotation. annot : str, tuple, or None, optional Path to a FreeSurfer .annot file **or** a ``(labels, ctab)`` / ``(labels, ctab, names)`` tuple with per-vertex labels. bg_map : str, array-like, or None, optional - Path to a curvature/morph file **or** a (N,) array whose sign + Path to a per-vertex scalar file **or** a (N,) array whose sign determines light/dark background shading for non-overlay vertices. roi : str, array-like, or None, optional Path to a FreeSurfer label file **or** a (N,) boolean array. Vertices with ``True`` receive overlay coloring; others fall back to *bg_map* shading. view : ViewType, optional - Which pre-defined view to render (left, right, front, ...). Default is ``ViewType.LEFT``. + Which pre-defined view to render (left, right, front, ...). + Default is ``ViewType.LEFT``. viewmat : 4x4 matrix-like, optional Optional view matrix to override the pre-defined view. width, height : int, optional - Requested overall canvas width/height in pixels. Defaults to (700x500). + Output canvas size in pixels. Defaults to (700×500). fthresh, fmax : float or None, optional Threshold and saturation values for overlay coloring. caption, caption_x, caption_y, caption_scale : str/float, optional - Caption text and layout parameters. Caption defaults to ``None`` and caption_scale defaults to 1. + Caption text and layout parameters. invert : bool, optional Invert the color scale. Default is ``False``. colorbar : bool, optional If True, render a colorbar when an overlay is present. Default is ``True``. colorbar_x, colorbar_y, colorbar_scale : float, optional - Colorbar positioning and scale flags. Scale defaults to 1. + Colorbar positioning and scale. Scale defaults to 1. orientation : OrientationType, optional - Orientation of the colorbar (HORIZONTAL/VERTICAL). Default is ``OrientationType.HORIZONTAL``. + Colorbar orientation (HORIZONTAL/VERTICAL). Default is ``OrientationType.HORIZONTAL``. color_mode : ColorSelection, optional Which sign of overlay to color (POSITIVE/NEGATIVE/BOTH). Default is ``ColorSelection.BOTH``. font_file : str or None, optional @@ -110,7 +115,7 @@ def snap1( Returns ------- PIL.Image.Image - Returns a PIL Image object containing the rendered snapshot. + Rendered snapshot as a PIL Image. Raises ------ @@ -120,22 +125,27 @@ def snap1( If the overlay contains no values to display for the chosen color_mode. FileNotFoundError - If required surface files cannot be found when deriving from - SUBJECTS_DIR in multi-view helpers. + If a required file cannot be found. Examples -------- - >>> from whippersnappy import snap1 - >>> img = snap1('fsaverage/surf/lh.white', overlay='fsaverage/surf/lh.thickness', - ... bg_map='fsaverage/surf/lh.curv', roi='fsaverage/label/lh.cortex.label') - >>> img.save('/tmp/lh.png') + FreeSurfer surface with overlay:: + + >>> from whippersnappy import snap1 + >>> img = snap1('lh.white', overlay='lh.thickness', + ... bg_map='lh.curv', roi='lh.cortex.label') + >>> img.save('/tmp/lh.png') + + Array inputs (any triangular mesh):: + + >>> import numpy as np + >>> v = np.random.randn(100, 3).astype(np.float32) + >>> f = np.array([[0, 1, 2]], dtype=np.uint32) + >>> img = snap1((v, f)) - Array inputs:: + OFF / VTK / PLY file:: - >>> import numpy as np - >>> v = np.random.randn(100, 3).astype(np.float32) - >>> f = np.array([[0, 1, 2]], dtype=np.uint32) - >>> img = snap1((v, f)) + >>> img = snap1('mesh.off', overlay='values.mgh') """ ref_width = 700 ref_height = 500 @@ -288,11 +298,12 @@ def snap4( ---------- lh_overlay, rh_overlay : str, array-like, or None Left/right hemisphere overlay — either a file path (FreeSurfer morph - or .mgh) or a per-vertex scalar array. Mutually required if either - is provided. + or .mgh) or a per-vertex scalar array. Typically provided as a pair + for a coherent two-hemisphere color scale. lh_annot, rh_annot : str, tuple, or None Left/right hemisphere annotation — either a path to a .annot file or a ``(labels, ctab)`` / ``(labels, ctab, names)`` tuple. + Cannot be combined with ``lh_overlay``/``rh_overlay``. fthresh, fmax : float or None Threshold and saturation for overlay coloring. Auto-estimated when ``None``. @@ -531,17 +542,19 @@ def snap_rotate( Rotates the view around the vertical (Y) axis in ``n_frames`` equal steps, captures each frame via OpenGL, and encodes the result into a - compressed video file using ``imageio`` with the ``ffmpeg`` backend - (provided by ``imageio-ffmpeg``). + video file. An animated GIF can be produced by passing an ``outpath`` + ending in ``.gif``; in that case ``imageio-ffmpeg`` is not required. - An animated GIF can also be produced by passing an ``outpath`` ending - in ``.gif``; in that case ``imageio-ffmpeg`` is not required. + The mesh can be any triangular surface — not just brain surfaces. + Supported file formats: FreeSurfer binary surface, ASCII OFF (``.off``), + legacy ASCII VTK PolyData (``.vtk``), ASCII PLY (``.ply``), or a + ``(vertices, faces)`` numpy array tuple. Parameters ---------- mesh : str or tuple of (array-like, array-like) - Path to the surface file (FreeSurfer binary format, e.g. ``lh.white``) - **or** a ``(vertices, faces)`` tuple. + Path to a mesh file (FreeSurfer binary, ``.off``, ``.vtk``, or + ``.ply``) **or** a ``(vertices, faces)`` tuple. outpath : str Destination file path. The extension controls the output format: From 1ab18bf994408e04f80923514eb00ec14f4cf681 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 17:43:05 +0100 Subject: [PATCH 64/83] add support for reading multiple overlay file formats and gifti surfaces --- tests/test_mesh_io.py | 120 +++++++++ tests/test_overlay_io.py | 351 +++++++++++++++++++++++++++ whippersnappy/geometry/__init__.py | 9 +- whippersnappy/geometry/inputs.py | 64 ++++- whippersnappy/geometry/mesh_io.py | 144 ++++++++++- whippersnappy/geometry/overlay_io.py | 346 ++++++++++++++++++++++++++ 6 files changed, 1009 insertions(+), 25 deletions(-) create mode 100644 tests/test_overlay_io.py create mode 100644 whippersnappy/geometry/overlay_io.py diff --git a/tests/test_mesh_io.py b/tests/test_mesh_io.py index 11a79db..543d7b3 100644 --- a/tests/test_mesh_io.py +++ b/tests/test_mesh_io.py @@ -12,6 +12,7 @@ from whippersnappy.geometry.inputs import resolve_mesh from whippersnappy.geometry.mesh_io import ( + read_gifti_surface, read_mesh, read_off, read_ply_ascii, @@ -395,3 +396,122 @@ def test_array_out_of_range_raises(self): f_in = np.array([[0, 1, 99]], dtype=np.uint32) # index 99 out of range with pytest.raises(ValueError, match="out of range"): resolve_mesh((v_in, f_in)) + + +# --------------------------------------------------------------------------- +# GIfTI surface reader +# --------------------------------------------------------------------------- + +def _make_surf_gii(verts, faces, suffix=".surf.gii"): + """Write a minimal GIfTI surface file and return its path.""" + import nibabel as nib + from nibabel import nifti1 + _INTENT_POINTSET = 1008 + _INTENT_TRIANGLE = 1009 + coords_da = nib.gifti.GiftiDataArray( + data=verts.astype(np.float32), + intent=_INTENT_POINTSET, + datatype="NIFTI_TYPE_FLOAT32", + ) + faces_da = nib.gifti.GiftiDataArray( + data=faces.astype(np.int32), + intent=_INTENT_TRIANGLE, + datatype="NIFTI_TYPE_INT32", + ) + img = nib.gifti.GiftiImage(darrays=[coords_da, faces_da]) + fd, path = tempfile.mkstemp(suffix=suffix) + os.close(fd) + nib.save(img, path) + return path + + +_V4 = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32) +_F4 = np.array([[0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3]], dtype=np.uint32) + + +class TestReadGiftiSurface: + def test_surf_gii_basic(self): + path = _make_surf_gii(_V4, _F4, ".surf.gii") + try: + v, f = read_gifti_surface(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + assert v.dtype == np.float32 + assert f.shape == (4, 3) + assert f.dtype == np.uint32 + np.testing.assert_allclose(v, _V4, atol=1e-6) + np.testing.assert_array_equal(f, _F4) + + def test_plain_gii_extension(self): + """A .gii file with POINTSET+TRIANGLE should also be loaded as surface.""" + path = _make_surf_gii(_V4, _F4, ".gii") + try: + v, f = read_gifti_surface(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + assert f.shape == (4, 3) + + def test_no_pointset_raises(self): + """A plain scalar .gii without POINTSET arrays should raise.""" + import nibabel as nib + scalar_da = nib.gifti.GiftiDataArray( + data=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), + intent=0, # NIFTI_INTENT_NONE + ) + img = nib.gifti.GiftiImage(darrays=[scalar_da]) + fd, path = tempfile.mkstemp(suffix=".gii") + os.close(fd) + nib.save(img, path) + try: + with pytest.raises(ValueError, match="POINTSET"): + read_gifti_surface(path) + finally: + os.unlink(path) + + def test_no_triangle_raises(self): + """A .gii with only a POINTSET but no TRIANGLE array should raise.""" + import nibabel as nib + coords_da = nib.gifti.GiftiDataArray( + data=_V4.astype(np.float32), + intent=1008, + ) + img = nib.gifti.GiftiImage(darrays=[coords_da]) + fd, path = tempfile.mkstemp(suffix=".gii") + os.close(fd) + nib.save(img, path) + try: + with pytest.raises(ValueError, match="TRIANGLE"): + read_gifti_surface(path) + finally: + os.unlink(path) + + +class TestReadMeshGiftiDispatch: + def test_surf_gii_dispatched(self): + path = _make_surf_gii(_V4, _F4, ".surf.gii") + try: + v, f = read_mesh(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + + def test_gii_dispatched(self): + path = _make_surf_gii(_V4, _F4, ".gii") + try: + v, f = read_mesh(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + + def test_resolve_mesh_surf_gii(self): + path = _make_surf_gii(_V4, _F4, ".surf.gii") + try: + v, f = resolve_mesh(path) + finally: + os.unlink(path) + assert v.shape == (4, 3) + assert v.dtype == np.float32 + assert f.dtype == np.uint32 + diff --git a/tests/test_overlay_io.py b/tests/test_overlay_io.py new file mode 100644 index 0000000..a80de80 --- /dev/null +++ b/tests/test_overlay_io.py @@ -0,0 +1,351 @@ +"""Tests for whippersnappy/geometry/overlay_io.py and the updated +_load_overlay_from_file routing in inputs.py. + +All tests write content to temporary files so no external data is required. +""" + +import os +import tempfile + +import numpy as np +import pytest + +from whippersnappy.geometry.inputs import resolve_overlay, resolve_bg_map, resolve_roi +from whippersnappy.geometry.overlay_io import ( + read_npy, + read_npz, + read_overlay, + read_txt, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _write_tmp(content, suffix, binary=False): + fd, path = tempfile.mkstemp(suffix=suffix) + mode = "wb" if binary else "w" + with os.fdopen(fd, mode) as fh: + fh.write(content) + return path + + +_FLOAT_VALUES = [0.1, -1.5, 2.0, 0.0] +_INT_VALUES = [0, 1, 3, 2] + + +# --------------------------------------------------------------------------- +# read_txt — float +# --------------------------------------------------------------------------- + +class TestReadTxtFloat: + def test_basic_floats(self): + content = "\n".join(str(v) for v in _FLOAT_VALUES) + "\n" + path = _write_tmp(content, ".txt") + try: + arr = read_txt(path) + finally: + os.unlink(path) + assert arr.shape == (4,) + assert arr.dtype == np.float32 + np.testing.assert_allclose(arr, _FLOAT_VALUES, atol=1e-6) + + def test_hash_comment_skipped(self): + content = "# this is a comment\n0.5\n1.5\n" + path = _write_tmp(content, ".txt") + try: + arr = read_txt(path) + finally: + os.unlink(path) + assert arr.shape == (2,) + + def test_text_header_skipped(self): + content = "value\n0.1\n0.2\n0.3\n" + path = _write_tmp(content, ".txt") + try: + arr = read_txt(path) + finally: + os.unlink(path) + assert arr.shape == (3,) + np.testing.assert_allclose(arr, [0.1, 0.2, 0.3], atol=1e-6) + + def test_csv_first_column_used(self): + content = "label,ignore\n1.0,extra\n2.0,extra\n" + path = _write_tmp(content, ".csv") + try: + arr = read_txt(path) + finally: + os.unlink(path) + assert arr.shape == (2,) + + def test_empty_file_raises(self): + path = _write_tmp("", ".txt") + try: + with pytest.raises(ValueError, match="No numeric"): + read_txt(path) + finally: + os.unlink(path) + + def test_bad_value_raises(self): + content = "1.0\nbadvalue\n3.0\n" + path = _write_tmp(content, ".txt") + try: + with pytest.raises(ValueError, match="Could not parse"): + read_txt(path) + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# read_txt — integer promotion +# --------------------------------------------------------------------------- + +class TestReadTxtInt: + def test_integer_values_promoted_to_int32(self): + content = "\n".join(str(v) for v in _INT_VALUES) + "\n" + path = _write_tmp(content, ".txt") + try: + arr = read_txt(path) + finally: + os.unlink(path) + assert arr.dtype == np.int32 + np.testing.assert_array_equal(arr, _INT_VALUES) + + def test_mixed_float_stays_float32(self): + content = "0\n1\n2.5\n3\n" + path = _write_tmp(content, ".txt") + try: + arr = read_txt(path) + finally: + os.unlink(path) + assert arr.dtype == np.float32 + + +# --------------------------------------------------------------------------- +# read_npy / read_npz +# --------------------------------------------------------------------------- + +class TestReadNpy: + def test_basic(self): + arr_in = np.array([1.0, 2.0, 3.0], dtype=np.float32) + fd, path = tempfile.mkstemp(suffix=".npy") + os.close(fd) + np.save(path, arr_in) + try: + arr = read_npy(path) + finally: + os.unlink(path) + np.testing.assert_array_equal(arr, arr_in) + + def test_2d_raises(self): + arr_in = np.ones((3, 4), dtype=np.float32) + fd, path = tempfile.mkstemp(suffix=".npy") + os.close(fd) + np.save(path, arr_in) + try: + with pytest.raises(ValueError, match="1-D"): + read_npy(path) + finally: + os.unlink(path) + + def test_column_vector_squeezed(self): + """Shape (N,1) should be squeezed to (N,) successfully.""" + arr_in = np.ones((5, 1), dtype=np.float32) + fd, path = tempfile.mkstemp(suffix=".npy") + os.close(fd) + np.save(path, arr_in) + try: + arr = read_npy(path) + finally: + os.unlink(path) + assert arr.shape == (5,) + + +class TestReadNpz: + def test_data_key(self): + arr_in = np.array([0, 1, 2], dtype=np.int32) + fd, path = tempfile.mkstemp(suffix=".npz") + os.close(fd) + np.savez(path, data=arr_in, other=np.zeros(3)) + try: + arr = read_npz(path) + finally: + os.unlink(path) + np.testing.assert_array_equal(arr, arr_in) + + def test_first_array_fallback(self): + arr_in = np.array([9.0, 8.0], dtype=np.float32) + fd, path = tempfile.mkstemp(suffix=".npz") + os.close(fd) + np.savez(path, arr_0=arr_in) + try: + arr = read_npz(path) + finally: + os.unlink(path) + np.testing.assert_array_equal(arr, arr_in) + + def test_empty_raises(self): + fd, path = tempfile.mkstemp(suffix=".npz") + os.close(fd) + np.savez(path) + try: + with pytest.raises(ValueError, match="no arrays"): + read_npz(path) + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# read_overlay dispatcher +# --------------------------------------------------------------------------- + +class TestReadOverlayDispatcher: + def test_txt_dispatched(self): + path = _write_tmp("1.0\n2.0\n3.0\n", ".txt") + try: + arr = read_overlay(path) + finally: + os.unlink(path) + assert arr.shape == (3,) + + def test_csv_dispatched(self): + path = _write_tmp("0.5\n1.5\n", ".csv") + try: + arr = read_overlay(path) + finally: + os.unlink(path) + assert arr.shape == (2,) + + def test_npy_dispatched(self): + arr_in = np.array([1.0, 2.0], dtype=np.float32) + fd, path = tempfile.mkstemp(suffix=".npy") + os.close(fd) + np.save(path, arr_in) + try: + arr = read_overlay(path) + finally: + os.unlink(path) + np.testing.assert_array_equal(arr, arr_in) + + def test_npz_dispatched(self): + arr_in = np.array([0, 1, 2], dtype=np.int32) + fd, path = tempfile.mkstemp(suffix=".npz") + os.close(fd) + np.savez(path, data=arr_in) + try: + arr = read_overlay(path) + finally: + os.unlink(path) + np.testing.assert_array_equal(arr, arr_in) + + def test_surface_gii_rejected_with_helpful_error(self): + """A .surf.gii passed to read_gifti (overlay reader) should raise clearly.""" + import nibabel as nib + coords_da = nib.gifti.GiftiDataArray( + data=np.array([[0,0,0],[1,0,0],[0,1,0],[0,0,1]], dtype=np.float32), + intent=1008, # POINTSET + ) + faces_da = nib.gifti.GiftiDataArray( + data=np.array([[0,1,2],[0,1,3]], dtype=np.int32), + intent=1009, # TRIANGLE + ) + img = nib.gifti.GiftiImage(darrays=[coords_da, faces_da]) + fd, path = tempfile.mkstemp(suffix=".surf.gii") + os.close(fd) + nib.save(img, path) + try: + with pytest.raises(ValueError, match="surface geometry"): + from whippersnappy.geometry.overlay_io import read_gifti + read_gifti(path) + finally: + os.unlink(path) + + def test_unknown_extension_raises(self): + with pytest.raises(ValueError, match="Unsupported"): + read_overlay("/some/file.xyz") + + def test_case_insensitive_extension(self): + path = _write_tmp("1.0\n2.0\n", ".TXT") + try: + arr = read_overlay(path) + finally: + os.unlink(path) + assert arr.shape == (2,) + + +# --------------------------------------------------------------------------- +# resolve_overlay / resolve_bg_map / resolve_roi routing via inputs.py +# --------------------------------------------------------------------------- + +class TestResolveOverlayRouting: + """Verify that resolve_overlay and resolve_bg_map pick up .txt/.npy files.""" + + def test_txt_path_routed_as_float(self): + content = "0.1\n0.5\n0.9\n0.3\n" + path = _write_tmp(content, ".txt") + try: + arr = resolve_overlay(path, n_vertices=4) + finally: + os.unlink(path) + assert arr.shape == (4,) + assert arr.dtype == np.float32 + + def test_npy_path_routed(self): + arr_in = np.array([1.0, 2.0, 3.0], dtype=np.float32) + fd, path = tempfile.mkstemp(suffix=".npy") + os.close(fd) + np.save(path, arr_in) + try: + arr = resolve_overlay(path, n_vertices=3) + finally: + os.unlink(path) + np.testing.assert_array_equal(arr, arr_in) + + def test_shape_mismatch_raises(self): + content = "0.1\n0.5\n" + path = _write_tmp(content, ".txt") + try: + with pytest.raises(ValueError, match="vertices"): + resolve_overlay(path, n_vertices=5) + finally: + os.unlink(path) + + def test_bg_map_txt_routed(self): + content = "1\n-1\n1\n-1\n" + path = _write_tmp(content, ".txt") + try: + arr = resolve_bg_map(path, n_vertices=4) + finally: + os.unlink(path) + assert arr.shape == (4,) + assert arr.dtype == np.float32 # always cast to float32 by resolve_bg_map + + def test_roi_from_bool_npy(self): + """Boolean .npy array is a valid ROI input after resolve_roi casts it.""" + arr_in = np.array([True, False, True, True]) + fd, path = tempfile.mkstemp(suffix=".npy") + os.close(fd) + np.save(path, arr_in) + try: + # resolve_roi receives a str path; _load_overlay_from_file loads npy + # and returns the array; then resolve_roi casts it to bool. + arr = resolve_roi(path, n_vertices=4) + finally: + os.unlink(path) + assert arr.dtype == bool + np.testing.assert_array_equal(arr, arr_in) + + def test_label_txt_integer_values(self): + """Integer .txt file (parcellation) should be loadable as overlay.""" + content = "3\n0\n1\n3\n" + path = _write_tmp(content, ".txt") + try: + # resolve_overlay casts to float32; original int32 from read_txt + arr = resolve_overlay(path, n_vertices=4) + finally: + os.unlink(path) + assert arr.shape == (4,) + # Values should be preserved numerically + np.testing.assert_array_equal(arr, [3.0, 0.0, 1.0, 3.0]) + diff --git a/whippersnappy/geometry/__init__.py b/whippersnappy/geometry/__init__.py index d2a57b0..b5cf706 100644 --- a/whippersnappy/geometry/__init__.py +++ b/whippersnappy/geometry/__init__.py @@ -3,7 +3,8 @@ Expose prepare_geometry and small IO helpers under `whippersnappy.geometry`. """ from .inputs import resolve_annot, resolve_bg_map, resolve_mesh, resolve_overlay, resolve_roi -from .mesh_io import read_mesh, read_off, read_ply_ascii, read_vtk_ascii_polydata +from .mesh_io import read_gifti_surface, read_mesh, read_off, read_ply_ascii, read_vtk_ascii_polydata +from .overlay_io import read_gifti, read_npy, read_npz, read_overlay, read_txt from .prepare import ( estimate_overlay_thresholds, prepare_and_validate_geometry, @@ -27,6 +28,12 @@ 'read_off', 'read_vtk_ascii_polydata', 'read_ply_ascii', + 'read_gifti_surface', + 'read_overlay', + 'read_txt', + 'read_npy', + 'read_npz', + 'read_gifti', 'read_geometry', 'read_annot_data', 'read_mgh_data', diff --git a/whippersnappy/geometry/inputs.py b/whippersnappy/geometry/inputs.py index 0091938..03e3e96 100644 --- a/whippersnappy/geometry/inputs.py +++ b/whippersnappy/geometry/inputs.py @@ -13,10 +13,15 @@ from ..utils.colormap import mask_label from .mesh_io import read_mesh as _read_mesh_by_ext +from .overlay_io import read_overlay as _read_overlay_by_ext from .read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data # Extensions handled by the lightweight ASCII mesh readers in mesh_io.py -_MESH_IO_EXTS = frozenset({".off", ".vtk", ".ply"}) +# (includes GIfTI surface via nibabel) +_MESH_IO_EXTS = frozenset({".off", ".vtk", ".ply", ".gii"}) + +# Extensions handled by overlay_io.py (everything except FreeSurfer morph / MGH) +_OVERLAY_IO_EXTS = frozenset({".txt", ".csv", ".npy", ".npz", ".gii"}) def resolve_mesh(mesh): @@ -49,8 +54,9 @@ def resolve_mesh(mesh): indices are out of range. """ if isinstance(mesh, str): - ext = os.path.splitext(mesh)[1].lower() - if ext in _MESH_IO_EXTS: + lower = mesh.lower() + # Compound extension must be checked before os.path.splitext + if lower.endswith(".surf.gii") or os.path.splitext(lower)[1] in _MESH_IO_EXTS: vertices, faces = _read_mesh_by_ext(mesh) else: v_raw, f_raw = read_geometry(mesh, read_metadata=False) @@ -85,10 +91,25 @@ def resolve_mesh(mesh): def _load_overlay_from_file(path): - """Load a 1-D per-vertex overlay array from a file path.""" - _, ext = os.path.splitext(path) - if ext == ".mgh": + """Load a 1-D per-vertex overlay array from a file path. + + Routing logic: + + * ``.mgh`` / ``.mgz`` → :func:`read_mgh_data` + * ``.txt``, ``.csv``, ``.npy``, ``.npz``, ``.gii``, + ``.func.gii``, ``.label.gii`` → :func:`overlay_io.read_overlay` + * anything else → :func:`read_morph_data` (FreeSurfer binary morph) + """ + lower = path.lower() + # Compound GIfTI extensions must be checked before splitext + if lower.endswith(".func.gii") or lower.endswith(".label.gii"): + return _read_overlay_by_ext(path) + _, ext = os.path.splitext(lower) + if ext in (".mgh", ".mgz"): return read_mgh_data(path) + if ext in _OVERLAY_IO_EXTS: + return _read_overlay_by_ext(path) + # Default: FreeSurfer binary morph (curv, thickness, etc. — often no ext) return read_morph_data(path) @@ -168,8 +189,16 @@ def resolve_roi(roi, *, n_vertices): ---------- roi : None, str, or array-like * ``None`` — no masking; returns ``None``. - * ``str`` — path to a FreeSurfer label file. Vertices listed in the - file are marked ``True``; all others are ``False``. + * ``str`` — file path. Routing by extension: + + - ``.txt``, ``.csv``, ``.npy``, ``.npz``, ``.gii``, + ``.func.gii``, ``.label.gii`` — loaded via + :func:`_load_overlay_from_file` and cast to ``bool``. + Non-zero / ``True`` values → vertex included. + - Any other path (including FreeSurfer label files with no + standard extension, e.g. ``lh.cortex.label``) — loaded via + :func:`~whippersnappy.utils.colormap.mask_label`. + * array-like — converted to ``np.bool_``; must have shape ``(n_vertices,)``. n_vertices : int @@ -187,11 +216,20 @@ def resolve_roi(roi, *, n_vertices): if roi is None: return None if isinstance(roi, str): - # Use mask_label to get vertices included in the label (NaN = excluded). - sentinel = np.ones(n_vertices, dtype=np.float32) - masked = mask_label(sentinel, roi) - # Vertices NOT in the label were set to NaN → roi = ~isnan - arr = ~np.isnan(masked) + lower = roi.lower() + use_overlay_io = ( + lower.endswith(".func.gii") + or lower.endswith(".label.gii") + or os.path.splitext(lower)[1] in _OVERLAY_IO_EXTS + ) + if use_overlay_io: + raw = _load_overlay_from_file(roi) + arr = raw.astype(bool) + else: + # FreeSurfer label file: vertices listed → True, rest → False + sentinel = np.ones(n_vertices, dtype=np.float32) + masked = mask_label(sentinel, roi) + arr = ~np.isnan(masked) else: arr = np.asarray(roi, dtype=bool) if arr.shape != (n_vertices,): diff --git a/whippersnappy/geometry/mesh_io.py b/whippersnappy/geometry/mesh_io.py index c30f04e..105b7af 100644 --- a/whippersnappy/geometry/mesh_io.py +++ b/whippersnappy/geometry/mesh_io.py @@ -6,16 +6,23 @@ * **VTK legacy ASCII PolyData** — ``DATASET POLYDATA`` with POINTS/POLYGONS * **PLY ASCII** — Stanford PLY, ASCII encoding, triangles only +And a nibabel-backed reader for: + +* **GIfTI surface** (``.surf.gii`` / ``.gii``) — loaded via nibabel; + requires the ``NIFTI_INTENT_POINTSET`` (1008) + ``NIFTI_INTENT_TRIANGLE`` + (1009) data arrays that every standard surface GIfTI file contains. + All readers return ``(vertices, faces)`` where * ``vertices`` — ``float32`` array of shape ``(N, 3)`` * ``faces`` — ``uint32`` array of shape ``(M, 3)`` The public dispatcher :func:`read_mesh` routes by file extension -(``.off``, ``.vtk``, ``.ply``). For FreeSurfer surfaces (no standard -extension) use the existing :func:`whippersnappy.geometry.read_geometry` -directly, or go through :func:`whippersnappy.geometry.inputs.resolve_mesh` -which handles the routing automatically. +(``.off``, ``.vtk``, ``.ply``, ``.surf.gii``, ``.gii``). For FreeSurfer +surfaces (no standard extension) use the existing +:func:`whippersnappy.geometry.read_geometry` directly, or go through +:func:`whippersnappy.geometry.inputs.resolve_mesh` which handles the routing +automatically. """ import numpy as np @@ -442,21 +449,131 @@ def read_ply_ascii(path): return vertices, faces +# --------------------------------------------------------------------------- +# GIfTI surface reader +# --------------------------------------------------------------------------- + +# NIFTI intent codes relevant to GIfTI surface files +_GIFTI_INTENT_POINTSET = 1008 # NIFTI_INTENT_POINTSET — vertex coordinates +_GIFTI_INTENT_TRIANGLE = 1009 # NIFTI_INTENT_TRIANGLE — face indices + + +def read_gifti_surface(path): + """Read a GIfTI surface file (``.surf.gii`` or ``.gii``). + + A standard GIfTI surface file contains exactly two data arrays: + + * One with intent ``NIFTI_INTENT_POINTSET`` (1008) — vertex coordinates, + shape ``(N, 3)``, dtype ``float32``. + * One with intent ``NIFTI_INTENT_TRIANGLE`` (1009) — face indices, + shape ``(M, 3)``, dtype ``int32``. + + This is the format produced by FreeSurfer's ``mris_convert``, + Connectome Workbench, fMRIPrep, and most HCP pipelines. + + Parameters + ---------- + path : str + Path to a GIfTI surface file. + + Returns + ------- + vertices : numpy.ndarray, shape (N, 3), dtype float32 + faces : numpy.ndarray, shape (M, 3), dtype uint32 + + Raises + ------ + ImportError + If ``nibabel`` is not installed. + ValueError + If the file does not contain the expected POINTSET and TRIANGLE + data arrays, or if the arrays have unexpected shapes. + Also raised if the file appears to be a functional/label GIfTI + rather than a surface GIfTI (i.e. no POINTSET array found) — in + that case, use ``resolve_overlay()`` instead. + IOError + If the file cannot be opened or is not a valid GIfTI file. + + Examples + -------- + >>> v, f = read_gifti_surface('lh.white.surf.gii') + >>> v.shape # (N, 3) + >>> f.shape # (M, 3) + """ + try: + import nibabel as nib # noqa: PLC0415 + except ImportError as exc: + raise ImportError( + "Reading GIfTI surface files requires nibabel. " + "Install with: pip install nibabel" + ) from exc + + img = nib.load(path) + if not hasattr(img, "darrays") or not img.darrays: + raise ValueError( + f"GIfTI file {path!r} contains no data arrays. " + f"Expected a surface file with POINTSET and TRIANGLE arrays." + ) + + coords_arr = None + faces_arr = None + for da in img.darrays: + if da.intent == _GIFTI_INTENT_POINTSET and coords_arr is None: + coords_arr = da.data + elif da.intent == _GIFTI_INTENT_TRIANGLE and faces_arr is None: + faces_arr = da.data + + if coords_arr is None: + raise ValueError( + f"GIfTI file {path!r} has no POINTSET (intent 1008) array. " + f"If this is a functional or label file, use resolve_overlay() instead." + ) + if faces_arr is None: + raise ValueError( + f"GIfTI file {path!r} has no TRIANGLE (intent 1009) array. " + f"A valid surface GIfTI must contain both POINTSET and TRIANGLE arrays." + ) + + vertices = np.asarray(coords_arr, dtype=np.float32) + faces = np.asarray(faces_arr, dtype=np.uint32) + + if vertices.ndim != 2 or vertices.shape[1] != 3: + raise ValueError( + f"GIfTI POINTSET array in {path!r} has unexpected shape " + f"{vertices.shape}; expected (N, 3)." + ) + if faces.ndim != 2 or faces.shape[1] != 3: + raise ValueError( + f"GIfTI TRIANGLE array in {path!r} has unexpected shape " + f"{faces.shape}; expected (M, 3)." + ) + if faces.size > 0: + n_verts = vertices.shape[0] + if int(faces.max()) >= n_verts or int(faces.min()) < 0: + raise ValueError( + f"GIfTI face indices out of range [0, {n_verts}) in {path!r}." + ) + + return vertices, faces + + # --------------------------------------------------------------------------- # Dispatcher # --------------------------------------------------------------------------- _READERS = { - ".off": read_off, - ".vtk": read_vtk_ascii_polydata, - ".ply": read_ply_ascii, + ".off": read_off, + ".vtk": read_vtk_ascii_polydata, + ".ply": read_ply_ascii, + ".gii": read_gifti_surface, + ".surf.gii": read_gifti_surface, # compound extension — checked first } _SUPPORTED = ", ".join(sorted(_READERS)) def read_mesh(path): - """Read a triangle mesh from an OFF, VTK, or PLY file. + """Read a triangle mesh from an OFF, VTK, PLY, or GIfTI surface file. Dispatches to the appropriate reader based on the file extension. For FreeSurfer binary surfaces (which typically have no standard @@ -469,7 +586,8 @@ def read_mesh(path): ---------- path : str Path to a mesh file. Extension must be one of: - ``.off``, ``.vtk``, ``.ply`` (case-insensitive). + ``.off``, ``.vtk``, ``.ply``, ``.surf.gii``, ``.gii`` + (case-insensitive). Returns ------- @@ -481,8 +599,12 @@ def read_mesh(path): ValueError If the extension is not recognised. """ - import os - ext = os.path.splitext(path)[1].lower() + import os as _os + lower = path.lower() + # Check compound extension first (.surf.gii before .gii) + if lower.endswith(".surf.gii"): + return read_gifti_surface(path) + ext = _os.path.splitext(lower)[1] reader = _READERS.get(ext) if reader is None: raise ValueError( diff --git a/whippersnappy/geometry/overlay_io.py b/whippersnappy/geometry/overlay_io.py new file mode 100644 index 0000000..5c6f204 --- /dev/null +++ b/whippersnappy/geometry/overlay_io.py @@ -0,0 +1,346 @@ +"""Lightweight per-vertex scalar and label readers for common open formats. + +This module implements pure-Python (stdlib + numpy only) readers for simple +per-vertex data files, plus a GIfTI reader that reuses the nibabel dependency +already present in the project. + +Supported formats +----------------- +* **ASCII text** (``.txt``, ``.csv``) — one numeric value per line; optional + single non-numeric header line (skipped automatically); whitespace or + comma-separated. Integer values are loaded as ``int32``; all others as + ``float32``. + +* **NumPy array** (``.npy``) — single 1-D array saved with + ``numpy.save``. Any numeric dtype is accepted and kept as-is; callers + cast to the required dtype. + +* **NumPy archive** (``.npz``) — multi-array archive saved with + ``numpy.savez``. The array named ``"data"`` is used if present, + otherwise the first array in the archive is used. + +* **GIfTI functional / label** (``.func.gii``, ``.label.gii``, ``.gii``) — + loaded via ``nibabel``; the first data array is returned. Covers HCP, + fMRIPrep, and Connectome Workbench outputs. + +The public dispatcher :func:`read_overlay` routes by file extension. +FreeSurfer binary morph files and MGH/MGZ files are *not* handled here — +they are loaded by the existing :mod:`whippersnappy.geometry.read_geometry` +functions and dispatched from :func:`whippersnappy.geometry.inputs`. + +All readers return a flat ``numpy.ndarray`` of shape ``(N,)``. The caller +is responsible for casting to the desired dtype (``float32`` for overlays and +background maps, ``bool`` for ROI masks, ``int32`` for label/parcellation +maps). +""" + +import os + +import numpy as np + +# --------------------------------------------------------------------------- +# ASCII text / CSV reader +# --------------------------------------------------------------------------- + +def read_txt(path): + """Read a per-vertex scalar file in plain ASCII format. + + The file must contain exactly one numeric value per line. An optional + single-line text header (non-numeric first line) is silently skipped. + Whitespace and comma separators are both accepted; only the *first* + value on each line is used (allowing simple CSV files with a single + data column). + + Integer-valued files (every value equal to its ``int`` cast) are + returned as ``int32``; all others as ``float32``. + + Parameters + ---------- + path : str + Path to the ``.txt`` or ``.csv`` file. + + Returns + ------- + numpy.ndarray, shape (N,), dtype float32 or int32 + + Raises + ------ + ValueError + If no numeric values can be parsed from the file. + IOError + If the file cannot be opened. + + Examples + -------- + A valid ``overlay.txt``:: + + # optional comment line (skipped) + 0.123 + -1.456 + 2.0 + + A valid ``labels.csv`` (first column used, header skipped):: + + label + 3 + 0 + 1 + 3 + """ + values = [] + with open(path, encoding="utf-8", errors="replace") as fh: + for lineno, raw in enumerate(fh, start=1): + line = raw.strip() + if not line or line.startswith("#"): + continue + # Take only the first token (handles CSV with a single data column) + token = line.split(",")[0].split()[0] + try: + values.append(float(token)) + except ValueError as exc: + if lineno == 1: + # Treat the very first non-numeric line as a header and skip it + continue + raise ValueError( + f"Could not parse numeric value on line {lineno} of {path!r}: " + f"{raw.strip()!r}" + ) from exc + + if not values: + raise ValueError(f"No numeric values found in {path!r}.") + + arr = np.array(values, dtype=np.float32) + + # Promote to int32 if all values are integers (label / parcellation file) + if np.all(arr == arr.astype(np.int32)): + return arr.astype(np.int32) + return arr + + +# --------------------------------------------------------------------------- +# NumPy readers +# --------------------------------------------------------------------------- + +def read_npy(path): + """Read a per-vertex scalar array from a NumPy ``.npy`` file. + + Parameters + ---------- + path : str + Path to the ``.npy`` file. + + Returns + ------- + numpy.ndarray, shape (N,) + The stored array, squeezed to 1-D. + + Raises + ------ + ValueError + If the stored array is not 1-D after squeezing, or is empty. + IOError + If the file cannot be opened. + """ + arr = np.load(path) + arr = np.squeeze(arr) + if arr.ndim != 1: + raise ValueError( + f"NumPy file {path!r} contains an array of shape {arr.shape}; " + f"expected a 1-D per-vertex array." + ) + if arr.size == 0: + raise ValueError(f"NumPy file {path!r} contains an empty array.") + return arr + + +def read_npz(path): + """Read a per-vertex scalar array from a NumPy ``.npz`` archive. + + The array named ``"data"`` is returned if it exists; otherwise the + first array in the archive is used. + + Parameters + ---------- + path : str + Path to the ``.npz`` file. + + Returns + ------- + numpy.ndarray, shape (N,) + The selected array, squeezed to 1-D. + + Raises + ------ + ValueError + If no arrays are found, or the selected array is not 1-D after + squeezing. + IOError + If the file cannot be opened. + """ + archive = np.load(path) + keys = list(archive.keys()) + if not keys: + raise ValueError(f"NumPy archive {path!r} contains no arrays.") + + key = "data" if "data" in keys else keys[0] + arr = np.squeeze(archive[key]) + if arr.ndim != 1: + raise ValueError( + f"NumPy archive {path!r}, array {key!r} has shape {arr.shape}; " + f"expected a 1-D per-vertex array." + ) + if arr.size == 0: + raise ValueError(f"NumPy archive {path!r}, array {key!r} is empty.") + return arr + + +# --------------------------------------------------------------------------- +# GIfTI reader +# --------------------------------------------------------------------------- + +def read_gifti(path): + """Read a per-vertex scalar array from a GIfTI functional or label file. + + Supports ``.func.gii`` (continuous scalars, e.g. HCP thickness) and + ``.label.gii`` (integer parcellation labels, e.g. HCP parcellation). + Plain ``.gii`` files are also accepted provided they contain a scalar + data array — **not** a surface geometry file. For surface GIfTI files + (``.surf.gii`` or ``.gii`` files with POINTSET+TRIANGLE arrays) use + :func:`whippersnappy.geometry.mesh_io.read_gifti_surface` or pass the + path to :func:`whippersnappy.geometry.inputs.resolve_mesh`. + + The first non-POINTSET, non-TRIANGLE data array in the file is returned. + + Parameters + ---------- + path : str + Path to a GIfTI file. + + Returns + ------- + numpy.ndarray, shape (N,) + The first scalar data array, squeezed to 1-D. + + Raises + ------ + ImportError + If ``nibabel`` is not installed. + ValueError + If the file is a surface GIfTI (POINTSET+TRIANGLE only), contains + no usable scalar arrays, or the first scalar array is not 1-D. + IOError + If the file cannot be opened or is not a valid GIfTI file. + """ + try: + import nibabel as nib # noqa: PLC0415 + except ImportError as exc: + raise ImportError( + "Reading GIfTI files requires nibabel. " + "Install with: pip install nibabel" + ) from exc + + img = nib.load(path) + if not hasattr(img, "darrays") or not img.darrays: + raise ValueError( + f"GIfTI file {path!r} contains no data arrays." + ) + + # Intent codes for surface geometry — skip these + _SURFACE_INTENTS = {1008, 1009} # POINTSET, TRIANGLE + + scalar_da = None + has_surface_arrays = False + for da in img.darrays: + if da.intent in _SURFACE_INTENTS: + has_surface_arrays = True + elif scalar_da is None: + scalar_da = da + + if scalar_da is None: + if has_surface_arrays: + raise ValueError( + f"GIfTI file {path!r} appears to be a surface geometry file " + f"(contains only POINTSET/TRIANGLE arrays). " + f"Use resolve_mesh() or read_gifti_surface() to load it as a mesh." + ) + raise ValueError( + f"GIfTI file {path!r} contains no scalar data arrays." + ) + + arr = np.squeeze(scalar_da.data) + if arr.ndim != 1: + raise ValueError( + f"GIfTI file {path!r}: first scalar data array has shape " + f"{scalar_da.data.shape}; expected a 1-D per-vertex array." + ) + if arr.size == 0: + raise ValueError(f"GIfTI file {path!r}: first scalar data array is empty.") + return arr + + +# --------------------------------------------------------------------------- +# Dispatcher +# --------------------------------------------------------------------------- + +# Map from lower-case file extension to reader function. +# Note: ".func.gii" and ".label.gii" have a compound extension; we handle +# them by matching the last *two* dot-separated components as well. +_READERS = { + ".txt": read_txt, + ".csv": read_txt, + ".npy": read_npy, + ".npz": read_npz, + ".gii": read_gifti, + ".func.gii": read_gifti, + ".label.gii": read_gifti, +} + +_SUPPORTED = ", ".join(sorted(_READERS)) + + +def read_overlay(path): + """Read a per-vertex scalar or label array from a file. + + Dispatches to the appropriate reader based on the file extension. + FreeSurfer binary morph files (e.g. ``lh.curv``, ``lh.thickness``) and + MGH/MGZ files are **not** handled here — pass them through + :func:`whippersnappy.geometry.inputs._load_overlay_from_file` which + already routes those formats via :func:`read_morph_data` / + :func:`read_mgh_data`. + + Parameters + ---------- + path : str + Path to an overlay/label file. Recognised extensions: + + * ``.txt``, ``.csv`` — plain ASCII, one value per line + * ``.npy`` — NumPy binary array + * ``.npz`` — NumPy archive (key ``"data"`` or first array) + * ``.gii``, ``.func.gii``, ``.label.gii`` — GIfTI + + Returns + ------- + numpy.ndarray, shape (N,) + + Raises + ------ + ValueError + If the file extension is not recognised. + """ + # Check compound extensions first (.func.gii, .label.gii) + lower = path.lower() + for compound in (".func.gii", ".label.gii"): + if lower.endswith(compound): + return _READERS[compound](path) + + ext = os.path.splitext(path)[1].lower() + reader = _READERS.get(ext) + if reader is None: + raise ValueError( + f"Unsupported overlay file extension {ext!r} for {path!r}. " + f"Supported formats: {_SUPPORTED}. " + f"For FreeSurfer morph files (no extension) or .mgh/.mgz files " + f"the routing is handled automatically by resolve_overlay()." + ) + return reader(path) + From d13728a7dcbb04a8048fa1121569108dd25d4ed9 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 18:03:01 +0100 Subject: [PATCH 65/83] rename read_geometry to freesurfer_io.py --- whippersnappy/geometry/__init__.py | 51 +++++++++++++++++-- .../{read_geometry.py => freesurfer_io.py} | 46 ++++++++++++----- whippersnappy/geometry/inputs.py | 20 +++++--- whippersnappy/geometry/overlay_io.py | 7 ++- 4 files changed, 95 insertions(+), 29 deletions(-) rename whippersnappy/geometry/{read_geometry.py => freesurfer_io.py} (82%) diff --git a/whippersnappy/geometry/__init__.py b/whippersnappy/geometry/__init__.py index b5cf706..3307510 100644 --- a/whippersnappy/geometry/__init__.py +++ b/whippersnappy/geometry/__init__.py @@ -1,7 +1,43 @@ -"""Geometry subpackage exports. +"""Geometry subpackage — mesh IO, overlay IO, and rendering preparation. -Expose prepare_geometry and small IO helpers under `whippersnappy.geometry`. +Architecture +------------ +The subpackage has three layers: + +**Layer 1 — low-level format readers** (one file per format family): + +* :mod:`~whippersnappy.geometry.freesurfer_io` — FreeSurfer binary formats: + surface geometry (``read_geometry``), morphometry scalars + (``read_morph_data``), MGH overlays (``read_mgh_data``), and annotation + files (``read_annot_data``). Contains derived nibabel code (MIT licence). +* :mod:`~whippersnappy.geometry.mesh_io` — open ASCII mesh formats: + OFF (``read_off``), legacy VTK PolyData (``read_vtk_ascii_polydata``), + ASCII PLY (``read_ply_ascii``), and GIfTI surface (``read_gifti_surface``). + Pure stdlib + numpy, except GIfTI which uses nibabel. +* :mod:`~whippersnappy.geometry.overlay_io` — open scalar/label formats: + plain ASCII (``read_txt``), NumPy binary (``read_npy``, ``read_npz``), + and GIfTI functional/label (``read_gifti``). Pure stdlib + numpy, except + GIfTI. + +Each family also exposes a dispatcher (``read_mesh`` / ``read_overlay``) +that routes by file extension. + +**Layer 2 — resolvers** (:mod:`~whippersnappy.geometry.inputs`): + +``resolve_mesh``, ``resolve_overlay``, ``resolve_bg_map``, ``resolve_roi``, +``resolve_annot`` — the **single public interface** for the rest of the +package. Each resolver accepts a file path *or* a numpy array *or* ``None``, +dispatches to the correct layer-1 reader, validates shapes and dtypes, and +returns a clean numpy array. All higher-level code (``prepare.py``, +``snap.py``, ``plot3d.py``, CLIs) should go through resolvers only. + +**Layer 3 — geometry preparation** (:mod:`~whippersnappy.geometry.prepare`): + +``prepare_geometry`` / ``prepare_geometry_from_arrays`` — load, normalise, +colour, and pack vertex data into the GPU-ready format consumed by the +OpenGL shaders. """ +from .freesurfer_io import read_annot_data, read_geometry, read_mgh_data, read_morph_data from .inputs import resolve_annot, resolve_bg_map, resolve_mesh, resolve_overlay, resolve_roi from .mesh_io import read_gifti_surface, read_mesh, read_off, read_ply_ascii, read_vtk_ascii_polydata from .overlay_io import read_gifti, read_npy, read_npz, read_overlay, read_txt @@ -11,32 +47,37 @@ prepare_geometry, prepare_geometry_from_arrays, ) -from .read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data from .surf_name import get_surf_name __all__ = [ + # Layer 3 — geometry preparation 'prepare_geometry', 'prepare_geometry_from_arrays', 'prepare_and_validate_geometry', 'estimate_overlay_thresholds', + # Layer 2 — resolvers (preferred public interface) 'resolve_mesh', 'resolve_overlay', 'resolve_bg_map', 'resolve_roi', 'resolve_annot', + # Layer 1 — mesh readers 'read_mesh', 'read_off', 'read_vtk_ascii_polydata', 'read_ply_ascii', 'read_gifti_surface', + # Layer 1 — overlay/scalar readers 'read_overlay', 'read_txt', 'read_npy', 'read_npz', 'read_gifti', + # Layer 1 — FreeSurfer binary readers 'read_geometry', - 'read_annot_data', - 'read_mgh_data', 'read_morph_data', + 'read_mgh_data', + 'read_annot_data', + # Utilities 'get_surf_name', ] diff --git a/whippersnappy/geometry/read_geometry.py b/whippersnappy/geometry/freesurfer_io.py similarity index 82% rename from whippersnappy/geometry/read_geometry.py rename to whippersnappy/geometry/freesurfer_io.py index 2387ada..7d7fd39 100644 --- a/whippersnappy/geometry/read_geometry.py +++ b/whippersnappy/geometry/freesurfer_io.py @@ -1,8 +1,26 @@ -"""Read FreeSurfer geometry (fix for dev, ll 126-128); - -Code was taken from nibabel.freesurfer package -(https://github.com/nipy/nibabel/blob/master/nibabel/freesurfer/io.py). -This software is licensed under the following license: +"""FreeSurfer binary IO — surface geometry, morphometry, MGH, and annotation. + +This module contains readers for all FreeSurfer binary formats: + +* :func:`read_geometry` — triangle surface (e.g. ``lh.white``, ``rh.pial``) +* :func:`read_morph_data` — per-vertex morphometry scalar (e.g. ``lh.curv``, + ``lh.thickness``); the format FreeSurfer internally calls "curv files" +* :func:`read_mgh_data` — MGH/MGZ volumetric image used as per-vertex + overlay (shape ``N×1×1``) +* :func:`read_annot_data` — FreeSurfer parcellation annotation (``.annot``) + +These functions are **low-level readers**. In normal use you should go +through the resolver functions in :mod:`whippersnappy.geometry.inputs` +(:func:`~whippersnappy.geometry.inputs.resolve_mesh`, +:func:`~whippersnappy.geometry.inputs.resolve_overlay`, etc.) which handle +format dispatch, dtype conversion, and shape validation for you. + +License notice +-------------- +The binary parsing code in this module is derived from the ``nibabel`` +FreeSurfer IO module +(https://github.com/nipy/nibabel/blob/master/nibabel/freesurfer/io.py), +used under the MIT licence reproduced below. The MIT License @@ -14,12 +32,12 @@ Copyright (c) 2011-2019 Yaroslav Halchenko Copyright (c) 2015-2019 Chris Markiewicz -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. @@ -28,9 +46,9 @@ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. """ import warnings diff --git a/whippersnappy/geometry/inputs.py b/whippersnappy/geometry/inputs.py index 03e3e96..3c3afa5 100644 --- a/whippersnappy/geometry/inputs.py +++ b/whippersnappy/geometry/inputs.py @@ -1,10 +1,18 @@ """Input resolver functions for WhipperSnapPy geometry loading. -This module is the single source of truth for loading and validating all -user-facing inputs (mesh, overlay, background map, ROI, annotation). No -other module should call ``read_geometry``, ``read_morph_data``, -``read_mgh_data``, ``read_annot_data``, or ``mask_label`` directly — all -calls should go through the resolver functions defined here. +This module is the **single entry point** for all input loading and +validation. No other module should call ``read_geometry``, +``read_morph_data``, ``read_mgh_data``, ``read_annot_data``, or +``mask_label`` directly — all calls go through the resolver functions +defined here. + +Each resolver accepts ``None``, a file path (``str``), or a numpy +array-like, validates shape and dtype, and returns a clean numpy array +(or ``None``). Format dispatch is handled internally by: + +* :mod:`~whippersnappy.geometry.freesurfer_io` — FreeSurfer binary formats +* :mod:`~whippersnappy.geometry.mesh_io` — OFF / VTK / PLY / GIfTI surfaces +* :mod:`~whippersnappy.geometry.overlay_io` — TXT / CSV / NPY / NPZ / GIfTI scalars """ import os @@ -12,9 +20,9 @@ import numpy as np from ..utils.colormap import mask_label +from .freesurfer_io import read_annot_data, read_geometry, read_mgh_data, read_morph_data from .mesh_io import read_mesh as _read_mesh_by_ext from .overlay_io import read_overlay as _read_overlay_by_ext -from .read_geometry import read_annot_data, read_geometry, read_mgh_data, read_morph_data # Extensions handled by the lightweight ASCII mesh readers in mesh_io.py # (includes GIfTI surface via nibabel) diff --git a/whippersnappy/geometry/overlay_io.py b/whippersnappy/geometry/overlay_io.py index 5c6f204..04ea510 100644 --- a/whippersnappy/geometry/overlay_io.py +++ b/whippersnappy/geometry/overlay_io.py @@ -25,8 +25,8 @@ The public dispatcher :func:`read_overlay` routes by file extension. FreeSurfer binary morph files and MGH/MGZ files are *not* handled here — -they are loaded by the existing :mod:`whippersnappy.geometry.read_geometry` -functions and dispatched from :func:`whippersnappy.geometry.inputs`. +they are loaded by :mod:`whippersnappy.geometry.freesurfer_io` and +dispatched from :func:`whippersnappy.geometry.inputs._load_overlay_from_file`. All readers return a flat ``numpy.ndarray`` of shape ``(N,)``. The caller is responsible for casting to the desired dtype (``float32`` for overlays and @@ -305,8 +305,7 @@ def read_overlay(path): FreeSurfer binary morph files (e.g. ``lh.curv``, ``lh.thickness``) and MGH/MGZ files are **not** handled here — pass them through :func:`whippersnappy.geometry.inputs._load_overlay_from_file` which - already routes those formats via :func:`read_morph_data` / - :func:`read_mgh_data`. + already routes those formats via :mod:`~whippersnappy.geometry.freesurfer_io`. Parameters ---------- From cfda0799c1a84ecfccf94a136a1838c5cda9d1fe Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 18:17:09 +0100 Subject: [PATCH 66/83] fix ruff --- tests/test_mesh_io.py | 1 - tests/test_overlay_io.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_mesh_io.py b/tests/test_mesh_io.py index 543d7b3..d8d76a4 100644 --- a/tests/test_mesh_io.py +++ b/tests/test_mesh_io.py @@ -405,7 +405,6 @@ def test_array_out_of_range_raises(self): def _make_surf_gii(verts, faces, suffix=".surf.gii"): """Write a minimal GIfTI surface file and return its path.""" import nibabel as nib - from nibabel import nifti1 _INTENT_POINTSET = 1008 _INTENT_TRIANGLE = 1009 coords_da = nib.gifti.GiftiDataArray( diff --git a/tests/test_overlay_io.py b/tests/test_overlay_io.py index a80de80..18c82d3 100644 --- a/tests/test_overlay_io.py +++ b/tests/test_overlay_io.py @@ -10,7 +10,7 @@ import numpy as np import pytest -from whippersnappy.geometry.inputs import resolve_overlay, resolve_bg_map, resolve_roi +from whippersnappy.geometry.inputs import resolve_bg_map, resolve_overlay, resolve_roi from whippersnappy.geometry.overlay_io import ( read_npy, read_npz, @@ -18,7 +18,6 @@ read_txt, ) - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- From 8fad5b1b51d3984eca926cea9200701d958853b6 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 19:39:57 +0100 Subject: [PATCH 67/83] minor fixes, fixing pillow min version --- DOCKER.md | 4 ++-- pyproject.toml | 2 +- whippersnappy/gl/utils.py | 7 ++----- whippersnappy/plot3d.py | 4 ++-- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/DOCKER.md b/DOCKER.md index dc7cd44..8648005 100644 --- a/DOCKER.md +++ b/DOCKER.md @@ -90,7 +90,7 @@ docker run --rm --init \ whippersnappy \ /subject/surf/lh.white \ --overlay /subject/surf/lh.thickness \ - --curv /subject/surf/lh.curv \ + --bg-map /subject/surf/lh.curv \ --view left \ --fthresh 2.0 --fmax 4.0 \ -o /output/snap1.png @@ -121,7 +121,7 @@ docker run --rm --init \ whippersnappy \ /subject/surf/lh.white \ --overlay /subject/surf/lh.thickness \ - --curv /subject/surf/lh.curv \ + --bg-map /subject/surf/lh.curv \ --rotate \ --rotate-frames 72 \ --rotate-fps 24 \ diff --git a/pyproject.toml b/pyproject.toml index d7d81d5..e2fc827 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ 'glfw', 'numpy>=1.21', 'pyrr', - 'pillow', + 'pillow>=9.1', 'pyopengl>=3.1.8', 'nibabel', 'psutil' diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index de279d7..0e546e0 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -51,7 +51,7 @@ def compile_shader_program(vertex_src, fragment_src): int OpenGL program handle. """ - return gl.shaders.compileProgram( + return shaders.compileProgram( shaders.compileShader(vertex_src, gl.GL_VERTEX_SHADER), shaders.compileShader(fragment_src, gl.GL_FRAGMENT_SHADER), ) @@ -401,10 +401,7 @@ def capture_window(window): gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) img_buf = gl.glReadPixels(0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE) image = Image.frombytes("RGB", (width, height), img_buf) - # Image.Transpose.FLIP_TOP_BOTTOM is the preferred form since Pillow 9.1; - # fall back to the legacy integer constant for older installations. - _flip = getattr(Image, "Transpose", Image).FLIP_TOP_BOTTOM - image = image.transpose(_flip) + image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM) if x_scale != 1 or y_scale != 1: rwidth = int(round(width / x_scale)) diff --git a/whippersnappy/plot3d.py b/whippersnappy/plot3d.py index 3b70a6e..8fefcce 100644 --- a/whippersnappy/plot3d.py +++ b/whippersnappy/plot3d.py @@ -245,5 +245,5 @@ def create_threejs_mesh_with_custom_shaders(vertices, faces, colors, normals, am } ) - mesh = p3js.Mesh(geometry=geometry, material=material) - return mesh + three_mesh = p3js.Mesh(geometry=geometry, material=material) + return three_mesh From 70572cac5e5a554c528df260cbe2442e33ce69a4 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 20:04:04 +0100 Subject: [PATCH 68/83] readme mention gifti and overlay support --- README.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index dd21af9..cf2a1dc 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ with color overlays or parcellations and generate screenshots — from the command line, in Jupyter notebooks, or via a desktop GUI. It works with FreeSurfer and FastSurfer brain surfaces as well as any -triangle mesh in OFF, legacy ASCII VTK PolyData, or ASCII PLY format, or +triangle mesh in OFF, legacy ASCII VTK PolyData, ASCII PLY, or GIfTI (.gii, .surf.gii) format, or passed directly as a NumPy ``(vertices, faces)`` tuple. ## Installation @@ -66,6 +66,7 @@ whippersnap1 $SUBJECT_DIR/surf/lh.white \ # Also works with OFF / VTK / PLY whippersnap1 mesh.off --overlay values.mgh -o snap1.png +whippersnap1 surface.surf.gii --overlay overlay.func.gii -o snap1.png ``` ### Rotation video (`whippersnap1 --rotate`) @@ -114,7 +115,10 @@ from whippersnappy import snap1, snap4, snap_rotate, plot3d | `plot3d` | Interactive 3D WebGL viewer for Jupyter notebooks | **Supported mesh inputs for `snap1`, `snap_rotate`, and `plot3d`:** -FreeSurfer binary surfaces (e.g. `lh.white`), OFF (`.off`), legacy ASCII VTK PolyData (`.vtk`), ASCII PLY (`.ply`), or a `(vertices, faces)` NumPy array tuple. +FreeSurfer binary surfaces (e.g. `lh.white`), OFF (`.off`), legacy ASCII VTK PolyData (`.vtk`), ASCII PLY (`.ply`), GIfTI surface (`.gii`, `.surf.gii`), or a `(vertices, faces)` NumPy array tuple. + +**Supported overlay/label inputs:** +FreeSurfer morph (`.curv`, `.thickness`), MGH/MGZ, ASCII (`.txt`, `.csv`), NumPy (`.npy`, `.npz`), GIfTI functional/label (`.func.gii`, `.label.gii`, `.gii`). ### Example @@ -135,8 +139,9 @@ img = snap4(sdir='/path/to/subject', colorbar=True, caption='Cortical Thickness (mm)') img.save('snap4.png') -# OFF / VTK / PLY mesh +# OFF / VTK / PLY / GIfTI mesh img = snap1('mesh.off', overlay='values.mgh') +img = snap1('surface.surf.gii', overlay='overlay.func.gii') # Array inputs (e.g. from LaPy or trimesh) import numpy as np From 0b6f8a4eb2b367e201ac394a9657e36a3512ba14 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 20:10:08 +0100 Subject: [PATCH 69/83] move redundant CLI doc in README --- README.md | 9 --------- 1 file changed, 9 deletions(-) diff --git a/README.md b/README.md index cf2a1dc..d672be9 100644 --- a/README.md +++ b/README.md @@ -151,15 +151,6 @@ overlay = np.random.randn(1000).astype(np.float32) img = snap1((v, f), overlay=overlay) ``` -CLI usage: - -```bash -# Single view -whippersnap1 lh.white --overlay lh.thickness --bg-map lh.curv --roi lh.cortex.label -o snap1.png - -# Four-view batch -whippersnap4 -lh lh.thickness -rh rh.thickness -sd /path/to/subject -o snap4.png -``` See `tutorials/whippersnappy_tutorial.ipynb` for complete notebook examples. From fdcbc3248c273439a93955de2b8dc1212bdc50dd Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 20:21:04 +0100 Subject: [PATCH 70/83] remove CLI imports in init --- whippersnappy/cli/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whippersnappy/cli/__init__.py b/whippersnappy/cli/__init__.py index a6d22a9..b9d0b91 100644 --- a/whippersnappy/cli/__init__.py +++ b/whippersnappy/cli/__init__.py @@ -1 +1 @@ -from .whippersnap import run +# CLI modules should not be imported here; keep __init__.py empty. From 6868634be8b0659bd9522605a0f3fa111557c1fb Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 21:49:28 +0100 Subject: [PATCH 71/83] add --mesh to whippersnap1 --- README.md | 8 ++--- whippersnappy/cli/whippersnap1.py | 49 +++++++++++++++++++++++-------- whippersnappy/cli/whippersnap4.py | 9 ++++++ 3 files changed, 50 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index d672be9..f38f86b 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ whippersnap4 -lh $LH_OVERLAY \ Renders one view of any triangular surface mesh: ```bash -whippersnap1 $SUBJECT_DIR/surf/lh.white \ +whippersnap1 --mesh $SUBJECT_DIR/surf/lh.white \ --overlay $LH_OVERLAY \ --bg-map $SUBJECT_DIR/surf/lh.curv \ --roi $SUBJECT_DIR/label/lh.cortex.label \ @@ -65,8 +65,8 @@ whippersnap1 $SUBJECT_DIR/surf/lh.white \ -o snap1.png # Also works with OFF / VTK / PLY -whippersnap1 mesh.off --overlay values.mgh -o snap1.png -whippersnap1 surface.surf.gii --overlay overlay.func.gii -o snap1.png +whippersnap1 --mesh mesh.off --overlay values.mgh -o snap1.png +whippersnap1 --mesh surface.surf.gii --overlay overlay.func.gii -o snap1.png ``` ### Rotation video (`whippersnap1 --rotate`) @@ -74,7 +74,7 @@ whippersnap1 surface.surf.gii --overlay overlay.func.gii -o snap1.png Renders a 360° animation of any triangular surface mesh: ```bash -whippersnap1 $SUBJECT_DIR/surf/lh.white \ +whippersnap1 --mesh $SUBJECT_DIR/surf/lh.white \ --overlay $LH_OVERLAY \ --rotate \ -o rotation.mp4 diff --git a/whippersnappy/cli/whippersnap1.py b/whippersnappy/cli/whippersnap1.py index 2ec9777..4d409f6 100644 --- a/whippersnappy/cli/whippersnap1.py +++ b/whippersnappy/cli/whippersnap1.py @@ -12,7 +12,7 @@ Usage:: # FreeSurfer surface — lateral view with thickness overlay - whippersnap1 /surf/lh.white \\ + whippersnap1 --mesh /surf/lh.white \\ --overlay /surf/lh.thickness \\ --bg-map /surf/lh.curv \\ --roi /label/lh.cortex.label \\ @@ -20,18 +20,18 @@ -o snap1.png # OFF / VTK / PLY mesh with a numpy-saved overlay - whippersnap1 mesh.off --overlay values.mgh -o snap1.png - whippersnap1 mesh.vtk -o snap1.png - whippersnap1 mesh.ply --overlay values.mgh -o snap1.png + whippersnap1 --mesh mesh.off --overlay values.mgh -o snap1.png + whippersnap1 --mesh mesh.vtk -o snap1.png + whippersnap1 --mesh mesh.ply --overlay values.mgh -o snap1.png # 360° rotation video - whippersnap1 /surf/lh.white \\ + whippersnap1 --mesh /surf/lh.white \\ --overlay /surf/lh.thickness \\ --rotate --rotate-frames 72 --rotate-fps 24 \\ -o rotation.mp4 # Parcellation annotation - whippersnap1 /surf/lh.white \\ + whippersnap1 --mesh /surf/lh.white \\ --annot /label/lh.aparc.annot \\ --view left -o snap_annot.png @@ -45,6 +45,10 @@ import os import tempfile +if __name__ == "__main__" and __package__ is None: + import sys + os.execv(sys.executable, [sys.executable, "-m", "whippersnappy.cli.whippersnap1"] + sys.argv[1:]) + from .. import snap1, snap_rotate from .._version import __version__ from ..utils.types import ColorSelection, OrientationType, ViewType @@ -76,7 +80,7 @@ def run(): ----- **Snapshot options** (default mode): - * ``mesh`` — path to any triangular surface mesh: FreeSurfer binary + * ``--mesh`` — path to any triangular surface mesh: FreeSurfer binary (e.g. ``lh.white``), ASCII OFF (``.off``), legacy ASCII VTK PolyData (``.vtk``), or ASCII PLY (``.ply``). * ``--overlay`` — per-vertex scalar overlay (e.g. ``lh.thickness`` or a ``.mgh`` file). @@ -111,16 +115,26 @@ def run(): ) parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}") - # --- Required --- + # --- Mesh input: --mesh flag (preferred) or bare positional (legacy) --- parser.add_argument( - "mesh", + "--mesh", type=str, + default=None, help=( "Path to the surface mesh file. Supported formats: " "FreeSurfer binary surface (e.g. lh.white, rh.pial), " - "ASCII OFF (.off), legacy ASCII VTK PolyData (.vtk), ASCII PLY (.ply)." + "ASCII OFF (.off), legacy ASCII VTK PolyData (.vtk), ASCII PLY (.ply), " + "GIfTI surface (.gii, .surf.gii)." ), ) + # Keep positional for backward compatibility (silently accepted) + parser.add_argument( + "_mesh_positional", + nargs="?", + default=None, + metavar="MESH", + help=argparse.SUPPRESS, + ) # --- Output --- parser.add_argument( @@ -216,6 +230,11 @@ def run(): args = parser.parse_args() + # Resolve mesh: --mesh takes precedence over bare positional argument + mesh_path = args.mesh or args._mesh_positional + if mesh_path is None: + parser.error("A mesh file is required: use --mesh .") + log = logging.getLogger(__name__) try: @@ -224,7 +243,7 @@ def run(): tempfile.gettempdir(), "whippersnappy_rotation.mp4" ) snap_rotate( - mesh=args.mesh, + mesh=mesh_path, outpath=outpath, n_frames=args.rotate_frames, fps=args.rotate_fps, @@ -249,7 +268,7 @@ def run(): tempfile.gettempdir(), "whippersnappy_snap1.png" ) img = snap1( - mesh=args.mesh, + mesh=mesh_path, outpath=outpath, overlay=args.overlay, annot=args.annot, @@ -273,3 +292,9 @@ def run(): log.info("Snapshot saved to %s (%dx%d)", outpath, img.width, img.height) except (RuntimeError, FileNotFoundError, ValueError, ImportError) as e: parser.error(str(e)) + + +if __name__ == "__main__": + run() + + diff --git a/whippersnappy/cli/whippersnap4.py b/whippersnappy/cli/whippersnap4.py index f8de394..c1370b3 100644 --- a/whippersnappy/cli/whippersnap4.py +++ b/whippersnappy/cli/whippersnap4.py @@ -18,6 +18,10 @@ import os import tempfile +if __name__ == "__main__" and __package__ is None: + import sys + os.execv(sys.executable, [sys.executable, "-m", "whippersnappy.cli.whippersnap4"] + sys.argv[1:]) + from .. import snap4 from .._version import __version__ @@ -178,3 +182,8 @@ def run(): except (RuntimeError, FileNotFoundError, ValueError) as e: parser.error(str(e)) + +if __name__ == "__main__": + run() + + From f58f700457d5474fd743bff72e35dcad1e82a2de Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 21:49:46 +0100 Subject: [PATCH 72/83] add --mesh to whippersnap1 --- DOCKER.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/DOCKER.md b/DOCKER.md index 8648005..fbdbc33 100644 --- a/DOCKER.md +++ b/DOCKER.md @@ -88,7 +88,7 @@ docker run --rm --init \ -v /path/to/output:/output \ --user $(id -u):$(id -g) \ whippersnappy \ - /subject/surf/lh.white \ + --mesh /subject/surf/lh.white \ --overlay /subject/surf/lh.thickness \ --bg-map /subject/surf/lh.curv \ --view left \ @@ -119,7 +119,7 @@ docker run --rm --init \ -v /path/to/output:/output \ --user $(id -u):$(id -g) \ whippersnappy \ - /subject/surf/lh.white \ + --mesh /subject/surf/lh.white \ --overlay /subject/surf/lh.thickness \ --bg-map /subject/surf/lh.curv \ --rotate \ @@ -137,7 +137,7 @@ docker run --rm --init \ -v /path/to/output:/output \ --user $(id -u):$(id -g) \ whippersnappy \ - /subject/surf/lh.white \ + --mesh /subject/surf/lh.white \ --overlay /subject/surf/lh.thickness \ --rotate \ --rotate-frames 36 \ From 86d286d6462376cf9626d82bb7621adffad2444a Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 21:56:44 +0100 Subject: [PATCH 73/83] fix GUI on mac --- whippersnappy/cli/whippersnap.py | 188 ++++++++++++++++++------------- 1 file changed, 108 insertions(+), 80 deletions(-) diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index b9da7bb..b082628 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -27,9 +27,13 @@ import argparse import logging import os -import signal import sys -import threading + +if __name__ == "__main__" and __package__ is None: + # Replace the current process with `python -m whippersnappy.cli.whippersnap` + # so that relative imports work. os.execv replaces this process in-place + # (no child process, no blocking wait, signals work correctly). + os.execv(sys.executable, [sys.executable, "-m", "whippersnappy.cli.whippersnap"] + sys.argv[1:]) import glfw import OpenGL.GL as gl @@ -49,11 +53,10 @@ # Module logger logger = logging.getLogger(__name__) -# Global state shared between the GL thread and the Qt main thread +# Global state shared between the GL render loop and the Qt config panel. +# All access is from the main thread — no locking needed. current_fthresh_ = None current_fmax_ = None -app_ = None -app_window_ = None app_window_closed_ = False @@ -66,109 +69,129 @@ def show_window( invert=False, specular=True, view=ViewType.LEFT, + app=None, + config_window=None, ): - """Start a live interactive OpenGL window for viewing a triangular mesh. - - The function initializes a GLFW window and renders the provided mesh - with any supplied overlay or annotation. It polls for threshold updates - from the Qt configuration panel and re-renders whenever the thresholds - change. + """Start a live interactive OpenGL+Qt window for viewing a triangular mesh. - ``mesh`` is a fully resolved path (or ``(vertices, faces)`` tuple); - all FreeSurfer path-building is performed in :func:`run` before this - function is called. + On macOS both GLFW/Cocoa and Qt require the main thread. This function + creates a GLFW window, then hands control to a ``QTimer``-driven render + loop so that GLFW polling and Qt event processing share the main thread. Parameters ---------- mesh : str or tuple of (array-like, array-like) - Path to any mesh file supported by :func:`whippersnappy.geometry.inputs.resolve_mesh` - (FreeSurfer binary, ``.off``, ``.vtk``, ``.ply``) **or** a - ``(vertices, faces)`` array tuple. + Path to any mesh file or a ``(vertices, faces)`` array tuple. overlay : str, array-like, or None, optional - Per-vertex scalar overlay — file path or (N,) array. + Per-vertex scalar overlay. annot : str, tuple, or None, optional - FreeSurfer ``.annot`` file path or ``(labels, ctab[, names])`` tuple. + FreeSurfer ``.annot`` file or ``(labels, ctab[, names])`` tuple. bg_map : str, array-like, or None, optional - Per-vertex scalar file or array for background shading (sign → light/dark). + Per-vertex scalar file or array for background shading. roi : str, array-like, or None, optional - FreeSurfer label file path or boolean (N,) array masking overlay coloring. + FreeSurfer label file or boolean array masking overlay coloring. invert : bool, optional Invert the overlay color mapping. Default is ``False``. specular : bool, optional - Enable specular highlights in the shader. Default is ``True``. + Enable specular highlights. Default is ``True``. view : ViewType, optional Initial camera view direction. Default is ``ViewType.LEFT``. + app : QApplication + The already-created ``QApplication`` instance. + config_window : ConfigWindow + The already-created Qt configuration panel. Raises ------ RuntimeError If the GLFW window or OpenGL context could not be created. """ - global current_fthresh_, current_fmax_, app_, app_window_, app_window_closed_ + global current_fthresh_, current_fmax_, app_window_closed_ - wwidth = 720 + from PyQt6.QtCore import QTimer # noqa: PLC0415 + + wwidth = 720 wheight = 600 - window = init_window(wwidth, wheight, "WhipperSnapPy", visible=True) + window = init_window(wwidth, wheight, "WhipperSnapPy", visible=True) if not window: - logger.error("Could not create any GLFW window/context. OpenGL context unavailable.") - raise RuntimeError("Could not create any GLFW window/context. OpenGL context unavailable.") + raise RuntimeError( + "Could not create a GLFW window/context. OpenGL context unavailable." + ) view_mats = get_view_matrices() - viewmat = view_mats[view] - rot_y = pyrr.Matrix44.from_y_rotation(0) + viewmat = view_mats[view] + rot_y = pyrr.Matrix44.from_y_rotation(0) + ypos = 0.0 - meshdata, triangles, fthresh, fmax, neg = prepare_geometry( - mesh, overlay, annot, bg_map, roi, current_fthresh_, current_fmax_, + meshdata, triangles, _fthresh, _fmax, _pos, _neg = prepare_geometry( + mesh, overlay, annot, bg_map, roi, + current_fthresh_, current_fmax_, invert=invert, ) shader = setup_shader(meshdata, triangles, wwidth, wheight, specular=specular) - logger.info("\nKeys:\nLeft - Right : Rotate Geometry\nESC : Quit\n") + logger.info("Keys: Left/Right arrows → rotate ESC → quit") + + def _render_frame(): + """Called by QTimer every frame; does one GLFW poll + one GL draw.""" + global current_fthresh_, current_fmax_, app_window_closed_ + nonlocal meshdata, triangles, shader, rot_y, ypos - ypos = 0 - while glfw.get_key(window, glfw.KEY_ESCAPE) != glfw.PRESS and not glfw.window_should_close(window): - if app_window_closed_: - break + # Check GLFW close conditions + if ( + glfw.get_key(window, glfw.KEY_ESCAPE) == glfw.PRESS + or glfw.window_should_close(window) + or app_window_closed_ + ): + timer.stop() + glfw.terminate() + app.quit() + return glfw.poll_events() gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) - if app_window_ is not None: - if ( - app_window_.get_fthresh_value() != current_fthresh_ - or app_window_.get_fmax_value() != current_fmax_ - ): - current_fthresh_ = app_window_.get_fthresh_value() - current_fmax_ = app_window_.get_fmax_value() - meshdata, triangles, fthresh, fmax, neg = prepare_geometry( + # Re-render if Qt sliders changed the thresholds + if config_window is not None: + new_fthresh = config_window.get_fthresh_value() + new_fmax = config_window.get_fmax_value() + if new_fthresh != current_fthresh_ or new_fmax != current_fmax_: + current_fthresh_ = new_fthresh + current_fmax_ = new_fmax + meshdata, triangles, _fthresh, _fmax, _pos, _neg = prepare_geometry( mesh, overlay, annot, bg_map, roi, current_fthresh_, current_fmax_, invert=invert, ) - shader = setup_shader(meshdata, triangles, wwidth, wheight, specular=specular) - - transformLoc = gl.glGetUniformLocation(shader, "transform") - gl.glUniformMatrix4fv(transformLoc, 1, gl.GL_FALSE, rot_y * viewmat) + shader = setup_shader( + meshdata, triangles, wwidth, wheight, specular=specular + ) + # Keyboard rotation if glfw.get_key(window, glfw.KEY_RIGHT) == glfw.PRESS: - ypos += 0.0004 + ypos += 0.004 if glfw.get_key(window, glfw.KEY_LEFT) == glfw.PRESS: - ypos -= 0.0004 + ypos -= 0.004 rot_y = pyrr.Matrix44.from_y_rotation(ypos) + transform_loc = gl.glGetUniformLocation(shader, "transform") + gl.glUniformMatrix4fv(transform_loc, 1, gl.GL_FALSE, rot_y * viewmat) gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) glfw.swap_buffers(window) - glfw.terminate() - # Signal the main thread to tear down the Qt app - app_window_closed_ = True + # ~60 fps timer — fires every 16 ms on the main thread + timer = QTimer() + timer.timeout.connect(_render_frame) + timer.start(16) + + app.exec() def config_app_exit_handler(): - """Mark the configuration application as closed. + """Mark the configuration window as closed. - Connected to the Qt app's ``aboutToQuit`` signal so the OpenGL loop - in the worker thread terminates cleanly. + Connected to ``QApplication.aboutToQuit`` so the render timer stops + cleanly when the user closes the Qt panel. """ global app_window_closed_ app_window_closed_ = True @@ -232,7 +255,7 @@ def run(): Requires ``pip install 'whippersnappy[gui]'``. For non-interactive batch rendering use ``whippersnap4`` or ``whippersnap1``. """ - global current_fthresh_, current_fmax_, app_, app_window_ + global current_fthresh_, current_fmax_ logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") _VIEW_CHOICES = {v.name.lower(): v for v in ViewType} @@ -425,33 +448,38 @@ def run(): current_fthresh_ = args.fthresh current_fmax_ = args.fmax - thread = threading.Thread( - target=show_window, - kwargs=dict( - mesh=mesh_path, - overlay=overlay, - annot=args.annot, - bg_map=bg_map, - roi=roi, - invert=args.invert, - specular=args.specular, - view=view, - ), - ) - thread.start() - - app_ = QApplication([]) - app_.setStyle("Fusion") - app_.aboutToQuit.connect(config_app_exit_handler) + # Both QApplication/Qt and GLFW/Cocoa require the main thread on macOS. + # Create Qt objects here on the main thread, then pass them into + # show_window which drives rendering via a QTimer (no extra threads). + app = QApplication(sys.argv) + app.setStyle("Fusion") + app.aboutToQuit.connect(config_app_exit_handler) - screen_geometry = app_.primaryScreen().availableGeometry() - app_window_ = ConfigWindow( + screen_geometry = app.primaryScreen().availableGeometry() + config_window = ConfigWindow( screen_dims=(screen_geometry.width(), screen_geometry.height()), initial_fthresh_value=current_fthresh_, initial_fmax_value=current_fmax_, ) + config_window.show() + + # show_window creates the GLFW window, sets up a QTimer render loop, + # then calls app.exec() — returns when either window is closed. + show_window( + mesh=mesh_path, + overlay=overlay, + annot=args.annot, + bg_map=bg_map, + roi=roi, + invert=args.invert, + specular=args.specular, + view=view, + app=app, + config_window=config_window, + ) + + +if __name__ == "__main__": + run() - signal.signal(signal.SIGINT, signal.SIG_DFL) - app_window_.show() - app_.exec() From ddff35e542332cb8781ee399d2e4b14e6089ff15 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 22:12:58 +0100 Subject: [PATCH 74/83] more tests also with offscreen render --- tests/test_array_inputs.py | 84 +++++++++++++++++++++++++++++++++++--- whippersnappy/gl/utils.py | 6 ++- 2 files changed, 84 insertions(+), 6 deletions(-) diff --git a/tests/test_array_inputs.py b/tests/test_array_inputs.py index 1c557a2..494edf5 100644 --- a/tests/test_array_inputs.py +++ b/tests/test_array_inputs.py @@ -242,14 +242,13 @@ def test_invalid_mesh_type_raises(self): # --------------------------------------------------------------------------- class TestSnap1ArrayInputs: - """Integration test for the array-input pathway of snap1. + """Tests for the geometry-preparation layer used by snap1. - We only test the geometry-preparation layer here (not OpenGL rendering) - so that the test suite can run in headless CI without a display. + These run without any OpenGL context and are safe for headless CI. """ def test_prepare_geometry_called_by_snap1_path(self): - """Verify that prepare_geometry accepts the same args snap1 would pass.""" + """prepare_geometry accepts the same args snap1 would pass.""" overlay = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) roi = np.array([True, True, True, False], dtype=bool) bg = np.array([-1.0, 1.0, -0.5, 0.5], dtype=np.float32) @@ -260,8 +259,83 @@ def test_prepare_geometry_called_by_snap1_path(self): assert tris is not None def test_prepare_geometry_bg_map_array_no_error(self): - """Verify bg_map as array raises no error.""" + """bg_map as array raises no error.""" bg = np.array([-0.5, 0.5, -0.3, 0.3], dtype=np.float32) vdata, tris, *_ = prepare_geometry((_V, _F), bg_map=bg) assert vdata.shape == (_N, 9) + +# --------------------------------------------------------------------------- +# snap1 rendering — actual OpenGL image output +# --------------------------------------------------------------------------- + +def _snap1_offscreen(**kwargs): + """Call snap1 with an invisible (offscreen) GLFW context. + + On macOS a visible GLFW window goes through the Cocoa compositor; the + first glReadPixels call may return all-black before the compositor has + finished its first composite pass. An invisible context renders + directly to the driver framebuffer and reads back correctly. + + Skips the test automatically if no OpenGL context can be created + (headless CI without GPU or EGL support). + """ + import whippersnappy.gl.utils as gl_utils # noqa: PLC0415 + + original = gl_utils.create_window_with_fallback + + def _invisible(*args, **kw): + kw["visible"] = False + return original(*args, **kw) + + gl_utils.create_window_with_fallback = _invisible + try: + from whippersnappy import snap1 # noqa: PLC0415 + return snap1(width=200, height=200, colorbar=False, **kwargs) + except RuntimeError as exc: + if "context" in str(exc).lower() or "opengl" in str(exc).lower(): + pytest.skip(f"No OpenGL context available: {exc}") + raise + finally: + gl_utils.create_window_with_fallback = original + + +class TestSnap1Rendering: + """End-to-end rendering tests: snap1 must return a non-empty PIL Image. + + All tests use the tetrahedron mesh (_V, _F) — a true 3-D shape that is + visible from any camera direction, unlike a flat surface which can + appear edge-on and produce an all-black image. + + Tests use an offscreen GLFW context (see ``_snap1_offscreen``) and are + skipped automatically when no OpenGL context is available. + """ + + def test_snap1_basic(self): + """snap1 returns the right size and renders a non-uniform image.""" + img = _snap1_offscreen(mesh=(_V, _F)) + assert img.width == 200 + assert img.height == 200 + arr = np.array(img) + assert arr.min() != arr.max(), ( + "Rendered image is completely uniform — shading is not working." + ) + + def test_snap1_with_overlay_and_roi(self): + """Overlay + ROI mask: image is non-uniform and correct size.""" + overlay = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) + roi = np.array([True, True, True, False], dtype=bool) + img = _snap1_offscreen( + mesh=(_V, _F), overlay=overlay, roi=roi, fthresh=0.0, fmax=1.0 + ) + assert img.width == 200 + arr = np.array(img) + assert arr.min() != arr.max(), "Image with overlay+ROI is completely uniform." + + def test_snap1_with_bg_map(self): + """bg_map array: image is non-uniform.""" + bg = np.array([-1.0, 1.0, -0.5, 0.5], dtype=np.float32) + img = _snap1_offscreen(mesh=(_V, _F), bg_map=bg) + arr = np.array(img) + assert arr.min() != arr.max(), "Image with bg_map is completely uniform." + diff --git a/whippersnappy/gl/utils.py b/whippersnappy/gl/utils.py index 0e546e0..288e383 100644 --- a/whippersnappy/gl/utils.py +++ b/whippersnappy/gl/utils.py @@ -392,7 +392,11 @@ def capture_window(window): # --- GLFW path: read from the default framebuffer --- monitor = glfw.get_primary_monitor() - x_scale, y_scale = glfw.get_monitor_content_scale(monitor) + if monitor is None: + # Invisible / offscreen GLFW window — no monitor, no HiDPI scaling. + x_scale, y_scale = 1.0, 1.0 + else: + x_scale, y_scale = glfw.get_monitor_content_scale(monitor) width, height = glfw.get_framebuffer_size(window) logger.debug("Framebuffer size = (%s,%s)", width, height) From 7d7c3a38e8a1c73634c34a8dc463aa25134a09b3 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sat, 21 Feb 2026 22:39:56 +0100 Subject: [PATCH 75/83] reduce and adapt tests --- ..._inputs.py => test_array_and_rendering.py} | 204 +++------- tests/test_mesh_io.py | 354 ++++++------------ tests/test_overlay_io.py | 269 +++++-------- 3 files changed, 261 insertions(+), 566 deletions(-) rename tests/{test_array_inputs.py => test_array_and_rendering.py} (57%) diff --git a/tests/test_array_inputs.py b/tests/test_array_and_rendering.py similarity index 57% rename from tests/test_array_inputs.py rename to tests/test_array_and_rendering.py index 494edf5..6ffce2e 100644 --- a/tests/test_array_inputs.py +++ b/tests/test_array_and_rendering.py @@ -35,70 +35,48 @@ # --------------------------------------------------------------------------- class TestResolveMesh: - def test_tuple_input(self): + def test_valid_inputs(self): v, f = resolve_mesh((_V, _F)) - assert v.shape == (4, 3) - assert v.dtype == np.float32 - assert f.shape == (4, 3) - assert f.dtype == np.uint32 + assert v.shape == (4, 3) and v.dtype == np.float32 + assert f.shape == (4, 3) and f.dtype == np.uint32 + # list input should also work + v2, f2 = resolve_mesh([_V, _F]) + assert v2.shape == (4, 3) and f2.shape == (4, 3) - def test_list_input(self): - v, f = resolve_mesh([_V, _F]) - assert v.shape == (4, 3) - assert f.shape == (4, 3) - - def test_wrong_type_raises(self): + def test_invalid_inputs_raise(self): with pytest.raises(TypeError): resolve_mesh(42) - - def test_wrong_shape_vertices_raises(self): - bad_v = np.ones((4, 4), dtype=np.float32) with pytest.raises(ValueError): - resolve_mesh((bad_v, _F)) - - def test_wrong_shape_faces_raises(self): - bad_f = np.ones((4, 4), dtype=np.uint32) + resolve_mesh((np.ones((4, 4), dtype=np.float32), _F)) with pytest.raises(ValueError): - resolve_mesh((_V, bad_f)) + resolve_mesh((_V, np.ones((4, 4), dtype=np.uint32))) # --------------------------------------------------------------------------- -# resolve_overlay / resolve_bg_map +# resolve_overlay / resolve_bg_map (identical logic, tested together) # --------------------------------------------------------------------------- -class TestResolveOverlay: - def test_none_returns_none(self): - assert resolve_overlay(None, n_vertices=_N) is None +class TestResolveScalarOverlay: + """Tests for resolve_overlay and resolve_bg_map (same logic).""" - def test_array_input(self): + @pytest.mark.parametrize("fn", [resolve_overlay, resolve_bg_map]) + def test_none_returns_none(self, fn): + assert fn(None, n_vertices=_N) is None + + @pytest.mark.parametrize("fn", [resolve_overlay, resolve_bg_map]) + def test_array_input_shape_and_dtype(self, fn): arr = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) - result = resolve_overlay(arr, n_vertices=_N) - assert result.shape == (_N,) - assert result.dtype == np.float32 + result = fn(arr, n_vertices=_N) + assert result.shape == (_N,) and result.dtype == np.float32 - def test_shape_mismatch_raises(self): - arr = np.array([0.1, 0.5], dtype=np.float32) + @pytest.mark.parametrize("fn", [resolve_overlay, resolve_bg_map]) + def test_shape_mismatch_raises(self, fn): with pytest.raises(ValueError): - resolve_overlay(arr, n_vertices=_N) + fn(np.ones(2), n_vertices=_N) - def test_n_vertices_none_skips_check(self): + def test_n_vertices_none_skips_shape_check(self): arr = np.array([0.1, 0.5], dtype=np.float32) - result = resolve_overlay(arr, n_vertices=None) - assert result.shape == (2,) - - -class TestResolveBgMap: - def test_none_returns_none(self): - assert resolve_bg_map(None, n_vertices=_N) is None - - def test_array_input(self): - arr = np.array([-1.0, 1.0, -0.5, 0.5], dtype=np.float32) - result = resolve_bg_map(arr, n_vertices=_N) - assert result.shape == (_N,) - - def test_shape_mismatch_raises(self): - with pytest.raises(ValueError): - resolve_bg_map(np.ones(2), n_vertices=_N) + assert resolve_overlay(arr, n_vertices=None).shape == (2,) # --------------------------------------------------------------------------- @@ -112,8 +90,7 @@ def test_none_returns_none(self): def test_bool_array(self): roi = np.array([True, True, True, False], dtype=bool) result = resolve_roi(roi, n_vertices=_N) - assert result.dtype == bool - assert result.shape == (_N,) + assert result.dtype == bool and result.shape == (_N,) assert result[3] is np.bool_(False) def test_shape_mismatch_raises(self): @@ -129,30 +106,22 @@ class TestResolveAnnot: def test_none_returns_none(self): assert resolve_annot(None, n_vertices=_N) is None - def test_two_tuple(self): + def test_two_and_three_tuple(self): labels = np.array([0, 1, 0, 1]) ctab = np.array([[255, 0, 0, 0, 0], [0, 255, 0, 0, 1]]) - result = resolve_annot((labels, ctab), n_vertices=_N) - assert result is not None - assert len(result) == 3 - assert result[2] is None # names - - def test_three_tuple(self): - labels = np.zeros(_N, dtype=int) - ctab = np.array([[200, 100, 50, 0, 1]]) - names = ["region0"] - result = resolve_annot((labels, ctab, names), n_vertices=_N) - assert result[2] == names - - def test_shape_mismatch_raises(self): - labels = np.zeros(2, dtype=int) - ctab = np.array([[255, 0, 0, 0, 0]]) - with pytest.raises(ValueError): - resolve_annot((labels, ctab), n_vertices=_N) - - def test_wrong_type_raises(self): + # two-tuple: names should be None + r2 = resolve_annot((labels, ctab), n_vertices=_N) + assert len(r2) == 3 and r2[2] is None + # three-tuple: names passed through + names = ["a", "b"] + r3 = resolve_annot((labels, ctab, names), n_vertices=_N) + assert r3[2] == names + + def test_invalid_inputs_raise(self): with pytest.raises(TypeError): resolve_annot(42, n_vertices=_N) + with pytest.raises(ValueError): + resolve_annot((np.zeros(2, dtype=int), np.array([[255, 0, 0, 0, 0]])), n_vertices=_N) # --------------------------------------------------------------------------- @@ -160,17 +129,13 @@ def test_wrong_type_raises(self): # --------------------------------------------------------------------------- class TestEstimateOverlayThresholds: - def test_array_input(self): + def test_auto_and_passthrough(self): arr = np.array([1.0, 2.0, 3.0, -1.5], dtype=np.float32) fmin, fmax = estimate_overlay_thresholds(arr) - assert fmin >= 0 - assert fmax == pytest.approx(3.0) - - def test_passthrough_when_provided(self): - arr = np.array([1.0, 2.0, 3.0], dtype=np.float32) - fmin, fmax = estimate_overlay_thresholds(arr, minval=0.5, maxval=5.0) - assert fmin == pytest.approx(0.5) - assert fmax == pytest.approx(5.0) + assert fmin >= 0 and fmax == pytest.approx(3.0) + # explicit values passed through unchanged + fmin2, fmax2 = estimate_overlay_thresholds(arr, minval=0.5, maxval=5.0) + assert fmin2 == pytest.approx(0.5) and fmax2 == pytest.approx(5.0) # --------------------------------------------------------------------------- @@ -180,33 +145,21 @@ def test_passthrough_when_provided(self): class TestPrepareGeometryFromArrays: def test_no_overlay(self): vdata, tris, fmin, fmax, pos, neg = prepare_geometry_from_arrays(_V, _F) - assert vdata.shape == (_N, 9) - assert tris.shape == (4, 3) + assert vdata.shape == (_N, 9) and tris.shape == (4, 3) assert fmin is None and fmax is None - def test_with_overlay(self): + def test_with_overlay_bg_map_roi(self): overlay = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) - vdata, tris, fmin, fmax, pos, neg = prepare_geometry_from_arrays( - _V, _F, overlay=overlay - ) - assert vdata.shape == (_N, 9) - assert fmin is not None and fmax is not None - - def test_with_bg_map(self): bg = np.array([-1.0, 1.0, -0.5, 0.5], dtype=np.float32) - vdata, tris, *_ = prepare_geometry_from_arrays(_V, _F, bg_map=bg) - assert vdata.shape == (_N, 9) - - def test_with_roi(self): - overlay = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) roi = np.array([True, True, True, False], dtype=bool) - vdata, tris, *_ = prepare_geometry_from_arrays(_V, _F, overlay=overlay, roi=roi) - assert vdata.shape == (_N, 9) + vdata, tris, fmin, fmax, pos, neg = prepare_geometry_from_arrays( + _V, _F, overlay=overlay, bg_map=bg, roi=roi + ) + assert vdata.shape == (_N, 9) and fmin is not None def test_overlay_shape_mismatch_raises(self): - overlay = np.array([0.1, 0.5], dtype=np.float32) # wrong length with pytest.raises(ValueError): - prepare_geometry_from_arrays(_V, _F, overlay=overlay) + prepare_geometry_from_arrays(_V, _F, overlay=np.array([0.1, 0.5])) # --------------------------------------------------------------------------- @@ -214,57 +167,21 @@ def test_overlay_shape_mismatch_raises(self): # --------------------------------------------------------------------------- class TestPrepareGeometry: - def test_tuple_mesh_no_overlay(self): - vdata, tris, *_ = prepare_geometry((_V, _F)) - assert vdata.shape == (_N, 9) - - def test_tuple_mesh_with_overlay_and_roi(self): + def test_tuple_mesh_various_inputs(self): + """One call covers: tuple mesh, overlay, roi, bg_map — all together.""" overlay = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) roi = np.array([True, True, True, False], dtype=bool) + bg = np.array([-1.0, 1.0, -0.5, 0.5], dtype=np.float32) vdata, tris, fmin, fmax, pos, neg = prepare_geometry( - (_V, _F), overlay=overlay, roi=roi + (_V, _F), overlay=overlay, roi=roi, bg_map=bg ) - assert vdata.shape == (_N, 9) - assert fmin is not None - - def test_tuple_mesh_with_bg_map(self): - bg = np.array([-1.0, 1.0, -0.5, 0.5], dtype=np.float32) - vdata, tris, *_ = prepare_geometry((_V, _F), bg_map=bg) - assert vdata.shape == (_N, 9) + assert vdata.shape == (_N, 9) and fmin is not None def test_invalid_mesh_type_raises(self): with pytest.raises(TypeError): prepare_geometry(12345) -# --------------------------------------------------------------------------- -# snap1 array-input integration (no OpenGL — just geometry prep) -# --------------------------------------------------------------------------- - -class TestSnap1ArrayInputs: - """Tests for the geometry-preparation layer used by snap1. - - These run without any OpenGL context and are safe for headless CI. - """ - - def test_prepare_geometry_called_by_snap1_path(self): - """prepare_geometry accepts the same args snap1 would pass.""" - overlay = np.array([0.1, 0.5, 0.9, 0.3], dtype=np.float32) - roi = np.array([True, True, True, False], dtype=bool) - bg = np.array([-1.0, 1.0, -0.5, 0.5], dtype=np.float32) - vdata, tris, fmin, fmax, pos, neg = prepare_geometry( - (_V, _F), overlay=overlay, roi=roi, bg_map=bg - ) - assert vdata is not None - assert tris is not None - - def test_prepare_geometry_bg_map_array_no_error(self): - """bg_map as array raises no error.""" - bg = np.array([-0.5, 0.5, -0.3, 0.3], dtype=np.float32) - vdata, tris, *_ = prepare_geometry((_V, _F), bg_map=bg) - assert vdata.shape == (_N, 9) - - # --------------------------------------------------------------------------- # snap1 rendering — actual OpenGL image output # --------------------------------------------------------------------------- @@ -314,8 +231,7 @@ class TestSnap1Rendering: def test_snap1_basic(self): """snap1 returns the right size and renders a non-uniform image.""" img = _snap1_offscreen(mesh=(_V, _F)) - assert img.width == 200 - assert img.height == 200 + assert img.width == 200 and img.height == 200 arr = np.array(img) assert arr.min() != arr.max(), ( "Rendered image is completely uniform — shading is not working." @@ -329,13 +245,13 @@ def test_snap1_with_overlay_and_roi(self): mesh=(_V, _F), overlay=overlay, roi=roi, fthresh=0.0, fmax=1.0 ) assert img.width == 200 - arr = np.array(img) - assert arr.min() != arr.max(), "Image with overlay+ROI is completely uniform." + assert np.array(img).min() != np.array(img).max(), ( + "Image with overlay+ROI is completely uniform." + ) def test_snap1_with_bg_map(self): """bg_map array: image is non-uniform.""" bg = np.array([-1.0, 1.0, -0.5, 0.5], dtype=np.float32) - img = _snap1_offscreen(mesh=(_V, _F), bg_map=bg) - arr = np.array(img) + arr = np.array(_snap1_offscreen(mesh=(_V, _F), bg_map=bg)) assert arr.min() != arr.max(), "Image with bg_map is completely uniform." diff --git a/tests/test_mesh_io.py b/tests/test_mesh_io.py index d8d76a4..7a4b589 100644 --- a/tests/test_mesh_io.py +++ b/tests/test_mesh_io.py @@ -75,6 +75,12 @@ 3 1 2 3 """ +_SAMPLES = { + ".off": _TETRA_OFF, + ".vtk": _TETRA_VTK, + ".ply": _TETRA_PLY, +} + def _write_tmp(content, suffix): """Write *content* to a named temp file, return its path.""" @@ -107,55 +113,31 @@ def test_basic(self): v, f = read_off(path) finally: os.unlink(path) - assert v.shape == (4, 3) - assert v.dtype == np.float32 - assert f.shape == (4, 3) - assert f.dtype == np.uint32 + assert v.shape == (4, 3) and v.dtype == np.float32 + assert f.shape == (4, 3) and f.dtype == np.uint32 np.testing.assert_array_equal(v, _expected_verts()) np.testing.assert_array_equal(f, _expected_faces()) def test_bundled_sample(self): """Verify the bundled tests/data/tetra.off file loads correctly.""" here = os.path.dirname(__file__) - sample = os.path.join(here, "data", "tetra.off") - v, f = read_off(sample) - assert v.shape == (4, 3) - assert f.shape == (4, 3) - - def test_bad_header_raises(self): - content = "NOFF\n4 4 6\n0 0 0\n1 0 0\n0 1 0\n0 0 1\n3 0 1 2\n3 0 1 3\n3 0 2 3\n3 1 2 3\n" - path = _write_tmp(content, ".off") - try: - with pytest.raises(ValueError, match="OFF"): - read_off(path) - finally: - os.unlink(path) - - def test_quad_face_raises(self): - content = "OFF\n4 1 4\n0 0 0\n1 0 0\n0 1 0\n0 0 1\n4 0 1 2 3\n" - path = _write_tmp(content, ".off") - try: - with pytest.raises(ValueError, match="triangles"): - read_off(path) - finally: - os.unlink(path) - - def test_out_of_range_indices_raises(self): - content = "OFF\n3 1 3\n0 0 0\n1 0 0\n0 1 0\n3 0 1 99\n" - path = _write_tmp(content, ".off") - try: - with pytest.raises(ValueError, match="out of range"): - read_off(path) - finally: - os.unlink(path) - - def test_empty_file_raises(self): - path = _write_tmp("", ".off") - try: - with pytest.raises(ValueError, match="empty"): - read_off(path) - finally: - os.unlink(path) + v, f = read_off(os.path.join(here, "data", "tetra.off")) + assert v.shape == (4, 3) and f.shape == (4, 3) + + def test_error_cases(self): + cases = [ + ("", ".off", "empty"), + ("NOFF\n4 4 6\n0 0 0\n1 0 0\n0 1 0\n0 0 1\n3 0 1 2\n3 0 1 3\n3 0 2 3\n3 1 2 3\n", ".off", "OFF"), + ("OFF\n4 1 4\n0 0 0\n1 0 0\n0 1 0\n0 0 1\n4 0 1 2 3\n", ".off", "triangles"), + ("OFF\n3 1 3\n0 0 0\n1 0 0\n0 1 0\n3 0 1 99\n", ".off", "out of range"), + ] + for content, suffix, match in cases: + path = _write_tmp(content, suffix) + try: + with pytest.raises(ValueError, match=match): + read_off(path) + finally: + os.unlink(path) # --------------------------------------------------------------------------- @@ -169,55 +151,33 @@ def test_basic(self): v, f = read_vtk_ascii_polydata(path) finally: os.unlink(path) - assert v.shape == (4, 3) - assert v.dtype == np.float32 - assert f.shape == (4, 3) - assert f.dtype == np.uint32 + assert v.shape == (4, 3) and v.dtype == np.float32 + assert f.shape == (4, 3) and f.dtype == np.uint32 np.testing.assert_array_equal(v, _expected_verts()) np.testing.assert_array_equal(f, _expected_faces()) - def test_binary_vtk_raises(self): - content = "# vtk DataFile Version 3.0\ntest\nBINARY\nDATASET POLYDATA\n" - path = _write_tmp(content, ".vtk") - try: - with pytest.raises(ValueError, match="BINARY"): - read_vtk_ascii_polydata(path) - finally: - os.unlink(path) - - def test_non_polydata_raises(self): - content = "# vtk DataFile Version 3.0\ntest\nASCII\nDATASET UNSTRUCTURED_GRID\n" - path = _write_tmp(content, ".vtk") - try: - with pytest.raises(ValueError, match="POLYDATA"): - read_vtk_ascii_polydata(path) - finally: - os.unlink(path) - - def test_quad_polygon_raises(self): - content = ( - "# vtk DataFile Version 3.0\ntest\nASCII\nDATASET POLYDATA\n" - "POINTS 4 float\n0 0 0\n1 0 0\n0 1 0\n0 0 1\n" - "POLYGONS 1 5\n4 0 1 2 3\n" - ) - path = _write_tmp(content, ".vtk") - try: - with pytest.raises(ValueError, match="triangles"): - read_vtk_ascii_polydata(path) - finally: - os.unlink(path) - - def test_missing_points_raises(self): - content = ( - "# vtk DataFile Version 3.0\ntest\nASCII\nDATASET POLYDATA\n" - "POLYGONS 1 4\n3 0 1 2\n" - ) - path = _write_tmp(content, ".vtk") - try: - with pytest.raises(ValueError, match="POINTS"): - read_vtk_ascii_polydata(path) - finally: - os.unlink(path) + def test_error_cases(self): + cases = [ + ("# vtk DataFile Version 3.0\ntest\nBINARY\nDATASET POLYDATA\n", "BINARY"), + ("# vtk DataFile Version 3.0\ntest\nASCII\nDATASET UNSTRUCTURED_GRID\n", "POLYDATA"), + ( + "# vtk DataFile Version 3.0\ntest\nASCII\nDATASET POLYDATA\n" + "POINTS 4 float\n0 0 0\n1 0 0\n0 1 0\n0 0 1\nPOLYGONS 1 5\n4 0 1 2 3\n", + "triangles", + ), + ( + "# vtk DataFile Version 3.0\ntest\nASCII\nDATASET POLYDATA\n" + "POLYGONS 1 4\n3 0 1 2\n", + "POINTS", + ), + ] + for content, match in cases: + path = _write_tmp(content, ".vtk") + try: + with pytest.raises(ValueError, match=match): + read_vtk_ascii_polydata(path) + finally: + os.unlink(path) # --------------------------------------------------------------------------- @@ -231,10 +191,8 @@ def test_basic(self): v, f = read_ply_ascii(path) finally: os.unlink(path) - assert v.shape == (4, 3) - assert v.dtype == np.float32 - assert f.shape == (4, 3) - assert f.dtype == np.uint32 + assert v.shape == (4, 3) and v.dtype == np.float32 + assert f.shape == (4, 3) and f.dtype == np.uint32 np.testing.assert_array_equal(v, _expected_verts()) np.testing.assert_array_equal(f, _expected_faces()) @@ -263,29 +221,10 @@ def test_extra_vertex_props(self): v, f = read_ply_ascii(path) finally: os.unlink(path) - assert v.shape == (3, 3) - assert f.shape == (1, 3) + assert v.shape == (3, 3) and f.shape == (1, 3) - def test_binary_ply_raises(self): - content = "ply\nformat binary_little_endian 1.0\nelement vertex 4\nend_header\n" - path = _write_tmp(content, ".ply") - try: - with pytest.raises(ValueError, match="binary"): - read_ply_ascii(path) - finally: - os.unlink(path) - - def test_not_ply_raises(self): - content = "OFF\n4 4 6\n0 0 0\n1 0 0\n0 1 0\n0 0 1\n3 0 1 2\n" - path = _write_tmp(content, ".ply") - try: - with pytest.raises(ValueError, match="ply"): - read_ply_ascii(path) - finally: - os.unlink(path) - - def test_quad_face_raises(self): - content = """\ + def test_error_cases(self): + quad_ply = """\ ply format ascii 1.0 element vertex 4 @@ -301,37 +240,40 @@ def test_quad_face_raises(self): 0.0 0.0 1.0 4 0 1 2 3 """ - path = _write_tmp(content, ".ply") - try: - with pytest.raises(ValueError, match="triangles"): - read_ply_ascii(path) - finally: - os.unlink(path) + cases = [ + ("ply\nformat binary_little_endian 1.0\nelement vertex 4\nend_header\n", "binary"), + ("OFF\n4 4 6\n0 0 0\n1 0 0\n0 1 0\n0 0 1\n3 0 1 2\n", "ply"), + (quad_ply, "triangles"), + ] + for content, match in cases: + path = _write_tmp(content, ".ply") + try: + with pytest.raises(ValueError, match=match): + read_ply_ascii(path) + finally: + os.unlink(path) # --------------------------------------------------------------------------- -# read_mesh dispatcher +# read_mesh dispatcher + resolve_mesh routing (combined) +# Each format is tested end-to-end through resolve_mesh (highest-level call). +# read_mesh itself is only tested for error cases not covered above. # --------------------------------------------------------------------------- -class TestReadMeshDispatcher: - def test_off_dispatch(self): - path = _write_tmp(_TETRA_OFF, ".off") - try: - v, f = read_mesh(path) - finally: - os.unlink(path) - assert v.shape == (4, 3) - - def test_vtk_dispatch(self): - path = _write_tmp(_TETRA_VTK, ".vtk") +class TestMeshDispatchAndRouting: + @pytest.mark.parametrize("suffix,content", list(_SAMPLES.items())) + def test_resolve_mesh_path(self, suffix, content): + """resolve_mesh routes each format to the right reader and returns correct dtypes.""" + path = _write_tmp(content, suffix) try: - v, f = read_mesh(path) + v, f = resolve_mesh(path) finally: os.unlink(path) - assert v.shape == (4, 3) + assert v.shape == (4, 3) and v.dtype == np.float32 + assert f.shape == (4, 3) and f.dtype == np.uint32 - def test_ply_dispatch(self): - path = _write_tmp(_TETRA_PLY, ".ply") + def test_case_insensitive_extension(self): + path = _write_tmp(_TETRA_OFF, ".OFF") try: v, f = read_mesh(path) finally: @@ -342,60 +284,16 @@ def test_unknown_extension_raises(self): with pytest.raises(ValueError, match="Unsupported"): read_mesh("/some/file.stl") - def test_case_insensitive_extension(self): - """Uppercase .OFF extension should be recognised.""" - path = _write_tmp(_TETRA_OFF, ".OFF") - try: - v, f = read_mesh(path) - finally: - os.unlink(path) - assert v.shape == (4, 3) - - -# --------------------------------------------------------------------------- -# resolve_mesh routing -# --------------------------------------------------------------------------- - -class TestResolveMeshRouting: - def test_off_path_routed(self): - path = _write_tmp(_TETRA_OFF, ".off") - try: - v, f = resolve_mesh(path) - finally: - os.unlink(path) - assert v.shape == (4, 3) - assert v.dtype == np.float32 - assert f.shape == (4, 3) - assert f.dtype == np.uint32 - - def test_vtk_path_routed(self): - path = _write_tmp(_TETRA_VTK, ".vtk") - try: - v, f = resolve_mesh(path) - finally: - os.unlink(path) - assert v.shape == (4, 3) - - def test_ply_path_routed(self): - path = _write_tmp(_TETRA_PLY, ".ply") - try: - v, f = resolve_mesh(path) - finally: - os.unlink(path) - assert v.shape == (4, 3) - def test_array_tuple_still_works(self): v_in = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32) f_in = np.array([[0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3]], dtype=np.uint32) v, f = resolve_mesh((v_in, f_in)) - assert v.shape == (4, 3) - assert f.shape == (4, 3) + assert v.shape == (4, 3) and f.shape == (4, 3) def test_array_out_of_range_raises(self): v_in = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]], dtype=np.float32) - f_in = np.array([[0, 1, 99]], dtype=np.uint32) # index 99 out of range with pytest.raises(ValueError, match="out of range"): - resolve_mesh((v_in, f_in)) + resolve_mesh((v_in, np.array([[0, 1, 99]], dtype=np.uint32))) # --------------------------------------------------------------------------- @@ -405,17 +303,11 @@ def test_array_out_of_range_raises(self): def _make_surf_gii(verts, faces, suffix=".surf.gii"): """Write a minimal GIfTI surface file and return its path.""" import nibabel as nib - _INTENT_POINTSET = 1008 - _INTENT_TRIANGLE = 1009 coords_da = nib.gifti.GiftiDataArray( - data=verts.astype(np.float32), - intent=_INTENT_POINTSET, - datatype="NIFTI_TYPE_FLOAT32", + data=verts.astype(np.float32), intent=1008, datatype="NIFTI_TYPE_FLOAT32", ) faces_da = nib.gifti.GiftiDataArray( - data=faces.astype(np.int32), - intent=_INTENT_TRIANGLE, - datatype="NIFTI_TYPE_INT32", + data=faces.astype(np.int32), intent=1009, datatype="NIFTI_TYPE_INT32", ) img = nib.gifti.GiftiImage(darrays=[coords_da, faces_da]) fd, path = tempfile.mkstemp(suffix=suffix) @@ -429,35 +321,39 @@ def _make_surf_gii(verts, faces, suffix=".surf.gii"): class TestReadGiftiSurface: - def test_surf_gii_basic(self): + def test_basic_and_dispatch(self): + """read_gifti_surface, read_mesh, and resolve_mesh all load .surf.gii correctly.""" path = _make_surf_gii(_V4, _F4, ".surf.gii") try: v, f = read_gifti_surface(path) + assert v.shape == (4, 3) and v.dtype == np.float32 + assert f.shape == (4, 3) and f.dtype == np.uint32 + np.testing.assert_allclose(v, _V4, atol=1e-6) + # dispatch through read_mesh and resolve_mesh also work + v2, _ = read_mesh(path) + assert v2.shape == (4, 3) + v3, f3 = resolve_mesh(path) + assert v3.dtype == np.float32 and f3.dtype == np.uint32 finally: os.unlink(path) - assert v.shape == (4, 3) - assert v.dtype == np.float32 - assert f.shape == (4, 3) - assert f.dtype == np.uint32 - np.testing.assert_allclose(v, _V4, atol=1e-6) - np.testing.assert_array_equal(f, _F4) def test_plain_gii_extension(self): - """A .gii file with POINTSET+TRIANGLE should also be loaded as surface.""" path = _make_surf_gii(_V4, _F4, ".gii") try: v, f = read_gifti_surface(path) + assert v.shape == (4, 3) and f.shape == (4, 3) + # dispatch also works for plain .gii + v2, _ = read_mesh(path) + assert v2.shape == (4, 3) finally: os.unlink(path) - assert v.shape == (4, 3) - assert f.shape == (4, 3) - def test_no_pointset_raises(self): - """A plain scalar .gii without POINTSET arrays should raise.""" + def test_missing_arrays_raise(self): + """Missing POINTSET or TRIANGLE array raises a clear ValueError.""" import nibabel as nib + # No POINTSET — only a scalar array scalar_da = nib.gifti.GiftiDataArray( - data=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), - intent=0, # NIFTI_INTENT_NONE + data=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), intent=0, ) img = nib.gifti.GiftiImage(darrays=[scalar_da]) fd, path = tempfile.mkstemp(suffix=".gii") @@ -469,48 +365,14 @@ def test_no_pointset_raises(self): finally: os.unlink(path) - def test_no_triangle_raises(self): - """A .gii with only a POINTSET but no TRIANGLE array should raise.""" - import nibabel as nib - coords_da = nib.gifti.GiftiDataArray( - data=_V4.astype(np.float32), - intent=1008, - ) - img = nib.gifti.GiftiImage(darrays=[coords_da]) - fd, path = tempfile.mkstemp(suffix=".gii") + # POINTSET but no TRIANGLE + coords_da = nib.gifti.GiftiDataArray(data=_V4.astype(np.float32), intent=1008) + img2 = nib.gifti.GiftiImage(darrays=[coords_da]) + fd, path2 = tempfile.mkstemp(suffix=".gii") os.close(fd) - nib.save(img, path) + nib.save(img2, path2) try: with pytest.raises(ValueError, match="TRIANGLE"): - read_gifti_surface(path) + read_gifti_surface(path2) finally: - os.unlink(path) - - -class TestReadMeshGiftiDispatch: - def test_surf_gii_dispatched(self): - path = _make_surf_gii(_V4, _F4, ".surf.gii") - try: - v, f = read_mesh(path) - finally: - os.unlink(path) - assert v.shape == (4, 3) - - def test_gii_dispatched(self): - path = _make_surf_gii(_V4, _F4, ".gii") - try: - v, f = read_mesh(path) - finally: - os.unlink(path) - assert v.shape == (4, 3) - - def test_resolve_mesh_surf_gii(self): - path = _make_surf_gii(_V4, _F4, ".surf.gii") - try: - v, f = resolve_mesh(path) - finally: - os.unlink(path) - assert v.shape == (4, 3) - assert v.dtype == np.float32 - assert f.dtype == np.uint32 - + os.unlink(path2) diff --git a/tests/test_overlay_io.py b/tests/test_overlay_io.py index 18c82d3..f5bb5e4 100644 --- a/tests/test_overlay_io.py +++ b/tests/test_overlay_io.py @@ -35,75 +35,21 @@ def _write_tmp(content, suffix, binary=False): # --------------------------------------------------------------------------- -# read_txt — float +# read_txt # --------------------------------------------------------------------------- -class TestReadTxtFloat: - def test_basic_floats(self): - content = "\n".join(str(v) for v in _FLOAT_VALUES) + "\n" - path = _write_tmp(content, ".txt") +class TestReadTxt: + def test_float_basic(self): + path = _write_tmp("\n".join(str(v) for v in _FLOAT_VALUES) + "\n", ".txt") try: arr = read_txt(path) finally: os.unlink(path) - assert arr.shape == (4,) - assert arr.dtype == np.float32 + assert arr.shape == (4,) and arr.dtype == np.float32 np.testing.assert_allclose(arr, _FLOAT_VALUES, atol=1e-6) - def test_hash_comment_skipped(self): - content = "# this is a comment\n0.5\n1.5\n" - path = _write_tmp(content, ".txt") - try: - arr = read_txt(path) - finally: - os.unlink(path) - assert arr.shape == (2,) - - def test_text_header_skipped(self): - content = "value\n0.1\n0.2\n0.3\n" - path = _write_tmp(content, ".txt") - try: - arr = read_txt(path) - finally: - os.unlink(path) - assert arr.shape == (3,) - np.testing.assert_allclose(arr, [0.1, 0.2, 0.3], atol=1e-6) - - def test_csv_first_column_used(self): - content = "label,ignore\n1.0,extra\n2.0,extra\n" - path = _write_tmp(content, ".csv") - try: - arr = read_txt(path) - finally: - os.unlink(path) - assert arr.shape == (2,) - - def test_empty_file_raises(self): - path = _write_tmp("", ".txt") - try: - with pytest.raises(ValueError, match="No numeric"): - read_txt(path) - finally: - os.unlink(path) - - def test_bad_value_raises(self): - content = "1.0\nbadvalue\n3.0\n" - path = _write_tmp(content, ".txt") - try: - with pytest.raises(ValueError, match="Could not parse"): - read_txt(path) - finally: - os.unlink(path) - - -# --------------------------------------------------------------------------- -# read_txt — integer promotion -# --------------------------------------------------------------------------- - -class TestReadTxtInt: - def test_integer_values_promoted_to_int32(self): - content = "\n".join(str(v) for v in _INT_VALUES) + "\n" - path = _write_tmp(content, ".txt") + def test_integer_promoted_to_int32(self): + path = _write_tmp("\n".join(str(v) for v in _INT_VALUES) + "\n", ".txt") try: arr = read_txt(path) finally: @@ -112,79 +58,95 @@ def test_integer_values_promoted_to_int32(self): np.testing.assert_array_equal(arr, _INT_VALUES) def test_mixed_float_stays_float32(self): - content = "0\n1\n2.5\n3\n" - path = _write_tmp(content, ".txt") + path = _write_tmp("0\n1\n2.5\n3\n", ".txt") try: arr = read_txt(path) finally: os.unlink(path) assert arr.dtype == np.float32 + def test_headers_and_comments_skipped(self): + # hash comment, text header, CSV first column + for content, suffix, expected_len in [ + ("# comment\n0.5\n1.5\n", ".txt", 2), + ("value\n0.1\n0.2\n0.3\n", ".txt", 3), + ("label,ignore\n1.0,extra\n2.0,extra\n", ".csv", 2), + ]: + path = _write_tmp(content, suffix) + try: + arr = read_txt(path) + finally: + os.unlink(path) + assert arr.shape == (expected_len,) + + def test_error_cases(self): + for content, match in [ + ("", "No numeric"), + ("1.0\nbadvalue\n3.0\n", "Could not parse"), + ]: + path = _write_tmp(content, ".txt") + try: + with pytest.raises(ValueError, match=match): + read_txt(path) + finally: + os.unlink(path) + # --------------------------------------------------------------------------- # read_npy / read_npz # --------------------------------------------------------------------------- -class TestReadNpy: - def test_basic(self): +class TestReadNpyNpz: + def test_npy_basic(self): arr_in = np.array([1.0, 2.0, 3.0], dtype=np.float32) fd, path = tempfile.mkstemp(suffix=".npy") os.close(fd) np.save(path, arr_in) try: - arr = read_npy(path) + np.testing.assert_array_equal(read_npy(path), arr_in) finally: os.unlink(path) - np.testing.assert_array_equal(arr, arr_in) - def test_2d_raises(self): - arr_in = np.ones((3, 4), dtype=np.float32) + def test_npy_column_vector_squeezed(self): fd, path = tempfile.mkstemp(suffix=".npy") os.close(fd) - np.save(path, arr_in) + np.save(path, np.ones((5, 1), dtype=np.float32)) try: - with pytest.raises(ValueError, match="1-D"): - read_npy(path) + assert read_npy(path).shape == (5,) finally: os.unlink(path) - def test_column_vector_squeezed(self): - """Shape (N,1) should be squeezed to (N,) successfully.""" - arr_in = np.ones((5, 1), dtype=np.float32) + def test_npy_2d_raises(self): fd, path = tempfile.mkstemp(suffix=".npy") os.close(fd) - np.save(path, arr_in) + np.save(path, np.ones((3, 4), dtype=np.float32)) try: - arr = read_npy(path) + with pytest.raises(ValueError, match="1-D"): + read_npy(path) finally: os.unlink(path) - assert arr.shape == (5,) - -class TestReadNpz: - def test_data_key(self): + def test_npz_data_key_and_fallback(self): arr_in = np.array([0, 1, 2], dtype=np.int32) + # named 'data' key fd, path = tempfile.mkstemp(suffix=".npz") os.close(fd) np.savez(path, data=arr_in, other=np.zeros(3)) try: - arr = read_npz(path) + np.testing.assert_array_equal(read_npz(path), arr_in) finally: os.unlink(path) - np.testing.assert_array_equal(arr, arr_in) - - def test_first_array_fallback(self): - arr_in = np.array([9.0, 8.0], dtype=np.float32) - fd, path = tempfile.mkstemp(suffix=".npz") + # first-array fallback + arr2 = np.array([9.0, 8.0], dtype=np.float32) + fd, path2 = tempfile.mkstemp(suffix=".npz") os.close(fd) - np.savez(path, arr_0=arr_in) + np.savez(path2, arr_0=arr2) try: - arr = read_npz(path) + np.testing.assert_array_equal(read_npz(path2), arr2) finally: - os.unlink(path) - np.testing.assert_array_equal(arr, arr_in) + os.unlink(path2) - def test_empty_raises(self): + def test_npz_empty_raises(self): fd, path = tempfile.mkstemp(suffix=".npz") os.close(fd) np.savez(path) @@ -197,57 +159,32 @@ def test_empty_raises(self): # --------------------------------------------------------------------------- # read_overlay dispatcher +# Dispatch is implicitly covered by TestResolveOverlayRouting below; here we +# only test the error cases and the gifti rejection that are not exercised +# through resolve_overlay. # --------------------------------------------------------------------------- class TestReadOverlayDispatcher: - def test_txt_dispatched(self): - path = _write_tmp("1.0\n2.0\n3.0\n", ".txt") - try: - arr = read_overlay(path) - finally: - os.unlink(path) - assert arr.shape == (3,) - - def test_csv_dispatched(self): - path = _write_tmp("0.5\n1.5\n", ".csv") - try: - arr = read_overlay(path) - finally: - os.unlink(path) - assert arr.shape == (2,) - - def test_npy_dispatched(self): - arr_in = np.array([1.0, 2.0], dtype=np.float32) - fd, path = tempfile.mkstemp(suffix=".npy") - os.close(fd) - np.save(path, arr_in) - try: - arr = read_overlay(path) - finally: - os.unlink(path) - np.testing.assert_array_equal(arr, arr_in) + def test_unknown_extension_raises(self): + with pytest.raises(ValueError, match="Unsupported"): + read_overlay("/some/file.xyz") - def test_npz_dispatched(self): - arr_in = np.array([0, 1, 2], dtype=np.int32) - fd, path = tempfile.mkstemp(suffix=".npz") - os.close(fd) - np.savez(path, data=arr_in) + def test_case_insensitive_extension(self): + path = _write_tmp("1.0\n2.0\n", ".TXT") try: - arr = read_overlay(path) + assert read_overlay(path).shape == (2,) finally: os.unlink(path) - np.testing.assert_array_equal(arr, arr_in) def test_surface_gii_rejected_with_helpful_error(self): - """A .surf.gii passed to read_gifti (overlay reader) should raise clearly.""" import nibabel as nib coords_da = nib.gifti.GiftiDataArray( data=np.array([[0,0,0],[1,0,0],[0,1,0],[0,0,1]], dtype=np.float32), - intent=1008, # POINTSET + intent=1008, ) faces_da = nib.gifti.GiftiDataArray( data=np.array([[0,1,2],[0,1,3]], dtype=np.int32), - intent=1009, # TRIANGLE + intent=1009, ) img = nib.gifti.GiftiImage(darrays=[coords_da, faces_da]) fd, path = tempfile.mkstemp(suffix=".surf.gii") @@ -260,91 +197,71 @@ def test_surface_gii_rejected_with_helpful_error(self): finally: os.unlink(path) - def test_unknown_extension_raises(self): - with pytest.raises(ValueError, match="Unsupported"): - read_overlay("/some/file.xyz") - - def test_case_insensitive_extension(self): - path = _write_tmp("1.0\n2.0\n", ".TXT") - try: - arr = read_overlay(path) - finally: - os.unlink(path) - assert arr.shape == (2,) - # --------------------------------------------------------------------------- # resolve_overlay / resolve_bg_map / resolve_roi routing via inputs.py # --------------------------------------------------------------------------- class TestResolveOverlayRouting: - """Verify that resolve_overlay and resolve_bg_map pick up .txt/.npy files.""" - - def test_txt_path_routed_as_float(self): - content = "0.1\n0.5\n0.9\n0.3\n" - path = _write_tmp(content, ".txt") + """End-to-end: file path → resolve_overlay / resolve_bg_map / resolve_roi.""" + + @pytest.mark.parametrize("suffix,content,n", [ + (".txt", "0.1\n0.5\n0.9\n0.3\n", 4), + (".csv", "0.5\n1.5\n", 2), + ]) + def test_txt_csv_routed(self, suffix, content, n): + path = _write_tmp(content, suffix) try: - arr = resolve_overlay(path, n_vertices=4) + arr = resolve_overlay(path, n_vertices=n) finally: os.unlink(path) - assert arr.shape == (4,) - assert arr.dtype == np.float32 + assert arr.shape == (n,) and arr.dtype == np.float32 - def test_npy_path_routed(self): + def test_npy_routed(self): arr_in = np.array([1.0, 2.0, 3.0], dtype=np.float32) fd, path = tempfile.mkstemp(suffix=".npy") os.close(fd) np.save(path, arr_in) try: - arr = resolve_overlay(path, n_vertices=3) + np.testing.assert_array_equal(resolve_overlay(path, n_vertices=3), arr_in) finally: os.unlink(path) - np.testing.assert_array_equal(arr, arr_in) def test_shape_mismatch_raises(self): - content = "0.1\n0.5\n" - path = _write_tmp(content, ".txt") + path = _write_tmp("0.1\n0.5\n", ".txt") try: with pytest.raises(ValueError, match="vertices"): resolve_overlay(path, n_vertices=5) finally: os.unlink(path) - def test_bg_map_txt_routed(self): - content = "1\n-1\n1\n-1\n" - path = _write_tmp(content, ".txt") + def test_bg_map_and_roi_routed(self): + """resolve_bg_map and resolve_roi also correctly route .txt and .npy.""" + # bg_map from txt — always float32 + path = _write_tmp("1\n-1\n1\n-1\n", ".txt") try: arr = resolve_bg_map(path, n_vertices=4) finally: os.unlink(path) - assert arr.shape == (4,) - assert arr.dtype == np.float32 # always cast to float32 by resolve_bg_map + assert arr.shape == (4,) and arr.dtype == np.float32 - def test_roi_from_bool_npy(self): - """Boolean .npy array is a valid ROI input after resolve_roi casts it.""" + # roi from bool npy arr_in = np.array([True, False, True, True]) - fd, path = tempfile.mkstemp(suffix=".npy") + fd, path2 = tempfile.mkstemp(suffix=".npy") os.close(fd) - np.save(path, arr_in) + np.save(path2, arr_in) try: - # resolve_roi receives a str path; _load_overlay_from_file loads npy - # and returns the array; then resolve_roi casts it to bool. - arr = resolve_roi(path, n_vertices=4) + roi = resolve_roi(path2, n_vertices=4) finally: - os.unlink(path) - assert arr.dtype == bool - np.testing.assert_array_equal(arr, arr_in) + os.unlink(path2) + assert roi.dtype == bool + np.testing.assert_array_equal(roi, arr_in) - def test_label_txt_integer_values(self): - """Integer .txt file (parcellation) should be loadable as overlay.""" - content = "3\n0\n1\n3\n" - path = _write_tmp(content, ".txt") + def test_integer_txt_as_overlay(self): + """Integer txt (parcellation) values are numerically preserved.""" + path = _write_tmp("3\n0\n1\n3\n", ".txt") try: - # resolve_overlay casts to float32; original int32 from read_txt arr = resolve_overlay(path, n_vertices=4) finally: os.unlink(path) - assert arr.shape == (4,) - # Values should be preserved numerically np.testing.assert_array_equal(arr, [3.0, 0.0, 1.0, 3.0]) - From 972f995b47c569e1fbab029992da790556d235b8 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sun, 22 Feb 2026 00:59:55 +0100 Subject: [PATCH 76/83] add GUI mouse navigation --- whippersnappy/cli/whippersnap.py | 218 +++++++++++++++++++++++++------ whippersnappy/gl/__init__.py | 11 +- whippersnappy/gl/shaders.py | 8 +- whippersnappy/gl/views.py | 147 ++++++++++++++++++++- 4 files changed, 338 insertions(+), 46 deletions(-) diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index b082628..b5b4ac8 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -47,17 +47,24 @@ from .._version import __version__ from ..geometry import get_surf_name, prepare_geometry -from ..gl import get_view_matrices, init_window, setup_shader +from ..gl import ( + ViewState, + arcball_rotation_matrix, + arcball_vector, + compute_view_matrix, + get_view_matrices, + init_window, + setup_shader, +) from ..utils.types import ViewType # Module logger logger = logging.getLogger(__name__) -# Global state shared between the GL render loop and the Qt config panel. +# Global thresholds shared between the GL render loop and the Qt config panel. # All access is from the main thread — no locking needed. current_fthresh_ = None current_fmax_ = None -app_window_closed_ = False def show_window( @@ -75,8 +82,18 @@ def show_window( """Start a live interactive OpenGL+Qt window for viewing a triangular mesh. On macOS both GLFW/Cocoa and Qt require the main thread. This function - creates a GLFW window, then hands control to a ``QTimer``-driven render - loop so that GLFW polling and Qt event processing share the main thread. + creates a GLFW window, registers GLFW input callbacks, then hands control + to a ``QTimer``-driven render loop so GLFW polling and Qt event processing + share the main thread. + + Interaction + ----------- + * **Left-drag** — arcball rotation in world space (no gimbal lock). + * **Right-drag / Middle-drag** — pan in screen space. + * **Scroll wheel** — zoom (Z-translation). + * **Arrow keys** — rotate in 2° increments. + * **R key / double-click** — reset view to initial preset. + * **Q key / ESC** — quit. Parameters ---------- @@ -106,8 +123,9 @@ def show_window( RuntimeError If the GLFW window or OpenGL context could not be created. """ - global current_fthresh_, current_fmax_, app_window_closed_ + global current_fthresh_, current_fmax_ + import numpy as np # noqa: PLC0415 from PyQt6.QtCore import QTimer # noqa: PLC0415 wwidth = 720 @@ -118,11 +136,23 @@ def show_window( "Could not create a GLFW window/context. OpenGL context unavailable." ) + # ------------------------------------------------------------------ + # Initialise view state and base view matrix + # ------------------------------------------------------------------ view_mats = get_view_matrices() - viewmat = view_mats[view] - rot_y = pyrr.Matrix44.from_y_rotation(0) - ypos = 0.0 + base_view = view_mats[view] # fixed orientation preset (→ transform uniform) + vs = ViewState(zoom=0.0) # zoom/pan packed into transform + + _last_left_press_time = [0.0] + + def _reset_view(): + vs.rotation = np.eye(4, dtype=np.float32) + vs.pan = np.zeros(2, dtype=np.float32) + vs.zoom = 0.0 + # ------------------------------------------------------------------ + # Load mesh and compile shader + # ------------------------------------------------------------------ meshdata, triangles, _fthresh, _fmax, _pos, _neg = prepare_geometry( mesh, overlay, annot, bg_map, roi, current_fthresh_, current_fmax_, @@ -130,35 +160,152 @@ def show_window( ) shader = setup_shader(meshdata, triangles, wwidth, wheight, specular=specular) - logger.info("Keys: Left/Right arrows → rotate ESC → quit") + logger.info( + "Mouse: left-drag=rotate right/middle-drag=pan scroll=zoom " + "R/double-click=reset Q/ESC=quit" + ) + + # ------------------------------------------------------------------ + # GLFW input callbacks + # ------------------------------------------------------------------ + + def _mouse_button_cb(win, button, action, _mods): + x, y = glfw.get_cursor_pos(win) + pos = np.array([x, y], dtype=np.float64) + + if button == glfw.MOUSE_BUTTON_LEFT: + vs.left_button_down = (action == glfw.PRESS) + if action == glfw.PRESS: + # Double-click detection (threshold: 300 ms) + now = glfw.get_time() + if now - _last_left_press_time[0] < 0.3: + _reset_view() + _last_left_press_time[0] = now + vs.last_mouse_pos = pos + else: + vs.last_mouse_pos = None + + elif button == glfw.MOUSE_BUTTON_RIGHT: + vs.right_button_down = (action == glfw.PRESS) + vs.last_mouse_pos = pos if action == glfw.PRESS else None + + elif button == glfw.MOUSE_BUTTON_MIDDLE: + vs.middle_button_down = (action == glfw.PRESS) + vs.last_mouse_pos = pos if action == glfw.PRESS else None + + def _cursor_pos_cb(win, x, y): + if vs.last_mouse_pos is None: + return + dx = x - vs.last_mouse_pos[0] + dy = y - vs.last_mouse_pos[1] + + if vs.left_button_down: + # Arcball rotation — amplify drag for snappier feel + _sensitivity = 2.5 + mx, my = vs.last_mouse_pos + v1 = arcball_vector(mx, my, wwidth, wheight) + v2 = arcball_vector( + mx + (x - mx) * _sensitivity, + my + (y - my) * _sensitivity, + wwidth, wheight, + ) + delta = arcball_rotation_matrix(v2, v1) + vs.rotation = vs.rotation @ delta + + elif vs.right_button_down or vs.middle_button_down: + # Pan in camera space — scale to normalised mesh units + pan_sensitivity = 1.0 / min(wwidth, wheight) + vs.pan[0] += dx * pan_sensitivity + vs.pan[1] -= dy * pan_sensitivity # y is flipped + + vs.last_mouse_pos = np.array([x, y], dtype=np.float64) + + def _scroll_cb(_win, _x_off, y_off): + # scroll up (y_off > 0) → move camera closer (positive Z in camera space) + vs.zoom += y_off * 0.05 + # Allow much further zoom out (e.g. -20.0) + vs.zoom = float(np.clip(vs.zoom, -20.0, 4.5)) + + _arrow_keys = {glfw.KEY_RIGHT, glfw.KEY_LEFT, glfw.KEY_UP, glfw.KEY_DOWN} + + def _key_cb(win, key, _scancode, action, _mods): + if action not in (glfw.PRESS, glfw.REPEAT): + # On key release, restore normal render rate + if key in _arrow_keys: + timer.setInterval(16) + return + delta = np.radians(3.0) + if key == glfw.KEY_RIGHT: + rot = np.array(pyrr.Matrix44.from_y_rotation(-delta), dtype=np.float32) + elif key == glfw.KEY_LEFT: + rot = np.array(pyrr.Matrix44.from_y_rotation(+delta), dtype=np.float32) + elif key == glfw.KEY_UP: + rot = np.array(pyrr.Matrix44.from_x_rotation(+delta), dtype=np.float32) + elif key == glfw.KEY_DOWN: + rot = np.array(pyrr.Matrix44.from_x_rotation(-delta), dtype=np.float32) + elif key == glfw.KEY_R: + _reset_view() + return + elif key == glfw.KEY_Q: + glfw.set_window_should_close(win, True) + return + else: + return + # Speed up render loop while key is held for smooth rotation + if key in _arrow_keys: + timer.setInterval(8) + vs.rotation = vs.rotation @ rot + + glfw.set_mouse_button_callback(window, _mouse_button_cb) + glfw.set_cursor_pos_callback(window, _cursor_pos_cb) + glfw.set_scroll_callback(window, _scroll_cb) + glfw.set_key_callback(window, _key_cb) + + from PyQt6.QtCore import QEventLoop # noqa: PLC0415 + loop = QEventLoop() + + _quitting = [False] # guard so we only shut down once + + def _begin_quit(): + """Stop rendering and GLFW, then exit the event loop.""" + if _quitting[0]: + return + _quitting[0] = True + timer.stop() + try: + glfw.terminate() + except Exception: + pass + # Defer loop.quit() so this callback fully unwinds first, + # avoiding QThreadStorage destruction mid-stack. + QTimer.singleShot(0, loop.quit) def _render_frame(): - """Called by QTimer every frame; does one GLFW poll + one GL draw.""" - global current_fthresh_, current_fmax_, app_window_closed_ - nonlocal meshdata, triangles, shader, rot_y, ypos + """Called by QTimer every 16 ms on the main thread.""" + global current_fthresh_, current_fmax_ + nonlocal meshdata, triangles, shader + + if _quitting[0]: + return - # Check GLFW close conditions if ( glfw.get_key(window, glfw.KEY_ESCAPE) == glfw.PRESS or glfw.window_should_close(window) - or app_window_closed_ ): - timer.stop() - glfw.terminate() - app.quit() + _begin_quit() return glfw.poll_events() gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) - # Re-render if Qt sliders changed the thresholds + # Re-prepare geometry if Qt sliders changed thresholds if config_window is not None: new_fthresh = config_window.get_fthresh_value() new_fmax = config_window.get_fmax_value() if new_fthresh != current_fthresh_ or new_fmax != current_fmax_: current_fthresh_ = new_fthresh current_fmax_ = new_fmax - meshdata, triangles, _fthresh, _fmax, _pos, _neg = prepare_geometry( + meshdata, triangles, _ft, _fm, _p, _n = prepare_geometry( mesh, overlay, annot, bg_map, roi, current_fthresh_, current_fmax_, invert=invert, @@ -167,34 +314,24 @@ def _render_frame(): meshdata, triangles, wwidth, wheight, specular=specular ) - # Keyboard rotation - if glfw.get_key(window, glfw.KEY_RIGHT) == glfw.PRESS: - ypos += 0.004 - if glfw.get_key(window, glfw.KEY_LEFT) == glfw.PRESS: - ypos -= 0.004 - rot_y = pyrr.Matrix44.from_y_rotation(ypos) - - transform_loc = gl.glGetUniformLocation(shader, "transform") - gl.glUniformMatrix4fv(transform_loc, 1, gl.GL_FALSE, rot_y * viewmat) + # Identical to snap_rotate: transl * rotation * base_view → transform uniform. + # model and view uniforms are left as set by setup_shader (identity / camera). + gl.glUniformMatrix4fv( + gl.glGetUniformLocation(shader, "transform"), 1, gl.GL_FALSE, + compute_view_matrix(vs, base_view), + ) gl.glDrawElements(gl.GL_TRIANGLES, triangles.size, gl.GL_UNSIGNED_INT, None) glfw.swap_buffers(window) - # ~60 fps timer — fires every 16 ms on the main thread timer = QTimer() timer.timeout.connect(_render_frame) timer.start(16) - app.exec() - + # If the Qt config panel is closed, shut down GLFW and exit the loop. + if config_window is not None: + config_window.destroyed.connect(lambda: _begin_quit()) -def config_app_exit_handler(): - """Mark the configuration window as closed. - - Connected to ``QApplication.aboutToQuit`` so the render timer stops - cleanly when the user closes the Qt panel. - """ - global app_window_closed_ - app_window_closed_ = True + loop.exec() def run(): @@ -453,7 +590,6 @@ def run(): # show_window which drives rendering via a QTimer (no extra threads). app = QApplication(sys.argv) app.setStyle("Fusion") - app.aboutToQuit.connect(config_app_exit_handler) screen_geometry = app.primaryScreen().availableGeometry() config_window = ConfigWindow( diff --git a/whippersnappy/gl/__init__.py b/whippersnappy/gl/__init__.py index a2654d7..4b458bf 100644 --- a/whippersnappy/gl/__init__.py +++ b/whippersnappy/gl/__init__.py @@ -25,7 +25,14 @@ setup_vertex_attributes, terminate_context, ) -from .views import get_view_matrices, get_view_matrix +from .views import ( + ViewState, + arcball_rotation_matrix, + arcball_vector, + compute_view_matrix, + get_view_matrices, + get_view_matrix, +) __all__ = [ 'create_vao', 'compile_shader_program', 'setup_buffers', 'setup_vertex_attributes', @@ -34,4 +41,6 @@ 'make_model', 'make_projection', 'make_view', 'get_default_shaders', 'get_view_matrices', 'get_view_matrix', 'get_webgl_shaders', 'terminate_context', + 'ViewState', 'compute_view_matrix', + 'arcball_vector', 'arcball_rotation_matrix', ] diff --git a/whippersnappy/gl/shaders.py b/whippersnappy/gl/shaders.py index c6c237a..fd3c789 100644 --- a/whippersnappy/gl/shaders.py +++ b/whippersnappy/gl/shaders.py @@ -44,8 +44,8 @@ def get_default_shaders(): fragment_shader = """ #version 330 - in vec3 Normal; in vec3 FragPos; + in vec3 Normal; in vec3 Color; out vec4 FragColor; @@ -93,12 +93,14 @@ def get_default_shaders(): diff = max(dot(norm, lightDir), 0.0); diffuse = diffuse + diffweights[3] * diff * lightColor; - // specular + // specular — camera is at (0,0,-5) in world space (from make_view), + // not at origin, so viewDir must point from FragPos toward (0,0,-5). vec3 result; if (doSpecular) { float specularStrength = 0.5; - vec3 viewDir = normalize(-FragPos); + vec3 cameraPos = vec3(0.0, 0.0, -5.0); + vec3 viewDir = normalize(cameraPos - FragPos); vec3 reflectDir = reflect(ohlightDir, norm); float spec = pow(max(dot(viewDir, reflectDir), 0.0), 32); vec3 specular = specularStrength * spec * lightColor; diff --git a/whippersnappy/gl/views.py b/whippersnappy/gl/views.py index 507b89f..055248c 100644 --- a/whippersnappy/gl/views.py +++ b/whippersnappy/gl/views.py @@ -1,9 +1,154 @@ -"""View matrices and presets under gl package.""" +"""View matrices, presets, and interactive view state under gl package.""" + +from dataclasses import dataclass, field import numpy as np +import pyrr # still needed for compute_view_matrix from ..utils.types import ViewType +# --------------------------------------------------------------------------- +# ViewState — single source of truth for all mutable view parameters +# --------------------------------------------------------------------------- + +@dataclass +class ViewState: + """Mutable view parameters for the interactive GUI render loop. + + All mouse/keyboard interaction updates this object; the view matrix is + recomputed from it each frame via :func:`compute_view_matrix`. + + Attributes + ---------- + rotation : np.ndarray + 4×4 float32 rotation matrix (identity = no rotation applied). + pan : np.ndarray + (x, y) pan offset in normalised screen-space units. + zoom : float + Z-translation applied before the base view. Larger = further away. + last_mouse_pos : np.ndarray or None + Last recorded mouse position (pixels); ``None`` when no button held. + left_button_down, right_button_down, middle_button_down : bool + Current pressed state of each mouse button. + """ + rotation: np.ndarray = field( + default_factory=lambda: np.eye(4, dtype=np.float32) + ) + pan: np.ndarray = field( + default_factory=lambda: np.zeros(2, dtype=np.float32) + ) + zoom: float = 0.4 + last_mouse_pos: np.ndarray | None = None + left_button_down: bool = False + right_button_down: bool = False + middle_button_down: bool = False + + +def compute_view_matrix(view_state: ViewState, base_view: np.ndarray) -> np.ndarray: + """Return the ``transform`` uniform — exactly as snap_rotate does it. + + Packs ``transl * rotation * base_view`` into a single matrix, matching + the snap_rotate convention (line: ``viewmat = transl * rot * base_view``). + The ``model`` and ``view`` uniforms are left as set by ``setup_shader`` + (identity and camera respectively) and must not be overwritten. + + Parameters + ---------- + view_state : ViewState + Current interactive view state. + base_view : np.ndarray + Fixed 4×4 orientation preset from :func:`get_view_matrices`. + + Returns + ------- + np.ndarray + 4×4 float32 matrix for the ``transform`` shader uniform. + """ + transl = pyrr.Matrix44.from_translation(( + view_state.pan[0], + view_state.pan[1], + 0.4 + view_state.zoom, + )) + rot = pyrr.Matrix44(view_state.rotation) + return np.array(transl * rot * pyrr.Matrix44(base_view), dtype=np.float32) + + + +# --------------------------------------------------------------------------- +# Arcball helpers +# --------------------------------------------------------------------------- + +def arcball_vector(x: float, y: float, width: int, height: int) -> np.ndarray: + """Map a 2-D screen pixel to a point on the unit arcball sphere. + + Normalises (x, y) to [-1, 1] NDC, then projects onto the unit sphere. + Points outside the sphere radius are clamped to the rim (z = 0). + + Parameters + ---------- + x, y : float + Mouse position in pixels. + width, height : int + Window dimensions in pixels. + + Returns + ------- + np.ndarray + Unit 3-vector on (or clamped to) the arcball sphere. + """ + s = min(width, height) + p = np.array([ + (2.0 * x - width) / s, + -(2.0 * y - height) / s, + 0.0, + ], dtype=np.float64) + sq = p[0] ** 2 + p[1] ** 2 + if sq <= 1.0: + p[2] = np.sqrt(1.0 - sq) + else: + p /= np.sqrt(sq) # clamp to rim + n = np.linalg.norm(p) + return p / n if n > 0 else p + + +def arcball_rotation_matrix(v1: np.ndarray, v2: np.ndarray) -> np.ndarray: + """Return a 4×4 rotation matrix that rotates unit vector *v1* to *v2*. + + Uses Rodrigues' rotation formula in pure numpy — no pyrr dependency. + Returns identity when *v1* and *v2* are coincident. + + Parameters + ---------- + v1, v2 : np.ndarray + Unit 3-vectors on the arcball sphere. + + Returns + ------- + np.ndarray + 4×4 float32 rotation matrix compatible with pyrr. + """ + axis = np.cross(v1, v2) + axis_len = np.linalg.norm(axis) + if axis_len < 1e-10: + return np.eye(4, dtype=np.float32) + + axis = axis / axis_len + angle = np.arctan2(axis_len, np.dot(v1, v2)) + + # Rodrigues' formula: R = I cos(a) + sin(a) [axis]× + (1-cos(a)) axis⊗axis + c, s = np.cos(angle), np.sin(angle) + t = 1.0 - c + x, y, z = axis + r3 = np.array([ + [t*x*x + c, t*x*y - s*z, t*x*z + s*y], + [t*x*y + s*z, t*y*y + c, t*y*z - s*x], + [t*x*z - s*y, t*y*z + s*x, t*z*z + c ], + ], dtype=np.float32) + + r4 = np.eye(4, dtype=np.float32) + r4[:3, :3] = r3 + return r4 + def get_view_matrices(): """Return canonical 4x4 view matrices for common brain orientations. From 50d6091b4083e5021497c19f16907cada228d005 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sun, 22 Feb 2026 01:06:06 +0100 Subject: [PATCH 77/83] GUI add s key for snapshot --- whippersnappy/cli/whippersnap.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index b5b4ac8..16eb539 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -51,6 +51,7 @@ ViewState, arcball_rotation_matrix, arcball_vector, + capture_window, compute_view_matrix, get_view_matrices, init_window, @@ -91,8 +92,9 @@ def show_window( * **Left-drag** — arcball rotation in world space (no gimbal lock). * **Right-drag / Middle-drag** — pan in screen space. * **Scroll wheel** — zoom (Z-translation). - * **Arrow keys** — rotate in 2° increments. + * **Arrow keys** — rotate in 3° increments. * **R key / double-click** — reset view to initial preset. + * **S key** — save snapshot (opens file dialog). * **Q key / ESC** — quit. Parameters @@ -162,7 +164,7 @@ def _reset_view(): logger.info( "Mouse: left-drag=rotate right/middle-drag=pan scroll=zoom " - "R/double-click=reset Q/ESC=quit" + "R/double-click=reset S=snapshot Q/ESC=quit" ) # ------------------------------------------------------------------ @@ -226,6 +228,21 @@ def _scroll_cb(_win, _x_off, y_off): # Allow much further zoom out (e.g. -20.0) vs.zoom = float(np.clip(vs.zoom, -20.0, 4.5)) + def _save_snapshot(): + """Capture the current frame and open a Qt save-file dialog.""" + from PyQt6.QtWidgets import QFileDialog # noqa: PLC0415 + path, _ = QFileDialog.getSaveFileName( + None, + "Save snapshot", + "snapshot.png", + "Images (*.png *.jpg *.jpeg *.tiff *.bmp);;All files (*)", + ) + if not path: + return # user cancelled + img = capture_window(window) + img.save(path) + logger.info("Snapshot saved to %s", path) + _arrow_keys = {glfw.KEY_RIGHT, glfw.KEY_LEFT, glfw.KEY_UP, glfw.KEY_DOWN} def _key_cb(win, key, _scancode, action, _mods): @@ -249,6 +266,9 @@ def _key_cb(win, key, _scancode, action, _mods): elif key == glfw.KEY_Q: glfw.set_window_should_close(win, True) return + elif key == glfw.KEY_S and action == glfw.PRESS: + _save_snapshot() + return else: return # Speed up render loop while key is held for smooth rotation From 8548b4bb29553a466ecfcf8cbd12427ea1872e60 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sun, 22 Feb 2026 01:25:16 +0100 Subject: [PATCH 78/83] doc workflow remove comments --- .github/workflows/doc.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 73850ce..bf01390 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -27,10 +27,8 @@ jobs: cache: pip - name: Install system dependencies run: | - # sudo apt update sudo apt-get update sudo apt-get -y install libgl1 libegl1 libxcb-cursor0 pandoc - # sudo apt install libxcb-cursor0 - name: Install package run: | python -m pip install --progress-bar off --upgrade pip setuptools wheel From 9386743632fccaa3fbfd1989565d9bb13fc89f37 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sun, 22 Feb 2026 01:28:43 +0100 Subject: [PATCH 79/83] fix sphinx --- whippersnappy/cli/whippersnap.py | 22 ++++++++++++---------- whippersnappy/gl/views.py | 14 +++++++++----- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index 16eb539..4a7f2e5 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -87,16 +87,6 @@ def show_window( to a ``QTimer``-driven render loop so GLFW polling and Qt event processing share the main thread. - Interaction - ----------- - * **Left-drag** — arcball rotation in world space (no gimbal lock). - * **Right-drag / Middle-drag** — pan in screen space. - * **Scroll wheel** — zoom (Z-translation). - * **Arrow keys** — rotate in 3° increments. - * **R key / double-click** — reset view to initial preset. - * **S key** — save snapshot (opens file dialog). - * **Q key / ESC** — quit. - Parameters ---------- mesh : str or tuple of (array-like, array-like) @@ -124,6 +114,18 @@ def show_window( ------ RuntimeError If the GLFW window or OpenGL context could not be created. + + Notes + ----- + **Mouse and keyboard interaction:** + + * **Left-drag** — arcball rotation (view-relative, no gimbal lock). + * **Right-drag / Middle-drag** — pan in screen space. + * **Scroll wheel** — zoom in/out. + * **Arrow keys** — rotate in 3° increments (faster while held). + * **R / double-click** — reset view to initial preset. + * **S** — save snapshot (opens a file-save dialog). + * **Q / ESC** — quit. """ global current_fthresh_, current_fmax_ diff --git a/whippersnappy/gl/views.py b/whippersnappy/gl/views.py index 055248c..0d2af24 100644 --- a/whippersnappy/gl/views.py +++ b/whippersnappy/gl/views.py @@ -18,18 +18,22 @@ class ViewState: All mouse/keyboard interaction updates this object; the view matrix is recomputed from it each frame via :func:`compute_view_matrix`. - Attributes + Parameters ---------- rotation : np.ndarray 4×4 float32 rotation matrix (identity = no rotation applied). pan : np.ndarray (x, y) pan offset in normalised screen-space units. zoom : float - Z-translation applied before the base view. Larger = further away. + Z-translation packed into the transform matrix. last_mouse_pos : np.ndarray or None - Last recorded mouse position (pixels); ``None`` when no button held. - left_button_down, right_button_down, middle_button_down : bool - Current pressed state of each mouse button. + Last recorded mouse position in pixels; ``None`` when no button held. + left_button_down : bool + Whether the left mouse button is currently pressed. + right_button_down : bool + Whether the right mouse button is currently pressed. + middle_button_down : bool + Whether the middle mouse button is currently pressed. """ rotation: np.ndarray = field( default_factory=lambda: np.eye(4, dtype=np.float32) From 3cf84e382dcdeb01d78ac46c97a2f086bcb18bef Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sun, 22 Feb 2026 11:30:58 +0100 Subject: [PATCH 80/83] support for also reading LUT from command line --- README.md | 2 +- tests/test_array_and_rendering.py | 19 ++++++++++++++++ whippersnappy/cli/whippersnap.py | 4 +++- whippersnappy/cli/whippersnap1.py | 22 +++++++++++++++++-- whippersnappy/cli/whippersnap4.py | 36 ++++++++++++++++++++++++++++--- 5 files changed, 76 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index f38f86b..0a7f81a 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ from whippersnappy import snap1, snap4, snap_rotate, plot3d |---|---| | `snap1` | Single-view snapshot of any triangular mesh → PIL Image | | `snap4` | Four-view composed image (FreeSurfer subject, lateral/medial both hemispheres) | -| `snap_rotate` | 360° rotation video of any triangular mesh (MP4, WebM, or GIF) | +| `snap_rotate` | 360° rotation video of any triangular surface mesh (MP4, WebM, or GIF) | | `plot3d` | Interactive 3D WebGL viewer for Jupyter notebooks | **Supported mesh inputs for `snap1`, `snap_rotate`, and `plot3d`:** diff --git a/tests/test_array_and_rendering.py b/tests/test_array_and_rendering.py index 6ffce2e..d426b63 100644 --- a/tests/test_array_and_rendering.py +++ b/tests/test_array_and_rendering.py @@ -255,3 +255,22 @@ def test_snap1_with_bg_map(self): arr = np.array(_snap1_offscreen(mesh=(_V, _F), bg_map=bg)) assert arr.min() != arr.max(), "Image with bg_map is completely uniform." + def test_label_map_and_lut_rendering(self): + import numpy as np + + from whippersnappy.snap import snap1 + # Minimal tetra mesh + v = np.array([[0,0,0],[1,0,0],[0,1,0],[0,0,1]], dtype=np.float32) + f = np.array([[0,1,2],[0,1,3],[0,2,3],[1,2,3]], dtype=np.uint32) + # Label map: 4 vertices, 2 labels + labels = np.array([1,2,1,2], dtype=int) + # LUT: label id, R, G, B (values in 0-255) + lut = np.array([[1,255,0,0],[2,0,255,0]], dtype=float) + # Normalize LUT colors + lut[:,1:] = lut[:,1:] / 255.0 + annot = (labels, lut) + img = snap1((v, f), annot=annot) + assert img is not None + arr = np.array(img) + assert arr.shape[0] > 0 and arr.shape[1] > 0 + assert np.any(arr != arr[0,0]) # Not uniform diff --git a/whippersnappy/cli/whippersnap.py b/whippersnappy/cli/whippersnap.py index 4a7f2e5..9e9bd19 100644 --- a/whippersnappy/cli/whippersnap.py +++ b/whippersnappy/cli/whippersnap.py @@ -494,6 +494,9 @@ def run(): "--annot", type=str, default=None, help="FreeSurfer .annot file for parcellation coloring.", ) + common.add_argument("--lut", type=str, default=None, + help="Path to a label look-up-table (LUT) file (csv/txt) with label IDs " + "and RGB(A) colors. Required if --annot is a csv/txt label map.") # --- Appearance / rendering --- rend = parser.add_argument_group("appearance") @@ -640,4 +643,3 @@ def run(): if __name__ == "__main__": run() - diff --git a/whippersnappy/cli/whippersnap1.py b/whippersnappy/cli/whippersnap1.py index 4d409f6..f3684e3 100644 --- a/whippersnappy/cli/whippersnap1.py +++ b/whippersnappy/cli/whippersnap1.py @@ -45,6 +45,8 @@ import os import tempfile +import numpy as np + if __name__ == "__main__" and __package__ is None: import sys os.execv(sys.executable, [sys.executable, "-m", "whippersnappy.cli.whippersnap1"] + sys.argv[1:]) @@ -157,6 +159,9 @@ def run(): parser.add_argument("--bg-map", type=str, default=None, dest="bg_map", help="Path to a per-vertex scalar file used as background shading " "(sign determines light/dark).") + parser.add_argument("--lut", type=str, default=None, + help="Path to a label look-up-table (LUT) file (csv/txt) with label IDs and RGB(A) colors. " + "Required if --annot is a csv/txt label map.") # --- View --- parser.add_argument( @@ -267,11 +272,25 @@ def run(): outpath = args.output or os.path.join( tempfile.gettempdir(), "whippersnappy_snap1.png" ) + if args.annot is not None and args.lut is not None: + # Load label map + labels = np.loadtxt(args.annot, delimiter=None, dtype=int) + # Load LUT + lut = np.loadtxt(args.lut, delimiter=None) + # Normalize color values if needed + if lut.shape[1] in (4,5): # label + RGB(A) + rgb = lut[:,1:] + if np.any(rgb > 1): + rgb = rgb / 255.0 + lut[:,1:] = rgb + annot_tuple = (labels, lut) + else: + annot_tuple = args.annot img = snap1( mesh=mesh_path, outpath=outpath, overlay=args.overlay, - annot=args.annot, + annot=annot_tuple, roi=args.roi, bg_map=args.bg_map, view=_VIEW_CHOICES[args.view], @@ -297,4 +316,3 @@ def run(): if __name__ == "__main__": run() - diff --git a/whippersnappy/cli/whippersnap4.py b/whippersnappy/cli/whippersnap4.py index c1370b3..5457b34 100644 --- a/whippersnappy/cli/whippersnap4.py +++ b/whippersnappy/cli/whippersnap4.py @@ -18,6 +18,8 @@ import os import tempfile +import numpy as np + if __name__ == "__main__" and __package__ is None: import sys os.execv(sys.executable, [sys.executable, "-m", "whippersnappy.cli.whippersnap4"] + sys.argv[1:]) @@ -99,6 +101,12 @@ def run(): help="Path to the lh annotation (.annot) file.") parser.add_argument("--rh_annot", type=str, default=None, help="Path to the rh annotation (.annot) file.") + parser.add_argument("--lh_lut", type=str, default=None, + help="Path to the lh label look-up-table (LUT) file (csv/txt) with label IDs " + "and RGB(A) colors. Required if --lh_annot is a csv/txt label map.") + parser.add_argument("--rh_lut", type=str, default=None, + help="Path to the rh label look-up-table (LUT) file (csv/txt) with label IDs " + "and RGB(A) colors. Required if --rh_annot is a csv/txt label map.") # --- Subject directory / surface --- parser.add_argument("-sd", "--sdir", type=str, required=True, @@ -158,11 +166,34 @@ def run(): logger.debug("Parsed args: %s", vars(args)) try: + if args.lh_annot is not None and args.lh_lut is not None: + labels = np.loadtxt(args.lh_annot, delimiter=None, dtype=int) + lut = np.loadtxt(args.lh_lut, delimiter=None) + if lut.shape[1] in (4,5): + rgb = lut[:,1:] + if np.any(rgb > 1): + rgb = rgb / 255.0 + lut[:,1:] = rgb + lh_annot_tuple = (labels, lut) + else: + lh_annot_tuple = args.lh_annot + if args.rh_annot is not None and args.rh_lut is not None: + labels = np.loadtxt(args.rh_annot, delimiter=None, dtype=int) + lut = np.loadtxt(args.rh_lut, delimiter=None) + if lut.shape[1] in (4,5): + rgb = lut[:,1:] + if np.any(rgb > 1): + rgb = rgb / 255.0 + lut[:,1:] = rgb + rh_annot_tuple = (labels, lut) + else: + rh_annot_tuple = args.rh_annot + img = snap4( lh_overlay=args.lh_overlay, rh_overlay=args.rh_overlay, - lh_annot=args.lh_annot, - rh_annot=args.rh_annot, + lh_annot=lh_annot_tuple, + rh_annot=rh_annot_tuple, sdir=args.sdir, caption=args.caption, surfname=args.surf_name, @@ -186,4 +217,3 @@ def run(): if __name__ == "__main__": run() - From 48c8f3e032d80b8b047186f8c479e6d2779c11bd Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sun, 22 Feb 2026 11:42:28 +0100 Subject: [PATCH 81/83] skip test if no gl context is available --- tests/test_array_and_rendering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_array_and_rendering.py b/tests/test_array_and_rendering.py index d426b63..89244ca 100644 --- a/tests/test_array_and_rendering.py +++ b/tests/test_array_and_rendering.py @@ -269,7 +269,7 @@ def test_label_map_and_lut_rendering(self): # Normalize LUT colors lut[:,1:] = lut[:,1:] / 255.0 annot = (labels, lut) - img = snap1((v, f), annot=annot) + img = _snap1_offscreen((v, f), annot=annot) assert img is not None arr = np.array(img) assert arr.shape[0] > 0 and arr.shape[1] > 0 From 3b1d6e986407d5b894bbfd3fc3b7760c5b00fa1f Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sun, 22 Feb 2026 11:49:43 +0100 Subject: [PATCH 82/83] fix lut test --- tests/test_array_and_rendering.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tests/test_array_and_rendering.py b/tests/test_array_and_rendering.py index 89244ca..9cbffd6 100644 --- a/tests/test_array_and_rendering.py +++ b/tests/test_array_and_rendering.py @@ -256,20 +256,14 @@ def test_snap1_with_bg_map(self): assert arr.min() != arr.max(), "Image with bg_map is completely uniform." def test_label_map_and_lut_rendering(self): - import numpy as np - - from whippersnappy.snap import snap1 - # Minimal tetra mesh - v = np.array([[0,0,0],[1,0,0],[0,1,0],[0,0,1]], dtype=np.float32) - f = np.array([[0,1,2],[0,1,3],[0,2,3],[1,2,3]], dtype=np.uint32) - # Label map: 4 vertices, 2 labels + """Label map + LUT: image is non-uniform and correct size. + Skips on platforms without OpenGL context (Windows/macOS headless). + Re-uses the tetra mesh (_V, _F) as in other tests.""" labels = np.array([1,2,1,2], dtype=int) - # LUT: label id, R, G, B (values in 0-255) lut = np.array([[1,255,0,0],[2,0,255,0]], dtype=float) - # Normalize LUT colors lut[:,1:] = lut[:,1:] / 255.0 annot = (labels, lut) - img = _snap1_offscreen((v, f), annot=annot) + img = _snap1_offscreen(mesh=(_V, _F), annot=annot) assert img is not None arr = np.array(img) assert arr.shape[0] > 0 and arr.shape[1] > 0 From e58777b13d9cf877d6a7182ffc0002324281d611 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Sun, 22 Feb 2026 11:54:31 +0100 Subject: [PATCH 83/83] fix typo --- tests/test_array_and_rendering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_array_and_rendering.py b/tests/test_array_and_rendering.py index 9cbffd6..b5c723d 100644 --- a/tests/test_array_and_rendering.py +++ b/tests/test_array_and_rendering.py @@ -258,7 +258,7 @@ def test_snap1_with_bg_map(self): def test_label_map_and_lut_rendering(self): """Label map + LUT: image is non-uniform and correct size. Skips on platforms without OpenGL context (Windows/macOS headless). - Re-uses the tetra mesh (_V, _F) as in other tests.""" + Reuses the tetra mesh (_V, _F) as in other tests.""" labels = np.array([1,2,1,2], dtype=int) lut = np.array([[1,255,0,0],[2,0,255,0]], dtype=float) lut[:,1:] = lut[:,1:] / 255.0